Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
This commit is contained in:
24
training/__init__.py
Normal file
24
training/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""
|
||||
CLIP fine-tuning module for logo recognition.
|
||||
|
||||
This module provides tools for fine-tuning CLIP's vision encoder using
|
||||
contrastive learning on the LogoDet-3K dataset.
|
||||
"""
|
||||
|
||||
from .config import TrainingConfig
|
||||
from .dataset import LogoContrastiveDataset, create_dataloaders
|
||||
from .model import LogoFineTunedCLIP
|
||||
from .losses import InfoNCELoss, TripletLoss
|
||||
from .trainer import Trainer
|
||||
from .evaluation import EmbeddingEvaluator
|
||||
|
||||
__all__ = [
|
||||
"TrainingConfig",
|
||||
"LogoContrastiveDataset",
|
||||
"create_dataloaders",
|
||||
"LogoFineTunedCLIP",
|
||||
"InfoNCELoss",
|
||||
"TripletLoss",
|
||||
"Trainer",
|
||||
"EmbeddingEvaluator",
|
||||
]
|
||||
141
training/config.py
Normal file
141
training/config.py
Normal file
@ -0,0 +1,141 @@
|
||||
"""
|
||||
Training configuration for CLIP fine-tuning.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Configuration for CLIP logo fine-tuning."""
|
||||
|
||||
# Base model
|
||||
base_model: str = "openai/clip-vit-large-patch14"
|
||||
|
||||
# Dataset paths
|
||||
dataset_dir: str = "LogoDet-3K"
|
||||
reference_dir: str = "reference_logos"
|
||||
db_path: str = "test_data_mapping.db"
|
||||
|
||||
# Data split ratios
|
||||
train_split: float = 0.7
|
||||
val_split: float = 0.15
|
||||
test_split: float = 0.15
|
||||
|
||||
# Batch construction
|
||||
batch_size: int = 16
|
||||
logos_per_batch: int = 32
|
||||
samples_per_logo: int = 4
|
||||
gradient_accumulation_steps: int = 8
|
||||
num_workers: int = 4
|
||||
|
||||
# Model architecture
|
||||
lora_r: int = 16
|
||||
lora_alpha: int = 32
|
||||
lora_dropout: float = 0.1
|
||||
freeze_layers: int = 12
|
||||
use_gradient_checkpointing: bool = True
|
||||
|
||||
# Training hyperparameters
|
||||
learning_rate: float = 1e-5
|
||||
weight_decay: float = 0.01
|
||||
warmup_steps: int = 500
|
||||
max_epochs: int = 20
|
||||
mixed_precision: bool = True
|
||||
|
||||
# Loss function
|
||||
temperature: float = 0.07
|
||||
loss_type: str = "infonce" # "infonce" or "triplet"
|
||||
triplet_margin: float = 0.3
|
||||
|
||||
# Early stopping
|
||||
patience: int = 5
|
||||
min_delta: float = 0.001
|
||||
|
||||
# Checkpoints and output
|
||||
checkpoint_dir: str = "checkpoints"
|
||||
output_dir: str = "models/logo_detection/clip_finetuned"
|
||||
save_every_n_epochs: int = 5
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: int = 10
|
||||
eval_every_n_epochs: int = 1
|
||||
|
||||
# Random seed for reproducibility
|
||||
seed: int = 42
|
||||
|
||||
# Hard negative mining
|
||||
use_hard_negatives: bool = False
|
||||
hard_negative_start_epoch: int = 5
|
||||
hard_negatives_per_logo: int = 10
|
||||
|
||||
# Data augmentation
|
||||
use_augmentation: bool = True
|
||||
augmentation_strength: str = "medium" # "light", "medium", "strong"
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str) -> "TrainingConfig":
|
||||
"""Load configuration from YAML file."""
|
||||
with open(yaml_path, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_yaml(self, yaml_path: str) -> None:
|
||||
"""Save configuration to YAML file."""
|
||||
Path(yaml_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(yaml_path, "w") as f:
|
||||
yaml.dump(self.__dict__, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate configuration and return list of warnings."""
|
||||
warnings = []
|
||||
|
||||
# Check split ratios
|
||||
total_split = self.train_split + self.val_split + self.test_split
|
||||
if abs(total_split - 1.0) > 0.01:
|
||||
warnings.append(
|
||||
f"Split ratios sum to {total_split}, expected 1.0"
|
||||
)
|
||||
|
||||
# Check batch construction
|
||||
effective_batch = self.batch_size * self.gradient_accumulation_steps
|
||||
if effective_batch < 64:
|
||||
warnings.append(
|
||||
f"Effective batch size ({effective_batch}) is small for contrastive learning. "
|
||||
"Consider increasing batch_size or gradient_accumulation_steps."
|
||||
)
|
||||
|
||||
# Check LoRA config
|
||||
if self.lora_r > 0 and self.lora_alpha < self.lora_r:
|
||||
warnings.append(
|
||||
f"lora_alpha ({self.lora_alpha}) < lora_r ({self.lora_r}). "
|
||||
"This may reduce LoRA effectiveness."
|
||||
)
|
||||
|
||||
# Check freeze layers
|
||||
if self.freeze_layers < 0:
|
||||
warnings.append("freeze_layers should be >= 0")
|
||||
|
||||
# Check temperature
|
||||
if self.temperature <= 0:
|
||||
warnings.append("temperature must be positive")
|
||||
elif self.temperature > 1.0:
|
||||
warnings.append(
|
||||
f"temperature ({self.temperature}) is high. "
|
||||
"Typical values are 0.05-0.1."
|
||||
)
|
||||
|
||||
return warnings
|
||||
|
||||
@property
|
||||
def effective_batch_size(self) -> int:
|
||||
"""Calculate effective batch size with gradient accumulation."""
|
||||
return self.batch_size * self.gradient_accumulation_steps
|
||||
|
||||
@property
|
||||
def samples_per_batch(self) -> int:
|
||||
"""Total samples in one batch (logos_per_batch * samples_per_logo)."""
|
||||
return self.logos_per_batch * self.samples_per_logo
|
||||
467
training/dataset.py
Normal file
467
training/dataset.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""
|
||||
Dataset classes for contrastive learning on logo images.
|
||||
"""
|
||||
|
||||
import random
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
# CLIP normalization values
|
||||
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
||||
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
|
||||
def get_train_transforms(strength: str = "medium") -> transforms.Compose:
|
||||
"""
|
||||
Get training data augmentation transforms.
|
||||
|
||||
Args:
|
||||
strength: Augmentation strength - "light", "medium", or "strong"
|
||||
|
||||
Returns:
|
||||
Composed transforms for training
|
||||
"""
|
||||
if strength == "light":
|
||||
return transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.ColorJitter(brightness=0.1, contrast=0.1),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
elif strength == "medium":
|
||||
return transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomRotation(degrees=15),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
|
||||
),
|
||||
transforms.RandomAffine(
|
||||
degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.1),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
else: # strong
|
||||
return transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.RandomCrop(224),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomVerticalFlip(p=0.1),
|
||||
transforms.RandomRotation(degrees=30),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
|
||||
),
|
||||
transforms.RandomAffine(
|
||||
degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
|
||||
|
||||
def get_val_transforms() -> transforms.Compose:
|
||||
"""Get validation/test transforms (no augmentation)."""
|
||||
return transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
|
||||
|
||||
class LogoDataset:
|
||||
"""
|
||||
Manages logo data from the SQLite database.
|
||||
|
||||
Handles loading logo-to-image mappings and splitting by logo brand.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str,
|
||||
reference_dir: str,
|
||||
train_split: float = 0.7,
|
||||
val_split: float = 0.15,
|
||||
test_split: float = 0.15,
|
||||
seed: int = 42,
|
||||
):
|
||||
self.db_path = Path(db_path)
|
||||
self.reference_dir = Path(reference_dir)
|
||||
self.seed = seed
|
||||
|
||||
# Load logo-to-images mapping from database
|
||||
self.logo_to_images = self._load_logo_mappings()
|
||||
self.all_logos = list(self.logo_to_images.keys())
|
||||
|
||||
# Create logo-level splits
|
||||
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
||||
train_split, val_split, test_split
|
||||
)
|
||||
|
||||
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
|
||||
"""Load logo name to image paths mapping from database."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT ln.name, rl.filename
|
||||
FROM reference_logos rl
|
||||
JOIN logo_names ln ON rl.logo_name_id = ln.id
|
||||
ORDER BY ln.name
|
||||
""")
|
||||
|
||||
logo_to_images: Dict[str, List[Path]] = {}
|
||||
for logo_name, filename in cursor.fetchall():
|
||||
if logo_name not in logo_to_images:
|
||||
logo_to_images[logo_name] = []
|
||||
logo_to_images[logo_name].append(self.reference_dir / filename)
|
||||
|
||||
conn.close()
|
||||
return logo_to_images
|
||||
|
||||
def _split_logos(
|
||||
self,
|
||||
train_split: float,
|
||||
val_split: float,
|
||||
test_split: float,
|
||||
) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""Split logos at brand level for train/val/test."""
|
||||
random.seed(self.seed)
|
||||
logos = self.all_logos.copy()
|
||||
random.shuffle(logos)
|
||||
|
||||
n = len(logos)
|
||||
train_end = int(n * train_split)
|
||||
val_end = train_end + int(n * val_split)
|
||||
|
||||
train_logos = logos[:train_end]
|
||||
val_logos = logos[train_end:val_end]
|
||||
test_logos = logos[val_end:]
|
||||
|
||||
return train_logos, val_logos, test_logos
|
||||
|
||||
def get_split_info(self) -> Dict[str, int]:
|
||||
"""Return information about the splits."""
|
||||
return {
|
||||
"total_logos": len(self.all_logos),
|
||||
"train_logos": len(self.train_logos),
|
||||
"val_logos": len(self.val_logos),
|
||||
"test_logos": len(self.test_logos),
|
||||
"train_images": sum(
|
||||
len(self.logo_to_images[l]) for l in self.train_logos
|
||||
),
|
||||
"val_images": sum(
|
||||
len(self.logo_to_images[l]) for l in self.val_logos
|
||||
),
|
||||
"test_images": sum(
|
||||
len(self.logo_to_images[l]) for l in self.test_logos
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class LogoContrastiveDataset(Dataset):
|
||||
"""
|
||||
Dataset for contrastive learning on logos.
|
||||
|
||||
Each __getitem__ call returns a batch of images organized for contrastive
|
||||
learning: K different logos with M samples each, ensuring positive pairs
|
||||
exist within each batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logo_data: LogoDataset,
|
||||
split: str = "train",
|
||||
logos_per_batch: int = 32,
|
||||
samples_per_logo: int = 4,
|
||||
transform: Optional[transforms.Compose] = None,
|
||||
batches_per_epoch: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the contrastive dataset.
|
||||
|
||||
Args:
|
||||
logo_data: LogoDataset instance with logo mappings
|
||||
split: One of "train", "val", or "test"
|
||||
logos_per_batch: Number of different logos per batch
|
||||
samples_per_logo: Number of samples for each logo
|
||||
transform: Image transforms to apply
|
||||
batches_per_epoch: Number of batches per epoch
|
||||
"""
|
||||
self.logo_data = logo_data
|
||||
self.logos_per_batch = logos_per_batch
|
||||
self.samples_per_logo = samples_per_logo
|
||||
self.transform = transform
|
||||
self.batches_per_epoch = batches_per_epoch
|
||||
|
||||
# Get logos for this split
|
||||
if split == "train":
|
||||
self.logos = logo_data.train_logos
|
||||
elif split == "val":
|
||||
self.logos = logo_data.val_logos
|
||||
else:
|
||||
self.logos = logo_data.test_logos
|
||||
|
||||
# Filter logos with enough samples
|
||||
self.valid_logos = [
|
||||
logo for logo in self.logos
|
||||
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
|
||||
]
|
||||
|
||||
# For logos with fewer samples, we'll use with replacement
|
||||
self.logos_needing_replacement = [
|
||||
logo for logo in self.logos
|
||||
if len(logo_data.logo_to_images[logo]) < samples_per_logo
|
||||
]
|
||||
|
||||
# Create label mapping
|
||||
self.logo_to_label = {
|
||||
logo: idx for idx, logo in enumerate(self.logos)
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.batches_per_epoch
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get a batch of images for contrastive learning.
|
||||
|
||||
Returns:
|
||||
images: Tensor of shape [K*M, 3, 224, 224]
|
||||
labels: Tensor of shape [K*M] with logo class indices
|
||||
"""
|
||||
images = []
|
||||
labels = []
|
||||
|
||||
# Sample K logos for this batch
|
||||
k = min(self.logos_per_batch, len(self.logos))
|
||||
batch_logos = random.sample(self.logos, k)
|
||||
|
||||
for logo in batch_logos:
|
||||
logo_images = self.logo_data.logo_to_images[logo]
|
||||
|
||||
# Sample M images for this logo
|
||||
if len(logo_images) >= self.samples_per_logo:
|
||||
sampled_paths = random.sample(logo_images, self.samples_per_logo)
|
||||
else:
|
||||
# Sample with replacement if not enough images
|
||||
sampled_paths = random.choices(
|
||||
logo_images, k=self.samples_per_logo
|
||||
)
|
||||
|
||||
# Load and transform images
|
||||
for img_path in sampled_paths:
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
else:
|
||||
img = get_val_transforms()(img)
|
||||
images.append(img)
|
||||
labels.append(self.logo_to_label[logo])
|
||||
except Exception as e:
|
||||
# Skip problematic images, sample another
|
||||
continue
|
||||
|
||||
# Stack into tensors
|
||||
if len(images) == 0:
|
||||
# Fallback: return dummy batch
|
||||
return (
|
||||
torch.zeros(1, 3, 224, 224),
|
||||
torch.zeros(1, dtype=torch.long),
|
||||
)
|
||||
|
||||
images_tensor = torch.stack(images)
|
||||
labels_tensor = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
return images_tensor, labels_tensor
|
||||
|
||||
|
||||
class BalancedBatchSampler(Sampler):
|
||||
"""
|
||||
Sampler that ensures each batch has a balanced distribution of logos.
|
||||
|
||||
Used with a flattened dataset where each sample is a single image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logo_labels: List[int],
|
||||
logos_per_batch: int,
|
||||
samples_per_logo: int,
|
||||
num_batches: int,
|
||||
):
|
||||
self.logo_labels = logo_labels
|
||||
self.logos_per_batch = logos_per_batch
|
||||
self.samples_per_logo = samples_per_logo
|
||||
self.num_batches = num_batches
|
||||
|
||||
# Group indices by logo
|
||||
self.logo_to_indices: Dict[int, List[int]] = {}
|
||||
for idx, label in enumerate(logo_labels):
|
||||
if label not in self.logo_to_indices:
|
||||
self.logo_to_indices[label] = []
|
||||
self.logo_to_indices[label].append(idx)
|
||||
|
||||
self.all_logos = list(self.logo_to_indices.keys())
|
||||
|
||||
def __iter__(self):
|
||||
for _ in range(self.num_batches):
|
||||
batch_indices = []
|
||||
|
||||
# Sample logos for this batch
|
||||
logos = random.sample(
|
||||
self.all_logos,
|
||||
min(self.logos_per_batch, len(self.all_logos)),
|
||||
)
|
||||
|
||||
for logo in logos:
|
||||
indices = self.logo_to_indices[logo]
|
||||
if len(indices) >= self.samples_per_logo:
|
||||
sampled = random.sample(indices, self.samples_per_logo)
|
||||
else:
|
||||
sampled = random.choices(indices, k=self.samples_per_logo)
|
||||
batch_indices.extend(sampled)
|
||||
|
||||
yield batch_indices
|
||||
|
||||
def __len__(self):
|
||||
return self.num_batches
|
||||
|
||||
|
||||
def create_dataloaders(
|
||||
db_path: str,
|
||||
reference_dir: str,
|
||||
batch_size: int = 16,
|
||||
logos_per_batch: int = 32,
|
||||
samples_per_logo: int = 4,
|
||||
num_workers: int = 4,
|
||||
train_split: float = 0.7,
|
||||
val_split: float = 0.15,
|
||||
test_split: float = 0.15,
|
||||
seed: int = 42,
|
||||
augmentation_strength: str = "medium",
|
||||
batches_per_epoch: int = 1000,
|
||||
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
|
||||
"""
|
||||
Create train, validation, and optionally test dataloaders.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
reference_dir: Directory containing reference logo images
|
||||
batch_size: Not used directly (see logos_per_batch and samples_per_logo)
|
||||
logos_per_batch: Number of different logos per batch
|
||||
samples_per_logo: Samples per logo in batch
|
||||
num_workers: Number of data loading workers
|
||||
train_split: Fraction for training
|
||||
val_split: Fraction for validation
|
||||
test_split: Fraction for testing
|
||||
seed: Random seed
|
||||
augmentation_strength: "light", "medium", or "strong"
|
||||
batches_per_epoch: Number of batches per training epoch
|
||||
|
||||
Returns:
|
||||
Tuple of (train_loader, val_loader, test_loader)
|
||||
"""
|
||||
# Load logo data
|
||||
logo_data = LogoDataset(
|
||||
db_path=db_path,
|
||||
reference_dir=reference_dir,
|
||||
train_split=train_split,
|
||||
val_split=val_split,
|
||||
test_split=test_split,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# Print split info
|
||||
split_info = logo_data.get_split_info()
|
||||
print(f"Dataset loaded:")
|
||||
print(f" Total logos: {split_info['total_logos']}")
|
||||
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
|
||||
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
|
||||
print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images")
|
||||
|
||||
# Create datasets
|
||||
train_dataset = LogoContrastiveDataset(
|
||||
logo_data=logo_data,
|
||||
split="train",
|
||||
logos_per_batch=logos_per_batch,
|
||||
samples_per_logo=samples_per_logo,
|
||||
transform=get_train_transforms(augmentation_strength),
|
||||
batches_per_epoch=batches_per_epoch,
|
||||
)
|
||||
|
||||
val_dataset = LogoContrastiveDataset(
|
||||
logo_data=logo_data,
|
||||
split="val",
|
||||
logos_per_batch=logos_per_batch,
|
||||
samples_per_logo=samples_per_logo,
|
||||
transform=get_val_transforms(),
|
||||
batches_per_epoch=batches_per_epoch // 10, # Fewer val batches
|
||||
)
|
||||
|
||||
test_dataset = LogoContrastiveDataset(
|
||||
logo_data=logo_data,
|
||||
split="test",
|
||||
logos_per_batch=logos_per_batch,
|
||||
samples_per_logo=samples_per_logo,
|
||||
transform=get_val_transforms(),
|
||||
batches_per_epoch=batches_per_epoch // 10,
|
||||
) if test_split > 0 else None
|
||||
|
||||
# Create dataloaders
|
||||
# Note: batch_size=1 because each __getitem__ already returns a batch
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=_collate_contrastive_batch,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=_collate_contrastive_batch,
|
||||
)
|
||||
|
||||
test_loader = None
|
||||
if test_dataset is not None:
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=_collate_contrastive_batch,
|
||||
)
|
||||
|
||||
return train_loader, val_loader, test_loader
|
||||
|
||||
|
||||
def _collate_contrastive_batch(
|
||||
batch: List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Collate function that unpacks pre-batched data.
|
||||
|
||||
Since LogoContrastiveDataset already returns batched data,
|
||||
we just squeeze the outer dimension.
|
||||
"""
|
||||
images, labels = batch[0]
|
||||
return images, labels
|
||||
339
training/evaluation.py
Normal file
339
training/evaluation.py
Normal file
@ -0,0 +1,339 @@
|
||||
"""
|
||||
Evaluation metrics for embedding quality.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EmbeddingEvaluator:
|
||||
"""
|
||||
Evaluator for embedding quality metrics.
|
||||
|
||||
Computes metrics that indicate how well the embeddings
|
||||
separate different logo classes.
|
||||
"""
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Compute embedding quality metrics.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer class labels
|
||||
|
||||
Returns:
|
||||
Dict with metric names and values
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return {
|
||||
"mean_pos_sim": 0.0,
|
||||
"mean_neg_sim": 0.0,
|
||||
"separation": 0.0,
|
||||
"recall_at_1": 0.0,
|
||||
"recall_at_5": 0.0,
|
||||
}
|
||||
|
||||
# Compute similarity matrix
|
||||
similarity = embeddings @ embeddings.T
|
||||
|
||||
# Create masks
|
||||
labels_col = labels.unsqueeze(0)
|
||||
labels_row = labels.unsqueeze(1)
|
||||
positive_mask = (labels_row == labels_col).float()
|
||||
negative_mask = 1 - positive_mask
|
||||
|
||||
# Remove diagonal from positive mask
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Count pairs
|
||||
num_positives = positive_mask.sum()
|
||||
num_negatives = negative_mask.sum()
|
||||
|
||||
# Mean positive similarity (excluding self)
|
||||
if num_positives > 0:
|
||||
pos_sims = (similarity * positive_mask).sum() / num_positives
|
||||
mean_pos_sim = pos_sims.item()
|
||||
else:
|
||||
mean_pos_sim = 0.0
|
||||
|
||||
# Mean negative similarity
|
||||
if num_negatives > 0:
|
||||
neg_sims = (similarity * negative_mask).sum() / num_negatives
|
||||
mean_neg_sim = neg_sims.item()
|
||||
else:
|
||||
mean_neg_sim = 0.0
|
||||
|
||||
# Separation: gap between positive and negative similarity
|
||||
separation = mean_pos_sim - mean_neg_sim
|
||||
|
||||
# Recall@K metrics
|
||||
recall_at_1 = self._compute_recall_at_k(similarity, labels, k=1)
|
||||
recall_at_5 = self._compute_recall_at_k(similarity, labels, k=5)
|
||||
|
||||
return {
|
||||
"mean_pos_sim": mean_pos_sim,
|
||||
"mean_neg_sim": mean_neg_sim,
|
||||
"separation": separation,
|
||||
"recall_at_1": recall_at_1,
|
||||
"recall_at_5": recall_at_5,
|
||||
}
|
||||
|
||||
def _compute_recall_at_k(
|
||||
self,
|
||||
similarity: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
k: int = 1,
|
||||
) -> float:
|
||||
"""
|
||||
Compute Recall@K for nearest neighbor retrieval.
|
||||
|
||||
For each sample, check if the k nearest neighbors (excluding self)
|
||||
contain at least one sample with the same label.
|
||||
|
||||
Args:
|
||||
similarity: [N, N] similarity matrix
|
||||
labels: [N] class labels
|
||||
k: Number of neighbors to consider
|
||||
|
||||
Returns:
|
||||
Recall@K score (0 to 1)
|
||||
"""
|
||||
batch_size = similarity.shape[0]
|
||||
if batch_size <= 1:
|
||||
return 0.0
|
||||
|
||||
# Mask out self-similarity
|
||||
similarity = similarity.clone()
|
||||
similarity.fill_diagonal_(float("-inf"))
|
||||
|
||||
# Get top-k indices
|
||||
_, top_k_indices = similarity.topk(min(k, batch_size - 1), dim=1)
|
||||
|
||||
# Check if any of top-k have same label
|
||||
correct = 0
|
||||
for i in range(batch_size):
|
||||
query_label = labels[i]
|
||||
retrieved_labels = labels[top_k_indices[i]]
|
||||
if (retrieved_labels == query_label).any():
|
||||
correct += 1
|
||||
|
||||
return correct / batch_size
|
||||
|
||||
def compute_detailed_metrics(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
label_names: Optional[List[str]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Compute detailed per-class metrics.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] embeddings
|
||||
labels: [N] class labels
|
||||
label_names: Optional list of label names
|
||||
|
||||
Returns:
|
||||
Dict with detailed metrics including per-class stats
|
||||
"""
|
||||
basic_metrics = self.compute_metrics(embeddings, labels)
|
||||
|
||||
# Per-class statistics
|
||||
unique_labels = labels.unique()
|
||||
per_class_stats = {}
|
||||
|
||||
similarity = embeddings @ embeddings.T
|
||||
|
||||
for label in unique_labels:
|
||||
mask = labels == label
|
||||
class_embeddings = embeddings[mask]
|
||||
class_size = mask.sum().item()
|
||||
|
||||
if class_size > 1:
|
||||
# Intra-class similarity
|
||||
class_sim = class_embeddings @ class_embeddings.T
|
||||
# Exclude diagonal
|
||||
mask_diag = ~torch.eye(class_size, dtype=torch.bool, device=class_sim.device)
|
||||
intra_sim = class_sim[mask_diag].mean().item()
|
||||
else:
|
||||
intra_sim = 1.0
|
||||
|
||||
# Inter-class similarity (to other classes)
|
||||
other_mask = labels != label
|
||||
if other_mask.any():
|
||||
inter_sim = similarity[mask][:, other_mask].mean().item()
|
||||
else:
|
||||
inter_sim = 0.0
|
||||
|
||||
class_name = label_names[label.item()] if label_names else str(label.item())
|
||||
per_class_stats[class_name] = {
|
||||
"size": class_size,
|
||||
"intra_class_sim": intra_sim,
|
||||
"inter_class_sim": inter_sim,
|
||||
"class_separation": intra_sim - inter_sim,
|
||||
}
|
||||
|
||||
# Aggregate per-class stats
|
||||
if per_class_stats:
|
||||
separations = [s["class_separation"] for s in per_class_stats.values()]
|
||||
min_separation = min(separations)
|
||||
max_separation = max(separations)
|
||||
std_separation = np.std(separations)
|
||||
else:
|
||||
min_separation = max_separation = std_separation = 0.0
|
||||
|
||||
return {
|
||||
**basic_metrics,
|
||||
"per_class": per_class_stats,
|
||||
"min_class_separation": min_separation,
|
||||
"max_class_separation": max_separation,
|
||||
"std_class_separation": std_separation,
|
||||
}
|
||||
|
||||
|
||||
class SimilarityAnalyzer:
|
||||
"""
|
||||
Analyze similarity distributions for debugging and tuning.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def analyze_similarity_distribution(
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Get similarity distributions for positive and negative pairs.
|
||||
|
||||
Useful for choosing appropriate thresholds.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] embeddings
|
||||
labels: [N] class labels
|
||||
|
||||
Returns:
|
||||
Dict with 'positive_sims' and 'negative_sims' arrays
|
||||
"""
|
||||
similarity = (embeddings @ embeddings.T).cpu().numpy()
|
||||
labels_np = labels.cpu().numpy()
|
||||
|
||||
batch_size = len(labels_np)
|
||||
positive_sims = []
|
||||
negative_sims = []
|
||||
|
||||
for i in range(batch_size):
|
||||
for j in range(i + 1, batch_size):
|
||||
if labels_np[i] == labels_np[j]:
|
||||
positive_sims.append(similarity[i, j])
|
||||
else:
|
||||
negative_sims.append(similarity[i, j])
|
||||
|
||||
return {
|
||||
"positive_sims": np.array(positive_sims),
|
||||
"negative_sims": np.array(negative_sims),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def find_hard_pairs(
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
n_hard: int = 10,
|
||||
) -> Tuple[List[Tuple[int, int, float]], List[Tuple[int, int, float]]]:
|
||||
"""
|
||||
Find hardest positive and negative pairs.
|
||||
|
||||
Hard positives: same label but low similarity
|
||||
Hard negatives: different label but high similarity
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] embeddings
|
||||
labels: [N] class labels
|
||||
n_hard: Number of hard pairs to return
|
||||
|
||||
Returns:
|
||||
Tuple of (hard_positives, hard_negatives)
|
||||
Each is a list of (idx1, idx2, similarity) tuples
|
||||
"""
|
||||
similarity = embeddings @ embeddings.T
|
||||
batch_size = len(labels)
|
||||
|
||||
hard_positives = [] # Low similarity, same label
|
||||
hard_negatives = [] # High similarity, different label
|
||||
|
||||
for i in range(batch_size):
|
||||
for j in range(i + 1, batch_size):
|
||||
sim = similarity[i, j].item()
|
||||
if labels[i] == labels[j]:
|
||||
hard_positives.append((i, j, sim))
|
||||
else:
|
||||
hard_negatives.append((i, j, sim))
|
||||
|
||||
# Sort: hard positives by ascending similarity (lowest first)
|
||||
hard_positives.sort(key=lambda x: x[2])
|
||||
|
||||
# Sort: hard negatives by descending similarity (highest first)
|
||||
hard_negatives.sort(key=lambda x: -x[2])
|
||||
|
||||
return hard_positives[:n_hard], hard_negatives[:n_hard]
|
||||
|
||||
@staticmethod
|
||||
def compute_confusion_pairs(
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
label_names: Optional[List[str]] = None,
|
||||
top_k: int = 10,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Find pairs of classes that are most confused (highest cross-class similarity).
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] embeddings
|
||||
labels: [N] class labels
|
||||
label_names: Optional label names
|
||||
top_k: Number of confused pairs to return
|
||||
|
||||
Returns:
|
||||
List of dicts with class pairs and their similarity
|
||||
"""
|
||||
unique_labels = labels.unique()
|
||||
class_centroids = {}
|
||||
|
||||
# Compute class centroids
|
||||
for label in unique_labels:
|
||||
mask = labels == label
|
||||
centroid = embeddings[mask].mean(dim=0)
|
||||
centroid = F.normalize(centroid, dim=0)
|
||||
class_centroids[label.item()] = centroid
|
||||
|
||||
# Compute pairwise centroid similarities
|
||||
confusions = []
|
||||
label_list = list(class_centroids.keys())
|
||||
|
||||
for i, label1 in enumerate(label_list):
|
||||
for label2 in label_list[i + 1:]:
|
||||
sim = (class_centroids[label1] @ class_centroids[label2]).item()
|
||||
name1 = label_names[label1] if label_names else str(label1)
|
||||
name2 = label_names[label2] if label_names else str(label2)
|
||||
confusions.append({
|
||||
"class1": name1,
|
||||
"class2": name2,
|
||||
"label1": label1,
|
||||
"label2": label2,
|
||||
"centroid_similarity": sim,
|
||||
})
|
||||
|
||||
# Sort by similarity (highest first)
|
||||
confusions.sort(key=lambda x: -x["centroid_similarity"])
|
||||
|
||||
return confusions[:top_k]
|
||||
326
training/losses.py
Normal file
326
training/losses.py
Normal file
@ -0,0 +1,326 @@
|
||||
"""
|
||||
Loss functions for contrastive learning on logo embeddings.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InfoNCELoss(nn.Module):
|
||||
"""
|
||||
Normalized Temperature-scaled Cross Entropy Loss (InfoNCE).
|
||||
|
||||
This is the contrastive loss used in CLIP training. It maximizes
|
||||
similarity between embeddings of the same logo class while
|
||||
minimizing similarity to embeddings of different classes.
|
||||
|
||||
For a batch with N samples:
|
||||
- Each sample is an anchor
|
||||
- Positive pairs: samples with the same label
|
||||
- Negative pairs: samples with different labels
|
||||
|
||||
The loss for each anchor is:
|
||||
-log(sum(exp(sim(anchor, pos)/temp)) / sum(exp(sim(anchor, all)/temp)))
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 0.07):
|
||||
"""
|
||||
Initialize InfoNCE loss.
|
||||
|
||||
Args:
|
||||
temperature: Scaling factor for similarities (0.05-0.1 typical).
|
||||
Lower temperature makes the distribution sharper.
|
||||
"""
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute InfoNCE loss for a batch of embeddings.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer logo class labels
|
||||
|
||||
Returns:
|
||||
Scalar loss value
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Compute similarity matrix [N, N]
|
||||
# Since embeddings are L2-normalized, dot product = cosine similarity
|
||||
similarity = embeddings @ embeddings.T / self.temperature
|
||||
|
||||
# Create positive mask: same label = 1, different = 0
|
||||
labels_col = labels.unsqueeze(0) # [1, N]
|
||||
labels_row = labels.unsqueeze(1) # [N, 1]
|
||||
positive_mask = (labels_row == labels_col).float() # [N, N]
|
||||
|
||||
# Remove self-similarity from positives (diagonal)
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Count positives per anchor (avoid division by zero)
|
||||
num_positives = positive_mask.sum(dim=1)
|
||||
has_positives = num_positives > 0
|
||||
|
||||
# If no positives exist for any anchor, return zero loss
|
||||
if not has_positives.any():
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Mask out self-similarity with large negative value
|
||||
similarity = similarity - identity * 1e9
|
||||
|
||||
# Compute log-softmax over similarities
|
||||
log_softmax = F.log_softmax(similarity, dim=1)
|
||||
|
||||
# Sum log probabilities of positive pairs
|
||||
positive_log_probs = (log_softmax * positive_mask).sum(dim=1)
|
||||
|
||||
# Average over number of positives (only for anchors with positives)
|
||||
loss_per_anchor = torch.zeros(batch_size, device=device)
|
||||
loss_per_anchor[has_positives] = (
|
||||
-positive_log_probs[has_positives] / num_positives[has_positives]
|
||||
)
|
||||
|
||||
return loss_per_anchor.mean()
|
||||
|
||||
|
||||
class SupConLoss(nn.Module):
|
||||
"""
|
||||
Supervised Contrastive Loss.
|
||||
|
||||
Similar to InfoNCE but uses a different formulation that
|
||||
considers each positive pair separately rather than averaging.
|
||||
|
||||
Reference: https://arxiv.org/abs/2004.11362
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 0.07):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Supervised Contrastive loss.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer logo class labels
|
||||
|
||||
Returns:
|
||||
Scalar loss value
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Compute similarity matrix
|
||||
similarity = embeddings @ embeddings.T / self.temperature
|
||||
|
||||
# Create masks
|
||||
labels_col = labels.unsqueeze(0)
|
||||
labels_row = labels.unsqueeze(1)
|
||||
positive_mask = (labels_row == labels_col).float()
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
|
||||
# Remove self from positives
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Number of positives per anchor
|
||||
num_positives = positive_mask.sum(dim=1)
|
||||
has_positives = num_positives > 0
|
||||
|
||||
if not has_positives.any():
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# For numerical stability, subtract max similarity
|
||||
sim_max, _ = similarity.max(dim=1, keepdim=True)
|
||||
similarity = similarity - sim_max.detach()
|
||||
|
||||
# Compute exp(similarity) with self masked out
|
||||
exp_sim = torch.exp(similarity) * (1 - identity)
|
||||
|
||||
# Denominator: sum of exp over all pairs except self
|
||||
log_prob = similarity - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
# Mean of log-prob over positive pairs
|
||||
mean_log_prob_pos = (positive_mask * log_prob).sum(dim=1) / (
|
||||
num_positives + 1e-8
|
||||
)
|
||||
|
||||
# Loss is negative mean log probability
|
||||
loss = -mean_log_prob_pos[has_positives].mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
"""
|
||||
Triplet loss with online hard mining.
|
||||
|
||||
For each anchor:
|
||||
- Hardest positive: most distant sample with same label
|
||||
- Hardest negative: closest sample with different label
|
||||
|
||||
Loss = max(0, d(anchor, hardest_pos) - d(anchor, hardest_neg) + margin)
|
||||
|
||||
This is an alternative to InfoNCE for when batch sizes are small.
|
||||
"""
|
||||
|
||||
def __init__(self, margin: float = 0.3):
|
||||
"""
|
||||
Initialize Triplet loss.
|
||||
|
||||
Args:
|
||||
margin: Minimum required gap between positive and negative distances
|
||||
"""
|
||||
super().__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute triplet loss with online hard mining.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer logo class labels
|
||||
|
||||
Returns:
|
||||
Scalar loss value
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Compute pairwise cosine distances (1 - cosine_similarity)
|
||||
# For normalized vectors: distance = 1 - dot_product
|
||||
similarity = embeddings @ embeddings.T
|
||||
distances = 1 - similarity
|
||||
|
||||
# Create masks
|
||||
labels_col = labels.unsqueeze(0)
|
||||
labels_row = labels.unsqueeze(1)
|
||||
positive_mask = (labels_row == labels_col).float()
|
||||
negative_mask = 1 - positive_mask
|
||||
|
||||
# Remove self from positives (diagonal)
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Check if we have any valid triplets
|
||||
has_positives = positive_mask.sum(dim=1) > 0
|
||||
has_negatives = negative_mask.sum(dim=1) > 0
|
||||
valid_anchors = has_positives & has_negatives
|
||||
|
||||
if not valid_anchors.any():
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# For each anchor, find hardest positive (max distance among positives)
|
||||
# Set negatives to -inf so they don't affect max
|
||||
pos_distances = distances.clone()
|
||||
pos_distances[positive_mask == 0] = float("-inf")
|
||||
hardest_positive, _ = pos_distances.max(dim=1)
|
||||
|
||||
# For each anchor, find hardest negative (min distance among negatives)
|
||||
# Set positives to inf so they don't affect min
|
||||
neg_distances = distances.clone()
|
||||
neg_distances[negative_mask == 0] = float("inf")
|
||||
hardest_negative, _ = neg_distances.min(dim=1)
|
||||
|
||||
# Triplet loss: want positive to be closer than negative by margin
|
||||
triplet_loss = F.relu(
|
||||
hardest_positive - hardest_negative + self.margin
|
||||
)
|
||||
|
||||
# Average over valid anchors only
|
||||
loss = triplet_loss[valid_anchors].mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""
|
||||
Combined loss function with weighted InfoNCE and Triplet losses.
|
||||
|
||||
Can help stabilize training by combining the benefits of both losses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: float = 0.07,
|
||||
triplet_margin: float = 0.3,
|
||||
infonce_weight: float = 1.0,
|
||||
triplet_weight: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.infonce = InfoNCELoss(temperature=temperature)
|
||||
self.triplet = TripletLoss(margin=triplet_margin)
|
||||
self.infonce_weight = infonce_weight
|
||||
self.triplet_weight = triplet_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
infonce_loss = self.infonce(embeddings, labels)
|
||||
triplet_loss = self.triplet(embeddings, labels)
|
||||
|
||||
return (
|
||||
self.infonce_weight * infonce_loss +
|
||||
self.triplet_weight * triplet_loss
|
||||
)
|
||||
|
||||
|
||||
def get_loss_function(
|
||||
loss_type: str = "infonce",
|
||||
temperature: float = 0.07,
|
||||
triplet_margin: float = 0.3,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Factory function to create loss function.
|
||||
|
||||
Args:
|
||||
loss_type: One of "infonce", "supcon", "triplet", or "combined"
|
||||
temperature: Temperature for InfoNCE/SupCon
|
||||
triplet_margin: Margin for triplet loss
|
||||
|
||||
Returns:
|
||||
Loss function module
|
||||
"""
|
||||
if loss_type == "infonce":
|
||||
return InfoNCELoss(temperature=temperature)
|
||||
elif loss_type == "supcon":
|
||||
return SupConLoss(temperature=temperature)
|
||||
elif loss_type == "triplet":
|
||||
return TripletLoss(margin=triplet_margin)
|
||||
elif loss_type == "combined":
|
||||
return CombinedLoss(
|
||||
temperature=temperature,
|
||||
triplet_margin=triplet_margin,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {loss_type}")
|
||||
335
training/model.py
Normal file
335
training/model.py
Normal file
@ -0,0 +1,335 @@
|
||||
"""
|
||||
Fine-tunable CLIP model wrapper with LoRA support.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
# Check if peft is available for LoRA
|
||||
try:
|
||||
from peft import LoraConfig, get_peft_model, PeftModel
|
||||
PEFT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PEFT_AVAILABLE = False
|
||||
LoraConfig = None
|
||||
get_peft_model = None
|
||||
PeftModel = None
|
||||
|
||||
|
||||
class LogoFineTunedCLIP(nn.Module):
|
||||
"""
|
||||
CLIP vision encoder fine-tuned for logo similarity.
|
||||
|
||||
Preserves embedding interface for compatibility with DetectLogosDETR:
|
||||
- Same embedding dimensionality (768 for ViT-L/14)
|
||||
- L2 normalized outputs
|
||||
- Works with existing get_image_features() pattern
|
||||
|
||||
Supports:
|
||||
- LoRA for memory-efficient fine-tuning
|
||||
- Layer freezing for transfer learning
|
||||
- Gradient checkpointing for memory optimization
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_model: nn.Module,
|
||||
lora_r: int = 16,
|
||||
lora_alpha: int = 32,
|
||||
lora_dropout: float = 0.1,
|
||||
freeze_layers: int = 12,
|
||||
use_gradient_checkpointing: bool = True,
|
||||
add_projection_head: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the fine-tunable CLIP wrapper.
|
||||
|
||||
Args:
|
||||
vision_model: CLIP vision model (CLIPVisionModel)
|
||||
lora_r: Rank of LoRA low-rank matrices (0 to disable)
|
||||
lora_alpha: LoRA scaling factor
|
||||
lora_dropout: Dropout for LoRA layers
|
||||
freeze_layers: Number of transformer layers to freeze (from bottom)
|
||||
use_gradient_checkpointing: Enable gradient checkpointing
|
||||
add_projection_head: Add trainable projection head
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = vision_model
|
||||
self.embedding_dim = vision_model.config.hidden_size
|
||||
self.freeze_layers = freeze_layers
|
||||
self.lora_r = lora_r
|
||||
self.lora_alpha = lora_alpha
|
||||
|
||||
# Enable gradient checkpointing for memory efficiency
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(self.vision_model, "gradient_checkpointing_enable"):
|
||||
self.vision_model.gradient_checkpointing_enable()
|
||||
|
||||
# Freeze lower layers
|
||||
self._freeze_layers(freeze_layers)
|
||||
|
||||
# Apply LoRA to attention layers in upper blocks
|
||||
self.peft_applied = False
|
||||
if PEFT_AVAILABLE and lora_r > 0:
|
||||
self._apply_lora(lora_r, lora_alpha, lora_dropout)
|
||||
self.peft_applied = True
|
||||
elif lora_r > 0 and not PEFT_AVAILABLE:
|
||||
print(
|
||||
"Warning: peft not installed. LoRA disabled. "
|
||||
"Install with: pip install peft"
|
||||
)
|
||||
|
||||
# Optional projection head for fine-tuning
|
||||
self.add_projection_head = add_projection_head
|
||||
if add_projection_head:
|
||||
self.projection = nn.Sequential(
|
||||
nn.Linear(self.embedding_dim, self.embedding_dim),
|
||||
nn.LayerNorm(self.embedding_dim),
|
||||
)
|
||||
else:
|
||||
self.projection = nn.Identity()
|
||||
|
||||
def _freeze_layers(self, num_layers: int) -> None:
|
||||
"""Freeze the first N transformer layers and embeddings."""
|
||||
if num_layers <= 0:
|
||||
return
|
||||
|
||||
# Freeze embeddings
|
||||
if hasattr(self.vision_model, "embeddings"):
|
||||
for param in self.vision_model.embeddings.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze specified number of encoder layers
|
||||
if hasattr(self.vision_model, "encoder"):
|
||||
for i, layer in enumerate(self.vision_model.encoder.layers):
|
||||
if i < num_layers:
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _apply_lora(
|
||||
self,
|
||||
r: int,
|
||||
alpha: int,
|
||||
dropout: float,
|
||||
) -> None:
|
||||
"""Apply LoRA adapters to attention layers."""
|
||||
if not PEFT_AVAILABLE:
|
||||
return
|
||||
|
||||
# Configure LoRA for vision transformer
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
lora_alpha=alpha,
|
||||
lora_dropout=dropout,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
bias="none",
|
||||
modules_to_save=[], # Don't save any full modules
|
||||
)
|
||||
|
||||
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extract normalized embeddings for logo images.
|
||||
|
||||
Args:
|
||||
pixel_values: [batch, 3, 224, 224] preprocessed images
|
||||
|
||||
Returns:
|
||||
embeddings: [batch, embedding_dim] L2-normalized
|
||||
"""
|
||||
# Get vision features
|
||||
outputs = self.vision_model(pixel_values=pixel_values)
|
||||
|
||||
# Use pooler output (CLS token projection) if available
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
features = outputs.pooler_output
|
||||
else:
|
||||
# Fall back to CLS token from last hidden state
|
||||
features = outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
# Apply projection head
|
||||
features = self.projection(features)
|
||||
|
||||
# L2 normalize for cosine similarity
|
||||
features = F.normalize(features, dim=-1)
|
||||
|
||||
return features
|
||||
|
||||
def get_image_features(self, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Compatibility method matching CLIP's interface.
|
||||
|
||||
Used by DetectLogosDETR._get_embedding_pil().
|
||||
"""
|
||||
return self.forward(kwargs["pixel_values"])
|
||||
|
||||
def get_trainable_parameters(self) -> List[torch.nn.Parameter]:
|
||||
"""Return list of trainable parameters."""
|
||||
return [p for p in self.parameters() if p.requires_grad]
|
||||
|
||||
def get_parameter_count(self) -> Dict[str, int]:
|
||||
"""Return count of trainable and total parameters."""
|
||||
total = sum(p.numel() for p in self.parameters())
|
||||
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
return {
|
||||
"total": total,
|
||||
"trainable": trainable,
|
||||
"frozen": total - trainable,
|
||||
"trainable_percent": 100 * trainable / total if total > 0 else 0,
|
||||
}
|
||||
|
||||
def save_pretrained(self, output_dir: str) -> None:
|
||||
"""
|
||||
Save model in HuggingFace-compatible format.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save model files
|
||||
"""
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model weights
|
||||
if self.peft_applied and PEFT_AVAILABLE:
|
||||
# Save LoRA weights separately
|
||||
self.vision_model.save_pretrained(output_path / "vision_lora")
|
||||
# Save projection head
|
||||
torch.save(
|
||||
self.projection.state_dict(),
|
||||
output_path / "projection_head.bin",
|
||||
)
|
||||
else:
|
||||
# Save full model state
|
||||
torch.save(self.state_dict(), output_path / "pytorch_model.bin")
|
||||
|
||||
# Save config
|
||||
config = {
|
||||
"model_type": "clip_logo_finetuned",
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"lora_r": self.lora_r,
|
||||
"lora_alpha": self.lora_alpha,
|
||||
"freeze_layers": self.freeze_layers,
|
||||
"add_projection_head": self.add_projection_head,
|
||||
"peft_applied": self.peft_applied,
|
||||
}
|
||||
|
||||
with open(output_path / "config.json", "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_path: str,
|
||||
base_model: str = "openai/clip-vit-large-patch14",
|
||||
device: Optional[torch.device] = None,
|
||||
) -> "LogoFineTunedCLIP":
|
||||
"""
|
||||
Load a fine-tuned model from saved weights.
|
||||
|
||||
Args:
|
||||
model_path: Path to saved model directory
|
||||
base_model: Base CLIP model name (for architecture)
|
||||
device: Device to load model on
|
||||
|
||||
Returns:
|
||||
Loaded LogoFineTunedCLIP model
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Load config
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Load base CLIP model
|
||||
clip_model = CLIPModel.from_pretrained(base_model)
|
||||
|
||||
# Create model instance
|
||||
model = cls(
|
||||
vision_model=clip_model.vision_model,
|
||||
lora_r=config.get("lora_r", 0),
|
||||
lora_alpha=config.get("lora_alpha", 1),
|
||||
freeze_layers=config.get("freeze_layers", 12),
|
||||
add_projection_head=config.get("add_projection_head", True),
|
||||
use_gradient_checkpointing=False, # Not needed for inference
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if config.get("peft_applied", False) and PEFT_AVAILABLE:
|
||||
# Load LoRA weights
|
||||
lora_path = model_path / "vision_lora"
|
||||
if lora_path.exists():
|
||||
model.vision_model = PeftModel.from_pretrained(
|
||||
model.vision_model, lora_path
|
||||
)
|
||||
# Load projection head
|
||||
proj_path = model_path / "projection_head.bin"
|
||||
if proj_path.exists():
|
||||
model.projection.load_state_dict(torch.load(proj_path))
|
||||
else:
|
||||
# Load full model state
|
||||
weights_path = model_path / "pytorch_model.bin"
|
||||
if weights_path.exists():
|
||||
model.load_state_dict(torch.load(weights_path))
|
||||
|
||||
if device is not None:
|
||||
model = model.to(device)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_model(
|
||||
base_model: str = "openai/clip-vit-large-patch14",
|
||||
lora_r: int = 16,
|
||||
lora_alpha: int = 32,
|
||||
lora_dropout: float = 0.1,
|
||||
freeze_layers: int = 12,
|
||||
use_gradient_checkpointing: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[LogoFineTunedCLIP, CLIPProcessor]:
|
||||
"""
|
||||
Create a fine-tunable CLIP model and processor.
|
||||
|
||||
Args:
|
||||
base_model: HuggingFace model name or path
|
||||
lora_r: LoRA rank (0 to disable)
|
||||
lora_alpha: LoRA scaling factor
|
||||
lora_dropout: LoRA dropout
|
||||
freeze_layers: Number of layers to freeze
|
||||
use_gradient_checkpointing: Enable gradient checkpointing
|
||||
device: Device to load model on
|
||||
|
||||
Returns:
|
||||
Tuple of (model, processor)
|
||||
"""
|
||||
# Load base CLIP model
|
||||
clip_model = CLIPModel.from_pretrained(base_model)
|
||||
processor = CLIPProcessor.from_pretrained(base_model)
|
||||
|
||||
# Create fine-tunable wrapper
|
||||
model = LogoFineTunedCLIP(
|
||||
vision_model=clip_model.vision_model,
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
freeze_layers=freeze_layers,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model = model.to(device)
|
||||
|
||||
# Print parameter info
|
||||
param_info = model.get_parameter_count()
|
||||
print(f"Model created:")
|
||||
print(f" Total parameters: {param_info['total']:,}")
|
||||
print(f" Trainable: {param_info['trainable']:,} ({param_info['trainable_percent']:.2f}%)")
|
||||
print(f" Frozen: {param_info['frozen']:,}")
|
||||
|
||||
return model, processor
|
||||
405
training/trainer.py
Normal file
405
training/trainer.py
Normal file
@ -0,0 +1,405 @@
|
||||
"""
|
||||
Training loop with checkpointing, mixed precision, and evaluation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import TrainingConfig
|
||||
from .losses import get_loss_function
|
||||
from .evaluation import EmbeddingEvaluator
|
||||
|
||||
# Check if amp is available
|
||||
try:
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
AMP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AMP_AVAILABLE = False
|
||||
autocast = None
|
||||
GradScaler = None
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Trainer for fine-tuning CLIP on logo recognition.
|
||||
|
||||
Features:
|
||||
- Mixed precision training (FP16)
|
||||
- Gradient accumulation
|
||||
- Gradient checkpointing (via model)
|
||||
- Cosine annealing LR scheduler
|
||||
- Early stopping
|
||||
- Checkpoint saving/loading
|
||||
- Evaluation during training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: TrainingConfig,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the trainer.
|
||||
|
||||
Args:
|
||||
model: LogoFineTunedCLIP model
|
||||
train_loader: Training dataloader
|
||||
val_loader: Validation dataloader
|
||||
config: Training configuration
|
||||
logger: Optional logger instance
|
||||
"""
|
||||
self.model = model
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.config = config
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
# Device setup
|
||||
self.device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
self.model.to(self.device)
|
||||
self.logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Optimizer - only trainable parameters
|
||||
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
|
||||
|
||||
self.optimizer = AdamW(
|
||||
trainable_params,
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay,
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
total_steps = len(train_loader) * config.max_epochs
|
||||
self.scheduler = OneCycleLR(
|
||||
self.optimizer,
|
||||
max_lr=config.learning_rate,
|
||||
total_steps=total_steps,
|
||||
pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1,
|
||||
anneal_strategy="cos",
|
||||
)
|
||||
|
||||
# Mixed precision training
|
||||
self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda"
|
||||
if self.use_amp:
|
||||
self.scaler = GradScaler()
|
||||
self.logger.info("Mixed precision training enabled")
|
||||
else:
|
||||
self.scaler = None
|
||||
if config.mixed_precision and not AMP_AVAILABLE:
|
||||
self.logger.warning("Mixed precision requested but not available")
|
||||
|
||||
# Loss function
|
||||
self.criterion = get_loss_function(
|
||||
loss_type=config.loss_type,
|
||||
temperature=config.temperature,
|
||||
triplet_margin=config.triplet_margin,
|
||||
)
|
||||
|
||||
# Evaluator
|
||||
self.evaluator = EmbeddingEvaluator()
|
||||
|
||||
# Training state
|
||||
self.epoch = 0
|
||||
self.global_step = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.best_val_separation = float("-inf")
|
||||
self.patience_counter = 0
|
||||
self.training_history = []
|
||||
|
||||
def train(self) -> Dict[str, float]:
|
||||
"""
|
||||
Main training loop.
|
||||
|
||||
Returns:
|
||||
Dict with final training metrics
|
||||
"""
|
||||
self.logger.info("Starting training...")
|
||||
self.logger.info(f" Epochs: {self.config.max_epochs}")
|
||||
self.logger.info(f" Batch size: {self.config.batch_size}")
|
||||
self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}")
|
||||
self.logger.info(f" Effective batch: {self.config.effective_batch_size}")
|
||||
self.logger.info(f" Learning rate: {self.config.learning_rate}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(self.epoch, self.config.max_epochs):
|
||||
self.epoch = epoch
|
||||
self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}")
|
||||
|
||||
# Training epoch
|
||||
train_metrics = self._train_epoch()
|
||||
self.logger.info(
|
||||
f"Train - Loss: {train_metrics['loss']:.4f}, "
|
||||
f"LR: {train_metrics['lr']:.2e}"
|
||||
)
|
||||
|
||||
# Validation
|
||||
if (epoch + 1) % self.config.eval_every_n_epochs == 0:
|
||||
val_metrics = self._validate()
|
||||
self.logger.info(
|
||||
f"Val - Loss: {val_metrics['loss']:.4f}, "
|
||||
f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, "
|
||||
f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, "
|
||||
f"Separation: {val_metrics['separation']:.3f}"
|
||||
)
|
||||
|
||||
# Record history
|
||||
self.training_history.append({
|
||||
"epoch": epoch + 1,
|
||||
"train_loss": train_metrics["loss"],
|
||||
"val_loss": val_metrics["loss"],
|
||||
"val_separation": val_metrics["separation"],
|
||||
"val_pos_sim": val_metrics["mean_pos_sim"],
|
||||
"val_neg_sim": val_metrics["mean_neg_sim"],
|
||||
})
|
||||
|
||||
# Checkpointing based on separation (primary) or loss (secondary)
|
||||
improved = False
|
||||
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
|
||||
self.best_val_separation = val_metrics["separation"]
|
||||
improved = True
|
||||
elif val_metrics["loss"] < self.best_val_loss - self.config.min_delta:
|
||||
self.best_val_loss = val_metrics["loss"]
|
||||
improved = True
|
||||
|
||||
if improved:
|
||||
self.patience_counter = 0
|
||||
self._save_checkpoint("best.pt")
|
||||
self.logger.info("New best model saved!")
|
||||
else:
|
||||
self.patience_counter += 1
|
||||
|
||||
# Early stopping
|
||||
if self.patience_counter >= self.config.patience:
|
||||
self.logger.info(
|
||||
f"Early stopping triggered at epoch {epoch + 1} "
|
||||
f"(no improvement for {self.config.patience} epochs)"
|
||||
)
|
||||
break
|
||||
|
||||
# Periodic checkpoint
|
||||
if (epoch + 1) % self.config.save_every_n_epochs == 0:
|
||||
self._save_checkpoint(f"epoch_{epoch + 1}.pt")
|
||||
|
||||
# Training complete
|
||||
total_time = time.time() - start_time
|
||||
self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes")
|
||||
|
||||
# Load best model
|
||||
best_path = Path(self.config.checkpoint_dir) / "best.pt"
|
||||
if best_path.exists():
|
||||
self.load_checkpoint("best.pt")
|
||||
self.logger.info("Loaded best model checkpoint")
|
||||
|
||||
return {
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"best_val_separation": self.best_val_separation,
|
||||
"total_epochs": self.epoch + 1,
|
||||
"total_time_minutes": total_time / 60,
|
||||
}
|
||||
|
||||
def _train_epoch(self) -> Dict[str, float]:
|
||||
"""Run a single training epoch."""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
accumulation_steps = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
self.train_loader,
|
||||
desc=f"Epoch {self.epoch + 1}",
|
||||
leave=False,
|
||||
)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
for batch_idx, (images, labels) in enumerate(progress_bar):
|
||||
images = images.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
# Forward pass with mixed precision
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
loss = loss / self.config.gradient_accumulation_steps
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
else:
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
loss = loss / self.config.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
accumulation_steps += 1
|
||||
|
||||
# Optimizer step after accumulation
|
||||
if accumulation_steps >= self.config.gradient_accumulation_steps:
|
||||
if self.use_amp:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
self.global_step += 1
|
||||
accumulation_steps = 0
|
||||
|
||||
total_loss += loss.item() * self.config.gradient_accumulation_steps
|
||||
num_batches += 1
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({
|
||||
"loss": total_loss / num_batches,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
})
|
||||
|
||||
# Logging
|
||||
if (batch_idx + 1) % self.config.log_every_n_steps == 0:
|
||||
self.logger.debug(
|
||||
f"Step {self.global_step}: loss={total_loss / num_batches:.4f}"
|
||||
)
|
||||
|
||||
return {
|
||||
"loss": total_loss / max(num_batches, 1),
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
}
|
||||
|
||||
def _validate(self) -> Dict[str, float]:
|
||||
"""Run validation and compute metrics."""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
all_embeddings = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
|
||||
images = images.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
else:
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
|
||||
total_loss += loss.item()
|
||||
all_embeddings.append(embeddings.cpu())
|
||||
all_labels.append(labels.cpu())
|
||||
|
||||
# Combine batches
|
||||
all_embeddings = torch.cat(all_embeddings, dim=0)
|
||||
all_labels = torch.cat(all_labels, dim=0)
|
||||
|
||||
# Compute embedding quality metrics
|
||||
metrics = self.evaluator.compute_metrics(all_embeddings, all_labels)
|
||||
metrics["loss"] = total_loss / max(len(self.val_loader), 1)
|
||||
|
||||
return metrics
|
||||
|
||||
def _save_checkpoint(self, filename: str) -> None:
|
||||
"""Save training checkpoint."""
|
||||
checkpoint_dir = Path(self.config.checkpoint_dir)
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
checkpoint = {
|
||||
"epoch": self.epoch,
|
||||
"global_step": self.global_step,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"scheduler_state_dict": self.scheduler.state_dict(),
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"best_val_separation": self.best_val_separation,
|
||||
"patience_counter": self.patience_counter,
|
||||
"training_history": self.training_history,
|
||||
"config": self.config.__dict__,
|
||||
}
|
||||
|
||||
if self.scaler is not None:
|
||||
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
|
||||
|
||||
torch.save(checkpoint, checkpoint_dir / filename)
|
||||
self.logger.debug(f"Saved checkpoint: {filename}")
|
||||
|
||||
def load_checkpoint(self, filename: str) -> None:
|
||||
"""Load training checkpoint."""
|
||||
checkpoint_path = Path(self.config.checkpoint_dir) / filename
|
||||
if not checkpoint_path.exists():
|
||||
self.logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
||||
return
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
self.epoch = checkpoint["epoch"]
|
||||
self.global_step = checkpoint["global_step"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
self.best_val_separation = checkpoint.get("best_val_separation", float("-inf"))
|
||||
self.patience_counter = checkpoint.get("patience_counter", 0)
|
||||
self.training_history = checkpoint.get("training_history", [])
|
||||
|
||||
if self.scaler is not None and "scaler_state_dict" in checkpoint:
|
||||
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
||||
|
||||
self.logger.info(f"Resumed from epoch {self.epoch + 1}")
|
||||
|
||||
def export_model(self, output_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
Export the trained model for inference.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory (uses config.output_dir if not specified)
|
||||
|
||||
Returns:
|
||||
Path to exported model directory
|
||||
"""
|
||||
output_dir = output_dir or self.config.output_dir
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
self.model.save_pretrained(output_dir)
|
||||
|
||||
# Save training config
|
||||
config_path = output_path / "training_config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(self.config.__dict__, f, indent=2)
|
||||
|
||||
# Save training history
|
||||
history_path = output_path / "training_history.json"
|
||||
with open(history_path, "w") as f:
|
||||
json.dump(self.training_history, f, indent=2)
|
||||
|
||||
self.logger.info(f"Model exported to: {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
def get_training_summary(self) -> Dict:
|
||||
"""Get summary of training."""
|
||||
return {
|
||||
"epochs_completed": self.epoch + 1,
|
||||
"global_steps": self.global_step,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"best_val_separation": self.best_val_separation,
|
||||
"history": self.training_history,
|
||||
}
|
||||
Reference in New Issue
Block a user