Files
logo_test/training/trainer.py
Rick McEwen 44e8b6ae7d Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision
encoder on LogoDet-3K dataset for improved logo embedding similarity.

New training module (training/):
- config.py: TrainingConfig dataclass with all hyperparameters
- dataset.py: LogoContrastiveDataset with logo-level splits
- model.py: LogoFineTunedCLIP wrapper with LoRA support
- losses.py: InfoNCE, TripletLoss, SupConLoss implementations
- trainer.py: Training loop with mixed precision and checkpointing
- evaluation.py: EmbeddingEvaluator for validation metrics

New scripts:
- train_clip_logo.py: Main training entry point
- export_model.py: Export to HuggingFace-compatible format

Configurations:
- configs/jetson_orin.yaml: Optimized for Jetson Orin AGX
- configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs
- configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs

Documentation:
- CLIP_FINETUNING.md: Training guide and usage instructions
- CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates

Modified:
- logo_detection_detr.py: Add fine-tuned model loading support
- pyproject.toml: Add peft, pyyaml, torchvision dependencies
2026-01-04 13:45:25 -05:00

406 lines
14 KiB
Python

"""
Training loop with checkpointing, mixed precision, and evaluation.
"""
import json
import logging
import time
from pathlib import Path
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from .config import TrainingConfig
from .losses import get_loss_function
from .evaluation import EmbeddingEvaluator
# Check if amp is available
try:
from torch.cuda.amp import autocast, GradScaler
AMP_AVAILABLE = True
except ImportError:
AMP_AVAILABLE = False
autocast = None
GradScaler = None
class Trainer:
"""
Trainer for fine-tuning CLIP on logo recognition.
Features:
- Mixed precision training (FP16)
- Gradient accumulation
- Gradient checkpointing (via model)
- Cosine annealing LR scheduler
- Early stopping
- Checkpoint saving/loading
- Evaluation during training
"""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: TrainingConfig,
logger: Optional[logging.Logger] = None,
):
"""
Initialize the trainer.
Args:
model: LogoFineTunedCLIP model
train_loader: Training dataloader
val_loader: Validation dataloader
config: Training configuration
logger: Optional logger instance
"""
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.logger = logger or logging.getLogger(__name__)
# Device setup
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.model.to(self.device)
self.logger.info(f"Using device: {self.device}")
# Optimizer - only trainable parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
self.optimizer = AdamW(
trainable_params,
lr=config.learning_rate,
weight_decay=config.weight_decay,
)
# Learning rate scheduler
total_steps = len(train_loader) * config.max_epochs
self.scheduler = OneCycleLR(
self.optimizer,
max_lr=config.learning_rate,
total_steps=total_steps,
pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1,
anneal_strategy="cos",
)
# Mixed precision training
self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda"
if self.use_amp:
self.scaler = GradScaler()
self.logger.info("Mixed precision training enabled")
else:
self.scaler = None
if config.mixed_precision and not AMP_AVAILABLE:
self.logger.warning("Mixed precision requested but not available")
# Loss function
self.criterion = get_loss_function(
loss_type=config.loss_type,
temperature=config.temperature,
triplet_margin=config.triplet_margin,
)
# Evaluator
self.evaluator = EmbeddingEvaluator()
# Training state
self.epoch = 0
self.global_step = 0
self.best_val_loss = float("inf")
self.best_val_separation = float("-inf")
self.patience_counter = 0
self.training_history = []
def train(self) -> Dict[str, float]:
"""
Main training loop.
Returns:
Dict with final training metrics
"""
self.logger.info("Starting training...")
self.logger.info(f" Epochs: {self.config.max_epochs}")
self.logger.info(f" Batch size: {self.config.batch_size}")
self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}")
self.logger.info(f" Effective batch: {self.config.effective_batch_size}")
self.logger.info(f" Learning rate: {self.config.learning_rate}")
start_time = time.time()
for epoch in range(self.epoch, self.config.max_epochs):
self.epoch = epoch
self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}")
# Training epoch
train_metrics = self._train_epoch()
self.logger.info(
f"Train - Loss: {train_metrics['loss']:.4f}, "
f"LR: {train_metrics['lr']:.2e}"
)
# Validation
if (epoch + 1) % self.config.eval_every_n_epochs == 0:
val_metrics = self._validate()
self.logger.info(
f"Val - Loss: {val_metrics['loss']:.4f}, "
f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, "
f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, "
f"Separation: {val_metrics['separation']:.3f}"
)
# Record history
self.training_history.append({
"epoch": epoch + 1,
"train_loss": train_metrics["loss"],
"val_loss": val_metrics["loss"],
"val_separation": val_metrics["separation"],
"val_pos_sim": val_metrics["mean_pos_sim"],
"val_neg_sim": val_metrics["mean_neg_sim"],
})
# Checkpointing based on separation (primary) or loss (secondary)
improved = False
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
self.best_val_separation = val_metrics["separation"]
improved = True
elif val_metrics["loss"] < self.best_val_loss - self.config.min_delta:
self.best_val_loss = val_metrics["loss"]
improved = True
if improved:
self.patience_counter = 0
self._save_checkpoint("best.pt")
self.logger.info("New best model saved!")
else:
self.patience_counter += 1
# Early stopping
if self.patience_counter >= self.config.patience:
self.logger.info(
f"Early stopping triggered at epoch {epoch + 1} "
f"(no improvement for {self.config.patience} epochs)"
)
break
# Periodic checkpoint
if (epoch + 1) % self.config.save_every_n_epochs == 0:
self._save_checkpoint(f"epoch_{epoch + 1}.pt")
# Training complete
total_time = time.time() - start_time
self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes")
# Load best model
best_path = Path(self.config.checkpoint_dir) / "best.pt"
if best_path.exists():
self.load_checkpoint("best.pt")
self.logger.info("Loaded best model checkpoint")
return {
"best_val_loss": self.best_val_loss,
"best_val_separation": self.best_val_separation,
"total_epochs": self.epoch + 1,
"total_time_minutes": total_time / 60,
}
def _train_epoch(self) -> Dict[str, float]:
"""Run a single training epoch."""
self.model.train()
total_loss = 0.0
num_batches = 0
accumulation_steps = 0
progress_bar = tqdm(
self.train_loader,
desc=f"Epoch {self.epoch + 1}",
leave=False,
)
self.optimizer.zero_grad()
for batch_idx, (images, labels) in enumerate(progress_bar):
images = images.to(self.device)
labels = labels.to(self.device)
# Forward pass with mixed precision
if self.use_amp:
with autocast():
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
loss = loss / self.config.gradient_accumulation_steps
self.scaler.scale(loss).backward()
else:
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
loss = loss / self.config.gradient_accumulation_steps
loss.backward()
accumulation_steps += 1
# Optimizer step after accumulation
if accumulation_steps >= self.config.gradient_accumulation_steps:
if self.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
self.global_step += 1
accumulation_steps = 0
total_loss += loss.item() * self.config.gradient_accumulation_steps
num_batches += 1
# Update progress bar
progress_bar.set_postfix({
"loss": total_loss / num_batches,
"lr": self.scheduler.get_last_lr()[0],
})
# Logging
if (batch_idx + 1) % self.config.log_every_n_steps == 0:
self.logger.debug(
f"Step {self.global_step}: loss={total_loss / num_batches:.4f}"
)
return {
"loss": total_loss / max(num_batches, 1),
"lr": self.scheduler.get_last_lr()[0],
}
def _validate(self) -> Dict[str, float]:
"""Run validation and compute metrics."""
self.model.eval()
total_loss = 0.0
all_embeddings = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
images = images.to(self.device)
labels = labels.to(self.device)
if self.use_amp:
with autocast():
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
else:
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
total_loss += loss.item()
all_embeddings.append(embeddings.cpu())
all_labels.append(labels.cpu())
# Combine batches
all_embeddings = torch.cat(all_embeddings, dim=0)
all_labels = torch.cat(all_labels, dim=0)
# Compute embedding quality metrics
metrics = self.evaluator.compute_metrics(all_embeddings, all_labels)
metrics["loss"] = total_loss / max(len(self.val_loader), 1)
return metrics
def _save_checkpoint(self, filename: str) -> None:
"""Save training checkpoint."""
checkpoint_dir = Path(self.config.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint = {
"epoch": self.epoch,
"global_step": self.global_step,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"best_val_loss": self.best_val_loss,
"best_val_separation": self.best_val_separation,
"patience_counter": self.patience_counter,
"training_history": self.training_history,
"config": self.config.__dict__,
}
if self.scaler is not None:
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
torch.save(checkpoint, checkpoint_dir / filename)
self.logger.debug(f"Saved checkpoint: {filename}")
def load_checkpoint(self, filename: str) -> None:
"""Load training checkpoint."""
checkpoint_path = Path(self.config.checkpoint_dir) / filename
if not checkpoint_path.exists():
self.logger.warning(f"Checkpoint not found: {checkpoint_path}")
return
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
self.epoch = checkpoint["epoch"]
self.global_step = checkpoint["global_step"]
self.best_val_loss = checkpoint["best_val_loss"]
self.best_val_separation = checkpoint.get("best_val_separation", float("-inf"))
self.patience_counter = checkpoint.get("patience_counter", 0)
self.training_history = checkpoint.get("training_history", [])
if self.scaler is not None and "scaler_state_dict" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
self.logger.info(f"Resumed from epoch {self.epoch + 1}")
def export_model(self, output_dir: Optional[str] = None) -> str:
"""
Export the trained model for inference.
Args:
output_dir: Output directory (uses config.output_dir if not specified)
Returns:
Path to exported model directory
"""
output_dir = output_dir or self.config.output_dir
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save model
self.model.save_pretrained(output_dir)
# Save training config
config_path = output_path / "training_config.json"
with open(config_path, "w") as f:
json.dump(self.config.__dict__, f, indent=2)
# Save training history
history_path = output_path / "training_history.json"
with open(history_path, "w") as f:
json.dump(self.training_history, f, indent=2)
self.logger.info(f"Model exported to: {output_path}")
return str(output_path)
def get_training_summary(self) -> Dict:
"""Get summary of training."""
return {
"epochs_completed": self.epoch + 1,
"global_steps": self.global_step,
"best_val_loss": self.best_val_loss,
"best_val_separation": self.best_val_separation,
"history": self.training_history,
}