Image-level splits allow the model to see some images from each logo brand during training, unlike logo-level splits where test brands are completely unseen. This is less rigorous but more representative of real-world use. Changes: - Add configs/image_level_splits.yaml with gentler training settings: - split_level: "image" for image-level splits - temperature: 0.15 (softer contrastive learning) - learning_rate: 5e-6 (slower learning) - max_epochs: 30 (more epochs) - Update training/dataset.py: - Add split_level parameter to LogoDataset - Implement _split_images() for image-level splitting - Update LogoContrastiveDataset to use split-specific image mappings - Update training/config.py: - Add split_level field to TrainingConfig - Update train_clip_logo.py: - Pass split_level to create_dataloaders Usage: uv run python train_clip_logo.py --config configs/image_level_splits.yaml
143 lines
4.3 KiB
Python
143 lines
4.3 KiB
Python
"""
|
|
Training configuration for CLIP fine-tuning.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
import yaml
|
|
|
|
|
|
@dataclass
|
|
class TrainingConfig:
|
|
"""Configuration for CLIP logo fine-tuning."""
|
|
|
|
# Base model
|
|
base_model: str = "openai/clip-vit-large-patch14"
|
|
|
|
# Dataset paths
|
|
dataset_dir: str = "LogoDet-3K"
|
|
reference_dir: str = "reference_logos"
|
|
db_path: str = "test_data_mapping.db"
|
|
|
|
# Data split configuration
|
|
split_level: str = "logo" # "logo" for brand-level, "image" for image-level
|
|
train_split: float = 0.7
|
|
val_split: float = 0.15
|
|
test_split: float = 0.15
|
|
|
|
# Batch construction
|
|
batch_size: int = 16
|
|
logos_per_batch: int = 32
|
|
samples_per_logo: int = 4
|
|
gradient_accumulation_steps: int = 8
|
|
num_workers: int = 4
|
|
|
|
# Model architecture
|
|
lora_r: int = 16
|
|
lora_alpha: int = 32
|
|
lora_dropout: float = 0.1
|
|
freeze_layers: int = 12
|
|
use_gradient_checkpointing: bool = True
|
|
|
|
# Training hyperparameters
|
|
learning_rate: float = 1e-5
|
|
weight_decay: float = 0.01
|
|
warmup_steps: int = 500
|
|
max_epochs: int = 20
|
|
mixed_precision: bool = True
|
|
|
|
# Loss function
|
|
temperature: float = 0.07
|
|
loss_type: str = "infonce" # "infonce" or "triplet"
|
|
triplet_margin: float = 0.3
|
|
|
|
# Early stopping
|
|
patience: int = 5
|
|
min_delta: float = 0.001
|
|
|
|
# Checkpoints and output
|
|
checkpoint_dir: str = "checkpoints"
|
|
output_dir: str = "models/logo_detection/clip_finetuned"
|
|
save_every_n_epochs: int = 5
|
|
|
|
# Logging
|
|
log_every_n_steps: int = 10
|
|
eval_every_n_epochs: int = 1
|
|
|
|
# Random seed for reproducibility
|
|
seed: int = 42
|
|
|
|
# Hard negative mining
|
|
use_hard_negatives: bool = False
|
|
hard_negative_start_epoch: int = 5
|
|
hard_negatives_per_logo: int = 10
|
|
|
|
# Data augmentation
|
|
use_augmentation: bool = True
|
|
augmentation_strength: str = "medium" # "light", "medium", "strong"
|
|
|
|
@classmethod
|
|
def from_yaml(cls, yaml_path: str) -> "TrainingConfig":
|
|
"""Load configuration from YAML file."""
|
|
with open(yaml_path, "r") as f:
|
|
config_dict = yaml.safe_load(f)
|
|
return cls(**config_dict)
|
|
|
|
def to_yaml(self, yaml_path: str) -> None:
|
|
"""Save configuration to YAML file."""
|
|
Path(yaml_path).parent.mkdir(parents=True, exist_ok=True)
|
|
with open(yaml_path, "w") as f:
|
|
yaml.dump(self.__dict__, f, default_flow_style=False, sort_keys=False)
|
|
|
|
def validate(self) -> List[str]:
|
|
"""Validate configuration and return list of warnings."""
|
|
warnings = []
|
|
|
|
# Check split ratios
|
|
total_split = self.train_split + self.val_split + self.test_split
|
|
if abs(total_split - 1.0) > 0.01:
|
|
warnings.append(
|
|
f"Split ratios sum to {total_split}, expected 1.0"
|
|
)
|
|
|
|
# Check batch construction
|
|
effective_batch = self.batch_size * self.gradient_accumulation_steps
|
|
if effective_batch < 64:
|
|
warnings.append(
|
|
f"Effective batch size ({effective_batch}) is small for contrastive learning. "
|
|
"Consider increasing batch_size or gradient_accumulation_steps."
|
|
)
|
|
|
|
# Check LoRA config
|
|
if self.lora_r > 0 and self.lora_alpha < self.lora_r:
|
|
warnings.append(
|
|
f"lora_alpha ({self.lora_alpha}) < lora_r ({self.lora_r}). "
|
|
"This may reduce LoRA effectiveness."
|
|
)
|
|
|
|
# Check freeze layers
|
|
if self.freeze_layers < 0:
|
|
warnings.append("freeze_layers should be >= 0")
|
|
|
|
# Check temperature
|
|
if self.temperature <= 0:
|
|
warnings.append("temperature must be positive")
|
|
elif self.temperature > 1.0:
|
|
warnings.append(
|
|
f"temperature ({self.temperature}) is high. "
|
|
"Typical values are 0.05-0.1."
|
|
)
|
|
|
|
return warnings
|
|
|
|
@property
|
|
def effective_batch_size(self) -> int:
|
|
"""Calculate effective batch size with gradient accumulation."""
|
|
return self.batch_size * self.gradient_accumulation_steps
|
|
|
|
@property
|
|
def samples_per_batch(self) -> int:
|
|
"""Total samples in one batch (logos_per_batch * samples_per_logo)."""
|
|
return self.logos_per_batch * self.samples_per_logo
|