Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
406 lines
14 KiB
Python
406 lines
14 KiB
Python
"""
|
|
Training loop with checkpointing, mixed precision, and evaluation.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.optim import AdamW
|
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
|
|
from .config import TrainingConfig
|
|
from .losses import get_loss_function
|
|
from .evaluation import EmbeddingEvaluator
|
|
|
|
# Check if amp is available
|
|
try:
|
|
from torch.cuda.amp import autocast, GradScaler
|
|
AMP_AVAILABLE = True
|
|
except ImportError:
|
|
AMP_AVAILABLE = False
|
|
autocast = None
|
|
GradScaler = None
|
|
|
|
|
|
class Trainer:
|
|
"""
|
|
Trainer for fine-tuning CLIP on logo recognition.
|
|
|
|
Features:
|
|
- Mixed precision training (FP16)
|
|
- Gradient accumulation
|
|
- Gradient checkpointing (via model)
|
|
- Cosine annealing LR scheduler
|
|
- Early stopping
|
|
- Checkpoint saving/loading
|
|
- Evaluation during training
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
train_loader: DataLoader,
|
|
val_loader: DataLoader,
|
|
config: TrainingConfig,
|
|
logger: Optional[logging.Logger] = None,
|
|
):
|
|
"""
|
|
Initialize the trainer.
|
|
|
|
Args:
|
|
model: LogoFineTunedCLIP model
|
|
train_loader: Training dataloader
|
|
val_loader: Validation dataloader
|
|
config: Training configuration
|
|
logger: Optional logger instance
|
|
"""
|
|
self.model = model
|
|
self.train_loader = train_loader
|
|
self.val_loader = val_loader
|
|
self.config = config
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
|
|
# Device setup
|
|
self.device = torch.device(
|
|
"cuda" if torch.cuda.is_available() else "cpu"
|
|
)
|
|
self.model.to(self.device)
|
|
self.logger.info(f"Using device: {self.device}")
|
|
|
|
# Optimizer - only trainable parameters
|
|
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
|
self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
|
|
|
|
self.optimizer = AdamW(
|
|
trainable_params,
|
|
lr=config.learning_rate,
|
|
weight_decay=config.weight_decay,
|
|
)
|
|
|
|
# Learning rate scheduler
|
|
total_steps = len(train_loader) * config.max_epochs
|
|
self.scheduler = OneCycleLR(
|
|
self.optimizer,
|
|
max_lr=config.learning_rate,
|
|
total_steps=total_steps,
|
|
pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1,
|
|
anneal_strategy="cos",
|
|
)
|
|
|
|
# Mixed precision training
|
|
self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda"
|
|
if self.use_amp:
|
|
self.scaler = GradScaler()
|
|
self.logger.info("Mixed precision training enabled")
|
|
else:
|
|
self.scaler = None
|
|
if config.mixed_precision and not AMP_AVAILABLE:
|
|
self.logger.warning("Mixed precision requested but not available")
|
|
|
|
# Loss function
|
|
self.criterion = get_loss_function(
|
|
loss_type=config.loss_type,
|
|
temperature=config.temperature,
|
|
triplet_margin=config.triplet_margin,
|
|
)
|
|
|
|
# Evaluator
|
|
self.evaluator = EmbeddingEvaluator()
|
|
|
|
# Training state
|
|
self.epoch = 0
|
|
self.global_step = 0
|
|
self.best_val_loss = float("inf")
|
|
self.best_val_separation = float("-inf")
|
|
self.patience_counter = 0
|
|
self.training_history = []
|
|
|
|
def train(self) -> Dict[str, float]:
|
|
"""
|
|
Main training loop.
|
|
|
|
Returns:
|
|
Dict with final training metrics
|
|
"""
|
|
self.logger.info("Starting training...")
|
|
self.logger.info(f" Epochs: {self.config.max_epochs}")
|
|
self.logger.info(f" Batch size: {self.config.batch_size}")
|
|
self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}")
|
|
self.logger.info(f" Effective batch: {self.config.effective_batch_size}")
|
|
self.logger.info(f" Learning rate: {self.config.learning_rate}")
|
|
|
|
start_time = time.time()
|
|
|
|
for epoch in range(self.epoch, self.config.max_epochs):
|
|
self.epoch = epoch
|
|
self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}")
|
|
|
|
# Training epoch
|
|
train_metrics = self._train_epoch()
|
|
self.logger.info(
|
|
f"Train - Loss: {train_metrics['loss']:.4f}, "
|
|
f"LR: {train_metrics['lr']:.2e}"
|
|
)
|
|
|
|
# Validation
|
|
if (epoch + 1) % self.config.eval_every_n_epochs == 0:
|
|
val_metrics = self._validate()
|
|
self.logger.info(
|
|
f"Val - Loss: {val_metrics['loss']:.4f}, "
|
|
f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, "
|
|
f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, "
|
|
f"Separation: {val_metrics['separation']:.3f}"
|
|
)
|
|
|
|
# Record history
|
|
self.training_history.append({
|
|
"epoch": epoch + 1,
|
|
"train_loss": train_metrics["loss"],
|
|
"val_loss": val_metrics["loss"],
|
|
"val_separation": val_metrics["separation"],
|
|
"val_pos_sim": val_metrics["mean_pos_sim"],
|
|
"val_neg_sim": val_metrics["mean_neg_sim"],
|
|
})
|
|
|
|
# Checkpointing based on separation (primary) or loss (secondary)
|
|
improved = False
|
|
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
|
|
self.best_val_separation = val_metrics["separation"]
|
|
improved = True
|
|
elif val_metrics["loss"] < self.best_val_loss - self.config.min_delta:
|
|
self.best_val_loss = val_metrics["loss"]
|
|
improved = True
|
|
|
|
if improved:
|
|
self.patience_counter = 0
|
|
self._save_checkpoint("best.pt")
|
|
self.logger.info("New best model saved!")
|
|
else:
|
|
self.patience_counter += 1
|
|
|
|
# Early stopping
|
|
if self.patience_counter >= self.config.patience:
|
|
self.logger.info(
|
|
f"Early stopping triggered at epoch {epoch + 1} "
|
|
f"(no improvement for {self.config.patience} epochs)"
|
|
)
|
|
break
|
|
|
|
# Periodic checkpoint
|
|
if (epoch + 1) % self.config.save_every_n_epochs == 0:
|
|
self._save_checkpoint(f"epoch_{epoch + 1}.pt")
|
|
|
|
# Training complete
|
|
total_time = time.time() - start_time
|
|
self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes")
|
|
|
|
# Load best model
|
|
best_path = Path(self.config.checkpoint_dir) / "best.pt"
|
|
if best_path.exists():
|
|
self.load_checkpoint("best.pt")
|
|
self.logger.info("Loaded best model checkpoint")
|
|
|
|
return {
|
|
"best_val_loss": self.best_val_loss,
|
|
"best_val_separation": self.best_val_separation,
|
|
"total_epochs": self.epoch + 1,
|
|
"total_time_minutes": total_time / 60,
|
|
}
|
|
|
|
def _train_epoch(self) -> Dict[str, float]:
|
|
"""Run a single training epoch."""
|
|
self.model.train()
|
|
total_loss = 0.0
|
|
num_batches = 0
|
|
accumulation_steps = 0
|
|
|
|
progress_bar = tqdm(
|
|
self.train_loader,
|
|
desc=f"Epoch {self.epoch + 1}",
|
|
leave=False,
|
|
)
|
|
|
|
self.optimizer.zero_grad()
|
|
|
|
for batch_idx, (images, labels) in enumerate(progress_bar):
|
|
images = images.to(self.device)
|
|
labels = labels.to(self.device)
|
|
|
|
# Forward pass with mixed precision
|
|
if self.use_amp:
|
|
with autocast():
|
|
embeddings = self.model(images)
|
|
loss = self.criterion(embeddings, labels)
|
|
loss = loss / self.config.gradient_accumulation_steps
|
|
|
|
self.scaler.scale(loss).backward()
|
|
else:
|
|
embeddings = self.model(images)
|
|
loss = self.criterion(embeddings, labels)
|
|
loss = loss / self.config.gradient_accumulation_steps
|
|
loss.backward()
|
|
|
|
accumulation_steps += 1
|
|
|
|
# Optimizer step after accumulation
|
|
if accumulation_steps >= self.config.gradient_accumulation_steps:
|
|
if self.use_amp:
|
|
self.scaler.step(self.optimizer)
|
|
self.scaler.update()
|
|
else:
|
|
self.optimizer.step()
|
|
|
|
self.optimizer.zero_grad()
|
|
self.scheduler.step()
|
|
self.global_step += 1
|
|
accumulation_steps = 0
|
|
|
|
total_loss += loss.item() * self.config.gradient_accumulation_steps
|
|
num_batches += 1
|
|
|
|
# Update progress bar
|
|
progress_bar.set_postfix({
|
|
"loss": total_loss / num_batches,
|
|
"lr": self.scheduler.get_last_lr()[0],
|
|
})
|
|
|
|
# Logging
|
|
if (batch_idx + 1) % self.config.log_every_n_steps == 0:
|
|
self.logger.debug(
|
|
f"Step {self.global_step}: loss={total_loss / num_batches:.4f}"
|
|
)
|
|
|
|
return {
|
|
"loss": total_loss / max(num_batches, 1),
|
|
"lr": self.scheduler.get_last_lr()[0],
|
|
}
|
|
|
|
def _validate(self) -> Dict[str, float]:
|
|
"""Run validation and compute metrics."""
|
|
self.model.eval()
|
|
total_loss = 0.0
|
|
all_embeddings = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
|
|
images = images.to(self.device)
|
|
labels = labels.to(self.device)
|
|
|
|
if self.use_amp:
|
|
with autocast():
|
|
embeddings = self.model(images)
|
|
loss = self.criterion(embeddings, labels)
|
|
else:
|
|
embeddings = self.model(images)
|
|
loss = self.criterion(embeddings, labels)
|
|
|
|
total_loss += loss.item()
|
|
all_embeddings.append(embeddings.cpu())
|
|
all_labels.append(labels.cpu())
|
|
|
|
# Combine batches
|
|
all_embeddings = torch.cat(all_embeddings, dim=0)
|
|
all_labels = torch.cat(all_labels, dim=0)
|
|
|
|
# Compute embedding quality metrics
|
|
metrics = self.evaluator.compute_metrics(all_embeddings, all_labels)
|
|
metrics["loss"] = total_loss / max(len(self.val_loader), 1)
|
|
|
|
return metrics
|
|
|
|
def _save_checkpoint(self, filename: str) -> None:
|
|
"""Save training checkpoint."""
|
|
checkpoint_dir = Path(self.config.checkpoint_dir)
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
checkpoint = {
|
|
"epoch": self.epoch,
|
|
"global_step": self.global_step,
|
|
"model_state_dict": self.model.state_dict(),
|
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
"scheduler_state_dict": self.scheduler.state_dict(),
|
|
"best_val_loss": self.best_val_loss,
|
|
"best_val_separation": self.best_val_separation,
|
|
"patience_counter": self.patience_counter,
|
|
"training_history": self.training_history,
|
|
"config": self.config.__dict__,
|
|
}
|
|
|
|
if self.scaler is not None:
|
|
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
|
|
|
|
torch.save(checkpoint, checkpoint_dir / filename)
|
|
self.logger.debug(f"Saved checkpoint: {filename}")
|
|
|
|
def load_checkpoint(self, filename: str) -> None:
|
|
"""Load training checkpoint."""
|
|
checkpoint_path = Path(self.config.checkpoint_dir) / filename
|
|
if not checkpoint_path.exists():
|
|
self.logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
|
return
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
|
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
|
self.epoch = checkpoint["epoch"]
|
|
self.global_step = checkpoint["global_step"]
|
|
self.best_val_loss = checkpoint["best_val_loss"]
|
|
self.best_val_separation = checkpoint.get("best_val_separation", float("-inf"))
|
|
self.patience_counter = checkpoint.get("patience_counter", 0)
|
|
self.training_history = checkpoint.get("training_history", [])
|
|
|
|
if self.scaler is not None and "scaler_state_dict" in checkpoint:
|
|
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
|
|
|
self.logger.info(f"Resumed from epoch {self.epoch + 1}")
|
|
|
|
def export_model(self, output_dir: Optional[str] = None) -> str:
|
|
"""
|
|
Export the trained model for inference.
|
|
|
|
Args:
|
|
output_dir: Output directory (uses config.output_dir if not specified)
|
|
|
|
Returns:
|
|
Path to exported model directory
|
|
"""
|
|
output_dir = output_dir or self.config.output_dir
|
|
output_path = Path(output_dir)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save model
|
|
self.model.save_pretrained(output_dir)
|
|
|
|
# Save training config
|
|
config_path = output_path / "training_config.json"
|
|
with open(config_path, "w") as f:
|
|
json.dump(self.config.__dict__, f, indent=2)
|
|
|
|
# Save training history
|
|
history_path = output_path / "training_history.json"
|
|
with open(history_path, "w") as f:
|
|
json.dump(self.training_history, f, indent=2)
|
|
|
|
self.logger.info(f"Model exported to: {output_path}")
|
|
return str(output_path)
|
|
|
|
def get_training_summary(self) -> Dict:
|
|
"""Get summary of training."""
|
|
return {
|
|
"epochs_completed": self.epoch + 1,
|
|
"global_steps": self.global_step,
|
|
"best_val_loss": self.best_val_loss,
|
|
"best_val_separation": self.best_val_separation,
|
|
"history": self.training_history,
|
|
}
|