From 14a1bda3fa6c7cdcfc555d57ff4c914773a74a94 Mon Sep 17 00:00:00 2001 From: Rick McEwen Date: Mon, 5 Jan 2026 15:10:45 -0500 Subject: [PATCH] Add image-level split support for CLIP fine-tuning Image-level splits allow the model to see some images from each logo brand during training, unlike logo-level splits where test brands are completely unseen. This is less rigorous but more representative of real-world use. Changes: - Add configs/image_level_splits.yaml with gentler training settings: - split_level: "image" for image-level splits - temperature: 0.15 (softer contrastive learning) - learning_rate: 5e-6 (slower learning) - max_epochs: 30 (more epochs) - Update training/dataset.py: - Add split_level parameter to LogoDataset - Implement _split_images() for image-level splitting - Update LogoContrastiveDataset to use split-specific image mappings - Update training/config.py: - Add split_level field to TrainingConfig - Update train_clip_logo.py: - Pass split_level to create_dataloaders Usage: uv run python train_clip_logo.py --config configs/image_level_splits.yaml --- configs/image_level_splits.yaml | 78 ++++++++++++++++++ train_clip_logo.py | 1 + training/config.py | 3 +- training/dataset.py | 138 +++++++++++++++++++++++++++----- 4 files changed, 200 insertions(+), 20 deletions(-) create mode 100644 configs/image_level_splits.yaml diff --git a/configs/image_level_splits.yaml b/configs/image_level_splits.yaml new file mode 100644 index 0000000..04b9011 --- /dev/null +++ b/configs/image_level_splits.yaml @@ -0,0 +1,78 @@ +# Training configuration with IMAGE-LEVEL splits +# +# Unlike logo-level splits where test logos are completely unseen brands, +# image-level splits allow the model to see some images from each brand +# during training. This is less rigorous but more representative of +# real-world use where you have reference images for logos you want to detect. +# +# Also uses gentler contrastive learning settings to prevent over-separation. +# +# Usage: +# uv run python train_clip_logo.py --config configs/image_level_splits.yaml + +# Base model +base_model: "openai/clip-vit-large-patch14" + +# Dataset paths (relative to project root) +dataset_dir: "LogoDet-3K" +reference_dir: "reference_logos" +db_path: "test_data_mapping.db" + +# Data split configuration +# split_level: "image" means images are split, not logo brands +# This allows test set to contain images from brands seen during training +split_level: "image" +train_split: 0.7 +val_split: 0.15 +test_split: 0.15 + +# Batch construction +batch_size: 16 +logos_per_batch: 32 +samples_per_logo: 4 +gradient_accumulation_steps: 8 +num_workers: 4 + +# Model architecture - same as before +lora_r: 16 +lora_alpha: 32 +lora_dropout: 0.1 +freeze_layers: 12 +use_gradient_checkpointing: true + +# Training hyperparameters - GENTLER settings +learning_rate: 5.0e-6 # Reduced from 1e-5 +weight_decay: 0.01 +warmup_steps: 500 +max_epochs: 30 # More epochs with slower learning +mixed_precision: true + +# Loss function - HIGHER temperature for softer contrastive learning +temperature: 0.15 # Increased from 0.07 +loss_type: "infonce" +triplet_margin: 0.2 # Reduced from 0.3 + +# Early stopping +patience: 7 # More patience with gentler learning +min_delta: 0.001 + +# Checkpoints and output +checkpoint_dir: "checkpoints_image_split" +output_dir: "models/logo_detection/clip_finetuned_image_split" +save_every_n_epochs: 5 + +# Logging +log_every_n_steps: 10 +eval_every_n_epochs: 1 + +# Reproducibility +seed: 42 + +# Hard negative mining +use_hard_negatives: false +hard_negative_start_epoch: 10 +hard_negatives_per_logo: 10 + +# Data augmentation +use_augmentation: true +augmentation_strength: "medium" diff --git a/train_clip_logo.py b/train_clip_logo.py index 44e74b5..2724fcb 100644 --- a/train_clip_logo.py +++ b/train_clip_logo.py @@ -256,6 +256,7 @@ def main(): test_split=config.test_split, seed=config.seed, augmentation_strength=config.augmentation_strength, + split_level=getattr(config, 'split_level', 'logo'), ) # Create trainer diff --git a/training/config.py b/training/config.py index caf6b78..772e5f6 100644 --- a/training/config.py +++ b/training/config.py @@ -20,7 +20,8 @@ class TrainingConfig: reference_dir: str = "reference_logos" db_path: str = "test_data_mapping.db" - # Data split ratios + # Data split configuration + split_level: str = "logo" # "logo" for brand-level, "image" for image-level train_split: float = 0.7 val_split: float = 0.15 test_split: float = 0.15 diff --git a/training/dataset.py b/training/dataset.py index 95b2fe1..ca44b57 100644 --- a/training/dataset.py +++ b/training/dataset.py @@ -84,7 +84,7 @@ class LogoDataset: """ Manages logo data from the SQLite database. - Handles loading logo-to-image mappings and splitting by logo brand. + Handles loading logo-to-image mappings and splitting by logo brand or image. """ def __init__( @@ -95,19 +95,57 @@ class LogoDataset: val_split: float = 0.15, test_split: float = 0.15, seed: int = 42, + split_level: str = "logo", ): + """ + Initialize the logo dataset. + + Args: + db_path: Path to SQLite database + reference_dir: Directory containing reference logo images + train_split: Fraction for training + val_split: Fraction for validation + test_split: Fraction for testing + seed: Random seed for reproducibility + split_level: "logo" for brand-level splits (test on unseen brands), + "image" for image-level splits (test on unseen images + from seen brands) + """ self.db_path = Path(db_path) self.reference_dir = Path(reference_dir) self.seed = seed + self.split_level = split_level # Load logo-to-images mapping from database self.logo_to_images = self._load_logo_mappings() self.all_logos = list(self.logo_to_images.keys()) - # Create logo-level splits - self.train_logos, self.val_logos, self.test_logos = self._split_logos( - train_split, val_split, test_split - ) + if split_level == "logo": + # Logo-level splits: test logos are completely unseen brands + self.train_logos, self.val_logos, self.test_logos = self._split_logos( + train_split, val_split, test_split + ) + # For logo-level splits, each split has its own logos + self.train_logo_to_images = { + l: self.logo_to_images[l] for l in self.train_logos + } + self.val_logo_to_images = { + l: self.logo_to_images[l] for l in self.val_logos + } + self.test_logo_to_images = { + l: self.logo_to_images[l] for l in self.test_logos + } + else: + # Image-level splits: all logos present in all splits, different images + ( + self.train_logo_to_images, + self.val_logo_to_images, + self.test_logo_to_images, + ) = self._split_images(train_split, val_split, test_split) + # All logos are in all splits + self.train_logos = list(self.train_logo_to_images.keys()) + self.val_logos = list(self.val_logo_to_images.keys()) + self.test_logos = list(self.test_logo_to_images.keys()) def _load_logo_mappings(self) -> Dict[str, List[Path]]: """Load logo name to image paths mapping from database.""" @@ -151,21 +189,74 @@ class LogoDataset: return train_logos, val_logos, test_logos - def get_split_info(self) -> Dict[str, int]: + def _split_images( + self, + train_split: float, + val_split: float, + test_split: float, + ) -> Tuple[Dict[str, List[Path]], Dict[str, List[Path]], Dict[str, List[Path]]]: + """ + Split images within each logo brand for train/val/test. + + Each logo brand will have images in all splits, allowing the model + to see some examples of each brand during training. + """ + random.seed(self.seed) + + train_logo_to_images: Dict[str, List[Path]] = {} + val_logo_to_images: Dict[str, List[Path]] = {} + test_logo_to_images: Dict[str, List[Path]] = {} + + for logo, images in self.logo_to_images.items(): + # Shuffle images for this logo + shuffled_images = images.copy() + random.shuffle(shuffled_images) + + n = len(shuffled_images) + if n == 1: + # Only one image: put in train only + train_logo_to_images[logo] = shuffled_images + continue + elif n == 2: + # Two images: one train, one val + train_logo_to_images[logo] = [shuffled_images[0]] + val_logo_to_images[logo] = [shuffled_images[1]] + continue + + # Normal split for 3+ images + train_end = max(1, int(n * train_split)) + val_end = train_end + max(1, int(n * val_split)) + + train_images = shuffled_images[:train_end] + val_images = shuffled_images[train_end:val_end] + test_images = shuffled_images[val_end:] + + # Ensure at least one image in train + if train_images: + train_logo_to_images[logo] = train_images + if val_images: + val_logo_to_images[logo] = val_images + if test_images: + test_logo_to_images[logo] = test_images + + return train_logo_to_images, val_logo_to_images, test_logo_to_images + + def get_split_info(self) -> Dict[str, any]: """Return information about the splits.""" return { + "split_level": self.split_level, "total_logos": len(self.all_logos), "train_logos": len(self.train_logos), "val_logos": len(self.val_logos), "test_logos": len(self.test_logos), "train_images": sum( - len(self.logo_to_images[l]) for l in self.train_logos + len(imgs) for imgs in self.train_logo_to_images.values() ), "val_images": sum( - len(self.logo_to_images[l]) for l in self.val_logos + len(imgs) for imgs in self.val_logo_to_images.values() ), "test_images": sum( - len(self.logo_to_images[l]) for l in self.test_logos + len(imgs) for imgs in self.test_logo_to_images.values() ), } @@ -205,29 +296,33 @@ class LogoContrastiveDataset(Dataset): self.transform = transform self.batches_per_epoch = batches_per_epoch - # Get logos for this split + # Get logos and their images for this split + # This respects both logo-level and image-level splits if split == "train": self.logos = logo_data.train_logos + self.logo_to_images = logo_data.train_logo_to_images elif split == "val": self.logos = logo_data.val_logos + self.logo_to_images = logo_data.val_logo_to_images else: self.logos = logo_data.test_logos + self.logo_to_images = logo_data.test_logo_to_images - # Filter logos with enough samples + # Filter logos with enough samples for this split self.valid_logos = [ logo for logo in self.logos - if len(logo_data.logo_to_images[logo]) >= samples_per_logo + if logo in self.logo_to_images and len(self.logo_to_images[logo]) >= samples_per_logo ] # For logos with fewer samples, we'll use with replacement self.logos_needing_replacement = [ logo for logo in self.logos - if len(logo_data.logo_to_images[logo]) < samples_per_logo + if logo in self.logo_to_images and len(self.logo_to_images[logo]) < samples_per_logo ] - # Create label mapping + # Create label mapping (use all logos from the full dataset for consistent labels) self.logo_to_label = { - logo: idx for idx, logo in enumerate(self.logos) + logo: idx for idx, logo in enumerate(logo_data.all_logos) } def __len__(self) -> int: @@ -244,12 +339,13 @@ class LogoContrastiveDataset(Dataset): images = [] labels = [] - # Sample K logos for this batch - k = min(self.logos_per_batch, len(self.logos)) - batch_logos = random.sample(self.logos, k) + # Sample K logos for this batch (only from logos that have images in this split) + available_logos = [l for l in self.logos if l in self.logo_to_images] + k = min(self.logos_per_batch, len(available_logos)) + batch_logos = random.sample(available_logos, k) for logo in batch_logos: - logo_images = self.logo_data.logo_to_images[logo] + logo_images = self.logo_to_images[logo] # Sample M images for this logo if len(logo_images) >= self.samples_per_logo: @@ -353,6 +449,7 @@ def create_dataloaders( seed: int = 42, augmentation_strength: str = "medium", batches_per_epoch: int = 1000, + split_level: str = "logo", ) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]: """ Create train, validation, and optionally test dataloaders. @@ -370,6 +467,7 @@ def create_dataloaders( seed: Random seed augmentation_strength: "light", "medium", or "strong" batches_per_epoch: Number of batches per training epoch + split_level: "logo" for brand-level splits, "image" for image-level splits Returns: Tuple of (train_loader, val_loader, test_loader) @@ -382,11 +480,13 @@ def create_dataloaders( val_split=val_split, test_split=test_split, seed=seed, + split_level=split_level, ) # Print split info split_info = logo_data.get_split_info() print(f"Dataset loaded:") + print(f" Split level: {split_info['split_level']}") print(f" Total logos: {split_info['total_logos']}") print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images") print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")