Files
logo_test/training/config.py
Rick McEwen 14a1bda3fa Add image-level split support for CLIP fine-tuning
Image-level splits allow the model to see some images from each logo
brand during training, unlike logo-level splits where test brands are
completely unseen. This is less rigorous but more representative of
real-world use.

Changes:
- Add configs/image_level_splits.yaml with gentler training settings:
  - split_level: "image" for image-level splits
  - temperature: 0.15 (softer contrastive learning)
  - learning_rate: 5e-6 (slower learning)
  - max_epochs: 30 (more epochs)

- Update training/dataset.py:
  - Add split_level parameter to LogoDataset
  - Implement _split_images() for image-level splitting
  - Update LogoContrastiveDataset to use split-specific image mappings

- Update training/config.py:
  - Add split_level field to TrainingConfig

- Update train_clip_logo.py:
  - Pass split_level to create_dataloaders

Usage:
  uv run python train_clip_logo.py --config configs/image_level_splits.yaml
2026-01-05 15:10:45 -05:00

143 lines
4.3 KiB
Python

"""
Training configuration for CLIP fine-tuning.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional
import yaml
@dataclass
class TrainingConfig:
"""Configuration for CLIP logo fine-tuning."""
# Base model
base_model: str = "openai/clip-vit-large-patch14"
# Dataset paths
dataset_dir: str = "LogoDet-3K"
reference_dir: str = "reference_logos"
db_path: str = "test_data_mapping.db"
# Data split configuration
split_level: str = "logo" # "logo" for brand-level, "image" for image-level
train_split: float = 0.7
val_split: float = 0.15
test_split: float = 0.15
# Batch construction
batch_size: int = 16
logos_per_batch: int = 32
samples_per_logo: int = 4
gradient_accumulation_steps: int = 8
num_workers: int = 4
# Model architecture
lora_r: int = 16
lora_alpha: int = 32
lora_dropout: float = 0.1
freeze_layers: int = 12
use_gradient_checkpointing: bool = True
# Training hyperparameters
learning_rate: float = 1e-5
weight_decay: float = 0.01
warmup_steps: int = 500
max_epochs: int = 20
mixed_precision: bool = True
# Loss function
temperature: float = 0.07
loss_type: str = "infonce" # "infonce" or "triplet"
triplet_margin: float = 0.3
# Early stopping
patience: int = 5
min_delta: float = 0.001
# Checkpoints and output
checkpoint_dir: str = "checkpoints"
output_dir: str = "models/logo_detection/clip_finetuned"
save_every_n_epochs: int = 5
# Logging
log_every_n_steps: int = 10
eval_every_n_epochs: int = 1
# Random seed for reproducibility
seed: int = 42
# Hard negative mining
use_hard_negatives: bool = False
hard_negative_start_epoch: int = 5
hard_negatives_per_logo: int = 10
# Data augmentation
use_augmentation: bool = True
augmentation_strength: str = "medium" # "light", "medium", "strong"
@classmethod
def from_yaml(cls, yaml_path: str) -> "TrainingConfig":
"""Load configuration from YAML file."""
with open(yaml_path, "r") as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)
def to_yaml(self, yaml_path: str) -> None:
"""Save configuration to YAML file."""
Path(yaml_path).parent.mkdir(parents=True, exist_ok=True)
with open(yaml_path, "w") as f:
yaml.dump(self.__dict__, f, default_flow_style=False, sort_keys=False)
def validate(self) -> List[str]:
"""Validate configuration and return list of warnings."""
warnings = []
# Check split ratios
total_split = self.train_split + self.val_split + self.test_split
if abs(total_split - 1.0) > 0.01:
warnings.append(
f"Split ratios sum to {total_split}, expected 1.0"
)
# Check batch construction
effective_batch = self.batch_size * self.gradient_accumulation_steps
if effective_batch < 64:
warnings.append(
f"Effective batch size ({effective_batch}) is small for contrastive learning. "
"Consider increasing batch_size or gradient_accumulation_steps."
)
# Check LoRA config
if self.lora_r > 0 and self.lora_alpha < self.lora_r:
warnings.append(
f"lora_alpha ({self.lora_alpha}) < lora_r ({self.lora_r}). "
"This may reduce LoRA effectiveness."
)
# Check freeze layers
if self.freeze_layers < 0:
warnings.append("freeze_layers should be >= 0")
# Check temperature
if self.temperature <= 0:
warnings.append("temperature must be positive")
elif self.temperature > 1.0:
warnings.append(
f"temperature ({self.temperature}) is high. "
"Typical values are 0.05-0.1."
)
return warnings
@property
def effective_batch_size(self) -> int:
"""Calculate effective batch size with gradient accumulation."""
return self.batch_size * self.gradient_accumulation_steps
@property
def samples_per_batch(self) -> int:
"""Total samples in one batch (logos_per_batch * samples_per_logo)."""
return self.logos_per_batch * self.samples_per_logo