Previously the trainer saved a new "best" model if either separation OR loss improved, with loss checked as a fallback. This caused confusing behavior where models with lower separation could overwrite better models. Now only separation (gap between positive and negative similarity) is used to determine the best model, which is the key metric for contrastive learning quality.
401 lines
14 KiB
Python
401 lines
14 KiB
Python
"""
|
|
Training loop with checkpointing, mixed precision, and evaluation.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import AdamW
|
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from .config import TrainingConfig
|
|
from .losses import get_loss_function
|
|
from .evaluation import EmbeddingEvaluator
|
|
|
|
# Check if amp is available
|
|
try:
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
AMP_AVAILABLE = True
|
|
except ImportError:
|
|
AMP_AVAILABLE = False
|
|
autocast = None
|
|
GradScaler = None
|
|
|
|
|
|
class Trainer:
|
|
"""
|
|
Trainer for fine-tuning CLIP on logo recognition.
|
|
|
|
Features:
|
|
- Mixed precision training (FP16)
|
|
- Gradient accumulation
|
|
- Gradient checkpointing (via model)
|
|
- Cosine annealing LR scheduler
|
|
- Early stopping
|
|
- Checkpoint saving/loading
|
|
- Evaluation during training
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
train_loader: DataLoader,
|
|
val_loader: DataLoader,
|
|
config: TrainingConfig,
|
|
logger: Optional[logging.Logger] = None,
|
|
):
|
|
"""
|
|
Initialize the trainer.
|
|
|
|
Args:
|
|
model: LogoFineTunedCLIP model
|
|
train_loader: Training dataloader
|
|
val_loader: Validation dataloader
|
|
config: Training configuration
|
|
logger: Optional logger instance
|
|
"""
|
|
self.model = model
|
|
self.train_loader = train_loader
|
|
self.val_loader = val_loader
|
|
self.config = config
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
|
|
# Device setup
|
|
self.device = torch.device(
|
|
"cuda" if torch.cuda.is_available() else "cpu"
|
|
)
|
|
self.model.to(self.device)
|
|
self.logger.info(f"Using device: {self.device}")
|
|
|
|
# Optimizer - only trainable parameters
|
|
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
|
self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
|
|
|
|
self.optimizer = AdamW(
|
|
trainable_params,
|
|
lr=config.learning_rate,
|
|
weight_decay=config.weight_decay,
|
|
)
|
|
|
|
# Learning rate scheduler
|
|
total_steps = len(train_loader) * config.max_epochs
|
|
self.scheduler = OneCycleLR(
|
|
self.optimizer,
|
|
max_lr=config.learning_rate,
|
|
total_steps=total_steps,
|
|
pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1,
|
|
anneal_strategy="cos",
|
|
)
|
|
|
|
# Mixed precision training
|
|
self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda"
|
|
if self.use_amp:
|
|
self.scaler = GradScaler()
|
|
self.logger.info("Mixed precision training enabled")
|
|
else:
|
|
self.scaler = None
|
|
if config.mixed_precision and not AMP_AVAILABLE:
|
|
self.logger.warning("Mixed precision requested but not available")
|
|
|
|
# Loss function
|
|
self.criterion = get_loss_function(
|
|
loss_type=config.loss_type,
|
|
temperature=config.temperature,
|
|
triplet_margin=config.triplet_margin,
|
|
)
|
|
|
|
# Evaluator
|
|
self.evaluator = EmbeddingEvaluator()
|
|
|
|
# Training state
|
|
self.epoch = 0
|
|
self.global_step = 0
|
|
self.best_val_loss = float("inf")
|
|
self.best_val_separation = float("-inf")
|
|
self.patience_counter = 0
|
|
self.training_history = []
|
|
|
|
def train(self) -> Dict[str, float]:
|
|
"""
|
|
Main training loop.
|
|
|
|
Returns:
|
|
Dict with final training metrics
|
|
"""
|
|
self.logger.info("Starting training...")
|
|
self.logger.info(f" Epochs: {self.config.max_epochs}")
|
|
self.logger.info(f" Batch size: {self.config.batch_size}")
|
|
self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}")
|
|
self.logger.info(f" Effective batch: {self.config.effective_batch_size}")
|
|
self.logger.info(f" Learning rate: {self.config.learning_rate}")
|
|
|
|
start_time = time.time()
|
|
|
|
for epoch in range(self.epoch, self.config.max_epochs):
|
|
self.epoch = epoch
|
|
self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}")
|
|
|
|
# Training epoch
|
|
train_metrics = self._train_epoch()
|
|
self.logger.info(
|
|
f"Train - Loss: {train_metrics['loss']:.4f}, "
|
|
f"LR: {train_metrics['lr']:.2e}"
|
|
)
|
|
|
|
# Validation
|
|
if (epoch + 1) % self.config.eval_every_n_epochs == 0:
|
|
val_metrics = self._validate()
|
|
self.logger.info(
|
|
f"Val - Loss: {val_metrics['loss']:.4f}, "
|
|
f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, "
|
|
f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, "
|
|
f"Separation: {val_metrics['separation']:.3f}"
|
|
)
|
|
|
|
# Record history
|
|
self.training_history.append({
|
|
"epoch": epoch + 1,
|
|
"train_loss": train_metrics["loss"],
|
|
"val_loss": val_metrics["loss"],
|
|
"val_separation": val_metrics["separation"],
|
|
"val_pos_sim": val_metrics["mean_pos_sim"],
|
|
"val_neg_sim": val_metrics["mean_neg_sim"],
|
|
})
|
|
|
|
# Checkpointing based on separation (gap between pos and neg similarity)
|
|
# This is the key metric for contrastive learning quality
|
|
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
|
|
self.best_val_separation = val_metrics["separation"]
|
|
self.best_val_loss = val_metrics["loss"] # Track for reference
|
|
self.patience_counter = 0
|
|
self._save_checkpoint("best.pt")
|
|
self.logger.info("New best model saved!")
|
|
else:
|
|
self.patience_counter += 1
|
|
|
|
# Early stopping
|
|
if self.patience_counter >= self.config.patience:
|
|
self.logger.info(
|
|
f"Early stopping triggered at epoch {epoch + 1} "
|
|
f"(no improvement for {self.config.patience} epochs)"
|
|
)
|
|
break
|
|
|
|
# Periodic checkpoint
|
|
if (epoch + 1) % self.config.save_every_n_epochs == 0:
|
|
self._save_checkpoint(f"epoch_{epoch + 1}.pt")
|
|
|
|
# Training complete
|
|
total_time = time.time() - start_time
|
|
self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes")
|
|
|
|
# Load best model
|
|
best_path = Path(self.config.checkpoint_dir) / "best.pt"
|
|
if best_path.exists():
|
|
self.load_checkpoint("best.pt")
|
|
self.logger.info("Loaded best model checkpoint")
|
|
|
|
return {
|
|
"best_val_loss": self.best_val_loss,
|
|
"best_val_separation": self.best_val_separation,
|
|
"total_epochs": self.epoch + 1,
|
|
"total_time_minutes": total_time / 60,
|
|
}
|
|
|
|
def _train_epoch(self) -> Dict[str, float]:
|
|
"""Run a single training epoch."""
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
accumulation_steps = 0
|
|
|
|
progress_bar = tqdm(
|
|
self.train_loader,
|
|
desc=f"Epoch {self.epoch + 1}",
|
|
leave=False,
|
|
)
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
for batch_idx, (images, labels) in enumerate(progress_bar):
|
|
images = images.to(self.device)
|
|
labels = labels.to(self.device)
|
|
|
|
# Forward pass with mixed precision
|
|
if self.use_amp:
|
|
with autocast():
|
|
embeddings = self.model(images)
|
|
loss = self.criterion(embeddings, labels)
|
|
loss = loss / self.config.gradient_accumulation_steps
|
|
|
|
self.scaler.scale(loss).backward()
|
|
else:
|
|
embeddings = self.model(images)
|
|
loss = self.criterion(embeddings, labels)
|
|
loss = loss / self.config.gradient_accumulation_steps
|
|
loss.backward()
|
|
|
|
accumulation_steps += 1
|
|
|
|
# Optimizer step after accumulation
|
|
if accumulation_steps >= self.config.gradient_accumulation_steps:
|
|
if self.use_amp:
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
self.optimizer.step()
|
|
|
|
self.optimizer.zero_grad()
|
|
self.scheduler.step()
|
|
self.global_step += 1
|
|
accumulation_steps = 0
|
|
|
|
total_loss += loss.item() * self.config.gradient_accumulation_steps
|
|
num_batches += 1
|
|
|
|
# Update progress bar
|
|
progress_bar.set_postfix({
|
|
"loss": total_loss / num_batches,
|
|
"lr": self.scheduler.get_last_lr()[0],
|
|
})
|
|
|
|
# Logging
|
|
if (batch_idx + 1) % self.config.log_every_n_steps == 0:
|
|
self.logger.debug(
|
|
f"Step {self.global_step}: loss={total_loss / num_batches:.4f}"
|
|
)
|
|
|
|
return {
|
|
"loss": total_loss / max(num_batches, 1),
|
|
"lr": self.scheduler.get_last_lr()[0],
|
|
}
|
|
|
|
def _validate(self) -> Dict[str, float]:
|
|
"""Run validation and compute metrics."""
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
all_embeddings = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
|
|
images = images.to(self.device)
|
|
labels = labels.to(self.device)
|
|
|
|
if self.use_amp:
|
|
with autocast():
|
|
embeddings = self.model(images)
|
|
loss = self.criterion(embeddings, labels)
|
|
else:
|
|
embeddings = self.model(images)
|
|
loss = self.criterion(embeddings, labels)
|
|
|
|
total_loss += loss.item()
|
|
all_embeddings.append(embeddings.cpu())
|
|
all_labels.append(labels.cpu())
|
|
|
|
# Combine batches
|
|
all_embeddings = torch.cat(all_embeddings, dim=0)
|
|
all_labels = torch.cat(all_labels, dim=0)
|
|
|
|
# Compute embedding quality metrics
|
|
metrics = self.evaluator.compute_metrics(all_embeddings, all_labels)
|
|
metrics["loss"] = total_loss / max(len(self.val_loader), 1)
|
|
|
|
return metrics
|
|
|
|
def _save_checkpoint(self, filename: str) -> None:
|
|
"""Save training checkpoint."""
|
|
checkpoint_dir = Path(self.config.checkpoint_dir)
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
checkpoint = {
|
|
"epoch": self.epoch,
|
|
"global_step": self.global_step,
|
|
"model_state_dict": self.model.state_dict(),
|
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
"scheduler_state_dict": self.scheduler.state_dict(),
|
|
"best_val_loss": self.best_val_loss,
|
|
"best_val_separation": self.best_val_separation,
|
|
"patience_counter": self.patience_counter,
|
|
"training_history": self.training_history,
|
|
"config": self.config.__dict__,
|
|
}
|
|
|
|
if self.scaler is not None:
|
|
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
|
|
|
|
torch.save(checkpoint, checkpoint_dir / filename)
|
|
self.logger.debug(f"Saved checkpoint: {filename}")
|
|
|
|
def load_checkpoint(self, filename: str) -> None:
|
|
"""Load training checkpoint."""
|
|
checkpoint_path = Path(self.config.checkpoint_dir) / filename
|
|
if not checkpoint_path.exists():
|
|
self.logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
|
return
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
|
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
self.epoch = checkpoint["epoch"]
|
|
self.global_step = checkpoint["global_step"]
|
|
self.best_val_loss = checkpoint["best_val_loss"]
|
|
self.best_val_separation = checkpoint.get("best_val_separation", float("-inf"))
|
|
self.patience_counter = checkpoint.get("patience_counter", 0)
|
|
self.training_history = checkpoint.get("training_history", [])
|
|
|
|
if self.scaler is not None and "scaler_state_dict" in checkpoint:
|
|
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
|
|
|
self.logger.info(f"Resumed from epoch {self.epoch + 1}")
|
|
|
|
def export_model(self, output_dir: Optional[str] = None) -> str:
|
|
"""
|
|
Export the trained model for inference.
|
|
|
|
Args:
|
|
output_dir: Output directory (uses config.output_dir if not specified)
|
|
|
|
Returns:
|
|
Path to exported model directory
|
|
"""
|
|
output_dir = output_dir or self.config.output_dir
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save model
|
|
self.model.save_pretrained(output_dir)
|
|
|
|
# Save training config
|
|
config_path = output_path / "training_config.json"
|
|
with open(config_path, "w") as f:
|
|
json.dump(self.config.__dict__, f, indent=2)
|
|
|
|
# Save training history
|
|
history_path = output_path / "training_history.json"
|
|
with open(history_path, "w") as f:
|
|
json.dump(self.training_history, f, indent=2)
|
|
|
|
self.logger.info(f"Model exported to: {output_path}")
|
|
return str(output_path)
|
|
|
|
def get_training_summary(self) -> Dict:
|
|
"""Get summary of training."""
|
|
return {
|
|
"epochs_completed": self.epoch + 1,
|
|
"global_steps": self.global_step,
|
|
"best_val_loss": self.best_val_loss,
|
|
"best_val_separation": self.best_val_separation,
|
|
"history": self.training_history,
|
|
}
|