Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
This commit is contained in:
405
training/trainer.py
Normal file
405
training/trainer.py
Normal file
@ -0,0 +1,405 @@
|
||||
"""
|
||||
Training loop with checkpointing, mixed precision, and evaluation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import TrainingConfig
|
||||
from .losses import get_loss_function
|
||||
from .evaluation import EmbeddingEvaluator
|
||||
|
||||
# Check if amp is available
|
||||
try:
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
AMP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AMP_AVAILABLE = False
|
||||
autocast = None
|
||||
GradScaler = None
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Trainer for fine-tuning CLIP on logo recognition.
|
||||
|
||||
Features:
|
||||
- Mixed precision training (FP16)
|
||||
- Gradient accumulation
|
||||
- Gradient checkpointing (via model)
|
||||
- Cosine annealing LR scheduler
|
||||
- Early stopping
|
||||
- Checkpoint saving/loading
|
||||
- Evaluation during training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: TrainingConfig,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the trainer.
|
||||
|
||||
Args:
|
||||
model: LogoFineTunedCLIP model
|
||||
train_loader: Training dataloader
|
||||
val_loader: Validation dataloader
|
||||
config: Training configuration
|
||||
logger: Optional logger instance
|
||||
"""
|
||||
self.model = model
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.config = config
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
# Device setup
|
||||
self.device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
self.model.to(self.device)
|
||||
self.logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Optimizer - only trainable parameters
|
||||
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
|
||||
|
||||
self.optimizer = AdamW(
|
||||
trainable_params,
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay,
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
total_steps = len(train_loader) * config.max_epochs
|
||||
self.scheduler = OneCycleLR(
|
||||
self.optimizer,
|
||||
max_lr=config.learning_rate,
|
||||
total_steps=total_steps,
|
||||
pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1,
|
||||
anneal_strategy="cos",
|
||||
)
|
||||
|
||||
# Mixed precision training
|
||||
self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda"
|
||||
if self.use_amp:
|
||||
self.scaler = GradScaler()
|
||||
self.logger.info("Mixed precision training enabled")
|
||||
else:
|
||||
self.scaler = None
|
||||
if config.mixed_precision and not AMP_AVAILABLE:
|
||||
self.logger.warning("Mixed precision requested but not available")
|
||||
|
||||
# Loss function
|
||||
self.criterion = get_loss_function(
|
||||
loss_type=config.loss_type,
|
||||
temperature=config.temperature,
|
||||
triplet_margin=config.triplet_margin,
|
||||
)
|
||||
|
||||
# Evaluator
|
||||
self.evaluator = EmbeddingEvaluator()
|
||||
|
||||
# Training state
|
||||
self.epoch = 0
|
||||
self.global_step = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.best_val_separation = float("-inf")
|
||||
self.patience_counter = 0
|
||||
self.training_history = []
|
||||
|
||||
def train(self) -> Dict[str, float]:
|
||||
"""
|
||||
Main training loop.
|
||||
|
||||
Returns:
|
||||
Dict with final training metrics
|
||||
"""
|
||||
self.logger.info("Starting training...")
|
||||
self.logger.info(f" Epochs: {self.config.max_epochs}")
|
||||
self.logger.info(f" Batch size: {self.config.batch_size}")
|
||||
self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}")
|
||||
self.logger.info(f" Effective batch: {self.config.effective_batch_size}")
|
||||
self.logger.info(f" Learning rate: {self.config.learning_rate}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(self.epoch, self.config.max_epochs):
|
||||
self.epoch = epoch
|
||||
self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}")
|
||||
|
||||
# Training epoch
|
||||
train_metrics = self._train_epoch()
|
||||
self.logger.info(
|
||||
f"Train - Loss: {train_metrics['loss']:.4f}, "
|
||||
f"LR: {train_metrics['lr']:.2e}"
|
||||
)
|
||||
|
||||
# Validation
|
||||
if (epoch + 1) % self.config.eval_every_n_epochs == 0:
|
||||
val_metrics = self._validate()
|
||||
self.logger.info(
|
||||
f"Val - Loss: {val_metrics['loss']:.4f}, "
|
||||
f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, "
|
||||
f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, "
|
||||
f"Separation: {val_metrics['separation']:.3f}"
|
||||
)
|
||||
|
||||
# Record history
|
||||
self.training_history.append({
|
||||
"epoch": epoch + 1,
|
||||
"train_loss": train_metrics["loss"],
|
||||
"val_loss": val_metrics["loss"],
|
||||
"val_separation": val_metrics["separation"],
|
||||
"val_pos_sim": val_metrics["mean_pos_sim"],
|
||||
"val_neg_sim": val_metrics["mean_neg_sim"],
|
||||
})
|
||||
|
||||
# Checkpointing based on separation (primary) or loss (secondary)
|
||||
improved = False
|
||||
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
|
||||
self.best_val_separation = val_metrics["separation"]
|
||||
improved = True
|
||||
elif val_metrics["loss"] < self.best_val_loss - self.config.min_delta:
|
||||
self.best_val_loss = val_metrics["loss"]
|
||||
improved = True
|
||||
|
||||
if improved:
|
||||
self.patience_counter = 0
|
||||
self._save_checkpoint("best.pt")
|
||||
self.logger.info("New best model saved!")
|
||||
else:
|
||||
self.patience_counter += 1
|
||||
|
||||
# Early stopping
|
||||
if self.patience_counter >= self.config.patience:
|
||||
self.logger.info(
|
||||
f"Early stopping triggered at epoch {epoch + 1} "
|
||||
f"(no improvement for {self.config.patience} epochs)"
|
||||
)
|
||||
break
|
||||
|
||||
# Periodic checkpoint
|
||||
if (epoch + 1) % self.config.save_every_n_epochs == 0:
|
||||
self._save_checkpoint(f"epoch_{epoch + 1}.pt")
|
||||
|
||||
# Training complete
|
||||
total_time = time.time() - start_time
|
||||
self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes")
|
||||
|
||||
# Load best model
|
||||
best_path = Path(self.config.checkpoint_dir) / "best.pt"
|
||||
if best_path.exists():
|
||||
self.load_checkpoint("best.pt")
|
||||
self.logger.info("Loaded best model checkpoint")
|
||||
|
||||
return {
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"best_val_separation": self.best_val_separation,
|
||||
"total_epochs": self.epoch + 1,
|
||||
"total_time_minutes": total_time / 60,
|
||||
}
|
||||
|
||||
def _train_epoch(self) -> Dict[str, float]:
|
||||
"""Run a single training epoch."""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
accumulation_steps = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
self.train_loader,
|
||||
desc=f"Epoch {self.epoch + 1}",
|
||||
leave=False,
|
||||
)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
for batch_idx, (images, labels) in enumerate(progress_bar):
|
||||
images = images.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
# Forward pass with mixed precision
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
loss = loss / self.config.gradient_accumulation_steps
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
else:
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
loss = loss / self.config.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
accumulation_steps += 1
|
||||
|
||||
# Optimizer step after accumulation
|
||||
if accumulation_steps >= self.config.gradient_accumulation_steps:
|
||||
if self.use_amp:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
self.global_step += 1
|
||||
accumulation_steps = 0
|
||||
|
||||
total_loss += loss.item() * self.config.gradient_accumulation_steps
|
||||
num_batches += 1
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({
|
||||
"loss": total_loss / num_batches,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
})
|
||||
|
||||
# Logging
|
||||
if (batch_idx + 1) % self.config.log_every_n_steps == 0:
|
||||
self.logger.debug(
|
||||
f"Step {self.global_step}: loss={total_loss / num_batches:.4f}"
|
||||
)
|
||||
|
||||
return {
|
||||
"loss": total_loss / max(num_batches, 1),
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
}
|
||||
|
||||
def _validate(self) -> Dict[str, float]:
|
||||
"""Run validation and compute metrics."""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
all_embeddings = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
|
||||
images = images.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
else:
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
|
||||
total_loss += loss.item()
|
||||
all_embeddings.append(embeddings.cpu())
|
||||
all_labels.append(labels.cpu())
|
||||
|
||||
# Combine batches
|
||||
all_embeddings = torch.cat(all_embeddings, dim=0)
|
||||
all_labels = torch.cat(all_labels, dim=0)
|
||||
|
||||
# Compute embedding quality metrics
|
||||
metrics = self.evaluator.compute_metrics(all_embeddings, all_labels)
|
||||
metrics["loss"] = total_loss / max(len(self.val_loader), 1)
|
||||
|
||||
return metrics
|
||||
|
||||
def _save_checkpoint(self, filename: str) -> None:
|
||||
"""Save training checkpoint."""
|
||||
checkpoint_dir = Path(self.config.checkpoint_dir)
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
checkpoint = {
|
||||
"epoch": self.epoch,
|
||||
"global_step": self.global_step,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"scheduler_state_dict": self.scheduler.state_dict(),
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"best_val_separation": self.best_val_separation,
|
||||
"patience_counter": self.patience_counter,
|
||||
"training_history": self.training_history,
|
||||
"config": self.config.__dict__,
|
||||
}
|
||||
|
||||
if self.scaler is not None:
|
||||
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
|
||||
|
||||
torch.save(checkpoint, checkpoint_dir / filename)
|
||||
self.logger.debug(f"Saved checkpoint: {filename}")
|
||||
|
||||
def load_checkpoint(self, filename: str) -> None:
|
||||
"""Load training checkpoint."""
|
||||
checkpoint_path = Path(self.config.checkpoint_dir) / filename
|
||||
if not checkpoint_path.exists():
|
||||
self.logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
||||
return
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
self.epoch = checkpoint["epoch"]
|
||||
self.global_step = checkpoint["global_step"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
self.best_val_separation = checkpoint.get("best_val_separation", float("-inf"))
|
||||
self.patience_counter = checkpoint.get("patience_counter", 0)
|
||||
self.training_history = checkpoint.get("training_history", [])
|
||||
|
||||
if self.scaler is not None and "scaler_state_dict" in checkpoint:
|
||||
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
||||
|
||||
self.logger.info(f"Resumed from epoch {self.epoch + 1}")
|
||||
|
||||
def export_model(self, output_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
Export the trained model for inference.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory (uses config.output_dir if not specified)
|
||||
|
||||
Returns:
|
||||
Path to exported model directory
|
||||
"""
|
||||
output_dir = output_dir or self.config.output_dir
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
self.model.save_pretrained(output_dir)
|
||||
|
||||
# Save training config
|
||||
config_path = output_path / "training_config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(self.config.__dict__, f, indent=2)
|
||||
|
||||
# Save training history
|
||||
history_path = output_path / "training_history.json"
|
||||
with open(history_path, "w") as f:
|
||||
json.dump(self.training_history, f, indent=2)
|
||||
|
||||
self.logger.info(f"Model exported to: {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
def get_training_summary(self) -> Dict:
|
||||
"""Get summary of training."""
|
||||
return {
|
||||
"epochs_completed": self.epoch + 1,
|
||||
"global_steps": self.global_step,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"best_val_separation": self.best_val_separation,
|
||||
"history": self.training_history,
|
||||
}
|
||||
Reference in New Issue
Block a user