""" Training loop with checkpointing, mixed precision, and evaluation. """ import json import logging import time from pathlib import Path from typing import Dict, Optional, Tuple import torch import torch.nn as nn from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR from torch.utils.data import DataLoader from tqdm import tqdm from .config import TrainingConfig from .losses import get_loss_function from .evaluation import EmbeddingEvaluator # Check if amp is available try: from torch.cuda.amp import autocast, GradScaler AMP_AVAILABLE = True except ImportError: AMP_AVAILABLE = False autocast = None GradScaler = None class Trainer: """ Trainer for fine-tuning CLIP on logo recognition. Features: - Mixed precision training (FP16) - Gradient accumulation - Gradient checkpointing (via model) - Cosine annealing LR scheduler - Early stopping - Checkpoint saving/loading - Evaluation during training """ def __init__( self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, config: TrainingConfig, logger: Optional[logging.Logger] = None, ): """ Initialize the trainer. Args: model: LogoFineTunedCLIP model train_loader: Training dataloader val_loader: Validation dataloader config: Training configuration logger: Optional logger instance """ self.model = model self.train_loader = train_loader self.val_loader = val_loader self.config = config self.logger = logger or logging.getLogger(__name__) # Device setup self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) self.model.to(self.device) self.logger.info(f"Using device: {self.device}") # Optimizer - only trainable parameters trainable_params = [p for p in model.parameters() if p.requires_grad] self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}") self.optimizer = AdamW( trainable_params, lr=config.learning_rate, weight_decay=config.weight_decay, ) # Learning rate scheduler total_steps = len(train_loader) * config.max_epochs self.scheduler = OneCycleLR( self.optimizer, max_lr=config.learning_rate, total_steps=total_steps, pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1, anneal_strategy="cos", ) # Mixed precision training self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda" if self.use_amp: self.scaler = GradScaler() self.logger.info("Mixed precision training enabled") else: self.scaler = None if config.mixed_precision and not AMP_AVAILABLE: self.logger.warning("Mixed precision requested but not available") # Loss function self.criterion = get_loss_function( loss_type=config.loss_type, temperature=config.temperature, triplet_margin=config.triplet_margin, ) # Evaluator self.evaluator = EmbeddingEvaluator() # Training state self.epoch = 0 self.global_step = 0 self.best_val_loss = float("inf") self.best_val_separation = float("-inf") self.patience_counter = 0 self.training_history = [] def train(self) -> Dict[str, float]: """ Main training loop. Returns: Dict with final training metrics """ self.logger.info("Starting training...") self.logger.info(f" Epochs: {self.config.max_epochs}") self.logger.info(f" Batch size: {self.config.batch_size}") self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}") self.logger.info(f" Effective batch: {self.config.effective_batch_size}") self.logger.info(f" Learning rate: {self.config.learning_rate}") start_time = time.time() for epoch in range(self.epoch, self.config.max_epochs): self.epoch = epoch self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}") # Training epoch train_metrics = self._train_epoch() self.logger.info( f"Train - Loss: {train_metrics['loss']:.4f}, " f"LR: {train_metrics['lr']:.2e}" ) # Validation if (epoch + 1) % self.config.eval_every_n_epochs == 0: val_metrics = self._validate() self.logger.info( f"Val - Loss: {val_metrics['loss']:.4f}, " f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, " f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, " f"Separation: {val_metrics['separation']:.3f}" ) # Record history self.training_history.append({ "epoch": epoch + 1, "train_loss": train_metrics["loss"], "val_loss": val_metrics["loss"], "val_separation": val_metrics["separation"], "val_pos_sim": val_metrics["mean_pos_sim"], "val_neg_sim": val_metrics["mean_neg_sim"], }) # Checkpointing based on separation (gap between pos and neg similarity) # This is the key metric for contrastive learning quality if val_metrics["separation"] > self.best_val_separation + self.config.min_delta: self.best_val_separation = val_metrics["separation"] self.best_val_loss = val_metrics["loss"] # Track for reference self.patience_counter = 0 self._save_checkpoint("best.pt") self.logger.info("New best model saved!") else: self.patience_counter += 1 # Early stopping if self.patience_counter >= self.config.patience: self.logger.info( f"Early stopping triggered at epoch {epoch + 1} " f"(no improvement for {self.config.patience} epochs)" ) break # Periodic checkpoint if (epoch + 1) % self.config.save_every_n_epochs == 0: self._save_checkpoint(f"epoch_{epoch + 1}.pt") # Training complete total_time = time.time() - start_time self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes") # Load best model best_path = Path(self.config.checkpoint_dir) / "best.pt" if best_path.exists(): self.load_checkpoint("best.pt") self.logger.info("Loaded best model checkpoint") return { "best_val_loss": self.best_val_loss, "best_val_separation": self.best_val_separation, "total_epochs": self.epoch + 1, "total_time_minutes": total_time / 60, } def _train_epoch(self) -> Dict[str, float]: """Run a single training epoch.""" self.model.train() total_loss = 0.0 num_batches = 0 accumulation_steps = 0 progress_bar = tqdm( self.train_loader, desc=f"Epoch {self.epoch + 1}", leave=False, ) self.optimizer.zero_grad() for batch_idx, (images, labels) in enumerate(progress_bar): images = images.to(self.device) labels = labels.to(self.device) # Forward pass with mixed precision if self.use_amp: with autocast(): embeddings = self.model(images) loss = self.criterion(embeddings, labels) loss = loss / self.config.gradient_accumulation_steps self.scaler.scale(loss).backward() else: embeddings = self.model(images) loss = self.criterion(embeddings, labels) loss = loss / self.config.gradient_accumulation_steps loss.backward() accumulation_steps += 1 # Optimizer step after accumulation if accumulation_steps >= self.config.gradient_accumulation_steps: if self.use_amp: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() self.global_step += 1 accumulation_steps = 0 total_loss += loss.item() * self.config.gradient_accumulation_steps num_batches += 1 # Update progress bar progress_bar.set_postfix({ "loss": total_loss / num_batches, "lr": self.scheduler.get_last_lr()[0], }) # Logging if (batch_idx + 1) % self.config.log_every_n_steps == 0: self.logger.debug( f"Step {self.global_step}: loss={total_loss / num_batches:.4f}" ) return { "loss": total_loss / max(num_batches, 1), "lr": self.scheduler.get_last_lr()[0], } def _validate(self) -> Dict[str, float]: """Run validation and compute metrics.""" self.model.eval() total_loss = 0.0 all_embeddings = [] all_labels = [] with torch.no_grad(): for images, labels in tqdm(self.val_loader, desc="Validating", leave=False): images = images.to(self.device) labels = labels.to(self.device) if self.use_amp: with autocast(): embeddings = self.model(images) loss = self.criterion(embeddings, labels) else: embeddings = self.model(images) loss = self.criterion(embeddings, labels) total_loss += loss.item() all_embeddings.append(embeddings.cpu()) all_labels.append(labels.cpu()) # Combine batches all_embeddings = torch.cat(all_embeddings, dim=0) all_labels = torch.cat(all_labels, dim=0) # Compute embedding quality metrics metrics = self.evaluator.compute_metrics(all_embeddings, all_labels) metrics["loss"] = total_loss / max(len(self.val_loader), 1) return metrics def _save_checkpoint(self, filename: str) -> None: """Save training checkpoint.""" checkpoint_dir = Path(self.config.checkpoint_dir) checkpoint_dir.mkdir(parents=True, exist_ok=True) checkpoint = { "epoch": self.epoch, "global_step": self.global_step, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict(), "best_val_loss": self.best_val_loss, "best_val_separation": self.best_val_separation, "patience_counter": self.patience_counter, "training_history": self.training_history, "config": self.config.__dict__, } if self.scaler is not None: checkpoint["scaler_state_dict"] = self.scaler.state_dict() torch.save(checkpoint, checkpoint_dir / filename) self.logger.debug(f"Saved checkpoint: {filename}") def load_checkpoint(self, filename: str) -> None: """Load training checkpoint.""" checkpoint_path = Path(self.config.checkpoint_dir) / filename if not checkpoint_path.exists(): self.logger.warning(f"Checkpoint not found: {checkpoint_path}") return checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) self.epoch = checkpoint["epoch"] self.global_step = checkpoint["global_step"] self.best_val_loss = checkpoint["best_val_loss"] self.best_val_separation = checkpoint.get("best_val_separation", float("-inf")) self.patience_counter = checkpoint.get("patience_counter", 0) self.training_history = checkpoint.get("training_history", []) if self.scaler is not None and "scaler_state_dict" in checkpoint: self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) self.logger.info(f"Resumed from epoch {self.epoch + 1}") def export_model(self, output_dir: Optional[str] = None) -> str: """ Export the trained model for inference. Args: output_dir: Output directory (uses config.output_dir if not specified) Returns: Path to exported model directory """ output_dir = output_dir or self.config.output_dir output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Save model self.model.save_pretrained(output_dir) # Save training config config_path = output_path / "training_config.json" with open(config_path, "w") as f: json.dump(self.config.__dict__, f, indent=2) # Save training history history_path = output_path / "training_history.json" with open(history_path, "w") as f: json.dump(self.training_history, f, indent=2) self.logger.info(f"Model exported to: {output_path}") return str(output_path) def get_training_summary(self) -> Dict: """Get summary of training.""" return { "epochs_completed": self.epoch + 1, "global_steps": self.global_step, "best_val_loss": self.best_val_loss, "best_val_separation": self.best_val_separation, "history": self.training_history, }