Add CLIP fine-tuning pipeline for logo recognition

Implement contrastive learning with LoRA to fine-tune CLIP's vision
encoder on LogoDet-3K dataset for improved logo embedding similarity.

New training module (training/):
- config.py: TrainingConfig dataclass with all hyperparameters
- dataset.py: LogoContrastiveDataset with logo-level splits
- model.py: LogoFineTunedCLIP wrapper with LoRA support
- losses.py: InfoNCE, TripletLoss, SupConLoss implementations
- trainer.py: Training loop with mixed precision and checkpointing
- evaluation.py: EmbeddingEvaluator for validation metrics

New scripts:
- train_clip_logo.py: Main training entry point
- export_model.py: Export to HuggingFace-compatible format

Configurations:
- configs/jetson_orin.yaml: Optimized for Jetson Orin AGX
- configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs
- configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs

Documentation:
- CLIP_FINETUNING.md: Training guide and usage instructions
- CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates

Modified:
- logo_detection_detr.py: Add fine-tuned model loading support
- pyproject.toml: Add peft, pyyaml, torchvision dependencies
This commit is contained in:
Rick McEwen
2026-01-04 13:45:25 -05:00
parent 1551360028
commit 44e8b6ae7d
16 changed files with 3334 additions and 12 deletions

24
training/__init__.py Normal file
View File

@ -0,0 +1,24 @@
"""
CLIP fine-tuning module for logo recognition.
This module provides tools for fine-tuning CLIP's vision encoder using
contrastive learning on the LogoDet-3K dataset.
"""
from .config import TrainingConfig
from .dataset import LogoContrastiveDataset, create_dataloaders
from .model import LogoFineTunedCLIP
from .losses import InfoNCELoss, TripletLoss
from .trainer import Trainer
from .evaluation import EmbeddingEvaluator
__all__ = [
"TrainingConfig",
"LogoContrastiveDataset",
"create_dataloaders",
"LogoFineTunedCLIP",
"InfoNCELoss",
"TripletLoss",
"Trainer",
"EmbeddingEvaluator",
]

141
training/config.py Normal file
View File

