Add image-level split support for CLIP fine-tuning

Image-level splits allow the model to see some images from each logo
brand during training, unlike logo-level splits where test brands are
completely unseen. This is less rigorous but more representative of
real-world use.

Changes:
- Add configs/image_level_splits.yaml with gentler training settings:
  - split_level: "image" for image-level splits
  - temperature: 0.15 (softer contrastive learning)
  - learning_rate: 5e-6 (slower learning)
  - max_epochs: 30 (more epochs)

- Update training/dataset.py:
  - Add split_level parameter to LogoDataset
  - Implement _split_images() for image-level splitting
  - Update LogoContrastiveDataset to use split-specific image mappings

- Update training/config.py:
  - Add split_level field to TrainingConfig

- Update train_clip_logo.py:
  - Pass split_level to create_dataloaders

Usage:
  uv run python train_clip_logo.py --config configs/image_level_splits.yaml
This commit is contained in:
Rick McEwen
2026-01-05 15:10:45 -05:00
parent 32bfefc022
commit 14a1bda3fa
4 changed files with 200 additions and 20 deletions

View File

@ -20,7 +20,8 @@ class TrainingConfig:
reference_dir: str = "reference_logos"
db_path: str = "test_data_mapping.db"
# Data split ratios
# Data split configuration
split_level: str = "logo" # "logo" for brand-level, "image" for image-level
train_split: float = 0.7
val_split: float = 0.15
test_split: float = 0.15

View File

