Add image-level split support for CLIP fine-tuning
Image-level splits allow the model to see some images from each logo brand during training, unlike logo-level splits where test brands are completely unseen. This is less rigorous but more representative of real-world use. Changes: - Add configs/image_level_splits.yaml with gentler training settings: - split_level: "image" for image-level splits - temperature: 0.15 (softer contrastive learning) - learning_rate: 5e-6 (slower learning) - max_epochs: 30 (more epochs) - Update training/dataset.py: - Add split_level parameter to LogoDataset - Implement _split_images() for image-level splitting - Update LogoContrastiveDataset to use split-specific image mappings - Update training/config.py: - Add split_level field to TrainingConfig - Update train_clip_logo.py: - Pass split_level to create_dataloaders Usage: uv run python train_clip_logo.py --config configs/image_level_splits.yaml
This commit is contained in:
78
configs/image_level_splits.yaml
Normal file
78
configs/image_level_splits.yaml
Normal file
@ -0,0 +1,78 @@
|
||||
# Training configuration with IMAGE-LEVEL splits
|
||||
#
|
||||
# Unlike logo-level splits where test logos are completely unseen brands,
|
||||
# image-level splits allow the model to see some images from each brand
|
||||
# during training. This is less rigorous but more representative of
|
||||
# real-world use where you have reference images for logos you want to detect.
|
||||
#
|
||||
# Also uses gentler contrastive learning settings to prevent over-separation.
|
||||
#
|
||||
# Usage:
|
||||
# uv run python train_clip_logo.py --config configs/image_level_splits.yaml
|
||||
|
||||
# Base model
|
||||
base_model: "openai/clip-vit-large-patch14"
|
||||
|
||||
# Dataset paths (relative to project root)
|
||||
dataset_dir: "LogoDet-3K"
|
||||
reference_dir: "reference_logos"
|
||||
db_path: "test_data_mapping.db"
|
||||
|
||||
# Data split configuration
|
||||
# split_level: "image" means images are split, not logo brands
|
||||
# This allows test set to contain images from brands seen during training
|
||||
split_level: "image"
|
||||
train_split: 0.7
|
||||
val_split: 0.15
|
||||
test_split: 0.15
|
||||
|
||||
# Batch construction
|
||||
batch_size: 16
|
||||
logos_per_batch: 32
|
||||
samples_per_logo: 4
|
||||
gradient_accumulation_steps: 8
|
||||
num_workers: 4
|
||||
|
||||
# Model architecture - same as before
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.1
|
||||
freeze_layers: 12
|
||||
use_gradient_checkpointing: true
|
||||
|
||||
# Training hyperparameters - GENTLER settings
|
||||
learning_rate: 5.0e-6 # Reduced from 1e-5
|
||||
weight_decay: 0.01
|
||||
warmup_steps: 500
|
||||
max_epochs: 30 # More epochs with slower learning
|
||||
mixed_precision: true
|
||||
|
||||
# Loss function - HIGHER temperature for softer contrastive learning
|
||||
temperature: 0.15 # Increased from 0.07
|
||||
loss_type: "infonce"
|
||||
triplet_margin: 0.2 # Reduced from 0.3
|
||||
|
||||
# Early stopping
|
||||
patience: 7 # More patience with gentler learning
|
||||
min_delta: 0.001
|
||||
|
||||
# Checkpoints and output
|
||||
checkpoint_dir: "checkpoints_image_split"
|
||||
output_dir: "models/logo_detection/clip_finetuned_image_split"
|
||||
save_every_n_epochs: 5
|
||||
|
||||
# Logging
|
||||
log_every_n_steps: 10
|
||||
eval_every_n_epochs: 1
|
||||
|
||||
# Reproducibility
|
||||
seed: 42
|
||||
|
||||
# Hard negative mining
|
||||
use_hard_negatives: false
|
||||
hard_negative_start_epoch: 10
|
||||
hard_negatives_per_logo: 10
|
||||
|
||||
# Data augmentation
|
||||
use_augmentation: true
|
||||
augmentation_strength: "medium"
|
||||
@ -256,6 +256,7 @@ def main():
|
||||
test_split=config.test_split,
|
||||
seed=config.seed,
|
||||
augmentation_strength=config.augmentation_strength,
|
||||
split_level=getattr(config, 'split_level', 'logo'),
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
|
||||
@ -20,7 +20,8 @@ class TrainingConfig:
|
||||
reference_dir: str = "reference_logos"
|
||||
db_path: str = "test_data_mapping.db"
|
||||
|
||||
# Data split ratios
|
||||
# Data split configuration
|
||||
split_level: str = "logo" # "logo" for brand-level, "image" for image-level
|
||||
train_split: float = 0.7
|
||||
val_split: float = 0.15
|
||||
test_split: float = 0.15
|
||||
|
||||
@ -84,7 +84,7 @@ class LogoDataset:
|
||||
"""
|
||||
Manages logo data from the SQLite database.
|
||||
|
||||
Handles loading logo-to-image mappings and splitting by logo brand.
|
||||
Handles loading logo-to-image mappings and splitting by logo brand or image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -95,19 +95,57 @@ class LogoDataset:
|
||||
val_split: float = 0.15,
|
||||
test_split: float = 0.15,
|
||||
seed: int = 42,
|
||||
split_level: str = "logo",
|
||||
):
|
||||
"""
|
||||
Initialize the logo dataset.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
reference_dir: Directory containing reference logo images
|
||||
train_split: Fraction for training
|
||||
val_split: Fraction for validation
|
||||
test_split: Fraction for testing
|
||||
seed: Random seed for reproducibility
|
||||
split_level: "logo" for brand-level splits (test on unseen brands),
|
||||
"image" for image-level splits (test on unseen images
|
||||
from seen brands)
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.reference_dir = Path(reference_dir)
|
||||
self.seed = seed
|
||||
self.split_level = split_level
|
||||
|
||||
# Load logo-to-images mapping from database
|
||||
self.logo_to_images = self._load_logo_mappings()
|
||||
self.all_logos = list(self.logo_to_images.keys())
|
||||
|
||||
# Create logo-level splits
|
||||
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
||||
train_split, val_split, test_split
|
||||
)
|
||||
if split_level == "logo":
|
||||
# Logo-level splits: test logos are completely unseen brands
|
||||
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
||||
train_split, val_split, test_split
|
||||
)
|
||||
# For logo-level splits, each split has its own logos
|
||||
self.train_logo_to_images = {
|
||||
l: self.logo_to_images[l] for l in self.train_logos
|
||||
}
|
||||
self.val_logo_to_images = {
|
||||
l: self.logo_to_images[l] for l in self.val_logos
|
||||
}
|
||||
self.test_logo_to_images = {
|
||||
l: self.logo_to_images[l] for l in self.test_logos
|
||||
}
|
||||
else:
|
||||
# Image-level splits: all logos present in all splits, different images
|
||||
(
|
||||
self.train_logo_to_images,
|
||||
self.val_logo_to_images,
|
||||
self.test_logo_to_images,
|
||||
) = self._split_images(train_split, val_split, test_split)
|
||||
# All logos are in all splits
|
||||
self.train_logos = list(self.train_logo_to_images.keys())
|
||||
self.val_logos = list(self.val_logo_to_images.keys())
|
||||
self.test_logos = list(self.test_logo_to_images.keys())
|
||||
|
||||
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
|
||||
"""Load logo name to image paths mapping from database."""
|
||||
@ -151,21 +189,74 @@ class LogoDataset:
|
||||
|
||||
return train_logos, val_logos, test_logos
|
||||
|
||||
def get_split_info(self) -> Dict[str, int]:
|
||||
def _split_images(
|
||||
self,
|
||||
train_split: float,
|
||||
val_split: float,
|
||||
test_split: float,
|
||||
) -> Tuple[Dict[str, List[Path]], Dict[str, List[Path]], Dict[str, List[Path]]]:
|
||||
"""
|
||||
Split images within each logo brand for train/val/test.
|
||||
|
||||
Each logo brand will have images in all splits, allowing the model
|
||||
to see some examples of each brand during training.
|
||||
"""
|
||||
random.seed(self.seed)
|
||||
|
||||
train_logo_to_images: Dict[str, List[Path]] = {}
|
||||
val_logo_to_images: Dict[str, List[Path]] = {}
|
||||
test_logo_to_images: Dict[str, List[Path]] = {}
|
||||
|
||||
for logo, images in self.logo_to_images.items():
|
||||
# Shuffle images for this logo
|
||||
shuffled_images = images.copy()
|
||||
random.shuffle(shuffled_images)
|
||||
|
||||
n = len(shuffled_images)
|
||||
if n == 1:
|
||||
# Only one image: put in train only
|
||||
train_logo_to_images[logo] = shuffled_images
|
||||
continue
|
||||
elif n == 2:
|
||||
# Two images: one train, one val
|
||||
train_logo_to_images[logo] = [shuffled_images[0]]
|
||||
val_logo_to_images[logo] = [shuffled_images[1]]
|
||||
continue
|
||||
|
||||
# Normal split for 3+ images
|
||||
train_end = max(1, int(n * train_split))
|
||||
val_end = train_end + max(1, int(n * val_split))
|
||||
|
||||
train_images = shuffled_images[:train_end]
|
||||
val_images = shuffled_images[train_end:val_end]
|
||||
test_images = shuffled_images[val_end:]
|
||||
|
||||
# Ensure at least one image in train
|
||||
if train_images:
|
||||
train_logo_to_images[logo] = train_images
|
||||
if val_images:
|
||||
val_logo_to_images[logo] = val_images
|
||||
if test_images:
|
||||
test_logo_to_images[logo] = test_images
|
||||
|
||||
return train_logo_to_images, val_logo_to_images, test_logo_to_images
|
||||
|
||||
def get_split_info(self) -> Dict[str, any]:
|
||||
"""Return information about the splits."""
|
||||
return {
|
||||
"split_level": self.split_level,
|
||||
"total_logos": len(self.all_logos),
|
||||
"train_logos": len(self.train_logos),
|
||||
"val_logos": len(self.val_logos),
|
||||
"test_logos": len(self.test_logos),
|
||||
"train_images": sum(
|
||||
len(self.logo_to_images[l]) for l in self.train_logos
|
||||
len(imgs) for imgs in self.train_logo_to_images.values()
|
||||
),
|
||||
"val_images": sum(
|
||||
len(self.logo_to_images[l]) for l in self.val_logos
|
||||
len(imgs) for imgs in self.val_logo_to_images.values()
|
||||
),
|
||||
"test_images": sum(
|
||||
len(self.logo_to_images[l]) for l in self.test_logos
|
||||
len(imgs) for imgs in self.test_logo_to_images.values()
|
||||
),
|
||||
}
|
||||
|
||||
@ -205,29 +296,33 @@ class LogoContrastiveDataset(Dataset):
|
||||
self.transform = transform
|
||||
self.batches_per_epoch = batches_per_epoch
|
||||
|
||||
# Get logos for this split
|
||||
# Get logos and their images for this split
|
||||
# This respects both logo-level and image-level splits
|
||||
if split == "train":
|
||||
self.logos = logo_data.train_logos
|
||||
self.logo_to_images = logo_data.train_logo_to_images
|
||||
elif split == "val":
|
||||
self.logos = logo_data.val_logos
|
||||
self.logo_to_images = logo_data.val_logo_to_images
|
||||
else:
|
||||
self.logos = logo_data.test_logos
|
||||
self.logo_to_images = logo_data.test_logo_to_images
|
||||
|
||||
# Filter logos with enough samples
|
||||
# Filter logos with enough samples for this split
|
||||
self.valid_logos = [
|
||||
logo for logo in self.logos
|
||||
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
|
||||
if logo in self.logo_to_images and len(self.logo_to_images[logo]) >= samples_per_logo
|
||||
]
|
||||
|
||||
# For logos with fewer samples, we'll use with replacement
|
||||
self.logos_needing_replacement = [
|
||||
logo for logo in self.logos
|
||||
if len(logo_data.logo_to_images[logo]) < samples_per_logo
|
||||
if logo in self.logo_to_images and len(self.logo_to_images[logo]) < samples_per_logo
|
||||
]
|
||||
|
||||
# Create label mapping
|
||||
# Create label mapping (use all logos from the full dataset for consistent labels)
|
||||
self.logo_to_label = {
|
||||
logo: idx for idx, logo in enumerate(self.logos)
|
||||
logo: idx for idx, logo in enumerate(logo_data.all_logos)
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
@ -244,12 +339,13 @@ class LogoContrastiveDataset(Dataset):
|
||||
images = []
|
||||
labels = []
|
||||
|
||||
# Sample K logos for this batch
|
||||
k = min(self.logos_per_batch, len(self.logos))
|
||||
batch_logos = random.sample(self.logos, k)
|
||||
# Sample K logos for this batch (only from logos that have images in this split)
|
||||
available_logos = [l for l in self.logos if l in self.logo_to_images]
|
||||
k = min(self.logos_per_batch, len(available_logos))
|
||||
batch_logos = random.sample(available_logos, k)
|
||||
|
||||
for logo in batch_logos:
|
||||
logo_images = self.logo_data.logo_to_images[logo]
|
||||
logo_images = self.logo_to_images[logo]
|
||||
|
||||
# Sample M images for this logo
|
||||
if len(logo_images) >= self.samples_per_logo:
|
||||
@ -353,6 +449,7 @@ def create_dataloaders(
|
||||
seed: int = 42,
|
||||
augmentation_strength: str = "medium",
|
||||
batches_per_epoch: int = 1000,
|
||||
split_level: str = "logo",
|
||||
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
|
||||
"""
|
||||
Create train, validation, and optionally test dataloaders.
|
||||
@ -370,6 +467,7 @@ def create_dataloaders(
|
||||
seed: Random seed
|
||||
augmentation_strength: "light", "medium", or "strong"
|
||||
batches_per_epoch: Number of batches per training epoch
|
||||
split_level: "logo" for brand-level splits, "image" for image-level splits
|
||||
|
||||
Returns:
|
||||
Tuple of (train_loader, val_loader, test_loader)
|
||||
@ -382,11 +480,13 @@ def create_dataloaders(
|
||||
val_split=val_split,
|
||||
test_split=test_split,
|
||||
seed=seed,
|
||||
split_level=split_level,
|
||||
)
|
||||
|
||||
# Print split info
|
||||
split_info = logo_data.get_split_info()
|
||||
print(f"Dataset loaded:")
|
||||
print(f" Split level: {split_info['split_level']}")
|
||||
print(f" Total logos: {split_info['total_logos']}")
|
||||
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
|
||||
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
|
||||
|
||||
Reference in New Issue
Block a user