Add image-level split support for CLIP fine-tuning

Image-level splits allow the model to see some images from each logo
brand during training, unlike logo-level splits where test brands are
completely unseen. This is less rigorous but more representative of
real-world use.

Changes:
- Add configs/image_level_splits.yaml with gentler training settings:
  - split_level: "image" for image-level splits
  - temperature: 0.15 (softer contrastive learning)
  - learning_rate: 5e-6 (slower learning)
  - max_epochs: 30 (more epochs)

- Update training/dataset.py:
  - Add split_level parameter to LogoDataset
  - Implement _split_images() for image-level splitting
  - Update LogoContrastiveDataset to use split-specific image mappings

- Update training/config.py:
  - Add split_level field to TrainingConfig

- Update train_clip_logo.py:
  - Pass split_level to create_dataloaders

Usage:
  uv run python train_clip_logo.py --config configs/image_level_splits.yaml
This commit is contained in:
Rick McEwen
2026-01-05 15:10:45 -05:00
parent 32bfefc022
commit 14a1bda3fa
4 changed files with 200 additions and 20 deletions

View File

@ -0,0 +1,78 @@
# Training configuration with IMAGE-LEVEL splits
#
# Unlike logo-level splits where test logos are completely unseen brands,
# image-level splits allow the model to see some images from each brand
# during training. This is less rigorous but more representative of
# real-world use where you have reference images for logos you want to detect.
#
# Also uses gentler contrastive learning settings to prevent over-separation.
#
# Usage:
# uv run python train_clip_logo.py --config configs/image_level_splits.yaml
# Base model
base_model: "openai/clip-vit-large-patch14"
# Dataset paths (relative to project root)
dataset_dir: "LogoDet-3K"
reference_dir: "reference_logos"
db_path: "test_data_mapping.db"
# Data split configuration
# split_level: "image" means images are split, not logo brands
# This allows test set to contain images from brands seen during training
split_level: "image"
train_split: 0.7
val_split: 0.15
test_split: 0.15
# Batch construction
batch_size: 16
logos_per_batch: 32
samples_per_logo: 4
gradient_accumulation_steps: 8
num_workers: 4
# Model architecture - same as before
lora_r: 16
lora_alpha: 32
lora_dropout: 0.1
freeze_layers: 12
use_gradient_checkpointing: true
# Training hyperparameters - GENTLER settings
learning_rate: 5.0e-6 # Reduced from 1e-5
weight_decay: 0.01
warmup_steps: 500
max_epochs: 30 # More epochs with slower learning
mixed_precision: true
# Loss function - HIGHER temperature for softer contrastive learning
temperature: 0.15 # Increased from 0.07
loss_type: "infonce"
triplet_margin: 0.2 # Reduced from 0.3
# Early stopping
patience: 7 # More patience with gentler learning
min_delta: 0.001
# Checkpoints and output
checkpoint_dir: "checkpoints_image_split"
output_dir: "models/logo_detection/clip_finetuned_image_split"
save_every_n_epochs: 5
# Logging
log_every_n_steps: 10
eval_every_n_epochs: 1
# Reproducibility
seed: 42
# Hard negative mining
use_hard_negatives: false
hard_negative_start_epoch: 10
hard_negatives_per_logo: 10
# Data augmentation
use_augmentation: true
augmentation_strength: "medium"

View File

@ -256,6 +256,7 @@ def main():
test_split=config.test_split, test_split=config.test_split,
seed=config.seed, seed=config.seed,
augmentation_strength=config.augmentation_strength, augmentation_strength=config.augmentation_strength,
split_level=getattr(config, 'split_level', 'logo'),
) )
# Create trainer # Create trainer

View File

@ -20,7 +20,8 @@ class TrainingConfig:
reference_dir: str = "reference_logos" reference_dir: str = "reference_logos"
db_path: str = "test_data_mapping.db" db_path: str = "test_data_mapping.db"
# Data split ratios # Data split configuration
split_level: str = "logo" # "logo" for brand-level, "image" for image-level
train_split: float = 0.7 train_split: float = 0.7
val_split: float = 0.15 val_split: float = 0.15
test_split: float = 0.15 test_split: float = 0.15

View File