@ -84,7 +84,7 @@ class LogoDataset:
"""
Manages logo data from the SQLite database.
Handles loading logo-to-image mappings and splitting by logo brand.
Handles loading logo-to-image mappings and splitting by logo brand or image.
"""
def __init__(
@ -95,19 +95,57 @@ class LogoDataset:
val_split: float = 0.15,
test_split: float = 0.15,
seed: int = 42,
split_level: str = "logo",
):
"""
Initialize the logo dataset.
Args:
db_path: Path to SQLite database
reference_dir: Directory containing reference logo images
train_split: Fraction for training
val_split: Fraction for validation
test_split: Fraction for testing
seed: Random seed for reproducibility
split_level: "logo" for brand-level splits (test on unseen brands),
"image" for image-level splits (test on unseen images
from seen brands)
"""
self.db_path = Path(db_path)
self.reference_dir = Path(reference_dir)
self.seed = seed
self.split_level = split_level
# Load logo-to-images mapping from database
self.logo_to_images = self._load_logo_mappings()
self.all_logos = list(self.logo_to_images.keys())
# Create logo-level splits
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
train_split, val_split, test_split
)
if split_level == "logo":
# Logo-level splits: test logos are completely unseen brands
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
train_split, val_split, test_split
)
# For logo-level splits, each split has its own logos
self.train_logo_to_images = {
l: self.logo_to_images[l] for l in self.train_logos
}
self.val_logo_to_images = {
l: self.logo_to_images[l] for l in self.val_logos
}
self.test_logo_to_images = {
l: self.logo_to_images[l] for l in self.test_logos
}
else:
# Image-level splits: all logos present in all splits, different images
(
self.train_logo_to_images,
self.val_logo_to_images,
self.test_logo_to_images,
) = self._split_images(train_split, val_split, test_split)
# All logos are in all splits
self.train_logos = list(self.train_logo_to_images.keys())
self.val_logos = list(self.val_logo_to_images.keys())
self.test_logos = list(self.test_logo_to_images.keys())
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
"""Load logo name to image paths mapping from database."""
@ -151,21 +189,74 @@ class LogoDataset:
return train_logos, val_logos, test_logos
def get_split_info(self) -> Dict[str, int]:
def _split_images(
self,
train_split: float,
val_split: float,
test_split: float,
) -> Tuple[Dict[str, List[Path]], Dict[str, List[Path]], Dict[str, List[Path]]]:
"""
Split images within each logo brand for train/val/test.
Each logo brand will have images in all splits, allowing the model
to see some examples of each brand during training.
"""
random.seed(self.seed)
train_logo_to_images: Dict[str, List[Path]] = {}
val_logo_to_images: Dict[str, List[Path]] = {}
test_logo_to_images: Dict[str, List[Path]] = {}
for logo, images in self.logo_to_images.items():
# Shuffle images for this logo
shuffled_images = images.copy()
random.shuffle(shuffled_images)
n = len(shuffled_images)
if n == 1:
# Only one image: put in train only
train_logo_to_images[logo] = shuffled_images
continue
elif n == 2:
# Two images: one train, one val
train_logo_to_images[logo] = [shuffled_images[0]]
val_logo_to_images[logo] = [shuffled_images[1]]
continue
# Normal split for 3+ images
train_end = max(1, int(n * train_split))
val_end = train_end + max(1, int(n * val_split))
train_images = shuffled_images[:train_end]
val_images = shuffled_images[train_end:val_end]
test_images = shuffled_images[val_end:]
# Ensure at least one image in train
if train_images:
train_logo_to_images[logo] = train_images
if val_images:
val_logo_to_images[logo] = val_images
if test_images:
test_logo_to_images[logo] = test_images
return train_logo_to_images, val_logo_to_images, test_logo_to_images
def get_split_info(self) -> Dict[str, any]:
"""Return information about the splits."""
return {
"split_level": self.split_level,
"total_logos": len(self.all_logos),
"train_logos": len(self.train_logos),
"val_logos": len(self.val_logos),
"test_logos": len(self.test_logos),
"train_images": sum(
len(self.logo_to_images[l]) for l in self.train_logos
len(imgs) for imgs in self.train_logo_to_images.values()
),
"val_images": sum(
len(self.logo_to_images[l]) for l in self.val_logos
len(imgs) for imgs in self.val_logo_to_images.values()
),
"test_images": sum(
len(self.logo_to_images[l]) for l in self.test_logos
len(imgs) for imgs in self.test_logo_to_images.values()
),
}
@ -205,29 +296,33 @@ class LogoContrastiveDataset(Dataset):
self.transform = transform
self.batches_per_epoch = batches_per_epoch
# Get logos for this split
# Get logos and their images for this split
# This respects both logo-level and image-level splits
if split == "train":
self.logos = logo_data.train_logos
self.logo_to_images = logo_data.train_logo_to_images
elif split == "val":
self.logos = logo_data.val_logos
self.logo_to_images = logo_data.val_logo_to_images
else:
self.logos = logo_data.test_logos
self.logo_to_images = logo_data.test_logo_to_images
# Filter logos with enough samples
# Filter logos with enough samples for this split
self.valid_logos = [
logo for logo in self.logos
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
if logo in self.logo_to_images and len(self.logo_to_images[logo]) >= samples_per_logo
]
# For logos with fewer samples, we'll use with replacement
self.logos_needing_replacement = [
logo for logo in self.logos
if len(logo_data.logo_to_images[logo]) < samples_per_logo
if logo in self.logo_to_images and len(self.logo_to_images[logo]) < samples_per_logo
]
# Create label mapping
# Create label mapping (use all logos from the full dataset for consistent labels)
self.logo_to_label = {
logo: idx for idx, logo in enumerate(self.logos)
logo: idx for idx, logo in enumerate(logo_data.all_logos)
}
def __len__(self) -> int:
@ -244,12 +339,13 @@ class LogoContrastiveDataset(Dataset):
images = []
labels = []
# Sample K logos for this batch
k = min(self.logos_per_batch, len(self.logos))
batch_logos = random.sample(self.logos, k)
# Sample K logos for this batch (only from logos that have images in this split)
available_logos = [l for l in self.logos if l in self.logo_to_images]
k = min(self.logos_per_batch, len(available_logos))
batch_logos = random.sample(available_logos, k)
for logo in batch_logos:
logo_images = self.logo_data.logo_to_images[logo]
logo_images = self.logo_to_images[logo]
# Sample M images for this logo
if len(logo_images) >= self.samples_per_logo:
@ -353,6 +449,7 @@ def create_dataloaders(
seed: int = 42,
augmentation_strength: str = "medium",
batches_per_epoch: int = 1000,
split_level: str = "logo",
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
"""
Create train, validation, and optionally test dataloaders.
@ -370,6 +467,7 @@ def create_dataloaders(
seed: Random seed
augmentation_strength: "light", "medium", or "strong"
batches_per_epoch: Number of batches per training epoch
split_level: "logo" for brand-level splits, "image" for image-level splits
Returns:
Tuple of (train_loader, val_loader, test_loader)
@ -382,11 +480,13 @@ def create_dataloaders(
val_split=val_split,
test_split=test_split,
seed=seed,
split_level=split_level,
)
# Print split info
split_info = logo_data.get_split_info()
print(f"Dataset loaded:")
print(f" Split level: {split_info['split_level']}")
print(f" Total logos: {split_info['total_logos']}")
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")