Files
logo_test/training/dataset.py
Rick McEwen 14a1bda3fa Add image-level split support for CLIP fine-tuning
Image-level splits allow the model to see some images from each logo
brand during training, unlike logo-level splits where test brands are
completely unseen. This is less rigorous but more representative of
real-world use.

Changes:
- Add configs/image_level_splits.yaml with gentler training settings:
  - split_level: "image" for image-level splits
  - temperature: 0.15 (softer contrastive learning)
  - learning_rate: 5e-6 (slower learning)
  - max_epochs: 30 (more epochs)

- Update training/dataset.py:
  - Add split_level parameter to LogoDataset
  - Implement _split_images() for image-level splitting
  - Update LogoContrastiveDataset to use split-specific image mappings

- Update training/config.py:
  - Add split_level field to TrainingConfig

- Update train_clip_logo.py:
  - Pass split_level to create_dataloaders

Usage:
  uv run python train_clip_logo.py --config configs/image_level_splits.yaml
2026-01-05 15:10:45 -05:00

568 lines
19 KiB
Python

"""
Dataset classes for contrastive learning on logo images.
"""
import random
import sqlite3
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
# CLIP normalization values
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
def get_train_transforms(strength: str = "medium") -> transforms.Compose:
"""
Get training data augmentation transforms.
Args:
strength: Augmentation strength - "light", "medium", or "strong"
Returns:
Composed transforms for training
"""
if strength == "light":
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.1, contrast=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
elif strength == "medium":
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
),
transforms.RandomAffine(
degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
else: # strong
return transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.1),
transforms.RandomRotation(degrees=30),
transforms.ColorJitter(
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
),
transforms.RandomAffine(
degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10
),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
def get_val_transforms() -> transforms.Compose:
"""Get validation/test transforms (no augmentation)."""
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
class LogoDataset:
"""
Manages logo data from the SQLite database.
Handles loading logo-to-image mappings and splitting by logo brand or image.
"""
def __init__(
self,
db_path: str,
reference_dir: str,
train_split: float = 0.7,
val_split: float = 0.15,
test_split: float = 0.15,
seed: int = 42,
split_level: str = "logo",
):
"""
Initialize the logo dataset.
Args:
db_path: Path to SQLite database
reference_dir: Directory containing reference logo images
train_split: Fraction for training
val_split: Fraction for validation
test_split: Fraction for testing
seed: Random seed for reproducibility
split_level: "logo" for brand-level splits (test on unseen brands),
"image" for image-level splits (test on unseen images
from seen brands)
"""
self.db_path = Path(db_path)
self.reference_dir = Path(reference_dir)
self.seed = seed
self.split_level = split_level
# Load logo-to-images mapping from database
self.logo_to_images = self._load_logo_mappings()
self.all_logos = list(self.logo_to_images.keys())
if split_level == "logo":
# Logo-level splits: test logos are completely unseen brands
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
train_split, val_split, test_split
)
# For logo-level splits, each split has its own logos
self.train_logo_to_images = {
l: self.logo_to_images[l] for l in self.train_logos
}
self.val_logo_to_images = {
l: self.logo_to_images[l] for l in self.val_logos
}
self.test_logo_to_images = {
l: self.logo_to_images[l] for l in self.test_logos
}
else:
# Image-level splits: all logos present in all splits, different images
(
self.train_logo_to_images,
self.val_logo_to_images,
self.test_logo_to_images,
) = self._split_images(train_split, val_split, test_split)
# All logos are in all splits
self.train_logos = list(self.train_logo_to_images.keys())
self.val_logos = list(self.val_logo_to_images.keys())
self.test_logos = list(self.test_logo_to_images.keys())
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
"""Load logo name to image paths mapping from database."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
SELECT ln.name, rl.filename
FROM reference_logos rl
JOIN logo_names ln ON rl.logo_name_id = ln.id
ORDER BY ln.name
""")
logo_to_images: Dict[str, List[Path]] = {}
for logo_name, filename in cursor.fetchall():
if logo_name not in logo_to_images:
logo_to_images[logo_name] = []
logo_to_images[logo_name].append(self.reference_dir / filename)
conn.close()
return logo_to_images
def _split_logos(
self,
train_split: float,
val_split: float,
test_split: float,
) -> Tuple[List[str], List[str], List[str]]:
"""Split logos at brand level for train/val/test."""
random.seed(self.seed)
logos = self.all_logos.copy()
random.shuffle(logos)
n = len(logos)
train_end = int(n * train_split)
val_end = train_end + int(n * val_split)
train_logos = logos[:train_end]
val_logos = logos[train_end:val_end]
test_logos = logos[val_end:]
return train_logos, val_logos, test_logos
def _split_images(
self,
train_split: float,
val_split: float,
test_split: float,
) -> Tuple[Dict[str, List[Path]], Dict[str, List[Path]], Dict[str, List[Path]]]:
"""
Split images within each logo brand for train/val/test.
Each logo brand will have images in all splits, allowing the model
to see some examples of each brand during training.
"""
random.seed(self.seed)
train_logo_to_images: Dict[str, List[Path]] = {}
val_logo_to_images: Dict[str, List[Path]] = {}
test_logo_to_images: Dict[str, List[Path]] = {}
for logo, images in self.logo_to_images.items():
# Shuffle images for this logo
shuffled_images = images.copy()
random.shuffle(shuffled_images)
n = len(shuffled_images)
if n == 1:
# Only one image: put in train only
train_logo_to_images[logo] = shuffled_images
continue
elif n == 2:
# Two images: one train, one val
train_logo_to_images[logo] = [shuffled_images[0]]
val_logo_to_images[logo] = [shuffled_images[1]]
continue
# Normal split for 3+ images
train_end = max(1, int(n * train_split))
val_end = train_end + max(1, int(n * val_split))
train_images = shuffled_images[:train_end]
val_images = shuffled_images[train_end:val_end]
test_images = shuffled_images[val_end:]
# Ensure at least one image in train
if train_images:
train_logo_to_images[logo] = train_images
if val_images:
val_logo_to_images[logo] = val_images
if test_images:
test_logo_to_images[logo] = test_images
return train_logo_to_images, val_logo_to_images, test_logo_to_images
def get_split_info(self) -> Dict[str, any]:
"""Return information about the splits."""
return {
"split_level": self.split_level,
"total_logos": len(self.all_logos),
"train_logos": len(self.train_logos),
"val_logos": len(self.val_logos),
"test_logos": len(self.test_logos),
"train_images": sum(
len(imgs) for imgs in self.train_logo_to_images.values()
),
"val_images": sum(
len(imgs) for imgs in self.val_logo_to_images.values()
),
"test_images": sum(
len(imgs) for imgs in self.test_logo_to_images.values()
),
}
class LogoContrastiveDataset(Dataset):
"""
Dataset for contrastive learning on logos.
Each __getitem__ call returns a batch of images organized for contrastive
learning: K different logos with M samples each, ensuring positive pairs
exist within each batch.
"""
def __init__(
self,
logo_data: LogoDataset,
split: str = "train",
logos_per_batch: int = 32,
samples_per_logo: int = 4,
transform: Optional[transforms.Compose] = None,
batches_per_epoch: int = 1000,
):
"""
Initialize the contrastive dataset.
Args:
logo_data: LogoDataset instance with logo mappings
split: One of "train", "val", or "test"
logos_per_batch: Number of different logos per batch
samples_per_logo: Number of samples for each logo
transform: Image transforms to apply
batches_per_epoch: Number of batches per epoch
"""
self.logo_data = logo_data
self.logos_per_batch = logos_per_batch
self.samples_per_logo = samples_per_logo
self.transform = transform
self.batches_per_epoch = batches_per_epoch
# Get logos and their images for this split
# This respects both logo-level and image-level splits
if split == "train":
self.logos = logo_data.train_logos
self.logo_to_images = logo_data.train_logo_to_images
elif split == "val":
self.logos = logo_data.val_logos
self.logo_to_images = logo_data.val_logo_to_images
else:
self.logos = logo_data.test_logos
self.logo_to_images = logo_data.test_logo_to_images
# Filter logos with enough samples for this split
self.valid_logos = [
logo for logo in self.logos
if logo in self.logo_to_images and len(self.logo_to_images[logo]) >= samples_per_logo
]
# For logos with fewer samples, we'll use with replacement
self.logos_needing_replacement = [
logo for logo in self.logos
if logo in self.logo_to_images and len(self.logo_to_images[logo]) < samples_per_logo
]
# Create label mapping (use all logos from the full dataset for consistent labels)
self.logo_to_label = {
logo: idx for idx, logo in enumerate(logo_data.all_logos)
}
def __len__(self) -> int:
return self.batches_per_epoch
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a batch of images for contrastive learning.
Returns:
images: Tensor of shape [K*M, 3, 224, 224]
labels: Tensor of shape [K*M] with logo class indices
"""
images = []
labels = []
# Sample K logos for this batch (only from logos that have images in this split)
available_logos = [l for l in self.logos if l in self.logo_to_images]
k = min(self.logos_per_batch, len(available_logos))
batch_logos = random.sample(available_logos, k)
for logo in batch_logos:
logo_images = self.logo_to_images[logo]
# Sample M images for this logo
if len(logo_images) >= self.samples_per_logo:
sampled_paths = random.sample(logo_images, self.samples_per_logo)
else:
# Sample with replacement if not enough images
sampled_paths = random.choices(
logo_images, k=self.samples_per_logo
)
# Load and transform images
for img_path in sampled_paths:
try:
img = Image.open(img_path).convert("RGB")
if self.transform:
img = self.transform(img)
else:
img = get_val_transforms()(img)
images.append(img)
labels.append(self.logo_to_label[logo])
except Exception as e:
# Skip problematic images, sample another
continue
# Stack into tensors
if len(images) == 0:
# Fallback: return dummy batch
return (
torch.zeros(1, 3, 224, 224),
torch.zeros(1, dtype=torch.long),
)
images_tensor = torch.stack(images)
labels_tensor = torch.tensor(labels, dtype=torch.long)
return images_tensor, labels_tensor
class BalancedBatchSampler(Sampler):
"""
Sampler that ensures each batch has a balanced distribution of logos.
Used with a flattened dataset where each sample is a single image.
"""
def __init__(
self,
logo_labels: List[int],
logos_per_batch: int,
samples_per_logo: int,
num_batches: int,
):
self.logo_labels = logo_labels
self.logos_per_batch = logos_per_batch
self.samples_per_logo = samples_per_logo
self.num_batches = num_batches
# Group indices by logo
self.logo_to_indices: Dict[int, List[int]] = {}
for idx, label in enumerate(logo_labels):
if label not in self.logo_to_indices:
self.logo_to_indices[label] = []
self.logo_to_indices[label].append(idx)
self.all_logos = list(self.logo_to_indices.keys())
def __iter__(self):
for _ in range(self.num_batches):
batch_indices = []
# Sample logos for this batch
logos = random.sample(
self.all_logos,
min(self.logos_per_batch, len(self.all_logos)),
)
for logo in logos:
indices = self.logo_to_indices[logo]
if len(indices) >= self.samples_per_logo:
sampled = random.sample(indices, self.samples_per_logo)
else:
sampled = random.choices(indices, k=self.samples_per_logo)
batch_indices.extend(sampled)
yield batch_indices
def __len__(self):
return self.num_batches
def create_dataloaders(
db_path: str,
reference_dir: str,
batch_size: int = 16,
logos_per_batch: int = 32,
samples_per_logo: int = 4,
num_workers: int = 4,
train_split: float = 0.7,
val_split: float = 0.15,
test_split: float = 0.15,
seed: int = 42,
augmentation_strength: str = "medium",
batches_per_epoch: int = 1000,
split_level: str = "logo",
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
"""
Create train, validation, and optionally test dataloaders.
Args:
db_path: Path to SQLite database
reference_dir: Directory containing reference logo images
batch_size: Not used directly (see logos_per_batch and samples_per_logo)
logos_per_batch: Number of different logos per batch
samples_per_logo: Samples per logo in batch
num_workers: Number of data loading workers
train_split: Fraction for training
val_split: Fraction for validation
test_split: Fraction for testing
seed: Random seed
augmentation_strength: "light", "medium", or "strong"
batches_per_epoch: Number of batches per training epoch
split_level: "logo" for brand-level splits, "image" for image-level splits
Returns:
Tuple of (train_loader, val_loader, test_loader)
"""
# Load logo data
logo_data = LogoDataset(
db_path=db_path,
reference_dir=reference_dir,
train_split=train_split,
val_split=val_split,
test_split=test_split,
seed=seed,
split_level=split_level,
)
# Print split info
split_info = logo_data.get_split_info()
print(f"Dataset loaded:")
print(f" Split level: {split_info['split_level']}")
print(f" Total logos: {split_info['total_logos']}")
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images")
# Create datasets
train_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="train",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_train_transforms(augmentation_strength),
batches_per_epoch=batches_per_epoch,
)
val_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="val",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_val_transforms(),
batches_per_epoch=batches_per_epoch // 10, # Fewer val batches
)
test_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="test",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_val_transforms(),
batches_per_epoch=batches_per_epoch // 10,
) if test_split > 0 else None
# Create dataloaders
# Note: batch_size=1 because each __getitem__ already returns a batch
train_loader = DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
test_loader = None
if test_dataset is not None:
test_loader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
return train_loader, val_loader, test_loader
def _collate_contrastive_batch(
batch: List[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Collate function that unpacks pre-batched data.
Since LogoContrastiveDataset already returns batched data,
we just squeeze the outer dimension.
"""
images, labels = batch[0]
return images, labels