@ -0,0 +1,141 @@
"""
Training configuration for CLIP fine-tuning.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
import yaml
@dataclass
class TrainingConfig:
"""Configuration for CLIP logo fine-tuning."""
# Base model
base_model: str = "openai/clip-vit-large-patch14"
# Dataset paths
dataset_dir: str = "LogoDet-3K"
reference_dir: str = "reference_logos"
db_path: str = "test_data_mapping.db"
# Data split ratios
train_split: float = 0.7
val_split: float = 0.15
test_split: float = 0.15
# Batch construction
batch_size: int = 16
logos_per_batch: int = 32
samples_per_logo: int = 4
gradient_accumulation_steps: int = 8
num_workers: int = 4
# Model architecture
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.1
freeze_layers: int = 12
use_gradient_checkpointing: bool = True
# Training hyperparameters
learning_rate: float = 1e-5
weight_decay: float = 0.01
warmup_steps: int = 500
max_epochs: int = 20
mixed_precision: bool = True
# Loss function
temperature: float = 0.07
loss_type: str = "infonce" # "infonce" or "triplet"
triplet_margin: float = 0.3
# Early stopping
patience: int = 5
min_delta: float = 0.001
# Checkpoints and output
checkpoint_dir: str = "checkpoints"
output_dir: str = "models/logo_detection/clip_finetuned"
save_every_n_epochs: int = 5
# Logging
log_every_n_steps: int = 10
eval_every_n_epochs: int = 1
# Random seed for reproducibility
seed: int = 42
# Hard negative mining
use_hard_negatives: bool = False
hard_negative_start_epoch: int = 5
hard_negatives_per_logo: int = 10
# Data augmentation
use_augmentation: bool = True
augmentation_strength: str = "medium" # "light", "medium", "strong"
@classmethod
def from_yaml(cls, yaml_path: str) -> "TrainingConfig":
"""Load configuration from YAML file."""
with open(yaml_path, "r") as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)
def to_yaml(self, yaml_path: str) -> None:
"""Save configuration to YAML file."""
Path(yaml_path).parent.mkdir(parents=True, exist_ok=True)
with open(yaml_path, "w") as f:
yaml.dump(self.__dict__, f, default_flow_style=False, sort_keys=False)
def validate(self) -> List[str]:
"""Validate configuration and return list of warnings."""
warnings = []
# Check split ratios
total_split = self.train_split + self.val_split + self.test_split
if abs(total_split - 1.0) > 0.01:
warnings.append(
f"Split ratios sum to {total_split}, expected 1.0"
)
# Check batch construction
effective_batch = self.batch_size * self.gradient_accumulation_steps
if effective_batch < 64:
warnings.append(
f"Effective batch size ({effective_batch}) is small for contrastive learning. "
"Consider increasing batch_size or gradient_accumulation_steps."
)
# Check LoRA config
if self.lora_r > 0 and self.lora_alpha < self.lora_r:
warnings.append(
f"lora_alpha ({self.lora_alpha}) < lora_r ({self.lora_r}). "
"This may reduce LoRA effectiveness."
)
# Check freeze layers
if self.freeze_layers < 0:
warnings.append("freeze_layers should be >= 0")
# Check temperature
if self.temperature <= 0:
warnings.append("temperature must be positive")
elif self.temperature > 1.0:
warnings.append(
f"temperature ({self.temperature}) is high. "
"Typical values are 0.05-0.1."
)
return warnings
@property
def effective_batch_size(self) -> int:
"""Calculate effective batch size with gradient accumulation."""
return self.batch_size * self.gradient_accumulation_steps
@property
def samples_per_batch(self) -> int:
"""Total samples in one batch (logos_per_batch * samples_per_logo)."""
return self.logos_per_batch * self.samples_per_logo

467
training/dataset.py Normal file
View File

@ -0,0 +1,467 @@
"""
Dataset classes for contrastive learning on logo images.
"""
import random
import sqlite3
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
# CLIP normalization values
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
def get_train_transforms(strength: str = "medium") -> transforms.Compose:
"""
Get training data augmentation transforms.
Args:
strength: Augmentation strength - "light", "medium", or "strong"
Returns:
Composed transforms for training
"""
if strength == "light":
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.1, contrast=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
elif strength == "medium":
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
),
transforms.RandomAffine(
degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
else: # strong
return transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.1),
transforms.RandomRotation(degrees=30),
transforms.ColorJitter(
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
),
transforms.RandomAffine(
degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10
),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
def get_val_transforms() -> transforms.Compose:
"""Get validation/test transforms (no augmentation)."""
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
class LogoDataset:
"""
Manages logo data from the SQLite database.
Handles loading logo-to-image mappings and splitting by logo brand.
"""
def __init__(
self,
db_path: str,
reference_dir: str,
train_split: float = 0.7,
val_split: float = 0.15,
test_split: float = 0.15,
seed: int = 42,
):
self.db_path = Path(db_path)
self.reference_dir = Path(reference_dir)
self.seed = seed
# Load logo-to-images mapping from database
self.logo_to_images = self._load_logo_mappings()
self.all_logos = list(self.logo_to_images.keys())
# Create logo-level splits
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
train_split, val_split, test_split
)
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
"""Load logo name to image paths mapping from database."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
SELECT ln.name, rl.filename
FROM reference_logos rl
JOIN logo_names ln ON rl.logo_name_id = ln.id
ORDER BY ln.name
""")
logo_to_images: Dict[str, List[Path]] = {}
for logo_name, filename in cursor.fetchall():
if logo_name not in logo_to_images:
logo_to_images[logo_name] = []
logo_to_images[logo_name].append(self.reference_dir / filename)
conn.close()
return logo_to_images
def _split_logos(
self,
train_split: float,
val_split: float,
test_split: float,
) -> Tuple[List[str], List[str], List[str]]:
"""Split logos at brand level for train/val/test."""
random.seed(self.seed)
logos = self.all_logos.copy()
random.shuffle(logos)
n = len(logos)
train_end = int(n * train_split)
val_end = train_end + int(n * val_split)
train_logos = logos[:train_end]
val_logos = logos[train_end:val_end]
test_logos = logos[val_end:]
return train_logos, val_logos, test_logos
def get_split_info(self) -> Dict[str, int]:
"""Return information about the splits."""
return {
"total_logos": len(self.all_logos),
"train_logos": len(self.train_logos),
"val_logos": len(self.val_logos),
"test_logos": len(self.test_logos),
"train_images": sum(
len(self.logo_to_images[l]) for l in self.train_logos
),
"val_images": sum(
len(self.logo_to_images[l]) for l in self.val_logos
),
"test_images": sum(
len(self.logo_to_images[l]) for l in self.test_logos
),
}
class LogoContrastiveDataset(Dataset):
"""
Dataset for contrastive learning on logos.
Each __getitem__ call returns a batch of images organized for contrastive
learning: K different logos with M samples each, ensuring positive pairs
exist within each batch.
"""
def __init__(
self,
logo_data: LogoDataset,
split: str = "train",
logos_per_batch: int = 32,
samples_per_logo: int = 4,
transform: Optional[transforms.Compose] = None,
batches_per_epoch: int = 1000,
):
"""
Initialize the contrastive dataset.
Args:
logo_data: LogoDataset instance with logo mappings
split: One of "train", "val", or "test"
logos_per_batch: Number of different logos per batch
samples_per_logo: Number of samples for each logo
transform: Image transforms to apply
batches_per_epoch: Number of batches per epoch
"""
self.logo_data = logo_data
self.logos_per_batch = logos_per_batch
self.samples_per_logo = samples_per_logo
self.transform = transform
self.batches_per_epoch = batches_per_epoch
# Get logos for this split
if split == "train":
self.logos = logo_data.train_logos
elif split == "val":
self.logos = logo_data.val_logos
else:
self.logos = logo_data.test_logos
# Filter logos with enough samples
self.valid_logos = [
logo for logo in self.logos
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
]
# For logos with fewer samples, we'll use with replacement
self.logos_needing_replacement = [
logo for logo in self.logos
if len(logo_data.logo_to_images[logo]) < samples_per_logo
]
# Create label mapping
self.logo_to_label = {
logo: idx for idx, logo in enumerate(self.logos)
}
def __len__(self) -> int:
return self.batches_per_epoch
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a batch of images for contrastive learning.
Returns:
images: Tensor of shape [K*M, 3, 224, 224]
labels: Tensor of shape [K*M] with logo class indices
"""
images = []
labels = []
# Sample K logos for this batch
k = min(self.logos_per_batch, len(self.logos))
batch_logos = random.sample(self.logos, k)
for logo in batch_logos:
logo_images = self.logo_data.logo_to_images[logo]
# Sample M images for this logo
if len(logo_images) >= self.samples_per_logo:
sampled_paths = random.sample(logo_images, self.samples_per_logo)
else:
# Sample with replacement if not enough images
sampled_paths = random.choices(
logo_images, k=self.samples_per_logo
)
# Load and transform images
for img_path in sampled_paths:
try:
img = Image.open(img_path).convert("RGB")
if self.transform:
img = self.transform(img)
else:
img = get_val_transforms()(img)
images.append(img)
labels.append(self.logo_to_label[logo])
except Exception as e:
# Skip problematic images, sample another
continue
# Stack into tensors
if len(images) == 0:
# Fallback: return dummy batch
return (
torch.zeros(1, 3, 224, 224),
torch.zeros(1, dtype=torch.long),
)
images_tensor = torch.stack(images)
labels_tensor = torch.tensor(labels, dtype=torch.long)
return images_tensor, labels_tensor
class BalancedBatchSampler(Sampler):
"""
Sampler that ensures each batch has a balanced distribution of logos.
Used with a flattened dataset where each sample is a single image.
"""
def __init__(
self,
logo_labels: List[int],
logos_per_batch: int,
samples_per_logo: int,
num_batches: int,
):
self.logo_labels = logo_labels
self.logos_per_batch = logos_per_batch
self.samples_per_logo = samples_per_logo
self.num_batches = num_batches
# Group indices by logo
self.logo_to_indices: Dict[int, List[int]] = {}
for idx, label in enumerate(logo_labels):
if label not in self.logo_to_indices:
self.logo_to_indices[label] = []
self.logo_to_indices[label].append(idx)
self.all_logos = list(self.logo_to_indices.keys())
def __iter__(self):
for _ in range(self.num_batches):
batch_indices = []
# Sample logos for this batch
logos = random.sample(
self.all_logos,
min(self.logos_per_batch, len(self.all_logos)),
)
for logo in logos:
indices = self.logo_to_indices[logo]
if len(indices) >= self.samples_per_logo:
sampled = random.sample(indices, self.samples_per_logo)
else:
sampled = random.choices(indices, k=self.samples_per_logo)
batch_indices.extend(sampled)
yield batch_indices
def __len__(self):
return self.num_batches
def create_dataloaders(
db_path: str,
reference_dir: str,
batch_size: int = 16,
logos_per_batch: int = 32,
samples_per_logo: int = 4,
num_workers: int = 4,
train_split: float = 0.7,
val_split: float = 0.15,
test_split: float = 0.15,
seed: int = 42,
augmentation_strength: str = "medium",
batches_per_epoch: int = 1000,
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
"""
Create train, validation, and optionally test dataloaders.
Args:
db_path: Path to SQLite database
reference_dir: Directory containing reference logo images
batch_size: Not used directly (see logos_per_batch and samples_per_logo)
logos_per_batch: Number of different logos per batch
samples_per_logo: Samples per logo in batch
num_workers: Number of data loading workers
train_split: Fraction for training
val_split: Fraction for validation
test_split: Fraction for testing
seed: Random seed
augmentation_strength: "light", "medium", or "strong"
batches_per_epoch: Number of batches per training epoch
Returns:
Tuple of (train_loader, val_loader, test_loader)
"""
# Load logo data
logo_data = LogoDataset(
db_path=db_path,
reference_dir=reference_dir,
train_split=train_split,
val_split=val_split,
test_split=test_split,
seed=seed,
)
# Print split info
split_info = logo_data.get_split_info()
print(f"Dataset loaded:")
print(f" Total logos: {split_info['total_logos']}")
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images")
# Create datasets
train_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="train",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_train_transforms(augmentation_strength),
batches_per_epoch=batches_per_epoch,
)
val_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="val",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_val_transforms(),
batches_per_epoch=batches_per_epoch // 10, # Fewer val batches
)
test_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="test",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_val_transforms(),
batches_per_epoch=batches_per_epoch // 10,
) if test_split > 0 else None
# Create dataloaders
# Note: batch_size=1 because each __getitem__ already returns a batch
train_loader = DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
test_loader = None
if test_dataset is not None:
test_loader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
return train_loader, val_loader, test_loader
def _collate_contrastive_batch(
batch: List[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Collate function that unpacks pre-batched data.
Since LogoContrastiveDataset already returns batched data,
we just squeeze the outer dimension.
"""
images, labels = batch[0]
return images, labels

339
training/evaluation.py Normal file
View File

@ -0,0 +1,339 @@
"""
Evaluation metrics for embedding quality.
"""
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import numpy as np
class EmbeddingEvaluator:
"""
Evaluator for embedding quality metrics.
Computes metrics that indicate how well the embeddings
separate different logo classes.
"""
def compute_metrics(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> Dict[str, float]:
"""
Compute embedding quality metrics.
Args:
embeddings: [N, D] L2-normalized embeddings
labels: [N] integer class labels
Returns:
Dict with metric names and values
"""
device = embeddings.device
batch_size = embeddings.shape[0]
if batch_size <= 1:
return {
"mean_pos_sim": 0.0,
"mean_neg_sim": 0.0,
"separation": 0.0,
"recall_at_1": 0.0,
"recall_at_5": 0.0,
}
# Compute similarity matrix
similarity = embeddings @ embeddings.T
# Create masks
labels_col = labels.unsqueeze(0)
labels_row = labels.unsqueeze(1)
positive_mask = (labels_row == labels_col).float()
negative_mask = 1 - positive_mask
# Remove diagonal from positive mask
identity = torch.eye(batch_size, device=device)
positive_mask = positive_mask - identity
# Count pairs
num_positives = positive_mask.sum()
num_negatives = negative_mask.sum()
# Mean positive similarity (excluding self)
if num_positives > 0:
pos_sims = (similarity * positive_mask).sum() / num_positives
mean_pos_sim = pos_sims.item()
else:
mean_pos_sim = 0.0
# Mean negative similarity
if num_negatives > 0:
neg_sims = (similarity * negative_mask).sum() / num_negatives
mean_neg_sim = neg_sims.item()
else:
mean_neg_sim = 0.0
# Separation: gap between positive and negative similarity
separation = mean_pos_sim - mean_neg_sim
# Recall@K metrics
recall_at_1 = self._compute_recall_at_k(similarity, labels, k=1)
recall_at_5 = self._compute_recall_at_k(similarity, labels, k=5)
return {
"mean_pos_sim": mean_pos_sim,
"mean_neg_sim": mean_neg_sim,
"separation": separation,
"recall_at_1": recall_at_1,
"recall_at_5": recall_at_5,
}
def _compute_recall_at_k(
self,
similarity: torch.Tensor,
labels: torch.Tensor,
k: int = 1,
) -> float:
"""
Compute Recall@K for nearest neighbor retrieval.
For each sample, check if the k nearest neighbors (excluding self)
contain at least one sample with the same label.
Args:
similarity: [N, N] similarity matrix
labels: [N] class labels
k: Number of neighbors to consider
Returns:
Recall@K score (0 to 1)
"""
batch_size = similarity.shape[0]
if batch_size <= 1:
return 0.0
# Mask out self-similarity
similarity = similarity.clone()
similarity.fill_diagonal_(float("-inf"))
# Get top-k indices
_, top_k_indices = similarity.topk(min(k, batch_size - 1), dim=1)
# Check if any of top-k have same label
correct = 0
for i in range(batch_size):
query_label = labels[i]
retrieved_labels = labels[top_k_indices[i]]
if (retrieved_labels == query_label).any():
correct += 1
return correct / batch_size
def compute_detailed_metrics(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
label_names: Optional[List[str]] = None,
) -> Dict:
"""
Compute detailed per-class metrics.
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
label_names: Optional list of label names
Returns:
Dict with detailed metrics including per-class stats
"""
basic_metrics = self.compute_metrics(embeddings, labels)
# Per-class statistics
unique_labels = labels.unique()
per_class_stats = {}
similarity = embeddings @ embeddings.T
for label in unique_labels:
mask = labels == label
class_embeddings = embeddings[mask]
class_size = mask.sum().item()
if class_size > 1:
# Intra-class similarity
class_sim = class_embeddings @ class_embeddings.T
# Exclude diagonal
mask_diag = ~torch.eye(class_size, dtype=torch.bool, device=class_sim.device)
intra_sim = class_sim[mask_diag].mean().item()
else:
intra_sim = 1.0
# Inter-class similarity (to other classes)
other_mask = labels != label
if other_mask.any():
inter_sim = similarity[mask][:, other_mask].mean().item()
else:
inter_sim = 0.0
class_name = label_names[label.item()] if label_names else str(label.item())
per_class_stats[class_name] = {
"size": class_size,
"intra_class_sim": intra_sim,
"inter_class_sim": inter_sim,
"class_separation": intra_sim - inter_sim,
}
# Aggregate per-class stats
if per_class_stats:
separations = [s["class_separation"] for s in per_class_stats.values()]
min_separation = min(separations)
max_separation = max(separations)
std_separation = np.std(separations)
else:
min_separation = max_separation = std_separation = 0.0
return {
**basic_metrics,
"per_class": per_class_stats,
"min_class_separation": min_separation,
"max_class_separation": max_separation,
"std_class_separation": std_separation,
}
class SimilarityAnalyzer:
"""
Analyze similarity distributions for debugging and tuning.
"""
@staticmethod
def analyze_similarity_distribution(
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> Dict[str, np.ndarray]:
"""
Get similarity distributions for positive and negative pairs.
Useful for choosing appropriate thresholds.
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
Returns:
Dict with 'positive_sims' and 'negative_sims' arrays
"""
similarity = (embeddings @ embeddings.T).cpu().numpy()
labels_np = labels.cpu().numpy()
batch_size = len(labels_np)
positive_sims = []
negative_sims = []
for i in range(batch_size):
for j in range(i + 1, batch_size):
if labels_np[i] == labels_np[j]:
positive_sims.append(similarity[i, j])
else:
negative_sims.append(similarity[i, j])
return {
"positive_sims": np.array(positive_sims),
"negative_sims": np.array(negative_sims),
}
@staticmethod
def find_hard_pairs(
embeddings: torch.Tensor,
labels: torch.Tensor,
n_hard: int = 10,
) -> Tuple[List[Tuple[int, int, float]], List[Tuple[int, int, float]]]:
"""
Find hardest positive and negative pairs.
Hard positives: same label but low similarity
Hard negatives: different label but high similarity
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
n_hard: Number of hard pairs to return
Returns:
Tuple of (hard_positives, hard_negatives)
Each is a list of (idx1, idx2, similarity) tuples
"""
similarity = embeddings @ embeddings.T
batch_size = len(labels)
hard_positives = [] # Low similarity, same label
hard_negatives = [] # High similarity, different label
for i in range(batch_size):
for j in range(i + 1, batch_size):
sim = similarity[i, j].item()
if labels[i] == labels[j]:
hard_positives.append((i, j, sim))
else:
hard_negatives.append((i, j, sim))
# Sort: hard positives by ascending similarity (lowest first)
hard_positives.sort(key=lambda x: x[2])
# Sort: hard negatives by descending similarity (highest first)
hard_negatives.sort(key=lambda x: -x[2])
return hard_positives[:n_hard], hard_negatives[:n_hard]
@staticmethod
def compute_confusion_pairs(
embeddings: torch.Tensor,
labels: torch.Tensor,
label_names: Optional[List[str]] = None,
top_k: int = 10,
) -> List[Dict]:
"""
Find pairs of classes that are most confused (highest cross-class similarity).
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
label_names: Optional label names
top_k: Number of confused pairs to return
Returns:
List of dicts with class pairs and their similarity
"""
unique_labels = labels.unique()
class_centroids = {}
# Compute class centroids
for label in unique_labels:
mask = labels == label
centroid = embeddings[mask].mean(dim=0)
centroid = F.normalize(centroid, dim=0)
class_centroids[label.item()] = centroid
# Compute pairwise centroid similarities
confusions = []
label_list = list(class_centroids.keys())
for i, label1 in enumerate(label_list):
for label2 in label_list[i + 1:]:
sim = (class_centroids[label1] @ class_centroids[label2]).item()
name1 = label_names[label1] if label_names else str(label1)
name2 = label_names[label2] if label_names else str(label2)
confusions.append({
"class1": name1,
"class2": name2,
"label1": label1,
"label2": label2,
"centroid_similarity": sim,
})
# Sort by similarity (highest first)
confusions.sort(key=lambda x: -x["centroid_similarity"])
return confusions[:top_k]

326
training/losses.py Normal file
View File

@ -0,0 +1,326 @@
"""
Loss functions for contrastive learning on logo embeddings.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class InfoNCELoss(nn.Module):
"""
Normalized Temperature-scaled Cross Entropy Loss (InfoNCE).
This is the contrastive loss used in CLIP training. It maximizes
similarity between embeddings of the same logo class while
minimizing similarity to embeddings of different classes.
For a batch with N samples:
- Each sample is an anchor
- Positive pairs: samples with the same label
- Negative pairs: samples with different labels
The loss for each anchor is:
-log(sum(exp(sim(anchor, pos)/temp)) / sum(exp(sim(anchor, all)/temp)))
"""
def __init__(self, temperature: float = 0.07):
"""
Initialize InfoNCE loss.
Args:
temperature: Scaling factor for similarities (0.05-0.1 typical).
Lower temperature makes the distribution sharper.
"""
super().__init__()
self.temperature = temperature
def forward(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
Compute InfoNCE loss for a batch of embeddings.
Args:
embeddings: [N, D] L2-normalized embeddings
labels: [N] integer logo class labels
Returns:
Scalar loss value
"""
device = embeddings.device
batch_size = embeddings.shape[0]
if batch_size <= 1:
return torch.tensor(0.0, device=device, requires_grad=True)
# Compute similarity matrix [N, N]
# Since embeddings are L2-normalized, dot product = cosine similarity
similarity = embeddings @ embeddings.T / self.temperature
# Create positive mask: same label = 1, different = 0
labels_col = labels.unsqueeze(0) # [1, N]
labels_row = labels.unsqueeze(1) # [N, 1]
positive_mask = (labels_row == labels_col).float() # [N, N]
# Remove self-similarity from positives (diagonal)
identity = torch.eye(batch_size, device=device)
positive_mask = positive_mask - identity
# Count positives per anchor (avoid division by zero)
num_positives = positive_mask.sum(dim=1)
has_positives = num_positives > 0
# If no positives exist for any anchor, return zero loss
if not has_positives.any():
return torch.tensor(0.0, device=device, requires_grad=True)
# Mask out self-similarity with large negative value
similarity = similarity - identity * 1e9
# Compute log-softmax over similarities
log_softmax = F.log_softmax(similarity, dim=1)
# Sum log probabilities of positive pairs
positive_log_probs = (log_softmax * positive_mask).sum(dim=1)
# Average over number of positives (only for anchors with positives)
loss_per_anchor = torch.zeros(batch_size, device=device)
loss_per_anchor[has_positives] = (
-positive_log_probs[has_positives] / num_positives[has_positives]
)
return loss_per_anchor.mean()
class SupConLoss(nn.Module):
"""
Supervised Contrastive Loss.
Similar to InfoNCE but uses a different formulation that
considers each positive pair separately rather than averaging.
Reference: https://arxiv.org/abs/2004.11362
"""
def __init__(self, temperature: float = 0.07):
super().__init__()
self.temperature = temperature
def forward(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
Compute Supervised Contrastive loss.
Args:
embeddings: [N, D] L2-normalized embeddings
labels: [N] integer logo class labels
Returns:
Scalar loss value
"""
device = embeddings.device
batch_size = embeddings.shape[0]
if batch_size <= 1:
return torch.tensor(0.0, device=device, requires_grad=True)
# Compute similarity matrix
similarity = embeddings @ embeddings.T / self.temperature
# Create masks
labels_col = labels.unsqueeze(0)
labels_row = labels.unsqueeze(1)
positive_mask = (labels_row == labels_col).float()
identity = torch.eye(batch_size, device=device)
# Remove self from positives
positive_mask = positive_mask - identity
# Number of positives per anchor
num_positives = positive_mask.sum(dim=1)
has_positives = num_positives > 0
if not has_positives.any():
return torch.tensor(0.0, device=device, requires_grad=True)
# For numerical stability, subtract max similarity
sim_max, _ = similarity.max(dim=1, keepdim=True)
similarity = similarity - sim_max.detach()
# Compute exp(similarity) with self masked out
exp_sim = torch.exp(similarity) * (1 - identity)
# Denominator: sum of exp over all pairs except self
log_prob = similarity - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)
# Mean of log-prob over positive pairs
mean_log_prob_pos = (positive_mask * log_prob).sum(dim=1) / (
num_positives + 1e-8
)
# Loss is negative mean log probability
loss = -mean_log_prob_pos[has_positives].mean()
return loss
class TripletLoss(nn.Module):
"""
Triplet loss with online hard mining.
For each anchor:
- Hardest positive: most distant sample with same label
- Hardest negative: closest sample with different label
Loss = max(0, d(anchor, hardest_pos) - d(anchor, hardest_neg) + margin)
This is an alternative to InfoNCE for when batch sizes are small.
"""
def __init__(self, margin: float = 0.3):
"""
Initialize Triplet loss.
Args:
margin: Minimum required gap between positive and negative distances
"""
super().__init__()
self.margin = margin
def forward(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
Compute triplet loss with online hard mining.
Args:
embeddings: [N, D] L2-normalized embeddings
labels: [N] integer logo class labels
Returns:
Scalar loss value
"""
device = embeddings.device
batch_size = embeddings.shape[0]
if batch_size <= 1:
return torch.tensor(0.0, device=device, requires_grad=True)
# Compute pairwise cosine distances (1 - cosine_similarity)
# For normalized vectors: distance = 1 - dot_product
similarity = embeddings @ embeddings.T
distances = 1 - similarity
# Create masks
labels_col = labels.unsqueeze(0)
labels_row = labels.unsqueeze(1)
positive_mask = (labels_row == labels_col).float()
negative_mask = 1 - positive_mask
# Remove self from positives (diagonal)
identity = torch.eye(batch_size, device=device)
positive_mask = positive_mask - identity
# Check if we have any valid triplets
has_positives = positive_mask.sum(dim=1) > 0
has_negatives = negative_mask.sum(dim=1) > 0
valid_anchors = has_positives & has_negatives
if not valid_anchors.any():
return torch.tensor(0.0, device=device, requires_grad=True)
# For each anchor, find hardest positive (max distance among positives)
# Set negatives to -inf so they don't affect max
pos_distances = distances.clone()
pos_distances[positive_mask == 0] = float("-inf")
hardest_positive, _ = pos_distances.max(dim=1)
# For each anchor, find hardest negative (min distance among negatives)
# Set positives to inf so they don't affect min
neg_distances = distances.clone()
neg_distances[negative_mask == 0] = float("inf")
hardest_negative, _ = neg_distances.min(dim=1)
# Triplet loss: want positive to be closer than negative by margin
triplet_loss = F.relu(
hardest_positive - hardest_negative + self.margin
)
# Average over valid anchors only
loss = triplet_loss[valid_anchors].mean()
return loss
class CombinedLoss(nn.Module):
"""
Combined loss function with weighted InfoNCE and Triplet losses.
Can help stabilize training by combining the benefits of both losses.
"""
def __init__(
self,
temperature: float = 0.07,
triplet_margin: float = 0.3,
infonce_weight: float = 1.0,
triplet_weight: float = 0.5,
):
super().__init__()
self.infonce = InfoNCELoss(temperature=temperature)
self.triplet = TripletLoss(margin=triplet_margin)
self.infonce_weight = infonce_weight
self.triplet_weight = triplet_weight
def forward(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
infonce_loss = self.infonce(embeddings, labels)
triplet_loss = self.triplet(embeddings, labels)
return (
self.infonce_weight * infonce_loss +
self.triplet_weight * triplet_loss
)
def get_loss_function(
loss_type: str = "infonce",
temperature: float = 0.07,
triplet_margin: float = 0.3,
) -> nn.Module:
"""
Factory function to create loss function.
Args:
loss_type: One of "infonce", "supcon", "triplet", or "combined"
temperature: Temperature for InfoNCE/SupCon
triplet_margin: Margin for triplet loss
Returns:
Loss function module
"""
if loss_type == "infonce":
return InfoNCELoss(temperature=temperature)
elif loss_type == "supcon":
return SupConLoss(temperature=temperature)
elif loss_type == "triplet":
return TripletLoss(margin=triplet_margin)
elif loss_type == "combined":
return CombinedLoss(
temperature=temperature,
triplet_margin=triplet_margin,
)
else:
raise ValueError(f"Unknown loss type: {loss_type}")

335
training/model.py Normal file
View File

@ -0,0 +1,335 @@
"""
Fine-tunable CLIP model wrapper with LoRA support.
"""
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPModel, CLIPProcessor
# Check if peft is available for LoRA
try:
from peft import LoraConfig, get_peft_model, PeftModel
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
LoraConfig = None
get_peft_model = None
PeftModel = None
class LogoFineTunedCLIP(nn.Module):
"""
CLIP vision encoder fine-tuned for logo similarity.
Preserves embedding interface for compatibility with DetectLogosDETR:
- Same embedding dimensionality (768 for ViT-L/14)
- L2 normalized outputs
- Works with existing get_image_features() pattern
Supports:
- LoRA for memory-efficient fine-tuning
- Layer freezing for transfer learning
- Gradient checkpointing for memory optimization
"""
def __init__(
self,
vision_model: nn.Module,
lora_r: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.1,
freeze_layers: int = 12,
use_gradient_checkpointing: bool = True,
add_projection_head: bool = True,
):
"""
Initialize the fine-tunable CLIP wrapper.
Args:
vision_model: CLIP vision model (CLIPVisionModel)
lora_r: Rank of LoRA low-rank matrices (0 to disable)
lora_alpha: LoRA scaling factor
lora_dropout: Dropout for LoRA layers
freeze_layers: Number of transformer layers to freeze (from bottom)
use_gradient_checkpointing: Enable gradient checkpointing
add_projection_head: Add trainable projection head
"""
super().__init__()
self.vision_model = vision_model
self.embedding_dim = vision_model.config.hidden_size
self.freeze_layers = freeze_layers
self.lora_r = lora_r
self.lora_alpha = lora_alpha
# Enable gradient checkpointing for memory efficiency
if use_gradient_checkpointing:
if hasattr(self.vision_model, "gradient_checkpointing_enable"):
self.vision_model.gradient_checkpointing_enable()
# Freeze lower layers
self._freeze_layers(freeze_layers)
# Apply LoRA to attention layers in upper blocks
self.peft_applied = False
if PEFT_AVAILABLE and lora_r > 0:
self._apply_lora(lora_r, lora_alpha, lora_dropout)
self.peft_applied = True
elif lora_r > 0 and not PEFT_AVAILABLE:
print(
"Warning: peft not installed. LoRA disabled. "
"Install with: pip install peft"
)
# Optional projection head for fine-tuning
self.add_projection_head = add_projection_head
if add_projection_head:
self.projection = nn.Sequential(
nn.Linear(self.embedding_dim, self.embedding_dim),
nn.LayerNorm(self.embedding_dim),
)
else:
self.projection = nn.Identity()
def _freeze_layers(self, num_layers: int) -> None:
"""Freeze the first N transformer layers and embeddings."""
if num_layers <= 0:
return
# Freeze embeddings
if hasattr(self.vision_model, "embeddings"):
for param in self.vision_model.embeddings.parameters():
param.requires_grad = False
# Freeze specified number of encoder layers
if hasattr(self.vision_model, "encoder"):
for i, layer in enumerate(self.vision_model.encoder.layers):
if i < num_layers:
for param in layer.parameters():
param.requires_grad = False
def _apply_lora(
self,
r: int,
alpha: int,
dropout: float,
) -> None:
"""Apply LoRA adapters to attention layers."""
if not PEFT_AVAILABLE:
return
# Configure LoRA for vision transformer
lora_config = LoraConfig(
r=r,
lora_alpha=alpha,
lora_dropout=dropout,
target_modules=["q_proj", "v_proj"],
bias="none",
modules_to_save=[], # Don't save any full modules
)
self.vision_model = get_peft_model(self.vision_model, lora_config)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Extract normalized embeddings for logo images.
Args:
pixel_values: [batch, 3, 224, 224] preprocessed images
Returns:
embeddings: [batch, embedding_dim] L2-normalized
"""
# Get vision features
outputs = self.vision_model(pixel_values=pixel_values)
# Use pooler output (CLS token projection) if available
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
features = outputs.pooler_output
else:
# Fall back to CLS token from last hidden state
features = outputs.last_hidden_state[:, 0, :]
# Apply projection head
features = self.projection(features)
# L2 normalize for cosine similarity
features = F.normalize(features, dim=-1)
return features
def get_image_features(self, **kwargs) -> torch.Tensor:
"""
Compatibility method matching CLIP's interface.
Used by DetectLogosDETR._get_embedding_pil().
"""
return self.forward(kwargs["pixel_values"])
def get_trainable_parameters(self) -> List[torch.nn.Parameter]:
"""Return list of trainable parameters."""
return [p for p in self.parameters() if p.requires_grad]
def get_parameter_count(self) -> Dict[str, int]:
"""Return count of trainable and total parameters."""
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {
"total": total,
"trainable": trainable,
"frozen": total - trainable,
"trainable_percent": 100 * trainable / total if total > 0 else 0,
}
def save_pretrained(self, output_dir: str) -> None:
"""
Save model in HuggingFace-compatible format.
Args:
output_dir: Directory to save model files
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save model weights
if self.peft_applied and PEFT_AVAILABLE:
# Save LoRA weights separately
self.vision_model.save_pretrained(output_path / "vision_lora")
# Save projection head
torch.save(
self.projection.state_dict(),
output_path / "projection_head.bin",
)
else:
# Save full model state
torch.save(self.state_dict(), output_path / "pytorch_model.bin")
# Save config
config = {
"model_type": "clip_logo_finetuned",
"embedding_dim": self.embedding_dim,
"lora_r": self.lora_r,
"lora_alpha": self.lora_alpha,
"freeze_layers": self.freeze_layers,
"add_projection_head": self.add_projection_head,
"peft_applied": self.peft_applied,
}
with open(output_path / "config.json", "w") as f:
json.dump(config, f, indent=2)
@classmethod
def from_pretrained(
cls,
model_path: str,
base_model: str = "openai/clip-vit-large-patch14",
device: Optional[torch.device] = None,
) -> "LogoFineTunedCLIP":
"""
Load a fine-tuned model from saved weights.
Args:
model_path: Path to saved model directory
base_model: Base CLIP model name (for architecture)
device: Device to load model on
Returns:
Loaded LogoFineTunedCLIP model
"""
model_path = Path(model_path)
# Load config
with open(model_path / "config.json", "r") as f:
config = json.load(f)
# Load base CLIP model
clip_model = CLIPModel.from_pretrained(base_model)
# Create model instance
model = cls(
vision_model=clip_model.vision_model,
lora_r=config.get("lora_r", 0),
lora_alpha=config.get("lora_alpha", 1),
freeze_layers=config.get("freeze_layers", 12),
add_projection_head=config.get("add_projection_head", True),
use_gradient_checkpointing=False, # Not needed for inference
)
# Load weights
if config.get("peft_applied", False) and PEFT_AVAILABLE:
# Load LoRA weights
lora_path = model_path / "vision_lora"
if lora_path.exists():
model.vision_model = PeftModel.from_pretrained(
model.vision_model, lora_path
)
# Load projection head
proj_path = model_path / "projection_head.bin"
if proj_path.exists():
model.projection.load_state_dict(torch.load(proj_path))
else:
# Load full model state
weights_path = model_path / "pytorch_model.bin"
if weights_path.exists():
model.load_state_dict(torch.load(weights_path))
if device is not None:
model = model.to(device)
return model
def create_model(
base_model: str = "openai/clip-vit-large-patch14",
lora_r: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.1,
freeze_layers: int = 12,
use_gradient_checkpointing: bool = True,
device: Optional[torch.device] = None,
) -> Tuple[LogoFineTunedCLIP, CLIPProcessor]:
"""
Create a fine-tunable CLIP model and processor.
Args:
base_model: HuggingFace model name or path
lora_r: LoRA rank (0 to disable)
lora_alpha: LoRA scaling factor
lora_dropout: LoRA dropout
freeze_layers: Number of layers to freeze
use_gradient_checkpointing: Enable gradient checkpointing
device: Device to load model on
Returns:
Tuple of (model, processor)
"""
# Load base CLIP model
clip_model = CLIPModel.from_pretrained(base_model)
processor = CLIPProcessor.from_pretrained(base_model)
# Create fine-tunable wrapper
model = LogoFineTunedCLIP(
vision_model=clip_model.vision_model,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
freeze_layers=freeze_layers,
use_gradient_checkpointing=use_gradient_checkpointing,
)
if device is not None:
model = model.to(device)
# Print parameter info
param_info = model.get_parameter_count()
print(f"Model created:")
print(f" Total parameters: {param_info['total']:,}")
print(f" Trainable: {param_info['trainable']:,} ({param_info['trainable_percent']:.2f}%)")
print(f" Frozen: {param_info['frozen']:,}")
return model, processor

405
training/trainer.py Normal file
View File

@ -0,0 +1,405 @@
"""
Training loop with checkpointing, mixed precision, and evaluation.
"""
import json
import logging
import time
from pathlib import Path
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from .config import TrainingConfig
from .losses import get_loss_function
from .evaluation import EmbeddingEvaluator
# Check if amp is available
try:
from torch.cuda.amp import autocast, GradScaler
AMP_AVAILABLE = True
except ImportError:
AMP_AVAILABLE = False
autocast = None
GradScaler = None
class Trainer:
"""
Trainer for fine-tuning CLIP on logo recognition.
Features:
- Mixed precision training (FP16)
- Gradient accumulation
- Gradient checkpointing (via model)
- Cosine annealing LR scheduler
- Early stopping
- Checkpoint saving/loading
- Evaluation during training
"""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: TrainingConfig,
logger: Optional[logging.Logger] = None,
):
"""
Initialize the trainer.
Args:
model: LogoFineTunedCLIP model
train_loader: Training dataloader
val_loader: Validation dataloader
config: Training configuration
logger: Optional logger instance
"""
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.logger = logger or logging.getLogger(__name__)
# Device setup
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.model.to(self.device)
self.logger.info(f"Using device: {self.device}")
# Optimizer - only trainable parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
self.optimizer = AdamW(
trainable_params,
lr=config.learning_rate,
weight_decay=config.weight_decay,
)
# Learning rate scheduler
total_steps = len(train_loader) * config.max_epochs
self.scheduler = OneCycleLR(
self.optimizer,
max_lr=config.learning_rate,
total_steps=total_steps,
pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1,
anneal_strategy="cos",
)
# Mixed precision training
self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda"
if self.use_amp:
self.scaler = GradScaler()
self.logger.info("Mixed precision training enabled")
else:
self.scaler = None
if config.mixed_precision and not AMP_AVAILABLE:
self.logger.warning("Mixed precision requested but not available")
# Loss function
self.criterion = get_loss_function(
loss_type=config.loss_type,
temperature=config.temperature,
triplet_margin=config.triplet_margin,
)
# Evaluator
self.evaluator = EmbeddingEvaluator()
# Training state
self.epoch = 0
self.global_step = 0
self.best_val_loss = float("inf")
self.best_val_separation = float("-inf")
self.patience_counter = 0
self.training_history = []
def train(self) -> Dict[str, float]:
"""
Main training loop.
Returns:
Dict with final training metrics
"""
self.logger.info("Starting training...")
self.logger.info(f" Epochs: {self.config.max_epochs}")
self.logger.info(f" Batch size: {self.config.batch_size}")
self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}")
self.logger.info(f" Effective batch: {self.config.effective_batch_size}")
self.logger.info(f" Learning rate: {self.config.learning_rate}")
start_time = time.time()
for epoch in range(self.epoch, self.config.max_epochs):
self.epoch = epoch
self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}")
# Training epoch
train_metrics = self._train_epoch()
self.logger.info(
f"Train - Loss: {train_metrics['loss']:.4f}, "
f"LR: {train_metrics['lr']:.2e}"
)
# Validation
if (epoch + 1) % self.config.eval_every_n_epochs == 0:
val_metrics = self._validate()
self.logger.info(
f"Val - Loss: {val_metrics['loss']:.4f}, "
f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, "
f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, "
f"Separation: {val_metrics['separation']:.3f}"
)
# Record history
self.training_history.append({
"epoch": epoch + 1,
"train_loss": train_metrics["loss"],
"val_loss": val_metrics["loss"],
"val_separation": val_metrics["separation"],
"val_pos_sim": val_metrics["mean_pos_sim"],
"val_neg_sim": val_metrics["mean_neg_sim"],
})
# Checkpointing based on separation (primary) or loss (secondary)
improved = False
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
self.best_val_separation = val_metrics["separation"]
improved = True
elif val_metrics["loss"] < self.best_val_loss - self.config.min_delta:
self.best_val_loss = val_metrics["loss"]
improved = True
if improved:
self.patience_counter = 0
self._save_checkpoint("best.pt")
self.logger.info("New best model saved!")
else:
self.patience_counter += 1
# Early stopping
if self.patience_counter >= self.config.patience:
self.logger.info(
f"Early stopping triggered at epoch {epoch + 1} "
f"(no improvement for {self.config.patience} epochs)"
)
break
# Periodic checkpoint
if (epoch + 1) % self.config.save_every_n_epochs == 0:
self._save_checkpoint(f"epoch_{epoch + 1}.pt")
# Training complete
total_time = time.time() - start_time
self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes")
# Load best model
best_path = Path(self.config.checkpoint_dir) / "best.pt"
if best_path.exists():
self.load_checkpoint("best.pt")
self.logger.info("Loaded best model checkpoint")
return {
"best_val_loss": self.best_val_loss,
"best_val_separation": self.best_val_separation,
"total_epochs": self.epoch + 1,
"total_time_minutes": total_time / 60,
}
def _train_epoch(self) -> Dict[str, float]:
"""Run a single training epoch."""
self.model.train()
total_loss = 0.0
num_batches = 0
accumulation_steps = 0
progress_bar = tqdm(
self.train_loader,
desc=f"Epoch {self.epoch + 1}",
leave=False,
)
self.optimizer.zero_grad()
for batch_idx, (images, labels) in enumerate(progress_bar):
images = images.to(self.device)
labels = labels.to(self.device)
# Forward pass with mixed precision
if self.use_amp:
with autocast():
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
loss = loss / self.config.gradient_accumulation_steps
self.scaler.scale(loss).backward()
else:
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
loss = loss / self.config.gradient_accumulation_steps
loss.backward()
accumulation_steps += 1
# Optimizer step after accumulation
if accumulation_steps >= self.config.gradient_accumulation_steps:
if self.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
self.global_step += 1
accumulation_steps = 0
total_loss += loss.item() * self.config.gradient_accumulation_steps
num_batches += 1
# Update progress bar
progress_bar.set_postfix({
"loss": total_loss / num_batches,
"lr": self.scheduler.get_last_lr()[0],
})
# Logging
if (batch_idx + 1) % self.config.log_every_n_steps == 0:
self.logger.debug(
f"Step {self.global_step}: loss={total_loss / num_batches:.4f}"
)
return {
"loss": total_loss / max(num_batches, 1),
"lr": self.scheduler.get_last_lr()[0],
}
def _validate(self) -> Dict[str, float]:
"""Run validation and compute metrics."""
self.model.eval()
total_loss = 0.0
all_embeddings = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
images = images.to(self.device)
labels = labels.to(self.device)
if self.use_amp:
with autocast():
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
else:
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
total_loss += loss.item()
all_embeddings.append(embeddings.cpu())
all_labels.append(labels.cpu())
# Combine batches
all_embeddings = torch.cat(all_embeddings, dim=0)
all_labels = torch.cat(all_labels, dim=0)
# Compute embedding quality metrics
metrics = self.evaluator.compute_metrics(all_embeddings, all_labels)
metrics["loss"] = total_loss / max(len(self.val_loader), 1)
return metrics
def _save_checkpoint(self, filename: str) -> None:
"""Save training checkpoint."""
checkpoint_dir = Path(self.config.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint = {
"epoch": self.epoch,
"global_step": self.global_step,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"best_val_loss": self.best_val_loss,
"best_val_separation": self.best_val_separation,
"patience_counter": self.patience_counter,
"training_history": self.training_history,
"config": self.config.__dict__,
}
if self.scaler is not None:
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
torch.save(checkpoint, checkpoint_dir / filename)
self.logger.debug(f"Saved checkpoint: {filename}")
def load_checkpoint(self, filename: str) -> None:
"""Load training checkpoint."""
checkpoint_path = Path(self.config.checkpoint_dir) / filename
if not checkpoint_path.exists():
self.logger.warning(f"Checkpoint not found: {checkpoint_path}")
return
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
self.epoch = checkpoint["epoch"]
self.global_step = checkpoint["global_step"]
self.best_val_loss = checkpoint["best_val_loss"]
self.best_val_separation = checkpoint.get("best_val_separation", float("-inf"))
self.patience_counter = checkpoint.get("patience_counter", 0)
self.training_history = checkpoint.get("training_history", [])
if self.scaler is not None and "scaler_state_dict" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
self.logger.info(f"Resumed from epoch {self.epoch + 1}")
def export_model(self, output_dir: Optional[str] = None) -> str:
"""
Export the trained model for inference.
Args:
output_dir: Output directory (uses config.output_dir if not specified)
Returns:
Path to exported model directory
"""
output_dir = output_dir or self.config.output_dir
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save model
self.model.save_pretrained(output_dir)
# Save training config
config_path = output_path / "training_config.json"
with open(config_path, "w") as f:
json.dump(self.config.__dict__, f, indent=2)
# Save training history
history_path = output_path / "training_history.json"
with open(history_path, "w") as f:
json.dump(self.training_history, f, indent=2)
self.logger.info(f"Model exported to: {output_path}")
return str(output_path)
def get_training_summary(self) -> Dict:
"""Get summary of training."""
return {
"epochs_completed": self.epoch + 1,
"global_steps": self.global_step,
"best_val_loss": self.best_val_loss,
"best_val_separation": self.best_val_separation,
"history": self.training_history,
}