Add CLIP fine-tuning pipeline for logo recognition

Implement contrastive learning with LoRA to fine-tune CLIP's vision
encoder on LogoDet-3K dataset for improved logo embedding similarity.

New training module (training/):
- config.py: TrainingConfig dataclass with all hyperparameters
- dataset.py: LogoContrastiveDataset with logo-level splits
- model.py: LogoFineTunedCLIP wrapper with LoRA support
- losses.py: InfoNCE, TripletLoss, SupConLoss implementations
- trainer.py: Training loop with mixed precision and checkpointing
- evaluation.py: EmbeddingEvaluator for validation metrics

New scripts:
- train_clip_logo.py: Main training entry point
- export_model.py: Export to HuggingFace-compatible format

Configurations:
- configs/jetson_orin.yaml: Optimized for Jetson Orin AGX
- configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs
- configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs

Documentation:
- CLIP_FINETUNING.md: Training guide and usage instructions
- CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates

Modified:
- logo_detection_detr.py: Add fine-tuned model loading support
- pyproject.toml: Add peft, pyyaml, torchvision dependencies
This commit is contained in:
Rick McEwen
2026-01-04 13:45:25 -05:00
parent 1551360028
commit 44e8b6ae7d
16 changed files with 3334 additions and 12 deletions

266
CLIP_FINETUNING.md Normal file
View File

@ -0,0 +1,266 @@
# CLIP Fine-Tuning for Logo Recognition
This document describes the CLIP fine-tuning pipeline for improving logo embedding similarity using the LogoDet-3K dataset.
## Overview
The fine-tuning approach uses **contrastive learning** with **LoRA** (Low-Rank Adaptation) to train CLIP's vision encoder for better logo similarity matching while maintaining compatibility with the existing `DetectLogosDETR` class.
**Goal**: Improve F1 from ~60% to >72% on logo matching tasks.
## Files Created
### Training Module (`training/`)
| File | Description |
|------|-------------|
| `__init__.py` | Module exports |
| `config.py` | `TrainingConfig` dataclass with all hyperparameters |
| `dataset.py` | `LogoContrastiveDataset` with logo-level splits and augmentations |
| `model.py` | `LogoFineTunedCLIP` wrapper with LoRA support |
| `losses.py` | `InfoNCELoss`, `TripletLoss`, `SupConLoss`, `CombinedLoss` |
| `trainer.py` | Training loop with mixed precision, checkpointing, early stopping |
| `evaluation.py` | `EmbeddingEvaluator` for validation metrics |
### Scripts
| File | Description |
|------|-------------|
| `train_clip_logo.py` | Main training entry point |
| `export_model.py` | Export trained models to HuggingFace-compatible format |
### Configuration
| File | Description |
|------|-------------|
| `configs/jetson_orin.yaml` | Training config optimized for Jetson Orin AGX |
## Prerequisites
1. **Install dependencies**:
```bash
uv sync
```
2. **Prepare test data** (if not already done):
```bash
uv run python prepare_test_data.py
```
This creates:
- `reference_logos/` - Cropped logo images organized by category/brand
- `test_images/` - Full images for testing
- `test_data_mapping.db` - SQLite database with mappings
## Training
### Basic Training
```bash
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
```
### Training with Overrides
```bash
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
--learning-rate 5e-6 \
--max-epochs 30 \
--batch-size 8
```
### Resume from Checkpoint
```bash
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
--resume checkpoints/epoch_10.pt
```
### Training Output
- Checkpoints saved to `checkpoints/`
- Best model saved as `checkpoints/best.pt`
- Final model exported to `models/logo_detection/clip_finetuned/`
## Configuration Options
Key parameters in `configs/jetson_orin.yaml`:
```yaml
# Model
base_model: "openai/clip-vit-large-patch14"
lora_r: 16 # LoRA rank (0 to disable)
lora_alpha: 32 # LoRA scaling factor
freeze_layers: 12 # Freeze first N transformer layers
# Batch construction
batch_size: 16
logos_per_batch: 32 # Different logos per batch
samples_per_logo: 4 # Samples per logo (creates positive pairs)
gradient_accumulation_steps: 8 # Effective batch = 128
# Training
learning_rate: 1.0e-5
max_epochs: 20
mixed_precision: true
temperature: 0.07 # InfoNCE temperature
# Early stopping
patience: 5
min_delta: 0.001
```
## Evaluation
### Test Fine-Tuned Model
```bash
uv run python test_logo_detection.py -n 50 \
-e models/logo_detection/clip_finetuned \
--matching-method multi-ref \
--seed 42
```
### Compare with Baseline
```bash
# Baseline CLIP
uv run python test_logo_detection.py -n 50 \
-e openai/clip-vit-large-patch14 \
--matching-method multi-ref \
--seed 42
# Fine-tuned model
uv run python test_logo_detection.py -n 50 \
-e models/logo_detection/clip_finetuned \
--matching-method multi-ref \
--seed 42
```
### Expected Metrics
| Metric | Baseline CLIP | Target (Fine-tuned) |
|--------|---------------|---------------------|
| Precision | ~49% | >70% |
| Recall | ~77% | >75% |
| F1 Score | ~60% | >72% |
Training metrics to monitor:
- Mean positive similarity: target > 0.85
- Mean negative similarity: target < 0.50
- Embedding separation: target > 0.35
## Export Model
To export a checkpoint to HuggingFace format:
```bash
uv run python export_model.py \
--checkpoint checkpoints/best.pt \
--output models/logo_detection/clip_finetuned
```
With LoRA weight merging (reduces inference overhead):
```bash
uv run python export_model.py \
--checkpoint checkpoints/best.pt \
--output models/logo_detection/clip_finetuned \
--merge-lora
```
## Using Fine-Tuned Model with DetectLogosDETR
The fine-tuned model works as a drop-in replacement:
```python
from logo_detection_detr import DetectLogosDETR
# Use fine-tuned model
detector = DetectLogosDETR(
logger=logger,
embedding_model="models/logo_detection/clip_finetuned",
)
# Or use baseline for comparison
detector_baseline = DetectLogosDETR(
logger=logger,
embedding_model="openai/clip-vit-large-patch14",
)
```
## Architecture Details
### Training Approach
1. **Contrastive Learning**: Uses InfoNCE loss to maximize similarity between embeddings of the same logo while minimizing similarity to different logos.
2. **LoRA (Low-Rank Adaptation)**: Adds small trainable matrices to attention layers instead of fine-tuning all weights. This is memory-efficient and prevents catastrophic forgetting.
3. **Layer Freezing**: Freezes the first 12 of 24 transformer layers to preserve CLIP's low-level visual features while adapting high-level semantics.
4. **Logo-Level Splits**: Splits data by logo brand (not by image) to test generalization to unseen logos.
### Batch Construction
Each batch contains:
- K different logo brands (default: 32)
- M samples per brand (default: 4)
- Total samples: K × M = 128
This ensures positive pairs (same logo) exist within each batch for contrastive learning.
### Data Augmentation
Medium strength augmentations:
- Random horizontal flip
- Random rotation (±15°)
- Color jitter (brightness, contrast, saturation)
- Random affine transforms
- Random grayscale (10% of images)
## Troubleshooting
### Out of Memory
Reduce batch size and increase gradient accumulation:
```bash
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
--batch-size 8 \
--gradient-accumulation-steps 16
```
### Slow Training
Ensure mixed precision is enabled:
```bash
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
# mixed_precision: true is default in jetson_orin.yaml
```
### No Improvement
Try adjusting:
- Lower learning rate: `--learning-rate 5e-6`
- Higher temperature: `--temperature 0.1`
- Different loss: edit config to use `loss_type: "combined"`
### Import Error for Fine-Tuned Model
Ensure the `training/` module is in your Python path:
```bash
export PYTHONPATH="${PYTHONPATH}:/data/dev.python/logo_test"
```
## Dependencies Added
The following were added to `pyproject.toml`:
```toml
peft>=0.7.0 # LoRA support
pyyaml>=6.0 # Config file parsing
torchvision>=0.20.0 # Image transforms
```

269
CLOUD_TRAINING.md Normal file
View File