@ -84,7 +84,7 @@ class LogoDataset:
""" """
Manages logo data from the SQLite database. Manages logo data from the SQLite database.
Handles loading logo-to-image mappings and splitting by logo brand. Handles loading logo-to-image mappings and splitting by logo brand or image.
""" """
def __init__( def __init__(
@ -95,19 +95,57 @@ class LogoDataset:
val_split: float = 0.15, val_split: float = 0.15,
test_split: float = 0.15, test_split: float = 0.15,
seed: int = 42, seed: int = 42,
split_level: str = "logo",
): ):
"""
Initialize the logo dataset.
Args:
db_path: Path to SQLite database
reference_dir: Directory containing reference logo images
train_split: Fraction for training
val_split: Fraction for validation
test_split: Fraction for testing
seed: Random seed for reproducibility
split_level: "logo" for brand-level splits (test on unseen brands),
"image" for image-level splits (test on unseen images
from seen brands)
"""
self.db_path = Path(db_path) self.db_path = Path(db_path)
self.reference_dir = Path(reference_dir) self.reference_dir = Path(reference_dir)
self.seed = seed self.seed = seed
self.split_level = split_level
# Load logo-to-images mapping from database # Load logo-to-images mapping from database
self.logo_to_images = self._load_logo_mappings() self.logo_to_images = self._load_logo_mappings()
self.all_logos = list(self.logo_to_images.keys()) self.all_logos = list(self.logo_to_images.keys())
# Create logo-level splits if split_level == "logo":
# Logo-level splits: test logos are completely unseen brands
self.train_logos, self.val_logos, self.test_logos = self._split_logos( self.train_logos, self.val_logos, self.test_logos = self._split_logos(
train_split, val_split, test_split train_split, val_split, test_split
) )
# For logo-level splits, each split has its own logos
self.train_logo_to_images = {
l: self.logo_to_images[l] for l in self.train_logos
}
self.val_logo_to_images = {
l: self.logo_to_images[l] for l in self.val_logos
}
self.test_logo_to_images = {
l: self.logo_to_images[l] for l in self.test_logos
}
else:
# Image-level splits: all logos present in all splits, different images
(
self.train_logo_to_images,
self.val_logo_to_images,
self.test_logo_to_images,
) = self._split_images(train_split, val_split, test_split)
# All logos are in all splits
self.train_logos = list(self.train_logo_to_images.keys())
self.val_logos = list(self.val_logo_to_images.keys())
self.test_logos = list(self.test_logo_to_images.keys())
def _load_logo_mappings(self) -> Dict[str, List[Path]]: def _load_logo_mappings(self) -> Dict[str, List[Path]]:
"""Load logo name to image paths mapping from database.""" """Load logo name to image paths mapping from database."""
@ -151,21 +189,74 @@ class LogoDataset:
return train_logos, val_logos, test_logos return train_logos, val_logos, test_logos
def get_split_info(self) -> Dict[str, int]: def _split_images(
self,
train_split: float,
val_split: float,
test_split: float,
) -> Tuple[Dict[str, List[Path]], Dict[str, List[Path]], Dict[str, List[Path]]]:
"""
Split images within each logo brand for train/val/test.
Each logo brand will have images in all splits, allowing the model
to see some examples of each brand during training.
"""
random.seed(self.seed)
train_logo_to_images: Dict[str, List[Path]] = {}
val_logo_to_images: Dict[str, List[Path]] = {}
test_logo_to_images: Dict[str, List[Path]] = {}
for logo, images in self.logo_to_images.items():
# Shuffle images for this logo
shuffled_images = images.copy()
random.shuffle(shuffled_images)
n = len(shuffled_images)
if n == 1:
# Only one image: put in train only
train_logo_to_images[logo] = shuffled_images
continue
elif n == 2:
# Two images: one train, one val
train_logo_to_images[logo] = [shuffled_images[0]]
val_logo_to_images[logo] = [shuffled_images[1]]
continue
# Normal split for 3+ images
train_end = max(1, int(n * train_split))
val_end = train_end + max(1, int(n * val_split))
train_images = shuffled_images[:train_end]
val_images = shuffled_images[train_end:val_end]
test_images = shuffled_images[val_end:]
# Ensure at least one image in train
if train_images:
train_logo_to_images[logo] = train_images
if val_images:
val_logo_to_images[logo] = val_images
if test_images:
test_logo_to_images[logo] = test_images
return train_logo_to_images, val_logo_to_images, test_logo_to_images
def get_split_info(self) -> Dict[str, any]:
"""Return information about the splits.""" """Return information about the splits."""
return { return {
"split_level": self.split_level,
"total_logos": len(self.all_logos), "total_logos": len(self.all_logos),
"train_logos": len(self.train_logos), "train_logos": len(self.train_logos),
"val_logos": len(self.val_logos), "val_logos": len(self.val_logos),
"test_logos": len(self.test_logos), "test_logos": len(self.test_logos),
"train_images": sum( "train_images": sum(
len(self.logo_to_images[l]) for l in self.train_logos len(imgs) for imgs in self.train_logo_to_images.values()
), ),
"val_images": sum( "val_images": sum(
len(self.logo_to_images[l]) for l in self.val_logos len(imgs) for imgs in self.val_logo_to_images.values()
), ),
"test_images": sum( "test_images": sum(
len(self.logo_to_images[l]) for l in self.test_logos len(imgs) for imgs in self.test_logo_to_images.values()
), ),
} }
@ -205,29 +296,33 @@ class LogoContrastiveDataset(Dataset):
self.transform = transform self.transform = transform
self.batches_per_epoch = batches_per_epoch self.batches_per_epoch = batches_per_epoch
# Get logos for this split # Get logos and their images for this split
# This respects both logo-level and image-level splits
if split == "train": if split == "train":
self.logos = logo_data.train_logos self.logos = logo_data.train_logos
self.logo_to_images = logo_data.train_logo_to_images
elif split == "val": elif split == "val":
self.logos = logo_data.val_logos self.logos = logo_data.val_logos
self.logo_to_images = logo_data.val_logo_to_images
else: else:
self.logos = logo_data.test_logos self.logos = logo_data.test_logos
self.logo_to_images = logo_data.test_logo_to_images
# Filter logos with enough samples # Filter logos with enough samples for this split
self.valid_logos = [ self.valid_logos = [
logo for logo in self.logos logo for logo in self.logos
if len(logo_data.logo_to_images[logo]) >= samples_per_logo if logo in self.logo_to_images and len(self.logo_to_images[logo]) >= samples_per_logo
] ]
# For logos with fewer samples, we'll use with replacement # For logos with fewer samples, we'll use with replacement
self.logos_needing_replacement = [ self.logos_needing_replacement = [
logo for logo in self.logos logo for logo in self.logos
if len(logo_data.logo_to_images[logo]) < samples_per_logo if logo in self.logo_to_images and len(self.logo_to_images[logo]) < samples_per_logo
] ]
# Create label mapping # Create label mapping (use all logos from the full dataset for consistent labels)
self.logo_to_label = { self.logo_to_label = {
logo: idx for idx, logo in enumerate(self.logos) logo: idx for idx, logo in enumerate(logo_data.all_logos)
} }
def __len__(self) -> int: def __len__(self) -> int:
@ -244,12 +339,13 @@ class LogoContrastiveDataset(Dataset):
images = [] images = []
labels = [] labels = []
# Sample K logos for this batch # Sample K logos for this batch (only from logos that have images in this split)
k = min(self.logos_per_batch, len(self.logos)) available_logos = [l for l in self.logos if l in self.logo_to_images]
batch_logos = random.sample(self.logos, k) k = min(self.logos_per_batch, len(available_logos))
batch_logos = random.sample(available_logos, k)
for logo in batch_logos: for logo in batch_logos:
logo_images = self.logo_data.logo_to_images[logo] logo_images = self.logo_to_images[logo]
# Sample M images for this logo # Sample M images for this logo
if len(logo_images) >= self.samples_per_logo: if len(logo_images) >= self.samples_per_logo:
@ -353,6 +449,7 @@ def create_dataloaders(
seed: int = 42, seed: int = 42,
augmentation_strength: str = "medium", augmentation_strength: str = "medium",
batches_per_epoch: int = 1000, batches_per_epoch: int = 1000,
split_level: str = "logo",
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]: ) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
""" """
Create train, validation, and optionally test dataloaders. Create train, validation, and optionally test dataloaders.
@ -370,6 +467,7 @@ def create_dataloaders(
seed: Random seed seed: Random seed
augmentation_strength: "light", "medium", or "strong" augmentation_strength: "light", "medium", or "strong"
batches_per_epoch: Number of batches per training epoch batches_per_epoch: Number of batches per training epoch
split_level: "logo" for brand-level splits, "image" for image-level splits
Returns: Returns:
Tuple of (train_loader, val_loader, test_loader) Tuple of (train_loader, val_loader, test_loader)
@ -382,11 +480,13 @@ def create_dataloaders(
val_split=val_split, val_split=val_split,
test_split=test_split, test_split=test_split,
seed=seed, seed=seed,
split_level=split_level,
) )
# Print split info # Print split info
split_info = logo_data.get_split_info() split_info = logo_data.get_split_info()
print(f"Dataset loaded:") print(f"Dataset loaded:")
print(f" Split level: {split_info['split_level']}")
print(f" Total logos: {split_info['total_logos']}") print(f" Total logos: {split_info['total_logos']}")
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images") print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images") print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")