""" Dataset classes for contrastive learning on logo images. """ import random import sqlite3 from pathlib import Path from typing import Dict, List, Optional, Tuple import torch from PIL import Image from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import transforms # CLIP normalization values CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] CLIP_STD = [0.26862954, 0.26130258, 0.27577711] def get_train_transforms(strength: str = "medium") -> transforms.Compose: """ Get training data augmentation transforms. Args: strength: Augmentation strength - "light", "medium", or "strong" Returns: Composed transforms for training """ if strength == "light": return transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.ColorJitter(brightness=0.1, contrast=0.1), transforms.ToTensor(), transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), ]) elif strength == "medium": return transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=15), transforms.ColorJitter( brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05 ), transforms.RandomAffine( degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1) ), transforms.RandomGrayscale(p=0.1), transforms.ToTensor(), transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), ]) else: # strong return transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.1), transforms.RandomRotation(degrees=30), transforms.ColorJitter( brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1 ), transforms.RandomAffine( degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10 ), transforms.RandomGrayscale(p=0.2), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)), transforms.ToTensor(), transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), ]) def get_val_transforms() -> transforms.Compose: """Get validation/test transforms (no augmentation).""" return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), ]) class LogoDataset: """ Manages logo data from the SQLite database. Handles loading logo-to-image mappings and splitting by logo brand or image. """ def __init__( self, db_path: str, reference_dir: str, train_split: float = 0.7, val_split: float = 0.15, test_split: float = 0.15, seed: int = 42, split_level: str = "logo", ): """ Initialize the logo dataset. Args: db_path: Path to SQLite database reference_dir: Directory containing reference logo images train_split: Fraction for training val_split: Fraction for validation test_split: Fraction for testing seed: Random seed for reproducibility split_level: "logo" for brand-level splits (test on unseen brands), "image" for image-level splits (test on unseen images from seen brands) """ self.db_path = Path(db_path) self.reference_dir = Path(reference_dir) self.seed = seed self.split_level = split_level # Load logo-to-images mapping from database self.logo_to_images = self._load_logo_mappings() self.all_logos = list(self.logo_to_images.keys()) if split_level == "logo": # Logo-level splits: test logos are completely unseen brands self.train_logos, self.val_logos, self.test_logos = self._split_logos( train_split, val_split, test_split ) # For logo-level splits, each split has its own logos self.train_logo_to_images = { l: self.logo_to_images[l] for l in self.train_logos } self.val_logo_to_images = { l: self.logo_to_images[l] for l in self.val_logos } self.test_logo_to_images = { l: self.logo_to_images[l] for l in self.test_logos } else: # Image-level splits: all logos present in all splits, different images ( self.train_logo_to_images, self.val_logo_to_images, self.test_logo_to_images, ) = self._split_images(train_split, val_split, test_split) # All logos are in all splits self.train_logos = list(self.train_logo_to_images.keys()) self.val_logos = list(self.val_logo_to_images.keys()) self.test_logos = list(self.test_logo_to_images.keys()) def _load_logo_mappings(self) -> Dict[str, List[Path]]: """Load logo name to image paths mapping from database.""" conn = sqlite3.connect(self.db_path) cursor = conn.cursor() cursor.execute(""" SELECT ln.name, rl.filename FROM reference_logos rl JOIN logo_names ln ON rl.logo_name_id = ln.id ORDER BY ln.name """) logo_to_images: Dict[str, List[Path]] = {} for logo_name, filename in cursor.fetchall(): if logo_name not in logo_to_images: logo_to_images[logo_name] = [] logo_to_images[logo_name].append(self.reference_dir / filename) conn.close() return logo_to_images def _split_logos( self, train_split: float, val_split: float, test_split: float, ) -> Tuple[List[str], List[str], List[str]]: """Split logos at brand level for train/val/test.""" random.seed(self.seed) logos = self.all_logos.copy() random.shuffle(logos) n = len(logos) train_end = int(n * train_split) val_end = train_end + int(n * val_split) train_logos = logos[:train_end] val_logos = logos[train_end:val_end] test_logos = logos[val_end:] return train_logos, val_logos, test_logos def _split_images( self, train_split: float, val_split: float, test_split: float, ) -> Tuple[Dict[str, List[Path]], Dict[str, List[Path]], Dict[str, List[Path]]]: """ Split images within each logo brand for train/val/test. Each logo brand will have images in all splits, allowing the model to see some examples of each brand during training. """ random.seed(self.seed) train_logo_to_images: Dict[str, List[Path]] = {} val_logo_to_images: Dict[str, List[Path]] = {} test_logo_to_images: Dict[str, List[Path]] = {} for logo, images in self.logo_to_images.items(): # Shuffle images for this logo shuffled_images = images.copy() random.shuffle(shuffled_images) n = len(shuffled_images) if n == 1: # Only one image: put in train only train_logo_to_images[logo] = shuffled_images continue elif n == 2: # Two images: one train, one val train_logo_to_images[logo] = [shuffled_images[0]] val_logo_to_images[logo] = [shuffled_images[1]] continue # Normal split for 3+ images train_end = max(1, int(n * train_split)) val_end = train_end + max(1, int(n * val_split)) train_images = shuffled_images[:train_end] val_images = shuffled_images[train_end:val_end] test_images = shuffled_images[val_end:] # Ensure at least one image in train if train_images: train_logo_to_images[logo] = train_images if val_images: val_logo_to_images[logo] = val_images if test_images: test_logo_to_images[logo] = test_images return train_logo_to_images, val_logo_to_images, test_logo_to_images def get_split_info(self) -> Dict[str, any]: """Return information about the splits.""" return { "split_level": self.split_level, "total_logos": len(self.all_logos), "train_logos": len(self.train_logos), "val_logos": len(self.val_logos), "test_logos": len(self.test_logos), "train_images": sum( len(imgs) for imgs in self.train_logo_to_images.values() ), "val_images": sum( len(imgs) for imgs in self.val_logo_to_images.values() ), "test_images": sum( len(imgs) for imgs in self.test_logo_to_images.values() ), } class LogoContrastiveDataset(Dataset): """ Dataset for contrastive learning on logos. Each __getitem__ call returns a batch of images organized for contrastive learning: K different logos with M samples each, ensuring positive pairs exist within each batch. """ def __init__( self, logo_data: LogoDataset, split: str = "train", logos_per_batch: int = 32, samples_per_logo: int = 4, transform: Optional[transforms.Compose] = None, batches_per_epoch: int = 1000, ): """ Initialize the contrastive dataset. Args: logo_data: LogoDataset instance with logo mappings split: One of "train", "val", or "test" logos_per_batch: Number of different logos per batch samples_per_logo: Number of samples for each logo transform: Image transforms to apply batches_per_epoch: Number of batches per epoch """ self.logo_data = logo_data self.logos_per_batch = logos_per_batch self.samples_per_logo = samples_per_logo self.transform = transform self.batches_per_epoch = batches_per_epoch # Get logos and their images for this split # This respects both logo-level and image-level splits if split == "train": self.logos = logo_data.train_logos self.logo_to_images = logo_data.train_logo_to_images elif split == "val": self.logos = logo_data.val_logos self.logo_to_images = logo_data.val_logo_to_images else: self.logos = logo_data.test_logos self.logo_to_images = logo_data.test_logo_to_images # Filter logos with enough samples for this split self.valid_logos = [ logo for logo in self.logos if logo in self.logo_to_images and len(self.logo_to_images[logo]) >= samples_per_logo ] # For logos with fewer samples, we'll use with replacement self.logos_needing_replacement = [ logo for logo in self.logos if logo in self.logo_to_images and len(self.logo_to_images[logo]) < samples_per_logo ] # Create label mapping (use all logos from the full dataset for consistent labels) self.logo_to_label = { logo: idx for idx, logo in enumerate(logo_data.all_logos) } def __len__(self) -> int: return self.batches_per_epoch def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Get a batch of images for contrastive learning. Returns: images: Tensor of shape [K*M, 3, 224, 224] labels: Tensor of shape [K*M] with logo class indices """ images = [] labels = [] # Sample K logos for this batch (only from logos that have images in this split) available_logos = [l for l in self.logos if l in self.logo_to_images] k = min(self.logos_per_batch, len(available_logos)) batch_logos = random.sample(available_logos, k) for logo in batch_logos: logo_images = self.logo_to_images[logo] # Sample M images for this logo if len(logo_images) >= self.samples_per_logo: sampled_paths = random.sample(logo_images, self.samples_per_logo) else: # Sample with replacement if not enough images sampled_paths = random.choices( logo_images, k=self.samples_per_logo ) # Load and transform images for img_path in sampled_paths: try: img = Image.open(img_path).convert("RGB") if self.transform: img = self.transform(img) else: img = get_val_transforms()(img) images.append(img) labels.append(self.logo_to_label[logo]) except Exception as e: # Skip problematic images, sample another continue # Stack into tensors if len(images) == 0: # Fallback: return dummy batch return ( torch.zeros(1, 3, 224, 224), torch.zeros(1, dtype=torch.long), ) images_tensor = torch.stack(images) labels_tensor = torch.tensor(labels, dtype=torch.long) return images_tensor, labels_tensor class BalancedBatchSampler(Sampler): """ Sampler that ensures each batch has a balanced distribution of logos. Used with a flattened dataset where each sample is a single image. """ def __init__( self, logo_labels: List[int], logos_per_batch: int, samples_per_logo: int, num_batches: int, ): self.logo_labels = logo_labels self.logos_per_batch = logos_per_batch self.samples_per_logo = samples_per_logo self.num_batches = num_batches # Group indices by logo self.logo_to_indices: Dict[int, List[int]] = {} for idx, label in enumerate(logo_labels): if label not in self.logo_to_indices: self.logo_to_indices[label] = [] self.logo_to_indices[label].append(idx) self.all_logos = list(self.logo_to_indices.keys()) def __iter__(self): for _ in range(self.num_batches): batch_indices = [] # Sample logos for this batch logos = random.sample( self.all_logos, min(self.logos_per_batch, len(self.all_logos)), ) for logo in logos: indices = self.logo_to_indices[logo] if len(indices) >= self.samples_per_logo: sampled = random.sample(indices, self.samples_per_logo) else: sampled = random.choices(indices, k=self.samples_per_logo) batch_indices.extend(sampled) yield batch_indices def __len__(self): return self.num_batches def create_dataloaders( db_path: str, reference_dir: str, batch_size: int = 16, logos_per_batch: int = 32, samples_per_logo: int = 4, num_workers: int = 4, train_split: float = 0.7, val_split: float = 0.15, test_split: float = 0.15, seed: int = 42, augmentation_strength: str = "medium", batches_per_epoch: int = 1000, split_level: str = "logo", ) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]: """ Create train, validation, and optionally test dataloaders. Args: db_path: Path to SQLite database reference_dir: Directory containing reference logo images batch_size: Not used directly (see logos_per_batch and samples_per_logo) logos_per_batch: Number of different logos per batch samples_per_logo: Samples per logo in batch num_workers: Number of data loading workers train_split: Fraction for training val_split: Fraction for validation test_split: Fraction for testing seed: Random seed augmentation_strength: "light", "medium", or "strong" batches_per_epoch: Number of batches per training epoch split_level: "logo" for brand-level splits, "image" for image-level splits Returns: Tuple of (train_loader, val_loader, test_loader) """ # Load logo data logo_data = LogoDataset( db_path=db_path, reference_dir=reference_dir, train_split=train_split, val_split=val_split, test_split=test_split, seed=seed, split_level=split_level, ) # Print split info split_info = logo_data.get_split_info() print(f"Dataset loaded:") print(f" Split level: {split_info['split_level']}") print(f" Total logos: {split_info['total_logos']}") print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images") print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images") print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images") # Create datasets train_dataset = LogoContrastiveDataset( logo_data=logo_data, split="train", logos_per_batch=logos_per_batch, samples_per_logo=samples_per_logo, transform=get_train_transforms(augmentation_strength), batches_per_epoch=batches_per_epoch, ) val_dataset = LogoContrastiveDataset( logo_data=logo_data, split="val", logos_per_batch=logos_per_batch, samples_per_logo=samples_per_logo, transform=get_val_transforms(), batches_per_epoch=batches_per_epoch // 10, # Fewer val batches ) test_dataset = LogoContrastiveDataset( logo_data=logo_data, split="test", logos_per_batch=logos_per_batch, samples_per_logo=samples_per_logo, transform=get_val_transforms(), batches_per_epoch=batches_per_epoch // 10, ) if test_split > 0 else None # Create dataloaders # Note: batch_size=1 because each __getitem__ already returns a batch train_loader = DataLoader( train_dataset, batch_size=1, shuffle=True, num_workers=num_workers, pin_memory=True, collate_fn=_collate_contrastive_batch, ) val_loader = DataLoader( val_dataset, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=_collate_contrastive_batch, ) test_loader = None if test_dataset is not None: test_loader = DataLoader( test_dataset, batch_size=1, shuffle=False, num_workers=num_workers, pin_memory=True, collate_fn=_collate_contrastive_batch, ) return train_loader, val_loader, test_loader def _collate_contrastive_batch( batch: List[Tuple[torch.Tensor, torch.Tensor]] ) -> Tuple[torch.Tensor, torch.Tensor]: """ Collate function that unpacks pre-batched data. Since LogoContrastiveDataset already returns batched data, we just squeeze the outer dimension. """ images, labels = batch[0] return images, labels