@ -0,0 +1,269 @@
# Cloud GPU Training for CLIP Fine-Tuning
This document provides guidance on using cloud GPU instances (e.g., RunPod) for faster CLIP fine-tuning compared to local training on Jetson Orin AGX.
## Training Time Comparison
Local training on Jetson Orin AGX takes approximately 24 hours. Cloud GPUs offer significantly faster training:
| GPU | VRAM | Est. Training Time | Hourly Rate | Est. Total Cost |
|-----|------|-------------------|-------------|-----------------|
| **RTX 4090** | 24GB | 4-6 hours | $0.59/hr | **$2.40-$3.50** |
| **RTX 3090** | 24GB | 5-7 hours | $0.39/hr | **$2.00-$2.75** |
| **A100 80GB** | 80GB | 2-3 hours | $1.99/hr | **$4.00-$6.00** |
| **L40S** | 48GB | 3-4 hours | $0.89/hr | **$2.70-$3.60** |
| **H100 80GB** | 80GB | 1.5-2 hours | $1.99/hr | **$3.00-$4.00** |
*Prices from RunPod Community Cloud as of January 2025. Rates may vary.*
## Recommendations
### Best Value: RTX 4090 ($0.59/hr)
- 24GB VRAM is sufficient for ViT-L/14 with LoRA
- Good balance of speed and cost
- Widely available on Community Cloud
- **Total cost: ~$3 for complete training**
### Best Speed: H100 80GB ($1.99/hr)
- Fastest training (1.5-2 hours)
- 80GB VRAM allows larger batch sizes
- Can increase `batch_size` to 32+ and reduce `gradient_accumulation_steps`
- **Total cost: ~$3-4**
### Budget Option: RTX 3090 ($0.39/hr)
- Cheapest hourly rate
- 24GB VRAM works fine
- Slightly slower than 4090
- **Total cost: ~$2-3**
## Cloud-Optimized Configurations
### RTX 4090 / RTX 3090 (24GB VRAM)
Create `configs/cloud_rtx4090.yaml`:
```yaml
# Optimized for 24GB VRAM cloud GPUs
base_model: "openai/clip-vit-large-patch14"
# Dataset paths
dataset_dir: "LogoDet-3K"
reference_dir: "reference_logos"
db_path: "test_data_mapping.db"
# Data splits
train_split: 0.7
val_split: 0.15
test_split: 0.15
# Larger batches for faster training
batch_size: 32
logos_per_batch: 32
samples_per_logo: 4
gradient_accumulation_steps: 4 # Effective batch = 128
num_workers: 8
# Model architecture
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
freeze_layers: 12
use_gradient_checkpointing: true
# Training
learning_rate: 1.0e-5
weight_decay: 0.01
warmup_steps: 500
max_epochs: 20
mixed_precision: true
# Loss
temperature: 0.07
loss_type: "infonce"
# Early stopping
patience: 5
min_delta: 0.001
# Output
checkpoint_dir: "checkpoints"
output_dir: "models/logo_detection/clip_finetuned"
save_every_n_epochs: 5
# Logging
log_every_n_steps: 10
eval_every_n_epochs: 1
seed: 42
use_augmentation: true
augmentation_strength: "medium"
```
### A100 / H100 (80GB VRAM)
Create `configs/cloud_a100.yaml`:
```yaml
# Optimized for 80GB VRAM cloud GPUs (A100, H100)
base_model: "openai/clip-vit-large-patch14"
# Dataset paths
dataset_dir: "LogoDet-3K"
reference_dir: "reference_logos"
db_path: "test_data_mapping.db"
# Data splits
train_split: 0.7
val_split: 0.15
test_split: 0.15
# Maximum batch sizes for 80GB VRAM
batch_size: 64
logos_per_batch: 32
samples_per_logo: 4
gradient_accumulation_steps: 2 # Effective batch = 128
num_workers: 8
# Model architecture (can disable gradient checkpointing with 80GB)
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
freeze_layers: 12
use_gradient_checkpointing: false # Not needed with 80GB
# Training
learning_rate: 1.0e-5
weight_decay: 0.01
warmup_steps: 500
max_epochs: 20
mixed_precision: true
# Loss
temperature: 0.07
loss_type: "infonce"
# Early stopping
patience: 5
min_delta: 0.001
# Output
checkpoint_dir: "checkpoints"
output_dir: "models/logo_detection/clip_finetuned"
save_every_n_epochs: 5
# Logging
log_every_n_steps: 10
eval_every_n_epochs: 1
seed: 42
use_augmentation: true
augmentation_strength: "medium"
```
## RunPod Quick Start
### 1. Create a Pod
1. Go to [RunPod](https://www.runpod.io/)
2. Select GPU (RTX 4090 recommended)
3. Choose PyTorch template (CUDA 12.x)
4. Set volume size: 50GB (for dataset + models)
### 2. Setup Environment
```bash
# Connect via SSH or web terminal
# Install dependencies
pip install peft pyyaml torchvision transformers tqdm pillow
# Clone your repository (or upload files)
git clone <your-repo-url>
cd logo_test
# Or use runpodctl to sync files
# runpodctl send logo_test/
```
### 3. Prepare Data
If data isn't already prepared:
```bash
# This creates reference_logos/ and test_data_mapping.db
python prepare_test_data.py
```
### 4. Run Training
```bash
# For RTX 4090
python train_clip_logo.py --config configs/cloud_rtx4090.yaml
# For A100/H100
python train_clip_logo.py --config configs/cloud_a100.yaml
# Or with command-line overrides
python train_clip_logo.py --config configs/jetson_orin.yaml \
--batch-size 32 \
--gradient-accumulation-steps 4 \
--num-workers 8
```
### 5. Download Results
```bash
# Export the trained model
python export_model.py \
--checkpoint checkpoints/best.pt \
--output models/logo_detection/clip_finetuned
# Download to local machine
# Option 1: Use runpodctl
runpodctl receive models/logo_detection/clip_finetuned
# Option 2: SCP
scp -r root@<pod-ip>:/workspace/logo_test/models/logo_detection/clip_finetuned ./
# Option 3: Compress and download via web
tar -czvf clip_finetuned.tar.gz models/logo_detection/clip_finetuned
```
## Cost Optimization Tips
### Use Spot/Interruptible Instances
- Community Cloud GPUs are already cheaper
- Some providers offer spot pricing for additional savings
- Save checkpoints frequently (`save_every_n_epochs: 2`)
### Minimize Storage Costs
- RunPod charges $0.10/GB/month for container disk
- Use network volumes only if needed
- Delete pods when training completes
### Monitor Training
- Watch for early convergence (may finish before 20 epochs)
- Early stopping will save time/cost if no improvement
### Batch Training Runs
- Test configuration locally first (1-2 epochs)
- Run full training on cloud only when config is validated
## Cost Comparison Summary
| Option | Time | Cost | Best For |
|--------|------|------|----------|
| Jetson Orin (local) | ~24 hrs | Free* | No cloud dependency |
| RTX 3090 (RunPod) | ~6 hrs | ~$2.50 | Lowest cost |
| RTX 4090 (RunPod) | ~5 hrs | ~$3.00 | Best value |
| L40S (RunPod) | ~3.5 hrs | ~$3.00 | Good balance |
| A100 80GB (RunPod) | ~2.5 hrs | ~$5.00 | Large batches |
| H100 80GB (RunPod) | ~1.5 hrs | ~$3.50 | Fastest |
*Local training has electricity cost but no cloud fees.
## References
- [RunPod Pricing](https://www.runpod.io/pricing)
- [RunPod RTX 4090](https://www.runpod.io/gpu-models/rtx-4090)
- [RunPod Documentation](https://docs.runpod.io/)

64
configs/cloud_a100.yaml Normal file
View File

@ -0,0 +1,64 @@
# Training configuration optimized for cloud A100 / H100 (80GB VRAM)
#
# Usage:
# python train_clip_logo.py --config configs/cloud_a100.yaml
#
# Estimated training time: 1.5-3 hours
# Estimated cost on RunPod: ~$3-6
# Base model
base_model: "openai/clip-vit-large-patch14"
# Dataset paths
dataset_dir: "LogoDet-3K"
reference_dir: "reference_logos"
db_path: "test_data_mapping.db"
# Data splits
train_split: 0.7
val_split: 0.15
test_split: 0.15
# Maximum batch sizes for 80GB VRAM
batch_size: 64
logos_per_batch: 32
samples_per_logo: 4
gradient_accumulation_steps: 2 # Effective batch = 128
num_workers: 8
# Model architecture (no gradient checkpointing needed with 80GB)
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
freeze_layers: 12
use_gradient_checkpointing: false
# Training
learning_rate: 1.0e-5
weight_decay: 0.01
warmup_steps: 500
max_epochs: 20
mixed_precision: true
# Loss
temperature: 0.07
loss_type: "infonce"
triplet_margin: 0.3
# Early stopping
patience: 5
min_delta: 0.001
# Output
checkpoint_dir: "checkpoints"
output_dir: "models/logo_detection/clip_finetuned"
save_every_n_epochs: 2 # Save more frequently for cloud
# Logging
log_every_n_steps: 10
eval_every_n_epochs: 1
seed: 42
use_hard_negatives: false
use_augmentation: true
augmentation_strength: "medium"

View File

@ -0,0 +1,64 @@
# Training configuration optimized for cloud RTX 4090 / RTX 3090 (24GB VRAM)
#
# Usage:
# python train_clip_logo.py --config configs/cloud_rtx4090.yaml
#
# Estimated training time: 4-6 hours
# Estimated cost on RunPod: ~$3
# Base model
base_model: "openai/clip-vit-large-patch14"
# Dataset paths
dataset_dir: "LogoDet-3K"
reference_dir: "reference_logos"
db_path: "test_data_mapping.db"
# Data splits
train_split: 0.7
val_split: 0.15
test_split: 0.15
# Larger batches for faster training on 24GB VRAM
batch_size: 32
logos_per_batch: 32
samples_per_logo: 4
gradient_accumulation_steps: 4 # Effective batch = 128
num_workers: 8
# Model architecture
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
freeze_layers: 12
use_gradient_checkpointing: true
# Training
learning_rate: 1.0e-5
weight_decay: 0.01
warmup_steps: 500
max_epochs: 20
mixed_precision: true
# Loss
temperature: 0.07
loss_type: "infonce"
triplet_margin: 0.3
# Early stopping
patience: 5
min_delta: 0.001
# Output
checkpoint_dir: "checkpoints"
output_dir: "models/logo_detection/clip_finetuned"
save_every_n_epochs: 2 # Save more frequently for cloud
# Logging
log_every_n_steps: 10
eval_every_n_epochs: 1
seed: 42
use_hard_negatives: false
use_augmentation: true
augmentation_strength: "medium"

76
configs/jetson_orin.yaml Normal file
View File

@ -0,0 +1,76 @@
# Training configuration optimized for Jetson Orin AGX (~64GB shared memory)
#
# Usage:
# uv run python train_clip_logo.py --config configs/jetson_orin.yaml
# Base model
base_model: "openai/clip-vit-large-patch14"
# Dataset paths (relative to project root)
dataset_dir: "LogoDet-3K"
reference_dir: "reference_logos"
db_path: "test_data_mapping.db"
# Data split ratios (logo-level split for generalization testing)
train_split: 0.7
val_split: 0.15
test_split: 0.15
# Batch construction
# - batch_size: Number of batches loaded at once (keep low for memory)
# - logos_per_batch: Different logo classes per contrastive batch
# - samples_per_logo: Samples of each logo (creates positive pairs)
# - Effective samples per step = logos_per_batch * samples_per_logo = 128
batch_size: 16
logos_per_batch: 32
samples_per_logo: 4
gradient_accumulation_steps: 8 # Effective batch = 128
num_workers: 4
# Model architecture
# LoRA enables memory-efficient fine-tuning by training low-rank adapters
# instead of full model weights
lora_r: 16 # LoRA rank (0 to disable)
lora_alpha: 32 # LoRA scaling factor
lora_dropout: 0.1 # Dropout in LoRA layers
freeze_layers: 12 # Freeze first 12 of 24 transformer layers
use_gradient_checkpointing: true # Trade compute for memory
# Training hyperparameters
learning_rate: 1.0e-5 # Conservative LR for fine-tuning
weight_decay: 0.01 # L2 regularization
warmup_steps: 500 # LR warmup steps
max_epochs: 20 # Maximum training epochs
mixed_precision: true # FP16 training for memory efficiency
# Loss function
# InfoNCE is the contrastive loss used in CLIP training
temperature: 0.07 # Similarity scaling (0.05-0.1 typical)
loss_type: "infonce" # Options: infonce, supcon, triplet, combined
triplet_margin: 0.3 # Only used if loss_type is triplet
# Early stopping
patience: 5 # Stop if no improvement for N epochs
min_delta: 0.001 # Minimum improvement threshold
# Checkpoints and output
checkpoint_dir: "checkpoints"
output_dir: "models/logo_detection/clip_finetuned"
save_every_n_epochs: 5
# Logging
log_every_n_steps: 10
eval_every_n_epochs: 1
# Reproducibility
seed: 42
# Hard negative mining (advanced)
# Enable after initial training epochs for harder examples
use_hard_negatives: false
hard_negative_start_epoch: 5
hard_negatives_per_logo: 10
# Data augmentation
use_augmentation: true
augmentation_strength: "medium" # light, medium, or strong

169
export_model.py Normal file
View File

@ -0,0 +1,169 @@
#!/usr/bin/env python3
"""
Export a trained CLIP model to HuggingFace-compatible format.
This script converts a training checkpoint to a format that can be
loaded by DetectLogosDETR for inference.
Usage:
uv run python export_model.py \
--checkpoint checkpoints/best.pt \
--output models/logo_detection/clip_finetuned
# With custom base model
uv run python export_model.py \
--checkpoint checkpoints/best.pt \
--output models/logo_detection/clip_finetuned \
--base-model openai/clip-vit-large-patch14
"""
import argparse
import json
import logging
import sys
from pathlib import Path
import torch
from training.config import TrainingConfig
from training.model import create_model, LogoFineTunedCLIP
def setup_logging() -> logging.Logger:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
return logging.getLogger(__name__)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Export trained CLIP model for inference",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to training checkpoint (.pt file)",
)
parser.add_argument(
"--output",
type=str,
required=True,
help="Output directory for exported model",
)
parser.add_argument(
"--base-model",
type=str,
default=None,
help="Base CLIP model (reads from checkpoint config if not specified)",
)
parser.add_argument(
"--merge-lora",
action="store_true",
help="Merge LoRA weights into base model (reduces inference overhead)",
)
return parser.parse_args()
def main():
args = parse_args()
logger = setup_logging()
logger.info("CLIP Model Export")
logger.info("=" * 60)
# Check checkpoint exists
checkpoint_path = Path(args.checkpoint)
if not checkpoint_path.exists():
logger.error(f"Checkpoint not found: {checkpoint_path}")
sys.exit(1)
# Load checkpoint
logger.info(f"Loading checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Get config from checkpoint
if "config" in checkpoint:
config_dict = checkpoint["config"]
base_model = args.base_model or config_dict.get(
"base_model", "openai/clip-vit-large-patch14"
)
lora_r = config_dict.get("lora_r", 16)
lora_alpha = config_dict.get("lora_alpha", 32)
freeze_layers = config_dict.get("freeze_layers", 12)
else:
base_model = args.base_model or "openai/clip-vit-large-patch14"
lora_r = 16
lora_alpha = 32
freeze_layers = 12
logger.info(f"Base model: {base_model}")
logger.info(f"LoRA rank: {lora_r}")
logger.info(f"Freeze layers: {freeze_layers}")
# Create model with same architecture
logger.info("Creating model architecture...")
model, processor = create_model(
base_model=base_model,
lora_r=lora_r,
lora_alpha=lora_alpha,
freeze_layers=freeze_layers,
use_gradient_checkpointing=False, # Not needed for export
)
# Load weights
logger.info("Loading trained weights...")
model.load_state_dict(checkpoint["model_state_dict"])
# Merge LoRA if requested
if args.merge_lora and model.peft_applied:
try:
logger.info("Merging LoRA weights into base model...")
model.vision_model = model.vision_model.merge_and_unload()
model.peft_applied = False
model.lora_r = 0
logger.info("LoRA weights merged successfully")
except Exception as e:
logger.warning(f"Could not merge LoRA weights: {e}")
logger.warning("Exporting with separate LoRA weights")
# Create output directory
output_path = Path(args.output)
output_path.mkdir(parents=True, exist_ok=True)
# Save model
logger.info(f"Exporting to: {output_path}")
model.save_pretrained(str(output_path))
# Save processor config for reference
processor.save_pretrained(str(output_path / "processor"))
# Save additional metadata
metadata = {
"base_model": base_model,
"source_checkpoint": str(checkpoint_path),
"training_epochs": checkpoint.get("epoch", -1) + 1,
"best_val_loss": checkpoint.get("best_val_loss", None),
"best_val_separation": checkpoint.get("best_val_separation", None),
"lora_merged": args.merge_lora and not model.peft_applied,
}
with open(output_path / "export_metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
logger.info("\nExport complete!")
logger.info(f"Model saved to: {output_path}")
logger.info("\nTo use with DetectLogosDETR:")
logger.info(f" detector = DetectLogosDETR(embedding_model='{output_path}')")
logger.info("\nOr with test_logo_detection.py:")
logger.info(f" uv run python test_logo_detection.py -e {output_path}")
if __name__ == "__main__":
main()

View File

@ -13,6 +13,7 @@ Supported embedding models:
- DINOv2 models (facebook/dinov2-*): Self-supervised, excellent for visual similarity
"""
import json
import os
import torch
import torch.nn.functional as F
@ -100,16 +101,20 @@ class DetectLogosDETR:
embedding_model, default_embedding_dir, "Embedding"
)
# Detect model type and initialize accordingly
self.model_type = self._detect_model_type(embedding_model)
self.logger.info(f"Loading {self.model_type} embedding model: {embedding_model_path}")
# Check if this is a fine-tuned model
if self._is_finetuned_model(embedding_model_path):
self._load_finetuned_embedding_model(embedding_model_path)
else:
# Detect model type and initialize accordingly
self.model_type = self._detect_model_type(embedding_model)
self.logger.info(f"Loading {self.model_type} embedding model: {embedding_model_path}")
if self.model_type == "clip":
self.embedding_model = CLIPModel.from_pretrained(embedding_model_path).to(self.device)
self.embedding_processor = CLIPProcessor.from_pretrained(embedding_model_path)
else: # dinov2 or other transformer models
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
if self.model_type == "clip":
self.embedding_model = CLIPModel.from_pretrained(embedding_model_path).to(self.device)
self.embedding_processor = CLIPProcessor.from_pretrained(embedding_model_path)
else: # dinov2 or other transformer models
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
self.logger.info("DetectLogosDETR initialization complete")
@ -124,6 +129,62 @@ class DetectLogosDETR:
# Default to generic transformer for unknown models
return "transformer"
def _is_finetuned_model(self, model_path: str) -> bool:
"""Check if a model path points to a fine-tuned CLIP model."""
config_path = Path(model_path) / "config.json"
if config_path.exists():
try:
with open(config_path, "r") as f:
config = json.load(f)
return config.get("model_type") == "clip_logo_finetuned"
except (json.JSONDecodeError, IOError):
pass
return False
def _load_finetuned_embedding_model(self, model_path: str) -> None:
"""
Load a fine-tuned CLIP model from the training module.
Args:
model_path: Path to the fine-tuned model directory
"""
# Import the fine-tuned model class
try:
from training.model import LogoFineTunedCLIP
except ImportError as e:
self.logger.error(
f"Cannot import training.model for fine-tuned model: {e}"
)
raise ImportError(
"Fine-tuned model requires the training module. "
"Ensure the training/ directory is in your Python path."
) from e
# Load config
config_path = Path(model_path) / "config.json"
with open(config_path, "r") as f:
config = json.load(f)
base_model = config.get("base_model", "openai/clip-vit-large-patch14")
self.logger.info(f"Loading fine-tuned CLIP model from: {model_path}")
self.logger.info(f" Base model: {base_model}")
# Load model using the from_pretrained method
self.embedding_model = LogoFineTunedCLIP.from_pretrained(
model_path,
base_model=base_model,
device=self.device,
)
self.embedding_model.eval()
# Load processor from base model
self.embedding_processor = CLIPProcessor.from_pretrained(base_model)
# Set model type for embedding extraction
self.model_type = "clip_finetuned"
self.logger.info("Fine-tuned CLIP model loaded successfully")
def _resolve_model_path(
self, model_name_or_path: str, default_local_dir: str, model_type: str
) -> str:
@ -345,7 +406,7 @@ class DetectLogosDETR:
"""
Internal method to get embedding from PIL image.
Handles both CLIP and DINOv2 model types.
Handles CLIP, fine-tuned CLIP, and DINOv2 model types.
Args:
pil_image: PIL Image (RGB format)
@ -360,6 +421,9 @@ class DetectLogosDETR:
if self.model_type == "clip":
# CLIP has a dedicated method for image features
features = self.embedding_model.get_image_features(**inputs)
elif self.model_type == "clip_finetuned":
# Fine-tuned CLIP uses get_image_features or forward with pixel_values
features = self.embedding_model.get_image_features(**inputs)
else:
# DINOv2 and other transformers use the CLS token or pooled output
outputs = self.embedding_model(**inputs)
@ -370,8 +434,9 @@ class DetectLogosDETR:
# Use CLS token from last_hidden_state
features = outputs.last_hidden_state[:, 0, :]
# Normalize for cosine similarity
features = F.normalize(features, dim=-1)
# Normalize for cosine similarity (fine-tuned model already normalizes)
if self.model_type != "clip_finetuned":
features = F.normalize(features, dim=-1)
return features

View File

@ -12,4 +12,7 @@ dependencies = [
"tqdm>=4.67.1",
"transformers>=4.57.3",
"typing>=3.10.0.0",
"peft>=0.7.0",
"pyyaml>=6.0",
"torchvision>=0.20.0",
]

309
train_clip_logo.py Normal file
View File

@ -0,0 +1,309 @@
#!/usr/bin/env python3
"""
Fine-tune CLIP vision encoder for logo recognition.
This script trains a CLIP model using contrastive learning on the LogoDet-3K
dataset to improve logo embedding quality for similarity-based matching.
Usage:
# Train with YAML config
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
# Train with command-line overrides
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
--learning-rate 5e-6 --max-epochs 30
# Resume from checkpoint
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
--resume checkpoints/epoch_10.pt
"""
import argparse
import logging
import random
import sys
from pathlib import Path
import numpy as np
import torch
from training.config import TrainingConfig
from training.dataset import create_dataloaders
from training.model import create_model
from training.trainer import Trainer
def setup_logging(log_level: str = "INFO") -> logging.Logger:
"""Configure logging."""
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
return logging.getLogger(__name__)
def set_seed(seed: int) -> None:
"""Set random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Fine-tune CLIP for logo recognition",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Config file
parser.add_argument(
"--config",
type=str,
help="Path to YAML configuration file",
)
# Dataset paths
parser.add_argument(
"--dataset-dir",
type=str,
help="Path to LogoDet-3K dataset",
)
parser.add_argument(
"--reference-dir",
type=str,
help="Path to reference logos directory",
)
parser.add_argument(
"--db-path",
type=str,
help="Path to SQLite database",
)
# Model
parser.add_argument(
"--base-model",
type=str,
help="Base CLIP model name or path",
)
parser.add_argument(
"--lora-r",
type=int,
help="LoRA rank (0 to disable)",
)
parser.add_argument(
"--freeze-layers",
type=int,
help="Number of transformer layers to freeze",
)
# Training
parser.add_argument(
"--batch-size",
type=int,
help="Batch size",
)
parser.add_argument(
"--learning-rate",
type=float,
help="Learning rate",
)
parser.add_argument(
"--max-epochs",
type=int,
help="Maximum number of epochs",
)
parser.add_argument(
"--gradient-accumulation-steps",
type=int,
help="Gradient accumulation steps",
)
# Loss
parser.add_argument(
"--temperature",
type=float,
help="Temperature for InfoNCE loss",
)
parser.add_argument(
"--loss-type",
choices=["infonce", "supcon", "triplet", "combined"],
help="Loss function type",
)
# Checkpointing
parser.add_argument(
"--checkpoint-dir",
type=str,
help="Directory for checkpoints",
)
parser.add_argument(
"--output-dir",
type=str,
help="Directory for final model output",
)
parser.add_argument(
"--resume",
type=str,
help="Path to checkpoint to resume from",
)
# Other
parser.add_argument(
"--seed",
type=int,
help="Random seed",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level",
)
parser.add_argument(
"--no-mixed-precision",
action="store_true",
help="Disable mixed precision training",
)
return parser.parse_args()
def main():
"""Main training entry point."""
args = parse_args()
# Setup logging
logger = setup_logging(args.log_level)
logger.info("CLIP Logo Fine-Tuning")
logger.info("=" * 60)
# Load or create configuration
if args.config:
logger.info(f"Loading config from: {args.config}")
config = TrainingConfig.from_yaml(args.config)
else:
logger.info("Using default configuration")
config = TrainingConfig()
# Apply command-line overrides
override_fields = [
"dataset_dir", "reference_dir", "db_path", "base_model",
"lora_r", "freeze_layers", "batch_size", "learning_rate",
"max_epochs", "gradient_accumulation_steps", "temperature",
"loss_type", "checkpoint_dir", "output_dir", "seed",
]
for field in override_fields:
arg_name = field.replace("_", "-")
arg_value = getattr(args, field.replace("-", "_"), None)
if arg_value is not None:
setattr(config, field, arg_value)
logger.info(f"Override: {field} = {arg_value}")
if args.no_mixed_precision:
config.mixed_precision = False
logger.info("Override: mixed_precision = False")
# Validate configuration
warnings = config.validate()
for warning in warnings:
logger.warning(f"Config warning: {warning}")
# Set random seed
set_seed(config.seed)
logger.info(f"Random seed: {config.seed}")
# Check paths exist
db_path = Path(config.db_path)
ref_dir = Path(config.reference_dir)
if not db_path.exists():
logger.error(f"Database not found: {db_path}")
logger.error("Run prepare_test_data.py first to create the database.")
sys.exit(1)
if not ref_dir.exists():
logger.error(f"Reference directory not found: {ref_dir}")
logger.error("Run prepare_test_data.py first to extract reference logos.")
sys.exit(1)
# Create model
logger.info(f"Creating model from: {config.base_model}")
model, processor = create_model(
base_model=config.base_model,
lora_r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
freeze_layers=config.freeze_layers,
use_gradient_checkpointing=config.use_gradient_checkpointing,
)
# Create dataloaders
logger.info("Creating dataloaders...")
train_loader, val_loader, test_loader = create_dataloaders(
db_path=str(config.db_path),
reference_dir=str(config.reference_dir),
batch_size=config.batch_size,
logos_per_batch=config.logos_per_batch,
samples_per_logo=config.samples_per_logo,
num_workers=config.num_workers,
train_split=config.train_split,
val_split=config.val_split,
test_split=config.test_split,
seed=config.seed,
augmentation_strength=config.augmentation_strength,
)
# Create trainer
trainer = Trainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
config=config,
logger=logger,
)
# Resume from checkpoint if specified
if args.resume:
resume_path = Path(args.resume)
if resume_path.exists():
logger.info(f"Resuming from: {resume_path}")
# Set checkpoint dir to resume path's parent
if resume_path.is_file():
config.checkpoint_dir = str(resume_path.parent)
trainer.load_checkpoint(resume_path.name)
else:
logger.warning(f"Resume checkpoint not found: {resume_path}")
# Train
logger.info("\nStarting training...")
final_metrics = trainer.train()
logger.info("\nTraining complete!")
logger.info(f" Best val loss: {final_metrics['best_val_loss']:.4f}")
logger.info(f" Best separation: {final_metrics['best_val_separation']:.4f}")
logger.info(f" Total epochs: {final_metrics['total_epochs']}")
logger.info(f" Total time: {final_metrics['total_time_minutes']:.1f} minutes")
# Export model
output_path = trainer.export_model()
logger.info(f"\nModel exported to: {output_path}")
# Print next steps
logger.info("\n" + "=" * 60)
logger.info("Next steps:")
logger.info(f"1. Test the fine-tuned model:")
logger.info(f" uv run python test_logo_detection.py -n 50 \\")
logger.info(f" -e {output_path} --matching-method multi-ref")
logger.info(f"")
logger.info(f"2. Compare with baseline:")
logger.info(f" uv run python test_logo_detection.py -n 50 \\")
logger.info(f" -e openai/clip-vit-large-patch14 --matching-method multi-ref")
if __name__ == "__main__":
main()

24
training/__init__.py Normal file
View File

@ -0,0 +1,24 @@
"""
CLIP fine-tuning module for logo recognition.
This module provides tools for fine-tuning CLIP's vision encoder using
contrastive learning on the LogoDet-3K dataset.
"""
from .config import TrainingConfig
from .dataset import LogoContrastiveDataset, create_dataloaders
from .model import LogoFineTunedCLIP
from .losses import InfoNCELoss, TripletLoss
from .trainer import Trainer
from .evaluation import EmbeddingEvaluator
__all__ = [
"TrainingConfig",
"LogoContrastiveDataset",
"create_dataloaders",
"LogoFineTunedCLIP",
"InfoNCELoss",
"TripletLoss",
"Trainer",
"EmbeddingEvaluator",
]

141
training/config.py Normal file
View File

@ -0,0 +1,141 @@
"""
Training configuration for CLIP fine-tuning.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
import yaml
@dataclass
class TrainingConfig:
"""Configuration for CLIP logo fine-tuning."""
# Base model
base_model: str = "openai/clip-vit-large-patch14"
# Dataset paths
dataset_dir: str = "LogoDet-3K"
reference_dir: str = "reference_logos"
db_path: str = "test_data_mapping.db"
# Data split ratios
train_split: float = 0.7
val_split: float = 0.15
test_split: float = 0.15
# Batch construction
batch_size: int = 16
logos_per_batch: int = 32
samples_per_logo: int = 4
gradient_accumulation_steps: int = 8
num_workers: int = 4
# Model architecture
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.1
freeze_layers: int = 12
use_gradient_checkpointing: bool = True
# Training hyperparameters
learning_rate: float = 1e-5
weight_decay: float = 0.01
warmup_steps: int = 500
max_epochs: int = 20
mixed_precision: bool = True
# Loss function
temperature: float = 0.07
loss_type: str = "infonce" # "infonce" or "triplet"
triplet_margin: float = 0.3
# Early stopping
patience: int = 5
min_delta: float = 0.001
# Checkpoints and output
checkpoint_dir: str = "checkpoints"
output_dir: str = "models/logo_detection/clip_finetuned"
save_every_n_epochs: int = 5
# Logging
log_every_n_steps: int = 10
eval_every_n_epochs: int = 1
# Random seed for reproducibility
seed: int = 42
# Hard negative mining
use_hard_negatives: bool = False
hard_negative_start_epoch: int = 5
hard_negatives_per_logo: int = 10
# Data augmentation
use_augmentation: bool = True
augmentation_strength: str = "medium" # "light", "medium", "strong"
@classmethod
def from_yaml(cls, yaml_path: str) -> "TrainingConfig":
"""Load configuration from YAML file."""
with open(yaml_path, "r") as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)
def to_yaml(self, yaml_path: str) -> None:
"""Save configuration to YAML file."""
Path(yaml_path).parent.mkdir(parents=True, exist_ok=True)
with open(yaml_path, "w") as f:
yaml.dump(self.__dict__, f, default_flow_style=False, sort_keys=False)
def validate(self) -> List[str]:
"""Validate configuration and return list of warnings."""
warnings = []
# Check split ratios
total_split = self.train_split + self.val_split + self.test_split
if abs(total_split - 1.0) > 0.01:
warnings.append(
f"Split ratios sum to {total_split}, expected 1.0"
)
# Check batch construction
effective_batch = self.batch_size * self.gradient_accumulation_steps
if effective_batch < 64:
warnings.append(
f"Effective batch size ({effective_batch}) is small for contrastive learning. "
"Consider increasing batch_size or gradient_accumulation_steps."
)
# Check LoRA config
if self.lora_r > 0 and self.lora_alpha < self.lora_r:
warnings.append(
f"lora_alpha ({self.lora_alpha}) < lora_r ({self.lora_r}). "
"This may reduce LoRA effectiveness."
)
# Check freeze layers
if self.freeze_layers < 0:
warnings.append("freeze_layers should be >= 0")
# Check temperature
if self.temperature <= 0:
warnings.append("temperature must be positive")
elif self.temperature > 1.0:
warnings.append(
f"temperature ({self.temperature}) is high. "
"Typical values are 0.05-0.1."
)
return warnings
@property
def effective_batch_size(self) -> int:
"""Calculate effective batch size with gradient accumulation."""
return self.batch_size * self.gradient_accumulation_steps
@property
def samples_per_batch(self) -> int:
"""Total samples in one batch (logos_per_batch * samples_per_logo)."""
return self.logos_per_batch * self.samples_per_logo

467
training/dataset.py Normal file
View File

@ -0,0 +1,467 @@
"""
Dataset classes for contrastive learning on logo images.
"""
import random
import sqlite3
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
# CLIP normalization values
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
def get_train_transforms(strength: str = "medium") -> transforms.Compose:
"""
Get training data augmentation transforms.
Args:
strength: Augmentation strength - "light", "medium", or "strong"
Returns:
Composed transforms for training
"""
if strength == "light":
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.1, contrast=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
elif strength == "medium":
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
),
transforms.RandomAffine(
degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
else: # strong
return transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.1),
transforms.RandomRotation(degrees=30),
transforms.ColorJitter(
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
),
transforms.RandomAffine(
degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10
),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
def get_val_transforms() -> transforms.Compose:
"""Get validation/test transforms (no augmentation)."""
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
class LogoDataset:
"""
Manages logo data from the SQLite database.
Handles loading logo-to-image mappings and splitting by logo brand.
"""
def __init__(
self,
db_path: str,
reference_dir: str,
train_split: float = 0.7,
val_split: float = 0.15,
test_split: float = 0.15,
seed: int = 42,
):
self.db_path = Path(db_path)
self.reference_dir = Path(reference_dir)
self.seed = seed
# Load logo-to-images mapping from database
self.logo_to_images = self._load_logo_mappings()
self.all_logos = list(self.logo_to_images.keys())
# Create logo-level splits
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
train_split, val_split, test_split
)
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
"""Load logo name to image paths mapping from database."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
SELECT ln.name, rl.filename
FROM reference_logos rl
JOIN logo_names ln ON rl.logo_name_id = ln.id
ORDER BY ln.name
""")
logo_to_images: Dict[str, List[Path]] = {}
for logo_name, filename in cursor.fetchall():
if logo_name not in logo_to_images:
logo_to_images[logo_name] = []
logo_to_images[logo_name].append(self.reference_dir / filename)
conn.close()
return logo_to_images
def _split_logos(
self,
train_split: float,
val_split: float,
test_split: float,
) -> Tuple[List[str], List[str], List[str]]:
"""Split logos at brand level for train/val/test."""
random.seed(self.seed)
logos = self.all_logos.copy()
random.shuffle(logos)
n = len(logos)
train_end = int(n * train_split)
val_end = train_end + int(n * val_split)
train_logos = logos[:train_end]
val_logos = logos[train_end:val_end]
test_logos = logos[val_end:]
return train_logos, val_logos, test_logos
def get_split_info(self) -> Dict[str, int]:
"""Return information about the splits."""
return {
"total_logos": len(self.all_logos),
"train_logos": len(self.train_logos),
"val_logos": len(self.val_logos),
"test_logos": len(self.test_logos),
"train_images": sum(
len(self.logo_to_images[l]) for l in self.train_logos
),
"val_images": sum(
len(self.logo_to_images[l]) for l in self.val_logos
),
"test_images": sum(
len(self.logo_to_images[l]) for l in self.test_logos
),
}
class LogoContrastiveDataset(Dataset):
"""
Dataset for contrastive learning on logos.
Each __getitem__ call returns a batch of images organized for contrastive
learning: K different logos with M samples each, ensuring positive pairs
exist within each batch.
"""
def __init__(
self,
logo_data: LogoDataset,
split: str = "train",
logos_per_batch: int = 32,
samples_per_logo: int = 4,
transform: Optional[transforms.Compose] = None,
batches_per_epoch: int = 1000,
):
"""
Initialize the contrastive dataset.
Args:
logo_data: LogoDataset instance with logo mappings
split: One of "train", "val", or "test"
logos_per_batch: Number of different logos per batch
samples_per_logo: Number of samples for each logo
transform: Image transforms to apply
batches_per_epoch: Number of batches per epoch
"""
self.logo_data = logo_data
self.logos_per_batch = logos_per_batch
self.samples_per_logo = samples_per_logo
self.transform = transform
self.batches_per_epoch = batches_per_epoch
# Get logos for this split
if split == "train":
self.logos = logo_data.train_logos
elif split == "val":
self.logos = logo_data.val_logos
else:
self.logos = logo_data.test_logos
# Filter logos with enough samples
self.valid_logos = [
logo for logo in self.logos
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
]
# For logos with fewer samples, we'll use with replacement
self.logos_needing_replacement = [
logo for logo in self.logos
if len(logo_data.logo_to_images[logo]) < samples_per_logo
]
# Create label mapping
self.logo_to_label = {
logo: idx for idx, logo in enumerate(self.logos)
}
def __len__(self) -> int:
return self.batches_per_epoch
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a batch of images for contrastive learning.
Returns:
images: Tensor of shape [K*M, 3, 224, 224]
labels: Tensor of shape [K*M] with logo class indices
"""
images = []
labels = []
# Sample K logos for this batch
k = min(self.logos_per_batch, len(self.logos))
batch_logos = random.sample(self.logos, k)
for logo in batch_logos:
logo_images = self.logo_data.logo_to_images[logo]
# Sample M images for this logo
if len(logo_images) >= self.samples_per_logo:
sampled_paths = random.sample(logo_images, self.samples_per_logo)
else:
# Sample with replacement if not enough images
sampled_paths = random.choices(
logo_images, k=self.samples_per_logo
)
# Load and transform images
for img_path in sampled_paths:
try:
img = Image.open(img_path).convert("RGB")
if self.transform:
img = self.transform(img)
else:
img = get_val_transforms()(img)
images.append(img)
labels.append(self.logo_to_label[logo])
except Exception as e:
# Skip problematic images, sample another
continue
# Stack into tensors
if len(images) == 0:
# Fallback: return dummy batch
return (
torch.zeros(1, 3, 224, 224),
torch.zeros(1, dtype=torch.long),
)
images_tensor = torch.stack(images)
labels_tensor = torch.tensor(labels, dtype=torch.long)
return images_tensor, labels_tensor
class BalancedBatchSampler(Sampler):
"""
Sampler that ensures each batch has a balanced distribution of logos.
Used with a flattened dataset where each sample is a single image.
"""
def __init__(
self,
logo_labels: List[int],
logos_per_batch: int,
samples_per_logo: int,
num_batches: int,
):
self.logo_labels = logo_labels
self.logos_per_batch = logos_per_batch
self.samples_per_logo = samples_per_logo
self.num_batches = num_batches
# Group indices by logo
self.logo_to_indices: Dict[int, List[int]] = {}
for idx, label in enumerate(logo_labels):
if label not in self.logo_to_indices:
self.logo_to_indices[label] = []
self.logo_to_indices[label].append(idx)
self.all_logos = list(self.logo_to_indices.keys())
def __iter__(self):
for _ in range(self.num_batches):
batch_indices = []
# Sample logos for this batch
logos = random.sample(
self.all_logos,
min(self.logos_per_batch, len(self.all_logos)),
)
for logo in logos:
indices = self.logo_to_indices[logo]
if len(indices) >= self.samples_per_logo:
sampled = random.sample(indices, self.samples_per_logo)
else:
sampled = random.choices(indices, k=self.samples_per_logo)
batch_indices.extend(sampled)
yield batch_indices
def __len__(self):
return self.num_batches
def create_dataloaders(
db_path: str,
reference_dir: str,
batch_size: int = 16,
logos_per_batch: int = 32,
samples_per_logo: int = 4,
num_workers: int = 4,
train_split: float = 0.7,
val_split: float = 0.15,
test_split: float = 0.15,
seed: int = 42,
augmentation_strength: str = "medium",
batches_per_epoch: int = 1000,
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
"""
Create train, validation, and optionally test dataloaders.
Args:
db_path: Path to SQLite database
reference_dir: Directory containing reference logo images
batch_size: Not used directly (see logos_per_batch and samples_per_logo)
logos_per_batch: Number of different logos per batch
samples_per_logo: Samples per logo in batch
num_workers: Number of data loading workers
train_split: Fraction for training
val_split: Fraction for validation
test_split: Fraction for testing
seed: Random seed
augmentation_strength: "light", "medium", or "strong"
batches_per_epoch: Number of batches per training epoch
Returns:
Tuple of (train_loader, val_loader, test_loader)
"""
# Load logo data
logo_data = LogoDataset(
db_path=db_path,
reference_dir=reference_dir,
train_split=train_split,
val_split=val_split,
test_split=test_split,
seed=seed,
)
# Print split info
split_info = logo_data.get_split_info()
print(f"Dataset loaded:")
print(f" Total logos: {split_info['total_logos']}")
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images")
# Create datasets
train_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="train",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_train_transforms(augmentation_strength),
batches_per_epoch=batches_per_epoch,
)
val_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="val",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_val_transforms(),
batches_per_epoch=batches_per_epoch // 10, # Fewer val batches
)
test_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="test",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_val_transforms(),
batches_per_epoch=batches_per_epoch // 10,
) if test_split > 0 else None
# Create dataloaders
# Note: batch_size=1 because each __getitem__ already returns a batch
train_loader = DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
test_loader = None
if test_dataset is not None:
test_loader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
return train_loader, val_loader, test_loader
def _collate_contrastive_batch(
batch: List[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Collate function that unpacks pre-batched data.
Since LogoContrastiveDataset already returns batched data,
we just squeeze the outer dimension.
"""
images, labels = batch[0]
return images, labels

339
training/evaluation.py Normal file
View File

@ -0,0 +1,339 @@
"""
Evaluation metrics for embedding quality.
"""
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import numpy as np
class EmbeddingEvaluator:
"""
Evaluator for embedding quality metrics.
Computes metrics that indicate how well the embeddings
separate different logo classes.
"""
def compute_metrics(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> Dict[str, float]:
"""
Compute embedding quality metrics.
Args:
embeddings: [N, D] L2-normalized embeddings
labels: [N] integer class labels
Returns:
Dict with metric names and values
"""
device = embeddings.device
batch_size = embeddings.shape[0]
if batch_size <= 1:
return {
"mean_pos_sim": 0.0,
"mean_neg_sim": 0.0,
"separation": 0.0,
"recall_at_1": 0.0,
"recall_at_5": 0.0,
}
# Compute similarity matrix
similarity = embeddings @ embeddings.T
# Create masks
labels_col = labels.unsqueeze(0)
labels_row = labels.unsqueeze(1)
positive_mask = (labels_row == labels_col).float()
negative_mask = 1 - positive_mask
# Remove diagonal from positive mask
identity = torch.eye(batch_size, device=device)
positive_mask = positive_mask - identity
# Count pairs
num_positives = positive_mask.sum()
num_negatives = negative_mask.sum()
# Mean positive similarity (excluding self)
if num_positives > 0:
pos_sims = (similarity * positive_mask).sum() / num_positives
mean_pos_sim = pos_sims.item()
else:
mean_pos_sim = 0.0
# Mean negative similarity
if num_negatives > 0:
neg_sims = (similarity * negative_mask).sum() / num_negatives
mean_neg_sim = neg_sims.item()
else:
mean_neg_sim = 0.0
# Separation: gap between positive and negative similarity
separation = mean_pos_sim - mean_neg_sim
# Recall@K metrics
recall_at_1 = self._compute_recall_at_k(similarity, labels, k=1)
recall_at_5 = self._compute_recall_at_k(similarity, labels, k=5)
return {
"mean_pos_sim": mean_pos_sim,
"mean_neg_sim": mean_neg_sim,
"separation": separation,
"recall_at_1": recall_at_1,
"recall_at_5": recall_at_5,
}
def _compute_recall_at_k(
self,
similarity: torch.Tensor,
labels: torch.Tensor,
k: int = 1,
) -> float:
"""
Compute Recall@K for nearest neighbor retrieval.
For each sample, check if the k nearest neighbors (excluding self)
contain at least one sample with the same label.
Args:
similarity: [N, N] similarity matrix
labels: [N] class labels
k: Number of neighbors to consider
Returns:
Recall@K score (0 to 1)
"""
batch_size = similarity.shape[0]
if batch_size <= 1:
return 0.0
# Mask out self-similarity
similarity = similarity.clone()
similarity.fill_diagonal_(float("-inf"))
# Get top-k indices
_, top_k_indices = similarity.topk(min(k, batch_size - 1), dim=1)
# Check if any of top-k have same label
correct = 0
for i in range(batch_size):
query_label = labels[i]
retrieved_labels = labels[top_k_indices[i]]
if (retrieved_labels == query_label).any():
correct += 1
return correct / batch_size
def compute_detailed_metrics(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
label_names: Optional[List[str]] = None,
) -> Dict:
"""
Compute detailed per-class metrics.
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
label_names: Optional list of label names
Returns:
Dict with detailed metrics including per-class stats
"""
basic_metrics = self.compute_metrics(embeddings, labels)
# Per-class statistics
unique_labels = labels.unique()
per_class_stats = {}
similarity = embeddings @ embeddings.T
for label in unique_labels:
mask = labels == label
class_embeddings = embeddings[mask]
class_size = mask.sum().item()
if class_size > 1:
# Intra-class similarity
class_sim = class_embeddings @ class_embeddings.T
# Exclude diagonal
mask_diag = ~torch.eye(class_size, dtype=torch.bool, device=class_sim.device)
intra_sim = class_sim[mask_diag].mean().item()
else:
intra_sim = 1.0
# Inter-class similarity (to other classes)
other_mask = labels != label
if other_mask.any():
inter_sim = similarity[mask][:, other_mask].mean().item()
else:
inter_sim = 0.0
class_name = label_names[label.item()] if label_names else str(label.item())
per_class_stats[class_name] = {
"size": class_size,
"intra_class_sim": intra_sim,
"inter_class_sim": inter_sim,
"class_separation": intra_sim - inter_sim,
}
# Aggregate per-class stats
if per_class_stats:
separations = [s["class_separation"] for s in per_class_stats.values()]
min_separation = min(separations)
max_separation = max(separations)
std_separation = np.std(separations)
else:
min_separation = max_separation = std_separation = 0.0
return {
**basic_metrics,
"per_class": per_class_stats,
"min_class_separation": min_separation,
"max_class_separation": max_separation,
"std_class_separation": std_separation,
}
class SimilarityAnalyzer:
"""
Analyze similarity distributions for debugging and tuning.
"""
@staticmethod
def analyze_similarity_distribution(
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> Dict[str, np.ndarray]:
"""
Get similarity distributions for positive and negative pairs.
Useful for choosing appropriate thresholds.
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
Returns:
Dict with 'positive_sims' and 'negative_sims' arrays
"""
similarity = (embeddings @ embeddings.T).cpu().numpy()
labels_np = labels.cpu().numpy()
batch_size = len(labels_np)
positive_sims = []
negative_sims = []
for i in range(batch_size):
for j in range(i + 1, batch_size):
if labels_np[i] == labels_np[j]:
positive_sims.append(similarity[i, j])
else:
negative_sims.append(similarity[i, j])
return {
"positive_sims": np.array(positive_sims),
"negative_sims": np.array(negative_sims),
}
@staticmethod
def find_hard_pairs(
embeddings: torch.Tensor,
labels: torch.Tensor,
n_hard: int = 10,
) -> Tuple[List[Tuple[int, int, float]], List[Tuple[int, int, float]]]:
"""
Find hardest positive and negative pairs.
Hard positives: same label but low similarity
Hard negatives: different label but high similarity
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
n_hard: Number of hard pairs to return
Returns:
Tuple of (hard_positives, hard_negatives)
Each is a list of (idx1, idx2, similarity) tuples
"""
similarity = embeddings @ embeddings.T
batch_size = len(labels)
hard_positives = [] # Low similarity, same label
hard_negatives = [] # High similarity, different label
for i in range(batch_size):
for j in range(i + 1, batch_size):
sim = similarity[i, j].item()
if labels[i] == labels[j]:
hard_positives.append((i, j, sim))
else:
hard_negatives.append((i, j, sim))
# Sort: hard positives by ascending similarity (lowest first)
hard_positives.sort(key=lambda x: x[2])
# Sort: hard negatives by descending similarity (highest first)
hard_negatives.sort(key=lambda x: -x[2])
return hard_positives[:n_hard], hard_negatives[:n_hard]
@staticmethod
def compute_confusion_pairs(
embeddings: torch.Tensor,
labels: torch.Tensor,
label_names: Optional[List[str]] = None,
top_k: int = 10,
) -> List[Dict]:
"""
Find pairs of classes that are most confused (highest cross-class similarity).
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
label_names: Optional label names
top_k: Number of confused pairs to return
Returns:
List of dicts with class pairs and their similarity
"""
unique_labels = labels.unique()
class_centroids = {}
# Compute class centroids
for label in unique_labels:
mask = labels == label
centroid = embeddings[mask].mean(dim=0)
centroid = F.normalize(centroid, dim=0)
class_centroids[label.item()] = centroid
# Compute pairwise centroid similarities
confusions = []
label_list = list(class_centroids.keys())
for i, label1 in enumerate(label_list):
for label2 in label_list[i + 1:]:
sim = (class_centroids[label1] @ class_centroids[label2]).item()
name1 = label_names[label1] if label_names else str(label1)
name2 = label_names[label2] if label_names else str(label2)
confusions.append({
"class1": name1,
"class2": name2,
"label1": label1,
"label2": label2,
"centroid_similarity": sim,
})
# Sort by similarity (highest first)
confusions.sort(key=lambda x: -x["centroid_similarity"])
return confusions[:top_k]

326
training/losses.py Normal file
View File

@ -0,0 +1,326 @@
"""
Loss functions for contrastive learning on logo embeddings.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
class InfoNCELoss(nn.Module):
"""
Normalized Temperature-scaled Cross Entropy Loss (InfoNCE).
This is the contrastive loss used in CLIP training. It maximizes
similarity between embeddings of the same logo class while
minimizing similarity to embeddings of different classes.
For a batch with N samples:
- Each sample is an anchor
- Positive pairs: samples with the same label
- Negative pairs: samples with different labels
The loss for each anchor is:
-log(sum(exp(sim(anchor, pos)/temp)) / sum(exp(sim(anchor, all)/temp)))
"""
def __init__(self, temperature: float = 0.07):
"""
Initialize InfoNCE loss.
Args:
temperature: Scaling factor for similarities (0.05-0.1 typical).
Lower temperature makes the distribution sharper.
"""
super().__init__()
self.temperature = temperature
def forward(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
Compute InfoNCE loss for a batch of embeddings.
Args:
embeddings: [N, D] L2-normalized embeddings
labels: [N] integer logo class labels
Returns:
Scalar loss value
"""
device = embeddings.device
batch_size = embeddings.shape[0]
if batch_size <= 1:
return torch.tensor(0.0, device=device, requires_grad=True)
# Compute similarity matrix [N, N]
# Since embeddings are L2-normalized, dot product = cosine similarity
similarity = embeddings @ embeddings.T / self.temperature
# Create positive mask: same label = 1, different = 0
labels_col = labels.unsqueeze(0) # [1, N]
labels_row = labels.unsqueeze(1) # [N, 1]
positive_mask = (labels_row == labels_col).float() # [N, N]
# Remove self-similarity from positives (diagonal)
identity = torch.eye(batch_size, device=device)
positive_mask = positive_mask - identity
# Count positives per anchor (avoid division by zero)
num_positives = positive_mask.sum(dim=1)
has_positives = num_positives > 0
# If no positives exist for any anchor, return zero loss
if not has_positives.any():
return torch.tensor(0.0, device=device, requires_grad=True)
# Mask out self-similarity with large negative value
similarity = similarity - identity * 1e9
# Compute log-softmax over similarities
log_softmax = F.log_softmax(similarity, dim=1)
# Sum log probabilities of positive pairs
positive_log_probs = (log_softmax * positive_mask).sum(dim=1)
# Average over number of positives (only for anchors with positives)
loss_per_anchor = torch.zeros(batch_size, device=device)
loss_per_anchor[has_positives] = (
-positive_log_probs[has_positives] / num_positives[has_positives]
)
return loss_per_anchor.mean()
class SupConLoss(nn.Module):
"""
Supervised Contrastive Loss.
Similar to InfoNCE but uses a different formulation that
considers each positive pair separately rather than averaging.
Reference: https://arxiv.org/abs/2004.11362
"""
def __init__(self, temperature: float = 0.07):
super().__init__()
self.temperature = temperature
def forward(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
Compute Supervised Contrastive loss.
Args:
embeddings: [N, D] L2-normalized embeddings
labels: [N] integer logo class labels
Returns:
Scalar loss value
"""
device = embeddings.device
batch_size = embeddings.shape[0]
if batch_size <= 1:
return torch.tensor(0.0, device=device, requires_grad=True)
# Compute similarity matrix
similarity = embeddings @ embeddings.T / self.temperature
# Create masks
labels_col = labels.unsqueeze(0)
labels_row = labels.unsqueeze(1)
positive_mask = (labels_row == labels_col).float()
identity = torch.eye(batch_size, device=device)
# Remove self from positives
positive_mask = positive_mask - identity
# Number of positives per anchor
num_positives = positive_mask.sum(dim=1)
has_positives = num_positives > 0
if not has_positives.any():
return torch.tensor(0.0, device=device, requires_grad=True)
# For numerical stability, subtract max similarity
sim_max, _ = similarity.max(dim=1, keepdim=True)
similarity = similarity - sim_max.detach()
# Compute exp(similarity) with self masked out
exp_sim = torch.exp(similarity) * (1 - identity)
# Denominator: sum of exp over all pairs except self
log_prob = similarity - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)
# Mean of log-prob over positive pairs
mean_log_prob_pos = (positive_mask * log_prob).sum(dim=1) / (
num_positives + 1e-8
)
# Loss is negative mean log probability
loss = -mean_log_prob_pos[has_positives].mean()
return loss
class TripletLoss(nn.Module):
"""
Triplet loss with online hard mining.
For each anchor:
- Hardest positive: most distant sample with same label
- Hardest negative: closest sample with different label
Loss = max(0, d(anchor, hardest_pos) - d(anchor, hardest_neg) + margin)
This is an alternative to InfoNCE for when batch sizes are small.
"""
def __init__(self, margin: float = 0.3):
"""
Initialize Triplet loss.
Args:
margin: Minimum required gap between positive and negative distances
"""
super().__init__()
self.margin = margin
def forward(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""
Compute triplet loss with online hard mining.
Args:
embeddings: [N, D] L2-normalized embeddings
labels: [N] integer logo class labels
Returns:
Scalar loss value
"""
device = embeddings.device
batch_size = embeddings.shape[0]
if batch_size <= 1:
return torch.tensor(0.0, device=device, requires_grad=True)
# Compute pairwise cosine distances (1 - cosine_similarity)
# For normalized vectors: distance = 1 - dot_product
similarity = embeddings @ embeddings.T
distances = 1 - similarity
# Create masks
labels_col = labels.unsqueeze(0)
labels_row = labels.unsqueeze(1)
positive_mask = (labels_row == labels_col).float()
negative_mask = 1 - positive_mask
# Remove self from positives (diagonal)
identity = torch.eye(batch_size, device=device)
positive_mask = positive_mask - identity
# Check if we have any valid triplets
has_positives = positive_mask.sum(dim=1) > 0
has_negatives = negative_mask.sum(dim=1) > 0
valid_anchors = has_positives & has_negatives
if not valid_anchors.any():
return torch.tensor(0.0, device=device, requires_grad=True)
# For each anchor, find hardest positive (max distance among positives)
# Set negatives to -inf so they don't affect max
pos_distances = distances.clone()
pos_distances[positive_mask == 0] = float("-inf")
hardest_positive, _ = pos_distances.max(dim=1)
# For each anchor, find hardest negative (min distance among negatives)
# Set positives to inf so they don't affect min
neg_distances = distances.clone()
neg_distances[negative_mask == 0] = float("inf")
hardest_negative, _ = neg_distances.min(dim=1)
# Triplet loss: want positive to be closer than negative by margin
triplet_loss = F.relu(
hardest_positive - hardest_negative + self.margin
)
# Average over valid anchors only
loss = triplet_loss[valid_anchors].mean()
return loss
class CombinedLoss(nn.Module):
"""
Combined loss function with weighted InfoNCE and Triplet losses.
Can help stabilize training by combining the benefits of both losses.
"""
def __init__(
self,
temperature: float = 0.07,
triplet_margin: float = 0.3,
infonce_weight: float = 1.0,
triplet_weight: float = 0.5,
):
super().__init__()
self.infonce = InfoNCELoss(temperature=temperature)
self.triplet = TripletLoss(margin=triplet_margin)
self.infonce_weight = infonce_weight
self.triplet_weight = triplet_weight
def forward(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
infonce_loss = self.infonce(embeddings, labels)
triplet_loss = self.triplet(embeddings, labels)
return (
self.infonce_weight * infonce_loss +
self.triplet_weight * triplet_loss
)
def get_loss_function(
loss_type: str = "infonce",
temperature: float = 0.07,
triplet_margin: float = 0.3,
) -> nn.Module:
"""
Factory function to create loss function.
Args:
loss_type: One of "infonce", "supcon", "triplet", or "combined"
temperature: Temperature for InfoNCE/SupCon
triplet_margin: Margin for triplet loss
Returns:
Loss function module
"""
if loss_type == "infonce":
return InfoNCELoss(temperature=temperature)
elif loss_type == "supcon":
return SupConLoss(temperature=temperature)
elif loss_type == "triplet":
return TripletLoss(margin=triplet_margin)
elif loss_type == "combined":
return CombinedLoss(
temperature=temperature,
triplet_margin=triplet_margin,
)
else:
raise ValueError(f"Unknown loss type: {loss_type}")

335
training/model.py Normal file
View File

@ -0,0 +1,335 @@
"""
Fine-tunable CLIP model wrapper with LoRA support.
"""
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPModel, CLIPProcessor
# Check if peft is available for LoRA
try:
from peft import LoraConfig, get_peft_model, PeftModel
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
LoraConfig = None
get_peft_model = None
PeftModel = None
class LogoFineTunedCLIP(nn.Module):
"""
CLIP vision encoder fine-tuned for logo similarity.
Preserves embedding interface for compatibility with DetectLogosDETR:
- Same embedding dimensionality (768 for ViT-L/14)
- L2 normalized outputs
- Works with existing get_image_features() pattern
Supports:
- LoRA for memory-efficient fine-tuning
- Layer freezing for transfer learning
- Gradient checkpointing for memory optimization
"""
def __init__(
self,
vision_model: nn.Module,
lora_r: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.1,
freeze_layers: int = 12,
use_gradient_checkpointing: bool = True,
add_projection_head: bool = True,
):
"""
Initialize the fine-tunable CLIP wrapper.
Args:
vision_model: CLIP vision model (CLIPVisionModel)
lora_r: Rank of LoRA low-rank matrices (0 to disable)
lora_alpha: LoRA scaling factor
lora_dropout: Dropout for LoRA layers
freeze_layers: Number of transformer layers to freeze (from bottom)
use_gradient_checkpointing: Enable gradient checkpointing
add_projection_head: Add trainable projection head
"""
super().__init__()
self.vision_model = vision_model
self.embedding_dim = vision_model.config.hidden_size
self.freeze_layers = freeze_layers
self.lora_r = lora_r
self.lora_alpha = lora_alpha
# Enable gradient checkpointing for memory efficiency
if use_gradient_checkpointing:
if hasattr(self.vision_model, "gradient_checkpointing_enable"):
self.vision_model.gradient_checkpointing_enable()
# Freeze lower layers
self._freeze_layers(freeze_layers)
# Apply LoRA to attention layers in upper blocks
self.peft_applied = False
if PEFT_AVAILABLE and lora_r > 0:
self._apply_lora(lora_r, lora_alpha, lora_dropout)
self.peft_applied = True
elif lora_r > 0 and not PEFT_AVAILABLE:
print(
"Warning: peft not installed. LoRA disabled. "
"Install with: pip install peft"
)
# Optional projection head for fine-tuning
self.add_projection_head = add_projection_head
if add_projection_head:
self.projection = nn.Sequential(
nn.Linear(self.embedding_dim, self.embedding_dim),
nn.LayerNorm(self.embedding_dim),
)
else:
self.projection = nn.Identity()
def _freeze_layers(self, num_layers: int) -> None:
"""Freeze the first N transformer layers and embeddings."""
if num_layers <= 0:
return
# Freeze embeddings
if hasattr(self.vision_model, "embeddings"):
for param in self.vision_model.embeddings.parameters():
param.requires_grad = False
# Freeze specified number of encoder layers
if hasattr(self.vision_model, "encoder"):
for i, layer in enumerate(self.vision_model.encoder.layers):
if i < num_layers:
for param in layer.parameters():
param.requires_grad = False
def _apply_lora(
self,
r: int,
alpha: int,
dropout: float,
) -> None:
"""Apply LoRA adapters to attention layers."""
if not PEFT_AVAILABLE:
return
# Configure LoRA for vision transformer
lora_config = LoraConfig(
r=r,
lora_alpha=alpha,
lora_dropout=dropout,
target_modules=["q_proj", "v_proj"],
bias="none",
modules_to_save=[], # Don't save any full modules
)
self.vision_model = get_peft_model(self.vision_model, lora_config)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Extract normalized embeddings for logo images.
Args:
pixel_values: [batch, 3, 224, 224] preprocessed images
Returns:
embeddings: [batch, embedding_dim] L2-normalized
"""
# Get vision features
outputs = self.vision_model(pixel_values=pixel_values)
# Use pooler output (CLS token projection) if available
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
features = outputs.pooler_output
else:
# Fall back to CLS token from last hidden state
features = outputs.last_hidden_state[:, 0, :]
# Apply projection head
features = self.projection(features)
# L2 normalize for cosine similarity
features = F.normalize(features, dim=-1)
return features
def get_image_features(self, **kwargs) -> torch.Tensor:
"""
Compatibility method matching CLIP's interface.
Used by DetectLogosDETR._get_embedding_pil().
"""
return self.forward(kwargs["pixel_values"])
def get_trainable_parameters(self) -> List[torch.nn.Parameter]:
"""Return list of trainable parameters."""
return [p for p in self.parameters() if p.requires_grad]
def get_parameter_count(self) -> Dict[str, int]:
"""Return count of trainable and total parameters."""
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {
"total": total,
"trainable": trainable,
"frozen": total - trainable,
"trainable_percent": 100 * trainable / total if total > 0 else 0,
}
def save_pretrained(self, output_dir: str) -> None:
"""
Save model in HuggingFace-compatible format.
Args:
output_dir: Directory to save model files
"""
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save model weights
if self.peft_applied and PEFT_AVAILABLE:
# Save LoRA weights separately
self.vision_model.save_pretrained(output_path / "vision_lora")
# Save projection head
torch.save(
self.projection.state_dict(),
output_path / "projection_head.bin",
)
else:
# Save full model state
torch.save(self.state_dict(), output_path / "pytorch_model.bin")
# Save config
config = {
"model_type": "clip_logo_finetuned",
"embedding_dim": self.embedding_dim,
"lora_r": self.lora_r,
"lora_alpha": self.lora_alpha,
"freeze_layers": self.freeze_layers,
"add_projection_head": self.add_projection_head,
"peft_applied": self.peft_applied,
}
with open(output_path / "config.json", "w") as f:
json.dump(config, f, indent=2)
@classmethod
def from_pretrained(
cls,
model_path: str,
base_model: str = "openai/clip-vit-large-patch14",
device: Optional[torch.device] = None,
) -> "LogoFineTunedCLIP":
"""
Load a fine-tuned model from saved weights.
Args:
model_path: Path to saved model directory
base_model: Base CLIP model name (for architecture)
device: Device to load model on
Returns:
Loaded LogoFineTunedCLIP model
"""
model_path = Path(model_path)
# Load config
with open(model_path / "config.json", "r") as f:
config = json.load(f)
# Load base CLIP model
clip_model = CLIPModel.from_pretrained(base_model)
# Create model instance
model = cls(
vision_model=clip_model.vision_model,
lora_r=config.get("lora_r", 0),
lora_alpha=config.get("lora_alpha", 1),
freeze_layers=config.get("freeze_layers", 12),
add_projection_head=config.get("add_projection_head", True),
use_gradient_checkpointing=False, # Not needed for inference
)
# Load weights
if config.get("peft_applied", False) and PEFT_AVAILABLE:
# Load LoRA weights
lora_path = model_path / "vision_lora"
if lora_path.exists():
model.vision_model = PeftModel.from_pretrained(
model.vision_model, lora_path
)
# Load projection head
proj_path = model_path / "projection_head.bin"
if proj_path.exists():
model.projection.load_state_dict(torch.load(proj_path))
else:
# Load full model state
weights_path = model_path / "pytorch_model.bin"
if weights_path.exists():
model.load_state_dict(torch.load(weights_path))
if device is not None:
model = model.to(device)
return model
def create_model(
base_model: str = "openai/clip-vit-large-patch14",
lora_r: int = 16,
lora_alpha: int = 32,
lora_dropout: float = 0.1,
freeze_layers: int = 12,
use_gradient_checkpointing: bool = True,
device: Optional[torch.device] = None,
) -> Tuple[LogoFineTunedCLIP, CLIPProcessor]:
"""
Create a fine-tunable CLIP model and processor.
Args:
base_model: HuggingFace model name or path
lora_r: LoRA rank (0 to disable)
lora_alpha: LoRA scaling factor
lora_dropout: LoRA dropout
freeze_layers: Number of layers to freeze
use_gradient_checkpointing: Enable gradient checkpointing
device: Device to load model on
Returns:
Tuple of (model, processor)
"""
# Load base CLIP model
clip_model = CLIPModel.from_pretrained(base_model)
processor = CLIPProcessor.from_pretrained(base_model)
# Create fine-tunable wrapper
model = LogoFineTunedCLIP(
vision_model=clip_model.vision_model,
lora_r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
freeze_layers=freeze_layers,
use_gradient_checkpointing=use_gradient_checkpointing,
)
if device is not None:
model = model.to(device)
# Print parameter info
param_info = model.get_parameter_count()
print(f"Model created:")
print(f" Total parameters: {param_info['total']:,}")
print(f" Trainable: {param_info['trainable']:,} ({param_info['trainable_percent']:.2f}%)")
print(f" Frozen: {param_info['frozen']:,}")
return model, processor

405
training/trainer.py Normal file
View File

@ -0,0 +1,405 @@
"""
Training loop with checkpointing, mixed precision, and evaluation.
"""
import json
import logging
import time
from pathlib import Path
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
from torch.utils.data import DataLoader
from tqdm import tqdm
from .config import TrainingConfig
from .losses import get_loss_function
from .evaluation import EmbeddingEvaluator
# Check if amp is available
try:
from torch.cuda.amp import autocast, GradScaler
AMP_AVAILABLE = True
except ImportError:
AMP_AVAILABLE = False
autocast = None
GradScaler = None
class Trainer:
"""
Trainer for fine-tuning CLIP on logo recognition.
Features:
- Mixed precision training (FP16)
- Gradient accumulation
- Gradient checkpointing (via model)
- Cosine annealing LR scheduler
- Early stopping
- Checkpoint saving/loading
- Evaluation during training
"""
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
config: TrainingConfig,
logger: Optional[logging.Logger] = None,
):
"""
Initialize the trainer.
Args:
model: LogoFineTunedCLIP model
train_loader: Training dataloader
val_loader: Validation dataloader
config: Training configuration
logger: Optional logger instance
"""
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.config = config
self.logger = logger or logging.getLogger(__name__)
# Device setup
self.device = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
self.model.to(self.device)
self.logger.info(f"Using device: {self.device}")
# Optimizer - only trainable parameters
trainable_params = [p for p in model.parameters() if p.requires_grad]
self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
self.optimizer = AdamW(
trainable_params,
lr=config.learning_rate,
weight_decay=config.weight_decay,
)
# Learning rate scheduler
total_steps = len(train_loader) * config.max_epochs
self.scheduler = OneCycleLR(
self.optimizer,
max_lr=config.learning_rate,
total_steps=total_steps,
pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1,
anneal_strategy="cos",
)
# Mixed precision training
self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda"
if self.use_amp:
self.scaler = GradScaler()
self.logger.info("Mixed precision training enabled")
else:
self.scaler = None
if config.mixed_precision and not AMP_AVAILABLE:
self.logger.warning("Mixed precision requested but not available")
# Loss function
self.criterion = get_loss_function(
loss_type=config.loss_type,
temperature=config.temperature,
triplet_margin=config.triplet_margin,
)
# Evaluator
self.evaluator = EmbeddingEvaluator()
# Training state
self.epoch = 0
self.global_step = 0
self.best_val_loss = float("inf")
self.best_val_separation = float("-inf")
self.patience_counter = 0
self.training_history = []
def train(self) -> Dict[str, float]:
"""
Main training loop.
Returns:
Dict with final training metrics
"""
self.logger.info("Starting training...")
self.logger.info(f" Epochs: {self.config.max_epochs}")
self.logger.info(f" Batch size: {self.config.batch_size}")
self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}")
self.logger.info(f" Effective batch: {self.config.effective_batch_size}")
self.logger.info(f" Learning rate: {self.config.learning_rate}")
start_time = time.time()
for epoch in range(self.epoch, self.config.max_epochs):
self.epoch = epoch
self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}")
# Training epoch
train_metrics = self._train_epoch()
self.logger.info(
f"Train - Loss: {train_metrics['loss']:.4f}, "
f"LR: {train_metrics['lr']:.2e}"
)
# Validation
if (epoch + 1) % self.config.eval_every_n_epochs == 0:
val_metrics = self._validate()
self.logger.info(
f"Val - Loss: {val_metrics['loss']:.4f}, "
f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, "
f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, "
f"Separation: {val_metrics['separation']:.3f}"
)
# Record history
self.training_history.append({
"epoch": epoch + 1,
"train_loss": train_metrics["loss"],
"val_loss": val_metrics["loss"],
"val_separation": val_metrics["separation"],
"val_pos_sim": val_metrics["mean_pos_sim"],
"val_neg_sim": val_metrics["mean_neg_sim"],
})
# Checkpointing based on separation (primary) or loss (secondary)
improved = False
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
self.best_val_separation = val_metrics["separation"]
improved = True
elif val_metrics["loss"] < self.best_val_loss - self.config.min_delta:
self.best_val_loss = val_metrics["loss"]
improved = True
if improved:
self.patience_counter = 0
self._save_checkpoint("best.pt")
self.logger.info("New best model saved!")
else:
self.patience_counter += 1
# Early stopping
if self.patience_counter >= self.config.patience:
self.logger.info(
f"Early stopping triggered at epoch {epoch + 1} "
f"(no improvement for {self.config.patience} epochs)"
)
break
# Periodic checkpoint
if (epoch + 1) % self.config.save_every_n_epochs == 0:
self._save_checkpoint(f"epoch_{epoch + 1}.pt")
# Training complete
total_time = time.time() - start_time
self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes")
# Load best model
best_path = Path(self.config.checkpoint_dir) / "best.pt"
if best_path.exists():
self.load_checkpoint("best.pt")
self.logger.info("Loaded best model checkpoint")
return {
"best_val_loss": self.best_val_loss,
"best_val_separation": self.best_val_separation,
"total_epochs": self.epoch + 1,
"total_time_minutes": total_time / 60,
}
def _train_epoch(self) -> Dict[str, float]:
"""Run a single training epoch."""
self.model.train()
total_loss = 0.0
num_batches = 0
accumulation_steps = 0
progress_bar = tqdm(
self.train_loader,
desc=f"Epoch {self.epoch + 1}",
leave=False,
)
self.optimizer.zero_grad()
for batch_idx, (images, labels) in enumerate(progress_bar):
images = images.to(self.device)
labels = labels.to(self.device)
# Forward pass with mixed precision
if self.use_amp:
with autocast():
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
loss = loss / self.config.gradient_accumulation_steps
self.scaler.scale(loss).backward()
else:
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
loss = loss / self.config.gradient_accumulation_steps
loss.backward()
accumulation_steps += 1
# Optimizer step after accumulation
if accumulation_steps >= self.config.gradient_accumulation_steps:
if self.use_amp:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
self.global_step += 1
accumulation_steps = 0
total_loss += loss.item() * self.config.gradient_accumulation_steps
num_batches += 1
# Update progress bar
progress_bar.set_postfix({
"loss": total_loss / num_batches,
"lr": self.scheduler.get_last_lr()[0],
})
# Logging
if (batch_idx + 1) % self.config.log_every_n_steps == 0:
self.logger.debug(
f"Step {self.global_step}: loss={total_loss / num_batches:.4f}"
)
return {
"loss": total_loss / max(num_batches, 1),
"lr": self.scheduler.get_last_lr()[0],
}
def _validate(self) -> Dict[str, float]:
"""Run validation and compute metrics."""
self.model.eval()
total_loss = 0.0
all_embeddings = []
all_labels = []
with torch.no_grad():
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
images = images.to(self.device)
labels = labels.to(self.device)
if self.use_amp:
with autocast():
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
else:
embeddings = self.model(images)
loss = self.criterion(embeddings, labels)
total_loss += loss.item()
all_embeddings.append(embeddings.cpu())
all_labels.append(labels.cpu())
# Combine batches
all_embeddings = torch.cat(all_embeddings, dim=0)
all_labels = torch.cat(all_labels, dim=0)
# Compute embedding quality metrics
metrics = self.evaluator.compute_metrics(all_embeddings, all_labels)
metrics["loss"] = total_loss / max(len(self.val_loader), 1)
return metrics
def _save_checkpoint(self, filename: str) -> None:
"""Save training checkpoint."""
checkpoint_dir = Path(self.config.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
checkpoint = {
"epoch": self.epoch,
"global_step": self.global_step,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"best_val_loss": self.best_val_loss,
"best_val_separation": self.best_val_separation,
"patience_counter": self.patience_counter,
"training_history": self.training_history,
"config": self.config.__dict__,
}
if self.scaler is not None:
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
torch.save(checkpoint, checkpoint_dir / filename)
self.logger.debug(f"Saved checkpoint: {filename}")
def load_checkpoint(self, filename: str) -> None:
"""Load training checkpoint."""
checkpoint_path = Path(self.config.checkpoint_dir) / filename
if not checkpoint_path.exists():
self.logger.warning(f"Checkpoint not found: {checkpoint_path}")
return
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
self.epoch = checkpoint["epoch"]
self.global_step = checkpoint["global_step"]
self.best_val_loss = checkpoint["best_val_loss"]
self.best_val_separation = checkpoint.get("best_val_separation", float("-inf"))
self.patience_counter = checkpoint.get("patience_counter", 0)
self.training_history = checkpoint.get("training_history", [])
if self.scaler is not None and "scaler_state_dict" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
self.logger.info(f"Resumed from epoch {self.epoch + 1}")
def export_model(self, output_dir: Optional[str] = None) -> str:
"""
Export the trained model for inference.
Args:
output_dir: Output directory (uses config.output_dir if not specified)
Returns:
Path to exported model directory
"""
output_dir = output_dir or self.config.output_dir
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Save model
self.model.save_pretrained(output_dir)
# Save training config
config_path = output_path / "training_config.json"
with open(config_path, "w") as f:
json.dump(self.config.__dict__, f, indent=2)
# Save training history
history_path = output_path / "training_history.json"
with open(history_path, "w") as f:
json.dump(self.training_history, f, indent=2)
self.logger.info(f"Model exported to: {output_path}")
return str(output_path)
def get_training_summary(self) -> Dict:
"""Get summary of training."""
return {
"epochs_completed": self.epoch + 1,
"global_steps": self.global_step,
"best_val_loss": self.best_val_loss,
"best_val_separation": self.best_val_separation,
"history": self.training_history,
}