Compare commits
25 Commits
1551360028
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| f2ae80c9e5 | |||
| 8b67b50d19 | |||
| 5ce6265a90 | |||
| 512f678310 | |||
| f598866d37 | |||
| 91d1c9cd59 | |||
| ea6fcec9ce | |||
| f777b049a3 | |||
| 49f982611a | |||
| 78f46f04bf | |||
| b5432c9ef7 | |||
| 440e8fcdb4 | |||
| 2f28aa6052 | |||
| 569285f664 | |||
| c086e8bbf7 | |||
| 304d743df8 | |||
| 55abb1217c | |||
| 14a1bda3fa | |||
| 32bfefc022 | |||
| f74d4b6981 | |||
| 6685af72d9 | |||
| 1bf9985def | |||
| e5482a2d9e | |||
| 99e5781c91 | |||
| 44e8b6ae7d |
301
CLIP_FINETUNING.md
Normal file
301
CLIP_FINETUNING.md
Normal file
@ -0,0 +1,301 @@
|
||||
# CLIP Fine-Tuning for Logo Recognition
|
||||
|
||||
This document describes the CLIP fine-tuning pipeline for improving logo embedding similarity using the LogoDet-3K dataset.
|
||||
|
||||
## Overview
|
||||
|
||||
The fine-tuning approach uses **contrastive learning** with **LoRA** (Low-Rank Adaptation) to train CLIP's vision encoder for better logo similarity matching while maintaining compatibility with the existing `DetectLogosDETR` class.
|
||||
|
||||
**Goal**: Improve F1 from ~60% to >72% on logo matching tasks.
|
||||
|
||||
## Files Created
|
||||
|
||||
### Training Module (`training/`)
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `__init__.py` | Module exports |
|
||||
| `config.py` | `TrainingConfig` dataclass with all hyperparameters |
|
||||
| `dataset.py` | `LogoContrastiveDataset` with logo-level splits and augmentations |
|
||||
| `model.py` | `LogoFineTunedCLIP` wrapper with LoRA support |
|
||||
| `losses.py` | `InfoNCELoss`, `TripletLoss`, `SupConLoss`, `CombinedLoss` |
|
||||
| `trainer.py` | Training loop with mixed precision, checkpointing, early stopping |
|
||||
| `evaluation.py` | `EmbeddingEvaluator` for validation metrics |
|
||||
|
||||
### Scripts
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `train_clip_logo.py` | Main training entry point |
|
||||
| `export_model.py` | Export trained models to HuggingFace-compatible format |
|
||||
|
||||
### Configuration
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `configs/jetson_orin.yaml` | Training config optimized for Jetson Orin AGX |
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. **Install dependencies**:
|
||||
```bash
|
||||
uv sync
|
||||
```
|
||||
|
||||
2. **Prepare test data** (if not already done):
|
||||
```bash
|
||||
uv run python prepare_test_data.py
|
||||
```
|
||||
|
||||
This creates:
|
||||
- `reference_logos/` - Cropped logo images organized by category/brand
|
||||
- `test_images/` - Full images for testing
|
||||
- `test_data_mapping.db` - SQLite database with mappings
|
||||
|
||||
## Training
|
||||
|
||||
### Basic Training
|
||||
|
||||
```bash
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
|
||||
```
|
||||
|
||||
### Training with Overrides
|
||||
|
||||
```bash
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||
--learning-rate 5e-6 \
|
||||
--max-epochs 30 \
|
||||
--batch-size 8
|
||||
```
|
||||
|
||||
### Resume from Checkpoint
|
||||
|
||||
```bash
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||
--resume checkpoints/epoch_10.pt
|
||||
```
|
||||
|
||||
### Training Output
|
||||
|
||||
- Checkpoints saved to `checkpoints/`
|
||||
- Best model saved as `checkpoints/best.pt`
|
||||
- Final model exported to `models/logo_detection/clip_finetuned/`
|
||||
|
||||
## Configuration Options
|
||||
|
||||
Key parameters in `configs/jetson_orin.yaml`:
|
||||
|
||||
```yaml
|
||||
# Model
|
||||
base_model: "openai/clip-vit-large-patch14"
|
||||
lora_r: 16 # LoRA rank (0 to disable)
|
||||
lora_alpha: 32 # LoRA scaling factor
|
||||
freeze_layers: 12 # Freeze first N transformer layers
|
||||
|
||||
# Batch construction
|
||||
batch_size: 16
|
||||
logos_per_batch: 32 # Different logos per batch
|
||||
samples_per_logo: 4 # Samples per logo (creates positive pairs)
|
||||
gradient_accumulation_steps: 8 # Effective batch = 128
|
||||
|
||||
# Training
|
||||
learning_rate: 1.0e-5
|
||||
max_epochs: 20
|
||||
mixed_precision: true
|
||||
temperature: 0.07 # InfoNCE temperature
|
||||
|
||||
# Early stopping
|
||||
patience: 5
|
||||
min_delta: 0.001
|
||||
```
|
||||
|
||||
## Evaluation
|
||||
|
||||
### Test Fine-Tuned Model
|
||||
|
||||
**Important**: The fine-tuned model requires a higher threshold (0.82) than baseline (0.75).
|
||||
|
||||
```bash
|
||||
uv run python test_logo_detection.py -n 50 \
|
||||
-e models/logo_detection/clip_finetuned \
|
||||
-t 0.82 \
|
||||
--matching-method multi-ref \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
### Compare with Baseline
|
||||
|
||||
```bash
|
||||
# Baseline CLIP (threshold 0.75)
|
||||
uv run python test_logo_detection.py -n 50 \
|
||||
-e openai/clip-vit-large-patch14 \
|
||||
-t 0.75 \
|
||||
--matching-method multi-ref \
|
||||
--seed 42
|
||||
|
||||
# Fine-tuned model (threshold 0.82)
|
||||
uv run python test_logo_detection.py -n 50 \
|
||||
-e models/logo_detection/clip_finetuned \
|
||||
-t 0.82 \
|
||||
--matching-method multi-ref \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
### Threshold Selection
|
||||
|
||||
The fine-tuned model requires a **higher similarity threshold** than baseline CLIP. This is because contrastive learning successfully pushed non-matching logo similarities much lower, changing the score distribution.
|
||||
|
||||
#### Similarity Distribution Analysis
|
||||
|
||||
| Metric | Baseline | Fine-tuned |
|
||||
|--------|----------|------------|
|
||||
| Wrong logos mean similarity | 0.66 | **0.44** |
|
||||
| Wrong logos above 0.75 | 23.2% | **0.6%** |
|
||||
| Correct logos mean similarity | 0.75 | 0.64 |
|
||||
| Optimal threshold | 0.756 | **0.819** |
|
||||
| F1 at optimal threshold | 67.1% | **71.9%** |
|
||||
|
||||
**Key insight**: The fine-tuned model dramatically reduced similarities to wrong logos (from 0.66 to 0.44 mean). This means at threshold 0.75, it correctly rejects far more non-matches, but needs a higher threshold to avoid false positives from scores that bunch up just above 0.75.
|
||||
|
||||
#### Analyze Similarity Distribution
|
||||
|
||||
To find the optimal threshold for your model:
|
||||
|
||||
```bash
|
||||
# Run detailed similarity analysis
|
||||
./analyze_similarity_distribution.sh --model finetuned
|
||||
|
||||
# Or analyze both models
|
||||
./analyze_similarity_distribution.sh --model both
|
||||
```
|
||||
|
||||
This outputs distribution statistics and suggests an optimal threshold based on the data.
|
||||
|
||||
### Expected Metrics
|
||||
|
||||
| Metric | Baseline (t=0.75) | Fine-tuned (t=0.82) |
|
||||
|--------|-------------------|---------------------|
|
||||
| Precision | ~49% | >65% |
|
||||
| Recall | ~77% | >70% |
|
||||
| F1 Score | ~60% | >70% |
|
||||
|
||||
Training metrics to monitor:
|
||||
- Mean positive similarity: target > 0.85
|
||||
- Mean negative similarity: target < 0.50
|
||||
- Embedding separation: target > 0.35
|
||||
|
||||
## Export Model
|
||||
|
||||
To export a checkpoint to HuggingFace format:
|
||||
|
||||
```bash
|
||||
uv run python export_model.py \
|
||||
--checkpoint checkpoints/best.pt \
|
||||
--output models/logo_detection/clip_finetuned
|
||||
```
|
||||
|
||||
With LoRA weight merging (reduces inference overhead):
|
||||
|
||||
```bash
|
||||
uv run python export_model.py \
|
||||
--checkpoint checkpoints/best.pt \
|
||||
--output models/logo_detection/clip_finetuned \
|
||||
--merge-lora
|
||||
```
|
||||
|
||||
## Using Fine-Tuned Model with DetectLogosDETR
|
||||
|
||||
The fine-tuned model works as a drop-in replacement:
|
||||
|
||||
```python
|
||||
from logo_detection_detr import DetectLogosDETR
|
||||
|
||||
# Use fine-tuned model
|
||||
detector = DetectLogosDETR(
|
||||
logger=logger,
|
||||
embedding_model="models/logo_detection/clip_finetuned",
|
||||
)
|
||||
|
||||
# Or use baseline for comparison
|
||||
detector_baseline = DetectLogosDETR(
|
||||
logger=logger,
|
||||
embedding_model="openai/clip-vit-large-patch14",
|
||||
)
|
||||
```
|
||||
|
||||
## Architecture Details
|
||||
|
||||
### Training Approach
|
||||
|
||||
1. **Contrastive Learning**: Uses InfoNCE loss to maximize similarity between embeddings of the same logo while minimizing similarity to different logos.
|
||||
|
||||
2. **LoRA (Low-Rank Adaptation)**: Adds small trainable matrices to attention layers instead of fine-tuning all weights. This is memory-efficient and prevents catastrophic forgetting.
|
||||
|
||||
3. **Layer Freezing**: Freezes the first 12 of 24 transformer layers to preserve CLIP's low-level visual features while adapting high-level semantics.
|
||||
|
||||
4. **Logo-Level Splits**: Splits data by logo brand (not by image) to test generalization to unseen logos.
|
||||
|
||||
### Batch Construction
|
||||
|
||||
Each batch contains:
|
||||
- K different logo brands (default: 32)
|
||||
- M samples per brand (default: 4)
|
||||
- Total samples: K × M = 128
|
||||
|
||||
This ensures positive pairs (same logo) exist within each batch for contrastive learning.
|
||||
|
||||
### Data Augmentation
|
||||
|
||||
Medium strength augmentations:
|
||||
- Random horizontal flip
|
||||
- Random rotation (±15°)
|
||||
- Color jitter (brightness, contrast, saturation)
|
||||
- Random affine transforms
|
||||
- Random grayscale (10% of images)
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Out of Memory
|
||||
|
||||
Reduce batch size and increase gradient accumulation:
|
||||
|
||||
```bash
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||
--batch-size 8 \
|
||||
--gradient-accumulation-steps 16
|
||||
```
|
||||
|
||||
### Slow Training
|
||||
|
||||
Ensure mixed precision is enabled:
|
||||
|
||||
```bash
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
|
||||
# mixed_precision: true is default in jetson_orin.yaml
|
||||
```
|
||||
|
||||
### No Improvement
|
||||
|
||||
Try adjusting:
|
||||
- Lower learning rate: `--learning-rate 5e-6`
|
||||
- Higher temperature: `--temperature 0.1`
|
||||
- Different loss: edit config to use `loss_type: "combined"`
|
||||
|
||||
### Import Error for Fine-Tuned Model
|
||||
|
||||
Ensure the `training/` module is in your Python path:
|
||||
|
||||
```bash
|
||||
export PYTHONPATH="${PYTHONPATH}:/data/dev.python/logo_test"
|
||||
```
|
||||
|
||||
## Dependencies Added
|
||||
|
||||
The following were added to `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
peft>=0.7.0 # LoRA support
|
||||
pyyaml>=6.0 # Config file parsing
|
||||
torchvision>=0.20.0 # Image transforms
|
||||
```
|
||||
269
CLOUD_TRAINING.md
Normal file
269
CLOUD_TRAINING.md
Normal file
@ -0,0 +1,269 @@
|
||||
# Cloud GPU Training for CLIP Fine-Tuning
|
||||
|
||||
This document provides guidance on using cloud GPU instances (e.g., RunPod) for faster CLIP fine-tuning compared to local training on Jetson Orin AGX.
|
||||
|
||||
## Training Time Comparison
|
||||
|
||||
Local training on Jetson Orin AGX takes approximately 24 hours. Cloud GPUs offer significantly faster training:
|
||||
|
||||
| GPU | VRAM | Est. Training Time | Hourly Rate | Est. Total Cost |
|
||||
|-----|------|-------------------|-------------|-----------------|
|
||||
| **RTX 4090** | 24GB | 4-6 hours | $0.59/hr | **$2.40-$3.50** |
|
||||
| **RTX 3090** | 24GB | 5-7 hours | $0.39/hr | **$2.00-$2.75** |
|
||||
| **A100 80GB** | 80GB | 2-3 hours | $1.99/hr | **$4.00-$6.00** |
|
||||
| **L40S** | 48GB | 3-4 hours | $0.89/hr | **$2.70-$3.60** |
|
||||
| **H100 80GB** | 80GB | 1.5-2 hours | $1.99/hr | **$3.00-$4.00** |
|
||||
|
||||
*Prices from RunPod Community Cloud as of January 2025. Rates may vary.*
|
||||
|
||||
## Recommendations
|
||||
|
||||
### Best Value: RTX 4090 ($0.59/hr)
|
||||
- 24GB VRAM is sufficient for ViT-L/14 with LoRA
|
||||
- Good balance of speed and cost
|
||||
- Widely available on Community Cloud
|
||||
- **Total cost: ~$3 for complete training**
|
||||
|
||||
### Best Speed: H100 80GB ($1.99/hr)
|
||||
- Fastest training (1.5-2 hours)
|
||||
- 80GB VRAM allows larger batch sizes
|
||||
- Can increase `batch_size` to 32+ and reduce `gradient_accumulation_steps`
|
||||
- **Total cost: ~$3-4**
|
||||
|
||||
### Budget Option: RTX 3090 ($0.39/hr)
|
||||
- Cheapest hourly rate
|
||||
- 24GB VRAM works fine
|
||||
- Slightly slower than 4090
|
||||
- **Total cost: ~$2-3**
|
||||
|
||||
## Cloud-Optimized Configurations
|
||||
|
||||
### RTX 4090 / RTX 3090 (24GB VRAM)
|
||||
|
||||
Create `configs/cloud_rtx4090.yaml`:
|
||||
|
||||
```yaml
|
||||
# Optimized for 24GB VRAM cloud GPUs
|
||||
base_model: "openai/clip-vit-large-patch14"
|
||||
|
||||
# Dataset paths
|
||||
dataset_dir: "LogoDet-3K"
|
||||
reference_dir: "reference_logos"
|
||||
db_path: "test_data_mapping.db"
|
||||
|
||||
# Data splits
|
||||
train_split: 0.7
|
||||
val_split: 0.15
|
||||
test_split: 0.15
|
||||
|
||||
# Larger batches for faster training
|
||||
batch_size: 32
|
||||
logos_per_batch: 32
|
||||
samples_per_logo: 4
|
||||
gradient_accumulation_steps: 4 # Effective batch = 128
|
||||
num_workers: 8
|
||||
|
||||
# Model architecture
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
freeze_layers: 12
|
||||
use_gradient_checkpointing: true
|
||||
|
||||
# Training
|
||||
learning_rate: 1.0e-5
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 500
|
||||
max_epochs: 20
|
||||
mixed_precision: true
|
||||
|
||||
# Loss
|
||||
temperature: 0.07
|
||||
loss_type: "infonce"
|
||||
|
||||
# Early stopping
|
||||
patience: 5
|
||||
min_delta: 0.001
|
||||
|
||||
# Output
|
||||
checkpoint_dir: "checkpoints"
|
||||
output_dir: "models/logo_detection/clip_finetuned"
|
||||
save_every_n_epochs: 5
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: 10
|
||||
eval_every_n_epochs: 1
|
||||
|
||||
seed: 42
|
||||
use_augmentation: true
|
||||
augmentation_strength: "medium"
|
||||
```
|
||||
|
||||
### A100 / H100 (80GB VRAM)
|
||||
|
||||
Create `configs/cloud_a100.yaml`:
|
||||
|
||||
```yaml
|
||||
# Optimized for 80GB VRAM cloud GPUs (A100, H100)
|
||||
base_model: "openai/clip-vit-large-patch14"
|
||||
|
||||
# Dataset paths
|
||||
dataset_dir: "LogoDet-3K"
|
||||
reference_dir: "reference_logos"
|
||||
db_path: "test_data_mapping.db"
|
||||
|
||||
# Data splits
|
||||
train_split: 0.7
|
||||
val_split: 0.15
|
||||
test_split: 0.15
|
||||
|
||||
# Maximum batch sizes for 80GB VRAM
|
||||
batch_size: 64
|
||||
logos_per_batch: 32
|
||||
samples_per_logo: 4
|
||||
gradient_accumulation_steps: 2 # Effective batch = 128
|
||||
num_workers: 8
|
||||
|
||||
# Model architecture (can disable gradient checkpointing with 80GB)
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
freeze_layers: 12
|
||||
use_gradient_checkpointing: false # Not needed with 80GB
|
||||
|
||||
# Training
|
||||
learning_rate: 1.0e-5
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 500
|
||||
max_epochs: 20
|
||||
mixed_precision: true
|
||||
|
||||
# Loss
|
||||
temperature: 0.07
|
||||
loss_type: "infonce"
|
||||
|
||||
# Early stopping
|
||||
patience: 5
|
||||
min_delta: 0.001
|
||||
|
||||
# Output
|
||||
checkpoint_dir: "checkpoints"
|
||||
output_dir: "models/logo_detection/clip_finetuned"
|
||||
save_every_n_epochs: 5
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: 10
|
||||
eval_every_n_epochs: 1
|
||||
|
||||
seed: 42
|
||||
use_augmentation: true
|
||||
augmentation_strength: "medium"
|
||||
```
|
||||
|
||||
## RunPod Quick Start
|
||||
|
||||
### 1. Create a Pod
|
||||
|
||||
1. Go to [RunPod](https://www.runpod.io/)
|
||||
2. Select GPU (RTX 4090 recommended)
|
||||
3. Choose PyTorch template (CUDA 12.x)
|
||||
4. Set volume size: 50GB (for dataset + models)
|
||||
|
||||
### 2. Setup Environment
|
||||
|
||||
```bash
|
||||
# Connect via SSH or web terminal
|
||||
|
||||
# Install dependencies
|
||||
pip install peft pyyaml torchvision transformers tqdm pillow
|
||||
|
||||
# Clone your repository (or upload files)
|
||||
git clone <your-repo-url>
|
||||
cd logo_test
|
||||
|
||||
# Or use runpodctl to sync files
|
||||
# runpodctl send logo_test/
|
||||
```
|
||||
|
||||
### 3. Prepare Data
|
||||
|
||||
If data isn't already prepared:
|
||||
|
||||
```bash
|
||||
# This creates reference_logos/ and test_data_mapping.db
|
||||
python prepare_test_data.py
|
||||
```
|
||||
|
||||
### 4. Run Training
|
||||
|
||||
```bash
|
||||
# For RTX 4090
|
||||
python train_clip_logo.py --config configs/cloud_rtx4090.yaml
|
||||
|
||||
# For A100/H100
|
||||
python train_clip_logo.py --config configs/cloud_a100.yaml
|
||||
|
||||
# Or with command-line overrides
|
||||
python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||
--batch-size 32 \
|
||||
--gradient-accumulation-steps 4 \
|
||||
--num-workers 8
|
||||
```
|
||||
|
||||
### 5. Download Results
|
||||
|
||||
```bash
|
||||
# Export the trained model
|
||||
python export_model.py \
|
||||
--checkpoint checkpoints/best.pt \
|
||||
--output models/logo_detection/clip_finetuned
|
||||
|
||||
# Download to local machine
|
||||
# Option 1: Use runpodctl
|
||||
runpodctl receive models/logo_detection/clip_finetuned
|
||||
|
||||
# Option 2: SCP
|
||||
scp -r root@<pod-ip>:/workspace/logo_test/models/logo_detection/clip_finetuned ./
|
||||
|
||||
# Option 3: Compress and download via web
|
||||
tar -czvf clip_finetuned.tar.gz models/logo_detection/clip_finetuned
|
||||
```
|
||||
|
||||
## Cost Optimization Tips
|
||||
|
||||
### Use Spot/Interruptible Instances
|
||||
- Community Cloud GPUs are already cheaper
|
||||
- Some providers offer spot pricing for additional savings
|
||||
- Save checkpoints frequently (`save_every_n_epochs: 2`)
|
||||
|
||||
### Minimize Storage Costs
|
||||
- RunPod charges $0.10/GB/month for container disk
|
||||
- Use network volumes only if needed
|
||||
- Delete pods when training completes
|
||||
|
||||
### Monitor Training
|
||||
- Watch for early convergence (may finish before 20 epochs)
|
||||
- Early stopping will save time/cost if no improvement
|
||||
|
||||
### Batch Training Runs
|
||||
- Test configuration locally first (1-2 epochs)
|
||||
- Run full training on cloud only when config is validated
|
||||
|
||||
## Cost Comparison Summary
|
||||
|
||||
| Option | Time | Cost | Best For |
|
||||
|--------|------|------|----------|
|
||||
| Jetson Orin (local) | ~24 hrs | Free* | No cloud dependency |
|
||||
| RTX 3090 (RunPod) | ~6 hrs | ~$2.50 | Lowest cost |
|
||||
| RTX 4090 (RunPod) | ~5 hrs | ~$3.00 | Best value |
|
||||
| L40S (RunPod) | ~3.5 hrs | ~$3.00 | Good balance |
|
||||
| A100 80GB (RunPod) | ~2.5 hrs | ~$5.00 | Large batches |
|
||||
| H100 80GB (RunPod) | ~1.5 hrs | ~$3.50 | Fastest |
|
||||
|
||||
*Local training has electricity cost but no cloud fees.
|
||||
|
||||
## References
|
||||
|
||||
- [RunPod Pricing](https://www.runpod.io/pricing)
|
||||
- [RunPod RTX 4090](https://www.runpod.io/gpu-models/rtx-4090)
|
||||
- [RunPod Documentation](https://docs.runpod.io/)
|
||||
121
README.md
121
README.md
@ -2,6 +2,110 @@
|
||||
|
||||
A testing framework for evaluating logo detection accuracy using DETR (DEtection TRansformer) and CLIP (Contrastive Language-Image Pre-training) models.
|
||||
|
||||
## Burnley Test: Averaged Embeddings with DINOv2
|
||||
|
||||
A targeted test using `DetectLogosEmbeddings` to detect two specific logos (barnfield and vertu) in 516 Burnley match images. Reference embeddings are averaged across all images in each reference directory, and matching uses margin-based comparison (margin=0.05).
|
||||
|
||||
**Test command:**
|
||||
```bash
|
||||
uv run python test_burnley_detection.py -e dinov2 -t 0.7 --margin 0.05 --output-file results_average_embeddings.txt
|
||||
```
|
||||
|
||||
**Results (DINOv2, threshold 0.70, margin 0.05):**
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| True Positives | 28 |
|
||||
| False Positives | 36 |
|
||||
| False Negatives | 125 |
|
||||
| Total Expected | 146 |
|
||||
| **Precision** | **43.8%** |
|
||||
| **Recall** | **19.2%** |
|
||||
| **F1 Score** | **26.7%** |
|
||||
|
||||
Ground truth is derived from filename prefixes: `vertu_` (vertu logo), `barnfield_` (barnfield logo), `barnfield+vertu_` (both logos). Images without these prefixes are treated as negatives.
|
||||
|
||||
Low recall suggests many logos go undetected by DETR or fall below the similarity threshold. The relatively low precision indicates DINOv2 averaged embeddings struggle to discriminate between the two logos in this domain. Further tuning of thresholds, margin, and embedding model (e.g. CLIP or SigLIP) may improve results.
|
||||
|
||||
---
|
||||
|
||||
## Recommended Settings
|
||||
|
||||
Based on extensive testing with the LogoDet-3K dataset, these are the optimal settings:
|
||||
|
||||
| Parameter | Recommended Value | Notes |
|
||||
|-----------|-------------------|-------|
|
||||
| **Matching Method** | `multi-ref` | Best balance of precision and recall |
|
||||
| **Similarity Aggregation** | `max` (default) | Max outperforms mean aggregation |
|
||||
| **Embedding Model** | `openai/clip-vit-large-patch14` | Significantly outperforms DINOv2 |
|
||||
| **CLIP Threshold** | `0.70` | Good precision/recall balance |
|
||||
| **DETR Threshold** | `0.50` | Default detection confidence |
|
||||
| **Margin** | `0.05` | Reduces false positives |
|
||||
| **Refs per Logo** | `7-10` | More references = better accuracy |
|
||||
| **Preprocessing** | `default` | Best precision; letterbox/stretch hurt precision |
|
||||
|
||||
**Example command with recommended settings:**
|
||||
```bash
|
||||
uv run python test_logo_detection.py \
|
||||
--matching-method multi-ref \
|
||||
--refs-per-logo 10 \
|
||||
--threshold 0.70 \
|
||||
--margin 0.05 \
|
||||
--use-max-similarity
|
||||
```
|
||||
|
||||
### Performance Benchmarks
|
||||
|
||||
With recommended settings (multi-ref max, threshold 0.70, margin 0.05):
|
||||
|
||||
| Refs/Logo | Precision | Recall | F1 Score |
|
||||
|-----------|-----------|--------|----------|
|
||||
| 1 | 45.8% | 65.9% | 54.0% |
|
||||
| 3 | 40.5% | 72.4% | 51.9% |
|
||||
| 5 | 47.2% | 72.6% | 57.2% |
|
||||
| 7 | **51.0%** | **79.9%** | **62.3%** |
|
||||
| 10 | 50.2% | 81.6% | 62.1% |
|
||||
|
||||
**Key findings:**
|
||||
- More reference images per logo consistently improves recall
|
||||
- 7+ refs provides the best precision/recall balance
|
||||
- Diminishing returns beyond 10 refs
|
||||
|
||||
### Matching Method Comparison
|
||||
|
||||
| Method | Precision | Recall | F1 | Use Case |
|
||||
|--------|-----------|--------|-----|----------|
|
||||
| `simple` | 1.3% | 203%* | 2.5% | Not recommended (too many FPs) |
|
||||
| `margin` | 69.8% | 16.3% | 26.4% | High precision, low recall |
|
||||
| `multi-ref` (mean) | 51.8% | 63.1% | 56.9% | Balanced |
|
||||
| `multi-ref` (max) | **51.8%** | **75.3%** | **61.4%** | **Best overall** |
|
||||
|
||||
*Simple method returns all matches above threshold, causing many duplicates.
|
||||
|
||||
### Embedding Model Comparison
|
||||
|
||||
| Model | Precision | Recall | F1 | Recommendation |
|
||||
|-------|-----------|--------|-----|----------------|
|
||||
| `openai/clip-vit-large-patch14` | **49.1%** | **77.0%** | **59.9%** | **Recommended** |
|
||||
| `facebook/dinov2-small` | 22.4% | 42.8% | 29.5% | Not recommended |
|
||||
| `facebook/dinov2-large` | 32.2% | 28.5% | 30.2% | Not recommended |
|
||||
|
||||
CLIP significantly outperforms DINOv2 for logo matching tasks.
|
||||
|
||||
### Preprocessing Mode Comparison
|
||||
|
||||
| Mode | Precision | Recall | F1 | Notes |
|
||||
|------|-----------|--------|-----|-------|
|
||||
| `default` | **50.2%** | 81.6% | 62.1% | **Recommended** - best precision |
|
||||
| `letterbox` | 42.4% | 119%* | 62.6% | Higher recall but worse precision |
|
||||
| `stretch` | 34.5% | 113%* | 52.9% | Not recommended |
|
||||
|
||||
*Recall >100% indicates multiple detections per expected logo.
|
||||
|
||||
**Recommendation:** Use `default` preprocessing. While letterbox shows marginally higher F1, it has significantly worse precision (more false positives).
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
This project provides tools to:
|
||||
@ -97,9 +201,9 @@ uv run python test_logo_detection.py -n 50 --seed 42
|
||||
| `--clear-cache` | False | Clear embedding cache before running |
|
||||
|
||||
**Matching Methods:**
|
||||
- `simple` - Returns all logos above threshold (baseline, most permissive)
|
||||
- `margin` - Requires margin over second-best match (reduces false positives)
|
||||
- `multi-ref` - Aggregates scores across multiple reference images per logo
|
||||
- `simple` - Returns all logos above threshold (not recommended - too many false positives)
|
||||
- `margin` - Requires margin over second-best match (high precision, low recall)
|
||||
- `multi-ref` - **Recommended.** Aggregates scores across multiple reference images per logo
|
||||
|
||||
See `--help` for all options.
|
||||
|
||||
@ -114,13 +218,18 @@ See `--help` for all options.
|
||||
|
||||
# Compare embedding models (CLIP vs DINOv2)
|
||||
./run_model_comparison.sh
|
||||
|
||||
# Test different refs-per-logo values
|
||||
./run_refs_per_logo_test.sh
|
||||
```
|
||||
|
||||
| Script | Purpose | Output File |
|
||||
|--------|---------|-------------|
|
||||
| `run_comparison_tests.sh` | Compare all 4 matching methods | `comparison_results.txt` |
|
||||
| `run_threshold_tests.sh` | Test threshold/margin combinations | `threshold_test_results.txt` |
|
||||
| `run_model_comparison.sh` | Compare CLIP vs DINOv2 models | `model_comparison_results.txt` |
|
||||
| `run_comparison_tests.sh` | Compare matching methods | `test_results/comparison_*.txt` |
|
||||
| `run_threshold_tests.sh` | Test threshold/margin combinations | `test_results/threshold_*.txt` |
|
||||
| `run_model_comparison.sh` | Compare CLIP vs DINOv2 models | `test_results/model_comparison_results.txt` |
|
||||
| `run_refs_per_logo_test.sh` | Test refs-per-logo values | `test_results/refs_per_logo_analysis.txt` |
|
||||
| `run_preprocess_test.sh` | Compare preprocessing modes | `test_results/preprocessing_comparison.txt` |
|
||||
|
||||
## Project Structure
|
||||
|
||||
|
||||
141
analyze_similarity_distribution.sh
Executable file
141
analyze_similarity_distribution.sh
Executable file
@ -0,0 +1,141 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Analyze similarity distribution for baseline and fine-tuned models.
|
||||
#
|
||||
# This script runs the test with --similarity-details to output detailed
|
||||
# statistics about how the models score matches vs non-matches.
|
||||
#
|
||||
# Usage:
|
||||
# ./analyze_similarity_distribution.sh
|
||||
# ./analyze_similarity_distribution.sh --model finetuned
|
||||
# ./analyze_similarity_distribution.sh --model baseline
|
||||
#
|
||||
|
||||
set -e
|
||||
|
||||
# Default parameters
|
||||
NUM_LOGOS="${NUM_LOGOS:-50}"
|
||||
SEED="${SEED:-42}"
|
||||
THRESHOLD="${THRESHOLD:-0.75}"
|
||||
REFS_PER_LOGO="${REFS_PER_LOGO:-3}"
|
||||
MARGIN="${MARGIN:-0.05}"
|
||||
MODEL="${MODEL:-both}"
|
||||
|
||||
# Model paths
|
||||
BASELINE_MODEL="openai/clip-vit-large-patch14"
|
||||
FINETUNED_MODEL="models/logo_detection/clip_finetuned"
|
||||
|
||||
# Output directory
|
||||
OUTPUT_DIR="similarity_analysis"
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-n|--num-logos)
|
||||
NUM_LOGOS="$2"
|
||||
shift 2
|
||||
;;
|
||||
-s|--seed)
|
||||
SEED="$2"
|
||||
shift 2
|
||||
;;
|
||||
-t|--threshold)
|
||||
THRESHOLD="$2"
|
||||
shift 2
|
||||
;;
|
||||
--model)
|
||||
MODEL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--finetuned-path)
|
||||
FINETUNED_MODEL="$2"
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " -n, --num-logos NUM Number of logos to test (default: 50)"
|
||||
echo " -s, --seed SEED Random seed (default: 42)"
|
||||
echo " -t, --threshold VAL Similarity threshold (default: 0.75)"
|
||||
echo " --model MODEL Which model: 'baseline', 'finetuned', or 'both' (default: both)"
|
||||
echo " --finetuned-path PATH Path to fine-tuned model"
|
||||
echo " -h, --help Show this help message"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Create output directory
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
|
||||
echo "============================================================"
|
||||
echo "SIMILARITY DISTRIBUTION ANALYSIS"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
echo "Parameters:"
|
||||
echo " Number of logos: ${NUM_LOGOS}"
|
||||
echo " Random seed: ${SEED}"
|
||||
echo " Threshold: ${THRESHOLD}"
|
||||
echo " Refs per logo: ${REFS_PER_LOGO}"
|
||||
echo " Margin: ${MARGIN}"
|
||||
echo " Model: ${MODEL}"
|
||||
echo ""
|
||||
|
||||
# Common test arguments
|
||||
TEST_ARGS=(
|
||||
-n "${NUM_LOGOS}"
|
||||
-s "${SEED}"
|
||||
-t "${THRESHOLD}"
|
||||
--refs-per-logo "${REFS_PER_LOGO}"
|
||||
--margin "${MARGIN}"
|
||||
--matching-method multi-ref
|
||||
--similarity-details
|
||||
--clear-cache
|
||||
)
|
||||
|
||||
run_analysis() {
|
||||
local model_name="$1"
|
||||
local model_path="$2"
|
||||
local output_file="${OUTPUT_DIR}/${model_name}_similarity_${TIMESTAMP}.txt"
|
||||
|
||||
echo "============================================================"
|
||||
echo "Analyzing: ${model_name}"
|
||||
echo "Model: ${model_path}"
|
||||
echo "Output: ${output_file}"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
|
||||
uv run python test_logo_detection.py \
|
||||
"${TEST_ARGS[@]}" \
|
||||
-e "${model_path}" \
|
||||
2>&1 | tee "${output_file}"
|
||||
|
||||
echo ""
|
||||
echo "Results saved to: ${output_file}"
|
||||
echo ""
|
||||
}
|
||||
|
||||
# Run analysis based on model selection
|
||||
if [[ "${MODEL}" == "baseline" ]] || [[ "${MODEL}" == "both" ]]; then
|
||||
run_analysis "baseline" "${BASELINE_MODEL}"
|
||||
fi
|
||||
|
||||
if [[ "${MODEL}" == "finetuned" ]] || [[ "${MODEL}" == "both" ]]; then
|
||||
if [ ! -d "${FINETUNED_MODEL}" ]; then
|
||||
echo "Warning: Fine-tuned model not found at ${FINETUNED_MODEL}"
|
||||
echo "Skipping fine-tuned model analysis."
|
||||
else
|
||||
run_analysis "finetuned" "${FINETUNED_MODEL}"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "============================================================"
|
||||
echo "Analysis complete!"
|
||||
echo "Results saved to: ${OUTPUT_DIR}/"
|
||||
echo "============================================================"
|
||||
191
compare_finetuned_vs_baseline.sh
Executable file
191
compare_finetuned_vs_baseline.sh
Executable file
@ -0,0 +1,191 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Compare fine-tuned CLIP model against baseline CLIP for logo recognition.
|
||||
#
|
||||
# This script runs the same test suite on both models and outputs results
|
||||
# for easy comparison.
|
||||
#
|
||||
# Usage:
|
||||
# ./compare_finetuned_vs_baseline.sh
|
||||
# ./compare_finetuned_vs_baseline.sh --num-logos 100
|
||||
#
|
||||
|
||||
set -e
|
||||
|
||||
# Default parameters
|
||||
NUM_LOGOS="${NUM_LOGOS:-50}"
|
||||
SEED="${SEED:-42}"
|
||||
THRESHOLD="${THRESHOLD:-0.7}"
|
||||
DETR_THRESHOLD="${DETR_THRESHOLD:-0.5}"
|
||||
REFS_PER_LOGO="${REFS_PER_LOGO:-3}"
|
||||
MARGIN="${MARGIN:-0.05}"
|
||||
POSITIVE_SAMPLES="${POSITIVE_SAMPLES:-5}"
|
||||
NEGATIVE_SAMPLES="${NEGATIVE_SAMPLES:-20}"
|
||||
|
||||
# Model paths
|
||||
BASELINE_MODEL="openai/clip-vit-large-patch14"
|
||||
FINETUNED_MODEL="models/logo_detection/clip_finetuned"
|
||||
|
||||
# Output files
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
OUTPUT_DIR="comparison_results"
|
||||
BASELINE_OUTPUT="${OUTPUT_DIR}/baseline_${TIMESTAMP}.txt"
|
||||
FINETUNED_OUTPUT="${OUTPUT_DIR}/finetuned_${TIMESTAMP}.txt"
|
||||
SUMMARY_OUTPUT="${OUTPUT_DIR}/comparison_summary_${TIMESTAMP}.txt"
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-n|--num-logos)
|
||||
NUM_LOGOS="$2"
|
||||
shift 2
|
||||
;;
|
||||
-s|--seed)
|
||||
SEED="$2"
|
||||
shift 2
|
||||
;;
|
||||
-t|--threshold)
|
||||
THRESHOLD="$2"
|
||||
shift 2
|
||||
;;
|
||||
--refs-per-logo)
|
||||
REFS_PER_LOGO="$2"
|
||||
shift 2
|
||||
;;
|
||||
--margin)
|
||||
MARGIN="$2"
|
||||
shift 2
|
||||
;;
|
||||
--finetuned-model)
|
||||
FINETUNED_MODEL="$2"
|
||||
shift 2
|
||||
;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " -n, --num-logos NUM Number of logos to test (default: 50)"
|
||||
echo " -s, --seed SEED Random seed for reproducibility (default: 42)"
|
||||
echo " -t, --threshold VAL Similarity threshold (default: 0.7)"
|
||||
echo " --refs-per-logo NUM Reference images per logo (default: 3)"
|
||||
echo " --margin VAL Margin for matching (default: 0.05)"
|
||||
echo " --finetuned-model PATH Path to fine-tuned model"
|
||||
echo " -h, --help Show this help message"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Create output directory
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
|
||||
# Check if fine-tuned model exists
|
||||
if [ ! -d "${FINETUNED_MODEL}" ]; then
|
||||
echo "Error: Fine-tuned model not found at ${FINETUNED_MODEL}"
|
||||
echo "Please train the model first using: uv run python train_clip_logo.py --config configs/jetson_orin.yaml"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "============================================================"
|
||||
echo "CLIP Logo Recognition: Fine-tuned vs Baseline Comparison"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
echo "Parameters:"
|
||||
echo " Number of logos: ${NUM_LOGOS}"
|
||||
echo " Random seed: ${SEED}"
|
||||
echo " Threshold: ${THRESHOLD}"
|
||||
echo " DETR threshold: ${DETR_THRESHOLD}"
|
||||
echo " Refs per logo: ${REFS_PER_LOGO}"
|
||||
echo " Margin: ${MARGIN}"
|
||||
echo " Positive samples: ${POSITIVE_SAMPLES}"
|
||||
echo " Negative samples: ${NEGATIVE_SAMPLES}"
|
||||
echo ""
|
||||
echo "Models:"
|
||||
echo " Baseline: ${BASELINE_MODEL}"
|
||||
echo " Fine-tuned: ${FINETUNED_MODEL}"
|
||||
echo ""
|
||||
echo "Output:"
|
||||
echo " Baseline results: ${BASELINE_OUTPUT}"
|
||||
echo " Fine-tuned results: ${FINETUNED_OUTPUT}"
|
||||
echo " Summary: ${SUMMARY_OUTPUT}"
|
||||
echo ""
|
||||
|
||||
# Common test arguments
|
||||
TEST_ARGS=(
|
||||
-n "${NUM_LOGOS}"
|
||||
-s "${SEED}"
|
||||
-t "${THRESHOLD}"
|
||||
-d "${DETR_THRESHOLD}"
|
||||
--refs-per-logo "${REFS_PER_LOGO}"
|
||||
--margin "${MARGIN}"
|
||||
--positive-samples "${POSITIVE_SAMPLES}"
|
||||
--negative-samples "${NEGATIVE_SAMPLES}"
|
||||
--matching-method multi-ref
|
||||
--clear-cache
|
||||
)
|
||||
|
||||
# Run baseline test
|
||||
echo "============================================================"
|
||||
echo "Testing BASELINE model: ${BASELINE_MODEL}"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
|
||||
uv run python test_logo_detection.py \
|
||||
"${TEST_ARGS[@]}" \
|
||||
-e "${BASELINE_MODEL}" \
|
||||
2>&1 | tee "${BASELINE_OUTPUT}"
|
||||
|
||||
echo ""
|
||||
echo "Baseline results saved to: ${BASELINE_OUTPUT}"
|
||||
echo ""
|
||||
|
||||
# Run fine-tuned test
|
||||
echo "============================================================"
|
||||
echo "Testing FINE-TUNED model: ${FINETUNED_MODEL}"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
|
||||
uv run python test_logo_detection.py \
|
||||
"${TEST_ARGS[@]}" \
|
||||
-e "${FINETUNED_MODEL}" \
|
||||
2>&1 | tee "${FINETUNED_OUTPUT}"
|
||||
|
||||
echo ""
|
||||
echo "Fine-tuned results saved to: ${FINETUNED_OUTPUT}"
|
||||
echo ""
|
||||
|
||||
# Extract and compare key metrics
|
||||
echo "============================================================"
|
||||
echo "COMPARISON SUMMARY"
|
||||
echo "============================================================" | tee "${SUMMARY_OUTPUT}"
|
||||
echo "" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo "Test Parameters:" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo " Logos: ${NUM_LOGOS}, Seed: ${SEED}, Threshold: ${THRESHOLD}" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo " Method: multi-ref, Refs/logo: ${REFS_PER_LOGO}, Margin: ${MARGIN}" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo "" | tee -a "${SUMMARY_OUTPUT}"
|
||||
|
||||
echo "BASELINE (${BASELINE_MODEL}):" | tee -a "${SUMMARY_OUTPUT}"
|
||||
grep -E "(Precision|Recall|F1 Score|True Positives|False Positives|False Negatives)" "${BASELINE_OUTPUT}" | head -6 | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo "" | tee -a "${SUMMARY_OUTPUT}"
|
||||
|
||||
echo "FINE-TUNED (${FINETUNED_MODEL}):" | tee -a "${SUMMARY_OUTPUT}"
|
||||
grep -E "(Precision|Recall|F1 Score|True Positives|False Positives|False Negatives)" "${FINETUNED_OUTPUT}" | head -6 | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo "" | tee -a "${SUMMARY_OUTPUT}"
|
||||
|
||||
# Extract F1 scores for quick comparison
|
||||
BASELINE_F1=$(grep "F1 Score" "${BASELINE_OUTPUT}" | head -1 | grep -oE "[0-9]+\.[0-9]+%" | head -1 || echo "N/A")
|
||||
FINETUNED_F1=$(grep "F1 Score" "${FINETUNED_OUTPUT}" | head -1 | grep -oE "[0-9]+\.[0-9]+%" | head -1 || echo "N/A")
|
||||
|
||||
echo "------------------------------------------------------------" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo "F1 SCORE COMPARISON:" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo " Baseline: ${BASELINE_F1}" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo " Fine-tuned: ${FINETUNED_F1}" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo "------------------------------------------------------------" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo "" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo "Full results saved to: ${OUTPUT_DIR}/" | tee -a "${SUMMARY_OUTPUT}"
|
||||
echo ""
|
||||
echo "Done!"
|
||||
64
configs/cloud_a100.yaml
Normal file
64
configs/cloud_a100.yaml
Normal file
@ -0,0 +1,64 @@
|
||||
# Training configuration optimized for cloud A100 / H100 (80GB VRAM)
|
||||
#
|
||||
# Usage:
|
||||
# python train_clip_logo.py --config configs/cloud_a100.yaml
|
||||
#
|
||||
# Estimated training time: 1.5-3 hours
|
||||
# Estimated cost on RunPod: ~$3-6
|
||||
|
||||
# Base model
|
||||
base_model: "openai/clip-vit-large-patch14"
|
||||
|
||||
# Dataset paths
|
||||
dataset_dir: "LogoDet-3K"
|
||||
reference_dir: "reference_logos"
|
||||
db_path: "test_data_mapping.db"
|
||||
|
||||
# Data splits
|
||||
train_split: 0.7
|
||||
val_split: 0.15
|
||||
test_split: 0.15
|
||||
|
||||
# Maximum batch sizes for 80GB VRAM
|
||||
batch_size: 64
|
||||
logos_per_batch: 32
|
||||
samples_per_logo: 4
|
||||
gradient_accumulation_steps: 2 # Effective batch = 128
|
||||
num_workers: 8
|
||||
|
||||
# Model architecture (no gradient checkpointing needed with 80GB)
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
freeze_layers: 12
|
||||
use_gradient_checkpointing: false
|
||||
|
||||
# Training
|
||||
learning_rate: 1.0e-5
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 500
|
||||
max_epochs: 20
|
||||
mixed_precision: true
|
||||
|
||||
# Loss
|
||||
temperature: 0.07
|
||||
loss_type: "infonce"
|
||||
triplet_margin: 0.3
|
||||
|
||||
# Early stopping
|
||||
patience: 5
|
||||
min_delta: 0.001
|
||||
|
||||
# Output
|
||||
checkpoint_dir: "checkpoints"
|
||||
output_dir: "models/logo_detection/clip_finetuned"
|
||||
save_every_n_epochs: 2 # Save more frequently for cloud
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: 10
|
||||
eval_every_n_epochs: 1
|
||||
|
||||
seed: 42
|
||||
use_hard_negatives: false
|
||||
use_augmentation: true
|
||||
augmentation_strength: "medium"
|
||||
64
configs/cloud_rtx4090.yaml
Normal file
64
configs/cloud_rtx4090.yaml
Normal file
@ -0,0 +1,64 @@
|
||||
# Training configuration optimized for cloud RTX 4090 / RTX 3090 (24GB VRAM)
|
||||
#
|
||||
# Usage:
|
||||
# python train_clip_logo.py --config configs/cloud_rtx4090.yaml
|
||||
#
|
||||
# Estimated training time: 4-6 hours
|
||||
# Estimated cost on RunPod: ~$3
|
||||
|
||||
# Base model
|
||||
base_model: "openai/clip-vit-large-patch14"
|
||||
|
||||
# Dataset paths
|
||||
dataset_dir: "LogoDet-3K"
|
||||
reference_dir: "reference_logos"
|
||||
db_path: "test_data_mapping.db"
|
||||
|
||||
# Data splits
|
||||
train_split: 0.7
|
||||
val_split: 0.15
|
||||
test_split: 0.15
|
||||
|
||||
# Larger batches for faster training on 24GB VRAM
|
||||
batch_size: 32
|
||||
logos_per_batch: 32
|
||||
samples_per_logo: 4
|
||||
gradient_accumulation_steps: 4 # Effective batch = 128
|
||||
num_workers: 8
|
||||
|
||||
# Model architecture
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
freeze_layers: 12
|
||||
use_gradient_checkpointing: true
|
||||
|
||||
# Training
|
||||
learning_rate: 1.0e-5
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 500
|
||||
max_epochs: 20
|
||||
mixed_precision: true
|
||||
|
||||
# Loss
|
||||
temperature: 0.07
|
||||
loss_type: "infonce"
|
||||
triplet_margin: 0.3
|
||||
|
||||
# Early stopping
|
||||
patience: 5
|
||||
min_delta: 0.001
|
||||
|
||||
# Output
|
||||
checkpoint_dir: "checkpoints"
|
||||
output_dir: "models/logo_detection/clip_finetuned"
|
||||
save_every_n_epochs: 2 # Save more frequently for cloud
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: 10
|
||||
eval_every_n_epochs: 1
|
||||
|
||||
seed: 42
|
||||
use_hard_negatives: false
|
||||
use_augmentation: true
|
||||
augmentation_strength: "medium"
|
||||
70
configs/cloud_rtx4090_image_split.yaml
Normal file
70
configs/cloud_rtx4090_image_split.yaml
Normal file
@ -0,0 +1,70 @@
|
||||
# Training configuration for RTX 4090 (24GB VRAM) with IMAGE-LEVEL splits
|
||||
#
|
||||
# Combines RTX 4090 hardware optimizations with image-level splitting and
|
||||
# gentler contrastive learning for better generalization.
|
||||
#
|
||||
# Usage:
|
||||
# python train_clip_logo.py --config configs/cloud_rtx4090_image_split.yaml
|
||||
#
|
||||
# Estimated training time: 5-7 hours (more epochs than logo-level)
|
||||
# Estimated cost on RunPod: ~$4
|
||||
|
||||
# Base model
|
||||
base_model: "openai/clip-vit-large-patch14"
|
||||
|
||||
# Dataset paths
|
||||
dataset_dir: "LogoDet-3K"
|
||||
reference_dir: "reference_logos"
|
||||
db_path: "test_data_mapping.db"
|
||||
|
||||
# Data split configuration - IMAGE LEVEL
|
||||
# Each logo brand will have images in all splits, allowing the model
|
||||
# to see some examples of each brand during training.
|
||||
split_level: "image"
|
||||
train_split: 0.7
|
||||
val_split: 0.15
|
||||
test_split: 0.15
|
||||
|
||||
# Larger batches for faster training on 24GB VRAM
|
||||
batch_size: 32
|
||||
logos_per_batch: 32
|
||||
samples_per_logo: 4
|
||||
gradient_accumulation_steps: 4 # Effective batch = 128
|
||||
num_workers: 8
|
||||
|
||||
# Model architecture
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
freeze_layers: 12
|
||||
use_gradient_checkpointing: true
|
||||
|
||||
# Training - GENTLER settings for better generalization
|
||||
learning_rate: 5.0e-6 # Reduced from 1e-5
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 500
|
||||
max_epochs: 30 # More epochs with slower learning
|
||||
mixed_precision: true
|
||||
|
||||
# Loss - HIGHER temperature for softer contrastive learning
|
||||
temperature: 0.15 # Increased from 0.07
|
||||
loss_type: "infonce"
|
||||
triplet_margin: 0.2 # Reduced from 0.3
|
||||
|
||||
# Early stopping - more patience with gentler learning
|
||||
patience: 7
|
||||
min_delta: 0.001
|
||||
|
||||
# Output - separate directory for image-split model
|
||||
checkpoint_dir: "checkpoints_image_split"
|
||||
output_dir: "models/logo_detection/clip_finetuned_image_split"
|
||||
save_every_n_epochs: 2 # Save frequently for cloud
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: 10
|
||||
eval_every_n_epochs: 1
|
||||
|
||||
seed: 42
|
||||
use_hard_negatives: false
|
||||
use_augmentation: true
|
||||
augmentation_strength: "medium"
|
||||
78
configs/image_level_splits.yaml
Normal file
78
configs/image_level_splits.yaml
Normal file
@ -0,0 +1,78 @@
|
||||
# Training configuration with IMAGE-LEVEL splits
|
||||
#
|
||||
# Unlike logo-level splits where test logos are completely unseen brands,
|
||||
# image-level splits allow the model to see some images from each brand
|
||||
# during training. This is less rigorous but more representative of
|
||||
# real-world use where you have reference images for logos you want to detect.
|
||||
#
|
||||
# Also uses gentler contrastive learning settings to prevent over-separation.
|
||||
#
|
||||
# Usage:
|
||||
# uv run python train_clip_logo.py --config configs/image_level_splits.yaml
|
||||
|
||||
# Base model
|
||||
base_model: "openai/clip-vit-large-patch14"
|
||||
|
||||
# Dataset paths (relative to project root)
|
||||
dataset_dir: "LogoDet-3K"
|
||||
reference_dir: "reference_logos"
|
||||
db_path: "test_data_mapping.db"
|
||||
|
||||
# Data split configuration
|
||||
# split_level: "image" means images are split, not logo brands
|
||||
# This allows test set to contain images from brands seen during training
|
||||
split_level: "image"
|
||||
train_split: 0.7
|
||||
val_split: 0.15
|
||||
test_split: 0.15
|
||||
|
||||
# Batch construction
|
||||
batch_size: 16
|
||||
logos_per_batch: 32
|
||||
samples_per_logo: 4
|
||||
gradient_accumulation_steps: 8
|
||||
num_workers: 4
|
||||
|
||||
# Model architecture - same as before
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
freeze_layers: 12
|
||||
use_gradient_checkpointing: true
|
||||
|
||||
# Training hyperparameters - GENTLER settings
|
||||
learning_rate: 5.0e-6 # Reduced from 1e-5
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 500
|
||||
max_epochs: 30 # More epochs with slower learning
|
||||
mixed_precision: true
|
||||
|
||||
# Loss function - HIGHER temperature for softer contrastive learning
|
||||
temperature: 0.15 # Increased from 0.07
|
||||
loss_type: "infonce"
|
||||
triplet_margin: 0.2 # Reduced from 0.3
|
||||
|
||||
# Early stopping
|
||||
patience: 7 # More patience with gentler learning
|
||||
min_delta: 0.001
|
||||
|
||||
# Checkpoints and output
|
||||
checkpoint_dir: "checkpoints_image_split"
|
||||
output_dir: "models/logo_detection/clip_finetuned_image_split"
|
||||
save_every_n_epochs: 5
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: 10
|
||||
eval_every_n_epochs: 1
|
||||
|
||||
# Reproducibility
|
||||
seed: 42
|
||||
|
||||
# Hard negative mining
|
||||
use_hard_negatives: false
|
||||
hard_negative_start_epoch: 10
|
||||
hard_negatives_per_logo: 10
|
||||
|
||||
# Data augmentation
|
||||
use_augmentation: true
|
||||
augmentation_strength: "medium"
|
||||
76
configs/jetson_orin.yaml
Normal file
76
configs/jetson_orin.yaml
Normal file
@ -0,0 +1,76 @@
|
||||
# Training configuration optimized for Jetson Orin AGX (~64GB shared memory)
|
||||
#
|
||||
# Usage:
|
||||
# uv run python train_clip_logo.py --config configs/jetson_orin.yaml
|
||||
|
||||
# Base model
|
||||
base_model: "openai/clip-vit-large-patch14"
|
||||
|
||||
# Dataset paths (relative to project root)
|
||||
dataset_dir: "LogoDet-3K"
|
||||
reference_dir: "reference_logos"
|
||||
db_path: "test_data_mapping.db"
|
||||
|
||||
# Data split ratios (logo-level split for generalization testing)
|
||||
train_split: 0.7
|
||||
val_split: 0.15
|
||||
test_split: 0.15
|
||||
|
||||
# Batch construction
|
||||
# - batch_size: Number of batches loaded at once (keep low for memory)
|
||||
# - logos_per_batch: Different logo classes per contrastive batch
|
||||
# - samples_per_logo: Samples of each logo (creates positive pairs)
|
||||
# - Effective samples per step = logos_per_batch * samples_per_logo = 128
|
||||
batch_size: 16
|
||||
logos_per_batch: 32
|
||||
samples_per_logo: 4
|
||||
gradient_accumulation_steps: 8 # Effective batch = 128
|
||||
num_workers: 4
|
||||
|
||||
# Model architecture
|
||||
# LoRA enables memory-efficient fine-tuning by training low-rank adapters
|
||||
# instead of full model weights
|
||||
lora_r: 16 # LoRA rank (0 to disable)
|
||||
lora_alpha: 32 # LoRA scaling factor
|
||||
lora_dropout: 0.1 # Dropout in LoRA layers
|
||||
freeze_layers: 12 # Freeze first 12 of 24 transformer layers
|
||||
use_gradient_checkpointing: true # Trade compute for memory
|
||||
|
||||
# Training hyperparameters
|
||||
learning_rate: 1.0e-5 # Conservative LR for fine-tuning
|
||||
weight_decay: 0.01 # L2 regularization
|
||||
warmup_steps: 500 # LR warmup steps
|
||||
max_epochs: 20 # Maximum training epochs
|
||||
mixed_precision: true # FP16 training for memory efficiency
|
||||
|
||||
# Loss function
|
||||
# InfoNCE is the contrastive loss used in CLIP training
|
||||
temperature: 0.07 # Similarity scaling (0.05-0.1 typical)
|
||||
loss_type: "infonce" # Options: infonce, supcon, triplet, combined
|
||||
triplet_margin: 0.3 # Only used if loss_type is triplet
|
||||
|
||||
# Early stopping
|
||||
patience: 5 # Stop if no improvement for N epochs
|
||||
min_delta: 0.001 # Minimum improvement threshold
|
||||
|
||||
# Checkpoints and output
|
||||
checkpoint_dir: "checkpoints"
|
||||
output_dir: "models/logo_detection/clip_finetuned"
|
||||
save_every_n_epochs: 5
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: 10
|
||||
eval_every_n_epochs: 1
|
||||
|
||||
# Reproducibility
|
||||
seed: 42
|
||||
|
||||
# Hard negative mining (advanced)
|
||||
# Enable after initial training epochs for harder examples
|
||||
use_hard_negatives: false
|
||||
hard_negative_start_epoch: 5
|
||||
hard_negatives_per_logo: 10
|
||||
|
||||
# Data augmentation
|
||||
use_augmentation: true
|
||||
augmentation_strength: "medium" # light, medium, or strong
|
||||
169
export_model.py
Normal file
169
export_model.py
Normal file
@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Export a trained CLIP model to HuggingFace-compatible format.
|
||||
|
||||
This script converts a training checkpoint to a format that can be
|
||||
loaded by DetectLogosDETR for inference.
|
||||
|
||||
Usage:
|
||||
uv run python export_model.py \
|
||||
--checkpoint checkpoints/best.pt \
|
||||
--output models/logo_detection/clip_finetuned
|
||||
|
||||
# With custom base model
|
||||
uv run python export_model.py \
|
||||
--checkpoint checkpoints/best.pt \
|
||||
--output models/logo_detection/clip_finetuned \
|
||||
--base-model openai/clip-vit-large-patch14
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from training.config import TrainingConfig
|
||||
from training.model import create_model, LogoFineTunedCLIP
|
||||
|
||||
|
||||
def setup_logging() -> logging.Logger:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export trained CLIP model for inference",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to training checkpoint (.pt file)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory for exported model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base CLIP model (reads from checkpoint config if not specified)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merge-lora",
|
||||
action="store_true",
|
||||
help="Merge LoRA weights into base model (reduces inference overhead)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logger = setup_logging()
|
||||
|
||||
logger.info("CLIP Model Export")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Check checkpoint exists
|
||||
checkpoint_path = Path(args.checkpoint)
|
||||
if not checkpoint_path.exists():
|
||||
logger.error(f"Checkpoint not found: {checkpoint_path}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load checkpoint
|
||||
logger.info(f"Loading checkpoint: {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
# Get config from checkpoint
|
||||
if "config" in checkpoint:
|
||||
config_dict = checkpoint["config"]
|
||||
base_model = args.base_model or config_dict.get(
|
||||
"base_model", "openai/clip-vit-large-patch14"
|
||||
)
|
||||
lora_r = config_dict.get("lora_r", 16)
|
||||
lora_alpha = config_dict.get("lora_alpha", 32)
|
||||
freeze_layers = config_dict.get("freeze_layers", 12)
|
||||
else:
|
||||
base_model = args.base_model or "openai/clip-vit-large-patch14"
|
||||
lora_r = 16
|
||||
lora_alpha = 32
|
||||
freeze_layers = 12
|
||||
|
||||
logger.info(f"Base model: {base_model}")
|
||||
logger.info(f"LoRA rank: {lora_r}")
|
||||
logger.info(f"Freeze layers: {freeze_layers}")
|
||||
|
||||
# Create model with same architecture
|
||||
logger.info("Creating model architecture...")
|
||||
model, processor = create_model(
|
||||
base_model=base_model,
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
freeze_layers=freeze_layers,
|
||||
use_gradient_checkpointing=False, # Not needed for export
|
||||
)
|
||||
|
||||
# Load weights
|
||||
logger.info("Loading trained weights...")
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
# Merge LoRA if requested
|
||||
if args.merge_lora and model.peft_applied:
|
||||
try:
|
||||
logger.info("Merging LoRA weights into base model...")
|
||||
model.vision_model = model.vision_model.merge_and_unload()
|
||||
model.peft_applied = False
|
||||
model.lora_r = 0
|
||||
logger.info("LoRA weights merged successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not merge LoRA weights: {e}")
|
||||
logger.warning("Exporting with separate LoRA weights")
|
||||
|
||||
# Create output directory
|
||||
output_path = Path(args.output)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
logger.info(f"Exporting to: {output_path}")
|
||||
model.save_pretrained(str(output_path))
|
||||
|
||||
# Save processor config for reference
|
||||
processor.save_pretrained(str(output_path / "processor"))
|
||||
|
||||
# Save additional metadata
|
||||
metadata = {
|
||||
"base_model": base_model,
|
||||
"source_checkpoint": str(checkpoint_path),
|
||||
"training_epochs": checkpoint.get("epoch", -1) + 1,
|
||||
"best_val_loss": checkpoint.get("best_val_loss", None),
|
||||
"best_val_separation": checkpoint.get("best_val_separation", None),
|
||||
"lora_merged": args.merge_lora and not model.peft_applied,
|
||||
}
|
||||
|
||||
with open(output_path / "export_metadata.json", "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logger.info("\nExport complete!")
|
||||
logger.info(f"Model saved to: {output_path}")
|
||||
logger.info("\nTo use with DetectLogosDETR:")
|
||||
logger.info(f" detector = DetectLogosDETR(embedding_model='{output_path}')")
|
||||
logger.info("\nOr with test_logo_detection.py:")
|
||||
logger.info(f" uv run python test_logo_detection.py -e {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
168
find_optimal_threshold.sh
Executable file
168
find_optimal_threshold.sh
Executable file
@ -0,0 +1,168 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Find optimal similarity threshold for logo detection.
|
||||
#
|
||||
# Tests a range of thresholds and outputs precision/recall/F1 for each.
|
||||
#
|
||||
# Usage:
|
||||
# ./find_optimal_threshold.sh
|
||||
# ./find_optimal_threshold.sh --model finetuned
|
||||
# ./find_optimal_threshold.sh --model baseline
|
||||
# ./find_optimal_threshold.sh --thresholds "0.70 0.75 0.80 0.85"
|
||||
#
|
||||
|
||||
set -e
|
||||
|
||||
# Default parameters
|
||||
NUM_LOGOS="${NUM_LOGOS:-50}"
|
||||
SEED="${SEED:-42}"
|
||||
REFS_PER_LOGO="${REFS_PER_LOGO:-3}"
|
||||
MARGIN="${MARGIN:-0.05}"
|
||||
MODEL="${MODEL:-finetuned}"
|
||||
USE_MAX_SIM="${USE_MAX_SIM:-false}"
|
||||
|
||||
# Default thresholds to test
|
||||
THRESHOLDS="${THRESHOLDS:-0.70 0.72 0.74 0.76 0.78 0.80 0.82 0.84 0.86}"
|
||||
|
||||
# Model paths
|
||||
BASELINE_MODEL="openai/clip-vit-large-patch14"
|
||||
FINETUNED_MODEL="models/logo_detection/clip_finetuned"
|
||||
|
||||
# Output
|
||||
OUTPUT_DIR="threshold_analysis"
|
||||
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||
|
||||
# Parse command line arguments
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case $1 in
|
||||
-n|--num-logos)
|
||||
NUM_LOGOS="$2"
|
||||
shift 2
|
||||
;;
|
||||
-s|--seed)
|
||||
SEED="$2"
|
||||
shift 2
|
||||
;;
|
||||
--model)
|
||||
MODEL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--thresholds)
|
||||
THRESHOLDS="$2"
|
||||
shift 2
|
||||
;;
|
||||
--finetuned-path)
|
||||
FINETUNED_MODEL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--use-max-similarity)
|
||||
USE_MAX_SIM="true"
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
echo "Usage: $0 [OPTIONS]"
|
||||
echo ""
|
||||
echo "Options:"
|
||||
echo " -n, --num-logos NUM Number of logos to test (default: 50)"
|
||||
echo " -s, --seed SEED Random seed (default: 42)"
|
||||
echo " --model MODEL Which model: 'baseline' or 'finetuned' (default: finetuned)"
|
||||
echo " --thresholds \"T1 T2 ...\" Space-separated thresholds to test"
|
||||
echo " --finetuned-path PATH Path to fine-tuned model"
|
||||
echo " --use-max-similarity Use max instead of mean for multi-ref aggregation"
|
||||
echo " -h, --help Show this help message"
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "Unknown option: $1"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Select model path
|
||||
if [[ "${MODEL}" == "baseline" ]]; then
|
||||
MODEL_PATH="${BASELINE_MODEL}"
|
||||
else
|
||||
MODEL_PATH="${FINETUNED_MODEL}"
|
||||
fi
|
||||
|
||||
# Check if fine-tuned model exists
|
||||
if [[ "${MODEL}" == "finetuned" ]] && [ ! -d "${FINETUNED_MODEL}" ]; then
|
||||
echo "Error: Fine-tuned model not found at ${FINETUNED_MODEL}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create output directory
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
OUTPUT_FILE="${OUTPUT_DIR}/${MODEL}_thresholds_${TIMESTAMP}.txt"
|
||||
|
||||
echo "============================================================"
|
||||
echo "THRESHOLD OPTIMIZATION"
|
||||
echo "============================================================"
|
||||
echo ""
|
||||
echo "Model: ${MODEL} (${MODEL_PATH})"
|
||||
echo "Thresholds: ${THRESHOLDS}"
|
||||
echo "Logos: ${NUM_LOGOS}"
|
||||
echo "Seed: ${SEED}"
|
||||
echo "Max sim: ${USE_MAX_SIM}"
|
||||
echo "Output: ${OUTPUT_FILE}"
|
||||
echo ""
|
||||
|
||||
# Header for results
|
||||
echo "============================================================" | tee "${OUTPUT_FILE}"
|
||||
echo "THRESHOLD OPTIMIZATION RESULTS" | tee -a "${OUTPUT_FILE}"
|
||||
echo "Model: ${MODEL} (${MODEL_PATH})" | tee -a "${OUTPUT_FILE}"
|
||||
echo "============================================================" | tee -a "${OUTPUT_FILE}"
|
||||
echo "" | tee -a "${OUTPUT_FILE}"
|
||||
printf "%-10s %8s %8s %8s %8s %8s %8s\n" "Threshold" "TP" "FP" "FN" "Prec" "Recall" "F1" | tee -a "${OUTPUT_FILE}"
|
||||
echo "--------------------------------------------------------------------" | tee -a "${OUTPUT_FILE}"
|
||||
|
||||
# Track best F1
|
||||
BEST_F1=0
|
||||
BEST_THRESHOLD=""
|
||||
|
||||
# Build extra args
|
||||
EXTRA_ARGS=""
|
||||
if [[ "${USE_MAX_SIM}" == "true" ]]; then
|
||||
EXTRA_ARGS="--use-max-similarity"
|
||||
fi
|
||||
|
||||
# Test each threshold
|
||||
for THRESHOLD in ${THRESHOLDS}; do
|
||||
# Run test and capture output
|
||||
OUTPUT=$(uv run python test_logo_detection.py \
|
||||
-n "${NUM_LOGOS}" \
|
||||
-s "${SEED}" \
|
||||
-t "${THRESHOLD}" \
|
||||
--refs-per-logo "${REFS_PER_LOGO}" \
|
||||
--margin "${MARGIN}" \
|
||||
--matching-method multi-ref \
|
||||
-e "${MODEL_PATH}" \
|
||||
${EXTRA_ARGS} \
|
||||
2>/dev/null)
|
||||
|
||||
# Extract metrics
|
||||
TP=$(echo "${OUTPUT}" | grep "True Positives" | grep -oE "[0-9]+" | head -1)
|
||||
FP=$(echo "${OUTPUT}" | grep "False Positives" | grep -oE "[0-9]+" | head -1)
|
||||
FN=$(echo "${OUTPUT}" | grep "False Negatives" | grep -oE "[0-9]+" | head -1)
|
||||
PREC=$(echo "${OUTPUT}" | grep "Precision:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||
RECALL=$(echo "${OUTPUT}" | grep "Recall:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||
F1=$(echo "${OUTPUT}" | grep "F1 Score:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||
|
||||
# Print row
|
||||
printf "%-10s %8s %8s %8s %8s %8s %8s\n" "${THRESHOLD}" "${TP}" "${FP}" "${FN}" "${PREC}" "${RECALL}" "${F1}" | tee -a "${OUTPUT_FILE}"
|
||||
|
||||
# Track best F1
|
||||
F1_NUM=$(echo "${F1}" | tr -d '%')
|
||||
BEST_NUM=$(echo "${BEST_F1}" | tr -d '%')
|
||||
if (( $(echo "${F1_NUM} > ${BEST_NUM}" | bc -l) )); then
|
||||
BEST_F1="${F1}"
|
||||
BEST_THRESHOLD="${THRESHOLD}"
|
||||
fi
|
||||
done
|
||||
|
||||
echo "--------------------------------------------------------------------" | tee -a "${OUTPUT_FILE}"
|
||||
echo "" | tee -a "${OUTPUT_FILE}"
|
||||
echo "BEST THRESHOLD: ${BEST_THRESHOLD} (F1 = ${BEST_F1})" | tee -a "${OUTPUT_FILE}"
|
||||
echo "" | tee -a "${OUTPUT_FILE}"
|
||||
echo "Results saved to: ${OUTPUT_FILE}"
|
||||
@ -13,6 +13,7 @@ Supported embedding models:
|
||||
- DINOv2 models (facebook/dinov2-*): Self-supervised, excellent for visual similarity
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -48,6 +49,7 @@ class DetectLogosDETR:
|
||||
detr_threshold: float = 0.5,
|
||||
min_box_size: int = 20,
|
||||
nms_iou_threshold: float = 0.5,
|
||||
preprocess_mode: str = "default",
|
||||
):
|
||||
"""
|
||||
Initialize DETR and embedding models.
|
||||
@ -63,12 +65,17 @@ class DetectLogosDETR:
|
||||
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||
min_box_size: Minimum width/height in pixels for detected boxes (filters noise)
|
||||
nms_iou_threshold: IoU threshold for Non-Maximum Suppression
|
||||
preprocess_mode: Image preprocessing mode for CLIP:
|
||||
- "default": Use CLIP's default (resize shortest edge + center crop)
|
||||
- "letterbox": Pad to square with black bars, preserving aspect ratio
|
||||
- "stretch": Stretch to square (distorts aspect ratio)
|
||||
"""
|
||||
self.logger = logger
|
||||
self.detr_threshold = detr_threshold
|
||||
self.min_box_size = min_box_size
|
||||
self.nms_iou_threshold = nms_iou_threshold
|
||||
self.embedding_model_name = embedding_model
|
||||
self.preprocess_mode = preprocess_mode
|
||||
|
||||
# Set device
|
||||
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
@ -100,6 +107,10 @@ class DetectLogosDETR:
|
||||
embedding_model, default_embedding_dir, "Embedding"
|
||||
)
|
||||
|
||||
# Check if this is a fine-tuned model
|
||||
if self._is_finetuned_model(embedding_model_path):
|
||||
self._load_finetuned_embedding_model(embedding_model_path)
|
||||
else:
|
||||
# Detect model type and initialize accordingly
|
||||
self.model_type = self._detect_model_type(embedding_model)
|
||||
self.logger.info(f"Loading {self.model_type} embedding model: {embedding_model_path}")
|
||||
@ -111,6 +122,8 @@ class DetectLogosDETR:
|
||||
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
|
||||
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
|
||||
|
||||
if self.preprocess_mode != "default":
|
||||
self.logger.info(f"Image preprocessing mode: {self.preprocess_mode}")
|
||||
self.logger.info("DetectLogosDETR initialization complete")
|
||||
|
||||
def _detect_model_type(self, model_name: str) -> str:
|
||||
@ -124,6 +137,62 @@ class DetectLogosDETR:
|
||||
# Default to generic transformer for unknown models
|
||||
return "transformer"
|
||||
|
||||
def _is_finetuned_model(self, model_path: str) -> bool:
|
||||
"""Check if a model path points to a fine-tuned CLIP model."""
|
||||
config_path = Path(model_path) / "config.json"
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
return config.get("model_type") == "clip_logo_finetuned"
|
||||
except (json.JSONDecodeError, IOError):
|
||||
pass
|
||||
return False
|
||||
|
||||
def _load_finetuned_embedding_model(self, model_path: str) -> None:
|
||||
"""
|
||||
Load a fine-tuned CLIP model from the training module.
|
||||
|
||||
Args:
|
||||
model_path: Path to the fine-tuned model directory
|
||||
"""
|
||||
# Import the fine-tuned model class
|
||||
try:
|
||||
from training.model import LogoFineTunedCLIP
|
||||
except ImportError as e:
|
||||
self.logger.error(
|
||||
f"Cannot import training.model for fine-tuned model: {e}"
|
||||
)
|
||||
raise ImportError(
|
||||
"Fine-tuned model requires the training module. "
|
||||
"Ensure the training/ directory is in your Python path."
|
||||
) from e
|
||||
|
||||
# Load config
|
||||
config_path = Path(model_path) / "config.json"
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
base_model = config.get("base_model", "openai/clip-vit-large-patch14")
|
||||
|
||||
self.logger.info(f"Loading fine-tuned CLIP model from: {model_path}")
|
||||
self.logger.info(f" Base model: {base_model}")
|
||||
|
||||
# Load model using the from_pretrained method
|
||||
self.embedding_model = LogoFineTunedCLIP.from_pretrained(
|
||||
model_path,
|
||||
base_model=base_model,
|
||||
device=self.device,
|
||||
)
|
||||
self.embedding_model.eval()
|
||||
|
||||
# Load processor from base model
|
||||
self.embedding_processor = CLIPProcessor.from_pretrained(base_model)
|
||||
|
||||
# Set model type for embedding extraction
|
||||
self.model_type = "clip_finetuned"
|
||||
self.logger.info("Fine-tuned CLIP model loaded successfully")
|
||||
|
||||
def _resolve_model_path(
|
||||
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
||||
) -> str:
|
||||
@ -341,11 +410,51 @@ class DetectLogosDETR:
|
||||
|
||||
return self._get_embedding_pil(pil_image)
|
||||
|
||||
def _preprocess_image(self, pil_image: Image.Image, target_size: int = 224) -> Image.Image:
|
||||
"""
|
||||
Preprocess image based on the configured preprocessing mode.
|
||||
|
||||
Args:
|
||||
pil_image: PIL Image (RGB format)
|
||||
target_size: Target size for the square output (default 224 for CLIP)
|
||||
|
||||
Returns:
|
||||
Preprocessed PIL Image
|
||||
"""
|
||||
if self.preprocess_mode == "default":
|
||||
# Let the processor handle it (resize shortest edge + center crop)
|
||||
return pil_image
|
||||
|
||||
width, height = pil_image.size
|
||||
|
||||
if self.preprocess_mode == "letterbox":
|
||||
# Pad to square with black bars, preserving aspect ratio
|
||||
max_dim = max(width, height)
|
||||
|
||||
# Create a black square canvas
|
||||
new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
|
||||
|
||||
# Paste the original image centered
|
||||
paste_x = (max_dim - width) // 2
|
||||
paste_y = (max_dim - height) // 2
|
||||
new_image.paste(pil_image, (paste_x, paste_y))
|
||||
|
||||
# Resize to target size
|
||||
return new_image.resize((target_size, target_size), Image.LANCZOS)
|
||||
|
||||
elif self.preprocess_mode == "stretch":
|
||||
# Stretch to square (distorts aspect ratio)
|
||||
return pil_image.resize((target_size, target_size), Image.LANCZOS)
|
||||
|
||||
else:
|
||||
# Unknown mode, return original
|
||||
return pil_image
|
||||
|
||||
def _get_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor:
|
||||
"""
|
||||
Internal method to get embedding from PIL image.
|
||||
|
||||
Handles both CLIP and DINOv2 model types.
|
||||
Handles CLIP, fine-tuned CLIP, and DINOv2 model types.
|
||||
|
||||
Args:
|
||||
pil_image: PIL Image (RGB format)
|
||||
@ -353,6 +462,10 @@ class DetectLogosDETR:
|
||||
Returns:
|
||||
Normalized feature embedding (torch.Tensor)
|
||||
"""
|
||||
# Apply preprocessing if configured
|
||||
if self.preprocess_mode != "default":
|
||||
pil_image = self._preprocess_image(pil_image)
|
||||
|
||||
# Process image through the embedding model
|
||||
inputs = self.embedding_processor(images=pil_image, return_tensors="pt").to(self.device)
|
||||
|
||||
@ -360,6 +473,9 @@ class DetectLogosDETR:
|
||||
if self.model_type == "clip":
|
||||
# CLIP has a dedicated method for image features
|
||||
features = self.embedding_model.get_image_features(**inputs)
|
||||
elif self.model_type == "clip_finetuned":
|
||||
# Fine-tuned CLIP uses get_image_features or forward with pixel_values
|
||||
features = self.embedding_model.get_image_features(**inputs)
|
||||
else:
|
||||
# DINOv2 and other transformers use the CLS token or pooled output
|
||||
outputs = self.embedding_model(**inputs)
|
||||
@ -370,7 +486,8 @@ class DetectLogosDETR:
|
||||
# Use CLS token from last_hidden_state
|
||||
features = outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
# Normalize for cosine similarity
|
||||
# Normalize for cosine similarity (fine-tuned model already normalizes)
|
||||
if self.model_type != "clip_finetuned":
|
||||
features = F.normalize(features, dim=-1)
|
||||
|
||||
return features
|
||||
|
||||
364
logo_detection_embeddings.py
Normal file
364
logo_detection_embeddings.py
Normal file
@ -0,0 +1,364 @@
|
||||
"""
|
||||
Logo detection using DETR for object detection and selectable embedding models for feature matching.
|
||||
|
||||
This module provides a class for detecting logos in images using:
|
||||
1. DETR (DEtection TRansformer) for initial logo region detection
|
||||
2. Selectable embedding model (CLIP, DINOv2, or SigLIP) for feature extraction and matching
|
||||
|
||||
Key features:
|
||||
- Multiple reference images per logo entry, averaged into a single embedding
|
||||
- Cache-aware: averaged embeddings are only recalculated when the filenames list changes
|
||||
- Supports local model directories with fallback to HuggingFace
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoModel,
|
||||
AutoProcessor,
|
||||
CLIPModel,
|
||||
CLIPProcessor,
|
||||
Dinov2Model,
|
||||
pipeline,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class DetectLogosEmbeddings:
|
||||
"""
|
||||
Logo detection class using DETR and a selectable embedding model.
|
||||
|
||||
This class detects logos in images by:
|
||||
1. Using DETR to find potential logo regions (bounding boxes)
|
||||
2. Extracting embeddings for each detected region using the selected model
|
||||
3. Comparing embeddings with averaged reference logo embeddings for identification
|
||||
|
||||
Supported embedding models:
|
||||
- clip: openai/clip-vit-large-patch14
|
||||
- dinov2: facebook/dinov2-base (recommended for visual similarity)
|
||||
- siglip: google/siglip-base-patch16-224
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger,
|
||||
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
|
||||
embedding_model_type: str = "dinov2",
|
||||
detr_threshold: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Initialize DETR and embedding models.
|
||||
|
||||
Args:
|
||||
logger: Logger instance for logging
|
||||
detr_model: HuggingFace model name or local path for DETR object detection
|
||||
embedding_model_type: One of "clip", "dinov2", or "siglip"
|
||||
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||
"""
|
||||
self.logger = logger
|
||||
self.detr_threshold = detr_threshold
|
||||
self.embedding_model_type = embedding_model_type
|
||||
|
||||
# Set device
|
||||
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
self.device_index = 0 if torch.cuda.is_available() else -1
|
||||
self.device = torch.device(self.device_str)
|
||||
|
||||
self.logger.info(
|
||||
f"Initializing DetectLogosEmbeddings on device: {self.device_str}, "
|
||||
f"embedding model: {embedding_model_type}"
|
||||
)
|
||||
|
||||
# --- DETR model ---
|
||||
default_detr_dir = os.environ.get(
|
||||
"LOGO_DETR_MODEL_DIR", "models/logo_detection/detr"
|
||||
)
|
||||
detr_model_path = self._resolve_model_path(detr_model, default_detr_dir, "DETR")
|
||||
|
||||
self.logger.info(f"Loading DETR model: {detr_model_path}")
|
||||
self.detr_pipe = pipeline(
|
||||
task="object-detection",
|
||||
model=detr_model_path,
|
||||
device=self.device_index,
|
||||
use_fast=True,
|
||||
)
|
||||
|
||||
# --- Embedding model ---
|
||||
self._load_embedding_model(embedding_model_type)
|
||||
|
||||
self.logger.info("DetectLogosEmbeddings initialization complete")
|
||||
|
||||
def _load_embedding_model(self, model_type: str) -> None:
|
||||
"""
|
||||
Load the selected embedding model.
|
||||
|
||||
Args:
|
||||
model_type: One of "clip", "dinov2", or "siglip"
|
||||
"""
|
||||
default_embedding_dir = os.environ.get(
|
||||
"LOGO_EMBEDDING_MODEL_DIR", f"models/logo_detection/{model_type}"
|
||||
)
|
||||
|
||||
if model_type == "clip":
|
||||
model_name = "openai/clip-vit-large-patch14"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "CLIP"
|
||||
)
|
||||
self.logger.info(f"Loading CLIP model: {model_path}")
|
||||
self._clip_model = CLIPModel.from_pretrained(model_path).to(self.device)
|
||||
self._clip_processor = CLIPProcessor.from_pretrained(model_path)
|
||||
self._clip_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._clip_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self._clip_model.get_image_features(**inputs)
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
elif model_type == "dinov2":
|
||||
model_name = "facebook/dinov2-base"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "DINOv2"
|
||||
)
|
||||
self.logger.info(f"Loading DINOv2 model: {model_path}")
|
||||
self._dinov2_model = Dinov2Model.from_pretrained(model_path).to(self.device)
|
||||
self._dinov2_processor = AutoImageProcessor.from_pretrained(model_path)
|
||||
self._dinov2_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._dinov2_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self._dinov2_model(**inputs)
|
||||
# Use CLS token embedding
|
||||
features = outputs.last_hidden_state[:, 0, :]
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
elif model_type == "siglip":
|
||||
model_name = "google/siglip-base-patch16-224"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "SigLIP"
|
||||
)
|
||||
self.logger.info(f"Loading SigLIP model: {model_path}")
|
||||
self._siglip_model = AutoModel.from_pretrained(model_path).to(self.device)
|
||||
self._siglip_processor = AutoProcessor.from_pretrained(model_path)
|
||||
self._siglip_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._siglip_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self._siglip_model.get_image_features(**inputs)
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown embedding model type: {model_type}. "
|
||||
f"Use 'clip', 'dinov2', or 'siglip'"
|
||||
)
|
||||
|
||||
self._embed_fn = embed_fn
|
||||
|
||||
def _resolve_model_path(
|
||||
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Resolve model path, checking for local models before using HuggingFace.
|
||||
|
||||
Args:
|
||||
model_name_or_path: HuggingFace model name or absolute path
|
||||
default_local_dir: Default local directory to check
|
||||
model_type: Type of model (for logging)
|
||||
|
||||
Returns:
|
||||
Resolved model path (local path or HuggingFace model name)
|
||||
"""
|
||||
# If it's an absolute path, use it directly
|
||||
if os.path.isabs(model_name_or_path):
|
||||
if os.path.exists(model_name_or_path):
|
||||
self.logger.info(
|
||||
f"{model_type} model: Using local model at {model_name_or_path}"
|
||||
)
|
||||
return model_name_or_path
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"{model_type} model: Local path {model_name_or_path} does not exist, "
|
||||
f"falling back to HuggingFace"
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
# Check if default local directory exists
|
||||
if os.path.exists(default_local_dir):
|
||||
config_file = os.path.join(default_local_dir, "config.json")
|
||||
if os.path.exists(config_file):
|
||||
abs_path = os.path.abspath(default_local_dir)
|
||||
self.logger.info(
|
||||
f"{model_type} model: Found local model at {abs_path}"
|
||||
)
|
||||
return abs_path
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"{model_type} model: Local directory {default_local_dir} exists but "
|
||||
f"is not a valid model (missing config.json)"
|
||||
)
|
||||
|
||||
# Use HuggingFace model name
|
||||
self.logger.info(
|
||||
f"{model_type} model: No local model found, will download from HuggingFace: "
|
||||
f"{model_name_or_path}"
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect logos in an image and return bounding boxes with embeddings.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (BGR format, numpy array)
|
||||
|
||||
Returns:
|
||||
List of dictionaries, each containing:
|
||||
- 'box': dict with 'xmin', 'ymin', 'xmax', 'ymax' (pixel coordinates)
|
||||
- 'score': DETR confidence score (float 0-1)
|
||||
- 'embedding': Feature embedding (torch.Tensor)
|
||||
- 'label': DETR predicted label (string)
|
||||
"""
|
||||
# Convert OpenCV BGR to RGB PIL Image
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
|
||||
# Run DETR detection
|
||||
predictions = self.detr_pipe(pil_image)
|
||||
|
||||
# Filter by threshold and add embeddings
|
||||
detections = []
|
||||
for pred in predictions:
|
||||
score = pred.get("score", 0.0)
|
||||
if score < self.detr_threshold:
|
||||
continue
|
||||
|
||||
box = pred.get("box", {})
|
||||
xmin = box.get("xmin", 0)
|
||||
ymin = box.get("ymin", 0)
|
||||
xmax = box.get("xmax", 0)
|
||||
ymax = box.get("ymax", 0)
|
||||
|
||||
# Extract bounding box region
|
||||
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
|
||||
|
||||
# Get embedding for this region
|
||||
embedding = self._embed_fn(bbox_crop)
|
||||
|
||||
detections.append(
|
||||
{
|
||||
"box": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax},
|
||||
"score": score,
|
||||
"embedding": embedding,
|
||||
"label": pred.get("label", "logo"),
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Detected {len(detections)} logos (threshold: {self.detr_threshold})"
|
||||
)
|
||||
return detections
|
||||
|
||||
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Get embedding for a single reference logo image.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (BGR format, numpy array)
|
||||
|
||||
Returns:
|
||||
Normalized feature embedding (torch.Tensor)
|
||||
"""
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
return self._embed_fn(pil_image)
|
||||
|
||||
def get_averaged_embedding(self, images: List[np.ndarray]) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Compute averaged embedding from multiple reference logo images.
|
||||
|
||||
Follows the averaging pattern from db_embeddings.py:
|
||||
1. Compute embedding for each image
|
||||
2. Stack and average across all images
|
||||
3. Re-normalize the averaged embedding
|
||||
|
||||
Args:
|
||||
images: List of OpenCV images (BGR format, numpy arrays)
|
||||
|
||||
Returns:
|
||||
Normalized averaged embedding (torch.Tensor, shape [1, D]),
|
||||
or None if no valid embeddings could be computed
|
||||
"""
|
||||
embeddings = []
|
||||
for img in images:
|
||||
try:
|
||||
emb = self.get_embedding(img)
|
||||
embeddings.append(emb)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to compute embedding for reference image: {e}")
|
||||
|
||||
if not embeddings:
|
||||
return None
|
||||
|
||||
# Stack: (N, D), average: (1, D), re-normalize
|
||||
stacked = torch.cat(embeddings, dim=0)
|
||||
avg_emb = stacked.mean(dim=0, keepdim=True)
|
||||
avg_emb = F.normalize(avg_emb, dim=-1)
|
||||
|
||||
self.logger.debug(
|
||||
f"Computed averaged embedding from {len(embeddings)} reference image(s)"
|
||||
)
|
||||
return avg_emb
|
||||
|
||||
def compare_embeddings(
|
||||
self, embedding1: torch.Tensor, embedding2: torch.Tensor
|
||||
) -> float:
|
||||
"""
|
||||
Compute cosine similarity between two embeddings.
|
||||
|
||||
Args:
|
||||
embedding1: First embedding (torch.Tensor)
|
||||
embedding2: Second embedding (torch.Tensor)
|
||||
|
||||
Returns:
|
||||
Cosine similarity score (float, range: -1 to 1, typically 0 to 1)
|
||||
"""
|
||||
# Ensure tensors are on the same device
|
||||
if embedding1.device != embedding2.device:
|
||||
embedding2 = embedding2.to(embedding1.device)
|
||||
|
||||
similarity = F.cosine_similarity(embedding1, embedding2, dim=-1)
|
||||
return similarity.item()
|
||||
|
||||
@staticmethod
|
||||
def make_filenames_hash(filenames: List[str]) -> str:
|
||||
"""
|
||||
Compute a deterministic hash of a filenames list.
|
||||
|
||||
Used for cache invalidation — if the filenames list changes,
|
||||
the hash changes, triggering re-computation of averaged embeddings.
|
||||
|
||||
Args:
|
||||
filenames: List of filename strings
|
||||
|
||||
Returns:
|
||||
16-character hex hash string
|
||||
"""
|
||||
canonical = json.dumps(sorted(filenames))
|
||||
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]
|
||||
@ -113,11 +113,14 @@ def get_or_create_logo_name(cursor: sqlite3.Cursor, name: str) -> int:
|
||||
|
||||
|
||||
def main():
|
||||
# Paths
|
||||
dataset_dir = Path("/data/dev.python/logo_test/LogoDet-3K")
|
||||
reference_dir = Path("/data/dev.python/logo_test/reference_logos")
|
||||
test_images_dir = Path("/data/dev.python/logo_test/test_images")
|
||||
db_path = Path("/data/dev.python/logo_test/test_data_mapping.db")
|
||||
# Use script directory as base path for portability
|
||||
base_dir = Path(__file__).parent.resolve()
|
||||
|
||||
# Paths relative to script location
|
||||
dataset_dir = base_dir / "LogoDet-3K"
|
||||
reference_dir = base_dir / "reference_logos"
|
||||
test_images_dir = base_dir / "test_images"
|
||||
db_path = base_dir / "test_data_mapping.db"
|
||||
|
||||
# Ensure output directories exist
|
||||
reference_dir.mkdir(exist_ok=True)
|
||||
|
||||
@ -12,4 +12,7 @@ dependencies = [
|
||||
"tqdm>=4.67.1",
|
||||
"transformers>=4.57.3",
|
||||
"typing>=3.10.0.0",
|
||||
"peft>=0.7.0",
|
||||
"pyyaml>=6.0",
|
||||
"torchvision>=0.20.0",
|
||||
]
|
||||
|
||||
23
requirements-training.txt
Normal file
23
requirements-training.txt
Normal file
@ -0,0 +1,23 @@
|
||||
# Requirements for CLIP logo fine-tuning on RTX 4090
|
||||
#
|
||||
# Only includes packages not already installed on the training server.
|
||||
# Does NOT upgrade existing packages (torch, torchvision, numpy, pillow,
|
||||
# pyyaml, opencv-python) which are already installed and compatible.
|
||||
#
|
||||
# Usage:
|
||||
# pip install -r requirements-training.txt
|
||||
|
||||
# CLIP models and tokenizers
|
||||
transformers>=4.36.0
|
||||
|
||||
# LoRA fine-tuning
|
||||
peft>=0.7.0
|
||||
|
||||
# Progress bars
|
||||
tqdm>=4.66.0
|
||||
|
||||
# HuggingFace Hub for model downloads
|
||||
huggingface-hub>=0.19.0
|
||||
|
||||
# Accelerate for efficient training (optional but recommended)
|
||||
accelerate>=0.25.0
|
||||
52
results_average_embeddings.txt
Normal file
52
results_average_embeddings.txt
Normal file
@ -0,0 +1,52 @@
|
||||
======================================================================
|
||||
BURNLEY LOGO DETECTION TEST
|
||||
Model: dinov2
|
||||
Method: Margin-based (margin=0.05)
|
||||
======================================================================
|
||||
Date: 2026-03-31 11:45:03
|
||||
|
||||
Configuration:
|
||||
Embedding model: dinov2
|
||||
Similarity threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Matching margin: 0.05
|
||||
Test images processed: 516
|
||||
Reference logos: barnfield, vertu
|
||||
|
||||
Results:
|
||||
True Positives: 28
|
||||
False Positives: 36
|
||||
False Negatives: 125
|
||||
Total Expected: 146
|
||||
|
||||
Scores:
|
||||
Precision: 0.4375 (43.8%)
|
||||
Recall: 0.1918 (19.2%)
|
||||
F1 Score: 0.2667 (26.7%)
|
||||
|
||||
======================================================================
|
||||
BURNLEY LOGO DETECTION TEST
|
||||
Model: dinov2
|
||||
Method: Margin-based (margin=0.05)
|
||||
======================================================================
|
||||
Date: 2026-03-31 12:29:32
|
||||
|
||||
Configuration:
|
||||
Embedding model: dinov2
|
||||
Similarity threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Matching margin: 0.05
|
||||
Test images processed: 516
|
||||
Reference logos: barnfield, vertu
|
||||
|
||||
Results:
|
||||
True Positives: 28
|
||||
False Positives: 36
|
||||
False Negatives: 125
|
||||
Total Expected: 146
|
||||
|
||||
Scores:
|
||||
Precision: 0.4375 (43.8%)
|
||||
Recall: 0.1918 (19.2%)
|
||||
F1 Score: 0.2667 (26.7%)
|
||||
|
||||
149
run_preprocess_test.sh
Executable file
149
run_preprocess_test.sh
Executable file
@ -0,0 +1,149 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Test different image preprocessing modes to determine if they improve
|
||||
# CLIP embedding accuracy for logo matching.
|
||||
#
|
||||
# Preprocessing modes tested:
|
||||
# - default: CLIP's default (resize shortest edge + center crop)
|
||||
# - letterbox: Pad to square with black bars, preserving aspect ratio
|
||||
# - stretch: Stretch to square (distorts aspect ratio)
|
||||
#
|
||||
# Usage:
|
||||
# ./run_preprocess_test.sh
|
||||
#
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OUTPUT_FILE="${SCRIPT_DIR}/test_results/preprocessing_comparison.txt"
|
||||
|
||||
# Model - baseline CLIP (testing preprocessing effect on standard model)
|
||||
MODEL="openai/clip-vit-large-patch14"
|
||||
|
||||
# Fixed parameters (same as refs_per_logo test for comparability)
|
||||
NUM_LOGOS=20
|
||||
REFS_PER_LOGO=10
|
||||
POSITIVE_SAMPLES=20
|
||||
NEGATIVE_SAMPLES=100
|
||||
MIN_MATCHING_REFS=1
|
||||
THRESHOLD=0.70
|
||||
MARGIN=0.05
|
||||
SEED=42
|
||||
|
||||
# Preprocessing modes to test
|
||||
MODES="default letterbox stretch"
|
||||
|
||||
# Create output directory if needed
|
||||
mkdir -p "${SCRIPT_DIR}/test_results"
|
||||
|
||||
# Clear output file and write header
|
||||
cat > "$OUTPUT_FILE" << EOF
|
||||
Image Preprocessing Comparison Test
|
||||
====================================
|
||||
Date: $(date)
|
||||
|
||||
Model: ${MODEL}
|
||||
Method: multi-ref (max)
|
||||
|
||||
Fixed Parameters:
|
||||
Number of logo brands: ${NUM_LOGOS}
|
||||
Refs per logo: ${REFS_PER_LOGO}
|
||||
Similarity threshold: ${THRESHOLD}
|
||||
Margin: ${MARGIN}
|
||||
Min matching refs: ${MIN_MATCHING_REFS}
|
||||
Positive samples/logo: ${POSITIVE_SAMPLES}
|
||||
Negative samples/logo: ${NEGATIVE_SAMPLES}
|
||||
Seed: ${SEED}
|
||||
|
||||
Testing preprocessing modes: ${MODES}
|
||||
|
||||
EOF
|
||||
|
||||
echo "Image Preprocessing Comparison Test"
|
||||
echo "===================================="
|
||||
echo "Model: ${MODEL}"
|
||||
echo "Testing preprocessing modes: ${MODES}"
|
||||
echo ""
|
||||
|
||||
# Results table header
|
||||
echo "Results Summary:" >> "$OUTPUT_FILE"
|
||||
echo "----------------" >> "$OUTPUT_FILE"
|
||||
printf "%-12s %8s %8s %8s %8s %8s %8s\n" "Mode" "TP" "FP" "FN" "Prec" "Recall" "F1" >> "$OUTPUT_FILE"
|
||||
echo "------------------------------------------------------------------------" >> "$OUTPUT_FILE"
|
||||
|
||||
# Track best result
|
||||
BEST_F1=0
|
||||
BEST_MODE="default"
|
||||
|
||||
for MODE in ${MODES}; do
|
||||
echo "=== Testing preprocess_mode=${MODE} ==="
|
||||
|
||||
# Clear cache to ensure fresh embeddings with new preprocessing
|
||||
rm -f "${SCRIPT_DIR}/.embedding_cache.pkl"
|
||||
|
||||
# Run test and capture output
|
||||
OUTPUT=$(uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method multi-ref \
|
||||
--min-matching-refs $MIN_MATCHING_REFS \
|
||||
--use-max-similarity \
|
||||
--threshold $THRESHOLD \
|
||||
--margin $MARGIN \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL" \
|
||||
--preprocess-mode "$MODE" \
|
||||
--no-cache \
|
||||
2>&1)
|
||||
|
||||
# Extract metrics
|
||||
TP=$(echo "${OUTPUT}" | grep "True Positives" | grep -oE "[0-9]+" | head -1)
|
||||
FP=$(echo "${OUTPUT}" | grep "False Positives" | grep -oE "[0-9]+" | head -1)
|
||||
FN=$(echo "${OUTPUT}" | grep "False Negatives" | grep -oE "[0-9]+" | head -1)
|
||||
PREC=$(echo "${OUTPUT}" | grep "Precision:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||
RECALL=$(echo "${OUTPUT}" | grep "Recall:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||
F1=$(echo "${OUTPUT}" | grep "F1 Score:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||
|
||||
# Print to console
|
||||
echo " TP: ${TP}, FP: ${FP}, FN: ${FN}"
|
||||
echo " Precision: ${PREC}, Recall: ${RECALL}, F1: ${F1}"
|
||||
echo ""
|
||||
|
||||
# Add to results table
|
||||
printf "%-12s %8s %8s %8s %8s %8s %8s\n" "${MODE}" "${TP}" "${FP}" "${FN}" "${PREC}" "${RECALL}" "${F1}" >> "$OUTPUT_FILE"
|
||||
|
||||
# Track best F1
|
||||
F1_NUM=$(echo "${F1}" | tr -d '%')
|
||||
if [ -n "$F1_NUM" ]; then
|
||||
BETTER=$(echo "${F1_NUM} > ${BEST_F1}" | bc -l 2>/dev/null || echo "0")
|
||||
if [ "$BETTER" = "1" ]; then
|
||||
BEST_F1="${F1_NUM}"
|
||||
BEST_MODE="${MODE}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Also append full output for this test
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "======================================================================" >> "$OUTPUT_FILE"
|
||||
echo "DETAILED RESULTS: preprocess_mode=${MODE}" >> "$OUTPUT_FILE"
|
||||
echo "======================================================================" >> "$OUTPUT_FILE"
|
||||
echo "${OUTPUT}" | grep -A 50 "Configuration:" | head -30 >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
done
|
||||
|
||||
# Summary
|
||||
echo "------------------------------------------------------------------------" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "BEST PREPROCESSING MODE: ${BEST_MODE} (F1 = ${BEST_F1}%)" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "Notes:" >> "$OUTPUT_FILE"
|
||||
echo " - default: CLIP's standard preprocessing (resize shortest edge + center crop)" >> "$OUTPUT_FILE"
|
||||
echo " - letterbox: Pads image to square with black bars, preserving aspect ratio" >> "$OUTPUT_FILE"
|
||||
echo " - stretch: Resizes image to square, distorting aspect ratio" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
|
||||
echo "======================================="
|
||||
echo "BEST: preprocess_mode=${BEST_MODE} (F1 = ${BEST_F1}%)"
|
||||
echo "======================================="
|
||||
echo ""
|
||||
echo "Results saved to: $OUTPUT_FILE"
|
||||
132
run_refs_per_logo_test.sh
Executable file
132
run_refs_per_logo_test.sh
Executable file
@ -0,0 +1,132 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Test different numbers of reference logos per brand to find optimal setting.
|
||||
# Uses baseline CLIP with multi-ref (max) matching method.
|
||||
#
|
||||
# Usage:
|
||||
# ./run_refs_per_logo_test.sh
|
||||
#
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OUTPUT_FILE="${SCRIPT_DIR}/test_results/refs_per_logo_analysis.txt"
|
||||
|
||||
# Model - baseline CLIP (best for unknown logos)
|
||||
MODEL="openai/clip-vit-large-patch14"
|
||||
|
||||
# Fixed parameters
|
||||
NUM_LOGOS=20
|
||||
POSITIVE_SAMPLES=20
|
||||
NEGATIVE_SAMPLES=100
|
||||
MIN_MATCHING_REFS=1
|
||||
THRESHOLD=0.70
|
||||
MARGIN=0.05
|
||||
SEED=42
|
||||
|
||||
# Refs per logo values to test
|
||||
REFS_TO_TEST="1 2 3 5 7 10 15 20"
|
||||
|
||||
# Create output directory if needed
|
||||
mkdir -p "${SCRIPT_DIR}/test_results"
|
||||
|
||||
# Clear output file and write header
|
||||
cat > "$OUTPUT_FILE" << EOF
|
||||
Reference Logos Per Brand Optimization
|
||||
======================================
|
||||
Date: $(date)
|
||||
|
||||
Model: ${MODEL}
|
||||
Method: multi-ref (max)
|
||||
|
||||
Fixed Parameters:
|
||||
Number of logo brands: ${NUM_LOGOS}
|
||||
Similarity threshold: ${THRESHOLD}
|
||||
Margin: ${MARGIN}
|
||||
Min matching refs: ${MIN_MATCHING_REFS}
|
||||
Positive samples/logo: ${POSITIVE_SAMPLES}
|
||||
Negative samples/logo: ${NEGATIVE_SAMPLES}
|
||||
Seed: ${SEED}
|
||||
|
||||
Testing refs per logo: ${REFS_TO_TEST}
|
||||
|
||||
EOF
|
||||
|
||||
echo "Reference Logos Per Brand Optimization"
|
||||
echo "======================================="
|
||||
echo "Model: ${MODEL}"
|
||||
echo "Testing refs per logo: ${REFS_TO_TEST}"
|
||||
echo ""
|
||||
|
||||
# Results table header
|
||||
echo "Results Summary:" >> "$OUTPUT_FILE"
|
||||
echo "----------------" >> "$OUTPUT_FILE"
|
||||
printf "%-12s %8s %8s %8s %8s %8s %8s\n" "Refs/Logo" "TP" "FP" "FN" "Prec" "Recall" "F1" >> "$OUTPUT_FILE"
|
||||
echo "------------------------------------------------------------------------" >> "$OUTPUT_FILE"
|
||||
|
||||
# Track best result
|
||||
BEST_F1=0
|
||||
BEST_REFS=0
|
||||
|
||||
for REFS in ${REFS_TO_TEST}; do
|
||||
echo "=== Testing refs_per_logo=${REFS} ==="
|
||||
|
||||
# Run test and capture output
|
||||
OUTPUT=$(uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method multi-ref \
|
||||
--min-matching-refs $MIN_MATCHING_REFS \
|
||||
--use-max-similarity \
|
||||
--threshold $THRESHOLD \
|
||||
--margin $MARGIN \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL" \
|
||||
2>&1)
|
||||
|
||||
# Extract metrics
|
||||
TP=$(echo "${OUTPUT}" | grep "True Positives" | grep -oE "[0-9]+" | head -1)
|
||||
FP=$(echo "${OUTPUT}" | grep "False Positives" | grep -oE "[0-9]+" | head -1)
|
||||
FN=$(echo "${OUTPUT}" | grep "False Negatives" | grep -oE "[0-9]+" | head -1)
|
||||
PREC=$(echo "${OUTPUT}" | grep "Precision:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||
RECALL=$(echo "${OUTPUT}" | grep "Recall:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||
F1=$(echo "${OUTPUT}" | grep "F1 Score:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||
|
||||
# Print to console
|
||||
echo " TP: ${TP}, FP: ${FP}, FN: ${FN}"
|
||||
echo " Precision: ${PREC}, Recall: ${RECALL}, F1: ${F1}"
|
||||
echo ""
|
||||
|
||||
# Add to results table
|
||||
printf "%-12s %8s %8s %8s %8s %8s %8s\n" "${REFS}" "${TP}" "${FP}" "${FN}" "${PREC}" "${RECALL}" "${F1}" >> "$OUTPUT_FILE"
|
||||
|
||||
# Track best F1
|
||||
F1_NUM=$(echo "${F1}" | tr -d '%')
|
||||
if [ -n "$F1_NUM" ]; then
|
||||
BETTER=$(echo "${F1_NUM} > ${BEST_F1}" | bc -l 2>/dev/null || echo "0")
|
||||
if [ "$BETTER" = "1" ]; then
|
||||
BEST_F1="${F1_NUM}"
|
||||
BEST_REFS="${REFS}"
|
||||
fi
|
||||
fi
|
||||
|
||||
# Also append full output for this test
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "======================================================================" >> "$OUTPUT_FILE"
|
||||
echo "DETAILED RESULTS: refs_per_logo=${REFS}" >> "$OUTPUT_FILE"
|
||||
echo "======================================================================" >> "$OUTPUT_FILE"
|
||||
echo "${OUTPUT}" | grep -A 50 "Configuration:" | head -30 >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
done
|
||||
|
||||
# Summary
|
||||
echo "------------------------------------------------------------------------" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "OPTIMAL SETTING: refs_per_logo=${BEST_REFS} (F1 = ${BEST_F1}%)" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
|
||||
echo "======================================="
|
||||
echo "OPTIMAL: refs_per_logo=${BEST_REFS} (F1 = ${BEST_F1}%)"
|
||||
echo "======================================="
|
||||
echo ""
|
||||
echo "Results saved to: $OUTPUT_FILE"
|
||||
181
run_threshold_tests_image_split.sh
Executable file
181
run_threshold_tests_image_split.sh
Executable file
@ -0,0 +1,181 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Run logo detection tests with the image-split fine-tuned model.
|
||||
# Tests various threshold and margin settings to find optimal parameters.
|
||||
#
|
||||
# Usage:
|
||||
# ./run_threshold_tests_image_split.sh
|
||||
#
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OUTPUT_FILE="${SCRIPT_DIR}/threshold_test_results_image_split.txt"
|
||||
|
||||
# Model path
|
||||
MODEL_PATH="models/logo_detection/clip_finetuned_image_split"
|
||||
|
||||
# Common parameters
|
||||
NUM_LOGOS=20
|
||||
REFS_PER_LOGO=10
|
||||
POSITIVE_SAMPLES=20
|
||||
NEGATIVE_SAMPLES=100
|
||||
MIN_MATCHING_REFS=3
|
||||
SEED=42
|
||||
|
||||
# Check if model exists
|
||||
if [ ! -d "${SCRIPT_DIR}/${MODEL_PATH}" ]; then
|
||||
echo "Error: Image-split model not found at ${SCRIPT_DIR}/${MODEL_PATH}"
|
||||
echo "Train the model first with: python train_clip_logo.py --config configs/cloud_rtx4090_image_split.yaml"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Clear output file and write header
|
||||
echo "Threshold Optimization Tests - Image-Split Model" > "$OUTPUT_FILE"
|
||||
echo "=================================================" >> "$OUTPUT_FILE"
|
||||
echo "Date: $(date)" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "Model: ${MODEL_PATH}" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "Common Parameters:" >> "$OUTPUT_FILE"
|
||||
echo " Matching method: multi-ref (max)" >> "$OUTPUT_FILE"
|
||||
echo " Reference logos: $NUM_LOGOS" >> "$OUTPUT_FILE"
|
||||
echo " Refs per logo: $REFS_PER_LOGO" >> "$OUTPUT_FILE"
|
||||
echo " Positive samples: $POSITIVE_SAMPLES" >> "$OUTPUT_FILE"
|
||||
echo " Negative samples: $NEGATIVE_SAMPLES" >> "$OUTPUT_FILE"
|
||||
echo " Min matching refs: $MIN_MATCHING_REFS" >> "$OUTPUT_FILE"
|
||||
echo " Seed: $SEED" >> "$OUTPUT_FILE"
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
|
||||
echo "Running threshold optimization tests for image-split model..."
|
||||
echo " Model: ${MODEL_PATH}"
|
||||
echo " Matching method: multi-ref (max)"
|
||||
echo " Reference logos: $NUM_LOGOS"
|
||||
echo " Refs per logo: $REFS_PER_LOGO"
|
||||
echo " Seed: $SEED"
|
||||
echo ""
|
||||
|
||||
# Test 1: Lower threshold (image-split model may have different distribution)
|
||||
echo "=== Test 1: threshold=0.65, margin=0.05 ==="
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method multi-ref \
|
||||
--min-matching-refs $MIN_MATCHING_REFS \
|
||||
--use-max-similarity \
|
||||
--threshold 0.65 \
|
||||
--margin 0.05 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL_PATH" \
|
||||
--output-file "$OUTPUT_FILE"
|
||||
|
||||
echo ""
|
||||
|
||||
# Test 2: Default threshold
|
||||
echo "=== Test 2: threshold=0.70, margin=0.05 ==="
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method multi-ref \
|
||||
--min-matching-refs $MIN_MATCHING_REFS \
|
||||
--use-max-similarity \
|
||||
--threshold 0.70 \
|
||||
--margin 0.05 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL_PATH" \
|
||||
--output-file "$OUTPUT_FILE"
|
||||
|
||||
echo ""
|
||||
|
||||
# Test 3: threshold=0.75
|
||||
echo "=== Test 3: threshold=0.75, margin=0.05 ==="
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method multi-ref \
|
||||
--min-matching-refs $MIN_MATCHING_REFS \
|
||||
--use-max-similarity \
|
||||
--threshold 0.75 \
|
||||
--margin 0.05 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL_PATH" \
|
||||
--output-file "$OUTPUT_FILE"
|
||||
|
||||
echo ""
|
||||
|
||||
# Test 4: threshold=0.80
|
||||
echo "=== Test 4: threshold=0.80, margin=0.05 ==="
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method multi-ref \
|
||||
--min-matching-refs $MIN_MATCHING_REFS \
|
||||
--use-max-similarity \
|
||||
--threshold 0.80 \
|
||||
--margin 0.05 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL_PATH" \
|
||||
--output-file "$OUTPUT_FILE"
|
||||
|
||||
echo ""
|
||||
|
||||
# Test 5: threshold=0.80 with larger margin
|
||||
echo "=== Test 5: threshold=0.80, margin=0.10 ==="
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method multi-ref \
|
||||
--min-matching-refs $MIN_MATCHING_REFS \
|
||||
--use-max-similarity \
|
||||
--threshold 0.80 \
|
||||
--margin 0.10 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL_PATH" \
|
||||
--output-file "$OUTPUT_FILE"
|
||||
|
||||
echo ""
|
||||
|
||||
# Test 6: threshold=0.85
|
||||
echo "=== Test 6: threshold=0.85, margin=0.10 ==="
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method multi-ref \
|
||||
--min-matching-refs $MIN_MATCHING_REFS \
|
||||
--use-max-similarity \
|
||||
--threshold 0.85 \
|
||||
--margin 0.10 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL_PATH" \
|
||||
--output-file "$OUTPUT_FILE"
|
||||
|
||||
echo ""
|
||||
|
||||
# Test 7: threshold=0.90
|
||||
echo "=== Test 7: threshold=0.90, margin=0.10 ==="
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method multi-ref \
|
||||
--min-matching-refs $MIN_MATCHING_REFS \
|
||||
--use-max-similarity \
|
||||
--threshold 0.90 \
|
||||
--margin 0.10 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL_PATH" \
|
||||
--output-file "$OUTPUT_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Results saved to: $OUTPUT_FILE"
|
||||
521
test_burnley_detection.py
Normal file
521
test_burnley_detection.py
Normal file
@ -0,0 +1,521 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for logo detection accuracy on Burnley test images.
|
||||
|
||||
Uses DetectLogosEmbeddings from logo_detection_embeddings.py to detect
|
||||
barnfield and vertu logos. Ground truth is determined by filename prefix:
|
||||
- "vertu_" → contains vertu logo
|
||||
- "barnfield_" → contains barnfield logo
|
||||
- "barnfield+vertu_" → contains both logos
|
||||
- anything else → no target logos
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from logo_detection_embeddings import DetectLogosEmbeddings
|
||||
|
||||
|
||||
def setup_logging(verbose: bool = False) -> logging.Logger:
|
||||
"""Configure logging."""
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_image(image_path: Path) -> Optional[cv2.Mat]:
|
||||
"""Load an image using OpenCV."""
|
||||
img = cv2.imread(str(image_path))
|
||||
if img is None:
|
||||
return None
|
||||
return img
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Simple file-based cache for embeddings."""
|
||||
|
||||
def __init__(self, cache_path: Path):
|
||||
self.cache_path = cache_path
|
||||
self.cache: Dict[str, Any] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
if self.cache_path.exists():
|
||||
try:
|
||||
with open(self.cache_path, "rb") as f:
|
||||
self.cache = pickle.load(f)
|
||||
except Exception:
|
||||
self.cache = {}
|
||||
|
||||
def save(self):
|
||||
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.cache_path, "wb") as f:
|
||||
pickle.dump(self.cache, f)
|
||||
|
||||
def get(self, key: str):
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, key: str, value):
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.cache[key] = value.cpu()
|
||||
else:
|
||||
self.cache[key] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.cache)
|
||||
|
||||
|
||||
def get_expected_logos(filename: str) -> Set[str]:
|
||||
"""Determine expected logos from filename prefix."""
|
||||
name = filename.lower()
|
||||
if name.startswith("barnfield+vertu_"):
|
||||
return {"barnfield", "vertu"}
|
||||
elif name.startswith("barnfield_"):
|
||||
return {"barnfield"}
|
||||
elif name.startswith("vertu_"):
|
||||
return {"vertu"}
|
||||
return set()
|
||||
|
||||
|
||||
def load_reference_images(ref_dir: Path, logger: logging.Logger) -> List[cv2.Mat]:
|
||||
"""Load all images from a reference directory."""
|
||||
images = []
|
||||
for path in sorted(ref_dir.iterdir()):
|
||||
if path.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp"):
|
||||
img = load_image(path)
|
||||
if img is not None:
|
||||
images.append(img)
|
||||
else:
|
||||
logger.warning(f"Failed to load reference image: {path}")
|
||||
return images
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test logo detection on Burnley test images using DetectLogosEmbeddings"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--threshold",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Similarity threshold for matching (default: 0.7)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d", "--detr-threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="DETR detection confidence threshold (default: 0.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--embedding-model",
|
||||
type=str,
|
||||
choices=["clip", "dinov2", "siglip"],
|
||||
default="dinov2",
|
||||
help="Embedding model type (default: dinov2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--margin",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Required margin between best and second-best match (default: 0.05)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--similarity-details",
|
||||
action="store_true",
|
||||
help="Output detailed similarity scores for each detection",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-cache",
|
||||
action="store_true",
|
||||
help="Disable embedding cache",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear-cache",
|
||||
action="store_true",
|
||||
help="Clear embedding cache before running",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Append results summary to this file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger = setup_logging(args.verbose)
|
||||
|
||||
# Paths
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
test_images_dir = base_dir / "burnley_test_images"
|
||||
barnfield_ref_dir = base_dir / "barnfield_reference_images"
|
||||
vertu_ref_dir = base_dir / "vertu_reference_images"
|
||||
cache_path = base_dir / ".burnley_embedding_cache.pkl"
|
||||
|
||||
# Verify directories exist
|
||||
for d, name in [(test_images_dir, "Test images"), (barnfield_ref_dir, "Barnfield refs"), (vertu_ref_dir, "Vertu refs")]:
|
||||
if not d.exists():
|
||||
logger.error(f"{name} directory not found: {d}")
|
||||
sys.exit(1)
|
||||
|
||||
# Handle cache
|
||||
if args.clear_cache and cache_path.exists():
|
||||
cache_path.unlink()
|
||||
logger.info("Cleared embedding cache")
|
||||
|
||||
cache = EmbeddingCache(cache_path) if not args.no_cache else None
|
||||
if cache:
|
||||
logger.info(f"Loaded {len(cache)} cached embeddings")
|
||||
|
||||
# Initialize detector
|
||||
logger.info(f"Initializing detector with embedding model: {args.embedding_model}")
|
||||
detector = DetectLogosEmbeddings(
|
||||
logger=logger,
|
||||
detr_threshold=args.detr_threshold,
|
||||
embedding_model_type=args.embedding_model,
|
||||
)
|
||||
|
||||
# Compute averaged reference embeddings
|
||||
logger.info("Computing reference embeddings...")
|
||||
|
||||
reference_embeddings: Dict[str, torch.Tensor] = {}
|
||||
for logo_name, ref_dir in [("barnfield", barnfield_ref_dir), ("vertu", vertu_ref_dir)]:
|
||||
cache_key = f"avg_ref:{logo_name}:{args.embedding_model}"
|
||||
cached = cache.get(cache_key) if cache else None
|
||||
|
||||
if cached is not None:
|
||||
reference_embeddings[logo_name] = cached
|
||||
logger.info(f"Loaded cached averaged embedding for {logo_name}")
|
||||
else:
|
||||
ref_images = load_reference_images(ref_dir, logger)
|
||||
logger.info(f"Computing averaged embedding for {logo_name} from {len(ref_images)} images")
|
||||
avg_emb = detector.get_averaged_embedding(ref_images)
|
||||
if avg_emb is None:
|
||||
logger.error(f"Failed to compute embedding for {logo_name}")
|
||||
sys.exit(1)
|
||||
reference_embeddings[logo_name] = avg_emb
|
||||
if cache:
|
||||
cache.put(cache_key, avg_emb)
|
||||
|
||||
# Collect test images
|
||||
test_files = sorted([
|
||||
f.name for f in test_images_dir.iterdir()
|
||||
if f.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp")
|
||||
])
|
||||
logger.info(f"Found {len(test_files)} test images")
|
||||
|
||||
# Metrics
|
||||
true_positives = 0
|
||||
false_positives = 0
|
||||
false_negatives = 0
|
||||
total_expected = 0
|
||||
results = []
|
||||
|
||||
similarity_details = {
|
||||
"true_positive_sims": [],
|
||||
"false_positive_sims": [],
|
||||
"missed_best_sims": [],
|
||||
"detection_details": [],
|
||||
}
|
||||
|
||||
# Process test images
|
||||
for test_filename in tqdm(test_files, desc="Testing"):
|
||||
test_path = test_images_dir / test_filename
|
||||
expected_logos = get_expected_logos(test_filename)
|
||||
total_expected += len(expected_logos)
|
||||
|
||||
# Check cache for detections
|
||||
det_cache_key = f"det:{test_filename}:{args.embedding_model}"
|
||||
cached_detections = cache.get(det_cache_key) if cache else None
|
||||
|
||||
if cached_detections is not None:
|
||||
detections = cached_detections
|
||||
else:
|
||||
test_img = load_image(test_path)
|
||||
if test_img is None:
|
||||
logger.warning(f"Failed to load test image: {test_path}")
|
||||
continue
|
||||
detections = detector.detect(test_img)
|
||||
if cache:
|
||||
cache.put(det_cache_key, detections)
|
||||
|
||||
# Match each detection against reference embeddings with margin
|
||||
matched_logos: Set[str] = set()
|
||||
for det_idx, detection in enumerate(detections):
|
||||
# Compute similarity to each reference logo
|
||||
sims: Dict[str, float] = {}
|
||||
for logo_name, ref_emb in reference_embeddings.items():
|
||||
sims[logo_name] = detector.compare_embeddings(
|
||||
detection["embedding"], ref_emb
|
||||
)
|
||||
|
||||
sorted_sims = sorted(sims.items(), key=lambda x: -x[1])
|
||||
|
||||
if args.similarity_details:
|
||||
similarity_details["detection_details"].append({
|
||||
"image": test_filename,
|
||||
"detection_idx": det_idx,
|
||||
"expected_logos": list(expected_logos),
|
||||
"similarities": sorted_sims,
|
||||
"detr_score": detection.get("score", 0),
|
||||
})
|
||||
|
||||
# Best match with margin check
|
||||
if not sorted_sims:
|
||||
continue
|
||||
|
||||
best_name, best_sim = sorted_sims[0]
|
||||
if best_sim < args.threshold:
|
||||
continue
|
||||
|
||||
# Check margin over second best
|
||||
if len(sorted_sims) > 1:
|
||||
second_sim = sorted_sims[1][1]
|
||||
if best_sim - second_sim < args.margin:
|
||||
continue
|
||||
|
||||
matched_logos.add(best_name)
|
||||
is_correct = best_name in expected_logos
|
||||
|
||||
if is_correct:
|
||||
true_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["true_positive_sims"].append(best_sim)
|
||||
else:
|
||||
false_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["false_positive_sims"].append(best_sim)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": best_name,
|
||||
"similarity": best_sim,
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
# Count missed detections
|
||||
missed = expected_logos - matched_logos
|
||||
false_negatives += len(missed)
|
||||
|
||||
for missed_logo in missed:
|
||||
if args.similarity_details and detections:
|
||||
best_sim_for_missed = 0
|
||||
ref_emb = reference_embeddings[missed_logo]
|
||||
for detection in detections:
|
||||
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||
best_sim_for_missed = max(best_sim_for_missed, sim)
|
||||
similarity_details["missed_best_sims"].append(best_sim_for_missed)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": None,
|
||||
"expected_logo": missed_logo,
|
||||
"similarity": None,
|
||||
"correct": False,
|
||||
})
|
||||
|
||||
# Save cache
|
||||
if cache:
|
||||
cache.save()
|
||||
logger.info(f"Saved {len(cache)} embeddings to cache")
|
||||
|
||||
# Calculate metrics
|
||||
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
||||
recall = true_positives / total_expected if total_expected > 0 else 0
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("BURNLEY LOGO DETECTION TEST RESULTS")
|
||||
print("=" * 60)
|
||||
print(f"\nConfiguration:")
|
||||
print(f" Embedding model: {args.embedding_model}")
|
||||
print(f" Similarity threshold: {args.threshold}")
|
||||
print(f" DETR confidence threshold: {args.detr_threshold}")
|
||||
print(f" Matching margin: {args.margin}")
|
||||
print(f" Test images processed: {len(test_files)}")
|
||||
print(f" Reference logos: barnfield, vertu")
|
||||
|
||||
print(f"\nMetrics:")
|
||||
print(f" True Positives (correct matches): {true_positives}")
|
||||
print(f" False Positives (wrong matches): {false_positives}")
|
||||
print(f" False Negatives (missed logos): {false_negatives}")
|
||||
print(f" Total expected matches: {total_expected}")
|
||||
|
||||
print(f"\nScores:")
|
||||
print(f" Precision: {precision:.4f} ({precision*100:.1f}%)")
|
||||
print(f" Recall: {recall:.4f} ({recall*100:.1f}%)")
|
||||
print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)")
|
||||
|
||||
# Show false positive examples
|
||||
false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]]
|
||||
if false_positive_examples:
|
||||
print(f"\nExample False Positives (first 5):")
|
||||
for r in false_positive_examples[:5]:
|
||||
print(f" - Image: {r['test_image']}")
|
||||
print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})")
|
||||
|
||||
# Show false negative examples
|
||||
false_negative_examples = [r for r in results if r.get("expected_logo")]
|
||||
if false_negative_examples:
|
||||
print(f"\nExample False Negatives (first 5):")
|
||||
for r in false_negative_examples[:5]:
|
||||
print(f" - Image: {r['test_image']}")
|
||||
print(f" Expected: {r['expected_logo']}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Print similarity details if requested
|
||||
if args.similarity_details:
|
||||
print_similarity_details(similarity_details, args.threshold)
|
||||
|
||||
# Write results to file if requested
|
||||
if args.output_file:
|
||||
write_results_to_file(
|
||||
output_path=Path(args.output_file),
|
||||
args=args,
|
||||
num_test_images=len(test_files),
|
||||
true_positives=true_positives,
|
||||
false_positives=false_positives,
|
||||
false_negatives=false_negatives,
|
||||
total_expected=total_expected,
|
||||
precision=precision,
|
||||
recall=recall,
|
||||
f1=f1,
|
||||
)
|
||||
print(f"\nResults appended to: {args.output_file}")
|
||||
|
||||
|
||||
def print_similarity_details(details: dict, threshold: float):
|
||||
"""Print detailed similarity distribution analysis."""
|
||||
import statistics
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SIMILARITY DISTRIBUTION ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
def compute_stats(values, name):
|
||||
if not values:
|
||||
print(f"\n{name}: No data")
|
||||
return
|
||||
print(f"\n{name} (n={len(values)}):")
|
||||
print(f" Min: {min(values):.4f}")
|
||||
print(f" Max: {max(values):.4f}")
|
||||
print(f" Mean: {statistics.mean(values):.4f}")
|
||||
if len(values) > 1:
|
||||
print(f" StdDev: {statistics.stdev(values):.4f}")
|
||||
print(f" Median: {statistics.median(values):.4f}")
|
||||
|
||||
above = sum(1 for v in values if v >= threshold)
|
||||
below = sum(1 for v in values if v < threshold)
|
||||
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
|
||||
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
|
||||
|
||||
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities")
|
||||
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities")
|
||||
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities")
|
||||
|
||||
# Overlap analysis
|
||||
tp_sims = details["true_positive_sims"]
|
||||
fp_sims = details["false_positive_sims"]
|
||||
if tp_sims and fp_sims:
|
||||
print("\n" + "-" * 40)
|
||||
print("OVERLAP ANALYSIS:")
|
||||
tp_min, tp_max = min(tp_sims), max(tp_sims)
|
||||
fp_min, fp_max = min(fp_sims), max(fp_sims)
|
||||
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
|
||||
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
|
||||
|
||||
overlap_min = max(tp_min, fp_min)
|
||||
overlap_max = min(tp_max, fp_max)
|
||||
if overlap_min < overlap_max:
|
||||
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
|
||||
else:
|
||||
print(" NO OVERLAP - distributions are separable!")
|
||||
|
||||
# Sample detection details
|
||||
det_details = details["detection_details"]
|
||||
if det_details:
|
||||
print("\n" + "-" * 40)
|
||||
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
|
||||
for i, det in enumerate(det_details[:20]):
|
||||
expected = det["expected_logos"]
|
||||
sims = det["similarities"]
|
||||
print(f"\n [{i+1}] Image: {det['image']}")
|
||||
print(f" Expected: {expected if expected else '(none)'}")
|
||||
print(f" DETR score: {det['detr_score']:.3f}")
|
||||
print(f" Similarities:")
|
||||
for logo, sim in sims:
|
||||
marker = " <-- CORRECT" if logo in expected else ""
|
||||
print(f" {sim:.4f} {logo}{marker}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
|
||||
def write_results_to_file(
|
||||
output_path: Path,
|
||||
args,
|
||||
num_test_images: int,
|
||||
true_positives: int,
|
||||
false_positives: int,
|
||||
false_negatives: int,
|
||||
total_expected: int,
|
||||
precision: float,
|
||||
recall: float,
|
||||
f1: float,
|
||||
):
|
||||
"""Write results summary to file."""
|
||||
from datetime import datetime
|
||||
|
||||
lines = [
|
||||
"=" * 70,
|
||||
"BURNLEY LOGO DETECTION TEST",
|
||||
f"Model: {args.embedding_model}",
|
||||
f"Method: Margin-based (margin={args.margin})",
|
||||
"=" * 70,
|
||||
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"",
|
||||
"Configuration:",
|
||||
f" Embedding model: {args.embedding_model}",
|
||||
f" Similarity threshold: {args.threshold}",
|
||||
f" DETR threshold: {args.detr_threshold}",
|
||||
f" Matching margin: {args.margin}",
|
||||
f" Test images processed: {num_test_images}",
|
||||
f" Reference logos: barnfield, vertu",
|
||||
"",
|
||||
"Results:",
|
||||
f" True Positives: {true_positives:>6}",
|
||||
f" False Positives: {false_positives:>6}",
|
||||
f" False Negatives: {false_negatives:>6}",
|
||||
f" Total Expected: {total_expected:>6}",
|
||||
"",
|
||||
"Scores:",
|
||||
f" Precision: {precision:.4f} ({precision*100:.1f}%)",
|
||||
f" Recall: {recall:.4f} ({recall*100:.1f}%)",
|
||||
f" F1 Score: {f1:.4f} ({f1*100:.1f}%)",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
|
||||
with open(output_path, "a") as f:
|
||||
f.write("\n".join(lines))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -18,7 +18,7 @@ import random
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@ -265,6 +265,11 @@ def main():
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--similarity-details",
|
||||
action="store_true",
|
||||
help="Output detailed similarity scores for each detection (for analyzing score distributions)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-cache",
|
||||
action="store_true",
|
||||
@ -281,6 +286,14 @@ def main():
|
||||
default=None,
|
||||
help="Append results summary to this file (no progress output, just results)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocess-mode",
|
||||
type=str,
|
||||
choices=["default", "letterbox", "stretch"],
|
||||
default="default",
|
||||
help="Image preprocessing mode for CLIP: 'default' (resize+center crop), "
|
||||
"'letterbox' (pad to square with black bars), 'stretch' (distort to square)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger = setup_logging(args.verbose)
|
||||
@ -310,10 +323,13 @@ def main():
|
||||
|
||||
# Initialize detector
|
||||
logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}")
|
||||
if args.preprocess_mode != "default":
|
||||
logger.info(f"Using preprocessing mode: {args.preprocess_mode}")
|
||||
detector = DetectLogosDETR(
|
||||
logger=logger,
|
||||
detr_threshold=args.detr_threshold,
|
||||
embedding_model=args.embedding_model,
|
||||
preprocess_mode=args.preprocess_mode,
|
||||
)
|
||||
|
||||
# Load ground truth (both mappings)
|
||||
@ -349,6 +365,7 @@ def main():
|
||||
cache_key = f"ref:{ref_filename}"
|
||||
embedding = cache.get(cache_key) if cache else None
|
||||
|
||||
# Load image if needed for embedding
|
||||
if embedding is None:
|
||||
img = load_image(ref_path)
|
||||
if img is None:
|
||||
@ -411,6 +428,16 @@ def main():
|
||||
# Detailed results for analysis
|
||||
results = []
|
||||
|
||||
# Similarity distribution tracking (for --similarity-details)
|
||||
similarity_details = {
|
||||
"true_positive_sims": [], # Similarities for correct matches
|
||||
"false_positive_sims": [], # Similarities for wrong matches
|
||||
"missed_best_sims": [], # Best similarity for logos that should have matched but didn't
|
||||
"all_positive_sims": [], # All similarities between detected regions and correct logos
|
||||
"all_negative_sims": [], # All similarities between detected regions and wrong logos
|
||||
"detection_details": [], # Per-detection breakdown
|
||||
}
|
||||
|
||||
# Process test images
|
||||
for test_filename in tqdm(test_images, desc="Testing"):
|
||||
test_path = test_images_dir / test_filename
|
||||
@ -427,17 +454,19 @@ def main():
|
||||
cache_key = f"det:{test_filename}"
|
||||
cached_detections = cache.get(cache_key) if cache else None
|
||||
|
||||
test_img = None
|
||||
if cached_detections is not None:
|
||||
# Cached detections contain serialized box data and embeddings
|
||||
detections = cached_detections
|
||||
else:
|
||||
# Load and detect
|
||||
img = load_image(test_path)
|
||||
if img is None:
|
||||
if test_img is None:
|
||||
test_img = load_image(test_path)
|
||||
if test_img is None:
|
||||
logger.warning(f"Failed to load test image: {test_path}")
|
||||
continue
|
||||
|
||||
detections = detector.detect(img)
|
||||
detections = detector.detect(test_img)
|
||||
|
||||
# Cache the detections
|
||||
if cache:
|
||||
@ -445,7 +474,38 @@ def main():
|
||||
|
||||
# Match detections against references using selected method
|
||||
matched_logos: Set[str] = set()
|
||||
for detection in detections:
|
||||
for det_idx, detection in enumerate(detections):
|
||||
# Compute similarities to all reference logos for detailed analysis
|
||||
if args.similarity_details:
|
||||
all_sims = {}
|
||||
for logo_name, ref_emb_list in multi_ref_embeddings.items():
|
||||
sims = []
|
||||
for ref_emb in ref_emb_list:
|
||||
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||
sims.append(sim)
|
||||
# Use mean or max based on setting
|
||||
if args.use_max_similarity:
|
||||
all_sims[logo_name] = max(sims) if sims else 0
|
||||
else:
|
||||
all_sims[logo_name] = sum(sims) / len(sims) if sims else 0
|
||||
|
||||
# Track positive vs negative similarities
|
||||
for sim in sims:
|
||||
if logo_name in expected_logos:
|
||||
similarity_details["all_positive_sims"].append(sim)
|
||||
else:
|
||||
similarity_details["all_negative_sims"].append(sim)
|
||||
|
||||
# Store detection details
|
||||
sorted_sims = sorted(all_sims.items(), key=lambda x: -x[1])
|
||||
similarity_details["detection_details"].append({
|
||||
"image": test_filename,
|
||||
"detection_idx": det_idx,
|
||||
"expected_logos": list(expected_logos),
|
||||
"top_5_matches": sorted_sims[:5],
|
||||
"detr_score": detection.get("score", 0),
|
||||
})
|
||||
|
||||
if args.matching_method == "simple":
|
||||
# Simple matching: return ALL logos above threshold
|
||||
all_matches = detector.find_all_matches(
|
||||
@ -457,16 +517,21 @@ def main():
|
||||
matched_logos.add(label)
|
||||
|
||||
# Check if this is a correct match
|
||||
if label in expected_logos:
|
||||
is_correct = label in expected_logos
|
||||
if is_correct:
|
||||
true_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["true_positive_sims"].append(similarity)
|
||||
else:
|
||||
false_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["false_positive_sims"].append(similarity)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": label,
|
||||
"similarity": similarity,
|
||||
"correct": label in expected_logos,
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
elif args.matching_method == "margin":
|
||||
@ -481,19 +546,24 @@ def main():
|
||||
label, similarity = match_result
|
||||
matched_logos.add(label)
|
||||
|
||||
if label in expected_logos:
|
||||
is_correct = label in expected_logos
|
||||
if is_correct:
|
||||
true_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["true_positive_sims"].append(similarity)
|
||||
else:
|
||||
false_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["false_positive_sims"].append(similarity)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": label,
|
||||
"similarity": similarity,
|
||||
"correct": label in expected_logos,
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
else: # multi-ref
|
||||
elif args.matching_method == "multi-ref":
|
||||
# Multi-ref matching: aggregates scores across reference images
|
||||
match_result = detector.find_best_match_multi_ref(
|
||||
detection["embedding"],
|
||||
@ -507,16 +577,21 @@ def main():
|
||||
label, similarity, num_matching = match_result
|
||||
matched_logos.add(label)
|
||||
|
||||
if label in expected_logos:
|
||||
is_correct = label in expected_logos
|
||||
if is_correct:
|
||||
true_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["true_positive_sims"].append(similarity)
|
||||
else:
|
||||
false_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["false_positive_sims"].append(similarity)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": label,
|
||||
"similarity": similarity,
|
||||
"correct": label in expected_logos,
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
# Count missed detections (false negatives)
|
||||
@ -524,6 +599,15 @@ def main():
|
||||
false_negatives += len(missed)
|
||||
|
||||
for missed_logo in missed:
|
||||
# Track best similarity for missed logos (if we have detections)
|
||||
if args.similarity_details and detections:
|
||||
best_sim_for_missed = 0
|
||||
for detection in detections:
|
||||
for ref_emb in multi_ref_embeddings.get(missed_logo, []):
|
||||
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||
best_sim_for_missed = max(best_sim_for_missed, sim)
|
||||
similarity_details["missed_best_sims"].append(best_sim_for_missed)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": None,
|
||||
@ -555,6 +639,7 @@ def main():
|
||||
print(f" Test images processed: {len(test_images)}")
|
||||
print(f" CLIP similarity threshold: {args.threshold}")
|
||||
print(f" DETR confidence threshold: {args.detr_threshold}")
|
||||
print(f" Preprocess mode: {args.preprocess_mode}")
|
||||
print(f" Matching method: {args.matching_method}")
|
||||
if args.matching_method in ("margin", "multi-ref"):
|
||||
print(f" Matching margin: {args.margin}")
|
||||
@ -593,6 +678,10 @@ def main():
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Print similarity distribution details if requested
|
||||
if args.similarity_details:
|
||||
print_similarity_details(similarity_details, args.threshold)
|
||||
|
||||
# Write results to file if requested
|
||||
if args.output_file:
|
||||
write_results_to_file(
|
||||
@ -612,6 +701,116 @@ def main():
|
||||
print(f"\nResults appended to: {args.output_file}")
|
||||
|
||||
|
||||
def print_similarity_details(details: dict, threshold: float):
|
||||
"""Print detailed similarity distribution analysis."""
|
||||
import statistics
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SIMILARITY DISTRIBUTION ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
# Helper to compute stats
|
||||
def compute_stats(values, name):
|
||||
if not values:
|
||||
print(f"\n{name}: No data")
|
||||
return
|
||||
print(f"\n{name} (n={len(values)}):")
|
||||
print(f" Min: {min(values):.4f}")
|
||||
print(f" Max: {max(values):.4f}")
|
||||
print(f" Mean: {statistics.mean(values):.4f}")
|
||||
if len(values) > 1:
|
||||
print(f" StdDev: {statistics.stdev(values):.4f}")
|
||||
print(f" Median: {statistics.median(values):.4f}")
|
||||
|
||||
# Percentiles
|
||||
sorted_vals = sorted(values)
|
||||
n = len(sorted_vals)
|
||||
p10 = sorted_vals[int(n * 0.10)] if n > 10 else sorted_vals[0]
|
||||
p25 = sorted_vals[int(n * 0.25)] if n > 4 else sorted_vals[0]
|
||||
p75 = sorted_vals[int(n * 0.75)] if n > 4 else sorted_vals[-1]
|
||||
p90 = sorted_vals[int(n * 0.90)] if n > 10 else sorted_vals[-1]
|
||||
print(f" P10: {p10:.4f}")
|
||||
print(f" P25: {p25:.4f}")
|
||||
print(f" P75: {p75:.4f}")
|
||||
print(f" P90: {p90:.4f}")
|
||||
|
||||
# Count above/below threshold
|
||||
above = sum(1 for v in values if v >= threshold)
|
||||
below = sum(1 for v in values if v < threshold)
|
||||
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
|
||||
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
|
||||
|
||||
# Print distribution stats
|
||||
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities (correct matches)")
|
||||
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities (wrong matches)")
|
||||
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities (false negatives)")
|
||||
compute_stats(details["all_positive_sims"], "ALL similarities to CORRECT logos (per-ref)")
|
||||
compute_stats(details["all_negative_sims"], "ALL similarities to WRONG logos (per-ref)")
|
||||
|
||||
# Overlap analysis
|
||||
tp_sims = details["true_positive_sims"]
|
||||
fp_sims = details["false_positive_sims"]
|
||||
if tp_sims and fp_sims:
|
||||
print("\n" + "-" * 40)
|
||||
print("OVERLAP ANALYSIS:")
|
||||
tp_min, tp_max = min(tp_sims), max(tp_sims)
|
||||
fp_min, fp_max = min(fp_sims), max(fp_sims)
|
||||
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
|
||||
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
|
||||
|
||||
# Check overlap
|
||||
overlap_min = max(tp_min, fp_min)
|
||||
overlap_max = min(tp_max, fp_max)
|
||||
if overlap_min < overlap_max:
|
||||
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
|
||||
tp_in_overlap = sum(1 for v in tp_sims if overlap_min <= v <= overlap_max)
|
||||
fp_in_overlap = sum(1 for v in fp_sims if overlap_min <= v <= overlap_max)
|
||||
print(f" TPs in overlap: {tp_in_overlap} ({100*tp_in_overlap/len(tp_sims):.1f}%)")
|
||||
print(f" FPs in overlap: {fp_in_overlap} ({100*fp_in_overlap/len(fp_sims):.1f}%)")
|
||||
else:
|
||||
print(" NO OVERLAP - distributions are separable!")
|
||||
|
||||
# Suggest optimal threshold
|
||||
all_points = [(s, "tp") for s in tp_sims] + [(s, "fp") for s in fp_sims]
|
||||
all_points.sort()
|
||||
best_thresh = threshold
|
||||
best_f1 = 0
|
||||
total_tp = len(tp_sims)
|
||||
total_fp = len(fp_sims)
|
||||
|
||||
for thresh in [p[0] for p in all_points]:
|
||||
# At this threshold:
|
||||
tp_above = sum(1 for s in tp_sims if s >= thresh)
|
||||
fp_above = sum(1 for s in fp_sims if s >= thresh)
|
||||
prec = tp_above / (tp_above + fp_above) if (tp_above + fp_above) > 0 else 0
|
||||
rec = tp_above / total_tp if total_tp > 0 else 0
|
||||
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
|
||||
if f1 > best_f1:
|
||||
best_f1 = f1
|
||||
best_thresh = thresh
|
||||
|
||||
print(f"\n SUGGESTED OPTIMAL THRESHOLD: {best_thresh:.4f}")
|
||||
print(f" (would give F1 = {best_f1:.4f} on this data)")
|
||||
|
||||
# Print sample detection details
|
||||
det_details = details["detection_details"]
|
||||
if det_details:
|
||||
print("\n" + "-" * 40)
|
||||
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
|
||||
for i, det in enumerate(det_details[:20]):
|
||||
expected = det["expected_logos"]
|
||||
top5 = det["top_5_matches"]
|
||||
print(f"\n [{i+1}] Image: {det['image']}")
|
||||
print(f" Expected: {expected if expected else '(none)'}")
|
||||
print(f" DETR score: {det['detr_score']:.3f}")
|
||||
print(f" Top 5 matches:")
|
||||
for logo, sim in top5:
|
||||
marker = " <-- CORRECT" if logo in expected else ""
|
||||
print(f" {sim:.4f} {logo}{marker}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
|
||||
def write_results_to_file(
|
||||
output_path: Path,
|
||||
args,
|
||||
@ -648,6 +847,7 @@ def write_results_to_file(
|
||||
"",
|
||||
"Configuration:",
|
||||
f" Embedding model: {args.embedding_model}",
|
||||
f" Preprocess mode: {args.preprocess_mode}",
|
||||
f" Reference logos: {num_logos}",
|
||||
f" Refs per logo: {args.refs_per_logo}",
|
||||
f" Total reference embeddings:{total_refs}",
|
||||
|
||||
216
test_results/FINAL_MODEL_ANALYSIS.md
Normal file
216
test_results/FINAL_MODEL_ANALYSIS.md
Normal file
@ -0,0 +1,216 @@
|
||||
# Logo Recognition Model Analysis
|
||||
|
||||
**Date:** January 7, 2026
|
||||
**Purpose:** Determine the best model and threshold for logo recognition of logos not currently in the test set.
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
| Model | Best Threshold | F1 Score | Precision | Recall | Recommended Use |
|
||||
|-------|---------------|----------|-----------|--------|-----------------|
|
||||
| **Image-Split Fine-tuned** | 0.70-0.75 | **67-68%** | 66-80% | 59-68% | Known logos (in reference set) |
|
||||
| Baseline CLIP | 0.70 | 57-60% | 48-49% | 72-77% | Unknown logos (never seen before) |
|
||||
| Logo-Split Fine-tuned | 0.76 | 56% | 49% | 64% | Not recommended |
|
||||
| DINOv2 (small/large) | - | 29-30% | 22-32% | 28-43% | Not suitable |
|
||||
|
||||
**Winner: Image-Split Fine-tuned Model** at threshold **0.70-0.75**
|
||||
|
||||
---
|
||||
|
||||
## Detailed Model Comparison
|
||||
|
||||
### 1. Baseline CLIP (openai/clip-vit-large-patch14)
|
||||
|
||||
The pre-trained CLIP model without any fine-tuning.
|
||||
|
||||
**Threshold Performance:**
|
||||
|
||||
| Threshold | Precision | Recall | F1 |
|
||||
|-----------|-----------|--------|-----|
|
||||
| 0.70 | 47.9% | 71.8% | 57.5% |
|
||||
| 0.80 | 33.0% | 63.1% | 43.4% |
|
||||
| 0.85 | 26.9% | 43.4% | 33.2% |
|
||||
| 0.90 | 54.9% | 22.8% | 32.2% |
|
||||
|
||||
**Similarity Distribution:**
|
||||
- True Positive mean: 0.854 (range: 0.75-0.95)
|
||||
- False Positive mean: 0.846 (range: 0.75-0.95)
|
||||
- **Problem:** TP and FP distributions almost completely overlap
|
||||
|
||||
**Suggested optimal threshold:** 0.756 (predicted F1 = 67.1%)
|
||||
|
||||
**Strengths:**
|
||||
- Good recall at low thresholds
|
||||
- Works on completely unseen logos
|
||||
- No training required
|
||||
|
||||
**Weaknesses:**
|
||||
- Poor separation between correct and incorrect matches
|
||||
- High false positive rate
|
||||
|
||||
---
|
||||
|
||||
### 2. Fine-tuned CLIP (Logo-Level Splits)
|
||||
|
||||
Trained with contrastive learning, tested on completely unseen logo brands.
|
||||
|
||||
**Threshold Performance:**
|
||||
|
||||
| Threshold | Precision | Recall | F1 |
|
||||
|-----------|-----------|--------|-----|
|
||||
| 0.70 | 25.9% | 67.1% | 37.4% |
|
||||
| 0.76 | **49.1%** | 64.3% | **55.7%** |
|
||||
| 0.82 | 75.7% | 41.4% | 53.5% |
|
||||
| 0.86 | 88.6% | 28.1% | 42.7% |
|
||||
|
||||
**Similarity Distribution:**
|
||||
- True Positive mean: 0.853
|
||||
- False Positive mean: 0.787 (better separation than baseline)
|
||||
- Missed logos mean: 0.711 (only 43.7% above 0.75)
|
||||
|
||||
**Suggested optimal threshold:** 0.82 (predicted F1 = 71.9%)
|
||||
|
||||
**Strengths:**
|
||||
- Better TP/FP separation than baseline
|
||||
- Very high precision at high thresholds (88.6% at t=0.86)
|
||||
|
||||
**Weaknesses:**
|
||||
- Does not generalize well to unseen logo brands
|
||||
- Many correct logos score below threshold (56% of missed logos below 0.75)
|
||||
- Worse than baseline at threshold 0.70
|
||||
|
||||
---
|
||||
|
||||
### 3. Fine-tuned CLIP (Image-Level Splits) ⭐ BEST
|
||||
|
||||
Trained with contrastive learning, all logo brands seen during training (different images held out for testing).
|
||||
|
||||
**Threshold Performance:**
|
||||
|
||||
| Threshold | Precision | Recall | F1 |
|
||||
|-----------|-----------|--------|-----|
|
||||
| 0.65 | 56.9% | **75.9%** | 65.0% |
|
||||
| 0.70 | 66.3% | 68.3% | **67.3%** |
|
||||
| 0.75 | **79.9%** | 59.3% | **68.1%** |
|
||||
| 0.80 | 83.7% | 52.8% | 64.8% |
|
||||
| 0.85 | 92.4% | 42.8% | 58.5% |
|
||||
| 0.90 | 98.9% | 24.7% | 39.5% |
|
||||
|
||||
**Similarity Distribution:**
|
||||
- True Positive mean: 0.866 (higher than other models)
|
||||
- False Positive mean: 0.807
|
||||
- TP-FP gap: 0.059 (best separation)
|
||||
- At t=0.75: 92 TP vs only 38 FP (excellent ratio)
|
||||
|
||||
**Suggested optimal threshold:** 0.755 (predicted F1 = 85.6%)
|
||||
|
||||
**Strengths:**
|
||||
- Best overall F1 score (68.1% at t=0.75)
|
||||
- Best precision at any threshold (79.9-98.9%)
|
||||
- Excellent TP/FP ratio
|
||||
- Highest true positive similarity scores
|
||||
|
||||
**Weaknesses:**
|
||||
- Requires logos to be in the reference set during training
|
||||
- May not generalize to completely novel logos
|
||||
|
||||
---
|
||||
|
||||
### 4. DINOv2 Models
|
||||
|
||||
Tested for comparison but significantly underperformed.
|
||||
|
||||
| Model | Precision | Recall | F1 |
|
||||
|-------|-----------|--------|-----|
|
||||
| DINOv2-small | 22.4% | 42.8% | 29.5% |
|
||||
| DINOv2-large | 32.2% | 28.5% | 30.2% |
|
||||
|
||||
**Not recommended** for logo recognition tasks.
|
||||
|
||||
---
|
||||
|
||||
## Recommendations
|
||||
|
||||
### For Logo Recognition of Known Logos (logos in your reference set)
|
||||
|
||||
**Use: Image-Split Fine-tuned Model**
|
||||
|
||||
```bash
|
||||
# Recommended configuration
|
||||
python test_logo_detection.py \
|
||||
-e models/logo_detection/clip_finetuned_image_split \
|
||||
-t 0.70 \
|
||||
--matching-method multi-ref \
|
||||
--use-max-similarity
|
||||
```
|
||||
|
||||
| Use Case | Threshold | Expected Performance |
|
||||
|----------|-----------|---------------------|
|
||||
| Balanced (recommended) | 0.70 | 66% precision, 68% recall, 67% F1 |
|
||||
| High precision | 0.75 | 80% precision, 59% recall, 68% F1 |
|
||||
| Very high precision | 0.80 | 84% precision, 53% recall, 65% F1 |
|
||||
| Maximum precision | 0.85+ | 92%+ precision, <43% recall |
|
||||
|
||||
### For Logo Recognition of Unknown Logos (completely novel brands)
|
||||
|
||||
**Use: Baseline CLIP** (the fine-tuned models don't generalize well)
|
||||
|
||||
```bash
|
||||
# Recommended configuration
|
||||
python test_logo_detection.py \
|
||||
-e openai/clip-vit-large-patch14 \
|
||||
-t 0.70 \
|
||||
--matching-method multi-ref \
|
||||
--use-max-similarity
|
||||
```
|
||||
|
||||
Expected: ~48% precision, ~72% recall, ~58% F1
|
||||
|
||||
---
|
||||
|
||||
## Key Findings
|
||||
|
||||
### 1. Image-Level Splits Dramatically Improve Performance
|
||||
|
||||
The image-split fine-tuned model outperforms all others because:
|
||||
- It learns brand-specific features during training
|
||||
- Test images are different but from same brands
|
||||
- Better represents real-world use where you have reference images for logos you want to detect
|
||||
|
||||
### 2. Logo-Level Splits Test True Generalization (but results are poor)
|
||||
|
||||
The logo-split model tests whether fine-tuning helps with completely unseen logos:
|
||||
- Result: It doesn't help much (56% F1 vs 58% baseline)
|
||||
- Contrastive learning doesn't transfer well to novel brands
|
||||
- Use baseline CLIP for novel logo detection
|
||||
|
||||
### 3. Threshold Sweet Spot is 0.70-0.75
|
||||
|
||||
For all models, the optimal F1 occurs around threshold 0.70-0.75:
|
||||
- Lower thresholds: Too many false positives
|
||||
- Higher thresholds: Misses too many correct logos
|
||||
- At 0.90+: Precision is high but recall drops below 25%
|
||||
|
||||
### 4. Precision-Recall Tradeoff
|
||||
|
||||
| Priority | Threshold | Tradeoff |
|
||||
|----------|-----------|----------|
|
||||
| Recall | 0.65-0.70 | More matches, more false positives |
|
||||
| Balanced | 0.70-0.75 | Best F1 score |
|
||||
| Precision | 0.75-0.80 | Fewer false positives, misses some matches |
|
||||
| High Precision | 0.85+ | Very few false positives, misses many matches |
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
**For production use with known logos:**
|
||||
- Use **Image-Split Fine-tuned Model** at **threshold 0.70-0.75**
|
||||
- Expected F1: 67-68%, Precision: 66-80%
|
||||
|
||||
**For discovering unknown logos:**
|
||||
- Use **Baseline CLIP** at **threshold 0.70**
|
||||
- Expected F1: ~58%, Precision: ~48%
|
||||
|
||||
The image-split fine-tuning provides significant improvements (+8-10% F1) over baseline for known logos, but does not help with completely novel brands. For a production system, ensure all target logos are included in the training/reference set.
|
||||
87
test_results/comparison_results/baseline_20260105_100740.txt
Normal file
87
test_results/comparison_results/baseline_20260105_100740.txt
Normal file
File diff suppressed because one or more lines are too long
@ -0,0 +1,29 @@
|
||||
============================================================
|
||||
|
||||
Test Parameters:
|
||||
Logos: 50, Seed: 42, Threshold: 0.7
|
||||
Method: multi-ref, Refs/logo: 3, Margin: 0.05
|
||||
|
||||
BASELINE (openai/clip-vit-large-patch14):
|
||||
True Positives (correct matches): 101
|
||||
False Positives (wrong matches): 104
|
||||
False Negatives (missed logos): 156
|
||||
Precision: 0.4927 (49.3%)
|
||||
Recall: 0.4056 (40.6%)
|
||||
F1 Score: 0.4449 (44.5%)
|
||||
|
||||
FINE-TUNED (models/logo_detection/clip_finetuned):
|
||||
True Positives (correct matches): 164
|
||||
False Positives (wrong matches): 414
|
||||
False Negatives (missed logos): 115
|
||||
Precision: 0.2837 (28.4%)
|
||||
Recall: 0.6586 (65.9%)
|
||||
F1 Score: 0.3966 (39.7%)
|
||||
|
||||
------------------------------------------------------------
|
||||
F1 SCORE COMPARISON:
|
||||
Baseline: 44.5%
|
||||
Fine-tuned: 39.7%
|
||||
------------------------------------------------------------
|
||||
|
||||
Full results saved to: comparison_results/
|
||||
File diff suppressed because one or more lines are too long
124
test_results/comparison_results_clip_defaults_all_methods.txt
Normal file
124
test_results/comparison_results_clip_defaults_all_methods.txt
Normal file
@ -0,0 +1,124 @@
|
||||
Logo Detection Comparison Tests
|
||||
================================
|
||||
Date: Wed Dec 31 03:43:45 PM MST 2025
|
||||
|
||||
Common Parameters:
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Positive samples: 20
|
||||
Negative samples: 100
|
||||
Min matching refs: 3
|
||||
Seed: 42
|
||||
|
||||
======================================================================
|
||||
TEST: SIMPLE MATCHING
|
||||
Method: Simple (all matches above threshold)
|
||||
======================================================================
|
||||
Date: 2025-12-31 16:02:25
|
||||
|
||||
Configuration:
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2355
|
||||
CLIP threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 751
|
||||
False Positives: 58221
|
||||
False Negatives: 9
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.0127 (1.3%)
|
||||
Recall: 2.0352 (203.5%)
|
||||
F1 Score: 0.0253 (2.5%)
|
||||
|
||||
======================================================================
|
||||
TEST: MARGIN MATCHING
|
||||
Method: Margin-based (margin=0.05)
|
||||
======================================================================
|
||||
Date: 2025-12-31 16:20:42
|
||||
|
||||
Configuration:
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2361
|
||||
CLIP threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 60
|
||||
False Positives: 26
|
||||
False Negatives: 310
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.6977 (69.8%)
|
||||
Recall: 0.1626 (16.3%)
|
||||
F1 Score: 0.2637 (26.4%)
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Method: Multi-ref (mean, min_refs=3, margin=0.05)
|
||||
======================================================================
|
||||
Date: 2025-12-31 16:38:59
|
||||
|
||||
Configuration:
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2352
|
||||
CLIP threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 233
|
||||
False Positives: 217
|
||||
False Negatives: 170
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.5178 (51.8%)
|
||||
Recall: 0.6314 (63.1%)
|
||||
F1 Score: 0.5690 (56.9%)
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Method: Multi-ref (max, min_refs=3, margin=0.05)
|
||||
======================================================================
|
||||
Date: 2025-12-31 16:56:49
|
||||
|
||||
Configuration:
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2350
|
||||
CLIP threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 278
|
||||
False Positives: 259
|
||||
False Negatives: 136
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.5177 (51.8%)
|
||||
Recall: 0.7534 (75.3%)
|
||||
F1 Score: 0.6137 (61.4%)
|
||||
|
||||
105
test_results/model_comparison_results.txt
Normal file
105
test_results/model_comparison_results.txt
Normal file
@ -0,0 +1,105 @@
|
||||
Embedding Model Comparison Tests
|
||||
=================================
|
||||
Date: Fri Jan 2 12:47:03 PM MST 2026
|
||||
|
||||
Common Parameters:
|
||||
Matching method: multi-ref (max)
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Positive samples: 20
|
||||
Negative samples: 100
|
||||
Min matching refs: 3
|
||||
Threshold: 0.70
|
||||
Margin: 0.05
|
||||
Seed: 42
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Model: openai/clip-vit-large-patch14
|
||||
Method: Multi-ref (max, min_refs=3, margin=0.05)
|
||||
======================================================================
|
||||
Date: 2026-01-02 13:05:17
|
||||
|
||||
Configuration:
|
||||
Embedding model: openai/clip-vit-large-patch14
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2355
|
||||
Similarity threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 284
|
||||
False Positives: 295
|
||||
False Negatives: 124
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.4905 (49.1%)
|
||||
Recall: 0.7696 (77.0%)
|
||||
F1 Score: 0.5992 (59.9%)
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Model: facebook/dinov2-small
|
||||
Method: Multi-ref (max, min_refs=3, margin=0.05)
|
||||
======================================================================
|
||||
Date: 2026-01-02 13:19:01
|
||||
|
||||
Configuration:
|
||||
Embedding model: facebook/dinov2-small
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2358
|
||||
Similarity threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 158
|
||||
False Positives: 546
|
||||
False Negatives: 234
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.2244 (22.4%)
|
||||
Recall: 0.4282 (42.8%)
|
||||
F1 Score: 0.2945 (29.5%)
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Model: facebook/dinov2-large
|
||||
Method: Multi-ref (max, min_refs=3, margin=0.05)
|
||||
======================================================================
|
||||
Date: 2026-01-02 13:39:33
|
||||
|
||||
Configuration:
|
||||
Embedding model: facebook/dinov2-large
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2355
|
||||
Similarity threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 105
|
||||
False Positives: 221
|
||||
False Negatives: 277
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.3221 (32.2%)
|
||||
Recall: 0.2846 (28.5%)
|
||||
F1 Score: 0.3022 (30.2%)
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -346,6 +346,131 @@ DINOv2 Small produces over 3x as many false positives as true positives, making
|
||||
|
||||
---
|
||||
|
||||
## Summary and Recommendations
|
||||
|
||||
This section synthesizes findings from all test runs to provide actionable recommendations for logo detection configuration and future improvements.
|
||||
|
||||
### Best Configuration
|
||||
|
||||
Based on all tests conducted, the optimal configuration is:
|
||||
|
||||
| Parameter | Recommended Value | Rationale |
|
||||
|-----------|-------------------|-----------|
|
||||
| **Embedding Model** | `openai/clip-vit-large-patch14` | 2x better F1 than DINOv2 alternatives |
|
||||
| **Matching Method** | `multi-ref` with max similarity | Best F1 (59.9%) and recall (77.0%) |
|
||||
| **Similarity Threshold** | 0.70 | Lower thresholds outperform higher ones |
|
||||
| **Margin** | 0.05 | Minimal impact; keep low to avoid rejecting valid matches |
|
||||
| **Min Matching Refs** | 3 | Provides better discrimination than threshold alone |
|
||||
| **Refs Per Logo** | 10 | More references improve robustness |
|
||||
| **DETR Threshold** | 0.50 | Standard detection confidence |
|
||||
|
||||
### Performance Expectations
|
||||
|
||||
With the recommended configuration:
|
||||
|
||||
| Metric | Expected Value | Interpretation |
|
||||
|--------|----------------|----------------|
|
||||
| Precision | ~49% | About half of detections are correct |
|
||||
| Recall | ~77% | Finds most logos present in images |
|
||||
| F1 Score | ~60% | Moderate overall accuracy |
|
||||
| FP:TP Ratio | ~1:1 | Approximately equal true and false positives |
|
||||
|
||||
**Important**: These results indicate the system is suitable for applications that can tolerate a high false positive rate, such as:
|
||||
- Initial screening with human review
|
||||
- Flagging content for further analysis
|
||||
- Low-stakes logo presence detection
|
||||
|
||||
The system is **not suitable** for high-precision applications without additional filtering or verification steps.
|
||||
|
||||
### Key Insights from Testing
|
||||
|
||||
#### What Works
|
||||
|
||||
1. **Multi-ref matching with max aggregation** consistently outperforms other methods
|
||||
2. **Multiple references per logo** (10) provides robustness against logo variations
|
||||
3. **min_matching_refs=3** is more effective at discrimination than threshold tuning
|
||||
4. **CLIP embeddings** significantly outperform self-supervised alternatives (DINOv2)
|
||||
|
||||
#### What Doesn't Work
|
||||
|
||||
1. **Raising similarity threshold** paradoxically increases false positives in the 0.70-0.85 range
|
||||
2. **Margin-only matching** fails with multiple references (same-logo refs compete)
|
||||
3. **DINOv2 models** produce 2-3x worse results than CLIP
|
||||
4. **Simple threshold-based matching** produces unacceptable 78:1 FP:TP ratio
|
||||
|
||||
#### Limitations
|
||||
|
||||
1. **~50% precision ceiling**: Even the best configuration produces nearly as many false positives as true positives
|
||||
2. **No clean threshold separation**: CLIP's embedding space doesn't provide clear decision boundaries for logos
|
||||
3. **General-purpose models**: Neither CLIP nor DINOv2 are optimized for fine-grained logo discrimination
|
||||
4. **Pipeline dependencies**: Results depend heavily on DETR detection quality
|
||||
|
||||
### Recommendations for Future Improvements
|
||||
|
||||
#### Short-Term Improvements
|
||||
|
||||
| Improvement | Expected Impact | Effort |
|
||||
|-------------|-----------------|--------|
|
||||
| **Post-processing filters** | Reduce FP by 20-30% | Low |
|
||||
| Add color histogram matching | Filter matches with wrong colors | |
|
||||
| Add aspect ratio validation | Reject shape mismatches | |
|
||||
| Add text detection | Filter if expected text is missing | |
|
||||
| **Reference curation** | Improve TP by 10-20% | Low |
|
||||
| Remove low-quality references | Reduce noise in ref embeddings | |
|
||||
| Ensure diverse logo variants | Improve coverage | |
|
||||
| **Ensemble scoring** | Improve F1 by 10-15% | Medium |
|
||||
| Combine CLIP + color features | Multi-signal confidence | |
|
||||
| Weighted voting across refs | More robust aggregation | |
|
||||
|
||||
#### Medium-Term Improvements
|
||||
|
||||
| Improvement | Expected Impact | Effort |
|
||||
|-------------|-----------------|--------|
|
||||
| **Fine-tune CLIP on logos** | Improve F1 by 20-40% | Medium |
|
||||
| Contrastive training on logo pairs | Better embedding separation | |
|
||||
| Use LogoDet-3K for training data | Domain-specific features | |
|
||||
| **Alternative detection models** | Improve detection quality | Medium |
|
||||
| Test YOLOv8 for logo detection | Faster, potentially more accurate | |
|
||||
| Train custom detector on logo data | Better region proposals | |
|
||||
| **Learned similarity metric** | Improve precision by 30-50% | Medium |
|
||||
| Train siamese network for logo matching | Replace cosine similarity | |
|
||||
| Learn logo-specific distance function | Better discrimination | |
|
||||
|
||||
#### Long-Term Improvements
|
||||
|
||||
| Improvement | Expected Impact | Effort |
|
||||
|-------------|-----------------|--------|
|
||||
| **End-to-end logo recognition model** | F1 > 85% | High |
|
||||
| Single model for detection + recognition | Eliminate pipeline errors | |
|
||||
| Train on large-scale logo dataset | Comprehensive coverage | |
|
||||
| **Logo-specific foundation model** | F1 > 90% | High |
|
||||
| Pre-train on millions of logo images | Domain expertise | |
|
||||
| Fine-tune for specific brand sets | Production-ready accuracy | |
|
||||
|
||||
### Decision Framework
|
||||
|
||||
Use this framework to choose between precision and recall:
|
||||
|
||||
| Use Case | Priority | Recommended Adjustments |
|
||||
|----------|----------|------------------------|
|
||||
| **Content moderation** | High recall | Use defaults; accept FPs for human review |
|
||||
| **Brand monitoring** | Balanced | Use defaults; filter obvious FPs |
|
||||
| **Automated licensing** | High precision | Use threshold=0.90; accept low recall |
|
||||
| **Search/discovery** | High recall | Lower threshold to 0.65; more refs |
|
||||
|
||||
### Conclusion
|
||||
|
||||
The current DETR + CLIP pipeline with multi-ref matching achieves moderate accuracy (~60% F1) that is suitable for screening applications but falls short of production requirements for automated decision-making. The fundamental limitation is that general-purpose vision models lack the fine-grained discrimination needed for logo recognition.
|
||||
|
||||
**To achieve production-quality accuracy (>85% F1), the system requires:**
|
||||
1. A logo-specific embedding model (fine-tuned or trained from scratch)
|
||||
2. Additional visual features beyond CLIP embeddings
|
||||
3. Potentially an end-to-end architecture designed for logo recognition
|
||||
|
||||
The test framework established here provides the foundation for evaluating these future improvements systematically.
|
||||
|
||||
---
|
||||
|
||||
## Test Run: [Next Test Name]
|
||||
|
||||
*Results pending...*
|
||||
@ -0,0 +1,20 @@
|
||||
============================================================
|
||||
THRESHOLD OPTIMIZATION RESULTS
|
||||
Model: finetuned (models/logo_detection/clip_finetuned)
|
||||
============================================================
|
||||
|
||||
Threshold TP FP FN Prec Recall F1
|
||||
--------------------------------------------------------------------
|
||||
0.70 167 477 120 25.9% 67.1% 37.4%
|
||||
0.72 158 339 116 31.8% 63.5% 42.4%
|
||||
0.74 150 252 123 37.3% 60.2% 46.1%
|
||||
0.76 160 166 119 49.1% 64.3% 55.7%
|
||||
0.78 120 102 147 54.1% 48.2% 51.0%
|
||||
0.80 110 73 151 60.1% 44.2% 50.9%
|
||||
0.82 103 33 159 75.7% 41.4% 53.5%
|
||||
0.84 74 18 180 80.4% 29.7% 43.4%
|
||||
0.86 70 9 187 88.6% 28.1% 42.7%
|
||||
--------------------------------------------------------------------
|
||||
|
||||
BEST THRESHOLD: 0.76 (F1 = 55.7%)
|
||||
|
||||
193
test_results/threshold_analysis/threshold_test_results.txt
Normal file
193
test_results/threshold_analysis/threshold_test_results.txt
Normal file
@ -0,0 +1,193 @@
|
||||
Threshold Optimization Tests
|
||||
=============================
|
||||
Date: Fri Jan 2 10:11:34 AM MST 2026
|
||||
|
||||
Common Parameters:
|
||||
Matching method: multi-ref (max)
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Positive samples: 20
|
||||
Negative samples: 100
|
||||
Min matching refs: 3
|
||||
Seed: 42
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Model: openai/clip-vit-large-patch14
|
||||
Method: Multi-ref (max, min_refs=3, margin=0.05)
|
||||
======================================================================
|
||||
Date: 2026-01-02 10:29:26
|
||||
|
||||
Configuration:
|
||||
Embedding model: openai/clip-vit-large-patch14
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2358
|
||||
Similarity threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 265
|
||||
False Positives: 288
|
||||
False Negatives: 141
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.4792 (47.9%)
|
||||
Recall: 0.7182 (71.8%)
|
||||
F1 Score: 0.5748 (57.5%)
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Model: openai/clip-vit-large-patch14
|
||||
Method: Multi-ref (max, min_refs=3, margin=0.05)
|
||||
======================================================================
|
||||
Date: 2026-01-02 10:47:35
|
||||
|
||||
Configuration:
|
||||
Embedding model: openai/clip-vit-large-patch14
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2348
|
||||
Similarity threshold: 0.8
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 233
|
||||
False Positives: 472
|
||||
False Negatives: 165
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.3305 (33.0%)
|
||||
Recall: 0.6314 (63.1%)
|
||||
F1 Score: 0.4339 (43.4%)
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Model: openai/clip-vit-large-patch14
|
||||
Method: Multi-ref (max, min_refs=3, margin=0.1)
|
||||
======================================================================
|
||||
Date: 2026-01-02 11:05:34
|
||||
|
||||
Configuration:
|
||||
Embedding model: openai/clip-vit-large-patch14
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2357
|
||||
Similarity threshold: 0.8
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 187
|
||||
False Positives: 375
|
||||
False Negatives: 208
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.3327 (33.3%)
|
||||
Recall: 0.5068 (50.7%)
|
||||
F1 Score: 0.4017 (40.2%)
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Model: openai/clip-vit-large-patch14
|
||||
Method: Multi-ref (max, min_refs=3, margin=0.1)
|
||||
======================================================================
|
||||
Date: 2026-01-02 11:23:33
|
||||
|
||||
Configuration:
|
||||
Embedding model: openai/clip-vit-large-patch14
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2356
|
||||
Similarity threshold: 0.85
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 160
|
||||
False Positives: 434
|
||||
False Negatives: 223
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.2694 (26.9%)
|
||||
Recall: 0.4336 (43.4%)
|
||||
F1 Score: 0.3323 (33.2%)
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Model: openai/clip-vit-large-patch14
|
||||
Method: Multi-ref (max, min_refs=3, margin=0.15)
|
||||
======================================================================
|
||||
Date: 2026-01-02 11:41:47
|
||||
|
||||
Configuration:
|
||||
Embedding model: openai/clip-vit-large-patch14
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2359
|
||||
Similarity threshold: 0.85
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 163
|
||||
False Positives: 410
|
||||
False Negatives: 220
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.2845 (28.4%)
|
||||
Recall: 0.4417 (44.2%)
|
||||
F1 Score: 0.3461 (34.6%)
|
||||
|
||||
======================================================================
|
||||
TEST: MULTI-REF MATCHING
|
||||
Model: openai/clip-vit-large-patch14
|
||||
Method: Multi-ref (max, min_refs=3, margin=0.15)
|
||||
======================================================================
|
||||
Date: 2026-01-02 12:00:00
|
||||
|
||||
Configuration:
|
||||
Embedding model: openai/clip-vit-large-patch14
|
||||
Reference logos: 20
|
||||
Refs per logo: 10
|
||||
Total reference embeddings:189
|
||||
Positive samples/logo: 20
|
||||
Negative samples/logo: 100
|
||||
Test images processed: 2363
|
||||
Similarity threshold: 0.9
|
||||
DETR threshold: 0.5
|
||||
Random seed: 42
|
||||
|
||||
Results:
|
||||
True Positives: 84
|
||||
False Positives: 69
|
||||
False Negatives: 288
|
||||
Total Expected: 369
|
||||
|
||||
Scores:
|
||||
Precision: 0.5490 (54.9%)
|
||||
Recall: 0.2276 (22.8%)
|
||||
F1 Score: 0.3218 (32.2%)
|
||||
|
||||
310
train_clip_logo.py
Normal file
310
train_clip_logo.py
Normal file
@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fine-tune CLIP vision encoder for logo recognition.
|
||||
|
||||
This script trains a CLIP model using contrastive learning on the LogoDet-3K
|
||||
dataset to improve logo embedding quality for similarity-based matching.
|
||||
|
||||
Usage:
|
||||
# Train with YAML config
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
|
||||
|
||||
# Train with command-line overrides
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||
--learning-rate 5e-6 --max-epochs 30
|
||||
|
||||
# Resume from checkpoint
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||
--resume checkpoints/epoch_10.pt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from training.config import TrainingConfig
|
||||
from training.dataset import create_dataloaders
|
||||
from training.model import create_model
|
||||
from training.trainer import Trainer
|
||||
|
||||
|
||||
def setup_logging(log_level: str = "INFO") -> logging.Logger:
|
||||
"""Configure logging."""
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level.upper()),
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
"""Set random seeds for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Fine-tune CLIP for logo recognition",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Config file
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to YAML configuration file",
|
||||
)
|
||||
|
||||
# Dataset paths
|
||||
parser.add_argument(
|
||||
"--dataset-dir",
|
||||
type=str,
|
||||
help="Path to LogoDet-3K dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference-dir",
|
||||
type=str,
|
||||
help="Path to reference logos directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-path",
|
||||
type=str,
|
||||
help="Path to SQLite database",
|
||||
)
|
||||
|
||||
# Model
|
||||
parser.add_argument(
|
||||
"--base-model",
|
||||
type=str,
|
||||
help="Base CLIP model name or path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-r",
|
||||
type=int,
|
||||
help="LoRA rank (0 to disable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze-layers",
|
||||
type=int,
|
||||
help="Number of transformer layers to freeze",
|
||||
)
|
||||
|
||||
# Training
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate",
|
||||
type=float,
|
||||
help="Learning rate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-epochs",
|
||||
type=int,
|
||||
help="Maximum number of epochs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient-accumulation-steps",
|
||||
type=int,
|
||||
help="Gradient accumulation steps",
|
||||
)
|
||||
|
||||
# Loss
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
help="Temperature for InfoNCE loss",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loss-type",
|
||||
choices=["infonce", "supcon", "triplet", "combined"],
|
||||
help="Loss function type",
|
||||
)
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument(
|
||||
"--checkpoint-dir",
|
||||
type=str,
|
||||
help="Directory for checkpoints",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
help="Directory for final model output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
help="Path to checkpoint to resume from",
|
||||
)
|
||||
|
||||
# Other
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
help="Random seed",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Logging level",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-mixed-precision",
|
||||
action="store_true",
|
||||
help="Disable mixed precision training",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training entry point."""
|
||||
args = parse_args()
|
||||
|
||||
# Setup logging
|
||||
logger = setup_logging(args.log_level)
|
||||
logger.info("CLIP Logo Fine-Tuning")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Load or create configuration
|
||||
if args.config:
|
||||
logger.info(f"Loading config from: {args.config}")
|
||||
config = TrainingConfig.from_yaml(args.config)
|
||||
else:
|
||||
logger.info("Using default configuration")
|
||||
config = TrainingConfig()
|
||||
|
||||
# Apply command-line overrides
|
||||
override_fields = [
|
||||
"dataset_dir", "reference_dir", "db_path", "base_model",
|
||||
"lora_r", "freeze_layers", "batch_size", "learning_rate",
|
||||
"max_epochs", "gradient_accumulation_steps", "temperature",
|
||||
"loss_type", "checkpoint_dir", "output_dir", "seed",
|
||||
]
|
||||
for field in override_fields:
|
||||
arg_name = field.replace("_", "-")
|
||||
arg_value = getattr(args, field.replace("-", "_"), None)
|
||||
if arg_value is not None:
|
||||
setattr(config, field, arg_value)
|
||||
logger.info(f"Override: {field} = {arg_value}")
|
||||
|
||||
if args.no_mixed_precision:
|
||||
config.mixed_precision = False
|
||||
logger.info("Override: mixed_precision = False")
|
||||
|
||||
# Validate configuration
|
||||
warnings = config.validate()
|
||||
for warning in warnings:
|
||||
logger.warning(f"Config warning: {warning}")
|
||||
|
||||
# Set random seed
|
||||
set_seed(config.seed)
|
||||
logger.info(f"Random seed: {config.seed}")
|
||||
|
||||
# Check paths exist
|
||||
db_path = Path(config.db_path)
|
||||
ref_dir = Path(config.reference_dir)
|
||||
|
||||
if not db_path.exists():
|
||||
logger.error(f"Database not found: {db_path}")
|
||||
logger.error("Run prepare_test_data.py first to create the database.")
|
||||
sys.exit(1)
|
||||
|
||||
if not ref_dir.exists():
|
||||
logger.error(f"Reference directory not found: {ref_dir}")
|
||||
logger.error("Run prepare_test_data.py first to extract reference logos.")
|
||||
sys.exit(1)
|
||||
|
||||
# Create model
|
||||
logger.info(f"Creating model from: {config.base_model}")
|
||||
model, processor = create_model(
|
||||
base_model=config.base_model,
|
||||
lora_r=config.lora_r,
|
||||
lora_alpha=config.lora_alpha,
|
||||
lora_dropout=config.lora_dropout,
|
||||
freeze_layers=config.freeze_layers,
|
||||
use_gradient_checkpointing=config.use_gradient_checkpointing,
|
||||
)
|
||||
|
||||
# Create dataloaders
|
||||
logger.info("Creating dataloaders...")
|
||||
train_loader, val_loader, test_loader = create_dataloaders(
|
||||
db_path=str(config.db_path),
|
||||
reference_dir=str(config.reference_dir),
|
||||
batch_size=config.batch_size,
|
||||
logos_per_batch=config.logos_per_batch,
|
||||
samples_per_logo=config.samples_per_logo,
|
||||
num_workers=config.num_workers,
|
||||
train_split=config.train_split,
|
||||
val_split=config.val_split,
|
||||
test_split=config.test_split,
|
||||
seed=config.seed,
|
||||
augmentation_strength=config.augmentation_strength,
|
||||
split_level=getattr(config, 'split_level', 'logo'),
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
config=config,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume:
|
||||
resume_path = Path(args.resume)
|
||||
if resume_path.exists():
|
||||
logger.info(f"Resuming from: {resume_path}")
|
||||
# Set checkpoint dir to resume path's parent
|
||||
if resume_path.is_file():
|
||||
config.checkpoint_dir = str(resume_path.parent)
|
||||
trainer.load_checkpoint(resume_path.name)
|
||||
else:
|
||||
logger.warning(f"Resume checkpoint not found: {resume_path}")
|
||||
|
||||
# Train
|
||||
logger.info("\nStarting training...")
|
||||
final_metrics = trainer.train()
|
||||
|
||||
logger.info("\nTraining complete!")
|
||||
logger.info(f" Best val loss: {final_metrics['best_val_loss']:.4f}")
|
||||
logger.info(f" Best separation: {final_metrics['best_val_separation']:.4f}")
|
||||
logger.info(f" Total epochs: {final_metrics['total_epochs']}")
|
||||
logger.info(f" Total time: {final_metrics['total_time_minutes']:.1f} minutes")
|
||||
|
||||
# Export model
|
||||
output_path = trainer.export_model()
|
||||
logger.info(f"\nModel exported to: {output_path}")
|
||||
|
||||
# Print next steps
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("Next steps:")
|
||||
logger.info(f"1. Test the fine-tuned model:")
|
||||
logger.info(f" uv run python test_logo_detection.py -n 50 \\")
|
||||
logger.info(f" -e {output_path} --matching-method multi-ref")
|
||||
logger.info(f"")
|
||||
logger.info(f"2. Compare with baseline:")
|
||||
logger.info(f" uv run python test_logo_detection.py -n 50 \\")
|
||||
logger.info(f" -e openai/clip-vit-large-patch14 --matching-method multi-ref")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
24
training/__init__.py
Normal file
24
training/__init__.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""
|
||||
CLIP fine-tuning module for logo recognition.
|
||||
|
||||
This module provides tools for fine-tuning CLIP's vision encoder using
|
||||
contrastive learning on the LogoDet-3K dataset.
|
||||
"""
|
||||
|
||||
from .config import TrainingConfig
|
||||
from .dataset import LogoContrastiveDataset, create_dataloaders
|
||||
from .model import LogoFineTunedCLIP
|
||||
from .losses import InfoNCELoss, TripletLoss
|
||||
from .trainer import Trainer
|
||||
from .evaluation import EmbeddingEvaluator
|
||||
|
||||
__all__ = [
|
||||
"TrainingConfig",
|
||||
"LogoContrastiveDataset",
|
||||
"create_dataloaders",
|
||||
"LogoFineTunedCLIP",
|
||||
"InfoNCELoss",
|
||||
"TripletLoss",
|
||||
"Trainer",
|
||||
"EmbeddingEvaluator",
|
||||
]
|
||||
142
training/config.py
Normal file
142
training/config.py
Normal file
@ -0,0 +1,142 @@
|
||||
"""
|
||||
Training configuration for CLIP fine-tuning.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
import yaml
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Configuration for CLIP logo fine-tuning."""
|
||||
|
||||
# Base model
|
||||
base_model: str = "openai/clip-vit-large-patch14"
|
||||
|
||||
# Dataset paths
|
||||
dataset_dir: str = "LogoDet-3K"
|
||||
reference_dir: str = "reference_logos"
|
||||
db_path: str = "test_data_mapping.db"
|
||||
|
||||
# Data split configuration
|
||||
split_level: str = "logo" # "logo" for brand-level, "image" for image-level
|
||||
train_split: float = 0.7
|
||||
val_split: float = 0.15
|
||||
test_split: float = 0.15
|
||||
|
||||
# Batch construction
|
||||
batch_size: int = 16
|
||||
logos_per_batch: int = 32
|
||||
samples_per_logo: int = 4
|
||||
gradient_accumulation_steps: int = 8
|
||||
num_workers: int = 4
|
||||
|
||||
# Model architecture
|
||||
lora_r: int = 16
|
||||
lora_alpha: int = 32
|
||||
lora_dropout: float = 0.1
|
||||
freeze_layers: int = 12
|
||||
use_gradient_checkpointing: bool = True
|
||||
|
||||
# Training hyperparameters
|
||||
learning_rate: float = 1e-5
|
||||
weight_decay: float = 0.01
|
||||
warmup_steps: int = 500
|
||||
max_epochs: int = 20
|
||||
mixed_precision: bool = True
|
||||
|
||||
# Loss function
|
||||
temperature: float = 0.07
|
||||
loss_type: str = "infonce" # "infonce" or "triplet"
|
||||
triplet_margin: float = 0.3
|
||||
|
||||
# Early stopping
|
||||
patience: int = 5
|
||||
min_delta: float = 0.001
|
||||
|
||||
# Checkpoints and output
|
||||
checkpoint_dir: str = "checkpoints"
|
||||
output_dir: str = "models/logo_detection/clip_finetuned"
|
||||
save_every_n_epochs: int = 5
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: int = 10
|
||||
eval_every_n_epochs: int = 1
|
||||
|
||||
# Random seed for reproducibility
|
||||
seed: int = 42
|
||||
|
||||
# Hard negative mining
|
||||
use_hard_negatives: bool = False
|
||||
hard_negative_start_epoch: int = 5
|
||||
hard_negatives_per_logo: int = 10
|
||||
|
||||
# Data augmentation
|
||||
use_augmentation: bool = True
|
||||
augmentation_strength: str = "medium" # "light", "medium", "strong"
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, yaml_path: str) -> "TrainingConfig":
|
||||
"""Load configuration from YAML file."""
|
||||
with open(yaml_path, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
return cls(**config_dict)
|
||||
|
||||
def to_yaml(self, yaml_path: str) -> None:
|
||||
"""Save configuration to YAML file."""
|
||||
Path(yaml_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(yaml_path, "w") as f:
|
||||
yaml.dump(self.__dict__, f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
def validate(self) -> List[str]:
|
||||
"""Validate configuration and return list of warnings."""
|
||||
warnings = []
|
||||
|
||||
# Check split ratios
|
||||
total_split = self.train_split + self.val_split + self.test_split
|
||||
if abs(total_split - 1.0) > 0.01:
|
||||
warnings.append(
|
||||
f"Split ratios sum to {total_split}, expected 1.0"
|
||||
)
|
||||
|
||||
# Check batch construction
|
||||
effective_batch = self.batch_size * self.gradient_accumulation_steps
|
||||
if effective_batch < 64:
|
||||
warnings.append(
|
||||
f"Effective batch size ({effective_batch}) is small for contrastive learning. "
|
||||
"Consider increasing batch_size or gradient_accumulation_steps."
|
||||
)
|
||||
|
||||
# Check LoRA config
|
||||
if self.lora_r > 0 and self.lora_alpha < self.lora_r:
|
||||
warnings.append(
|
||||
f"lora_alpha ({self.lora_alpha}) < lora_r ({self.lora_r}). "
|
||||
"This may reduce LoRA effectiveness."
|
||||
)
|
||||
|
||||
# Check freeze layers
|
||||
if self.freeze_layers < 0:
|
||||
warnings.append("freeze_layers should be >= 0")
|
||||
|
||||
# Check temperature
|
||||
if self.temperature <= 0:
|
||||
warnings.append("temperature must be positive")
|
||||
elif self.temperature > 1.0:
|
||||
warnings.append(
|
||||
f"temperature ({self.temperature}) is high. "
|
||||
"Typical values are 0.05-0.1."
|
||||
)
|
||||
|
||||
return warnings
|
||||
|
||||
@property
|
||||
def effective_batch_size(self) -> int:
|
||||
"""Calculate effective batch size with gradient accumulation."""
|
||||
return self.batch_size * self.gradient_accumulation_steps
|
||||
|
||||
@property
|
||||
def samples_per_batch(self) -> int:
|
||||
"""Total samples in one batch (logos_per_batch * samples_per_logo)."""
|
||||
return self.logos_per_batch * self.samples_per_logo
|
||||
567
training/dataset.py
Normal file
567
training/dataset.py
Normal file
@ -0,0 +1,567 @@
|
||||
"""
|
||||
Dataset classes for contrastive learning on logo images.
|
||||
"""
|
||||
|
||||
import random
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
# CLIP normalization values
|
||||
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
||||
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
|
||||
def get_train_transforms(strength: str = "medium") -> transforms.Compose:
|
||||
"""
|
||||
Get training data augmentation transforms.
|
||||
|
||||
Args:
|
||||
strength: Augmentation strength - "light", "medium", or "strong"
|
||||
|
||||
Returns:
|
||||
Composed transforms for training
|
||||
"""
|
||||
if strength == "light":
|
||||
return transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.ColorJitter(brightness=0.1, contrast=0.1),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
elif strength == "medium":
|
||||
return transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomRotation(degrees=15),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
|
||||
),
|
||||
transforms.RandomAffine(
|
||||
degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.1),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
else: # strong
|
||||
return transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.RandomCrop(224),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomVerticalFlip(p=0.1),
|
||||
transforms.RandomRotation(degrees=30),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
|
||||
),
|
||||
transforms.RandomAffine(
|
||||
degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
|
||||
|
||||
def get_val_transforms() -> transforms.Compose:
|
||||
"""Get validation/test transforms (no augmentation)."""
|
||||
return transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
|
||||
|
||||
class LogoDataset:
|
||||
"""
|
||||
Manages logo data from the SQLite database.
|
||||
|
||||
Handles loading logo-to-image mappings and splitting by logo brand or image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str,
|
||||
reference_dir: str,
|
||||
train_split: float = 0.7,
|
||||
val_split: float = 0.15,
|
||||
test_split: float = 0.15,
|
||||
seed: int = 42,
|
||||
split_level: str = "logo",
|
||||
):
|
||||
"""
|
||||
Initialize the logo dataset.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
reference_dir: Directory containing reference logo images
|
||||
train_split: Fraction for training
|
||||
val_split: Fraction for validation
|
||||
test_split: Fraction for testing
|
||||
seed: Random seed for reproducibility
|
||||
split_level: "logo" for brand-level splits (test on unseen brands),
|
||||
"image" for image-level splits (test on unseen images
|
||||
from seen brands)
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.reference_dir = Path(reference_dir)
|
||||
self.seed = seed
|
||||
self.split_level = split_level
|
||||
|
||||
# Load logo-to-images mapping from database
|
||||
self.logo_to_images = self._load_logo_mappings()
|
||||
self.all_logos = list(self.logo_to_images.keys())
|
||||
|
||||
if split_level == "logo":
|
||||
# Logo-level splits: test logos are completely unseen brands
|
||||
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
||||
train_split, val_split, test_split
|
||||
)
|
||||
# For logo-level splits, each split has its own logos
|
||||
self.train_logo_to_images = {
|
||||
l: self.logo_to_images[l] for l in self.train_logos
|
||||
}
|
||||
self.val_logo_to_images = {
|
||||
l: self.logo_to_images[l] for l in self.val_logos
|
||||
}
|
||||
self.test_logo_to_images = {
|
||||
l: self.logo_to_images[l] for l in self.test_logos
|
||||
}
|
||||
else:
|
||||
# Image-level splits: all logos present in all splits, different images
|
||||
(
|
||||
self.train_logo_to_images,
|
||||
self.val_logo_to_images,
|
||||
self.test_logo_to_images,
|
||||
) = self._split_images(train_split, val_split, test_split)
|
||||
# All logos are in all splits
|
||||
self.train_logos = list(self.train_logo_to_images.keys())
|
||||
self.val_logos = list(self.val_logo_to_images.keys())
|
||||
self.test_logos = list(self.test_logo_to_images.keys())
|
||||
|
||||
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
|
||||
"""Load logo name to image paths mapping from database."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT ln.name, rl.filename
|
||||
FROM reference_logos rl
|
||||
JOIN logo_names ln ON rl.logo_name_id = ln.id
|
||||
ORDER BY ln.name
|
||||
""")
|
||||
|
||||
logo_to_images: Dict[str, List[Path]] = {}
|
||||
for logo_name, filename in cursor.fetchall():
|
||||
if logo_name not in logo_to_images:
|
||||
logo_to_images[logo_name] = []
|
||||
logo_to_images[logo_name].append(self.reference_dir / filename)
|
||||
|
||||
conn.close()
|
||||
return logo_to_images
|
||||
|
||||
def _split_logos(
|
||||
self,
|
||||
train_split: float,
|
||||
val_split: float,
|
||||
test_split: float,
|
||||
) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""Split logos at brand level for train/val/test."""
|
||||
random.seed(self.seed)
|
||||
logos = self.all_logos.copy()
|
||||
random.shuffle(logos)
|
||||
|
||||
n = len(logos)
|
||||
train_end = int(n * train_split)
|
||||
val_end = train_end + int(n * val_split)
|
||||
|
||||
train_logos = logos[:train_end]
|
||||
val_logos = logos[train_end:val_end]
|
||||
test_logos = logos[val_end:]
|
||||
|
||||
return train_logos, val_logos, test_logos
|
||||
|
||||
def _split_images(
|
||||
self,
|
||||
train_split: float,
|
||||
val_split: float,
|
||||
test_split: float,
|
||||
) -> Tuple[Dict[str, List[Path]], Dict[str, List[Path]], Dict[str, List[Path]]]:
|
||||
"""
|
||||
Split images within each logo brand for train/val/test.
|
||||
|
||||
Each logo brand will have images in all splits, allowing the model
|
||||
to see some examples of each brand during training.
|
||||
"""
|
||||
random.seed(self.seed)
|
||||
|
||||
train_logo_to_images: Dict[str, List[Path]] = {}
|
||||
val_logo_to_images: Dict[str, List[Path]] = {}
|
||||
test_logo_to_images: Dict[str, List[Path]] = {}
|
||||
|
||||
for logo, images in self.logo_to_images.items():
|
||||
# Shuffle images for this logo
|
||||
shuffled_images = images.copy()
|
||||
random.shuffle(shuffled_images)
|
||||
|
||||
n = len(shuffled_images)
|
||||
if n == 1:
|
||||
# Only one image: put in train only
|
||||
train_logo_to_images[logo] = shuffled_images
|
||||
continue
|
||||
elif n == 2:
|
||||
# Two images: one train, one val
|
||||
train_logo_to_images[logo] = [shuffled_images[0]]
|
||||
val_logo_to_images[logo] = [shuffled_images[1]]
|
||||
continue
|
||||
|
||||
# Normal split for 3+ images
|
||||
train_end = max(1, int(n * train_split))
|
||||
val_end = train_end + max(1, int(n * val_split))
|
||||
|
||||
train_images = shuffled_images[:train_end]
|
||||
val_images = shuffled_images[train_end:val_end]
|
||||
test_images = shuffled_images[val_end:]
|
||||
|
||||
# Ensure at least one image in train
|
||||
if train_images:
|
||||
train_logo_to_images[logo] = train_images
|
||||
if val_images:
|
||||
val_logo_to_images[logo] = val_images
|
||||
if test_images:
|
||||
test_logo_to_images[logo] = test_images
|
||||
|
||||
return train_logo_to_images, val_logo_to_images, test_logo_to_images
|
||||
|
||||
def get_split_info(self) -> Dict[str, any]:
|
||||
"""Return information about the splits."""
|
||||
return {
|
||||
"split_level": self.split_level,
|
||||
"total_logos": len(self.all_logos),
|
||||
"train_logos": len(self.train_logos),
|
||||
"val_logos": len(self.val_logos),
|
||||
"test_logos": len(self.test_logos),
|
||||
"train_images": sum(
|
||||
len(imgs) for imgs in self.train_logo_to_images.values()
|
||||
),
|
||||
"val_images": sum(
|
||||
len(imgs) for imgs in self.val_logo_to_images.values()
|
||||
),
|
||||
"test_images": sum(
|
||||
len(imgs) for imgs in self.test_logo_to_images.values()
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class LogoContrastiveDataset(Dataset):
|
||||
"""
|
||||
Dataset for contrastive learning on logos.
|
||||
|
||||
Each __getitem__ call returns a batch of images organized for contrastive
|
||||
learning: K different logos with M samples each, ensuring positive pairs
|
||||
exist within each batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logo_data: LogoDataset,
|
||||
split: str = "train",
|
||||
logos_per_batch: int = 32,
|
||||
samples_per_logo: int = 4,
|
||||
transform: Optional[transforms.Compose] = None,
|
||||
batches_per_epoch: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the contrastive dataset.
|
||||
|
||||
Args:
|
||||
logo_data: LogoDataset instance with logo mappings
|
||||
split: One of "train", "val", or "test"
|
||||
logos_per_batch: Number of different logos per batch
|
||||
samples_per_logo: Number of samples for each logo
|
||||
transform: Image transforms to apply
|
||||
batches_per_epoch: Number of batches per epoch
|
||||
"""
|
||||
self.logo_data = logo_data
|
||||
self.logos_per_batch = logos_per_batch
|
||||
self.samples_per_logo = samples_per_logo
|
||||
self.transform = transform
|
||||
self.batches_per_epoch = batches_per_epoch
|
||||
|
||||
# Get logos and their images for this split
|
||||
# This respects both logo-level and image-level splits
|
||||
if split == "train":
|
||||
self.logos = logo_data.train_logos
|
||||
self.logo_to_images = logo_data.train_logo_to_images
|
||||
elif split == "val":
|
||||
self.logos = logo_data.val_logos
|
||||
self.logo_to_images = logo_data.val_logo_to_images
|
||||
else:
|
||||
self.logos = logo_data.test_logos
|
||||
self.logo_to_images = logo_data.test_logo_to_images
|
||||
|
||||
# Filter logos with enough samples for this split
|
||||
self.valid_logos = [
|
||||
logo for logo in self.logos
|
||||
if logo in self.logo_to_images and len(self.logo_to_images[logo]) >= samples_per_logo
|
||||
]
|
||||
|
||||
# For logos with fewer samples, we'll use with replacement
|
||||
self.logos_needing_replacement = [
|
||||
logo for logo in self.logos
|
||||
if logo in self.logo_to_images and len(self.logo_to_images[logo]) < samples_per_logo
|
||||
]
|
||||
|
||||
# Create label mapping (use all logos from the full dataset for consistent labels)
|
||||
self.logo_to_label = {
|
||||
logo: idx for idx, logo in enumerate(logo_data.all_logos)
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.batches_per_epoch
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get a batch of images for contrastive learning.
|
||||
|
||||
Returns:
|
||||
images: Tensor of shape [K*M, 3, 224, 224]
|
||||
labels: Tensor of shape [K*M] with logo class indices
|
||||
"""
|
||||
images = []
|
||||
labels = []
|
||||
|
||||
# Sample K logos for this batch (only from logos that have images in this split)
|
||||
available_logos = [l for l in self.logos if l in self.logo_to_images]
|
||||
k = min(self.logos_per_batch, len(available_logos))
|
||||
batch_logos = random.sample(available_logos, k)
|
||||
|
||||
for logo in batch_logos:
|
||||
logo_images = self.logo_to_images[logo]
|
||||
|
||||
# Sample M images for this logo
|
||||
if len(logo_images) >= self.samples_per_logo:
|
||||
sampled_paths = random.sample(logo_images, self.samples_per_logo)
|
||||
else:
|
||||
# Sample with replacement if not enough images
|
||||
sampled_paths = random.choices(
|
||||
logo_images, k=self.samples_per_logo
|
||||
)
|
||||
|
||||
# Load and transform images
|
||||
for img_path in sampled_paths:
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
else:
|
||||
img = get_val_transforms()(img)
|
||||
images.append(img)
|
||||
labels.append(self.logo_to_label[logo])
|
||||
except Exception as e:
|
||||
# Skip problematic images, sample another
|
||||
continue
|
||||
|
||||
# Stack into tensors
|
||||
if len(images) == 0:
|
||||
# Fallback: return dummy batch
|
||||
return (
|
||||
torch.zeros(1, 3, 224, 224),
|
||||
torch.zeros(1, dtype=torch.long),
|
||||
)
|
||||
|
||||
images_tensor = torch.stack(images)
|
||||
labels_tensor = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
return images_tensor, labels_tensor
|
||||
|
||||
|
||||
class BalancedBatchSampler(Sampler):
|
||||
"""
|
||||
Sampler that ensures each batch has a balanced distribution of logos.
|
||||
|
||||
Used with a flattened dataset where each sample is a single image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logo_labels: List[int],
|
||||
logos_per_batch: int,
|
||||
samples_per_logo: int,
|
||||
num_batches: int,
|
||||
):
|
||||
self.logo_labels = logo_labels
|
||||
self.logos_per_batch = logos_per_batch
|
||||
self.samples_per_logo = samples_per_logo
|
||||
self.num_batches = num_batches
|
||||
|
||||
# Group indices by logo
|
||||
self.logo_to_indices: Dict[int, List[int]] = {}
|
||||
for idx, label in enumerate(logo_labels):
|
||||
if label not in self.logo_to_indices:
|
||||
self.logo_to_indices[label] = []
|
||||
self.logo_to_indices[label].append(idx)
|
||||
|
||||
self.all_logos = list(self.logo_to_indices.keys())
|
||||
|
||||
def __iter__(self):
|
||||
for _ in range(self.num_batches):
|
||||
batch_indices = []
|
||||
|
||||
# Sample logos for this batch
|
||||
logos = random.sample(
|
||||
self.all_logos,
|
||||
min(self.logos_per_batch, len(self.all_logos)),
|
||||
)
|
||||
|
||||
for logo in logos:
|
||||
indices = self.logo_to_indices[logo]
|
||||
if len(indices) >= self.samples_per_logo:
|
||||
sampled = random.sample(indices, self.samples_per_logo)
|
||||
else:
|
||||
sampled = random.choices(indices, k=self.samples_per_logo)
|
||||
batch_indices.extend(sampled)
|
||||
|
||||
yield batch_indices
|
||||
|
||||
def __len__(self):
|
||||
return self.num_batches
|
||||
|
||||
|
||||
def create_dataloaders(
|
||||
db_path: str,
|
||||
reference_dir: str,
|
||||
batch_size: int = 16,
|
||||
logos_per_batch: int = 32,
|
||||
samples_per_logo: int = 4,
|
||||
num_workers: int = 4,
|
||||
train_split: float = 0.7,
|
||||
val_split: float = 0.15,
|
||||
test_split: float = 0.15,
|
||||
seed: int = 42,
|
||||
augmentation_strength: str = "medium",
|
||||
batches_per_epoch: int = 1000,
|
||||
split_level: str = "logo",
|
||||
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
|
||||
"""
|
||||
Create train, validation, and optionally test dataloaders.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
reference_dir: Directory containing reference logo images
|
||||
batch_size: Not used directly (see logos_per_batch and samples_per_logo)
|
||||
logos_per_batch: Number of different logos per batch
|
||||
samples_per_logo: Samples per logo in batch
|
||||
num_workers: Number of data loading workers
|
||||
train_split: Fraction for training
|
||||
val_split: Fraction for validation
|
||||
test_split: Fraction for testing
|
||||
seed: Random seed
|
||||
augmentation_strength: "light", "medium", or "strong"
|
||||
batches_per_epoch: Number of batches per training epoch
|
||||
split_level: "logo" for brand-level splits, "image" for image-level splits
|
||||
|
||||
Returns:
|
||||
Tuple of (train_loader, val_loader, test_loader)
|
||||
"""
|
||||
# Load logo data
|
||||
logo_data = LogoDataset(
|
||||
db_path=db_path,
|
||||
reference_dir=reference_dir,
|
||||
train_split=train_split,
|
||||
val_split=val_split,
|
||||
test_split=test_split,
|
||||
seed=seed,
|
||||
split_level=split_level,
|
||||
)
|
||||
|
||||
# Print split info
|
||||
split_info = logo_data.get_split_info()
|
||||
print(f"Dataset loaded:")
|
||||
print(f" Split level: {split_info['split_level']}")
|
||||
print(f" Total logos: {split_info['total_logos']}")
|
||||
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
|
||||
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
|
||||
print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images")
|
||||
|
||||
# Create datasets
|
||||
train_dataset = LogoContrastiveDataset(
|
||||
logo_data=logo_data,
|
||||
split="train",
|
||||
logos_per_batch=logos_per_batch,
|
||||
samples_per_logo=samples_per_logo,
|
||||
transform=get_train_transforms(augmentation_strength),
|
||||
batches_per_epoch=batches_per_epoch,
|
||||
)
|
||||
|
||||
val_dataset = LogoContrastiveDataset(
|
||||
logo_data=logo_data,
|
||||
split="val",
|
||||
logos_per_batch=logos_per_batch,
|
||||
samples_per_logo=samples_per_logo,
|
||||
transform=get_val_transforms(),
|
||||
batches_per_epoch=batches_per_epoch // 10, # Fewer val batches
|
||||
)
|
||||
|
||||
test_dataset = LogoContrastiveDataset(
|
||||
logo_data=logo_data,
|
||||
split="test",
|
||||
logos_per_batch=logos_per_batch,
|
||||
samples_per_logo=samples_per_logo,
|
||||
transform=get_val_transforms(),
|
||||
batches_per_epoch=batches_per_epoch // 10,
|
||||
) if test_split > 0 else None
|
||||
|
||||
# Create dataloaders
|
||||
# Note: batch_size=1 because each __getitem__ already returns a batch
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=_collate_contrastive_batch,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=_collate_contrastive_batch,
|
||||
)
|
||||
|
||||
test_loader = None
|
||||
if test_dataset is not None:
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=_collate_contrastive_batch,
|
||||
)
|
||||
|
||||
return train_loader, val_loader, test_loader
|
||||
|
||||
|
||||
def _collate_contrastive_batch(
|
||||
batch: List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Collate function that unpacks pre-batched data.
|
||||
|
||||
Since LogoContrastiveDataset already returns batched data,
|
||||
we just squeeze the outer dimension.
|
||||
"""
|
||||
images, labels = batch[0]
|
||||
return images, labels
|
||||
339
training/evaluation.py
Normal file
339
training/evaluation.py
Normal file
@ -0,0 +1,339 @@
|
||||
"""
|
||||
Evaluation metrics for embedding quality.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class EmbeddingEvaluator:
|
||||
"""
|
||||
Evaluator for embedding quality metrics.
|
||||
|
||||
Computes metrics that indicate how well the embeddings
|
||||
separate different logo classes.
|
||||
"""
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> Dict[str, float]:
|
||||
"""
|
||||
Compute embedding quality metrics.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer class labels
|
||||
|
||||
Returns:
|
||||
Dict with metric names and values
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return {
|
||||
"mean_pos_sim": 0.0,
|
||||
"mean_neg_sim": 0.0,
|
||||
"separation": 0.0,
|
||||
"recall_at_1": 0.0,
|
||||
"recall_at_5": 0.0,
|
||||
}
|
||||
|
||||
# Compute similarity matrix
|
||||
similarity = embeddings @ embeddings.T
|
||||
|
||||
# Create masks
|
||||
labels_col = labels.unsqueeze(0)
|
||||
labels_row = labels.unsqueeze(1)
|
||||
positive_mask = (labels_row == labels_col).float()
|
||||
negative_mask = 1 - positive_mask
|
||||
|
||||
# Remove diagonal from positive mask
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Count pairs
|
||||
num_positives = positive_mask.sum()
|
||||
num_negatives = negative_mask.sum()
|
||||
|
||||
# Mean positive similarity (excluding self)
|
||||
if num_positives > 0:
|
||||
pos_sims = (similarity * positive_mask).sum() / num_positives
|
||||
mean_pos_sim = pos_sims.item()
|
||||
else:
|
||||
mean_pos_sim = 0.0
|
||||
|
||||
# Mean negative similarity
|
||||
if num_negatives > 0:
|
||||
neg_sims = (similarity * negative_mask).sum() / num_negatives
|
||||
mean_neg_sim = neg_sims.item()
|
||||
else:
|
||||
mean_neg_sim = 0.0
|
||||
|
||||
# Separation: gap between positive and negative similarity
|
||||
separation = mean_pos_sim - mean_neg_sim
|
||||
|
||||
# Recall@K metrics
|
||||
recall_at_1 = self._compute_recall_at_k(similarity, labels, k=1)
|
||||
recall_at_5 = self._compute_recall_at_k(similarity, labels, k=5)
|
||||
|
||||
return {
|
||||
"mean_pos_sim": mean_pos_sim,
|
||||
"mean_neg_sim": mean_neg_sim,
|
||||
"separation": separation,
|
||||
"recall_at_1": recall_at_1,
|
||||
"recall_at_5": recall_at_5,
|
||||
}
|
||||
|
||||
def _compute_recall_at_k(
|
||||
self,
|
||||
similarity: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
k: int = 1,
|
||||
) -> float:
|
||||
"""
|
||||
Compute Recall@K for nearest neighbor retrieval.
|
||||
|
||||
For each sample, check if the k nearest neighbors (excluding self)
|
||||
contain at least one sample with the same label.
|
||||
|
||||
Args:
|
||||
similarity: [N, N] similarity matrix
|
||||
labels: [N] class labels
|
||||
k: Number of neighbors to consider
|
||||
|
||||
Returns:
|
||||
Recall@K score (0 to 1)
|
||||
"""
|
||||
batch_size = similarity.shape[0]
|
||||
if batch_size <= 1:
|
||||
return 0.0
|
||||
|
||||
# Mask out self-similarity
|
||||
similarity = similarity.clone()
|
||||
similarity.fill_diagonal_(float("-inf"))
|
||||
|
||||
# Get top-k indices
|
||||
_, top_k_indices = similarity.topk(min(k, batch_size - 1), dim=1)
|
||||
|
||||
# Check if any of top-k have same label
|
||||
correct = 0
|
||||
for i in range(batch_size):
|
||||
query_label = labels[i]
|
||||
retrieved_labels = labels[top_k_indices[i]]
|
||||
if (retrieved_labels == query_label).any():
|
||||
correct += 1
|
||||
|
||||
return correct / batch_size
|
||||
|
||||
def compute_detailed_metrics(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
label_names: Optional[List[str]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Compute detailed per-class metrics.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] embeddings
|
||||
labels: [N] class labels
|
||||
label_names: Optional list of label names
|
||||
|
||||
Returns:
|
||||
Dict with detailed metrics including per-class stats
|
||||
"""
|
||||
basic_metrics = self.compute_metrics(embeddings, labels)
|
||||
|
||||
# Per-class statistics
|
||||
unique_labels = labels.unique()
|
||||
per_class_stats = {}
|
||||
|
||||
similarity = embeddings @ embeddings.T
|
||||
|
||||
for label in unique_labels:
|
||||
mask = labels == label
|
||||
class_embeddings = embeddings[mask]
|
||||
class_size = mask.sum().item()
|
||||
|
||||
if class_size > 1:
|
||||
# Intra-class similarity
|
||||
class_sim = class_embeddings @ class_embeddings.T
|
||||
# Exclude diagonal
|
||||
mask_diag = ~torch.eye(class_size, dtype=torch.bool, device=class_sim.device)
|
||||
intra_sim = class_sim[mask_diag].mean().item()
|
||||
else:
|
||||
intra_sim = 1.0
|
||||
|
||||
# Inter-class similarity (to other classes)
|
||||
other_mask = labels != label
|
||||
if other_mask.any():
|
||||
inter_sim = similarity[mask][:, other_mask].mean().item()
|
||||
else:
|
||||
inter_sim = 0.0
|
||||
|
||||
class_name = label_names[label.item()] if label_names else str(label.item())
|
||||
per_class_stats[class_name] = {
|
||||
"size": class_size,
|
||||
"intra_class_sim": intra_sim,
|
||||
"inter_class_sim": inter_sim,
|
||||
"class_separation": intra_sim - inter_sim,
|
||||
}
|
||||
|
||||
# Aggregate per-class stats
|
||||
if per_class_stats:
|
||||
separations = [s["class_separation"] for s in per_class_stats.values()]
|
||||
min_separation = min(separations)
|
||||
max_separation = max(separations)
|
||||
std_separation = np.std(separations)
|
||||
else:
|
||||
min_separation = max_separation = std_separation = 0.0
|
||||
|
||||
return {
|
||||
**basic_metrics,
|
||||
"per_class": per_class_stats,
|
||||
"min_class_separation": min_separation,
|
||||
"max_class_separation": max_separation,
|
||||
"std_class_separation": std_separation,
|
||||
}
|
||||
|
||||
|
||||
class SimilarityAnalyzer:
|
||||
"""
|
||||
Analyze similarity distributions for debugging and tuning.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def analyze_similarity_distribution(
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Get similarity distributions for positive and negative pairs.
|
||||
|
||||
Useful for choosing appropriate thresholds.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] embeddings
|
||||
labels: [N] class labels
|
||||
|
||||
Returns:
|
||||
Dict with 'positive_sims' and 'negative_sims' arrays
|
||||
"""
|
||||
similarity = (embeddings @ embeddings.T).cpu().numpy()
|
||||
labels_np = labels.cpu().numpy()
|
||||
|
||||
batch_size = len(labels_np)
|
||||
positive_sims = []
|
||||
negative_sims = []
|
||||
|
||||
for i in range(batch_size):
|
||||
for j in range(i + 1, batch_size):
|
||||
if labels_np[i] == labels_np[j]:
|
||||
positive_sims.append(similarity[i, j])
|
||||
else:
|
||||
negative_sims.append(similarity[i, j])
|
||||
|
||||
return {
|
||||
"positive_sims": np.array(positive_sims),
|
||||
"negative_sims": np.array(negative_sims),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def find_hard_pairs(
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
n_hard: int = 10,
|
||||
) -> Tuple[List[Tuple[int, int, float]], List[Tuple[int, int, float]]]:
|
||||
"""
|
||||
Find hardest positive and negative pairs.
|
||||
|
||||
Hard positives: same label but low similarity
|
||||
Hard negatives: different label but high similarity
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] embeddings
|
||||
labels: [N] class labels
|
||||
n_hard: Number of hard pairs to return
|
||||
|
||||
Returns:
|
||||
Tuple of (hard_positives, hard_negatives)
|
||||
Each is a list of (idx1, idx2, similarity) tuples
|
||||
"""
|
||||
similarity = embeddings @ embeddings.T
|
||||
batch_size = len(labels)
|
||||
|
||||
hard_positives = [] # Low similarity, same label
|
||||
hard_negatives = [] # High similarity, different label
|
||||
|
||||
for i in range(batch_size):
|
||||
for j in range(i + 1, batch_size):
|
||||
sim = similarity[i, j].item()
|
||||
if labels[i] == labels[j]:
|
||||
hard_positives.append((i, j, sim))
|
||||
else:
|
||||
hard_negatives.append((i, j, sim))
|
||||
|
||||
# Sort: hard positives by ascending similarity (lowest first)
|
||||
hard_positives.sort(key=lambda x: x[2])
|
||||
|
||||
# Sort: hard negatives by descending similarity (highest first)
|
||||
hard_negatives.sort(key=lambda x: -x[2])
|
||||
|
||||
return hard_positives[:n_hard], hard_negatives[:n_hard]
|
||||
|
||||
@staticmethod
|
||||
def compute_confusion_pairs(
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
label_names: Optional[List[str]] = None,
|
||||
top_k: int = 10,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Find pairs of classes that are most confused (highest cross-class similarity).
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] embeddings
|
||||
labels: [N] class labels
|
||||
label_names: Optional label names
|
||||
top_k: Number of confused pairs to return
|
||||
|
||||
Returns:
|
||||
List of dicts with class pairs and their similarity
|
||||
"""
|
||||
unique_labels = labels.unique()
|
||||
class_centroids = {}
|
||||
|
||||
# Compute class centroids
|
||||
for label in unique_labels:
|
||||
mask = labels == label
|
||||
centroid = embeddings[mask].mean(dim=0)
|
||||
centroid = F.normalize(centroid, dim=0)
|
||||
class_centroids[label.item()] = centroid
|
||||
|
||||
# Compute pairwise centroid similarities
|
||||
confusions = []
|
||||
label_list = list(class_centroids.keys())
|
||||
|
||||
for i, label1 in enumerate(label_list):
|
||||
for label2 in label_list[i + 1:]:
|
||||
sim = (class_centroids[label1] @ class_centroids[label2]).item()
|
||||
name1 = label_names[label1] if label_names else str(label1)
|
||||
name2 = label_names[label2] if label_names else str(label2)
|
||||
confusions.append({
|
||||
"class1": name1,
|
||||
"class2": name2,
|
||||
"label1": label1,
|
||||
"label2": label2,
|
||||
"centroid_similarity": sim,
|
||||
})
|
||||
|
||||
# Sort by similarity (highest first)
|
||||
confusions.sort(key=lambda x: -x["centroid_similarity"])
|
||||
|
||||
return confusions[:top_k]
|
||||
326
training/losses.py
Normal file
326
training/losses.py
Normal file
@ -0,0 +1,326 @@
|
||||
"""
|
||||
Loss functions for contrastive learning on logo embeddings.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InfoNCELoss(nn.Module):
|
||||
"""
|
||||
Normalized Temperature-scaled Cross Entropy Loss (InfoNCE).
|
||||
|
||||
This is the contrastive loss used in CLIP training. It maximizes
|
||||
similarity between embeddings of the same logo class while
|
||||
minimizing similarity to embeddings of different classes.
|
||||
|
||||
For a batch with N samples:
|
||||
- Each sample is an anchor
|
||||
- Positive pairs: samples with the same label
|
||||
- Negative pairs: samples with different labels
|
||||
|
||||
The loss for each anchor is:
|
||||
-log(sum(exp(sim(anchor, pos)/temp)) / sum(exp(sim(anchor, all)/temp)))
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 0.07):
|
||||
"""
|
||||
Initialize InfoNCE loss.
|
||||
|
||||
Args:
|
||||
temperature: Scaling factor for similarities (0.05-0.1 typical).
|
||||
Lower temperature makes the distribution sharper.
|
||||
"""
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute InfoNCE loss for a batch of embeddings.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer logo class labels
|
||||
|
||||
Returns:
|
||||
Scalar loss value
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Compute similarity matrix [N, N]
|
||||
# Since embeddings are L2-normalized, dot product = cosine similarity
|
||||
similarity = embeddings @ embeddings.T / self.temperature
|
||||
|
||||
# Create positive mask: same label = 1, different = 0
|
||||
labels_col = labels.unsqueeze(0) # [1, N]
|
||||
labels_row = labels.unsqueeze(1) # [N, 1]
|
||||
positive_mask = (labels_row == labels_col).float() # [N, N]
|
||||
|
||||
# Remove self-similarity from positives (diagonal)
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Count positives per anchor (avoid division by zero)
|
||||
num_positives = positive_mask.sum(dim=1)
|
||||
has_positives = num_positives > 0
|
||||
|
||||
# If no positives exist for any anchor, return zero loss
|
||||
if not has_positives.any():
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Mask out self-similarity with large negative value
|
||||
similarity = similarity - identity * 1e9
|
||||
|
||||
# Compute log-softmax over similarities
|
||||
log_softmax = F.log_softmax(similarity, dim=1)
|
||||
|
||||
# Sum log probabilities of positive pairs
|
||||
positive_log_probs = (log_softmax * positive_mask).sum(dim=1)
|
||||
|
||||
# Average over number of positives (only for anchors with positives)
|
||||
loss_per_anchor = torch.zeros(batch_size, device=device)
|
||||
loss_per_anchor[has_positives] = (
|
||||
-positive_log_probs[has_positives] / num_positives[has_positives]
|
||||
)
|
||||
|
||||
return loss_per_anchor.mean()
|
||||
|
||||
|
||||
class SupConLoss(nn.Module):
|
||||
"""
|
||||
Supervised Contrastive Loss.
|
||||
|
||||
Similar to InfoNCE but uses a different formulation that
|
||||
considers each positive pair separately rather than averaging.
|
||||
|
||||
Reference: https://arxiv.org/abs/2004.11362
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 0.07):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Supervised Contrastive loss.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer logo class labels
|
||||
|
||||
Returns:
|
||||
Scalar loss value
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Compute similarity matrix
|
||||
similarity = embeddings @ embeddings.T / self.temperature
|
||||
|
||||
# Create masks
|
||||
labels_col = labels.unsqueeze(0)
|
||||
labels_row = labels.unsqueeze(1)
|
||||
positive_mask = (labels_row == labels_col).float()
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
|
||||
# Remove self from positives
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Number of positives per anchor
|
||||
num_positives = positive_mask.sum(dim=1)
|
||||
has_positives = num_positives > 0
|
||||
|
||||
if not has_positives.any():
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# For numerical stability, subtract max similarity
|
||||
sim_max, _ = similarity.max(dim=1, keepdim=True)
|
||||
similarity = similarity - sim_max.detach()
|
||||
|
||||
# Compute exp(similarity) with self masked out
|
||||
exp_sim = torch.exp(similarity) * (1 - identity)
|
||||
|
||||
# Denominator: sum of exp over all pairs except self
|
||||
log_prob = similarity - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
# Mean of log-prob over positive pairs
|
||||
mean_log_prob_pos = (positive_mask * log_prob).sum(dim=1) / (
|
||||
num_positives + 1e-8
|
||||
)
|
||||
|
||||
# Loss is negative mean log probability
|
||||
loss = -mean_log_prob_pos[has_positives].mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
"""
|
||||
Triplet loss with online hard mining.
|
||||
|
||||
For each anchor:
|
||||
- Hardest positive: most distant sample with same label
|
||||
- Hardest negative: closest sample with different label
|
||||
|
||||
Loss = max(0, d(anchor, hardest_pos) - d(anchor, hardest_neg) + margin)
|
||||
|
||||
This is an alternative to InfoNCE for when batch sizes are small.
|
||||
"""
|
||||
|
||||
def __init__(self, margin: float = 0.3):
|
||||
"""
|
||||
Initialize Triplet loss.
|
||||
|
||||
Args:
|
||||
margin: Minimum required gap between positive and negative distances
|
||||
"""
|
||||
super().__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute triplet loss with online hard mining.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer logo class labels
|
||||
|
||||
Returns:
|
||||
Scalar loss value
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Compute pairwise cosine distances (1 - cosine_similarity)
|
||||
# For normalized vectors: distance = 1 - dot_product
|
||||
similarity = embeddings @ embeddings.T
|
||||
distances = 1 - similarity
|
||||
|
||||
# Create masks
|
||||
labels_col = labels.unsqueeze(0)
|
||||
labels_row = labels.unsqueeze(1)
|
||||
positive_mask = (labels_row == labels_col).float()
|
||||
negative_mask = 1 - positive_mask
|
||||
|
||||
# Remove self from positives (diagonal)
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Check if we have any valid triplets
|
||||
has_positives = positive_mask.sum(dim=1) > 0
|
||||
has_negatives = negative_mask.sum(dim=1) > 0
|
||||
valid_anchors = has_positives & has_negatives
|
||||
|
||||
if not valid_anchors.any():
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# For each anchor, find hardest positive (max distance among positives)
|
||||
# Set negatives to -inf so they don't affect max
|
||||
pos_distances = distances.clone()
|
||||
pos_distances[positive_mask == 0] = float("-inf")
|
||||
hardest_positive, _ = pos_distances.max(dim=1)
|
||||
|
||||
# For each anchor, find hardest negative (min distance among negatives)
|
||||
# Set positives to inf so they don't affect min
|
||||
neg_distances = distances.clone()
|
||||
neg_distances[negative_mask == 0] = float("inf")
|
||||
hardest_negative, _ = neg_distances.min(dim=1)
|
||||
|
||||
# Triplet loss: want positive to be closer than negative by margin
|
||||
triplet_loss = F.relu(
|
||||
hardest_positive - hardest_negative + self.margin
|
||||
)
|
||||
|
||||
# Average over valid anchors only
|
||||
loss = triplet_loss[valid_anchors].mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""
|
||||
Combined loss function with weighted InfoNCE and Triplet losses.
|
||||
|
||||
Can help stabilize training by combining the benefits of both losses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: float = 0.07,
|
||||
triplet_margin: float = 0.3,
|
||||
infonce_weight: float = 1.0,
|
||||
triplet_weight: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.infonce = InfoNCELoss(temperature=temperature)
|
||||
self.triplet = TripletLoss(margin=triplet_margin)
|
||||
self.infonce_weight = infonce_weight
|
||||
self.triplet_weight = triplet_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
infonce_loss = self.infonce(embeddings, labels)
|
||||
triplet_loss = self.triplet(embeddings, labels)
|
||||
|
||||
return (
|
||||
self.infonce_weight * infonce_loss +
|
||||
self.triplet_weight * triplet_loss
|
||||
)
|
||||
|
||||
|
||||
def get_loss_function(
|
||||
loss_type: str = "infonce",
|
||||
temperature: float = 0.07,
|
||||
triplet_margin: float = 0.3,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Factory function to create loss function.
|
||||
|
||||
Args:
|
||||
loss_type: One of "infonce", "supcon", "triplet", or "combined"
|
||||
temperature: Temperature for InfoNCE/SupCon
|
||||
triplet_margin: Margin for triplet loss
|
||||
|
||||
Returns:
|
||||
Loss function module
|
||||
"""
|
||||
if loss_type == "infonce":
|
||||
return InfoNCELoss(temperature=temperature)
|
||||
elif loss_type == "supcon":
|
||||
return SupConLoss(temperature=temperature)
|
||||
elif loss_type == "triplet":
|
||||
return TripletLoss(margin=triplet_margin)
|
||||
elif loss_type == "combined":
|
||||
return CombinedLoss(
|
||||
temperature=temperature,
|
||||
triplet_margin=triplet_margin,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {loss_type}")
|
||||
351
training/model.py
Normal file
351
training/model.py
Normal file
@ -0,0 +1,351 @@
|
||||
"""
|
||||
Fine-tunable CLIP model wrapper with LoRA support.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
# Check if peft is available for LoRA
|
||||
try:
|
||||
from peft import LoraConfig, get_peft_model, PeftModel
|
||||
PEFT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PEFT_AVAILABLE = False
|
||||
LoraConfig = None
|
||||
get_peft_model = None
|
||||
PeftModel = None
|
||||
|
||||
|
||||
class LogoFineTunedCLIP(nn.Module):
|
||||
"""
|
||||
CLIP vision encoder fine-tuned for logo similarity.
|
||||
|
||||
Preserves embedding interface for compatibility with DetectLogosDETR:
|
||||
- Same embedding dimensionality (768 for ViT-L/14)
|
||||
- L2 normalized outputs
|
||||
- Works with existing get_image_features() pattern
|
||||
|
||||
Supports:
|
||||
- LoRA for memory-efficient fine-tuning
|
||||
- Layer freezing for transfer learning
|
||||
- Gradient checkpointing for memory optimization
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_model: nn.Module,
|
||||
lora_r: int = 16,
|
||||
lora_alpha: int = 32,
|
||||
lora_dropout: float = 0.1,
|
||||
freeze_layers: int = 12,
|
||||
use_gradient_checkpointing: bool = True,
|
||||
add_projection_head: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the fine-tunable CLIP wrapper.
|
||||
|
||||
Args:
|
||||
vision_model: CLIP vision model (CLIPVisionModel)
|
||||
lora_r: Rank of LoRA low-rank matrices (0 to disable)
|
||||
lora_alpha: LoRA scaling factor
|
||||
lora_dropout: Dropout for LoRA layers
|
||||
freeze_layers: Number of transformer layers to freeze (from bottom)
|
||||
use_gradient_checkpointing: Enable gradient checkpointing
|
||||
add_projection_head: Add trainable projection head
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = vision_model
|
||||
self.embedding_dim = vision_model.config.hidden_size
|
||||
self.freeze_layers = freeze_layers
|
||||
self.lora_r = lora_r
|
||||
self.lora_alpha = lora_alpha
|
||||
|
||||
# Enable gradient checkpointing for memory efficiency
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(self.vision_model, "gradient_checkpointing_enable"):
|
||||
self.vision_model.gradient_checkpointing_enable()
|
||||
|
||||
# Freeze lower layers
|
||||
self._freeze_layers(freeze_layers)
|
||||
|
||||
# Apply LoRA to attention layers in upper blocks
|
||||
self.peft_applied = False
|
||||
if PEFT_AVAILABLE and lora_r > 0:
|
||||
self._apply_lora(lora_r, lora_alpha, lora_dropout)
|
||||
self.peft_applied = True
|
||||
elif lora_r > 0 and not PEFT_AVAILABLE:
|
||||
print(
|
||||
"Warning: peft not installed. LoRA disabled. "
|
||||
"Install with: pip install peft"
|
||||
)
|
||||
|
||||
# Optional projection head for fine-tuning
|
||||
self.add_projection_head = add_projection_head
|
||||
if add_projection_head:
|
||||
self.projection = nn.Sequential(
|
||||
nn.Linear(self.embedding_dim, self.embedding_dim),
|
||||
nn.LayerNorm(self.embedding_dim),
|
||||
)
|
||||
else:
|
||||
self.projection = nn.Identity()
|
||||
|
||||
def _freeze_layers(self, num_layers: int) -> None:
|
||||
"""Freeze the first N transformer layers and embeddings."""
|
||||
if num_layers <= 0:
|
||||
return
|
||||
|
||||
# Freeze embeddings
|
||||
if hasattr(self.vision_model, "embeddings"):
|
||||
for param in self.vision_model.embeddings.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze specified number of encoder layers
|
||||
if hasattr(self.vision_model, "encoder"):
|
||||
for i, layer in enumerate(self.vision_model.encoder.layers):
|
||||
if i < num_layers:
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _apply_lora(
|
||||
self,
|
||||
r: int,
|
||||
alpha: int,
|
||||
dropout: float,
|
||||
) -> None:
|
||||
"""Apply LoRA adapters to attention layers."""
|
||||
if not PEFT_AVAILABLE:
|
||||
return
|
||||
|
||||
# Configure LoRA for vision transformer
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
lora_alpha=alpha,
|
||||
lora_dropout=dropout,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
bias="none",
|
||||
modules_to_save=[], # Don't save any full modules
|
||||
)
|
||||
|
||||
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extract normalized embeddings for logo images.
|
||||
|
||||
Args:
|
||||
pixel_values: [batch, 3, 224, 224] preprocessed images
|
||||
|
||||
Returns:
|
||||
embeddings: [batch, embedding_dim] L2-normalized
|
||||
"""
|
||||
# Get vision features
|
||||
outputs = self.vision_model(pixel_values=pixel_values)
|
||||
|
||||
# Use pooler output (CLS token projection) if available
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
features = outputs.pooler_output
|
||||
else:
|
||||
# Fall back to CLS token from last hidden state
|
||||
features = outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
# Apply projection head
|
||||
features = self.projection(features)
|
||||
|
||||
# L2 normalize for cosine similarity
|
||||
features = F.normalize(features, dim=-1)
|
||||
|
||||
return features
|
||||
|
||||
def get_image_features(self, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Compatibility method matching CLIP's interface.
|
||||
|
||||
Used by DetectLogosDETR._get_embedding_pil().
|
||||
"""
|
||||
return self.forward(kwargs["pixel_values"])
|
||||
|
||||
def get_trainable_parameters(self) -> List[torch.nn.Parameter]:
|
||||
"""Return list of trainable parameters."""
|
||||
return [p for p in self.parameters() if p.requires_grad]
|
||||
|
||||
def get_parameter_count(self) -> Dict[str, int]:
|
||||
"""Return count of trainable and total parameters."""
|
||||
total = sum(p.numel() for p in self.parameters())
|
||||
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
return {
|
||||
"total": total,
|
||||
"trainable": trainable,
|
||||
"frozen": total - trainable,
|
||||
"trainable_percent": 100 * trainable / total if total > 0 else 0,
|
||||
}
|
||||
|
||||
def save_pretrained(self, output_dir: str) -> None:
|
||||
"""
|
||||
Save model in HuggingFace-compatible format.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save model files
|
||||
"""
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model weights
|
||||
if self.peft_applied and PEFT_AVAILABLE:
|
||||
# Save LoRA weights separately
|
||||
self.vision_model.save_pretrained(output_path / "vision_lora")
|
||||
# Save projection head
|
||||
torch.save(
|
||||
self.projection.state_dict(),
|
||||
output_path / "projection_head.bin",
|
||||
)
|
||||
else:
|
||||
# Save full model state
|
||||
torch.save(self.state_dict(), output_path / "pytorch_model.bin")
|
||||
|
||||
# Save config
|
||||
config = {
|
||||
"model_type": "clip_logo_finetuned",
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"lora_r": self.lora_r,
|
||||
"lora_alpha": self.lora_alpha,
|
||||
"freeze_layers": self.freeze_layers,
|
||||
"add_projection_head": self.add_projection_head,
|
||||
"peft_applied": self.peft_applied,
|
||||
}
|
||||
|
||||
with open(output_path / "config.json", "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_path: str,
|
||||
base_model: str = "openai/clip-vit-large-patch14",
|
||||
device: Optional[torch.device] = None,
|
||||
) -> "LogoFineTunedCLIP":
|
||||
"""
|
||||
Load a fine-tuned model from saved weights.
|
||||
|
||||
Args:
|
||||
model_path: Path to saved model directory
|
||||
base_model: Base CLIP model name (for architecture)
|
||||
device: Device to load model on
|
||||
|
||||
Returns:
|
||||
Loaded LogoFineTunedCLIP model
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Load config
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Load base CLIP model
|
||||
clip_model = CLIPModel.from_pretrained(base_model)
|
||||
|
||||
# Check if we need to load LoRA weights
|
||||
if config.get("peft_applied", False) and PEFT_AVAILABLE:
|
||||
# Create model WITHOUT LoRA (lora_r=0) - we'll load LoRA weights separately
|
||||
model = cls(
|
||||
vision_model=clip_model.vision_model,
|
||||
lora_r=0, # Don't apply LoRA in constructor
|
||||
lora_alpha=config.get("lora_alpha", 1),
|
||||
freeze_layers=config.get("freeze_layers", 12),
|
||||
add_projection_head=config.get("add_projection_head", True),
|
||||
use_gradient_checkpointing=False,
|
||||
)
|
||||
|
||||
# Load LoRA weights from checkpoint
|
||||
lora_path = model_path / "vision_lora"
|
||||
if lora_path.exists():
|
||||
model.vision_model = PeftModel.from_pretrained(
|
||||
model.vision_model, lora_path
|
||||
)
|
||||
model.peft_applied = True
|
||||
model.lora_r = config.get("lora_r", 16)
|
||||
|
||||
# Load projection head
|
||||
proj_path = model_path / "projection_head.bin"
|
||||
if proj_path.exists():
|
||||
model.projection.load_state_dict(
|
||||
torch.load(proj_path, map_location="cpu")
|
||||
)
|
||||
else:
|
||||
# No LoRA - create model and load full state
|
||||
model = cls(
|
||||
vision_model=clip_model.vision_model,
|
||||
lora_r=0,
|
||||
lora_alpha=config.get("lora_alpha", 1),
|
||||
freeze_layers=config.get("freeze_layers", 12),
|
||||
add_projection_head=config.get("add_projection_head", True),
|
||||
use_gradient_checkpointing=False,
|
||||
)
|
||||
|
||||
weights_path = model_path / "pytorch_model.bin"
|
||||
if weights_path.exists():
|
||||
model.load_state_dict(
|
||||
torch.load(weights_path, map_location="cpu")
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model = model.to(device)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_model(
|
||||
base_model: str = "openai/clip-vit-large-patch14",
|
||||
lora_r: int = 16,
|
||||
lora_alpha: int = 32,
|
||||
lora_dropout: float = 0.1,
|
||||
freeze_layers: int = 12,
|
||||
use_gradient_checkpointing: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[LogoFineTunedCLIP, CLIPProcessor]:
|
||||
"""
|
||||
Create a fine-tunable CLIP model and processor.
|
||||
|
||||
Args:
|
||||
base_model: HuggingFace model name or path
|
||||
lora_r: LoRA rank (0 to disable)
|
||||
lora_alpha: LoRA scaling factor
|
||||
lora_dropout: LoRA dropout
|
||||
freeze_layers: Number of layers to freeze
|
||||
use_gradient_checkpointing: Enable gradient checkpointing
|
||||
device: Device to load model on
|
||||
|
||||
Returns:
|
||||
Tuple of (model, processor)
|
||||
"""
|
||||
# Load base CLIP model
|
||||
clip_model = CLIPModel.from_pretrained(base_model)
|
||||
processor = CLIPProcessor.from_pretrained(base_model)
|
||||
|
||||
# Create fine-tunable wrapper
|
||||
model = LogoFineTunedCLIP(
|
||||
vision_model=clip_model.vision_model,
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
freeze_layers=freeze_layers,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model = model.to(device)
|
||||
|
||||
# Print parameter info
|
||||
param_info = model.get_parameter_count()
|
||||
print(f"Model created:")
|
||||
print(f" Total parameters: {param_info['total']:,}")
|
||||
print(f" Trainable: {param_info['trainable']:,} ({param_info['trainable_percent']:.2f}%)")
|
||||
print(f" Frozen: {param_info['frozen']:,}")
|
||||
|
||||
return model, processor
|
||||
400
training/trainer.py
Normal file
400
training/trainer.py
Normal file
@ -0,0 +1,400 @@
|
||||
"""
|
||||
Training loop with checkpointing, mixed precision, and evaluation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import TrainingConfig
|
||||
from .losses import get_loss_function
|
||||
from .evaluation import EmbeddingEvaluator
|
||||
|
||||
# Check if amp is available
|
||||
try:
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
AMP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AMP_AVAILABLE = False
|
||||
autocast = None
|
||||
GradScaler = None
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
Trainer for fine-tuning CLIP on logo recognition.
|
||||
|
||||
Features:
|
||||
- Mixed precision training (FP16)
|
||||
- Gradient accumulation
|
||||
- Gradient checkpointing (via model)
|
||||
- Cosine annealing LR scheduler
|
||||
- Early stopping
|
||||
- Checkpoint saving/loading
|
||||
- Evaluation during training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
train_loader: DataLoader,
|
||||
val_loader: DataLoader,
|
||||
config: TrainingConfig,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the trainer.
|
||||
|
||||
Args:
|
||||
model: LogoFineTunedCLIP model
|
||||
train_loader: Training dataloader
|
||||
val_loader: Validation dataloader
|
||||
config: Training configuration
|
||||
logger: Optional logger instance
|
||||
"""
|
||||
self.model = model
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.config = config
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
# Device setup
|
||||
self.device = torch.device(
|
||||
"cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
self.model.to(self.device)
|
||||
self.logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Optimizer - only trainable parameters
|
||||
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
||||
self.logger.info(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
|
||||
|
||||
self.optimizer = AdamW(
|
||||
trainable_params,
|
||||
lr=config.learning_rate,
|
||||
weight_decay=config.weight_decay,
|
||||
)
|
||||
|
||||
# Learning rate scheduler
|
||||
total_steps = len(train_loader) * config.max_epochs
|
||||
self.scheduler = OneCycleLR(
|
||||
self.optimizer,
|
||||
max_lr=config.learning_rate,
|
||||
total_steps=total_steps,
|
||||
pct_start=config.warmup_steps / total_steps if total_steps > 0 else 0.1,
|
||||
anneal_strategy="cos",
|
||||
)
|
||||
|
||||
# Mixed precision training
|
||||
self.use_amp = config.mixed_precision and AMP_AVAILABLE and self.device.type == "cuda"
|
||||
if self.use_amp:
|
||||
self.scaler = GradScaler()
|
||||
self.logger.info("Mixed precision training enabled")
|
||||
else:
|
||||
self.scaler = None
|
||||
if config.mixed_precision and not AMP_AVAILABLE:
|
||||
self.logger.warning("Mixed precision requested but not available")
|
||||
|
||||
# Loss function
|
||||
self.criterion = get_loss_function(
|
||||
loss_type=config.loss_type,
|
||||
temperature=config.temperature,
|
||||
triplet_margin=config.triplet_margin,
|
||||
)
|
||||
|
||||
# Evaluator
|
||||
self.evaluator = EmbeddingEvaluator()
|
||||
|
||||
# Training state
|
||||
self.epoch = 0
|
||||
self.global_step = 0
|
||||
self.best_val_loss = float("inf")
|
||||
self.best_val_separation = float("-inf")
|
||||
self.patience_counter = 0
|
||||
self.training_history = []
|
||||
|
||||
def train(self) -> Dict[str, float]:
|
||||
"""
|
||||
Main training loop.
|
||||
|
||||
Returns:
|
||||
Dict with final training metrics
|
||||
"""
|
||||
self.logger.info("Starting training...")
|
||||
self.logger.info(f" Epochs: {self.config.max_epochs}")
|
||||
self.logger.info(f" Batch size: {self.config.batch_size}")
|
||||
self.logger.info(f" Gradient accumulation: {self.config.gradient_accumulation_steps}")
|
||||
self.logger.info(f" Effective batch: {self.config.effective_batch_size}")
|
||||
self.logger.info(f" Learning rate: {self.config.learning_rate}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(self.epoch, self.config.max_epochs):
|
||||
self.epoch = epoch
|
||||
self.logger.info(f"\nEpoch {epoch + 1}/{self.config.max_epochs}")
|
||||
|
||||
# Training epoch
|
||||
train_metrics = self._train_epoch()
|
||||
self.logger.info(
|
||||
f"Train - Loss: {train_metrics['loss']:.4f}, "
|
||||
f"LR: {train_metrics['lr']:.2e}"
|
||||
)
|
||||
|
||||
# Validation
|
||||
if (epoch + 1) % self.config.eval_every_n_epochs == 0:
|
||||
val_metrics = self._validate()
|
||||
self.logger.info(
|
||||
f"Val - Loss: {val_metrics['loss']:.4f}, "
|
||||
f"Pos Sim: {val_metrics['mean_pos_sim']:.3f}, "
|
||||
f"Neg Sim: {val_metrics['mean_neg_sim']:.3f}, "
|
||||
f"Separation: {val_metrics['separation']:.3f}"
|
||||
)
|
||||
|
||||
# Record history
|
||||
self.training_history.append({
|
||||
"epoch": epoch + 1,
|
||||
"train_loss": train_metrics["loss"],
|
||||
"val_loss": val_metrics["loss"],
|
||||
"val_separation": val_metrics["separation"],
|
||||
"val_pos_sim": val_metrics["mean_pos_sim"],
|
||||
"val_neg_sim": val_metrics["mean_neg_sim"],
|
||||
})
|
||||
|
||||
# Checkpointing based on separation (gap between pos and neg similarity)
|
||||
# This is the key metric for contrastive learning quality
|
||||
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
|
||||
self.best_val_separation = val_metrics["separation"]
|
||||
self.best_val_loss = val_metrics["loss"] # Track for reference
|
||||
self.patience_counter = 0
|
||||
self._save_checkpoint("best.pt")
|
||||
self.logger.info("New best model saved!")
|
||||
else:
|
||||
self.patience_counter += 1
|
||||
|
||||
# Early stopping
|
||||
if self.patience_counter >= self.config.patience:
|
||||
self.logger.info(
|
||||
f"Early stopping triggered at epoch {epoch + 1} "
|
||||
f"(no improvement for {self.config.patience} epochs)"
|
||||
)
|
||||
break
|
||||
|
||||
# Periodic checkpoint
|
||||
if (epoch + 1) % self.config.save_every_n_epochs == 0:
|
||||
self._save_checkpoint(f"epoch_{epoch + 1}.pt")
|
||||
|
||||
# Training complete
|
||||
total_time = time.time() - start_time
|
||||
self.logger.info(f"\nTraining completed in {total_time / 60:.1f} minutes")
|
||||
|
||||
# Load best model
|
||||
best_path = Path(self.config.checkpoint_dir) / "best.pt"
|
||||
if best_path.exists():
|
||||
self.load_checkpoint("best.pt")
|
||||
self.logger.info("Loaded best model checkpoint")
|
||||
|
||||
return {
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"best_val_separation": self.best_val_separation,
|
||||
"total_epochs": self.epoch + 1,
|
||||
"total_time_minutes": total_time / 60,
|
||||
}
|
||||
|
||||
def _train_epoch(self) -> Dict[str, float]:
|
||||
"""Run a single training epoch."""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
accumulation_steps = 0
|
||||
|
||||
progress_bar = tqdm(
|
||||
self.train_loader,
|
||||
desc=f"Epoch {self.epoch + 1}",
|
||||
leave=False,
|
||||
)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
for batch_idx, (images, labels) in enumerate(progress_bar):
|
||||
images = images.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
# Forward pass with mixed precision
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
loss = loss / self.config.gradient_accumulation_steps
|
||||
|
||||
self.scaler.scale(loss).backward()
|
||||
else:
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
loss = loss / self.config.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
accumulation_steps += 1
|
||||
|
||||
# Optimizer step after accumulation
|
||||
if accumulation_steps >= self.config.gradient_accumulation_steps:
|
||||
if self.use_amp:
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
self.global_step += 1
|
||||
accumulation_steps = 0
|
||||
|
||||
total_loss += loss.item() * self.config.gradient_accumulation_steps
|
||||
num_batches += 1
|
||||
|
||||
# Update progress bar
|
||||
progress_bar.set_postfix({
|
||||
"loss": total_loss / num_batches,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
})
|
||||
|
||||
# Logging
|
||||
if (batch_idx + 1) % self.config.log_every_n_steps == 0:
|
||||
self.logger.debug(
|
||||
f"Step {self.global_step}: loss={total_loss / num_batches:.4f}"
|
||||
)
|
||||
|
||||
return {
|
||||
"loss": total_loss / max(num_batches, 1),
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
}
|
||||
|
||||
def _validate(self) -> Dict[str, float]:
|
||||
"""Run validation and compute metrics."""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
all_embeddings = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for images, labels in tqdm(self.val_loader, desc="Validating", leave=False):
|
||||
images = images.to(self.device)
|
||||
labels = labels.to(self.device)
|
||||
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
else:
|
||||
embeddings = self.model(images)
|
||||
loss = self.criterion(embeddings, labels)
|
||||
|
||||
total_loss += loss.item()
|
||||
all_embeddings.append(embeddings.cpu())
|
||||
all_labels.append(labels.cpu())
|
||||
|
||||
# Combine batches
|
||||
all_embeddings = torch.cat(all_embeddings, dim=0)
|
||||
all_labels = torch.cat(all_labels, dim=0)
|
||||
|
||||
# Compute embedding quality metrics
|
||||
metrics = self.evaluator.compute_metrics(all_embeddings, all_labels)
|
||||
metrics["loss"] = total_loss / max(len(self.val_loader), 1)
|
||||
|
||||
return metrics
|
||||
|
||||
def _save_checkpoint(self, filename: str) -> None:
|
||||
"""Save training checkpoint."""
|
||||
checkpoint_dir = Path(self.config.checkpoint_dir)
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
checkpoint = {
|
||||
"epoch": self.epoch,
|
||||
"global_step": self.global_step,
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
"scheduler_state_dict": self.scheduler.state_dict(),
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"best_val_separation": self.best_val_separation,
|
||||
"patience_counter": self.patience_counter,
|
||||
"training_history": self.training_history,
|
||||
"config": self.config.__dict__,
|
||||
}
|
||||
|
||||
if self.scaler is not None:
|
||||
checkpoint["scaler_state_dict"] = self.scaler.state_dict()
|
||||
|
||||
torch.save(checkpoint, checkpoint_dir / filename)
|
||||
self.logger.debug(f"Saved checkpoint: {filename}")
|
||||
|
||||
def load_checkpoint(self, filename: str) -> None:
|
||||
"""Load training checkpoint."""
|
||||
checkpoint_path = Path(self.config.checkpoint_dir) / filename
|
||||
if not checkpoint_path.exists():
|
||||
self.logger.warning(f"Checkpoint not found: {checkpoint_path}")
|
||||
return
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
||||
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
||||
self.epoch = checkpoint["epoch"]
|
||||
self.global_step = checkpoint["global_step"]
|
||||
self.best_val_loss = checkpoint["best_val_loss"]
|
||||
self.best_val_separation = checkpoint.get("best_val_separation", float("-inf"))
|
||||
self.patience_counter = checkpoint.get("patience_counter", 0)
|
||||
self.training_history = checkpoint.get("training_history", [])
|
||||
|
||||
if self.scaler is not None and "scaler_state_dict" in checkpoint:
|
||||
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
||||
|
||||
self.logger.info(f"Resumed from epoch {self.epoch + 1}")
|
||||
|
||||
def export_model(self, output_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
Export the trained model for inference.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory (uses config.output_dir if not specified)
|
||||
|
||||
Returns:
|
||||
Path to exported model directory
|
||||
"""
|
||||
output_dir = output_dir or self.config.output_dir
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
self.model.save_pretrained(output_dir)
|
||||
|
||||
# Save training config
|
||||
config_path = output_path / "training_config.json"
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(self.config.__dict__, f, indent=2)
|
||||
|
||||
# Save training history
|
||||
history_path = output_path / "training_history.json"
|
||||
with open(history_path, "w") as f:
|
||||
json.dump(self.training_history, f, indent=2)
|
||||
|
||||
self.logger.info(f"Model exported to: {output_path}")
|
||||
return str(output_path)
|
||||
|
||||
def get_training_summary(self) -> Dict:
|
||||
"""Get summary of training."""
|
||||
return {
|
||||
"epochs_completed": self.epoch + 1,
|
||||
"global_steps": self.global_step,
|
||||
"best_val_loss": self.best_val_loss,
|
||||
"best_val_separation": self.best_val_separation,
|
||||
"history": self.training_history,
|
||||
}
|
||||
Reference in New Issue
Block a user