From 44e8b6ae7dd8f22a3b05856010de181363332b02 Mon Sep 17 00:00:00 2001 From: Rick McEwen Date: Sun, 4 Jan 2026 13:45:25 -0500 Subject: [PATCH] Add CLIP fine-tuning pipeline for logo recognition Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies --- CLIP_FINETUNING.md | 266 +++++++++++++++++++++ CLOUD_TRAINING.md | 269 +++++++++++++++++++++ configs/cloud_a100.yaml | 64 +++++ configs/cloud_rtx4090.yaml | 64 +++++ configs/jetson_orin.yaml | 76 ++++++ export_model.py | 169 ++++++++++++++ logo_detection_detr.py | 89 ++++++- pyproject.toml | 3 + train_clip_logo.py | 309 ++++++++++++++++++++++++ training/__init__.py | 24 ++ training/config.py | 141 +++++++++++ training/dataset.py | 467 +++++++++++++++++++++++++++++++++++++ training/evaluation.py | 339 +++++++++++++++++++++++++++ training/losses.py | 326 ++++++++++++++++++++++++++ training/model.py | 335 ++++++++++++++++++++++++++ training/trainer.py | 405 ++++++++++++++++++++++++++++++++ 16 files changed, 3334 insertions(+), 12 deletions(-) create mode 100644 CLIP_FINETUNING.md create mode 100644 CLOUD_TRAINING.md create mode 100644 configs/cloud_a100.yaml create mode 100644 configs/cloud_rtx4090.yaml create mode 100644 configs/jetson_orin.yaml create mode 100644 export_model.py create mode 100644 train_clip_logo.py create mode 100644 training/__init__.py create mode 100644 training/config.py create mode 100644 training/dataset.py create mode 100644 training/evaluation.py create mode 100644 training/losses.py create mode 100644 training/model.py create mode 100644 training/trainer.py diff --git a/CLIP_FINETUNING.md b/CLIP_FINETUNING.md new file mode 100644 index 0000000..352b72f --- /dev/null +++ b/CLIP_FINETUNING.md @@ -0,0 +1,266 @@ +# CLIP Fine-Tuning for Logo Recognition + +This document describes the CLIP fine-tuning pipeline for improving logo embedding similarity using the LogoDet-3K dataset. + +## Overview + +The fine-tuning approach uses **contrastive learning** with **LoRA** (Low-Rank Adaptation) to train CLIP's vision encoder for better logo similarity matching while maintaining compatibility with the existing `DetectLogosDETR` class. + +**Goal**: Improve F1 from ~60% to >72% on logo matching tasks. + +## Files Created + +### Training Module (`training/`) + +| File | Description | +|------|-------------| +| `__init__.py` | Module exports | +| `config.py` | `TrainingConfig` dataclass with all hyperparameters | +| `dataset.py` | `LogoContrastiveDataset` with logo-level splits and augmentations | +| `model.py` | `LogoFineTunedCLIP` wrapper with LoRA support | +| `losses.py` | `InfoNCELoss`, `TripletLoss`, `SupConLoss`, `CombinedLoss` | +| `trainer.py` | Training loop with mixed precision, checkpointing, early stopping | +| `evaluation.py` | `EmbeddingEvaluator` for validation metrics | + +### Scripts + +| File | Description | +|------|-------------| +| `train_clip_logo.py` | Main training entry point | +| `export_model.py` | Export trained models to HuggingFace-compatible format | + +### Configuration + +| File | Description | +|------|-------------| +| `configs/jetson_orin.yaml` | Training config optimized for Jetson Orin AGX | + +## Prerequisites + +1. **Install dependencies**: + ```bash + uv sync + ``` + +2. **Prepare test data** (if not already done): + ```bash + uv run python prepare_test_data.py + ``` + + This creates: + - `reference_logos/` - Cropped logo images organized by category/brand + - `test_images/` - Full images for testing + - `test_data_mapping.db` - SQLite database with mappings + +## Training + +### Basic Training + +```bash +uv run python train_clip_logo.py --config configs/jetson_orin.yaml +``` + +### Training with Overrides + +```bash +uv run python train_clip_logo.py --config configs/jetson_orin.yaml \ + --learning-rate 5e-6 \ + --max-epochs 30 \ + --batch-size 8 +``` + +### Resume from Checkpoint + +```bash +uv run python train_clip_logo.py --config configs/jetson_orin.yaml \ + --resume checkpoints/epoch_10.pt +``` + +### Training Output + +- Checkpoints saved to `checkpoints/` +- Best model saved as `checkpoints/best.pt` +- Final model exported to `models/logo_detection/clip_finetuned/` + +## Configuration Options + +Key parameters in `configs/jetson_orin.yaml`: + +```yaml +# Model +base_model: "openai/clip-vit-large-patch14" +lora_r: 16 # LoRA rank (0 to disable) +lora_alpha: 32 # LoRA scaling factor +freeze_layers: 12 # Freeze first N transformer layers + +# Batch construction +batch_size: 16 +logos_per_batch: 32 # Different logos per batch +samples_per_logo: 4 # Samples per logo (creates positive pairs) +gradient_accumulation_steps: 8 # Effective batch = 128 + +# Training +learning_rate: 1.0e-5 +max_epochs: 20 +mixed_precision: true +temperature: 0.07 # InfoNCE temperature + +# Early stopping +patience: 5 +min_delta: 0.001 +``` + +## Evaluation + +### Test Fine-Tuned Model + +```bash +uv run python test_logo_detection.py -n 50 \ + -e models/logo_detection/clip_finetuned \ + --matching-method multi-ref \ + --seed 42 +``` + +### Compare with Baseline + +```bash +# Baseline CLIP +uv run python test_logo_detection.py -n 50 \ + -e openai/clip-vit-large-patch14 \ + --matching-method multi-ref \ + --seed 42 + +# Fine-tuned model +uv run python test_logo_detection.py -n 50 \ + -e models/logo_detection/clip_finetuned \ + --matching-method multi-ref \ + --seed 42 +``` + +### Expected Metrics + +| Metric | Baseline CLIP | Target (Fine-tuned) | +|--------|---------------|---------------------| +| Precision | ~49% | >70% | +| Recall | ~77% | >75% | +| F1 Score | ~60% | >72% | + +Training metrics to monitor: +- Mean positive similarity: target > 0.85 +- Mean negative similarity: target < 0.50 +- Embedding separation: target > 0.35 + +## Export Model + +To export a checkpoint to HuggingFace format: + +```bash +uv run python export_model.py \ + --checkpoint checkpoints/best.pt \ + --output models/logo_detection/clip_finetuned +``` + +With LoRA weight merging (reduces inference overhead): + +```bash +uv run python export_model.py \ + --checkpoint checkpoints/best.pt \ + --output models/logo_detection/clip_finetuned \ + --merge-lora +``` + +## Using Fine-Tuned Model with DetectLogosDETR + +The fine-tuned model works as a drop-in replacement: + +```python +from logo_detection_detr import DetectLogosDETR + +# Use fine-tuned model +detector = DetectLogosDETR( + logger=logger, + embedding_model="models/logo_detection/clip_finetuned", +) + +# Or use baseline for comparison +detector_baseline = DetectLogosDETR( + logger=logger, + embedding_model="openai/clip-vit-large-patch14", +) +``` + +## Architecture Details + +### Training Approach + +1. **Contrastive Learning**: Uses InfoNCE loss to maximize similarity between embeddings of the same logo while minimizing similarity to different logos. + +2. **LoRA (Low-Rank Adaptation)**: Adds small trainable matrices to attention layers instead of fine-tuning all weights. This is memory-efficient and prevents catastrophic forgetting. + +3. **Layer Freezing**: Freezes the first 12 of 24 transformer layers to preserve CLIP's low-level visual features while adapting high-level semantics. + +4. **Logo-Level Splits**: Splits data by logo brand (not by image) to test generalization to unseen logos. + +### Batch Construction + +Each batch contains: +- K different logo brands (default: 32) +- M samples per brand (default: 4) +- Total samples: K × M = 128 + +This ensures positive pairs (same logo) exist within each batch for contrastive learning. + +### Data Augmentation + +Medium strength augmentations: +- Random horizontal flip +- Random rotation (±15°) +- Color jitter (brightness, contrast, saturation) +- Random affine transforms +- Random grayscale (10% of images) + +## Troubleshooting + +### Out of Memory + +Reduce batch size and increase gradient accumulation: + +```bash +uv run python train_clip_logo.py --config configs/jetson_orin.yaml \ + --batch-size 8 \ + --gradient-accumulation-steps 16 +``` + +### Slow Training + +Ensure mixed precision is enabled: + +```bash +uv run python train_clip_logo.py --config configs/jetson_orin.yaml +# mixed_precision: true is default in jetson_orin.yaml +``` + +### No Improvement + +Try adjusting: +- Lower learning rate: `--learning-rate 5e-6` +- Higher temperature: `--temperature 0.1` +- Different loss: edit config to use `loss_type: "combined"` + +### Import Error for Fine-Tuned Model + +Ensure the `training/` module is in your Python path: + +```bash +export PYTHONPATH="${PYTHONPATH}:/data/dev.python/logo_test" +``` + +## Dependencies Added + +The following were added to `pyproject.toml`: + +```toml +peft>=0.7.0 # LoRA support +pyyaml>=6.0 # Config file parsing +torchvision>=0.20.0 # Image transforms +``` diff --git a/CLOUD_TRAINING.md b/CLOUD_TRAINING.md new file mode 100644 index 0000000..629159e --- /dev/null +++ b/CLOUD_TRAINING.md @@ -0,0 +1,269 @@ +# Cloud GPU Training for CLIP Fine-Tuning + +This document provides guidance on using cloud GPU instances (e.g., RunPod) for faster CLIP fine-tuning compared to local training on Jetson Orin AGX. + +## Training Time Comparison + +Local training on Jetson Orin AGX takes approximately 24 hours. Cloud GPUs offer significantly faster training: + +| GPU | VRAM | Est. Training Time | Hourly Rate | Est. Total Cost | +|-----|------|-------------------|-------------|-----------------| +| **RTX 4090** | 24GB | 4-6 hours | $0.59/hr | **$2.40-$3.50** | +| **RTX 3090** | 24GB | 5-7 hours | $0.39/hr | **$2.00-$2.75** | +| **A100 80GB** | 80GB | 2-3 hours | $1.99/hr | **$4.00-$6.00** | +| **L40S** | 48GB | 3-4 hours | $0.89/hr | **$2.70-$3.60** | +| **H100 80GB** | 80GB | 1.5-2 hours | $1.99/hr | **$3.00-$4.00** | + +*Prices from RunPod Community Cloud as of January 2025. Rates may vary.* + +## Recommendations + +### Best Value: RTX 4090 ($0.59/hr) +- 24GB VRAM is sufficient for ViT-L/14 with LoRA +- Good balance of speed and cost +- Widely available on Community Cloud +- **Total cost: ~$3 for complete training** + +### Best Speed: H100 80GB ($1.99/hr) +- Fastest training (1.5-2 hours) +- 80GB VRAM allows larger batch sizes +- Can increase `batch_size` to 32+ and reduce `gradient_accumulation_steps` +- **Total cost: ~$3-4** + +### Budget Option: RTX 3090 ($0.39/hr) +- Cheapest hourly rate +- 24GB VRAM works fine +- Slightly slower than 4090 +- **Total cost: ~$2-3** + +## Cloud-Optimized Configurations + +### RTX 4090 / RTX 3090 (24GB VRAM) + +Create `configs/cloud_rtx4090.yaml`: + +```yaml +# Optimized for 24GB VRAM cloud GPUs +base_model: "openai/clip-vit-large-patch14" + +# Dataset paths +dataset_dir: "LogoDet-3K" +reference_dir: "reference_logos" +db_path: "test_data_mapping.db" + +# Data splits +train_split: 0.7 +val_split: 0.15 +test_split: 0.15 + +# Larger batches for faster training +batch_size: 32 +logos_per_batch: 32 +samples_per_logo: 4 +gradient_accumulation_steps: 4 # Effective batch = 128 +num_workers: 8 + +# Model architecture +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.1 +freeze_layers: 12 +use_gradient_checkpointing: true + +# Training +learning_rate: 1.0e-5 +weight_decay: 0.01 +warmup_steps: 500 +max_epochs: 20 +mixed_precision: true + +# Loss +temperature: 0.07 +loss_type: "infonce" + +# Early stopping +patience: 5 +min_delta: 0.001 + +# Output +checkpoint_dir: "checkpoints" +output_dir: "models/logo_detection/clip_finetuned" +save_every_n_epochs: 5 + +# Logging +log_every_n_steps: 10 +eval_every_n_epochs: 1 + +seed: 42 +use_augmentation: true +augmentation_strength: "medium" +``` + +### A100 / H100 (80GB VRAM) + +Create `configs/cloud_a100.yaml`: + +```yaml +# Optimized for 80GB VRAM cloud GPUs (A100, H100) +base_model: "openai/clip-vit-large-patch14" + +# Dataset paths +dataset_dir: "LogoDet-3K" +reference_dir: "reference_logos" +db_path: "test_data_mapping.db" + +# Data splits +train_split: 0.7 +val_split: 0.15 +test_split: 0.15 + +# Maximum batch sizes for 80GB VRAM +batch_size: 64 +logos_per_batch: 32 +samples_per_logo: 4 +gradient_accumulation_steps: 2 # Effective batch = 128 +num_workers: 8 + +# Model architecture (can disable gradient checkpointing with 80GB) +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.1 +freeze_layers: 12 +use_gradient_checkpointing: false # Not needed with 80GB + +# Training +learning_rate: 1.0e-5 +weight_decay: 0.01 +warmup_steps: 500 +max_epochs: 20 +mixed_precision: true + +# Loss +temperature: 0.07 +loss_type: "infonce" + +# Early stopping +patience: 5 +min_delta: 0.001 + +# Output +checkpoint_dir: "checkpoints" +output_dir: "models/logo_detection/clip_finetuned" +save_every_n_epochs: 5 + +# Logging +log_every_n_steps: 10 +eval_every_n_epochs: 1 + +seed: 42 +use_augmentation: true +augmentation_strength: "medium" +``` + +## RunPod Quick Start + +### 1. Create a Pod + +1. Go to [RunPod](https://www.runpod.io/) +2. Select GPU (RTX 4090 recommended) +3. Choose PyTorch template (CUDA 12.x) +4. Set volume size: 50GB (for dataset + models) + +### 2. Setup Environment + +```bash +# Connect via SSH or web terminal + +# Install dependencies +pip install peft pyyaml torchvision transformers tqdm pillow + +# Clone your repository (or upload files) +git clone +cd logo_test + +# Or use runpodctl to sync files +# runpodctl send logo_test/ +``` + +### 3. Prepare Data + +If data isn't already prepared: + +```bash +# This creates reference_logos/ and test_data_mapping.db +python prepare_test_data.py +``` + +### 4. Run Training + +```bash +# For RTX 4090 +python train_clip_logo.py --config configs/cloud_rtx4090.yaml + +# For A100/H100 +python train_clip_logo.py --config configs/cloud_a100.yaml + +# Or with command-line overrides +python train_clip_logo.py --config configs/jetson_orin.yaml \ + --batch-size 32 \ + --gradient-accumulation-steps 4 \ + --num-workers 8 +``` + +### 5. Download Results + +```bash +# Export the trained model +python export_model.py \ + --checkpoint checkpoints/best.pt \ + --output models/logo_detection/clip_finetuned + +# Download to local machine +# Option 1: Use runpodctl +runpodctl receive models/logo_detection/clip_finetuned + +# Option 2: SCP +scp -r root@:/workspace/logo_test/models/logo_detection/clip_finetuned ./ + +# Option 3: Compress and download via web +tar -czvf clip_finetuned.tar.gz models/logo_detection/clip_finetuned +``` + +## Cost Optimization Tips + +### Use Spot/Interruptible Instances +- Community Cloud GPUs are already cheaper +- Some providers offer spot pricing for additional savings +- Save checkpoints frequently (`save_every_n_epochs: 2`) + +### Minimize Storage Costs +- RunPod charges $0.10/GB/month for container disk +- Use network volumes only if needed +- Delete pods when training completes + +### Monitor Training +- Watch for early convergence (may finish before 20 epochs) +- Early stopping will save time/cost if no improvement + +### Batch Training Runs +- Test configuration locally first (1-2 epochs) +- Run full training on cloud only when config is validated + +## Cost Comparison Summary + +| Option | Time | Cost | Best For | +|--------|------|------|----------| +| Jetson Orin (local) | ~24 hrs | Free* | No cloud dependency | +| RTX 3090 (RunPod) | ~6 hrs | ~$2.50 | Lowest cost | +| RTX 4090 (RunPod) | ~5 hrs | ~$3.00 | Best value | +| L40S (RunPod) | ~3.5 hrs | ~$3.00 | Good balance | +| A100 80GB (RunPod) | ~2.5 hrs | ~$5.00 | Large batches | +| H100 80GB (RunPod) | ~1.5 hrs | ~$3.50 | Fastest | + +*Local training has electricity cost but no cloud fees. + +## References + +- [RunPod Pricing](https://www.runpod.io/pricing) +- [RunPod RTX 4090](https://www.runpod.io/gpu-models/rtx-4090) +- [RunPod Documentation](https://docs.runpod.io/) diff --git a/configs/cloud_a100.yaml b/configs/cloud_a100.yaml new file mode 100644 index 0000000..5551070 --- /dev/null +++ b/configs/cloud_a100.yaml @@ -0,0 +1,64 @@ +# Training configuration optimized for cloud A100 / H100 (80GB VRAM) +# +# Usage: +# python train_clip_logo.py --config configs/cloud_a100.yaml +# +# Estimated training time: 1.5-3 hours +# Estimated cost on RunPod: ~$3-6 + +# Base model +base_model: "openai/clip-vit-large-patch14" + +# Dataset paths +dataset_dir: "LogoDet-3K" +reference_dir: "reference_logos" +db_path: "test_data_mapping.db" + +# Data splits +train_split: 0.7 +val_split: 0.15 +test_split: 0.15 + +# Maximum batch sizes for 80GB VRAM +batch_size: 64 +logos_per_batch: 32 +samples_per_logo: 4 +gradient_accumulation_steps: 2 # Effective batch = 128 +num_workers: 8 + +# Model architecture (no gradient checkpointing needed with 80GB) +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.1 +freeze_layers: 12 +use_gradient_checkpointing: false + +# Training +learning_rate: 1.0e-5 +weight_decay: 0.01 +warmup_steps: 500 +max_epochs: 20 +mixed_precision: true + +# Loss +temperature: 0.07 +loss_type: "infonce" +triplet_margin: 0.3 + +# Early stopping +patience: 5 +min_delta: 0.001 + +# Output +checkpoint_dir: "checkpoints" +output_dir: "models/logo_detection/clip_finetuned" +save_every_n_epochs: 2 # Save more frequently for cloud + +# Logging +log_every_n_steps: 10 +eval_every_n_epochs: 1 + +seed: 42 +use_hard_negatives: false +use_augmentation: true +augmentation_strength: "medium" diff --git a/configs/cloud_rtx4090.yaml b/configs/cloud_rtx4090.yaml new file mode 100644 index 0000000..30095c7 --- /dev/null +++ b/configs/cloud_rtx4090.yaml @@ -0,0 +1,64 @@ +# Training configuration optimized for cloud RTX 4090 / RTX 3090 (24GB VRAM) +# +# Usage: +# python train_clip_logo.py --config configs/cloud_rtx4090.yaml +# +# Estimated training time: 4-6 hours +# Estimated cost on RunPod: ~$3 + +# Base model +base_model: "openai/clip-vit-large-patch14" + +# Dataset paths +dataset_dir: "LogoDet-3K" +reference_dir: "reference_logos" +db_path: "test_data_mapping.db" + +# Data splits +train_split: 0.7 +val_split: 0.15 +test_split: 0.15 + +# Larger batches for faster training on 24GB VRAM +batch_size: 32 +logos_per_batch: 32 +samples_per_logo: 4 +gradient_accumulation_steps: 4 # Effective batch = 128 +num_workers: 8 + +# Model architecture +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.1 +freeze_layers: 12 +use_gradient_checkpointing: true + +# Training +learning_rate: 1.0e-5 +weight_decay: 0.01 +warmup_steps: 500 +max_epochs: 20 +mixed_precision: true + +# Loss +temperature: 0.07 +loss_type: "infonce" +triplet_margin: 0.3 + +# Early stopping +patience: 5 +min_delta: 0.001 + +# Output +checkpoint_dir: "checkpoints" +output_dir: "models/logo_detection/clip_finetuned" +save_every_n_epochs: 2 # Save more frequently for cloud + +# Logging +log_every_n_steps: 10 +eval_every_n_epochs: 1 + +seed: 42 +use_hard_negatives: false +use_augmentation: true +augmentation_strength: "medium" diff --git a/configs/jetson_orin.yaml b/configs/jetson_orin.yaml new file mode 100644 index 0000000..04f6240 --- /dev/null +++ b/configs/jetson_orin.yaml @@ -0,0 +1,76 @@ +# Training configuration optimized for Jetson Orin AGX (~64GB shared memory) +# +# Usage: +# uv run python train_clip_logo.py --config configs/jetson_orin.yaml + +# Base model +base_model: "openai/clip-vit-large-patch14" + +# Dataset paths (relative to project root) +dataset_dir: "LogoDet-3K" +reference_dir: "reference_logos" +db_path: "test_data_mapping.db" + +# Data split ratios (logo-level split for generalization testing) +train_split: 0.7 +val_split: 0.15 +test_split: 0.15 + +# Batch construction +# - batch_size: Number of batches loaded at once (keep low for memory) +# - logos_per_batch: Different logo classes per contrastive batch +# - samples_per_logo: Samples of each logo (creates positive pairs) +# - Effective samples per step = logos_per_batch * samples_per_logo = 128 +batch_size: 16 +logos_per_batch: 32 +samples_per_logo: 4 +gradient_accumulation_steps: 8 # Effective batch = 128 +num_workers: 4 + +# Model architecture +# LoRA enables memory-efficient fine-tuning by training low-rank adapters +# instead of full model weights +lora_r: 16 # LoRA rank (0 to disable) +lora_alpha: 32 # LoRA scaling factor +lora_dropout: 0.1 # Dropout in LoRA layers +freeze_layers: 12 # Freeze first 12 of 24 transformer layers +use_gradient_checkpointing: true # Trade compute for memory + +# Training hyperparameters +learning_rate: 1.0e-5 # Conservative LR for fine-tuning +weight_decay: 0.01 # L2 regularization +warmup_steps: 500 # LR warmup steps +max_epochs: 20 # Maximum training epochs +mixed_precision: true # FP16 training for memory efficiency + +# Loss function +# InfoNCE is the contrastive loss used in CLIP training +temperature: 0.07 # Similarity scaling (0.05-0.1 typical) +loss_type: "infonce" # Options: infonce, supcon, triplet, combined +triplet_margin: 0.3 # Only used if loss_type is triplet + +# Early stopping +patience: 5 # Stop if no improvement for N epochs +min_delta: 0.001 # Minimum improvement threshold + +# Checkpoints and output +checkpoint_dir: "checkpoints" +output_dir: "models/logo_detection/clip_finetuned" +save_every_n_epochs: 5 + +# Logging +log_every_n_steps: 10 +eval_every_n_epochs: 1 + +# Reproducibility +seed: 42 + +# Hard negative mining (advanced) +# Enable after initial training epochs for harder examples +use_hard_negatives: false +hard_negative_start_epoch: 5 +hard_negatives_per_logo: 10 + +# Data augmentation +use_augmentation: true +augmentation_strength: "medium" # light, medium, or strong diff --git a/export_model.py b/export_model.py new file mode 100644 index 0000000..7f3c7d3 --- /dev/null +++ b/export_model.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +""" +Export a trained CLIP model to HuggingFace-compatible format. + +This script converts a training checkpoint to a format that can be +loaded by DetectLogosDETR for inference. + +Usage: + uv run python export_model.py \ + --checkpoint checkpoints/best.pt \ + --output models/logo_detection/clip_finetuned + + # With custom base model + uv run python export_model.py \ + --checkpoint checkpoints/best.pt \ + --output models/logo_detection/clip_finetuned \ + --base-model openai/clip-vit-large-patch14 +""" + +import argparse +import json +import logging +import sys +from pathlib import Path + +import torch + +from training.config import TrainingConfig +from training.model import create_model, LogoFineTunedCLIP + + +def setup_logging() -> logging.Logger: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + return logging.getLogger(__name__) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Export trained CLIP model for inference", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to training checkpoint (.pt file)", + ) + parser.add_argument( + "--output", + type=str, + required=True, + help="Output directory for exported model", + ) + parser.add_argument( + "--base-model", + type=str, + default=None, + help="Base CLIP model (reads from checkpoint config if not specified)", + ) + parser.add_argument( + "--merge-lora", + action="store_true", + help="Merge LoRA weights into base model (reduces inference overhead)", + ) + + return parser.parse_args() + + +def main(): + args = parse_args() + logger = setup_logging() + + logger.info("CLIP Model Export") + logger.info("=" * 60) + + # Check checkpoint exists + checkpoint_path = Path(args.checkpoint) + if not checkpoint_path.exists(): + logger.error(f"Checkpoint not found: {checkpoint_path}") + sys.exit(1) + + # Load checkpoint + logger.info(f"Loading checkpoint: {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location="cpu") + + # Get config from checkpoint + if "config" in checkpoint: + config_dict = checkpoint["config"] + base_model = args.base_model or config_dict.get( + "base_model", "openai/clip-vit-large-patch14" + ) + lora_r = config_dict.get("lora_r", 16) + lora_alpha = config_dict.get("lora_alpha", 32) + freeze_layers = config_dict.get("freeze_layers", 12) + else: + base_model = args.base_model or "openai/clip-vit-large-patch14" + lora_r = 16 + lora_alpha = 32 + freeze_layers = 12 + + logger.info(f"Base model: {base_model}") + logger.info(f"LoRA rank: {lora_r}") + logger.info(f"Freeze layers: {freeze_layers}") + + # Create model with same architecture + logger.info("Creating model architecture...") + model, processor = create_model( + base_model=base_model, + lora_r=lora_r, + lora_alpha=lora_alpha, + freeze_layers=freeze_layers, + use_gradient_checkpointing=False, # Not needed for export + ) + + # Load weights + logger.info("Loading trained weights...") + model.load_state_dict(checkpoint["model_state_dict"]) + + # Merge LoRA if requested + if args.merge_lora and model.peft_applied: + try: + logger.info("Merging LoRA weights into base model...") + model.vision_model = model.vision_model.merge_and_unload() + model.peft_applied = False + model.lora_r = 0 + logger.info("LoRA weights merged successfully") + except Exception as e: + logger.warning(f"Could not merge LoRA weights: {e}") + logger.warning("Exporting with separate LoRA weights") + + # Create output directory + output_path = Path(args.output) + output_path.mkdir(parents=True, exist_ok=True) + + # Save model + logger.info(f"Exporting to: {output_path}") + model.save_pretrained(str(output_path)) + + # Save processor config for reference + processor.save_pretrained(str(output_path / "processor")) + + # Save additional metadata + metadata = { + "base_model": base_model, + "source_checkpoint": str(checkpoint_path), + "training_epochs": checkpoint.get("epoch", -1) + 1, + "best_val_loss": checkpoint.get("best_val_loss", None), + "best_val_separation": checkpoint.get("best_val_separation", None), + "lora_merged": args.merge_lora and not model.peft_applied, + } + + with open(output_path / "export_metadata.json", "w") as f: + json.dump(metadata, f, indent=2) + + logger.info("\nExport complete!") + logger.info(f"Model saved to: {output_path}") + logger.info("\nTo use with DetectLogosDETR:") + logger.info(f" detector = DetectLogosDETR(embedding_model='{output_path}')") + logger.info("\nOr with test_logo_detection.py:") + logger.info(f" uv run python test_logo_detection.py -e {output_path}") + + +if __name__ == "__main__": + main() diff --git a/logo_detection_detr.py b/logo_detection_detr.py index 19fc4e6..7cbbd8d 100644 --- a/logo_detection_detr.py +++ b/logo_detection_detr.py @@ -13,6 +13,7 @@ Supported embedding models: - DINOv2 models (facebook/dinov2-*): Self-supervised, excellent for visual similarity """ +import json import os import torch import torch.nn.functional as F @@ -100,16 +101,20 @@ class DetectLogosDETR: embedding_model, default_embedding_dir, "Embedding" ) - # Detect model type and initialize accordingly - self.model_type = self._detect_model_type(embedding_model) - self.logger.info(f"Loading {self.model_type} embedding model: {embedding_model_path}") + # Check if this is a fine-tuned model + if self._is_finetuned_model(embedding_model_path): + self._load_finetuned_embedding_model(embedding_model_path) + else: + # Detect model type and initialize accordingly + self.model_type = self._detect_model_type(embedding_model) + self.logger.info(f"Loading {self.model_type} embedding model: {embedding_model_path}") - if self.model_type == "clip": - self.embedding_model = CLIPModel.from_pretrained(embedding_model_path).to(self.device) - self.embedding_processor = CLIPProcessor.from_pretrained(embedding_model_path) - else: # dinov2 or other transformer models - self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device) - self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path) + if self.model_type == "clip": + self.embedding_model = CLIPModel.from_pretrained(embedding_model_path).to(self.device) + self.embedding_processor = CLIPProcessor.from_pretrained(embedding_model_path) + else: # dinov2 or other transformer models + self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device) + self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path) self.logger.info("DetectLogosDETR initialization complete") @@ -124,6 +129,62 @@ class DetectLogosDETR: # Default to generic transformer for unknown models return "transformer" + def _is_finetuned_model(self, model_path: str) -> bool: + """Check if a model path points to a fine-tuned CLIP model.""" + config_path = Path(model_path) / "config.json" + if config_path.exists(): + try: + with open(config_path, "r") as f: + config = json.load(f) + return config.get("model_type") == "clip_logo_finetuned" + except (json.JSONDecodeError, IOError): + pass + return False + + def _load_finetuned_embedding_model(self, model_path: str) -> None: + """ + Load a fine-tuned CLIP model from the training module. + + Args: + model_path: Path to the fine-tuned model directory + """ + # Import the fine-tuned model class + try: + from training.model import LogoFineTunedCLIP + except ImportError as e: + self.logger.error( + f"Cannot import training.model for fine-tuned model: {e}" + ) + raise ImportError( + "Fine-tuned model requires the training module. " + "Ensure the training/ directory is in your Python path." + ) from e + + # Load config + config_path = Path(model_path) / "config.json" + with open(config_path, "r") as f: + config = json.load(f) + + base_model = config.get("base_model", "openai/clip-vit-large-patch14") + + self.logger.info(f"Loading fine-tuned CLIP model from: {model_path}") + self.logger.info(f" Base model: {base_model}") + + # Load model using the from_pretrained method + self.embedding_model = LogoFineTunedCLIP.from_pretrained( + model_path, + base_model=base_model, + device=self.device, + ) + self.embedding_model.eval() + + # Load processor from base model + self.embedding_processor = CLIPProcessor.from_pretrained(base_model) + + # Set model type for embedding extraction + self.model_type = "clip_finetuned" + self.logger.info("Fine-tuned CLIP model loaded successfully") + def _resolve_model_path( self, model_name_or_path: str, default_local_dir: str, model_type: str ) -> str: @@ -345,7 +406,7 @@ class DetectLogosDETR: """ Internal method to get embedding from PIL image. - Handles both CLIP and DINOv2 model types. + Handles CLIP, fine-tuned CLIP, and DINOv2 model types. Args: pil_image: PIL Image (RGB format) @@ -360,6 +421,9 @@ class DetectLogosDETR: if self.model_type == "clip": # CLIP has a dedicated method for image features features = self.embedding_model.get_image_features(**inputs) + elif self.model_type == "clip_finetuned": + # Fine-tuned CLIP uses get_image_features or forward with pixel_values + features = self.embedding_model.get_image_features(**inputs) else: # DINOv2 and other transformers use the CLS token or pooled output outputs = self.embedding_model(**inputs) @@ -370,8 +434,9 @@ class DetectLogosDETR: # Use CLS token from last_hidden_state features = outputs.last_hidden_state[:, 0, :] - # Normalize for cosine similarity - features = F.normalize(features, dim=-1) + # Normalize for cosine similarity (fine-tuned model already normalizes) + if self.model_type != "clip_finetuned": + features = F.normalize(features, dim=-1) return features diff --git a/pyproject.toml b/pyproject.toml index c7a24be..a80d6fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,4 +12,7 @@ dependencies = [ "tqdm>=4.67.1", "transformers>=4.57.3", "typing>=3.10.0.0", + "peft>=0.7.0", + "pyyaml>=6.0", + "torchvision>=0.20.0", ] diff --git a/train_clip_logo.py b/train_clip_logo.py new file mode 100644 index 0000000..44e74b5 --- /dev/null +++ b/train_clip_logo.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python3 +""" +Fine-tune CLIP vision encoder for logo recognition. + +This script trains a CLIP model using contrastive learning on the LogoDet-3K +dataset to improve logo embedding quality for similarity-based matching. + +Usage: + # Train with YAML config + uv run python train_clip_logo.py --config configs/jetson_orin.yaml + + # Train with command-line overrides + uv run python train_clip_logo.py --config configs/jetson_orin.yaml \ + --learning-rate 5e-6 --max-epochs 30 + + # Resume from checkpoint + uv run python train_clip_logo.py --config configs/jetson_orin.yaml \ + --resume checkpoints/epoch_10.pt +""" + +import argparse +import logging +import random +import sys +from pathlib import Path + +import numpy as np +import torch + +from training.config import TrainingConfig +from training.dataset import create_dataloaders +from training.model import create_model +from training.trainer import Trainer + + +def setup_logging(log_level: str = "INFO") -> logging.Logger: + """Configure logging.""" + logging.basicConfig( + level=getattr(logging, log_level.upper()), + format="%(asctime)s [%(levelname)s] %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + return logging.getLogger(__name__) + + +def set_seed(seed: int) -> None: + """Set random seeds for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def parse_args() -> argparse.Namespace: + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Fine-tune CLIP for logo recognition", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Config file + parser.add_argument( + "--config", + type=str, + help="Path to YAML configuration file", + ) + + # Dataset paths + parser.add_argument( + "--dataset-dir", + type=str, + help="Path to LogoDet-3K dataset", + ) + parser.add_argument( + "--reference-dir", + type=str, + help="Path to reference logos directory", + ) + parser.add_argument( + "--db-path", + type=str, + help="Path to SQLite database", + ) + + # Model + parser.add_argument( + "--base-model", + type=str, + help="Base CLIP model name or path", + ) + parser.add_argument( + "--lora-r", + type=int, + help="LoRA rank (0 to disable)", + ) + parser.add_argument( + "--freeze-layers", + type=int, + help="Number of transformer layers to freeze", + ) + + # Training + parser.add_argument( + "--batch-size", + type=int, + help="Batch size", + ) + parser.add_argument( + "--learning-rate", + type=float, + help="Learning rate", + ) + parser.add_argument( + "--max-epochs", + type=int, + help="Maximum number of epochs", + ) + parser.add_argument( + "--gradient-accumulation-steps", + type=int, + help="Gradient accumulation steps", + ) + + # Loss + parser.add_argument( + "--temperature", + type=float, + help="Temperature for InfoNCE loss", + ) + parser.add_argument( + "--loss-type", + choices=["infonce", "supcon", "triplet", "combined"], + help="Loss function type", + ) + + # Checkpointing + parser.add_argument( + "--checkpoint-dir", + type=str, + help="Directory for checkpoints", + ) + parser.add_argument( + "--output-dir", + type=str, + help="Directory for final model output", + ) + parser.add_argument( + "--resume", + type=str, + help="Path to checkpoint to resume from", + ) + + # Other + parser.add_argument( + "--seed", + type=int, + help="Random seed", + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level", + ) + parser.add_argument( + "--no-mixed-precision", + action="store_true", + help="Disable mixed precision training", + ) + + return parser.parse_args() + + +def main(): + """Main training entry point.""" + args = parse_args() + + # Setup logging + logger = setup_logging(args.log_level) + logger.info("CLIP Logo Fine-Tuning") + logger.info("=" * 60) + + # Load or create configuration + if args.config: + logger.info(f"Loading config from: {args.config}") + config = TrainingConfig.from_yaml(args.config) + else: + logger.info("Using default configuration") + config = TrainingConfig() + + # Apply command-line overrides + override_fields = [ + "dataset_dir", "reference_dir", "db_path", "base_model", + "lora_r", "freeze_layers", "batch_size", "learning_rate", + "max_epochs", "gradient_accumulation_steps", "temperature", + "loss_type", "checkpoint_dir", "output_dir", "seed", + ] + for field in override_fields: + arg_name = field.replace("_", "-") + arg_value = getattr(args, field.replace("-", "_"), None) + if arg_value is not None: + setattr(config, field, arg_value) + logger.info(f"Override: {field} = {arg_value}") + + if args.no_mixed_precision: + config.mixed_precision = False + logger.info("Override: mixed_precision = False") + + # Validate configuration + warnings = config.validate() + for warning in warnings: + logger.warning(f"Config warning: {warning}") + + # Set random seed + set_seed(config.seed) + logger.info(f"Random seed: {config.seed}") + + # Check paths exist + db_path = Path(config.db_path) + ref_dir = Path(config.reference_dir) + + if not db_path.exists(): + logger.error(f"Database not found: {db_path}") + logger.error("Run prepare_test_data.py first to create the database.") + sys.exit(1) + + if not ref_dir.exists(): + logger.error(f"Reference directory not found: {ref_dir}") + logger.error("Run prepare_test_data.py first to extract reference logos.") + sys.exit(1) + + # Create model + logger.info(f"Creating model from: {config.base_model}") + model, processor = create_model( + base_model=config.base_model, + lora_r=config.lora_r, + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + freeze_layers=config.freeze_layers, + use_gradient_checkpointing=config.use_gradient_checkpointing, + ) + + # Create dataloaders + logger.info("Creating dataloaders...") + train_loader, val_loader, test_loader = create_dataloaders( + db_path=str(config.db_path), + reference_dir=str(config.reference_dir), + batch_size=config.batch_size, + logos_per_batch=config.logos_per_batch, + samples_per_logo=config.samples_per_logo, + num_workers=config.num_workers, + train_split=config.train_split, + val_split=config.val_split, + test_split=config.test_split, + seed=config.seed, + augmentation_strength=config.augmentation_strength, + ) + + # Create trainer + trainer = Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + config=config, + logger=logger, + ) + + # Resume from checkpoint if specified + if args.resume: + resume_path = Path(args.resume) + if resume_path.exists(): + logger.info(f"Resuming from: {resume_path}") + # Set checkpoint dir to resume path's parent + if resume_path.is_file(): + config.checkpoint_dir = str(resume_path.parent) + trainer.load_checkpoint(resume_path.name) + else: + logger.warning(f"Resume checkpoint not found: {resume_path}") + + # Train + logger.info("\nStarting training...") + final_metrics = trainer.train() + + logger.info("\nTraining complete!") + logger.info(f" Best val loss: {final_metrics['best_val_loss']:.4f}") + logger.info(f" Best separation: {final_metrics['best_val_separation']:.4f}") + logger.info(f" Total epochs: {final_metrics['total_epochs']}") + logger.info(f" Total time: {final_metrics['total_time_minutes']:.1f} minutes") + + # Export model + output_path = trainer.export_model() + logger.info(f"\nModel exported to: {output_path}") + + # Print next steps + logger.info("\n" + "=" * 60) + logger.info("Next steps:") + logger.info(f"1. Test the fine-tuned model:") + logger.info(f" uv run python test_logo_detection.py -n 50 \\") + logger.info(f" -e {output_path} --matching-method multi-ref") + logger.info(f"") + logger.info(f"2. Compare with baseline:") + logger.info(f" uv run python test_logo_detection.py -n 50 \\") + logger.info(f" -e openai/clip-vit-large-patch14 --matching-method multi-ref") + + +if __name__ == "__main__": + main() diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000..cc8f7fd --- /dev/null +++ b/training/__init__.py @@ -0,0 +1,24 @@ +""" +CLIP fine-tuning module for logo recognition. + +This module provides tools for fine-tuning CLIP's vision encoder using +contrastive learning on the LogoDet-3K dataset. +""" + +from .config import TrainingConfig +from .dataset import LogoContrastiveDataset, create_dataloaders +from .model import LogoFineTunedCLIP +from .losses import InfoNCELoss, TripletLoss +from .trainer import Trainer +from .evaluation import EmbeddingEvaluator + +__all__ = [ + "TrainingConfig", + "LogoContrastiveDataset", + "create_dataloaders", + "LogoFineTunedCLIP", + "InfoNCELoss", + "TripletLoss", + "Trainer", + "EmbeddingEvaluator", +] diff --git a/training/config.py b/training/config.py new file mode 100644 index 0000000..caf6b78 --- /dev/null +++ b/training/config.py @@ -0,0 +1,141 @@ +""" +Training configuration for CLIP fine-tuning. +""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import List, Optional +import yaml + + +@dataclass +class TrainingConfig: + """Configuration for CLIP logo fine-tuning.""" + + # Base model + base_model: str = "openai/clip-vit-large-patch14" + + # Dataset paths + dataset_dir: str = "LogoDet-3K" + reference_dir: str = "reference_logos" + db_path: str = "test_data_mapping.db" + + # Data split ratios + train_split: float = 0.7 + val_split: float = 0.15 + test_split: float = 0.15 + + # Batch construction + batch_size: int = 16 + logos_per_batch: int = 32 + samples_per_logo: int = 4 + gradient_accumulation_steps: int = 8 + num_workers: int = 4 + + # Model architecture + lora_r: int = 16 + lora_alpha: int = 32 + lora_dropout: float = 0.1 + freeze_layers: int = 12 + use_gradient_checkpointing: bool = True + + # Training hyperparameters + learning_rate: float = 1e-5 + weight_decay: float = 0.01 + warmup_steps: int = 500 + max_epochs: int = 20 + mixed_precision: bool = True + + # Loss function + temperature: float = 0.07 + loss_type: str = "infonce" # "infonce" or "triplet" + triplet_margin: float = 0.3 + + # Early stopping + patience: int = 5 + min_delta: float = 0.001 + + # Checkpoints and output + checkpoint_dir: str = "checkpoints" + output_dir: str = "models/logo_detection/clip_finetuned" + save_every_n_epochs: int = 5 + + # Logging + log_every_n_steps: int = 10 + eval_every_n_epochs: int = 1 + + # Random seed for reproducibility + seed: int = 42 + + # Hard negative mining + use_hard_negatives: bool = False + hard_negative_start_epoch: int = 5 + hard_negatives_per_logo: int = 10 + + # Data augmentation + use_augmentation: bool = True + augmentation_strength: str = "medium" # "light", "medium", "strong" + + @classmethod + def from_yaml(cls, yaml_path: str) -> "TrainingConfig": + """Load configuration from YAML file.""" + with open(yaml_path, "r") as f: + config_dict = yaml.safe_load(f) + return cls(**config_dict) + + def to_yaml(self, yaml_path: str) -> None: + """Save configuration to YAML file.""" + Path(yaml_path).parent.mkdir(parents=True, exist_ok=True) + with open(yaml_path, "w") as f: + yaml.dump(self.__dict__, f, default_flow_style=False, sort_keys=False) + + def validate(self) -> List[str]: + """Validate configuration and return list of warnings.""" + warnings = [] + + # Check split ratios + total_split = self.train_split + self.val_split + self.test_split + if abs(total_split - 1.0) > 0.01: + warnings.append( + f"Split ratios sum to {total_split}, expected 1.0" + ) + + # Check batch construction + effective_batch = self.batch_size * self.gradient_accumulation_steps + if effective_batch < 64: + warnings.append( + f"Effective batch size ({effective_batch}) is small for contrastive learning. " + "Consider increasing batch_size or gradient_accumulation_steps." + ) + + # Check LoRA config + if self.lora_r > 0 and self.lora_alpha < self.lora_r: + warnings.append( + f"lora_alpha ({self.lora_alpha}) < lora_r ({self.lora_r}). " + "This may reduce LoRA effectiveness." + ) + + # Check freeze layers + if self.freeze_layers < 0: + warnings.append("freeze_layers should be >= 0") + + # Check temperature + if self.temperature <= 0: + warnings.append("temperature must be positive") + elif self.temperature > 1.0: + warnings.append( + f"temperature ({self.temperature}) is high. " + "Typical values are 0.05-0.1." + ) + + return warnings + + @property + def effective_batch_size(self) -> int: + """Calculate effective batch size with gradient accumulation.""" + return self.batch_size * self.gradient_accumulation_steps + + @property + def samples_per_batch(self) -> int: + """Total samples in one batch (logos_per_batch * samples_per_logo).""" + return self.logos_per_batch * self.samples_per_logo diff --git a/training/dataset.py b/training/dataset.py new file mode 100644 index 0000000..95b2fe1 --- /dev/null +++ b/training/dataset.py @@ -0,0 +1,467 @@ +""" +Dataset classes for contrastive learning on logo images. +""" + +import random +import sqlite3 +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import torch +from PIL import Image +from torch.utils.data import Dataset, DataLoader, Sampler +from torchvision import transforms + + +# CLIP normalization values +CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] +CLIP_STD = [0.26862954, 0.26130258, 0.27577711] + + +def get_train_transforms(strength: str = "medium") -> transforms.Compose: + """ + Get training data augmentation transforms. + + Args: + strength: Augmentation strength - "light", "medium", or "strong" + + Returns: + Composed transforms for training + """ + if strength == "light": + return transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ColorJitter(brightness=0.1, contrast=0.1), + transforms.ToTensor(), + transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), + ]) + elif strength == "medium": + return transforms.Compose([ + transforms.Resize((224, 224)), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomRotation(degrees=15), + transforms.ColorJitter( + brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05 + ), + transforms.RandomAffine( + degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1) + ), + transforms.RandomGrayscale(p=0.1), + transforms.ToTensor(), + transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), + ]) + else: # strong + return transforms.Compose([ + transforms.Resize((256, 256)), + transforms.RandomCrop(224), + transforms.RandomHorizontalFlip(p=0.5), + transforms.RandomVerticalFlip(p=0.1), + transforms.RandomRotation(degrees=30), + transforms.ColorJitter( + brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1 + ), + transforms.RandomAffine( + degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10 + ), + transforms.RandomGrayscale(p=0.2), + transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), + transforms.ToTensor(), + transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), + ]) + + +def get_val_transforms() -> transforms.Compose: + """Get validation/test transforms (no augmentation).""" + return transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), + ]) + + +class LogoDataset: + """ + Manages logo data from the SQLite database. + + Handles loading logo-to-image mappings and splitting by logo brand. + """ + + def __init__( + self, + db_path: str, + reference_dir: str, + train_split: float = 0.7, + val_split: float = 0.15, + test_split: float = 0.15, + seed: int = 42, + ): + self.db_path = Path(db_path) + self.reference_dir = Path(reference_dir) + self.seed = seed + + # Load logo-to-images mapping from database + self.logo_to_images = self._load_logo_mappings() + self.all_logos = list(self.logo_to_images.keys()) + + # Create logo-level splits + self.train_logos, self.val_logos, self.test_logos = self._split_logos( + train_split, val_split, test_split + ) + + def _load_logo_mappings(self) -> Dict[str, List[Path]]: + """Load logo name to image paths mapping from database.""" + conn = sqlite3.connect(self.db_path) + cursor = conn.cursor() + + cursor.execute(""" + SELECT ln.name, rl.filename + FROM reference_logos rl + JOIN logo_names ln ON rl.logo_name_id = ln.id + ORDER BY ln.name + """) + + logo_to_images: Dict[str, List[Path]] = {} + for logo_name, filename in cursor.fetchall(): + if logo_name not in logo_to_images: + logo_to_images[logo_name] = [] + logo_to_images[logo_name].append(self.reference_dir / filename) + + conn.close() + return logo_to_images + + def _split_logos( + self, + train_split: float, + val_split: float, + test_split: float, + ) -> Tuple[List[str], List[str], List[str]]: + """Split logos at brand level for train/val/test.""" + random.seed(self.seed) + logos = self.all_logos.copy() + random.shuffle(logos) + + n = len(logos) + train_end = int(n * train_split) + val_end = train_end + int(n * val_split) + + train_logos = logos[:train_end] + val_logos = logos[train_end:val_end] + test_logos = logos[val_end:] + + return train_logos, val_logos, test_logos + + def get_split_info(self) -> Dict[str, int]: + """Return information about the splits.""" + return { + "total_logos": len(self.all_logos), + "train_logos": len(self.train_logos), + "val_logos": len(self.val_logos), + "test_logos": len(self.test_logos), + "train_images": sum( + len(self.logo_to_images[l]) for l in self.train_logos + ), + "val_images": sum( + len(self.logo_to_images[l]) for l in self.val_logos + ), + "test_images": sum( + len(self.logo_to_images[l]) for l in self.test_logos + ), + } + + +class LogoContrastiveDataset(Dataset): + """ + Dataset for contrastive learning on logos. + + Each __getitem__ call returns a batch of images organized for contrastive + learning: K different logos with M samples each, ensuring positive pairs + exist within each batch. + """ + + def __init__( + self, + logo_data: LogoDataset, + split: str = "train", + logos_per_batch: int = 32, + samples_per_logo: int = 4, + transform: Optional[transforms.Compose] = None, + batches_per_epoch: int = 1000, + ): + """ + Initialize the contrastive dataset. + + Args: + logo_data: LogoDataset instance with logo mappings + split: One of "train", "val", or "test" + logos_per_batch: Number of different logos per batch + samples_per_logo: Number of samples for each logo + transform: Image transforms to apply + batches_per_epoch: Number of batches per epoch + """ + self.logo_data = logo_data + self.logos_per_batch = logos_per_batch + self.samples_per_logo = samples_per_logo + self.transform = transform + self.batches_per_epoch = batches_per_epoch + + # Get logos for this split + if split == "train": + self.logos = logo_data.train_logos + elif split == "val": + self.logos = logo_data.val_logos + else: + self.logos = logo_data.test_logos + + # Filter logos with enough samples + self.valid_logos = [ + logo for logo in self.logos + if len(logo_data.logo_to_images[logo]) >= samples_per_logo + ] + + # For logos with fewer samples, we'll use with replacement + self.logos_needing_replacement = [ + logo for logo in self.logos + if len(logo_data.logo_to_images[logo]) < samples_per_logo + ] + + # Create label mapping + self.logo_to_label = { + logo: idx for idx, logo in enumerate(self.logos) + } + + def __len__(self) -> int: + return self.batches_per_epoch + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get a batch of images for contrastive learning. + + Returns: + images: Tensor of shape [K*M, 3, 224, 224] + labels: Tensor of shape [K*M] with logo class indices + """ + images = [] + labels = [] + + # Sample K logos for this batch + k = min(self.logos_per_batch, len(self.logos)) + batch_logos = random.sample(self.logos, k) + + for logo in batch_logos: + logo_images = self.logo_data.logo_to_images[logo] + + # Sample M images for this logo + if len(logo_images) >= self.samples_per_logo: + sampled_paths = random.sample(logo_images, self.samples_per_logo) + else: + # Sample with replacement if not enough images + sampled_paths = random.choices( + logo_images, k=self.samples_per_logo + ) + + # Load and transform images + for img_path in sampled_paths: + try: + img = Image.open(img_path).convert("RGB") + if self.transform: + img = self.transform(img) + else: + img = get_val_transforms()(img) + images.append(img) + labels.append(self.logo_to_label[logo]) + except Exception as e: + # Skip problematic images, sample another + continue + + # Stack into tensors + if len(images) == 0: + # Fallback: return dummy batch + return ( + torch.zeros(1, 3, 224, 224), + torch.zeros(1, dtype=torch.long), + ) + + images_tensor = torch.stack(images) + labels_tensor = torch.tensor(labels, dtype=torch.long) + + return images_tensor, labels_tensor + + +class BalancedBatchSampler(Sampler): + """ + Sampler that ensures each batch has a balanced distribution of logos. + + Used with a flattened dataset where each sample is a single image. + """ + + def __init__( + self, + logo_labels: List[int], + logos_per_batch: int, + samples_per_logo: int, + num_batches: int, + ): + self.logo_labels = logo_labels + self.logos_per_batch = logos_per_batch + self.samples_per_logo = samples_per_logo + self.num_batches = num_batches + + # Group indices by logo + self.logo_to_indices: Dict[int, List[int]] = {} + for idx, label in enumerate(logo_labels): + if label not in self.logo_to_indices: + self.logo_to_indices[label] = [] + self.logo_to_indices[label].append(idx) + + self.all_logos = list(self.logo_to_indices.keys()) + + def __iter__(self): + for _ in range(self.num_batches): + batch_indices = [] + + # Sample logos for this batch + logos = random.sample( + self.all_logos, + min(self.logos_per_batch, len(self.all_logos)), + ) + + for logo in logos: + indices = self.logo_to_indices[logo] + if len(indices) >= self.samples_per_logo: + sampled = random.sample(indices, self.samples_per_logo) + else: + sampled = random.choices(indices, k=self.samples_per_logo) + batch_indices.extend(sampled) + + yield batch_indices + + def __len__(self): + return self.num_batches + + +def create_dataloaders( + db_path: str, + reference_dir: str, + batch_size: int = 16, + logos_per_batch: int = 32, + samples_per_logo: int = 4, + num_workers: int = 4, + train_split: float = 0.7, + val_split: float = 0.15, + test_split: float = 0.15, + seed: int = 42, + augmentation_strength: str = "medium", + batches_per_epoch: int = 1000, +) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]: + """ + Create train, validation, and optionally test dataloaders. + + Args: + db_path: Path to SQLite database + reference_dir: Directory containing reference logo images + batch_size: Not used directly (see logos_per_batch and samples_per_logo) + logos_per_batch: Number of different logos per batch + samples_per_logo: Samples per logo in batch + num_workers: Number of data loading workers + train_split: Fraction for training + val_split: Fraction for validation + test_split: Fraction for testing + seed: Random seed + augmentation_strength: "light", "medium", or "strong" + batches_per_epoch: Number of batches per training epoch + + Returns: + Tuple of (train_loader, val_loader, test_loader) + """ + # Load logo data + logo_data = LogoDataset( + db_path=db_path, + reference_dir=reference_dir, + train_split=train_split, + val_split=val_split, + test_split=test_split, + seed=seed, + ) + + # Print split info + split_info = logo_data.get_split_info() + print(f"Dataset loaded:") + print(f" Total logos: {split_info['total_logos']}") + print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images") + print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images") + print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images") + + # Create datasets + train_dataset = LogoContrastiveDataset( + logo_data=logo_data, + split="train", + logos_per_batch=logos_per_batch, + samples_per_logo=samples_per_logo, + transform=get_train_transforms(augmentation_strength), + batches_per_epoch=batches_per_epoch, + ) + + val_dataset = LogoContrastiveDataset( + logo_data=logo_data, + split="val", + logos_per_batch=logos_per_batch, + samples_per_logo=samples_per_logo, + transform=get_val_transforms(), + batches_per_epoch=batches_per_epoch // 10, # Fewer val batches + ) + + test_dataset = LogoContrastiveDataset( + logo_data=logo_data, + split="test", + logos_per_batch=logos_per_batch, + samples_per_logo=samples_per_logo, + transform=get_val_transforms(), + batches_per_epoch=batches_per_epoch // 10, + ) if test_split > 0 else None + + # Create dataloaders + # Note: batch_size=1 because each __getitem__ already returns a batch + train_loader = DataLoader( + train_dataset, + batch_size=1, + shuffle=True, + num_workers=num_workers, + pin_memory=True, + collate_fn=_collate_contrastive_batch, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=1, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + collate_fn=_collate_contrastive_batch, + ) + + test_loader = None + if test_dataset is not None: + test_loader = DataLoader( + test_dataset, + batch_size=1, + shuffle=False, + num_workers=num_workers, + pin_memory=True, + collate_fn=_collate_contrastive_batch, + ) + + return train_loader, val_loader, test_loader + + +def _collate_contrastive_batch( + batch: List[Tuple[torch.Tensor, torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Collate function that unpacks pre-batched data. + + Since LogoContrastiveDataset already returns batched data, + we just squeeze the outer dimension. + """ + images, labels = batch[0] + return images, labels diff --git a/training/evaluation.py b/training/evaluation.py new file mode 100644 index 0000000..3b8519b --- /dev/null +++ b/training/evaluation.py @@ -0,0 +1,339 @@ +""" +Evaluation metrics for embedding quality. +""" + +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F +import numpy as np + + +class EmbeddingEvaluator: + """ + Evaluator for embedding quality metrics. + + Computes metrics that indicate how well the embeddings + separate different logo classes. + """ + + def compute_metrics( + self, + embeddings: torch.Tensor, + labels: torch.Tensor, + ) -> Dict[str, float]: + """ + Compute embedding quality metrics. + + Args: + embeddings: [N, D] L2-normalized embeddings + labels: [N] integer class labels + + Returns: + Dict with metric names and values + """ + device = embeddings.device + batch_size = embeddings.shape[0] + + if batch_size <= 1: + return { + "mean_pos_sim": 0.0, + "mean_neg_sim": 0.0, + "separation": 0.0, + "recall_at_1": 0.0, + "recall_at_5": 0.0, + } + + # Compute similarity matrix + similarity = embeddings @ embeddings.T + + # Create masks + labels_col = labels.unsqueeze(0) + labels_row = labels.unsqueeze(1) + positive_mask = (labels_row == labels_col).float() + negative_mask = 1 - positive_mask + + # Remove diagonal from positive mask + identity = torch.eye(batch_size, device=device) + positive_mask = positive_mask - identity + + # Count pairs + num_positives = positive_mask.sum() + num_negatives = negative_mask.sum() + + # Mean positive similarity (excluding self) + if num_positives > 0: + pos_sims = (similarity * positive_mask).sum() / num_positives + mean_pos_sim = pos_sims.item() + else: + mean_pos_sim = 0.0 + + # Mean negative similarity + if num_negatives > 0: + neg_sims = (similarity * negative_mask).sum() / num_negatives + mean_neg_sim = neg_sims.item() + else: + mean_neg_sim = 0.0 + + # Separation: gap between positive and negative similarity + separation = mean_pos_sim - mean_neg_sim + + # Recall@K metrics + recall_at_1 = self._compute_recall_at_k(similarity, labels, k=1) + recall_at_5 = self._compute_recall_at_k(similarity, labels, k=5) + + return { + "mean_pos_sim": mean_pos_sim, + "mean_neg_sim": mean_neg_sim, + "separation": separation, + "recall_at_1": recall_at_1, + "recall_at_5": recall_at_5, + } + + def _compute_recall_at_k( + self, + similarity: torch.Tensor, + labels: torch.Tensor, + k: int = 1, + ) -> float: + """ + Compute Recall@K for nearest neighbor retrieval. + + For each sample, check if the k nearest neighbors (excluding self) + contain at least one sample with the same label. + + Args: + similarity: [N, N] similarity matrix + labels: [N] class labels + k: Number of neighbors to consider + + Returns: + Recall@K score (0 to 1) + """ + batch_size = similarity.shape[0] + if batch_size <= 1: + return 0.0 + + # Mask out self-similarity + similarity = similarity.clone() + similarity.fill_diagonal_(float("-inf")) + + # Get top-k indices + _, top_k_indices = similarity.topk(min(k, batch_size - 1), dim=1) + + # Check if any of top-k have same label + correct = 0 + for i in range(batch_size): + query_label = labels[i] + retrieved_labels = labels[top_k_indices[i]] + if (retrieved_labels == query_label).any(): + correct += 1 + + return correct / batch_size + + def compute_detailed_metrics( + self, + embeddings: torch.Tensor, + labels: torch.Tensor, + label_names: Optional[List[str]] = None, + ) -> Dict: + """ + Compute detailed per-class metrics. + + Args: + embeddings: [N, D] embeddings + labels: [N] class labels + label_names: Optional list of label names + + Returns: + Dict with detailed metrics including per-class stats + """ + basic_metrics = self.compute_metrics(embeddings, labels) + + # Per-class statistics + unique_labels = labels.unique() + per_class_stats = {} + + similarity = embeddings @ embeddings.T + + for label in unique_labels: + mask = labels == label + class_embeddings = embeddings[mask] + class_size = mask.sum().item() + + if class_size > 1: + # Intra-class similarity + class_sim = class_embeddings @ class_embeddings.T + # Exclude diagonal + mask_diag = ~torch.eye(class_size, dtype=torch.bool, device=class_sim.device) + intra_sim = class_sim[mask_diag].mean().item() + else: + intra_sim = 1.0 + + # Inter-class similarity (to other classes) + other_mask = labels != label + if other_mask.any(): + inter_sim = similarity[mask][:, other_mask].mean().item() + else: + inter_sim = 0.0 + + class_name = label_names[label.item()] if label_names else str(label.item()) + per_class_stats[class_name] = { + "size": class_size, + "intra_class_sim": intra_sim, + "inter_class_sim": inter_sim, + "class_separation": intra_sim - inter_sim, + } + + # Aggregate per-class stats + if per_class_stats: + separations = [s["class_separation"] for s in per_class_stats.values()] + min_separation = min(separations) + max_separation = max(separations) + std_separation = np.std(separations) + else: + min_separation = max_separation = std_separation = 0.0 + + return { + **basic_metrics, + "per_class": per_class_stats, + "min_class_separation": min_separation, + "max_class_separation": max_separation, + "std_class_separation": std_separation, + } + + +class SimilarityAnalyzer: + """ + Analyze similarity distributions for debugging and tuning. + """ + + @staticmethod + def analyze_similarity_distribution( + embeddings: torch.Tensor, + labels: torch.Tensor, + ) -> Dict[str, np.ndarray]: + """ + Get similarity distributions for positive and negative pairs. + + Useful for choosing appropriate thresholds. + + Args: + embeddings: [N, D] embeddings + labels: [N] class labels + + Returns: + Dict with 'positive_sims' and 'negative_sims' arrays + """ + similarity = (embeddings @ embeddings.T).cpu().numpy() + labels_np = labels.cpu().numpy() + + batch_size = len(labels_np) + positive_sims = [] + negative_sims = [] + + for i in range(batch_size): + for j in range(i + 1, batch_size): + if labels_np[i] == labels_np[j]: + positive_sims.append(similarity[i, j]) + else: + negative_sims.append(similarity[i, j]) + + return { + "positive_sims": np.array(positive_sims), + "negative_sims": np.array(negative_sims), + } + + @staticmethod + def find_hard_pairs( + embeddings: torch.Tensor, + labels: torch.Tensor, + n_hard: int = 10, + ) -> Tuple[List[Tuple[int, int, float]], List[Tuple[int, int, float]]]: + """ + Find hardest positive and negative pairs. + + Hard positives: same label but low similarity + Hard negatives: different label but high similarity + + Args: + embeddings: [N, D] embeddings + labels: [N] class labels + n_hard: Number of hard pairs to return + + Returns: + Tuple of (hard_positives, hard_negatives) + Each is a list of (idx1, idx2, similarity) tuples + """ + similarity = embeddings @ embeddings.T + batch_size = len(labels) + + hard_positives = [] # Low similarity, same label + hard_negatives = [] # High similarity, different label + + for i in range(batch_size): + for j in range(i + 1, batch_size): + sim = similarity[i, j].item() + if labels[i] == labels[j]: + hard_positives.append((i, j, sim)) + else: + hard_negatives.append((i, j, sim)) + + # Sort: hard positives by ascending similarity (lowest first) + hard_positives.sort(key=lambda x: x[2]) + + # Sort: hard negatives by descending similarity (highest first) + hard_negatives.sort(key=lambda x: -x[2]) + + return hard_positives[:n_hard], hard_negatives[:n_hard] + + @staticmethod + def compute_confusion_pairs( + embeddings: torch.Tensor, + labels: torch.Tensor, + label_names: Optional[List[str]] = None, + top_k: int = 10, + ) -> List[Dict]: + """ + Find pairs of classes that are most confused (highest cross-class similarity). + + Args: + embeddings: [N, D] embeddings + labels: [N] class labels + label_names: Optional label names + top_k: Number of confused pairs to return + + Returns: + List of dicts with class pairs and their similarity + """ + unique_labels = labels.unique() + class_centroids = {} + + # Compute class centroids + for label in unique_labels: + mask = labels == label + centroid = embeddings[mask].mean(dim=0) + centroid = F.normalize(centroid, dim=0) + class_centroids[label.item()] = centroid + + # Compute pairwise centroid similarities + confusions = [] + label_list = list(class_centroids.keys()) + + for i, label1 in enumerate(label_list): + for label2 in label_list[i + 1:]: + sim = (class_centroids[label1] @ class_centroids[label2]).item() + name1 = label_names[label1] if label_names else str(label1) + name2 = label_names[label2] if label_names else str(label2) + confusions.append({ + "class1": name1, + "class2": name2, + "label1": label1, + "label2": label2, + "centroid_similarity": sim, + }) + + # Sort by similarity (highest first) + confusions.sort(key=lambda x: -x["centroid_similarity"]) + + return confusions[:top_k] diff --git a/training/losses.py b/training/losses.py new file mode 100644 index 0000000..c6b6b2e --- /dev/null +++ b/training/losses.py @@ -0,0 +1,326 @@ +""" +Loss functions for contrastive learning on logo embeddings. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Optional + + +class InfoNCELoss(nn.Module): + """ + Normalized Temperature-scaled Cross Entropy Loss (InfoNCE). + + This is the contrastive loss used in CLIP training. It maximizes + similarity between embeddings of the same logo class while + minimizing similarity to embeddings of different classes. + + For a batch with N samples: + - Each sample is an anchor + - Positive pairs: samples with the same label + - Negative pairs: samples with different labels + + The loss for each anchor is: + -log(sum(exp(sim(anchor, pos)/temp)) / sum(exp(sim(anchor, all)/temp))) + """ + + def __init__(self, temperature: float = 0.07): + """ + Initialize InfoNCE loss. + + Args: + temperature: Scaling factor for similarities (0.05-0.1 typical). + Lower temperature makes the distribution sharper. + """ + super().__init__() + self.temperature = temperature + + def forward( + self, + embeddings: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """ + Compute InfoNCE loss for a batch of embeddings. + + Args: + embeddings: [N, D] L2-normalized embeddings + labels: [N] integer logo class labels + + Returns: + Scalar loss value + """ + device = embeddings.device + batch_size = embeddings.shape[0] + + if batch_size <= 1: + return torch.tensor(0.0, device=device, requires_grad=True) + + # Compute similarity matrix [N, N] + # Since embeddings are L2-normalized, dot product = cosine similarity + similarity = embeddings @ embeddings.T / self.temperature + + # Create positive mask: same label = 1, different = 0 + labels_col = labels.unsqueeze(0) # [1, N] + labels_row = labels.unsqueeze(1) # [N, 1] + positive_mask = (labels_row == labels_col).float() # [N, N] + + # Remove self-similarity from positives (diagonal) + identity = torch.eye(batch_size, device=device) + positive_mask = positive_mask - identity + + # Count positives per anchor (avoid division by zero) + num_positives = positive_mask.sum(dim=1) + has_positives = num_positives > 0 + + # If no positives exist for any anchor, return zero loss + if not has_positives.any(): + return torch.tensor(0.0, device=device, requires_grad=True) + + # Mask out self-similarity with large negative value + similarity = similarity - identity * 1e9 + + # Compute log-softmax over similarities + log_softmax = F.log_softmax(similarity, dim=1) + + # Sum log probabilities of positive pairs + positive_log_probs = (log_softmax * positive_mask).sum(dim=1) + + # Average over number of positives (only for anchors with positives) + loss_per_anchor = torch.zeros(batch_size, device=device) + loss_per_anchor[has_positives] = ( + -positive_log_probs[has_positives] / num_positives[has_positives] + ) + + return loss_per_anchor.mean() + + +class SupConLoss(nn.Module): + """ + Supervised Contrastive Loss. + + Similar to InfoNCE but uses a different formulation that + considers each positive pair separately rather than averaging. + + Reference: https://arxiv.org/abs/2004.11362 + """ + + def __init__(self, temperature: float = 0.07): + super().__init__() + self.temperature = temperature + + def forward( + self, + embeddings: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """ + Compute Supervised Contrastive loss. + + Args: + embeddings: [N, D] L2-normalized embeddings + labels: [N] integer logo class labels + + Returns: + Scalar loss value + """ + device = embeddings.device + batch_size = embeddings.shape[0] + + if batch_size <= 1: + return torch.tensor(0.0, device=device, requires_grad=True) + + # Compute similarity matrix + similarity = embeddings @ embeddings.T / self.temperature + + # Create masks + labels_col = labels.unsqueeze(0) + labels_row = labels.unsqueeze(1) + positive_mask = (labels_row == labels_col).float() + identity = torch.eye(batch_size, device=device) + + # Remove self from positives + positive_mask = positive_mask - identity + + # Number of positives per anchor + num_positives = positive_mask.sum(dim=1) + has_positives = num_positives > 0 + + if not has_positives.any(): + return torch.tensor(0.0, device=device, requires_grad=True) + + # For numerical stability, subtract max similarity + sim_max, _ = similarity.max(dim=1, keepdim=True) + similarity = similarity - sim_max.detach() + + # Compute exp(similarity) with self masked out + exp_sim = torch.exp(similarity) * (1 - identity) + + # Denominator: sum of exp over all pairs except self + log_prob = similarity - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8) + + # Mean of log-prob over positive pairs + mean_log_prob_pos = (positive_mask * log_prob).sum(dim=1) / ( + num_positives + 1e-8 + ) + + # Loss is negative mean log probability + loss = -mean_log_prob_pos[has_positives].mean() + + return loss + + +class TripletLoss(nn.Module): + """ + Triplet loss with online hard mining. + + For each anchor: + - Hardest positive: most distant sample with same label + - Hardest negative: closest sample with different label + + Loss = max(0, d(anchor, hardest_pos) - d(anchor, hardest_neg) + margin) + + This is an alternative to InfoNCE for when batch sizes are small. + """ + + def __init__(self, margin: float = 0.3): + """ + Initialize Triplet loss. + + Args: + margin: Minimum required gap between positive and negative distances + """ + super().__init__() + self.margin = margin + + def forward( + self, + embeddings: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """ + Compute triplet loss with online hard mining. + + Args: + embeddings: [N, D] L2-normalized embeddings + labels: [N] integer logo class labels + + Returns: + Scalar loss value + """ + device = embeddings.device + batch_size = embeddings.shape[0] + + if batch_size <= 1: + return torch.tensor(0.0, device=device, requires_grad=True) + + # Compute pairwise cosine distances (1 - cosine_similarity) + # For normalized vectors: distance = 1 - dot_product + similarity = embeddings @ embeddings.T + distances = 1 - similarity + + # Create masks + labels_col = labels.unsqueeze(0) + labels_row = labels.unsqueeze(1) + positive_mask = (labels_row == labels_col).float() + negative_mask = 1 - positive_mask + + # Remove self from positives (diagonal) + identity = torch.eye(batch_size, device=device) + positive_mask = positive_mask - identity + + # Check if we have any valid triplets + has_positives = positive_mask.sum(dim=1) > 0 + has_negatives = negative_mask.sum(dim=1) > 0 + valid_anchors = has_positives & has_negatives + + if not valid_anchors.any(): + return torch.tensor(0.0, device=device, requires_grad=True) + + # For each anchor, find hardest positive (max distance among positives) + # Set negatives to -inf so they don't affect max + pos_distances = distances.clone() + pos_distances[positive_mask == 0] = float("-inf") + hardest_positive, _ = pos_distances.max(dim=1) + + # For each anchor, find hardest negative (min distance among negatives) + # Set positives to inf so they don't affect min + neg_distances = distances.clone() + neg_distances[negative_mask == 0] = float("inf") + hardest_negative, _ = neg_distances.min(dim=1) + + # Triplet loss: want positive to be closer than negative by margin + triplet_loss = F.relu( + hardest_positive - hardest_negative + self.margin + ) + + # Average over valid anchors only + loss = triplet_loss[valid_anchors].mean() + + return loss + + +class CombinedLoss(nn.Module): + """ + Combined loss function with weighted InfoNCE and Triplet losses. + + Can help stabilize training by combining the benefits of both losses. + """ + + def __init__( + self, + temperature: float = 0.07, + triplet_margin: float = 0.3, + infonce_weight: float = 1.0, + triplet_weight: float = 0.5, + ): + super().__init__() + self.infonce = InfoNCELoss(temperature=temperature) + self.triplet = TripletLoss(margin=triplet_margin) + self.infonce_weight = infonce_weight + self.triplet_weight = triplet_weight + + def forward( + self, + embeddings: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + infonce_loss = self.infonce(embeddings, labels) + triplet_loss = self.triplet(embeddings, labels) + + return ( + self.infonce_weight * infonce_loss + + self.triplet_weight * triplet_loss + ) + + +def get_loss_function( + loss_type: str = "infonce", + temperature: float = 0.07, + triplet_margin: float = 0.3, +) -> nn.Module: + """ + Factory function to create loss function. + + Args: + loss_type: One of "infonce", "supcon", "triplet", or "combined" + temperature: Temperature for InfoNCE/SupCon + triplet_margin: Margin for triplet loss + + Returns: + Loss function module + """ + if loss_type == "infonce": + return InfoNCELoss(temperature=temperature) + elif loss_type == "supcon": + return SupConLoss(temperature=temperature) + elif loss_type == "triplet": + return TripletLoss(margin=triplet_margin) + elif loss_type == "combined": + return CombinedLoss( + temperature=temperature, + triplet_margin=triplet_margin, + ) + else: + raise ValueError(f"Unknown loss type: {loss_type}") diff --git a/training/model.py b/training/model.py new file mode 100644 index 0000000..90a30b9 --- /dev/null +++ b/training/model.py @@ -0,0 +1,335 @@ +""" +Fine-tunable CLIP model wrapper with LoRA support. +""" + +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import CLIPModel, CLIPProcessor + +# Check if peft is available for LoRA +try: + from peft import LoraConfig, get_peft_model, PeftModel + PEFT_AVAILABLE = True +except ImportError: + PEFT_AVAILABLE = False + LoraConfig = None + get_peft_model = None + PeftModel = None + + +class LogoFineTunedCLIP(nn.Module): + """ + CLIP vision encoder fine-tuned for logo similarity. + + Preserves embedding interface for compatibility with DetectLogosDETR: + - Same embedding dimensionality (768 for ViT-L/14) + - L2 normalized outputs + - Works with existing get_image_features() pattern + + Supports: + - LoRA for memory-efficient fine-tuning + - Layer freezing for transfer learning + - Gradient checkpointing for memory optimization + """ + + def __init__( + self, + vision_model: nn.Module, + lora_r: int = 16, + lora_alpha: int = 32, + lora_dropout: float = 0.1, + freeze_layers: int = 12, + use_gradient_checkpointing: bool = True, + add_projection_head: bool = True, + ): + """ + Initialize the fine-tunable CLIP wrapper. + + Args: + vision_model: CLIP vision model (CLIPVisionModel) + lora_r: Rank of LoRA low-rank matrices (0 to disable) + lora_alpha: LoRA scaling factor + lora_dropout: Dropout for LoRA layers + freeze_layers: Number of transformer layers to freeze (from bottom) + use_gradient_checkpointing: Enable gradient checkpointing + add_projection_head: Add trainable projection head + """ + super().__init__() + + self.vision_model = vision_model + self.embedding_dim = vision_model.config.hidden_size + self.freeze_layers = freeze_layers + self.lora_r = lora_r + self.lora_alpha = lora_alpha + + # Enable gradient checkpointing for memory efficiency + if use_gradient_checkpointing: + if hasattr(self.vision_model, "gradient_checkpointing_enable"): + self.vision_model.gradient_checkpointing_enable() + + # Freeze lower layers + self._freeze_layers(freeze_layers) + + # Apply LoRA to attention layers in upper blocks + self.peft_applied = False + if PEFT_AVAILABLE and lora_r > 0: + self._apply_lora(lora_r, lora_alpha, lora_dropout) + self.peft_applied = True + elif lora_r > 0 and not PEFT_AVAILABLE: + print( + "Warning: peft not installed. LoRA disabled. " + "Install with: pip install peft" + ) + + # Optional projection head for fine-tuning + self.add_projection_head = add_projection_head + if add_projection_head: + self.projection = nn.Sequential( + nn.Linear(self.embedding_dim, self.embedding_dim), + nn.LayerNorm(self.embedding_dim), + ) + else: + self.projection = nn.Identity() + + def _freeze_layers(self, num_layers: int) -> None: + """Freeze the first N transformer layers and embeddings.""" + if num_layers <= 0: + return + + # Freeze embeddings + if hasattr(self.vision_model, "embeddings"): + for param in self.vision_model.embeddings.parameters(): + param.requires_grad = False + + # Freeze specified number of encoder layers + if hasattr(self.vision_model, "encoder"): + for i, layer in enumerate(self.vision_model.encoder.layers): + if i < num_layers: + for param in layer.parameters(): + param.requires_grad = False + + def _apply_lora( + self, + r: int, + alpha: int, + dropout: float, + ) -> None: + """Apply LoRA adapters to attention layers.""" + if not PEFT_AVAILABLE: + return + + # Configure LoRA for vision transformer + lora_config = LoraConfig( + r=r, + lora_alpha=alpha, + lora_dropout=dropout, + target_modules=["q_proj", "v_proj"], + bias="none", + modules_to_save=[], # Don't save any full modules + ) + + self.vision_model = get_peft_model(self.vision_model, lora_config) + + def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: + """ + Extract normalized embeddings for logo images. + + Args: + pixel_values: [batch, 3, 224, 224] preprocessed images + + Returns: + embeddings: [batch, embedding_dim] L2-normalized + """ + # Get vision features + outputs = self.vision_model(pixel_values=pixel_values) + + # Use pooler output (CLS token projection) if available + if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: + features = outputs.pooler_output + else: + # Fall back to CLS token from last hidden state + features = outputs.last_hidden_state[:, 0, :] + + # Apply projection head + features = self.projection(features) + + # L2 normalize for cosine similarity + features = F.normalize(features, dim=-1) + + return features + + def get_image_features(self, **kwargs) -> torch.Tensor: + """ + Compatibility method matching CLIP's interface. + + Used by DetectLogosDETR._get_embedding_pil(). + """ + return self.forward(kwargs["pixel_values"]) + + def get_trainable_parameters(self) -> List[torch.nn.Parameter]: + """Return list of trainable parameters.""" + return [p for p in self.parameters() if p.requires_grad] + + def get_parameter_count(self) -> Dict[str, int]: + """Return count of trainable and total parameters.""" + total = sum(p.numel() for p in self.parameters()) + trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) + return { + "total": total, + "trainable": trainable, + "frozen": total - trainable, + "trainable_percent": 100 * trainable / total if total > 0 else 0, + } + + def save_pretrained(self, output_dir: str) -> None: + """ + Save model in HuggingFace-compatible format. + + Args: + output_dir: Directory to save model files + """ + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Save model weights + if self.peft_applied and PEFT_AVAILABLE: + # Save LoRA weights separately + self.vision_model.save_pretrained(output_path / "vision_lora") + # Save projection head + torch.save( + self.projection.state_dict(), + output_path / "projection_head.bin", + ) + else: + # Save full model state + torch.save(self.state_dict(), output_path / "pytorch_model.bin") + + # Save config + config = { + "model_type": "clip_logo_finetuned", + "embedding_dim": self.embedding_dim, + "lora_r": self.lora_r, + "lora_alpha": self.lora_alpha, + "freeze_layers": self.freeze_layers, + "add_projection_head": self.add_projection_head, + "peft_applied": self.peft_applied, + } + + with open(output_path / "config.json", "w") as f: + json.dump(config, f, indent=2) + + @classmethod + def from_pretrained( + cls, + model_path: str, + base_model: str = "openai/clip-vit-large-patch14", + device: Optional[torch.device] = None, + ) -> "LogoFineTunedCLIP": + """ + Load a fine-tuned model from saved weights. + + Args: + model_path: Path to saved model directory + base_model: Base CLIP model name (for architecture) + device: Device to load model on + + Returns: + Loaded LogoFineTunedCLIP model + """ + model_path = Path(model_path) + + # Load config + with open(model_path / "config.json", "r") as f: + config = json.load(f) + + # Load base CLIP model + clip_model = CLIPModel.from_pretrained(base_model) + + # Create model instance + model = cls( + vision_model=clip_model.vision_model, + lora_r=config.get("lora_r", 0), + lora_alpha=config.get("lora_alpha", 1), + freeze_layers=config.get("freeze_layers", 12), + add_projection_head=config.get("add_projection_head", True), + use_gradient_checkpointing=False, # Not needed for inference + ) + + # Load weights + if config.get("peft_applied", False) and PEFT_AVAILABLE: + # Load LoRA weights + lora_path = model_path / "vision_lora" + if lora_path.exists(): + model.vision_model = PeftModel.from_pretrained( + model.vision_model, lora_path + ) + # Load projection head + proj_path = model_path / "projection_head.bin" + if proj_path.exists(): + model.projection.load_state_dict(torch.load(proj_path)) + else: + # Load full model state + weights_path = model_path / "pytorch_model.bin" + if weights_path.exists(): + model.load_state_dict(torch.load(weights_path)) + + if device is not None: + model = model.to(device) + + return model + + +def create_model( + base_model: str = "openai/clip-vit-large-patch14", + lora_r: int = 16, + lora_alpha: int = 32, + lora_dropout: float = 0.1, + freeze_layers: int = 12, + use_gradient_checkpointing: bool = True, + device: Optional[torch.device] = None, +) -> Tuple[LogoFineTunedCLIP, CLIPProcessor]: + """ + Create a fine-tunable CLIP model and processor. + + Args: + base_model: HuggingFace model name or path + lora_r: LoRA rank (0 to disable) + lora_alpha: LoRA scaling factor + lora_dropout: LoRA dropout + freeze_layers: Number of layers to freeze + use_gradient_checkpointing: Enable gradient checkpointing + device: Device to load model on + + Returns: + Tuple of (model, processor) + """ + # Load base CLIP model + clip_model = CLIPModel.from_pretrained(base_model) + processor = CLIPProcessor.from_pretrained(base_model) + + # Create fine-tunable wrapper + model = LogoFineTunedCLIP( + vision_model=clip_model.vision_model, + lora_r=lora_r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + freeze_layers=freeze_layers, + use_gradient_checkpointing=use_gradient_checkpointing, + ) + + if device is not None: + model = model.to(device) + + # Print parameter info + param_info = model.get_parameter_count() + print(f"Model created:") + print(f" Total parameters: {param_info['total']:,}") + print(f" Trainable: {param_info['trainable']:,} ({param_info['trainable_percent']:.2f}%)") + print(f" Frozen: {param_info['frozen']:,}") + + return model, processor diff --git a/training/trainer.py b/training/trainer.py new file mode 100644 index 0000000..e624283 --- /dev/null +++ b/training/trainer.py @@ -0,0 +1,405 @@ +""" +Training loop with checkpointing, mixed precision, and evaluation. +""" + +import json +import logging +import time +from pathlib import Path +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +from torch.optim import AdamW +from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR +from torch.utils.data import DataLoader +from tqdm import tqdm + +from .config import TrainingConfig +from .losses import get_loss_function +from .evaluation import EmbeddingEvaluator + +# Check if amp is available +try: + from torch.cuda.amp import autocast, GradScaler + AMP_AVAILABLE = True +except ImportError: + AMP_AVAILABLE = False + autocast = None + GradScaler = None + + +class Trainer: + """ + Trainer for fine-tuning CLIP on logo recognition. + + Features: + - Mixed precision training (FP16) + - Gradient accumulation + - Gradient checkpointing (via model) + - Cosine annealing LR scheduler + - Early stopping + - Checkpoint saving/loading + - Evaluation during training + """ + + def __init__( + self, + model: nn.Module, + train_loader: DataLoader, + val_loader: DataLoader, + config: TrainingConfig, + logger: Optional[logging.Logger] = None, + ): + """ + Initialize the trainer. + + Args: + model: LogoFineTunedCLIP model + train_loader: Training dataloader + val_loader: Validation dataloader + config: Training configuration + logger: Optional logger instance + """ + self.model = model + self.train_loader = train_loader + self.val_loader = val_loader + self.config = config + self.logger = logger or logging.getLogger(__name__) + + # Device setup + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self.model.to(self.device) + self.logger.info(f"Using device: {self.device}") + + # Optimizer - only trainable parameters + trainable_params = [p for p in model.parameters() if p.requires_grad] + self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}") + + self.optimizer = AdamW( + trainable_params, + lr=config.learning_rate, + weight_decay=config.weight_decay, + ) + + # Learning rate scheduler + total_steps = len(train_loader) * config.max_epochs + self.scheduler = OneCycleLR( + self.optimizer, + max_lr=config.learning_rate, + total_steps=total_steps, + pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1, + anneal_strategy="cos", + ) + + # Mixed precision training + self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda" + if self.use_amp: + self.scaler = GradScaler() + self.logger.info("Mixed precision training enabled") + else: + self.scaler = None + if config.mixed_precision and not AMP_AVAILABLE: + self.logger.warning("Mixed precision requested but not available") + + # Loss function + self.criterion = get_loss_function( + loss_type=config.loss_type, + temperature=config.temperature, + triplet_margin=config.triplet_margin, + ) + + # Evaluator + self.evaluator = EmbeddingEvaluator() + + # Training state + self.epoch = 0 + self.global_step = 0 + self.best_val_loss = float("inf") + self.best_val_separation = float("-inf") + self.patience_counter = 0 + self.training_history = [] + + def train(self) -> Dict[str, float]: + """ + Main training loop. + + Returns: + Dict with final training metrics + """ + self.logger.info("Starting training...") + self.logger.info(f" Epochs: {self.config.max_epochs}") + self.logger.info(f" Batch size: {self.config.batch_size}") + self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}") + self.logger.info(f" Effective batch: {self.config.effective_batch_size}") + self.logger.info(f" Learning rate: {self.config.learning_rate}") + + start_time = time.time() + + for epoch in range(self.epoch, self.config.max_epochs): + self.epoch = epoch + self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}") + + # Training epoch + train_metrics = self._train_epoch() + self.logger.info( + f"Train - Loss: {train_metrics['loss']:.4f}, " + f"LR: {train_metrics['lr']:.2e}" + ) + + # Validation + if (epoch + 1) % self.config.eval_every_n_epochs == 0: + val_metrics = self._validate() + self.logger.info( + f"Val - Loss: {val_metrics['loss']:.4f}, " + f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, " + f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, " + f"Separation: {val_metrics['separation']:.3f}" + ) + + # Record history + self.training_history.append({ + "epoch": epoch + 1, + "train_loss": train_metrics["loss"], + "val_loss": val_metrics["loss"], + "val_separation": val_metrics["separation"], + "val_pos_sim": val_metrics["mean_pos_sim"], + "val_neg_sim": val_metrics["mean_neg_sim"], + }) + + # Checkpointing based on separation (primary) or loss (secondary) + improved = False + if val_metrics["separation"] > self.best_val_separation + self.config.min_delta: + self.best_val_separation = val_metrics["separation"] + improved = True + elif val_metrics["loss"] < self.best_val_loss - self.config.min_delta: + self.best_val_loss = val_metrics["loss"] + improved = True + + if improved: + self.patience_counter = 0 + self._save_checkpoint("best.pt") + self.logger.info("New best model saved!") + else: + self.patience_counter += 1 + + # Early stopping + if self.patience_counter >= self.config.patience: + self.logger.info( + f"Early stopping triggered at epoch {epoch + 1} " + f"(no improvement for {self.config.patience} epochs)" + ) + break + + # Periodic checkpoint + if (epoch + 1) % self.config.save_every_n_epochs == 0: + self._save_checkpoint(f"epoch_{epoch + 1}.pt") + + # Training complete + total_time = time.time() - start_time + self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes") + + # Load best model + best_path = Path(self.config.checkpoint_dir) / "best.pt" + if best_path.exists(): + self.load_checkpoint("best.pt") + self.logger.info("Loaded best model checkpoint") + + return { + "best_val_loss": self.best_val_loss, + "best_val_separation": self.best_val_separation, + "total_epochs": self.epoch + 1, + "total_time_minutes": total_time / 60, + } + + def _train_epoch(self) -> Dict[str, float]: + """Run a single training epoch.""" + self.model.train() + total_loss = 0.0 + num_batches = 0 + accumulation_steps = 0 + + progress_bar = tqdm( + self.train_loader, + desc=f"Epoch {self.epoch + 1}", + leave=False, + ) + + self.optimizer.zero_grad() + + for batch_idx, (images, labels) in enumerate(progress_bar): + images = images.to(self.device) + labels = labels.to(self.device) + + # Forward pass with mixed precision + if self.use_amp: + with autocast(): + embeddings = self.model(images) + loss = self.criterion(embeddings, labels) + loss = loss / self.config.gradient_accumulation_steps + + self.scaler.scale(loss).backward() + else: + embeddings = self.model(images) + loss = self.criterion(embeddings, labels) + loss = loss / self.config.gradient_accumulation_steps + loss.backward() + + accumulation_steps += 1 + + # Optimizer step after accumulation + if accumulation_steps >= self.config.gradient_accumulation_steps: + if self.use_amp: + self.scaler.step(self.optimizer) + self.scaler.update() + else: + self.optimizer.step() + + self.optimizer.zero_grad() + self.scheduler.step() + self.global_step += 1 + accumulation_steps = 0 + + total_loss += loss.item() * self.config.gradient_accumulation_steps + num_batches += 1 + + # Update progress bar + progress_bar.set_postfix({ + "loss": total_loss / num_batches, + "lr": self.scheduler.get_last_lr()[0], + }) + + # Logging + if (batch_idx + 1) % self.config.log_every_n_steps == 0: + self.logger.debug( + f"Step {self.global_step}: loss={total_loss / num_batches:.4f}" + ) + + return { + "loss": total_loss / max(num_batches, 1), + "lr": self.scheduler.get_last_lr()[0], + } + + def _validate(self) -> Dict[str, float]: + """Run validation and compute metrics.""" + self.model.eval() + total_loss = 0.0 + all_embeddings = [] + all_labels = [] + + with torch.no_grad(): + for images, labels in tqdm(self.val_loader, desc="Validating", leave=False): + images = images.to(self.device) + labels = labels.to(self.device) + + if self.use_amp: + with autocast(): + embeddings = self.model(images) + loss = self.criterion(embeddings, labels) + else: + embeddings = self.model(images) + loss = self.criterion(embeddings, labels) + + total_loss += loss.item() + all_embeddings.append(embeddings.cpu()) + all_labels.append(labels.cpu()) + + # Combine batches + all_embeddings = torch.cat(all_embeddings, dim=0) + all_labels = torch.cat(all_labels, dim=0) + + # Compute embedding quality metrics + metrics = self.evaluator.compute_metrics(all_embeddings, all_labels) + metrics["loss"] = total_loss / max(len(self.val_loader), 1) + + return metrics + + def _save_checkpoint(self, filename: str) -> None: + """Save training checkpoint.""" + checkpoint_dir = Path(self.config.checkpoint_dir) + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + checkpoint = { + "epoch": self.epoch, + "global_step": self.global_step, + "model_state_dict": self.model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + "scheduler_state_dict": self.scheduler.state_dict(), + "best_val_loss": self.best_val_loss, + "best_val_separation": self.best_val_separation, + "patience_counter": self.patience_counter, + "training_history": self.training_history, + "config": self.config.__dict__, + } + + if self.scaler is not None: + checkpoint["scaler_state_dict"] = self.scaler.state_dict() + + torch.save(checkpoint, checkpoint_dir / filename) + self.logger.debug(f"Saved checkpoint: {filename}") + + def load_checkpoint(self, filename: str) -> None: + """Load training checkpoint.""" + checkpoint_path = Path(self.config.checkpoint_dir) / filename + if not checkpoint_path.exists(): + self.logger.warning(f"Checkpoint not found: {checkpoint_path}") + return + + checkpoint = torch.load(checkpoint_path, map_location=self.device) + + self.model.load_state_dict(checkpoint["model_state_dict"]) + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) + self.epoch = checkpoint["epoch"] + self.global_step = checkpoint["global_step"] + self.best_val_loss = checkpoint["best_val_loss"] + self.best_val_separation = checkpoint.get("best_val_separation", float("-inf")) + self.patience_counter = checkpoint.get("patience_counter", 0) + self.training_history = checkpoint.get("training_history", []) + + if self.scaler is not None and "scaler_state_dict" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) + + self.logger.info(f"Resumed from epoch {self.epoch + 1}") + + def export_model(self, output_dir: Optional[str] = None) -> str: + """ + Export the trained model for inference. + + Args: + output_dir: Output directory (uses config.output_dir if not specified) + + Returns: + Path to exported model directory + """ + output_dir = output_dir or self.config.output_dir + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + # Save model + self.model.save_pretrained(output_dir) + + # Save training config + config_path = output_path / "training_config.json" + with open(config_path, "w") as f: + json.dump(self.config.__dict__, f, indent=2) + + # Save training history + history_path = output_path / "training_history.json" + with open(history_path, "w") as f: + json.dump(self.training_history, f, indent=2) + + self.logger.info(f"Model exported to: {output_path}") + return str(output_path) + + def get_training_summary(self) -> Dict: + """Get summary of training.""" + return { + "epochs_completed": self.epoch + 1, + "global_steps": self.global_step, + "best_val_loss": self.best_val_loss, + "best_val_separation": self.best_val_separation, + "history": self.training_history, + }