""" Dataset classes for contrastive learning on logo images. """ import random import sqlite3 from pathlib import Path from typing import Dict, List, Optional, Tuple import torch from PIL import Image from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import transforms # CLIP normalization values CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] CLIP_STD = [0.26862954, 0.26130258, 0.27577711] def get_train_transforms(strength: str = "medium") -> transforms.Compose: """ Get training data augmentation transforms. Args: strength: Augmentation strength - "light", "medium", or "strong" Returns: Composed transforms for training """ if strength == "light": return transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.1, contrast=0.1), transforms.ToTensor(), transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), ]) elif strength == "medium": return transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=15), transforms.ColorJitter( brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05 ), transforms.RandomAffine( degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1) ), transforms.RandomGrayscale(p=0.1), transforms.ToTensor(), transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), ]) else: # strong return transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.1), transforms.RandomRotation(degrees=30), transforms.ColorJitter( brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1 ), transforms.RandomAffine( degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10 ), transforms.RandomGrayscale(p=0.2), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), transforms.ToTensor(), transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), ]) def get_val_transforms() -> transforms.Compose: """Get validation/test transforms (no augmentation).""" return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), ]) class LogoDataset: """ Manages logo data from the SQLite database. Handles loading logo-to-image mappings and splitting by logo brand. """ def __init__( self, db_path: str, reference_dir: str, train_split: float = 0.7, val_split: float = 0.15, test_split: float = 0.15, seed: int = 42, ): self.db_path = Path(db_path) self.reference_dir = Path(reference_dir) self.seed = seed # Load logo-to-images mapping from database self.logo_to_images = self._load_logo_mappings() self.all_logos = list(self.logo_to_images.keys()) # Create logo-level splits self.train_logos, self.val_logos, self.test_logos = self._split_logos( train_split, val_split, test_split ) def _load_logo_mappings(self) -> Dict[str, List[Path]]: """Load logo name to image paths mapping from database.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(""" SELECT ln.name, rl.filename FROM reference_logos rl JOIN logo_names ln ON rl.logo_name_id = ln.id ORDER BY ln.name """) logo_to_images: Dict[str, List[Path]] = {} for logo_name, filename in cursor.fetchall(): if logo_name not in logo_to_images: logo_to_images[logo_name] = [] logo_to_images[logo_name].append(self.reference_dir / filename) conn.close() return logo_to_images def _split_logos( self, train_split: float, val_split: float, test_split: float, ) -> Tuple[List[str], List[str], List[str]]: """Split logos at brand level for train/val/test.""" random.seed(self.seed) logos = self.all_logos.copy() random.shuffle(logos) n = len(logos) train_end = int(n * train_split) val_end = train_end + int(n * val_split) train_logos = logos[:train_end] val_logos = logos[train_end:val_end] test_logos = logos[val_end:] return train_logos, val_logos, test_logos def get_split_info(self) -> Dict[str, int]: """Return information about the splits.""" return { "total_logos": len(self.all_logos), "train_logos": len(self.train_logos), "val_logos": len(self.val_logos), "test_logos": len(self.test_logos), "train_images": sum( len(self.logo_to_images[l]) for l in self.train_logos ), "val_images": sum( len(self.logo_to_images[l]) for l in self.val_logos ), "test_images": sum( len(self.logo_to_images[l]) for l in self.test_logos ), } class LogoContrastiveDataset(Dataset): """ Dataset for contrastive learning on logos. Each __getitem__ call returns a batch of images organized for contrastive learning: K different logos with M samples each, ensuring positive pairs exist within each batch. """ def __init__( self, logo_data: LogoDataset, split: str = "train", logos_per_batch: int = 32, samples_per_logo: int = 4, transform: Optional[transforms.Compose] = None, batches_per_epoch: int = 1000, ): """ Initialize the contrastive dataset. Args: logo_data: LogoDataset instance with logo mappings split: One of "train", "val", or "test" logos_per_batch: Number of different logos per batch samples_per_logo: Number of samples for each logo transform: Image transforms to apply batches_per_epoch: Number of batches per epoch """ self.logo_data = logo_data self.logos_per_batch = logos_per_batch self.samples_per_logo = samples_per_logo self.transform = transform self.batches_per_epoch = batches_per_epoch # Get logos for this split if split == "train": self.logos = logo_data.train_logos elif split == "val": self.logos = logo_data.val_logos else: self.logos = logo_data.test_logos # Filter logos with enough samples self.valid_logos = [ logo for logo in self.logos if len(logo_data.logo_to_images[logo]) >= samples_per_logo ] # For logos with fewer samples, we'll use with replacement self.logos_needing_replacement = [ logo for logo in self.logos if len(logo_data.logo_to_images[logo]) < samples_per_logo ] # Create label mapping self.logo_to_label = { logo: idx for idx, logo in enumerate(self.logos) } def __len__(self) -> int: return self.batches_per_epoch def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Get a batch of images for contrastive learning. Returns: images: Tensor of shape [K*M, 3, 224, 224] labels: Tensor of shape [K*M] with logo class indices """ images = [] labels = [] # Sample K logos for this batch k = min(self.logos_per_batch, len(self.logos)) batch_logos = random.sample(self.logos, k) for logo in batch_logos: logo_images = self.logo_data.logo_to_images[logo] # Sample M images for this logo if len(logo_images) >= self.samples_per_logo: sampled_paths = random.sample(logo_images, self.samples_per_logo) else: # Sample with replacement if not enough images sampled_paths = random.choices( logo_images, k=self.samples_per_logo ) # Load and transform images for img_path in sampled_paths: try: img = Image.open(img_path).convert("RGB") if self.transform: img = self.transform(img) else: img = get_val_transforms()(img) images.append(img) labels.append(self.logo_to_label[logo]) except Exception as e: # Skip problematic images, sample another continue # Stack into tensors if len(images) == 0: # Fallback: return dummy batch return ( torch.zeros(1, 3, 224, 224), torch.zeros(1, dtype=torch.long), ) images_tensor = torch.stack(images) labels_tensor = torch.tensor(labels, dtype=torch.long) return images_tensor, labels_tensor class BalancedBatchSampler(Sampler): """ Sampler that ensures each batch has a balanced distribution of logos. Used with a flattened dataset where each sample is a single image. """ def __init__( self, logo_labels: List[int], logos_per_batch: int, samples_per_logo: int, num_batches: int, ): self.logo_labels = logo_labels self.logos_per_batch = logos_per_batch self.samples_per_logo = samples_per_logo self.num_batches = num_batches # Group indices by logo self.logo_to_indices: Dict[int, List[int]] = {} for idx, label in enumerate(logo_labels): if label not in self.logo_to_indices: self.logo_to_indices[label] = [] self.logo_to_indices[label].append(idx) self.all_logos = list(self.logo_to_indices.keys()) def __iter__(self): for _ in range(self.num_batches): batch_indices = [] # Sample logos for this batch logos = random.sample( self.all_logos, min(self.logos_per_batch, len(self.all_logos)), ) for logo in logos: indices = self.logo_to_indices[logo] if len(indices) >= self.samples_per_logo: sampled = random.sample(indices, self.samples_per_logo) else: sampled = random.choices(indices, k=self.samples_per_logo) batch_indices.extend(sampled) yield batch_indices def __len__(self): return self.num_batches def create_dataloaders( db_path: str, reference_dir: str, batch_size: int = 16, logos_per_batch: int = 32, samples_per_logo: int = 4, num_workers: int = 4, train_split: float = 0.7, val_split: float = 0.15, test_split: float = 0.15, seed: int = 42, augmentation_strength: str = "medium", batches_per_epoch: int = 1000, ) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]: """ Create train, validation, and optionally test dataloaders. Args: db_path: Path to SQLite database reference_dir: Directory containing reference logo images batch_size: Not used directly (see logos_per_batch and samples_per_logo) logos_per_batch: Number of different logos per batch samples_per_logo: Samples per logo in batch num_workers: Number of data loading workers train_split: Fraction for training val_split: Fraction for validation test_split: Fraction for testing seed: Random seed augmentation_strength: "light", "medium", or "strong" batches_per_epoch: Number of batches per training epoch Returns: Tuple of (train_loader, val_loader, test_loader) """ # Load logo data logo_data = LogoDataset( db_path=db_path, reference_dir=reference_dir, train_split=train_split, val_split=val_split, test_split=test_split, seed=seed, ) # Print split info split_info = logo_data.get_split_info() print(f"Dataset loaded:") print(f" Total logos: {split_info['total_logos']}") print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images") print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images") print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images") # Create datasets train_dataset = LogoContrastiveDataset( logo_data=logo_data, split="train", logos_per_batch=logos_per_batch, samples_per_logo=samples_per_logo, transform=get_train_transforms(augmentation_strength), batches_per_epoch=batches_per_epoch, ) val_dataset = LogoContrastiveDataset( logo_data=logo_data, split="val", logos_per_batch=logos_per_batch, samples_per_logo=samples_per_logo, transform=get_val_transforms(), batches_per_epoch=batches_per_epoch // 10, # Fewer val batches ) test_dataset = LogoContrastiveDataset( logo_data=logo_data, split="test", logos_per_batch=logos_per_batch, samples_per_logo=samples_per_logo, transform=get_val_transforms(), batches_per_epoch=batches_per_epoch // 10, ) if test_split > 0 else None # Create dataloaders # Note: batch_size=1 because each __getitem__ already returns a batch train_loader = DataLoader( train_dataset, batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True, collate_fn=_collate_contrastive_batch, ) val_loader = DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=_collate_contrastive_batch, ) test_loader = None if test_dataset is not None: test_loader = DataLoader( test_dataset, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=_collate_contrastive_batch, ) return train_loader, val_loader, test_loader def _collate_contrastive_batch( batch: List[Tuple[torch.Tensor, torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Collate function that unpacks pre-batched data. Since LogoContrastiveDataset already returns batched data, we just squeeze the outer dimension. """ images, labels = batch[0] return images, labels