Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
This commit is contained in:
266
CLIP_FINETUNING.md
Normal file
266
CLIP_FINETUNING.md
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
# CLIP Fine-Tuning for Logo Recognition
|
||||||
|
|
||||||
|
This document describes the CLIP fine-tuning pipeline for improving logo embedding similarity using the LogoDet-3K dataset.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The fine-tuning approach uses **contrastive learning** with **LoRA** (Low-Rank Adaptation) to train CLIP's vision encoder for better logo similarity matching while maintaining compatibility with the existing `DetectLogosDETR` class.
|
||||||
|
|
||||||
|
**Goal**: Improve F1 from ~60% to >72% on logo matching tasks.
|
||||||
|
|
||||||
|
## Files Created
|
||||||
|
|
||||||
|
### Training Module (`training/`)
|
||||||
|
|
||||||
|
| File | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `__init__.py` | Module exports |
|
||||||
|
| `config.py` | `TrainingConfig` dataclass with all hyperparameters |
|
||||||
|
| `dataset.py` | `LogoContrastiveDataset` with logo-level splits and augmentations |
|
||||||
|
| `model.py` | `LogoFineTunedCLIP` wrapper with LoRA support |
|
||||||
|
| `losses.py` | `InfoNCELoss`, `TripletLoss`, `SupConLoss`, `CombinedLoss` |
|
||||||
|
| `trainer.py` | Training loop with mixed precision, checkpointing, early stopping |
|
||||||
|
| `evaluation.py` | `EmbeddingEvaluator` for validation metrics |
|
||||||
|
|
||||||
|
### Scripts
|
||||||
|
|
||||||
|
| File | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `train_clip_logo.py` | Main training entry point |
|
||||||
|
| `export_model.py` | Export trained models to HuggingFace-compatible format |
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
| File | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `configs/jetson_orin.yaml` | Training config optimized for Jetson Orin AGX |
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
1. **Install dependencies**:
|
||||||
|
```bash
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Prepare test data** (if not already done):
|
||||||
|
```bash
|
||||||
|
uv run python prepare_test_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This creates:
|
||||||
|
- `reference_logos/` - Cropped logo images organized by category/brand
|
||||||
|
- `test_images/` - Full images for testing
|
||||||
|
- `test_data_mapping.db` - SQLite database with mappings
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
### Basic Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training with Overrides
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||||
|
--learning-rate 5e-6 \
|
||||||
|
--max-epochs 30 \
|
||||||
|
--batch-size 8
|
||||||
|
```
|
||||||
|
|
||||||
|
### Resume from Checkpoint
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||||
|
--resume checkpoints/epoch_10.pt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training Output
|
||||||
|
|
||||||
|
- Checkpoints saved to `checkpoints/`
|
||||||
|
- Best model saved as `checkpoints/best.pt`
|
||||||
|
- Final model exported to `models/logo_detection/clip_finetuned/`
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
Key parameters in `configs/jetson_orin.yaml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Model
|
||||||
|
base_model: "openai/clip-vit-large-patch14"
|
||||||
|
lora_r: 16 # LoRA rank (0 to disable)
|
||||||
|
lora_alpha: 32 # LoRA scaling factor
|
||||||
|
freeze_layers: 12 # Freeze first N transformer layers
|
||||||
|
|
||||||
|
# Batch construction
|
||||||
|
batch_size: 16
|
||||||
|
logos_per_batch: 32 # Different logos per batch
|
||||||
|
samples_per_logo: 4 # Samples per logo (creates positive pairs)
|
||||||
|
gradient_accumulation_steps: 8 # Effective batch = 128
|
||||||
|
|
||||||
|
# Training
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
max_epochs: 20
|
||||||
|
mixed_precision: true
|
||||||
|
temperature: 0.07 # InfoNCE temperature
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience: 5
|
||||||
|
min_delta: 0.001
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation
|
||||||
|
|
||||||
|
### Test Fine-Tuned Model
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python test_logo_detection.py -n 50 \
|
||||||
|
-e models/logo_detection/clip_finetuned \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--seed 42
|
||||||
|
```
|
||||||
|
|
||||||
|
### Compare with Baseline
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Baseline CLIP
|
||||||
|
uv run python test_logo_detection.py -n 50 \
|
||||||
|
-e openai/clip-vit-large-patch14 \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--seed 42
|
||||||
|
|
||||||
|
# Fine-tuned model
|
||||||
|
uv run python test_logo_detection.py -n 50 \
|
||||||
|
-e models/logo_detection/clip_finetuned \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--seed 42
|
||||||
|
```
|
||||||
|
|
||||||
|
### Expected Metrics
|
||||||
|
|
||||||
|
| Metric | Baseline CLIP | Target (Fine-tuned) |
|
||||||
|
|--------|---------------|---------------------|
|
||||||
|
| Precision | ~49% | >70% |
|
||||||
|
| Recall | ~77% | >75% |
|
||||||
|
| F1 Score | ~60% | >72% |
|
||||||
|
|
||||||
|
Training metrics to monitor:
|
||||||
|
- Mean positive similarity: target > 0.85
|
||||||
|
- Mean negative similarity: target < 0.50
|
||||||
|
- Embedding separation: target > 0.35
|
||||||
|
|
||||||
|
## Export Model
|
||||||
|
|
||||||
|
To export a checkpoint to HuggingFace format:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python export_model.py \
|
||||||
|
--checkpoint checkpoints/best.pt \
|
||||||
|
--output models/logo_detection/clip_finetuned
|
||||||
|
```
|
||||||
|
|
||||||
|
With LoRA weight merging (reduces inference overhead):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python export_model.py \
|
||||||
|
--checkpoint checkpoints/best.pt \
|
||||||
|
--output models/logo_detection/clip_finetuned \
|
||||||
|
--merge-lora
|
||||||
|
```
|
||||||
|
|
||||||
|
## Using Fine-Tuned Model with DetectLogosDETR
|
||||||
|
|
||||||
|
The fine-tuned model works as a drop-in replacement:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from logo_detection_detr import DetectLogosDETR
|
||||||
|
|
||||||
|
# Use fine-tuned model
|
||||||
|
detector = DetectLogosDETR(
|
||||||
|
logger=logger,
|
||||||
|
embedding_model="models/logo_detection/clip_finetuned",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Or use baseline for comparison
|
||||||
|
detector_baseline = DetectLogosDETR(
|
||||||
|
logger=logger,
|
||||||
|
embedding_model="openai/clip-vit-large-patch14",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture Details
|
||||||
|
|
||||||
|
### Training Approach
|
||||||
|
|
||||||
|
1. **Contrastive Learning**: Uses InfoNCE loss to maximize similarity between embeddings of the same logo while minimizing similarity to different logos.
|
||||||
|
|
||||||
|
2. **LoRA (Low-Rank Adaptation)**: Adds small trainable matrices to attention layers instead of fine-tuning all weights. This is memory-efficient and prevents catastrophic forgetting.
|
||||||
|
|
||||||
|
3. **Layer Freezing**: Freezes the first 12 of 24 transformer layers to preserve CLIP's low-level visual features while adapting high-level semantics.
|
||||||
|
|
||||||
|
4. **Logo-Level Splits**: Splits data by logo brand (not by image) to test generalization to unseen logos.
|
||||||
|
|
||||||
|
### Batch Construction
|
||||||
|
|
||||||
|
Each batch contains:
|
||||||
|
- K different logo brands (default: 32)
|
||||||
|
- M samples per brand (default: 4)
|
||||||
|
- Total samples: K × M = 128
|
||||||
|
|
||||||
|
This ensures positive pairs (same logo) exist within each batch for contrastive learning.
|
||||||
|
|
||||||
|
### Data Augmentation
|
||||||
|
|
||||||
|
Medium strength augmentations:
|
||||||
|
- Random horizontal flip
|
||||||
|
- Random rotation (±15°)
|
||||||
|
- Color jitter (brightness, contrast, saturation)
|
||||||
|
- Random affine transforms
|
||||||
|
- Random grayscale (10% of images)
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Out of Memory
|
||||||
|
|
||||||
|
Reduce batch size and increase gradient accumulation:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||||
|
--batch-size 8 \
|
||||||
|
--gradient-accumulation-steps 16
|
||||||
|
```
|
||||||
|
|
||||||
|
### Slow Training
|
||||||
|
|
||||||
|
Ensure mixed precision is enabled:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
|
||||||
|
# mixed_precision: true is default in jetson_orin.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
### No Improvement
|
||||||
|
|
||||||
|
Try adjusting:
|
||||||
|
- Lower learning rate: `--learning-rate 5e-6`
|
||||||
|
- Higher temperature: `--temperature 0.1`
|
||||||
|
- Different loss: edit config to use `loss_type: "combined"`
|
||||||
|
|
||||||
|
### Import Error for Fine-Tuned Model
|
||||||
|
|
||||||
|
Ensure the `training/` module is in your Python path:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH="${PYTHONPATH}:/data/dev.python/logo_test"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies Added
|
||||||
|
|
||||||
|
The following were added to `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
peft>=0.7.0 # LoRA support
|
||||||
|
pyyaml>=6.0 # Config file parsing
|
||||||
|
torchvision>=0.20.0 # Image transforms
|
||||||
|
```
|
||||||
269
CLOUD_TRAINING.md
Normal file
269
CLOUD_TRAINING.md
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
# Cloud GPU Training for CLIP Fine-Tuning
|
||||||
|
|
||||||
|
This document provides guidance on using cloud GPU instances (e.g., RunPod) for faster CLIP fine-tuning compared to local training on Jetson Orin AGX.
|
||||||
|
|
||||||
|
## Training Time Comparison
|
||||||
|
|
||||||
|
Local training on Jetson Orin AGX takes approximately 24 hours. Cloud GPUs offer significantly faster training:
|
||||||
|
|
||||||
|
| GPU | VRAM | Est. Training Time | Hourly Rate | Est. Total Cost |
|
||||||
|
|-----|------|-------------------|-------------|-----------------|
|
||||||
|
| **RTX 4090** | 24GB | 4-6 hours | $0.59/hr | **$2.40-$3.50** |
|
||||||
|
| **RTX 3090** | 24GB | 5-7 hours | $0.39/hr | **$2.00-$2.75** |
|
||||||
|
| **A100 80GB** | 80GB | 2-3 hours | $1.99/hr | **$4.00-$6.00** |
|
||||||
|
| **L40S** | 48GB | 3-4 hours | $0.89/hr | **$2.70-$3.60** |
|
||||||
|
| **H100 80GB** | 80GB | 1.5-2 hours | $1.99/hr | **$3.00-$4.00** |
|
||||||
|
|
||||||
|
*Prices from RunPod Community Cloud as of January 2025. Rates may vary.*
|
||||||
|
|
||||||
|
## Recommendations
|
||||||
|
|
||||||
|
### Best Value: RTX 4090 ($0.59/hr)
|
||||||
|
- 24GB VRAM is sufficient for ViT-L/14 with LoRA
|
||||||
|
- Good balance of speed and cost
|
||||||
|
- Widely available on Community Cloud
|
||||||
|
- **Total cost: ~$3 for complete training**
|
||||||
|
|
||||||
|
### Best Speed: H100 80GB ($1.99/hr)
|
||||||
|
- Fastest training (1.5-2 hours)
|
||||||
|
- 80GB VRAM allows larger batch sizes
|
||||||
|
- Can increase `batch_size` to 32+ and reduce `gradient_accumulation_steps`
|
||||||
|
- **Total cost: ~$3-4**
|
||||||
|
|
||||||
|
### Budget Option: RTX 3090 ($0.39/hr)
|
||||||
|
- Cheapest hourly rate
|
||||||
|
- 24GB VRAM works fine
|
||||||
|
- Slightly slower than 4090
|
||||||
|
- **Total cost: ~$2-3**
|
||||||
|
|
||||||
|
## Cloud-Optimized Configurations
|
||||||
|
|
||||||
|
### RTX 4090 / RTX 3090 (24GB VRAM)
|
||||||
|
|
||||||
|
Create `configs/cloud_rtx4090.yaml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Optimized for 24GB VRAM cloud GPUs
|
||||||
|
base_model: "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Dataset paths
|
||||||
|
dataset_dir: "LogoDet-3K"
|
||||||
|
reference_dir: "reference_logos"
|
||||||
|
db_path: "test_data_mapping.db"
|
||||||
|
|
||||||
|
# Data splits
|
||||||
|
train_split: 0.7
|
||||||
|
val_split: 0.15
|
||||||
|
test_split: 0.15
|
||||||
|
|
||||||
|
# Larger batches for faster training
|
||||||
|
batch_size: 32
|
||||||
|
logos_per_batch: 32
|
||||||
|
samples_per_logo: 4
|
||||||
|
gradient_accumulation_steps: 4 # Effective batch = 128
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
# Model architecture
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.1
|
||||||
|
freeze_layers: 12
|
||||||
|
use_gradient_checkpointing: true
|
||||||
|
|
||||||
|
# Training
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
weight_decay: 0.01
|
||||||
|
warmup_steps: 500
|
||||||
|
max_epochs: 20
|
||||||
|
mixed_precision: true
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
temperature: 0.07
|
||||||
|
loss_type: "infonce"
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience: 5
|
||||||
|
min_delta: 0.001
|
||||||
|
|
||||||
|
# Output
|
||||||
|
checkpoint_dir: "checkpoints"
|
||||||
|
output_dir: "models/logo_detection/clip_finetuned"
|
||||||
|
save_every_n_epochs: 5
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_every_n_steps: 10
|
||||||
|
eval_every_n_epochs: 1
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
use_augmentation: true
|
||||||
|
augmentation_strength: "medium"
|
||||||
|
```
|
||||||
|
|
||||||
|
### A100 / H100 (80GB VRAM)
|
||||||
|
|
||||||
|
Create `configs/cloud_a100.yaml`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Optimized for 80GB VRAM cloud GPUs (A100, H100)
|
||||||
|
base_model: "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Dataset paths
|
||||||
|
dataset_dir: "LogoDet-3K"
|
||||||
|
reference_dir: "reference_logos"
|
||||||
|
db_path: "test_data_mapping.db"
|
||||||
|
|
||||||
|
# Data splits
|
||||||
|
train_split: 0.7
|
||||||
|
val_split: 0.15
|
||||||
|
test_split: 0.15
|
||||||
|
|
||||||
|
# Maximum batch sizes for 80GB VRAM
|
||||||
|
batch_size: 64
|
||||||
|
logos_per_batch: 32
|
||||||
|
samples_per_logo: 4
|
||||||
|
gradient_accumulation_steps: 2 # Effective batch = 128
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
# Model architecture (can disable gradient checkpointing with 80GB)
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.1
|
||||||
|
freeze_layers: 12
|
||||||
|
use_gradient_checkpointing: false # Not needed with 80GB
|
||||||
|
|
||||||
|
# Training
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
weight_decay: 0.01
|
||||||
|
warmup_steps: 500
|
||||||
|
max_epochs: 20
|
||||||
|
mixed_precision: true
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
temperature: 0.07
|
||||||
|
loss_type: "infonce"
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience: 5
|
||||||
|
min_delta: 0.001
|
||||||
|
|
||||||
|
# Output
|
||||||
|
checkpoint_dir: "checkpoints"
|
||||||
|
output_dir: "models/logo_detection/clip_finetuned"
|
||||||
|
save_every_n_epochs: 5
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_every_n_steps: 10
|
||||||
|
eval_every_n_epochs: 1
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
use_augmentation: true
|
||||||
|
augmentation_strength: "medium"
|
||||||
|
```
|
||||||
|
|
||||||
|
## RunPod Quick Start
|
||||||
|
|
||||||
|
### 1. Create a Pod
|
||||||
|
|
||||||
|
1. Go to [RunPod](https://www.runpod.io/)
|
||||||
|
2. Select GPU (RTX 4090 recommended)
|
||||||
|
3. Choose PyTorch template (CUDA 12.x)
|
||||||
|
4. Set volume size: 50GB (for dataset + models)
|
||||||
|
|
||||||
|
### 2. Setup Environment
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Connect via SSH or web terminal
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
pip install peft pyyaml torchvision transformers tqdm pillow
|
||||||
|
|
||||||
|
# Clone your repository (or upload files)
|
||||||
|
git clone <your-repo-url>
|
||||||
|
cd logo_test
|
||||||
|
|
||||||
|
# Or use runpodctl to sync files
|
||||||
|
# runpodctl send logo_test/
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Prepare Data
|
||||||
|
|
||||||
|
If data isn't already prepared:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# This creates reference_logos/ and test_data_mapping.db
|
||||||
|
python prepare_test_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Run Training
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# For RTX 4090
|
||||||
|
python train_clip_logo.py --config configs/cloud_rtx4090.yaml
|
||||||
|
|
||||||
|
# For A100/H100
|
||||||
|
python train_clip_logo.py --config configs/cloud_a100.yaml
|
||||||
|
|
||||||
|
# Or with command-line overrides
|
||||||
|
python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||||
|
--batch-size 32 \
|
||||||
|
--gradient-accumulation-steps 4 \
|
||||||
|
--num-workers 8
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Download Results
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Export the trained model
|
||||||
|
python export_model.py \
|
||||||
|
--checkpoint checkpoints/best.pt \
|
||||||
|
--output models/logo_detection/clip_finetuned
|
||||||
|
|
||||||
|
# Download to local machine
|
||||||
|
# Option 1: Use runpodctl
|
||||||
|
runpodctl receive models/logo_detection/clip_finetuned
|
||||||
|
|
||||||
|
# Option 2: SCP
|
||||||
|
scp -r root@<pod-ip>:/workspace/logo_test/models/logo_detection/clip_finetuned ./
|
||||||
|
|
||||||
|
# Option 3: Compress and download via web
|
||||||
|
tar -czvf clip_finetuned.tar.gz models/logo_detection/clip_finetuned
|
||||||
|
```
|
||||||
|
|
||||||
|
## Cost Optimization Tips
|
||||||
|
|
||||||
|
### Use Spot/Interruptible Instances
|
||||||
|
- Community Cloud GPUs are already cheaper
|
||||||
|
- Some providers offer spot pricing for additional savings
|
||||||
|
- Save checkpoints frequently (`save_every_n_epochs: 2`)
|
||||||
|
|
||||||
|
### Minimize Storage Costs
|
||||||
|
- RunPod charges $0.10/GB/month for container disk
|
||||||
|
- Use network volumes only if needed
|
||||||
|
- Delete pods when training completes
|
||||||
|
|
||||||
|
### Monitor Training
|
||||||
|
- Watch for early convergence (may finish before 20 epochs)
|
||||||
|
- Early stopping will save time/cost if no improvement
|
||||||
|
|
||||||
|
### Batch Training Runs
|
||||||
|
- Test configuration locally first (1-2 epochs)
|
||||||
|
- Run full training on cloud only when config is validated
|
||||||
|
|
||||||
|
## Cost Comparison Summary
|
||||||
|
|
||||||
|
| Option | Time | Cost | Best For |
|
||||||
|
|--------|------|------|----------|
|
||||||
|
| Jetson Orin (local) | ~24 hrs | Free* | No cloud dependency |
|
||||||
|
| RTX 3090 (RunPod) | ~6 hrs | ~$2.50 | Lowest cost |
|
||||||
|
| RTX 4090 (RunPod) | ~5 hrs | ~$3.00 | Best value |
|
||||||
|
| L40S (RunPod) | ~3.5 hrs | ~$3.00 | Good balance |
|
||||||
|
| A100 80GB (RunPod) | ~2.5 hrs | ~$5.00 | Large batches |
|
||||||
|
| H100 80GB (RunPod) | ~1.5 hrs | ~$3.50 | Fastest |
|
||||||
|
|
||||||
|
*Local training has electricity cost but no cloud fees.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
- [RunPod Pricing](https://www.runpod.io/pricing)
|
||||||
|
- [RunPod RTX 4090](https://www.runpod.io/gpu-models/rtx-4090)
|
||||||
|
- [RunPod Documentation](https://docs.runpod.io/)
|
||||||
64
configs/cloud_a100.yaml
Normal file
64
configs/cloud_a100.yaml
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Training configuration optimized for cloud A100 / H100 (80GB VRAM)
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# python train_clip_logo.py --config configs/cloud_a100.yaml
|
||||||
|
#
|
||||||
|
# Estimated training time: 1.5-3 hours
|
||||||
|
# Estimated cost on RunPod: ~$3-6
|
||||||
|
|
||||||
|
# Base model
|
||||||
|
base_model: "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Dataset paths
|
||||||
|
dataset_dir: "LogoDet-3K"
|
||||||
|
reference_dir: "reference_logos"
|
||||||
|
db_path: "test_data_mapping.db"
|
||||||
|
|
||||||
|
# Data splits
|
||||||
|
train_split: 0.7
|
||||||
|
val_split: 0.15
|
||||||
|
test_split: 0.15
|
||||||
|
|
||||||
|
# Maximum batch sizes for 80GB VRAM
|
||||||
|
batch_size: 64
|
||||||
|
logos_per_batch: 32
|
||||||
|
samples_per_logo: 4
|
||||||
|
gradient_accumulation_steps: 2 # Effective batch = 128
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
# Model architecture (no gradient checkpointing needed with 80GB)
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.1
|
||||||
|
freeze_layers: 12
|
||||||
|
use_gradient_checkpointing: false
|
||||||
|
|
||||||
|
# Training
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
weight_decay: 0.01
|
||||||
|
warmup_steps: 500
|
||||||
|
max_epochs: 20
|
||||||
|
mixed_precision: true
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
temperature: 0.07
|
||||||
|
loss_type: "infonce"
|
||||||
|
triplet_margin: 0.3
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience: 5
|
||||||
|
min_delta: 0.001
|
||||||
|
|
||||||
|
# Output
|
||||||
|
checkpoint_dir: "checkpoints"
|
||||||
|
output_dir: "models/logo_detection/clip_finetuned"
|
||||||
|
save_every_n_epochs: 2 # Save more frequently for cloud
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_every_n_steps: 10
|
||||||
|
eval_every_n_epochs: 1
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
use_hard_negatives: false
|
||||||
|
use_augmentation: true
|
||||||
|
augmentation_strength: "medium"
|
||||||
64
configs/cloud_rtx4090.yaml
Normal file
64
configs/cloud_rtx4090.yaml
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
# Training configuration optimized for cloud RTX 4090 / RTX 3090 (24GB VRAM)
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# python train_clip_logo.py --config configs/cloud_rtx4090.yaml
|
||||||
|
#
|
||||||
|
# Estimated training time: 4-6 hours
|
||||||
|
# Estimated cost on RunPod: ~$3
|
||||||
|
|
||||||
|
# Base model
|
||||||
|
base_model: "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Dataset paths
|
||||||
|
dataset_dir: "LogoDet-3K"
|
||||||
|
reference_dir: "reference_logos"
|
||||||
|
db_path: "test_data_mapping.db"
|
||||||
|
|
||||||
|
# Data splits
|
||||||
|
train_split: 0.7
|
||||||
|
val_split: 0.15
|
||||||
|
test_split: 0.15
|
||||||
|
|
||||||
|
# Larger batches for faster training on 24GB VRAM
|
||||||
|
batch_size: 32
|
||||||
|
logos_per_batch: 32
|
||||||
|
samples_per_logo: 4
|
||||||
|
gradient_accumulation_steps: 4 # Effective batch = 128
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
# Model architecture
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.1
|
||||||
|
freeze_layers: 12
|
||||||
|
use_gradient_checkpointing: true
|
||||||
|
|
||||||
|
# Training
|
||||||
|
learning_rate: 1.0e-5
|
||||||
|
weight_decay: 0.01
|
||||||
|
warmup_steps: 500
|
||||||
|
max_epochs: 20
|
||||||
|
mixed_precision: true
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
temperature: 0.07
|
||||||
|
loss_type: "infonce"
|
||||||
|
triplet_margin: 0.3
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience: 5
|
||||||
|
min_delta: 0.001
|
||||||
|
|
||||||
|
# Output
|
||||||
|
checkpoint_dir: "checkpoints"
|
||||||
|
output_dir: "models/logo_detection/clip_finetuned"
|
||||||
|
save_every_n_epochs: 2 # Save more frequently for cloud
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_every_n_steps: 10
|
||||||
|
eval_every_n_epochs: 1
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
use_hard_negatives: false
|
||||||
|
use_augmentation: true
|
||||||
|
augmentation_strength: "medium"
|
||||||
76
configs/jetson_orin.yaml
Normal file
76
configs/jetson_orin.yaml
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
# Training configuration optimized for Jetson Orin AGX (~64GB shared memory)
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# uv run python train_clip_logo.py --config configs/jetson_orin.yaml
|
||||||
|
|
||||||
|
# Base model
|
||||||
|
base_model: "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Dataset paths (relative to project root)
|
||||||
|
dataset_dir: "LogoDet-3K"
|
||||||
|
reference_dir: "reference_logos"
|
||||||
|
db_path: "test_data_mapping.db"
|
||||||
|
|
||||||
|
# Data split ratios (logo-level split for generalization testing)
|
||||||
|
train_split: 0.7
|
||||||
|
val_split: 0.15
|
||||||
|
test_split: 0.15
|
||||||
|
|
||||||
|
# Batch construction
|
||||||
|
# - batch_size: Number of batches loaded at once (keep low for memory)
|
||||||
|
# - logos_per_batch: Different logo classes per contrastive batch
|
||||||
|
# - samples_per_logo: Samples of each logo (creates positive pairs)
|
||||||
|
# - Effective samples per step = logos_per_batch * samples_per_logo = 128
|
||||||
|
batch_size: 16
|
||||||
|
logos_per_batch: 32
|
||||||
|
samples_per_logo: 4
|
||||||
|
gradient_accumulation_steps: 8 # Effective batch = 128
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
# Model architecture
|
||||||
|
# LoRA enables memory-efficient fine-tuning by training low-rank adapters
|
||||||
|
# instead of full model weights
|
||||||
|
lora_r: 16 # LoRA rank (0 to disable)
|
||||||
|
lora_alpha: 32 # LoRA scaling factor
|
||||||
|
lora_dropout: 0.1 # Dropout in LoRA layers
|
||||||
|
freeze_layers: 12 # Freeze first 12 of 24 transformer layers
|
||||||
|
use_gradient_checkpointing: true # Trade compute for memory
|
||||||
|
|
||||||
|
# Training hyperparameters
|
||||||
|
learning_rate: 1.0e-5 # Conservative LR for fine-tuning
|
||||||
|
weight_decay: 0.01 # L2 regularization
|
||||||
|
warmup_steps: 500 # LR warmup steps
|
||||||
|
max_epochs: 20 # Maximum training epochs
|
||||||
|
mixed_precision: true # FP16 training for memory efficiency
|
||||||
|
|
||||||
|
# Loss function
|
||||||
|
# InfoNCE is the contrastive loss used in CLIP training
|
||||||
|
temperature: 0.07 # Similarity scaling (0.05-0.1 typical)
|
||||||
|
loss_type: "infonce" # Options: infonce, supcon, triplet, combined
|
||||||
|
triplet_margin: 0.3 # Only used if loss_type is triplet
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience: 5 # Stop if no improvement for N epochs
|
||||||
|
min_delta: 0.001 # Minimum improvement threshold
|
||||||
|
|
||||||
|
# Checkpoints and output
|
||||||
|
checkpoint_dir: "checkpoints"
|
||||||
|
output_dir: "models/logo_detection/clip_finetuned"
|
||||||
|
save_every_n_epochs: 5
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_every_n_steps: 10
|
||||||
|
eval_every_n_epochs: 1
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
# Hard negative mining (advanced)
|
||||||
|
# Enable after initial training epochs for harder examples
|
||||||
|
use_hard_negatives: false
|
||||||
|
hard_negative_start_epoch: 5
|
||||||
|
hard_negatives_per_logo: 10
|
||||||
|
|
||||||
|
# Data augmentation
|
||||||
|
use_augmentation: true
|
||||||
|
augmentation_strength: "medium" # light, medium, or strong
|
||||||
169
export_model.py
Normal file
169
export_model.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Export a trained CLIP model to HuggingFace-compatible format.
|
||||||
|
|
||||||
|
This script converts a training checkpoint to a format that can be
|
||||||
|
loaded by DetectLogosDETR for inference.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python export_model.py \
|
||||||
|
--checkpoint checkpoints/best.pt \
|
||||||
|
--output models/logo_detection/clip_finetuned
|
||||||
|
|
||||||
|
# With custom base model
|
||||||
|
uv run python export_model.py \
|
||||||
|
--checkpoint checkpoints/best.pt \
|
||||||
|
--output models/logo_detection/clip_finetuned \
|
||||||
|
--base-model openai/clip-vit-large-patch14
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from training.config import TrainingConfig
|
||||||
|
from training.model import create_model, LogoFineTunedCLIP
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging() -> logging.Logger:
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
return logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Export trained CLIP model for inference",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to training checkpoint (.pt file)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Output directory for exported model",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-model",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Base CLIP model (reads from checkpoint config if not specified)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--merge-lora",
|
||||||
|
action="store_true",
|
||||||
|
help="Merge LoRA weights into base model (reduces inference overhead)",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
logger = setup_logging()
|
||||||
|
|
||||||
|
logger.info("CLIP Model Export")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# Check checkpoint exists
|
||||||
|
checkpoint_path = Path(args.checkpoint)
|
||||||
|
if not checkpoint_path.exists():
|
||||||
|
logger.error(f"Checkpoint not found: {checkpoint_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Load checkpoint
|
||||||
|
logger.info(f"Loading checkpoint: {checkpoint_path}")
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
|
# Get config from checkpoint
|
||||||
|
if "config" in checkpoint:
|
||||||
|
config_dict = checkpoint["config"]
|
||||||
|
base_model = args.base_model or config_dict.get(
|
||||||
|
"base_model", "openai/clip-vit-large-patch14"
|
||||||
|
)
|
||||||
|
lora_r = config_dict.get("lora_r", 16)
|
||||||
|
lora_alpha = config_dict.get("lora_alpha", 32)
|
||||||
|
freeze_layers = config_dict.get("freeze_layers", 12)
|
||||||
|
else:
|
||||||
|
base_model = args.base_model or "openai/clip-vit-large-patch14"
|
||||||
|
lora_r = 16
|
||||||
|
lora_alpha = 32
|
||||||
|
freeze_layers = 12
|
||||||
|
|
||||||
|
logger.info(f"Base model: {base_model}")
|
||||||
|
logger.info(f"LoRA rank: {lora_r}")
|
||||||
|
logger.info(f"Freeze layers: {freeze_layers}")
|
||||||
|
|
||||||
|
# Create model with same architecture
|
||||||
|
logger.info("Creating model architecture...")
|
||||||
|
model, processor = create_model(
|
||||||
|
base_model=base_model,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
freeze_layers=freeze_layers,
|
||||||
|
use_gradient_checkpointing=False, # Not needed for export
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
logger.info("Loading trained weights...")
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
|
||||||
|
# Merge LoRA if requested
|
||||||
|
if args.merge_lora and model.peft_applied:
|
||||||
|
try:
|
||||||
|
logger.info("Merging LoRA weights into base model...")
|
||||||
|
model.vision_model = model.vision_model.merge_and_unload()
|
||||||
|
model.peft_applied = False
|
||||||
|
model.lora_r = 0
|
||||||
|
logger.info("LoRA weights merged successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not merge LoRA weights: {e}")
|
||||||
|
logger.warning("Exporting with separate LoRA weights")
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
output_path = Path(args.output)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
logger.info(f"Exporting to: {output_path}")
|
||||||
|
model.save_pretrained(str(output_path))
|
||||||
|
|
||||||
|
# Save processor config for reference
|
||||||
|
processor.save_pretrained(str(output_path / "processor"))
|
||||||
|
|
||||||
|
# Save additional metadata
|
||||||
|
metadata = {
|
||||||
|
"base_model": base_model,
|
||||||
|
"source_checkpoint": str(checkpoint_path),
|
||||||
|
"training_epochs": checkpoint.get("epoch", -1) + 1,
|
||||||
|
"best_val_loss": checkpoint.get("best_val_loss", None),
|
||||||
|
"best_val_separation": checkpoint.get("best_val_separation", None),
|
||||||
|
"lora_merged": args.merge_lora and not model.peft_applied,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(output_path / "export_metadata.json", "w") as f:
|
||||||
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
|
logger.info("\nExport complete!")
|
||||||
|
logger.info(f"Model saved to: {output_path}")
|
||||||
|
logger.info("\nTo use with DetectLogosDETR:")
|
||||||
|
logger.info(f" detector = DetectLogosDETR(embedding_model='{output_path}')")
|
||||||
|
logger.info("\nOr with test_logo_detection.py:")
|
||||||
|
logger.info(f" uv run python test_logo_detection.py -e {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -13,6 +13,7 @@ Supported embedding models:
|
|||||||
- DINOv2 models (facebook/dinov2-*): Self-supervised, excellent for visual similarity
|
- DINOv2 models (facebook/dinov2-*): Self-supervised, excellent for visual similarity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -100,16 +101,20 @@ class DetectLogosDETR:
|
|||||||
embedding_model, default_embedding_dir, "Embedding"
|
embedding_model, default_embedding_dir, "Embedding"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Detect model type and initialize accordingly
|
# Check if this is a fine-tuned model
|
||||||
self.model_type = self._detect_model_type(embedding_model)
|
if self._is_finetuned_model(embedding_model_path):
|
||||||
self.logger.info(f"Loading {self.model_type} embedding model: {embedding_model_path}")
|
self._load_finetuned_embedding_model(embedding_model_path)
|
||||||
|
else:
|
||||||
|
# Detect model type and initialize accordingly
|
||||||
|
self.model_type = self._detect_model_type(embedding_model)
|
||||||
|
self.logger.info(f"Loading {self.model_type} embedding model: {embedding_model_path}")
|
||||||
|
|
||||||
if self.model_type == "clip":
|
if self.model_type == "clip":
|
||||||
self.embedding_model = CLIPModel.from_pretrained(embedding_model_path).to(self.device)
|
self.embedding_model = CLIPModel.from_pretrained(embedding_model_path).to(self.device)
|
||||||
self.embedding_processor = CLIPProcessor.from_pretrained(embedding_model_path)
|
self.embedding_processor = CLIPProcessor.from_pretrained(embedding_model_path)
|
||||||
else: # dinov2 or other transformer models
|
else: # dinov2 or other transformer models
|
||||||
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
|
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
|
||||||
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
|
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
|
||||||
|
|
||||||
self.logger.info("DetectLogosDETR initialization complete")
|
self.logger.info("DetectLogosDETR initialization complete")
|
||||||
|
|
||||||
@ -124,6 +129,62 @@ class DetectLogosDETR:
|
|||||||
# Default to generic transformer for unknown models
|
# Default to generic transformer for unknown models
|
||||||
return "transformer"
|
return "transformer"
|
||||||
|
|
||||||
|
def _is_finetuned_model(self, model_path: str) -> bool:
|
||||||
|
"""Check if a model path points to a fine-tuned CLIP model."""
|
||||||
|
config_path = Path(model_path) / "config.json"
|
||||||
|
if config_path.exists():
|
||||||
|
try:
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return config.get("model_type") == "clip_logo_finetuned"
|
||||||
|
except (json.JSONDecodeError, IOError):
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _load_finetuned_embedding_model(self, model_path: str) -> None:
|
||||||
|
"""
|
||||||
|
Load a fine-tuned CLIP model from the training module.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to the fine-tuned model directory
|
||||||
|
"""
|
||||||
|
# Import the fine-tuned model class
|
||||||
|
try:
|
||||||
|
from training.model import LogoFineTunedCLIP
|
||||||
|
except ImportError as e:
|
||||||
|
self.logger.error(
|
||||||
|
f"Cannot import training.model for fine-tuned model: {e}"
|
||||||
|
)
|
||||||
|
raise ImportError(
|
||||||
|
"Fine-tuned model requires the training module. "
|
||||||
|
"Ensure the training/ directory is in your Python path."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
config_path = Path(model_path) / "config.json"
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
base_model = config.get("base_model", "openai/clip-vit-large-patch14")
|
||||||
|
|
||||||
|
self.logger.info(f"Loading fine-tuned CLIP model from: {model_path}")
|
||||||
|
self.logger.info(f" Base model: {base_model}")
|
||||||
|
|
||||||
|
# Load model using the from_pretrained method
|
||||||
|
self.embedding_model = LogoFineTunedCLIP.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
base_model=base_model,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.embedding_model.eval()
|
||||||
|
|
||||||
|
# Load processor from base model
|
||||||
|
self.embedding_processor = CLIPProcessor.from_pretrained(base_model)
|
||||||
|
|
||||||
|
# Set model type for embedding extraction
|
||||||
|
self.model_type = "clip_finetuned"
|
||||||
|
self.logger.info("Fine-tuned CLIP model loaded successfully")
|
||||||
|
|
||||||
def _resolve_model_path(
|
def _resolve_model_path(
|
||||||
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -345,7 +406,7 @@ class DetectLogosDETR:
|
|||||||
"""
|
"""
|
||||||
Internal method to get embedding from PIL image.
|
Internal method to get embedding from PIL image.
|
||||||
|
|
||||||
Handles both CLIP and DINOv2 model types.
|
Handles CLIP, fine-tuned CLIP, and DINOv2 model types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pil_image: PIL Image (RGB format)
|
pil_image: PIL Image (RGB format)
|
||||||
@ -360,6 +421,9 @@ class DetectLogosDETR:
|
|||||||
if self.model_type == "clip":
|
if self.model_type == "clip":
|
||||||
# CLIP has a dedicated method for image features
|
# CLIP has a dedicated method for image features
|
||||||
features = self.embedding_model.get_image_features(**inputs)
|
features = self.embedding_model.get_image_features(**inputs)
|
||||||
|
elif self.model_type == "clip_finetuned":
|
||||||
|
# Fine-tuned CLIP uses get_image_features or forward with pixel_values
|
||||||
|
features = self.embedding_model.get_image_features(**inputs)
|
||||||
else:
|
else:
|
||||||
# DINOv2 and other transformers use the CLS token or pooled output
|
# DINOv2 and other transformers use the CLS token or pooled output
|
||||||
outputs = self.embedding_model(**inputs)
|
outputs = self.embedding_model(**inputs)
|
||||||
@ -370,8 +434,9 @@ class DetectLogosDETR:
|
|||||||
# Use CLS token from last_hidden_state
|
# Use CLS token from last_hidden_state
|
||||||
features = outputs.last_hidden_state[:, 0, :]
|
features = outputs.last_hidden_state[:, 0, :]
|
||||||
|
|
||||||
# Normalize for cosine similarity
|
# Normalize for cosine similarity (fine-tuned model already normalizes)
|
||||||
features = F.normalize(features, dim=-1)
|
if self.model_type != "clip_finetuned":
|
||||||
|
features = F.normalize(features, dim=-1)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|||||||
@ -12,4 +12,7 @@ dependencies = [
|
|||||||
"tqdm>=4.67.1",
|
"tqdm>=4.67.1",
|
||||||
"transformers>=4.57.3",
|
"transformers>=4.57.3",
|
||||||
"typing>=3.10.0.0",
|
"typing>=3.10.0.0",
|
||||||
|
"peft>=0.7.0",
|
||||||
|
"pyyaml>=6.0",
|
||||||
|
"torchvision>=0.20.0",
|
||||||
]
|
]
|
||||||
|
|||||||
309
train_clip_logo.py
Normal file
309
train_clip_logo.py
Normal file
@ -0,0 +1,309 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Fine-tune CLIP vision encoder for logo recognition.
|
||||||
|
|
||||||
|
This script trains a CLIP model using contrastive learning on the LogoDet-3K
|
||||||
|
dataset to improve logo embedding quality for similarity-based matching.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Train with YAML config
|
||||||
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
|
||||||
|
|
||||||
|
# Train with command-line overrides
|
||||||
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||||
|
--learning-rate 5e-6 --max-epochs 30
|
||||||
|
|
||||||
|
# Resume from checkpoint
|
||||||
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||||
|
--resume checkpoints/epoch_10.pt
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from training.config import TrainingConfig
|
||||||
|
from training.dataset import create_dataloaders
|
||||||
|
from training.model import create_model
|
||||||
|
from training.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(log_level: str = "INFO") -> logging.Logger:
|
||||||
|
"""Configure logging."""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=getattr(logging, log_level.upper()),
|
||||||
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
|
)
|
||||||
|
return logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed: int) -> None:
|
||||||
|
"""Set random seeds for reproducibility."""
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
"""Parse command-line arguments."""
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Fine-tune CLIP for logo recognition",
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Config file
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
help="Path to YAML configuration file",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dataset paths
|
||||||
|
parser.add_argument(
|
||||||
|
"--dataset-dir",
|
||||||
|
type=str,
|
||||||
|
help="Path to LogoDet-3K dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--reference-dir",
|
||||||
|
type=str,
|
||||||
|
help="Path to reference logos directory",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--db-path",
|
||||||
|
type=str,
|
||||||
|
help="Path to SQLite database",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-model",
|
||||||
|
type=str,
|
||||||
|
help="Base CLIP model name or path",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-r",
|
||||||
|
type=int,
|
||||||
|
help="LoRA rank (0 to disable)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--freeze-layers",
|
||||||
|
type=int,
|
||||||
|
help="Number of transformer layers to freeze",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size",
|
||||||
|
type=int,
|
||||||
|
help="Batch size",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning-rate",
|
||||||
|
type=float,
|
||||||
|
help="Learning rate",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-epochs",
|
||||||
|
type=int,
|
||||||
|
help="Maximum number of epochs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient-accumulation-steps",
|
||||||
|
type=int,
|
||||||
|
help="Gradient accumulation steps",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loss
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature",
|
||||||
|
type=float,
|
||||||
|
help="Temperature for InfoNCE loss",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--loss-type",
|
||||||
|
choices=["infonce", "supcon", "triplet", "combined"],
|
||||||
|
help="Loss function type",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Checkpointing
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint-dir",
|
||||||
|
type=str,
|
||||||
|
help="Directory for checkpoints",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=str,
|
||||||
|
help="Directory for final model output",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume",
|
||||||
|
type=str,
|
||||||
|
help="Path to checkpoint to resume from",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Other
|
||||||
|
parser.add_argument(
|
||||||
|
"--seed",
|
||||||
|
type=int,
|
||||||
|
help="Random seed",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-level",
|
||||||
|
type=str,
|
||||||
|
default="INFO",
|
||||||
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||||
|
help="Logging level",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-mixed-precision",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable mixed precision training",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main training entry point."""
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logger = setup_logging(args.log_level)
|
||||||
|
logger.info("CLIP Logo Fine-Tuning")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# Load or create configuration
|
||||||
|
if args.config:
|
||||||
|
logger.info(f"Loading config from: {args.config}")
|
||||||
|
config = TrainingConfig.from_yaml(args.config)
|
||||||
|
else:
|
||||||
|
logger.info("Using default configuration")
|
||||||
|
config = TrainingConfig()
|
||||||
|
|
||||||
|
# Apply command-line overrides
|
||||||
|
override_fields = [
|
||||||
|
"dataset_dir", "reference_dir", "db_path", "base_model",
|
||||||
|
"lora_r", "freeze_layers", "batch_size", "learning_rate",
|
||||||
|
"max_epochs", "gradient_accumulation_steps", "temperature",
|
||||||
|
"loss_type", "checkpoint_dir", "output_dir", "seed",
|
||||||
|
]
|
||||||
|
for field in override_fields:
|
||||||
|
arg_name = field.replace("_", "-")
|
||||||
|
arg_value = getattr(args, field.replace("-", "_"), None)
|
||||||
|
if arg_value is not None:
|
||||||
|
setattr(config, field, arg_value)
|
||||||
|
logger.info(f"Override: {field} = {arg_value}")
|
||||||
|
|
||||||
|
if args.no_mixed_precision:
|
||||||
|
config.mixed_precision = False
|
||||||
|
logger.info("Override: mixed_precision = False")
|
||||||
|
|
||||||
|
# Validate configuration
|
||||||
|
warnings = config.validate()
|
||||||
|
for warning in warnings:
|
||||||
|
logger.warning(f"Config warning: {warning}")
|
||||||
|
|
||||||
|
# Set random seed
|
||||||
|
set_seed(config.seed)
|
||||||
|
logger.info(f"Random seed: {config.seed}")
|
||||||
|
|
||||||
|
# Check paths exist
|
||||||
|
db_path = Path(config.db_path)
|
||||||
|
ref_dir = Path(config.reference_dir)
|
||||||
|
|
||||||
|
if not db_path.exists():
|
||||||
|
logger.error(f"Database not found: {db_path}")
|
||||||
|
logger.error("Run prepare_test_data.py first to create the database.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not ref_dir.exists():
|
||||||
|
logger.error(f"Reference directory not found: {ref_dir}")
|
||||||
|
logger.error("Run prepare_test_data.py first to extract reference logos.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
logger.info(f"Creating model from: {config.base_model}")
|
||||||
|
model, processor = create_model(
|
||||||
|
base_model=config.base_model,
|
||||||
|
lora_r=config.lora_r,
|
||||||
|
lora_alpha=config.lora_alpha,
|
||||||
|
lora_dropout=config.lora_dropout,
|
||||||
|
freeze_layers=config.freeze_layers,
|
||||||
|
use_gradient_checkpointing=config.use_gradient_checkpointing,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create dataloaders
|
||||||
|
logger.info("Creating dataloaders...")
|
||||||
|
train_loader, val_loader, test_loader = create_dataloaders(
|
||||||
|
db_path=str(config.db_path),
|
||||||
|
reference_dir=str(config.reference_dir),
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
logos_per_batch=config.logos_per_batch,
|
||||||
|
samples_per_logo=config.samples_per_logo,
|
||||||
|
num_workers=config.num_workers,
|
||||||
|
train_split=config.train_split,
|
||||||
|
val_split=config.val_split,
|
||||||
|
test_split=config.test_split,
|
||||||
|
seed=config.seed,
|
||||||
|
augmentation_strength=config.augmentation_strength,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create trainer
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model,
|
||||||
|
train_loader=train_loader,
|
||||||
|
val_loader=val_loader,
|
||||||
|
config=config,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resume from checkpoint if specified
|
||||||
|
if args.resume:
|
||||||
|
resume_path = Path(args.resume)
|
||||||
|
if resume_path.exists():
|
||||||
|
logger.info(f"Resuming from: {resume_path}")
|
||||||
|
# Set checkpoint dir to resume path's parent
|
||||||
|
if resume_path.is_file():
|
||||||
|
config.checkpoint_dir = str(resume_path.parent)
|
||||||
|
trainer.load_checkpoint(resume_path.name)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Resume checkpoint not found: {resume_path}")
|
||||||
|
|
||||||
|
# Train
|
||||||
|
logger.info("\nStarting training...")
|
||||||
|
final_metrics = trainer.train()
|
||||||
|
|
||||||
|
logger.info("\nTraining complete!")
|
||||||
|
logger.info(f" Best val loss: {final_metrics['best_val_loss']:.4f}")
|
||||||
|
logger.info(f" Best separation: {final_metrics['best_val_separation']:.4f}")
|
||||||
|
logger.info(f" Total epochs: {final_metrics['total_epochs']}")
|
||||||
|
logger.info(f" Total time: {final_metrics['total_time_minutes']:.1f} minutes")
|
||||||
|
|
||||||
|
# Export model
|
||||||
|
output_path = trainer.export_model()
|
||||||
|
logger.info(f"\nModel exported to: {output_path}")
|
||||||
|
|
||||||
|
# Print next steps
|
||||||
|
logger.info("\n" + "=" * 60)
|
||||||
|
logger.info("Next steps:")
|
||||||
|
logger.info(f"1. Test the fine-tuned model:")
|
||||||
|
logger.info(f" uv run python test_logo_detection.py -n 50 \\")
|
||||||
|
logger.info(f" -e {output_path} --matching-method multi-ref")
|
||||||
|
logger.info(f"")
|
||||||
|
logger.info(f"2. Compare with baseline:")
|
||||||
|
logger.info(f" uv run python test_logo_detection.py -n 50 \\")
|
||||||
|
logger.info(f" -e openai/clip-vit-large-patch14 --matching-method multi-ref")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
24
training/__init__.py
Normal file
24
training/__init__.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
CLIP fine-tuning module for logo recognition.
|
||||||
|
|
||||||
|
This module provides tools for fine-tuning CLIP's vision encoder using
|
||||||
|
contrastive learning on the LogoDet-3K dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .config import TrainingConfig
|
||||||
|
from .dataset import LogoContrastiveDataset, create_dataloaders
|
||||||
|
from .model import LogoFineTunedCLIP
|
||||||
|
from .losses import InfoNCELoss, TripletLoss
|
||||||
|
from .trainer import Trainer
|
||||||
|
from .evaluation import EmbeddingEvaluator
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TrainingConfig",
|
||||||
|
"LogoContrastiveDataset",
|
||||||
|
"create_dataloaders",
|
||||||
|
"LogoFineTunedCLIP",
|
||||||
|
"InfoNCELoss",
|
||||||
|
"TripletLoss",
|
||||||
|
"Trainer",
|
||||||
|
"EmbeddingEvaluator",
|
||||||
|
]
|
||||||
141
training/config.py
Normal file
141
training/config.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
"""
|
||||||
|
Training configuration for CLIP fine-tuning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TrainingConfig:
|
||||||
|
"""Configuration for CLIP logo fine-tuning."""
|
||||||
|
|
||||||
|
# Base model
|
||||||
|
base_model: str = "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Dataset paths
|
||||||
|
dataset_dir: str = "LogoDet-3K"
|
||||||
|
reference_dir: str = "reference_logos"
|
||||||
|
db_path: str = "test_data_mapping.db"
|
||||||
|
|
||||||
|
# Data split ratios
|
||||||
|
train_split: float = 0.7
|
||||||
|
val_split: float = 0.15
|
||||||
|
test_split: float = 0.15
|
||||||
|
|
||||||
|
# Batch construction
|
||||||
|
batch_size: int = 16
|
||||||
|
logos_per_batch: int = 32
|
||||||
|
samples_per_logo: int = 4
|
||||||
|
gradient_accumulation_steps: int = 8
|
||||||
|
num_workers: int = 4
|
||||||
|
|
||||||
|
# Model architecture
|
||||||
|
lora_r: int = 16
|
||||||
|
lora_alpha: int = 32
|
||||||
|
lora_dropout: float = 0.1
|
||||||
|
freeze_layers: int = 12
|
||||||
|
use_gradient_checkpointing: bool = True
|
||||||
|
|
||||||
|
# Training hyperparameters
|
||||||
|
learning_rate: float = 1e-5
|
||||||
|
weight_decay: float = 0.01
|
||||||
|
warmup_steps: int = 500
|
||||||
|
max_epochs: int = 20
|
||||||
|
mixed_precision: bool = True
|
||||||
|
|
||||||
|
# Loss function
|
||||||
|
temperature: float = 0.07
|
||||||
|
loss_type: str = "infonce" # "infonce" or "triplet"
|
||||||
|
triplet_margin: float = 0.3
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience: int = 5
|
||||||
|
min_delta: float = 0.001
|
||||||
|
|
||||||
|
# Checkpoints and output
|
||||||
|
checkpoint_dir: str = "checkpoints"
|
||||||
|
output_dir: str = "models/logo_detection/clip_finetuned"
|
||||||
|
save_every_n_epochs: int = 5
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_every_n_steps: int = 10
|
||||||
|
eval_every_n_epochs: int = 1
|
||||||
|
|
||||||
|
# Random seed for reproducibility
|
||||||
|
seed: int = 42
|
||||||
|
|
||||||
|
# Hard negative mining
|
||||||
|
use_hard_negatives: bool = False
|
||||||
|
hard_negative_start_epoch: int = 5
|
||||||
|
hard_negatives_per_logo: int = 10
|
||||||
|
|
||||||
|
# Data augmentation
|
||||||
|
use_augmentation: bool = True
|
||||||
|
augmentation_strength: str = "medium" # "light", "medium", "strong"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, yaml_path: str) -> "TrainingConfig":
|
||||||
|
"""Load configuration from YAML file."""
|
||||||
|
with open(yaml_path, "r") as f:
|
||||||
|
config_dict = yaml.safe_load(f)
|
||||||
|
return cls(**config_dict)
|
||||||
|
|
||||||
|
def to_yaml(self, yaml_path: str) -> None:
|
||||||
|
"""Save configuration to YAML file."""
|
||||||
|
Path(yaml_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(yaml_path, "w") as f:
|
||||||
|
yaml.dump(self.__dict__, f, default_flow_style=False, sort_keys=False)
|
||||||
|
|
||||||
|
def validate(self) -> List[str]:
|
||||||
|
"""Validate configuration and return list of warnings."""
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
# Check split ratios
|
||||||
|
total_split = self.train_split + self.val_split + self.test_split
|
||||||
|
if abs(total_split - 1.0) > 0.01:
|
||||||
|
warnings.append(
|
||||||
|
f"Split ratios sum to {total_split}, expected 1.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check batch construction
|
||||||
|
effective_batch = self.batch_size * self.gradient_accumulation_steps
|
||||||
|
if effective_batch < 64:
|
||||||
|
warnings.append(
|
||||||
|
f"Effective batch size ({effective_batch}) is small for contrastive learning. "
|
||||||
|
"Consider increasing batch_size or gradient_accumulation_steps."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check LoRA config
|
||||||
|
if self.lora_r > 0 and self.lora_alpha < self.lora_r:
|
||||||
|
warnings.append(
|
||||||
|
f"lora_alpha ({self.lora_alpha}) < lora_r ({self.lora_r}). "
|
||||||
|
"This may reduce LoRA effectiveness."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check freeze layers
|
||||||
|
if self.freeze_layers < 0:
|
||||||
|
warnings.append("freeze_layers should be >= 0")
|
||||||
|
|
||||||
|
# Check temperature
|
||||||
|
if self.temperature <= 0:
|
||||||
|
warnings.append("temperature must be positive")
|
||||||
|
elif self.temperature > 1.0:
|
||||||
|
warnings.append(
|
||||||
|
f"temperature ({self.temperature}) is high. "
|
||||||
|
"Typical values are 0.05-0.1."
|
||||||
|
)
|
||||||
|
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
@property
|
||||||
|
def effective_batch_size(self) -> int:
|
||||||
|
"""Calculate effective batch size with gradient accumulation."""
|
||||||
|
return self.batch_size * self.gradient_accumulation_steps
|
||||||
|
|
||||||
|
@property
|
||||||
|
def samples_per_batch(self) -> int:
|
||||||
|
"""Total samples in one batch (logos_per_batch * samples_per_logo)."""
|
||||||
|
return self.logos_per_batch * self.samples_per_logo
|
||||||
467
training/dataset.py
Normal file
467
training/dataset.py
Normal file
@ -0,0 +1,467 @@
|
|||||||
|
"""
|
||||||
|
Dataset classes for contrastive learning on logo images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
# CLIP normalization values
|
||||||
|
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
||||||
|
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_transforms(strength: str = "medium") -> transforms.Compose:
|
||||||
|
"""
|
||||||
|
Get training data augmentation transforms.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strength: Augmentation strength - "light", "medium", or "strong"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Composed transforms for training
|
||||||
|
"""
|
||||||
|
if strength == "light":
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.RandomHorizontalFlip(p=0.5),
|
||||||
|
transforms.ColorJitter(brightness=0.1, contrast=0.1),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||||
|
])
|
||||||
|
elif strength == "medium":
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.RandomHorizontalFlip(p=0.5),
|
||||||
|
transforms.RandomRotation(degrees=15),
|
||||||
|
transforms.ColorJitter(
|
||||||
|
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
|
||||||
|
),
|
||||||
|
transforms.RandomAffine(
|
||||||
|
degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
|
||||||
|
),
|
||||||
|
transforms.RandomGrayscale(p=0.1),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||||
|
])
|
||||||
|
else: # strong
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.Resize((256, 256)),
|
||||||
|
transforms.RandomCrop(224),
|
||||||
|
transforms.RandomHorizontalFlip(p=0.5),
|
||||||
|
transforms.RandomVerticalFlip(p=0.1),
|
||||||
|
transforms.RandomRotation(degrees=30),
|
||||||
|
transforms.ColorJitter(
|
||||||
|
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
|
||||||
|
),
|
||||||
|
transforms.RandomAffine(
|
||||||
|
degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10
|
||||||
|
),
|
||||||
|
transforms.RandomGrayscale(p=0.2),
|
||||||
|
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
def get_val_transforms() -> transforms.Compose:
|
||||||
|
"""Get validation/test transforms (no augmentation)."""
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
|
class LogoDataset:
|
||||||
|
"""
|
||||||
|
Manages logo data from the SQLite database.
|
||||||
|
|
||||||
|
Handles loading logo-to-image mappings and splitting by logo brand.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: str,
|
||||||
|
reference_dir: str,
|
||||||
|
train_split: float = 0.7,
|
||||||
|
val_split: float = 0.15,
|
||||||
|
test_split: float = 0.15,
|
||||||
|
seed: int = 42,
|
||||||
|
):
|
||||||
|
self.db_path = Path(db_path)
|
||||||
|
self.reference_dir = Path(reference_dir)
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
# Load logo-to-images mapping from database
|
||||||
|
self.logo_to_images = self._load_logo_mappings()
|
||||||
|
self.all_logos = list(self.logo_to_images.keys())
|
||||||
|
|
||||||
|
# Create logo-level splits
|
||||||
|
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
||||||
|
train_split, val_split, test_split
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
|
||||||
|
"""Load logo name to image paths mapping from database."""
|
||||||
|
conn = sqlite3.connect(self.db_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
cursor.execute("""
|
||||||
|
SELECT ln.name, rl.filename
|
||||||
|
FROM reference_logos rl
|
||||||
|
JOIN logo_names ln ON rl.logo_name_id = ln.id
|
||||||
|
ORDER BY ln.name
|
||||||
|
""")
|
||||||
|
|
||||||
|
logo_to_images: Dict[str, List[Path]] = {}
|
||||||
|
for logo_name, filename in cursor.fetchall():
|
||||||
|
if logo_name not in logo_to_images:
|
||||||
|
logo_to_images[logo_name] = []
|
||||||
|
logo_to_images[logo_name].append(self.reference_dir / filename)
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
return logo_to_images
|
||||||
|
|
||||||
|
def _split_logos(
|
||||||
|
self,
|
||||||
|
train_split: float,
|
||||||
|
val_split: float,
|
||||||
|
test_split: float,
|
||||||
|
) -> Tuple[List[str], List[str], List[str]]:
|
||||||
|
"""Split logos at brand level for train/val/test."""
|
||||||
|
random.seed(self.seed)
|
||||||
|
logos = self.all_logos.copy()
|
||||||
|
random.shuffle(logos)
|
||||||
|
|
||||||
|
n = len(logos)
|
||||||
|
train_end = int(n * train_split)
|
||||||
|
val_end = train_end + int(n * val_split)
|
||||||
|
|
||||||
|
train_logos = logos[:train_end]
|
||||||
|
val_logos = logos[train_end:val_end]
|
||||||
|
test_logos = logos[val_end:]
|
||||||
|
|
||||||
|
return train_logos, val_logos, test_logos
|
||||||
|
|
||||||
|
def get_split_info(self) -> Dict[str, int]:
|
||||||
|
"""Return information about the splits."""
|
||||||
|
return {
|
||||||
|
"total_logos": len(self.all_logos),
|
||||||
|
"train_logos": len(self.train_logos),
|
||||||
|
"val_logos": len(self.val_logos),
|
||||||
|
"test_logos": len(self.test_logos),
|
||||||
|
"train_images": sum(
|
||||||
|
len(self.logo_to_images[l]) for l in self.train_logos
|
||||||
|
),
|
||||||
|
"val_images": sum(
|
||||||
|
len(self.logo_to_images[l]) for l in self.val_logos
|
||||||
|
),
|
||||||
|
"test_images": sum(
|
||||||
|
len(self.logo_to_images[l]) for l in self.test_logos
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class LogoContrastiveDataset(Dataset):
|
||||||
|
"""
|
||||||
|
Dataset for contrastive learning on logos.
|
||||||
|
|
||||||
|
Each __getitem__ call returns a batch of images organized for contrastive
|
||||||
|
learning: K different logos with M samples each, ensuring positive pairs
|
||||||
|
exist within each batch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logo_data: LogoDataset,
|
||||||
|
split: str = "train",
|
||||||
|
logos_per_batch: int = 32,
|
||||||
|
samples_per_logo: int = 4,
|
||||||
|
transform: Optional[transforms.Compose] = None,
|
||||||
|
batches_per_epoch: int = 1000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the contrastive dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logo_data: LogoDataset instance with logo mappings
|
||||||
|
split: One of "train", "val", or "test"
|
||||||
|
logos_per_batch: Number of different logos per batch
|
||||||
|
samples_per_logo: Number of samples for each logo
|
||||||
|
transform: Image transforms to apply
|
||||||
|
batches_per_epoch: Number of batches per epoch
|
||||||
|
"""
|
||||||
|
self.logo_data = logo_data
|
||||||
|
self.logos_per_batch = logos_per_batch
|
||||||
|
self.samples_per_logo = samples_per_logo
|
||||||
|
self.transform = transform
|
||||||
|
self.batches_per_epoch = batches_per_epoch
|
||||||
|
|
||||||
|
# Get logos for this split
|
||||||
|
if split == "train":
|
||||||
|
self.logos = logo_data.train_logos
|
||||||
|
elif split == "val":
|
||||||
|
self.logos = logo_data.val_logos
|
||||||
|
else:
|
||||||
|
self.logos = logo_data.test_logos
|
||||||
|
|
||||||
|
# Filter logos with enough samples
|
||||||
|
self.valid_logos = [
|
||||||
|
logo for logo in self.logos
|
||||||
|
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
|
||||||
|
]
|
||||||
|
|
||||||
|
# For logos with fewer samples, we'll use with replacement
|
||||||
|
self.logos_needing_replacement = [
|
||||||
|
logo for logo in self.logos
|
||||||
|
if len(logo_data.logo_to_images[logo]) < samples_per_logo
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create label mapping
|
||||||
|
self.logo_to_label = {
|
||||||
|
logo: idx for idx, logo in enumerate(self.logos)
|
||||||
|
}
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return self.batches_per_epoch
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Get a batch of images for contrastive learning.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
images: Tensor of shape [K*M, 3, 224, 224]
|
||||||
|
labels: Tensor of shape [K*M] with logo class indices
|
||||||
|
"""
|
||||||
|
images = []
|
||||||
|
labels = []
|
||||||
|
|
||||||
|
# Sample K logos for this batch
|
||||||
|
k = min(self.logos_per_batch, len(self.logos))
|
||||||
|
batch_logos = random.sample(self.logos, k)
|
||||||
|
|
||||||
|
for logo in batch_logos:
|
||||||
|
logo_images = self.logo_data.logo_to_images[logo]
|
||||||
|
|
||||||
|
# Sample M images for this logo
|
||||||
|
if len(logo_images) >= self.samples_per_logo:
|
||||||
|
sampled_paths = random.sample(logo_images, self.samples_per_logo)
|
||||||
|
else:
|
||||||
|
# Sample with replacement if not enough images
|
||||||
|
sampled_paths = random.choices(
|
||||||
|
logo_images, k=self.samples_per_logo
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load and transform images
|
||||||
|
for img_path in sampled_paths:
|
||||||
|
try:
|
||||||
|
img = Image.open(img_path).convert("RGB")
|
||||||
|
if self.transform:
|
||||||
|
img = self.transform(img)
|
||||||
|
else:
|
||||||
|
img = get_val_transforms()(img)
|
||||||
|
images.append(img)
|
||||||
|
labels.append(self.logo_to_label[logo])
|
||||||
|
except Exception as e:
|
||||||
|
# Skip problematic images, sample another
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Stack into tensors
|
||||||
|
if len(images) == 0:
|
||||||
|
# Fallback: return dummy batch
|
||||||
|
return (
|
||||||
|
torch.zeros(1, 3, 224, 224),
|
||||||
|
torch.zeros(1, dtype=torch.long),
|
||||||
|
)
|
||||||
|
|
||||||
|
images_tensor = torch.stack(images)
|
||||||
|
labels_tensor = torch.tensor(labels, dtype=torch.long)
|
||||||
|
|
||||||
|
return images_tensor, labels_tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BalancedBatchSampler(Sampler):
|
||||||
|
"""
|
||||||
|
Sampler that ensures each batch has a balanced distribution of logos.
|
||||||
|
|
||||||
|
Used with a flattened dataset where each sample is a single image.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logo_labels: List[int],
|
||||||
|
logos_per_batch: int,
|
||||||
|
samples_per_logo: int,
|
||||||
|
num_batches: int,
|
||||||
|
):
|
||||||
|
self.logo_labels = logo_labels
|
||||||
|
self.logos_per_batch = logos_per_batch
|
||||||
|
self.samples_per_logo = samples_per_logo
|
||||||
|
self.num_batches = num_batches
|
||||||
|
|
||||||
|
# Group indices by logo
|
||||||
|
self.logo_to_indices: Dict[int, List[int]] = {}
|
||||||
|
for idx, label in enumerate(logo_labels):
|
||||||
|
if label not in self.logo_to_indices:
|
||||||
|
self.logo_to_indices[label] = []
|
||||||
|
self.logo_to_indices[label].append(idx)
|
||||||
|
|
||||||
|
self.all_logos = list(self.logo_to_indices.keys())
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for _ in range(self.num_batches):
|
||||||
|
batch_indices = []
|
||||||
|
|
||||||
|
# Sample logos for this batch
|
||||||
|
logos = random.sample(
|
||||||
|
self.all_logos,
|
||||||
|
min(self.logos_per_batch, len(self.all_logos)),
|
||||||
|
)
|
||||||
|
|
||||||
|
for logo in logos:
|
||||||
|
indices = self.logo_to_indices[logo]
|
||||||
|
if len(indices) >= self.samples_per_logo:
|
||||||
|
sampled = random.sample(indices, self.samples_per_logo)
|
||||||
|
else:
|
||||||
|
sampled = random.choices(indices, k=self.samples_per_logo)
|
||||||
|
batch_indices.extend(sampled)
|
||||||
|
|
||||||
|
yield batch_indices
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_batches
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataloaders(
|
||||||
|
db_path: str,
|
||||||
|
reference_dir: str,
|
||||||
|
batch_size: int = 16,
|
||||||
|
logos_per_batch: int = 32,
|
||||||
|
samples_per_logo: int = 4,
|
||||||
|
num_workers: int = 4,
|
||||||
|
train_split: float = 0.7,
|
||||||
|
val_split: float = 0.15,
|
||||||
|
test_split: float = 0.15,
|
||||||
|
seed: int = 42,
|
||||||
|
augmentation_strength: str = "medium",
|
||||||
|
batches_per_epoch: int = 1000,
|
||||||
|
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
|
||||||
|
"""
|
||||||
|
Create train, validation, and optionally test dataloaders.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database
|
||||||
|
reference_dir: Directory containing reference logo images
|
||||||
|
batch_size: Not used directly (see logos_per_batch and samples_per_logo)
|
||||||
|
logos_per_batch: Number of different logos per batch
|
||||||
|
samples_per_logo: Samples per logo in batch
|
||||||
|
num_workers: Number of data loading workers
|
||||||
|
train_split: Fraction for training
|
||||||
|
val_split: Fraction for validation
|
||||||
|
test_split: Fraction for testing
|
||||||
|
seed: Random seed
|
||||||
|
augmentation_strength: "light", "medium", or "strong"
|
||||||
|
batches_per_epoch: Number of batches per training epoch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (train_loader, val_loader, test_loader)
|
||||||
|
"""
|
||||||
|
# Load logo data
|
||||||
|
logo_data = LogoDataset(
|
||||||
|
db_path=db_path,
|
||||||
|
reference_dir=reference_dir,
|
||||||
|
train_split=train_split,
|
||||||
|
val_split=val_split,
|
||||||
|
test_split=test_split,
|
||||||
|
seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print split info
|
||||||
|
split_info = logo_data.get_split_info()
|
||||||
|
print(f"Dataset loaded:")
|
||||||
|
print(f" Total logos: {split_info['total_logos']}")
|
||||||
|
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
|
||||||
|
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
|
||||||
|
print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images")
|
||||||
|
|
||||||
|
# Create datasets
|
||||||
|
train_dataset = LogoContrastiveDataset(
|
||||||
|
logo_data=logo_data,
|
||||||
|
split="train",
|
||||||
|
logos_per_batch=logos_per_batch,
|
||||||
|
samples_per_logo=samples_per_logo,
|
||||||
|
transform=get_train_transforms(augmentation_strength),
|
||||||
|
batches_per_epoch=batches_per_epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataset = LogoContrastiveDataset(
|
||||||
|
logo_data=logo_data,
|
||||||
|
split="val",
|
||||||
|
logos_per_batch=logos_per_batch,
|
||||||
|
samples_per_logo=samples_per_logo,
|
||||||
|
transform=get_val_transforms(),
|
||||||
|
batches_per_epoch=batches_per_epoch // 10, # Fewer val batches
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataset = LogoContrastiveDataset(
|
||||||
|
logo_data=logo_data,
|
||||||
|
split="test",
|
||||||
|
logos_per_batch=logos_per_batch,
|
||||||
|
samples_per_logo=samples_per_logo,
|
||||||
|
transform=get_val_transforms(),
|
||||||
|
batches_per_epoch=batches_per_epoch // 10,
|
||||||
|
) if test_split > 0 else None
|
||||||
|
|
||||||
|
# Create dataloaders
|
||||||
|
# Note: batch_size=1 because each __getitem__ already returns a batch
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=_collate_contrastive_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=_collate_contrastive_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_loader = None
|
||||||
|
if test_dataset is not None:
|
||||||
|
test_loader = DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=_collate_contrastive_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_loader, val_loader, test_loader
|
||||||
|
|
||||||
|
|
||||||
|
def _collate_contrastive_batch(
|
||||||
|
batch: List[Tuple[torch.Tensor, torch.Tensor]]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Collate function that unpacks pre-batched data.
|
||||||
|
|
||||||
|
Since LogoContrastiveDataset already returns batched data,
|
||||||
|
we just squeeze the outer dimension.
|
||||||
|
"""
|
||||||
|
images, labels = batch[0]
|
||||||
|
return images, labels
|
||||||
339
training/evaluation.py
Normal file
339
training/evaluation.py
Normal file
@ -0,0 +1,339 @@
|
|||||||
|
"""
|
||||||
|
Evaluation metrics for embedding quality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingEvaluator:
|
||||||
|
"""
|
||||||
|
Evaluator for embedding quality metrics.
|
||||||
|
|
||||||
|
Computes metrics that indicate how well the embeddings
|
||||||
|
separate different logo classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def compute_metrics(
|
||||||
|
self,
|
||||||
|
embeddings: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Compute embedding quality metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: [N, D] L2-normalized embeddings
|
||||||
|
labels: [N] integer class labels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with metric names and values
|
||||||
|
"""
|
||||||
|
device = embeddings.device
|
||||||
|
batch_size = embeddings.shape[0]
|
||||||
|
|
||||||
|
if batch_size <= 1:
|
||||||
|
return {
|
||||||
|
"mean_pos_sim": 0.0,
|
||||||
|
"mean_neg_sim": 0.0,
|
||||||
|
"separation": 0.0,
|
||||||
|
"recall_at_1": 0.0,
|
||||||
|
"recall_at_5": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Compute similarity matrix
|
||||||
|
similarity = embeddings @ embeddings.T
|
||||||
|
|
||||||
|
# Create masks
|
||||||
|
labels_col = labels.unsqueeze(0)
|
||||||
|
labels_row = labels.unsqueeze(1)
|
||||||
|
positive_mask = (labels_row == labels_col).float()
|
||||||
|
negative_mask = 1 - positive_mask
|
||||||
|
|
||||||
|
# Remove diagonal from positive mask
|
||||||
|
identity = torch.eye(batch_size, device=device)
|
||||||
|
positive_mask = positive_mask - identity
|
||||||
|
|
||||||
|
# Count pairs
|
||||||
|
num_positives = positive_mask.sum()
|
||||||
|
num_negatives = negative_mask.sum()
|
||||||
|
|
||||||
|
# Mean positive similarity (excluding self)
|
||||||
|
if num_positives > 0:
|
||||||
|
pos_sims = (similarity * positive_mask).sum() / num_positives
|
||||||
|
mean_pos_sim = pos_sims.item()
|
||||||
|
else:
|
||||||
|
mean_pos_sim = 0.0
|
||||||
|
|
||||||
|
# Mean negative similarity
|
||||||
|
if num_negatives > 0:
|
||||||
|
neg_sims = (similarity * negative_mask).sum() / num_negatives
|
||||||
|
mean_neg_sim = neg_sims.item()
|
||||||
|
else:
|
||||||
|
mean_neg_sim = 0.0
|
||||||
|
|
||||||
|
# Separation: gap between positive and negative similarity
|
||||||
|
separation = mean_pos_sim - mean_neg_sim
|
||||||
|
|
||||||
|
# Recall@K metrics
|
||||||
|
recall_at_1 = self._compute_recall_at_k(similarity, labels, k=1)
|
||||||
|
recall_at_5 = self._compute_recall_at_k(similarity, labels, k=5)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"mean_pos_sim": mean_pos_sim,
|
||||||
|
"mean_neg_sim": mean_neg_sim,
|
||||||
|
"separation": separation,
|
||||||
|
"recall_at_1": recall_at_1,
|
||||||
|
"recall_at_5": recall_at_5,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _compute_recall_at_k(
|
||||||
|
self,
|
||||||
|
similarity: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
k: int = 1,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute Recall@K for nearest neighbor retrieval.
|
||||||
|
|
||||||
|
For each sample, check if the k nearest neighbors (excluding self)
|
||||||
|
contain at least one sample with the same label.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
similarity: [N, N] similarity matrix
|
||||||
|
labels: [N] class labels
|
||||||
|
k: Number of neighbors to consider
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Recall@K score (0 to 1)
|
||||||
|
"""
|
||||||
|
batch_size = similarity.shape[0]
|
||||||
|
if batch_size <= 1:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Mask out self-similarity
|
||||||
|
similarity = similarity.clone()
|
||||||
|
similarity.fill_diagonal_(float("-inf"))
|
||||||
|
|
||||||
|
# Get top-k indices
|
||||||
|
_, top_k_indices = similarity.topk(min(k, batch_size - 1), dim=1)
|
||||||
|
|
||||||
|
# Check if any of top-k have same label
|
||||||
|
correct = 0
|
||||||
|
for i in range(batch_size):
|
||||||
|
query_label = labels[i]
|
||||||
|
retrieved_labels = labels[top_k_indices[i]]
|
||||||
|
if (retrieved_labels == query_label).any():
|
||||||
|
correct += 1
|
||||||
|
|
||||||
|
return correct / batch_size
|
||||||
|
|
||||||
|
def compute_detailed_metrics(
|
||||||
|
self,
|
||||||
|
embeddings: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
label_names: Optional[List[str]] = None,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Compute detailed per-class metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: [N, D] embeddings
|
||||||
|
labels: [N] class labels
|
||||||
|
label_names: Optional list of label names
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with detailed metrics including per-class stats
|
||||||
|
"""
|
||||||
|
basic_metrics = self.compute_metrics(embeddings, labels)
|
||||||
|
|
||||||
|
# Per-class statistics
|
||||||
|
unique_labels = labels.unique()
|
||||||
|
per_class_stats = {}
|
||||||
|
|
||||||
|
similarity = embeddings @ embeddings.T
|
||||||
|
|
||||||
|
for label in unique_labels:
|
||||||
|
mask = labels == label
|
||||||
|
class_embeddings = embeddings[mask]
|
||||||
|
class_size = mask.sum().item()
|
||||||
|
|
||||||
|
if class_size > 1:
|
||||||
|
# Intra-class similarity
|
||||||
|
class_sim = class_embeddings @ class_embeddings.T
|
||||||
|
# Exclude diagonal
|
||||||
|
mask_diag = ~torch.eye(class_size, dtype=torch.bool, device=class_sim.device)
|
||||||
|
intra_sim = class_sim[mask_diag].mean().item()
|
||||||
|
else:
|
||||||
|
intra_sim = 1.0
|
||||||
|
|
||||||
|
# Inter-class similarity (to other classes)
|
||||||
|
other_mask = labels != label
|
||||||
|
if other_mask.any():
|
||||||
|
inter_sim = similarity[mask][:, other_mask].mean().item()
|
||||||
|
else:
|
||||||
|
inter_sim = 0.0
|
||||||
|
|
||||||
|
class_name = label_names[label.item()] if label_names else str(label.item())
|
||||||
|
per_class_stats[class_name] = {
|
||||||
|
"size": class_size,
|
||||||
|
"intra_class_sim": intra_sim,
|
||||||
|
"inter_class_sim": inter_sim,
|
||||||
|
"class_separation": intra_sim - inter_sim,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Aggregate per-class stats
|
||||||
|
if per_class_stats:
|
||||||
|
separations = [s["class_separation"] for s in per_class_stats.values()]
|
||||||
|
min_separation = min(separations)
|
||||||
|
max_separation = max(separations)
|
||||||
|
std_separation = np.std(separations)
|
||||||
|
else:
|
||||||
|
min_separation = max_separation = std_separation = 0.0
|
||||||
|
|
||||||
|
return {
|
||||||
|
**basic_metrics,
|
||||||
|
"per_class": per_class_stats,
|
||||||
|
"min_class_separation": min_separation,
|
||||||
|
"max_class_separation": max_separation,
|
||||||
|
"std_class_separation": std_separation,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SimilarityAnalyzer:
|
||||||
|
"""
|
||||||
|
Analyze similarity distributions for debugging and tuning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def analyze_similarity_distribution(
|
||||||
|
embeddings: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
) -> Dict[str, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Get similarity distributions for positive and negative pairs.
|
||||||
|
|
||||||
|
Useful for choosing appropriate thresholds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: [N, D] embeddings
|
||||||
|
labels: [N] class labels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with 'positive_sims' and 'negative_sims' arrays
|
||||||
|
"""
|
||||||
|
similarity = (embeddings @ embeddings.T).cpu().numpy()
|
||||||
|
labels_np = labels.cpu().numpy()
|
||||||
|
|
||||||
|
batch_size = len(labels_np)
|
||||||
|
positive_sims = []
|
||||||
|
negative_sims = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
for j in range(i + 1, batch_size):
|
||||||
|
if labels_np[i] == labels_np[j]:
|
||||||
|
positive_sims.append(similarity[i, j])
|
||||||
|
else:
|
||||||
|
negative_sims.append(similarity[i, j])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"positive_sims": np.array(positive_sims),
|
||||||
|
"negative_sims": np.array(negative_sims),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_hard_pairs(
|
||||||
|
embeddings: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
n_hard: int = 10,
|
||||||
|
) -> Tuple[List[Tuple[int, int, float]], List[Tuple[int, int, float]]]:
|
||||||
|
"""
|
||||||
|
Find hardest positive and negative pairs.
|
||||||
|
|
||||||
|
Hard positives: same label but low similarity
|
||||||
|
Hard negatives: different label but high similarity
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: [N, D] embeddings
|
||||||
|
labels: [N] class labels
|
||||||
|
n_hard: Number of hard pairs to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (hard_positives, hard_negatives)
|
||||||
|
Each is a list of (idx1, idx2, similarity) tuples
|
||||||
|
"""
|
||||||
|
similarity = embeddings @ embeddings.T
|
||||||
|
batch_size = len(labels)
|
||||||
|
|
||||||
|
hard_positives = [] # Low similarity, same label
|
||||||
|
hard_negatives = [] # High similarity, different label
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
for j in range(i + 1, batch_size):
|
||||||
|
sim = similarity[i, j].item()
|
||||||
|
if labels[i] == labels[j]:
|
||||||
|
hard_positives.append((i, j, sim))
|
||||||
|
else:
|
||||||
|
hard_negatives.append((i, j, sim))
|
||||||
|
|
||||||
|
# Sort: hard positives by ascending similarity (lowest first)
|
||||||
|
hard_positives.sort(key=lambda x: x[2])
|
||||||
|
|
||||||
|
# Sort: hard negatives by descending similarity (highest first)
|
||||||
|
hard_negatives.sort(key=lambda x: -x[2])
|
||||||
|
|
||||||
|
return hard_positives[:n_hard], hard_negatives[:n_hard]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_confusion_pairs(
|
||||||
|
embeddings: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
label_names: Optional[List[str]] = None,
|
||||||
|
top_k: int = 10,
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Find pairs of classes that are most confused (highest cross-class similarity).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: [N, D] embeddings
|
||||||
|
labels: [N] class labels
|
||||||
|
label_names: Optional label names
|
||||||
|
top_k: Number of confused pairs to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with class pairs and their similarity
|
||||||
|
"""
|
||||||
|
unique_labels = labels.unique()
|
||||||
|
class_centroids = {}
|
||||||
|
|
||||||
|
# Compute class centroids
|
||||||
|
for label in unique_labels:
|
||||||
|
mask = labels == label
|
||||||
|
centroid = embeddings[mask].mean(dim=0)
|
||||||
|
centroid = F.normalize(centroid, dim=0)
|
||||||
|
class_centroids[label.item()] = centroid
|
||||||
|
|
||||||
|
# Compute pairwise centroid similarities
|
||||||
|
confusions = []
|
||||||
|
label_list = list(class_centroids.keys())
|
||||||
|
|
||||||
|
for i, label1 in enumerate(label_list):
|
||||||
|
for label2 in label_list[i + 1:]:
|
||||||
|
sim = (class_centroids[label1] @ class_centroids[label2]).item()
|
||||||
|
name1 = label_names[label1] if label_names else str(label1)
|
||||||
|
name2 = label_names[label2] if label_names else str(label2)
|
||||||
|
confusions.append({
|
||||||
|
"class1": name1,
|
||||||
|
"class2": name2,
|
||||||
|
"label1": label1,
|
||||||
|
"label2": label2,
|
||||||
|
"centroid_similarity": sim,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Sort by similarity (highest first)
|
||||||
|
confusions.sort(key=lambda x: -x["centroid_similarity"])
|
||||||
|
|
||||||
|
return confusions[:top_k]
|
||||||
326
training/losses.py
Normal file
326
training/losses.py
Normal file
@ -0,0 +1,326 @@
|
|||||||
|
"""
|
||||||
|
Loss functions for contrastive learning on logo embeddings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class InfoNCELoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Normalized Temperature-scaled Cross Entropy Loss (InfoNCE).
|
||||||
|
|
||||||
|
This is the contrastive loss used in CLIP training. It maximizes
|
||||||
|
similarity between embeddings of the same logo class while
|
||||||
|
minimizing similarity to embeddings of different classes.
|
||||||
|
|
||||||
|
For a batch with N samples:
|
||||||
|
- Each sample is an anchor
|
||||||
|
- Positive pairs: samples with the same label
|
||||||
|
- Negative pairs: samples with different labels
|
||||||
|
|
||||||
|
The loss for each anchor is:
|
||||||
|
-log(sum(exp(sim(anchor, pos)/temp)) / sum(exp(sim(anchor, all)/temp)))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature: float = 0.07):
|
||||||
|
"""
|
||||||
|
Initialize InfoNCE loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
temperature: Scaling factor for similarities (0.05-0.1 typical).
|
||||||
|
Lower temperature makes the distribution sharper.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
embeddings: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute InfoNCE loss for a batch of embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: [N, D] L2-normalized embeddings
|
||||||
|
labels: [N] integer logo class labels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scalar loss value
|
||||||
|
"""
|
||||||
|
device = embeddings.device
|
||||||
|
batch_size = embeddings.shape[0]
|
||||||
|
|
||||||
|
if batch_size <= 1:
|
||||||
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
# Compute similarity matrix [N, N]
|
||||||
|
# Since embeddings are L2-normalized, dot product = cosine similarity
|
||||||
|
similarity = embeddings @ embeddings.T / self.temperature
|
||||||
|
|
||||||
|
# Create positive mask: same label = 1, different = 0
|
||||||
|
labels_col = labels.unsqueeze(0) # [1, N]
|
||||||
|
labels_row = labels.unsqueeze(1) # [N, 1]
|
||||||
|
positive_mask = (labels_row == labels_col).float() # [N, N]
|
||||||
|
|
||||||
|
# Remove self-similarity from positives (diagonal)
|
||||||
|
identity = torch.eye(batch_size, device=device)
|
||||||
|
positive_mask = positive_mask - identity
|
||||||
|
|
||||||
|
# Count positives per anchor (avoid division by zero)
|
||||||
|
num_positives = positive_mask.sum(dim=1)
|
||||||
|
has_positives = num_positives > 0
|
||||||
|
|
||||||
|
# If no positives exist for any anchor, return zero loss
|
||||||
|
if not has_positives.any():
|
||||||
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
# Mask out self-similarity with large negative value
|
||||||
|
similarity = similarity - identity * 1e9
|
||||||
|
|
||||||
|
# Compute log-softmax over similarities
|
||||||
|
log_softmax = F.log_softmax(similarity, dim=1)
|
||||||
|
|
||||||
|
# Sum log probabilities of positive pairs
|
||||||
|
positive_log_probs = (log_softmax * positive_mask).sum(dim=1)
|
||||||
|
|
||||||
|
# Average over number of positives (only for anchors with positives)
|
||||||
|
loss_per_anchor = torch.zeros(batch_size, device=device)
|
||||||
|
loss_per_anchor[has_positives] = (
|
||||||
|
-positive_log_probs[has_positives] / num_positives[has_positives]
|
||||||
|
)
|
||||||
|
|
||||||
|
return loss_per_anchor.mean()
|
||||||
|
|
||||||
|
|
||||||
|
class SupConLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Supervised Contrastive Loss.
|
||||||
|
|
||||||
|
Similar to InfoNCE but uses a different formulation that
|
||||||
|
considers each positive pair separately rather than averaging.
|
||||||
|
|
||||||
|
Reference: https://arxiv.org/abs/2004.11362
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature: float = 0.07):
|
||||||
|
super().__init__()
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
embeddings: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute Supervised Contrastive loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: [N, D] L2-normalized embeddings
|
||||||
|
labels: [N] integer logo class labels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scalar loss value
|
||||||
|
"""
|
||||||
|
device = embeddings.device
|
||||||
|
batch_size = embeddings.shape[0]
|
||||||
|
|
||||||
|
if batch_size <= 1:
|
||||||
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
# Compute similarity matrix
|
||||||
|
similarity = embeddings @ embeddings.T / self.temperature
|
||||||
|
|
||||||
|
# Create masks
|
||||||
|
labels_col = labels.unsqueeze(0)
|
||||||
|
labels_row = labels.unsqueeze(1)
|
||||||
|
positive_mask = (labels_row == labels_col).float()
|
||||||
|
identity = torch.eye(batch_size, device=device)
|
||||||
|
|
||||||
|
# Remove self from positives
|
||||||
|
positive_mask = positive_mask - identity
|
||||||
|
|
||||||
|
# Number of positives per anchor
|
||||||
|
num_positives = positive_mask.sum(dim=1)
|
||||||
|
has_positives = num_positives > 0
|
||||||
|
|
||||||
|
if not has_positives.any():
|
||||||
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
# For numerical stability, subtract max similarity
|
||||||
|
sim_max, _ = similarity.max(dim=1, keepdim=True)
|
||||||
|
similarity = similarity - sim_max.detach()
|
||||||
|
|
||||||
|
# Compute exp(similarity) with self masked out
|
||||||
|
exp_sim = torch.exp(similarity) * (1 - identity)
|
||||||
|
|
||||||
|
# Denominator: sum of exp over all pairs except self
|
||||||
|
log_prob = similarity - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)
|
||||||
|
|
||||||
|
# Mean of log-prob over positive pairs
|
||||||
|
mean_log_prob_pos = (positive_mask * log_prob).sum(dim=1) / (
|
||||||
|
num_positives + 1e-8
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loss is negative mean log probability
|
||||||
|
loss = -mean_log_prob_pos[has_positives].mean()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class TripletLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Triplet loss with online hard mining.
|
||||||
|
|
||||||
|
For each anchor:
|
||||||
|
- Hardest positive: most distant sample with same label
|
||||||
|
- Hardest negative: closest sample with different label
|
||||||
|
|
||||||
|
Loss = max(0, d(anchor, hardest_pos) - d(anchor, hardest_neg) + margin)
|
||||||
|
|
||||||
|
This is an alternative to InfoNCE for when batch sizes are small.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, margin: float = 0.3):
|
||||||
|
"""
|
||||||
|
Initialize Triplet loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
margin: Minimum required gap between positive and negative distances
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.margin = margin
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
embeddings: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute triplet loss with online hard mining.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embeddings: [N, D] L2-normalized embeddings
|
||||||
|
labels: [N] integer logo class labels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Scalar loss value
|
||||||
|
"""
|
||||||
|
device = embeddings.device
|
||||||
|
batch_size = embeddings.shape[0]
|
||||||
|
|
||||||
|
if batch_size <= 1:
|
||||||
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
# Compute pairwise cosine distances (1 - cosine_similarity)
|
||||||
|
# For normalized vectors: distance = 1 - dot_product
|
||||||
|
similarity = embeddings @ embeddings.T
|
||||||
|
distances = 1 - similarity
|
||||||
|
|
||||||
|
# Create masks
|
||||||
|
labels_col = labels.unsqueeze(0)
|
||||||
|
labels_row = labels.unsqueeze(1)
|
||||||
|
positive_mask = (labels_row == labels_col).float()
|
||||||
|
negative_mask = 1 - positive_mask
|
||||||
|
|
||||||
|
# Remove self from positives (diagonal)
|
||||||
|
identity = torch.eye(batch_size, device=device)
|
||||||
|
positive_mask = positive_mask - identity
|
||||||
|
|
||||||
|
# Check if we have any valid triplets
|
||||||
|
has_positives = positive_mask.sum(dim=1) > 0
|
||||||
|
has_negatives = negative_mask.sum(dim=1) > 0
|
||||||
|
valid_anchors = has_positives & has_negatives
|
||||||
|
|
||||||
|
if not valid_anchors.any():
|
||||||
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||||
|
|
||||||
|
# For each anchor, find hardest positive (max distance among positives)
|
||||||
|
# Set negatives to -inf so they don't affect max
|
||||||
|
pos_distances = distances.clone()
|
||||||
|
pos_distances[positive_mask == 0] = float("-inf")
|
||||||
|
hardest_positive, _ = pos_distances.max(dim=1)
|
||||||
|
|
||||||
|
# For each anchor, find hardest negative (min distance among negatives)
|
||||||
|
# Set positives to inf so they don't affect min
|
||||||
|
neg_distances = distances.clone()
|
||||||
|
neg_distances[negative_mask == 0] = float("inf")
|
||||||
|
hardest_negative, _ = neg_distances.min(dim=1)
|
||||||
|
|
||||||
|
# Triplet loss: want positive to be closer than negative by margin
|
||||||
|
triplet_loss = F.relu(
|
||||||
|
hardest_positive - hardest_negative + self.margin
|
||||||
|
)
|
||||||
|
|
||||||
|
# Average over valid anchors only
|
||||||
|
loss = triplet_loss[valid_anchors].mean()
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class CombinedLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Combined loss function with weighted InfoNCE and Triplet losses.
|
||||||
|
|
||||||
|
Can help stabilize training by combining the benefits of both losses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
temperature: float = 0.07,
|
||||||
|
triplet_margin: float = 0.3,
|
||||||
|
infonce_weight: float = 1.0,
|
||||||
|
triplet_weight: float = 0.5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.infonce = InfoNCELoss(temperature=temperature)
|
||||||
|
self.triplet = TripletLoss(margin=triplet_margin)
|
||||||
|
self.infonce_weight = infonce_weight
|
||||||
|
self.triplet_weight = triplet_weight
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
embeddings: torch.Tensor,
|
||||||
|
labels: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
infonce_loss = self.infonce(embeddings, labels)
|
||||||
|
triplet_loss = self.triplet(embeddings, labels)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.infonce_weight * infonce_loss +
|
||||||
|
self.triplet_weight * triplet_loss
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_loss_function(
|
||||||
|
loss_type: str = "infonce",
|
||||||
|
temperature: float = 0.07,
|
||||||
|
triplet_margin: float = 0.3,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Factory function to create loss function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_type: One of "infonce", "supcon", "triplet", or "combined"
|
||||||
|
temperature: Temperature for InfoNCE/SupCon
|
||||||
|
triplet_margin: Margin for triplet loss
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loss function module
|
||||||
|
"""
|
||||||
|
if loss_type == "infonce":
|
||||||
|
return InfoNCELoss(temperature=temperature)
|
||||||
|
elif loss_type == "supcon":
|
||||||
|
return SupConLoss(temperature=temperature)
|
||||||
|
elif loss_type == "triplet":
|
||||||
|
return TripletLoss(margin=triplet_margin)
|
||||||
|
elif loss_type == "combined":
|
||||||
|
return CombinedLoss(
|
||||||
|
temperature=temperature,
|
||||||
|
triplet_margin=triplet_margin,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown loss type: {loss_type}")
|
||||||
335
training/model.py
Normal file
335
training/model.py
Normal file
@ -0,0 +1,335 @@
|
|||||||
|
"""
|
||||||
|
Fine-tunable CLIP model wrapper with LoRA support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers import CLIPModel, CLIPProcessor
|
||||||
|
|
||||||
|
# Check if peft is available for LoRA
|
||||||
|
try:
|
||||||
|
from peft import LoraConfig, get_peft_model, PeftModel
|
||||||
|
PEFT_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
PEFT_AVAILABLE = False
|
||||||
|
LoraConfig = None
|
||||||
|
get_peft_model = None
|
||||||
|
PeftModel = None
|
||||||
|
|
||||||
|
|
||||||
|
class LogoFineTunedCLIP(nn.Module):
|
||||||
|
"""
|
||||||
|
CLIP vision encoder fine-tuned for logo similarity.
|
||||||
|
|
||||||
|
Preserves embedding interface for compatibility with DetectLogosDETR:
|
||||||
|
- Same embedding dimensionality (768 for ViT-L/14)
|
||||||
|
- L2 normalized outputs
|
||||||
|
- Works with existing get_image_features() pattern
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- LoRA for memory-efficient fine-tuning
|
||||||
|
- Layer freezing for transfer learning
|
||||||
|
- Gradient checkpointing for memory optimization
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vision_model: nn.Module,
|
||||||
|
lora_r: int = 16,
|
||||||
|
lora_alpha: int = 32,
|
||||||
|
lora_dropout: float = 0.1,
|
||||||
|
freeze_layers: int = 12,
|
||||||
|
use_gradient_checkpointing: bool = True,
|
||||||
|
add_projection_head: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the fine-tunable CLIP wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vision_model: CLIP vision model (CLIPVisionModel)
|
||||||
|
lora_r: Rank of LoRA low-rank matrices (0 to disable)
|
||||||
|
lora_alpha: LoRA scaling factor
|
||||||
|
lora_dropout: Dropout for LoRA layers
|
||||||
|
freeze_layers: Number of transformer layers to freeze (from bottom)
|
||||||
|
use_gradient_checkpointing: Enable gradient checkpointing
|
||||||
|
add_projection_head: Add trainable projection head
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.vision_model = vision_model
|
||||||
|
self.embedding_dim = vision_model.config.hidden_size
|
||||||
|
self.freeze_layers = freeze_layers
|
||||||
|
self.lora_r = lora_r
|
||||||
|
self.lora_alpha = lora_alpha
|
||||||
|
|
||||||
|
# Enable gradient checkpointing for memory efficiency
|
||||||
|
if use_gradient_checkpointing:
|
||||||
|
if hasattr(self.vision_model, "gradient_checkpointing_enable"):
|
||||||
|
self.vision_model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
# Freeze lower layers
|
||||||
|
self._freeze_layers(freeze_layers)
|
||||||
|
|
||||||
|
# Apply LoRA to attention layers in upper blocks
|
||||||
|
self.peft_applied = False
|
||||||
|
if PEFT_AVAILABLE and lora_r > 0:
|
||||||
|
self._apply_lora(lora_r, lora_alpha, lora_dropout)
|
||||||
|
self.peft_applied = True
|
||||||
|
elif lora_r > 0 and not PEFT_AVAILABLE:
|
||||||
|
print(
|
||||||
|
"Warning: peft not installed. LoRA disabled. "
|
||||||
|
"Install with: pip install peft"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional projection head for fine-tuning
|
||||||
|
self.add_projection_head = add_projection_head
|
||||||
|
if add_projection_head:
|
||||||
|
self.projection = nn.Sequential(
|
||||||
|
nn.Linear(self.embedding_dim, self.embedding_dim),
|
||||||
|
nn.LayerNorm(self.embedding_dim),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.projection = nn.Identity()
|
||||||
|
|
||||||
|
def _freeze_layers(self, num_layers: int) -> None:
|
||||||
|
"""Freeze the first N transformer layers and embeddings."""
|
||||||
|
if num_layers <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Freeze embeddings
|
||||||
|
if hasattr(self.vision_model, "embeddings"):
|
||||||
|
for param in self.vision_model.embeddings.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Freeze specified number of encoder layers
|
||||||
|
if hasattr(self.vision_model, "encoder"):
|
||||||
|
for i, layer in enumerate(self.vision_model.encoder.layers):
|
||||||
|
if i < num_layers:
|
||||||
|
for param in layer.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def _apply_lora(
|
||||||
|
self,
|
||||||
|
r: int,
|
||||||
|
alpha: int,
|
||||||
|
dropout: float,
|
||||||
|
) -> None:
|
||||||
|
"""Apply LoRA adapters to attention layers."""
|
||||||
|
if not PEFT_AVAILABLE:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Configure LoRA for vision transformer
|
||||||
|
lora_config = LoraConfig(
|
||||||
|
r=r,
|
||||||
|
lora_alpha=alpha,
|
||||||
|
lora_dropout=dropout,
|
||||||
|
target_modules=["q_proj", "v_proj"],
|
||||||
|
bias="none",
|
||||||
|
modules_to_save=[], # Don't save any full modules
|
||||||
|
)
|
||||||
|
|
||||||
|
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Extract normalized embeddings for logo images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pixel_values: [batch, 3, 224, 224] preprocessed images
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
embeddings: [batch, embedding_dim] L2-normalized
|
||||||
|
"""
|
||||||
|
# Get vision features
|
||||||
|
outputs = self.vision_model(pixel_values=pixel_values)
|
||||||
|
|
||||||
|
# Use pooler output (CLS token projection) if available
|
||||||
|
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||||
|
features = outputs.pooler_output
|
||||||
|
else:
|
||||||
|
# Fall back to CLS token from last hidden state
|
||||||
|
features = outputs.last_hidden_state[:, 0, :]
|
||||||
|
|
||||||
|
# Apply projection head
|
||||||
|
features = self.projection(features)
|
||||||
|
|
||||||
|
# L2 normalize for cosine similarity
|
||||||
|
features = F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def get_image_features(self, **kwargs) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compatibility method matching CLIP's interface.
|
||||||
|
|
||||||
|
Used by DetectLogosDETR._get_embedding_pil().
|
||||||
|
"""
|
||||||
|
return self.forward(kwargs["pixel_values"])
|
||||||
|
|
||||||
|
def get_trainable_parameters(self) -> List[torch.nn.Parameter]:
|
||||||
|
"""Return list of trainable parameters."""
|
||||||
|
return [p for p in self.parameters() if p.requires_grad]
|
||||||
|
|
||||||
|
def get_parameter_count(self) -> Dict[str, int]:
|
||||||
|
"""Return count of trainable and total parameters."""
|
||||||
|
total = sum(p.numel() for p in self.parameters())
|
||||||
|
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||||
|
return {
|
||||||
|
"total": total,
|
||||||
|
"trainable": trainable,
|
||||||
|
"frozen": total - trainable,
|
||||||
|
"trainable_percent": 100 * trainable / total if total > 0 else 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def save_pretrained(self, output_dir: str) -> None:
|
||||||
|
"""
|
||||||
|
Save model in HuggingFace-compatible format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Directory to save model files
|
||||||
|
"""
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save model weights
|
||||||
|
if self.peft_applied and PEFT_AVAILABLE:
|
||||||
|
# Save LoRA weights separately
|
||||||
|
self.vision_model.save_pretrained(output_path / "vision_lora")
|
||||||
|
# Save projection head
|
||||||
|
torch.save(
|
||||||
|
self.projection.state_dict(),
|
||||||
|
output_path / "projection_head.bin",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Save full model state
|
||||||
|
torch.save(self.state_dict(), output_path / "pytorch_model.bin")
|
||||||
|
|
||||||
|
# Save config
|
||||||
|
config = {
|
||||||
|
"model_type": "clip_logo_finetuned",
|
||||||
|
"embedding_dim": self.embedding_dim,
|
||||||
|
"lora_r": self.lora_r,
|
||||||
|
"lora_alpha": self.lora_alpha,
|
||||||
|
"freeze_layers": self.freeze_layers,
|
||||||
|
"add_projection_head": self.add_projection_head,
|
||||||
|
"peft_applied": self.peft_applied,
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(output_path / "config.json", "w") as f:
|
||||||
|
json.dump(config, f, indent=2)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
base_model: str = "openai/clip-vit-large-patch14",
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> "LogoFineTunedCLIP":
|
||||||
|
"""
|
||||||
|
Load a fine-tuned model from saved weights.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to saved model directory
|
||||||
|
base_model: Base CLIP model name (for architecture)
|
||||||
|
device: Device to load model on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Loaded LogoFineTunedCLIP model
|
||||||
|
"""
|
||||||
|
model_path = Path(model_path)
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
with open(model_path / "config.json", "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# Load base CLIP model
|
||||||
|
clip_model = CLIPModel.from_pretrained(base_model)
|
||||||
|
|
||||||
|
# Create model instance
|
||||||
|
model = cls(
|
||||||
|
vision_model=clip_model.vision_model,
|
||||||
|
lora_r=config.get("lora_r", 0),
|
||||||
|
lora_alpha=config.get("lora_alpha", 1),
|
||||||
|
freeze_layers=config.get("freeze_layers", 12),
|
||||||
|
add_projection_head=config.get("add_projection_head", True),
|
||||||
|
use_gradient_checkpointing=False, # Not needed for inference
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
if config.get("peft_applied", False) and PEFT_AVAILABLE:
|
||||||
|
# Load LoRA weights
|
||||||
|
lora_path = model_path / "vision_lora"
|
||||||
|
if lora_path.exists():
|
||||||
|
model.vision_model = PeftModel.from_pretrained(
|
||||||
|
model.vision_model, lora_path
|
||||||
|
)
|
||||||
|
# Load projection head
|
||||||
|
proj_path = model_path / "projection_head.bin"
|
||||||
|
if proj_path.exists():
|
||||||
|
model.projection.load_state_dict(torch.load(proj_path))
|
||||||
|
else:
|
||||||
|
# Load full model state
|
||||||
|
weights_path = model_path / "pytorch_model.bin"
|
||||||
|
if weights_path.exists():
|
||||||
|
model.load_state_dict(torch.load(weights_path))
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(
|
||||||
|
base_model: str = "openai/clip-vit-large-patch14",
|
||||||
|
lora_r: int = 16,
|
||||||
|
lora_alpha: int = 32,
|
||||||
|
lora_dropout: float = 0.1,
|
||||||
|
freeze_layers: int = 12,
|
||||||
|
use_gradient_checkpointing: bool = True,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Tuple[LogoFineTunedCLIP, CLIPProcessor]:
|
||||||
|
"""
|
||||||
|
Create a fine-tunable CLIP model and processor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_model: HuggingFace model name or path
|
||||||
|
lora_r: LoRA rank (0 to disable)
|
||||||
|
lora_alpha: LoRA scaling factor
|
||||||
|
lora_dropout: LoRA dropout
|
||||||
|
freeze_layers: Number of layers to freeze
|
||||||
|
use_gradient_checkpointing: Enable gradient checkpointing
|
||||||
|
device: Device to load model on
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (model, processor)
|
||||||
|
"""
|
||||||
|
# Load base CLIP model
|
||||||
|
clip_model = CLIPModel.from_pretrained(base_model)
|
||||||
|
processor = CLIPProcessor.from_pretrained(base_model)
|
||||||
|
|
||||||
|
# Create fine-tunable wrapper
|
||||||
|
model = LogoFineTunedCLIP(
|
||||||
|
vision_model=clip_model.vision_model,
|
||||||
|
lora_r=lora_r,
|
||||||
|
lora_alpha=lora_alpha,
|
||||||
|
lora_dropout=lora_dropout,
|
||||||
|
freeze_layers=freeze_layers,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
)
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
# Print parameter info
|
||||||
|
param_info = model.get_parameter_count()
|
||||||
|
print(f"Model created:")
|
||||||
|
print(f" Total parameters: {param_info['total']:,}")
|
||||||
|
print(f" Trainable: {param_info['trainable']:,} ({param_info['trainable_percent']:.2f}%)")
|
||||||
|
print(f" Frozen: {param_info['frozen']:,}")
|
||||||
|
|
||||||
|
return model, processor
|
||||||
405
training/trainer.py
Normal file
405
training/trainer.py
Normal file
@ -0,0 +1,405 @@
|
|||||||
|
"""
|
||||||
|
Training loop with checkpointing, mixed precision, and evaluation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .config import TrainingConfig
|
||||||
|
from .losses import get_loss_function
|
||||||
|
from .evaluation import EmbeddingEvaluator
|
||||||
|
|
||||||
|
# Check if amp is available
|
||||||
|
try:
|
||||||
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
AMP_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
AMP_AVAILABLE = False
|
||||||
|
autocast = None
|
||||||
|
GradScaler = None
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
"""
|
||||||
|
Trainer for fine-tuning CLIP on logo recognition.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Mixed precision training (FP16)
|
||||||
|
- Gradient accumulation
|
||||||
|
- Gradient checkpointing (via model)
|
||||||
|
- Cosine annealing LR scheduler
|
||||||
|
- Early stopping
|
||||||
|
- Checkpoint saving/loading
|
||||||
|
- Evaluation during training
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
val_loader: DataLoader,
|
||||||
|
config: TrainingConfig,
|
||||||
|
logger: Optional[logging.Logger] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the trainer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: LogoFineTunedCLIP model
|
||||||
|
train_loader: Training dataloader
|
||||||
|
val_loader: Validation dataloader
|
||||||
|
config: Training configuration
|
||||||
|
logger: Optional logger instance
|
||||||
|
"""
|
||||||
|
self.model = model
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.config = config
|
||||||
|
self.logger = logger or logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Device setup
|
||||||
|
self.device = torch.device(
|
||||||
|
"cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.logger.info(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
# Optimizer - only trainable parameters
|
||||||
|
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||||
|
self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
|
||||||
|
|
||||||
|
self.optimizer = AdamW(
|
||||||
|
trainable_params,
|
||||||
|
lr=config.learning_rate,
|
||||||
|
weight_decay=config.weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Learning rate scheduler
|
||||||
|
total_steps = len(train_loader) * config.max_epochs
|
||||||
|
self.scheduler = OneCycleLR(
|
||||||
|
self.optimizer,
|
||||||
|
max_lr=config.learning_rate,
|
||||||
|
total_steps=total_steps,
|
||||||
|
pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1,
|
||||||
|
anneal_strategy="cos",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mixed precision training
|
||||||
|
self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda"
|
||||||
|
if self.use_amp:
|
||||||
|
self.scaler = GradScaler()
|
||||||
|
self.logger.info("Mixed precision training enabled")
|
||||||
|
else:
|
||||||
|
self.scaler = None
|
||||||
|
if config.mixed_precision and not AMP_AVAILABLE:
|
||||||
|
self.logger.warning("Mixed precision requested but not available")
|
||||||
|
|
||||||
|
# Loss function
|
||||||
|
self.criterion = get_loss_function(
|
||||||
|
loss_type=config.loss_type,
|
||||||
|
temperature=config.temperature,
|
||||||
|
triplet_margin=config.triplet_margin,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Evaluator
|
||||||
|
self.evaluator = EmbeddingEvaluator()
|
||||||
|
|
||||||
|
# Training state
|
||||||
|
self.epoch = 0
|
||||||
|
self.global_step = 0
|
||||||
|
self.best_val_loss = float("inf")
|
||||||
|
self.best_val_separation = float("-inf")
|
||||||
|
self.patience_counter = 0
|
||||||
|
self.training_history = []
|
||||||
|
|
||||||
|
def train(self) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Main training loop.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with final training metrics
|
||||||
|
"""
|
||||||
|
self.logger.info("Starting training...")
|
||||||
|
self.logger.info(f" Epochs: {self.config.max_epochs}")
|
||||||
|
self.logger.info(f" Batch size: {self.config.batch_size}")
|
||||||
|
self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}")
|
||||||
|
self.logger.info(f" Effective batch: {self.config.effective_batch_size}")
|
||||||
|
self.logger.info(f" Learning rate: {self.config.learning_rate}")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for epoch in range(self.epoch, self.config.max_epochs):
|
||||||
|
self.epoch = epoch
|
||||||
|
self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}")
|
||||||
|
|
||||||
|
# Training epoch
|
||||||
|
train_metrics = self._train_epoch()
|
||||||
|
self.logger.info(
|
||||||
|
f"Train - Loss: {train_metrics['loss']:.4f}, "
|
||||||
|
f"LR: {train_metrics['lr']:.2e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
if (epoch + 1) % self.config.eval_every_n_epochs == 0:
|
||||||
|
val_metrics = self._validate()
|
||||||
|
self.logger.info(
|
||||||
|
f"Val - Loss: {val_metrics['loss']:.4f}, "
|
||||||
|
f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, "
|
||||||
|
f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, "
|
||||||
|
f"Separation: {val_metrics['separation']:.3f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record history
|
||||||
|
self.training_history.append({
|
||||||
|
"epoch": epoch + 1,
|
||||||
|
"train_loss": train_metrics["loss"],
|
||||||
|
"val_loss": val_metrics["loss"],
|
||||||
|
"val_separation": val_metrics["separation"],
|
||||||
|
"val_pos_sim": val_metrics["mean_pos_sim"],
|
||||||
|
"val_neg_sim": val_metrics["mean_neg_sim"],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Checkpointing based on separation (primary) or loss (secondary)
|
||||||
|
improved = False
|
||||||
|
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
|
||||||
|
self.best_val_separation = val_metrics["separation"]
|
||||||
|
improved = True
|
||||||
|
elif val_metrics["loss"] < self.best_val_loss - self.config.min_delta:
|
||||||
|
self.best_val_loss = val_metrics["loss"]
|
||||||
|
improved = True
|
||||||
|
|
||||||
|
if improved:
|
||||||
|
self.patience_counter = 0
|
||||||
|
self._save_checkpoint("best.pt")
|
||||||
|
self.logger.info("New best model saved!")
|
||||||
|
else:
|
||||||
|
self.patience_counter += 1
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
if self.patience_counter >= self.config.patience:
|
||||||
|
self.logger.info(
|
||||||
|
f"Early stopping triggered at epoch {epoch + 1} "
|
||||||
|
f"(no improvement for {self.config.patience} epochs)"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# Periodic checkpoint
|
||||||
|
if (epoch + 1) % self.config.save_every_n_epochs == 0:
|
||||||
|
self._save_checkpoint(f"epoch_{epoch + 1}.pt")
|
||||||
|
|
||||||
|
# Training complete
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes")
|
||||||
|
|
||||||
|
# Load best model
|
||||||
|
best_path = Path(self.config.checkpoint_dir) / "best.pt"
|
||||||
|
if best_path.exists():
|
||||||
|
self.load_checkpoint("best.pt")
|
||||||
|
self.logger.info("Loaded best model checkpoint")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"best_val_loss": self.best_val_loss,
|
||||||
|
"best_val_separation": self.best_val_separation,
|
||||||
|
"total_epochs": self.epoch + 1,
|
||||||
|
"total_time_minutes": total_time / 60,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _train_epoch(self) -> Dict[str, float]:
|
||||||
|
"""Run a single training epoch."""
|
||||||
|
self.model.train()
|
||||||
|
total_loss = 0.0
|
||||||
|
num_batches = 0
|
||||||
|
accumulation_steps = 0
|
||||||
|
|
||||||
|
progress_bar = tqdm(
|
||||||
|
self.train_loader,
|
||||||
|
desc=f"Epoch {self.epoch + 1}",
|
||||||
|
leave=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
|
for batch_idx, (images, labels) in enumerate(progress_bar):
|
||||||
|
images = images.to(self.device)
|
||||||
|
labels = labels.to(self.device)
|
||||||
|
|
||||||
|
# Forward pass with mixed precision
|
||||||
|
if self.use_amp:
|
||||||
|
with autocast():
|
||||||
|
embeddings = self.model(images)
|
||||||
|
loss = self.criterion(embeddings, labels)
|
||||||
|
loss = loss / self.config.gradient_accumulation_steps
|
||||||
|
|
||||||
|
self.scaler.scale(loss).backward()
|
||||||
|
else:
|
||||||
|
embeddings = self.model(images)
|
||||||
|
loss = self.criterion(embeddings, labels)
|
||||||
|
loss = loss / self.config.gradient_accumulation_steps
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
accumulation_steps += 1
|
||||||
|
|
||||||
|
# Optimizer step after accumulation
|
||||||
|
if accumulation_steps >= self.config.gradient_accumulation_steps:
|
||||||
|
if self.use_amp:
|
||||||
|
self.scaler.step(self.optimizer)
|
||||||
|
self.scaler.update()
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
self.scheduler.step()
|
||||||
|
self.global_step += 1
|
||||||
|
accumulation_steps = 0
|
||||||
|
|
||||||
|
total_loss += loss.item() * self.config.gradient_accumulation_steps
|
||||||
|
num_batches += 1
|
||||||
|
|
||||||
|
# Update progress bar
|
||||||
|
progress_bar.set_postfix({
|
||||||
|
"loss": total_loss / num_batches,
|
||||||
|
"lr": self.scheduler.get_last_lr()[0],
|
||||||
|
})
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
if (batch_idx + 1) % self.config.log_every_n_steps == 0:
|
||||||
|
self.logger.debug(
|
||||||
|
f"Step {self.global_step}: loss={total_loss / num_batches:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": total_loss / max(num_batches, 1),
|
||||||
|
"lr": self.scheduler.get_last_lr()[0],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _validate(self) -> Dict[str, float]:
|
||||||
|
"""Run validation and compute metrics."""
|
||||||
|
self.model.eval()
|
||||||
|
total_loss = 0.0
|
||||||
|
all_embeddings = []
|
||||||
|
all_labels = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
|
||||||
|
images = images.to(self.device)
|
||||||
|
labels = labels.to(self.device)
|
||||||
|
|
||||||
|
if self.use_amp:
|
||||||
|
with autocast():
|
||||||
|
embeddings = self.model(images)
|
||||||
|
loss = self.criterion(embeddings, labels)
|
||||||
|
else:
|
||||||
|
embeddings = self.model(images)
|
||||||
|
loss = self.criterion(embeddings, labels)
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
all_embeddings.append(embeddings.cpu())
|
||||||
|
all_labels.append(labels.cpu())
|
||||||
|
|
||||||
|
# Combine batches
|
||||||
|
all_embeddings = torch.cat(all_embeddings, dim=0)
|
||||||
|
all_labels = torch.cat(all_labels, dim=0)
|
||||||
|
|
||||||
|
# Compute embedding quality metrics
|
||||||
|
metrics = self.evaluator.compute_metrics(all_embeddings, all_labels)
|
||||||
|
metrics["loss"] = total_loss / max(len(self.val_loader), 1)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _save_checkpoint(self, filename: str) -> None:
|
||||||
|
"""Save training checkpoint."""
|
||||||
|
checkpoint_dir = Path(self.config.checkpoint_dir)
|
||||||
|
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
checkpoint = {
|
||||||
|
"epoch": self.epoch,
|
||||||
|
"global_step": self.global_step,
|
||||||
|
"model_state_dict": self.model.state_dict(),
|
||||||
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||||
|
"scheduler_state_dict": self.scheduler.state_dict(),
|
||||||
|
"best_val_loss": self.best_val_loss,
|
||||||
|
"best_val_separation": self.best_val_separation,
|
||||||
|
"patience_counter": self.patience_counter,
|
||||||
|
"training_history": self.training_history,
|
||||||
|
"config": self.config.__dict__,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.scaler is not None:
|
||||||
|
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
|
||||||
|
|
||||||
|
torch.save(checkpoint, checkpoint_dir / filename)
|
||||||
|
self.logger.debug(f"Saved checkpoint: {filename}")
|
||||||
|
|
||||||
|
def load_checkpoint(self, filename: str) -> None:
|
||||||
|
"""Load training checkpoint."""
|
||||||
|
checkpoint_path = Path(self.config.checkpoint_dir) / filename
|
||||||
|
if not checkpoint_path.exists():
|
||||||
|
self.logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||||
|
|
||||||
|
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||||
|
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||||
|
self.epoch = checkpoint["epoch"]
|
||||||
|
self.global_step = checkpoint["global_step"]
|
||||||
|
self.best_val_loss = checkpoint["best_val_loss"]
|
||||||
|
self.best_val_separation = checkpoint.get("best_val_separation", float("-inf"))
|
||||||
|
self.patience_counter = checkpoint.get("patience_counter", 0)
|
||||||
|
self.training_history = checkpoint.get("training_history", [])
|
||||||
|
|
||||||
|
if self.scaler is not None and "scaler_state_dict" in checkpoint:
|
||||||
|
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
||||||
|
|
||||||
|
self.logger.info(f"Resumed from epoch {self.epoch + 1}")
|
||||||
|
|
||||||
|
def export_model(self, output_dir: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Export the trained model for inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: Output directory (uses config.output_dir if not specified)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to exported model directory
|
||||||
|
"""
|
||||||
|
output_dir = output_dir or self.config.output_dir
|
||||||
|
output_path = Path(output_dir)
|
||||||
|
output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Save model
|
||||||
|
self.model.save_pretrained(output_dir)
|
||||||
|
|
||||||
|
# Save training config
|
||||||
|
config_path = output_path / "training_config.json"
|
||||||
|
with open(config_path, "w") as f:
|
||||||
|
json.dump(self.config.__dict__, f, indent=2)
|
||||||
|
|
||||||
|
# Save training history
|
||||||
|
history_path = output_path / "training_history.json"
|
||||||
|
with open(history_path, "w") as f:
|
||||||
|
json.dump(self.training_history, f, indent=2)
|
||||||
|
|
||||||
|
self.logger.info(f"Model exported to: {output_path}")
|
||||||
|
return str(output_path)
|
||||||
|
|
||||||
|
def get_training_summary(self) -> Dict:
|
||||||
|
"""Get summary of training."""
|
||||||
|
return {
|
||||||
|
"epochs_completed": self.epoch + 1,
|
||||||
|
"global_steps": self.global_step,
|
||||||
|
"best_val_loss": self.best_val_loss,
|
||||||
|
"best_val_separation": self.best_val_separation,
|
||||||
|
"history": self.training_history,
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user