Add image-level split support for CLIP fine-tuning
Image-level splits allow the model to see some images from each logo brand during training, unlike logo-level splits where test brands are completely unseen. This is less rigorous but more representative of real-world use. Changes: - Add configs/image_level_splits.yaml with gentler training settings: - split_level: "image" for image-level splits - temperature: 0.15 (softer contrastive learning) - learning_rate: 5e-6 (slower learning) - max_epochs: 30 (more epochs) - Update training/dataset.py: - Add split_level parameter to LogoDataset - Implement _split_images() for image-level splitting - Update LogoContrastiveDataset to use split-specific image mappings - Update training/config.py: - Add split_level field to TrainingConfig - Update train_clip_logo.py: - Pass split_level to create_dataloaders Usage: uv run python train_clip_logo.py --config configs/image_level_splits.yaml
This commit is contained in:
78
configs/image_level_splits.yaml
Normal file
78
configs/image_level_splits.yaml
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# Training configuration with IMAGE-LEVEL splits
|
||||||
|
#
|
||||||
|
# Unlike logo-level splits where test logos are completely unseen brands,
|
||||||
|
# image-level splits allow the model to see some images from each brand
|
||||||
|
# during training. This is less rigorous but more representative of
|
||||||
|
# real-world use where you have reference images for logos you want to detect.
|
||||||
|
#
|
||||||
|
# Also uses gentler contrastive learning settings to prevent over-separation.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# uv run python train_clip_logo.py --config configs/image_level_splits.yaml
|
||||||
|
|
||||||
|
# Base model
|
||||||
|
base_model: "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Dataset paths (relative to project root)
|
||||||
|
dataset_dir: "LogoDet-3K"
|
||||||
|
reference_dir: "reference_logos"
|
||||||
|
db_path: "test_data_mapping.db"
|
||||||
|
|
||||||
|
# Data split configuration
|
||||||
|
# split_level: "image" means images are split, not logo brands
|
||||||
|
# This allows test set to contain images from brands seen during training
|
||||||
|
split_level: "image"
|
||||||
|
train_split: 0.7
|
||||||
|
val_split: 0.15
|
||||||
|
test_split: 0.15
|
||||||
|
|
||||||
|
# Batch construction
|
||||||
|
batch_size: 16
|
||||||
|
logos_per_batch: 32
|
||||||
|
samples_per_logo: 4
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
# Model architecture - same as before
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.1
|
||||||
|
freeze_layers: 12
|
||||||
|
use_gradient_checkpointing: true
|
||||||
|
|
||||||
|
# Training hyperparameters - GENTLER settings
|
||||||
|
learning_rate: 5.0e-6 # Reduced from 1e-5
|
||||||
|
weight_decay: 0.01
|
||||||
|
warmup_steps: 500
|
||||||
|
max_epochs: 30 # More epochs with slower learning
|
||||||
|
mixed_precision: true
|
||||||
|
|
||||||
|
# Loss function - HIGHER temperature for softer contrastive learning
|
||||||
|
temperature: 0.15 # Increased from 0.07
|
||||||
|
loss_type: "infonce"
|
||||||
|
triplet_margin: 0.2 # Reduced from 0.3
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience: 7 # More patience with gentler learning
|
||||||
|
min_delta: 0.001
|
||||||
|
|
||||||
|
# Checkpoints and output
|
||||||
|
checkpoint_dir: "checkpoints_image_split"
|
||||||
|
output_dir: "models/logo_detection/clip_finetuned_image_split"
|
||||||
|
save_every_n_epochs: 5
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_every_n_steps: 10
|
||||||
|
eval_every_n_epochs: 1
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
# Hard negative mining
|
||||||
|
use_hard_negatives: false
|
||||||
|
hard_negative_start_epoch: 10
|
||||||
|
hard_negatives_per_logo: 10
|
||||||
|
|
||||||
|
# Data augmentation
|
||||||
|
use_augmentation: true
|
||||||
|
augmentation_strength: "medium"
|
||||||
@ -256,6 +256,7 @@ def main():
|
|||||||
test_split=config.test_split,
|
test_split=config.test_split,
|
||||||
seed=config.seed,
|
seed=config.seed,
|
||||||
augmentation_strength=config.augmentation_strength,
|
augmentation_strength=config.augmentation_strength,
|
||||||
|
split_level=getattr(config, 'split_level', 'logo'),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create trainer
|
# Create trainer
|
||||||
|
|||||||
@ -20,7 +20,8 @@ class TrainingConfig:
|
|||||||
reference_dir: str = "reference_logos"
|
reference_dir: str = "reference_logos"
|
||||||
db_path: str = "test_data_mapping.db"
|
db_path: str = "test_data_mapping.db"
|
||||||
|
|
||||||
# Data split ratios
|
# Data split configuration
|
||||||
|
split_level: str = "logo" # "logo" for brand-level, "image" for image-level
|
||||||
train_split: float = 0.7
|
train_split: float = 0.7
|
||||||
val_split: float = 0.15
|
val_split: float = 0.15
|
||||||
test_split: float = 0.15
|
test_split: float = 0.15
|
||||||
|
|||||||
@ -84,7 +84,7 @@ class LogoDataset:
|
|||||||
"""
|
"""
|
||||||
Manages logo data from the SQLite database.
|
Manages logo data from the SQLite database.
|
||||||
|
|
||||||
Handles loading logo-to-image mappings and splitting by logo brand.
|
Handles loading logo-to-image mappings and splitting by logo brand or image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -95,19 +95,57 @@ class LogoDataset:
|
|||||||
val_split: float = 0.15,
|
val_split: float = 0.15,
|
||||||
test_split: float = 0.15,
|
test_split: float = 0.15,
|
||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
|
split_level: str = "logo",
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize the logo dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database
|
||||||
|
reference_dir: Directory containing reference logo images
|
||||||
|
train_split: Fraction for training
|
||||||
|
val_split: Fraction for validation
|
||||||
|
test_split: Fraction for testing
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
split_level: "logo" for brand-level splits (test on unseen brands),
|
||||||
|
"image" for image-level splits (test on unseen images
|
||||||
|
from seen brands)
|
||||||
|
"""
|
||||||
self.db_path = Path(db_path)
|
self.db_path = Path(db_path)
|
||||||
self.reference_dir = Path(reference_dir)
|
self.reference_dir = Path(reference_dir)
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
self.split_level = split_level
|
||||||
|
|
||||||
# Load logo-to-images mapping from database
|
# Load logo-to-images mapping from database
|
||||||
self.logo_to_images = self._load_logo_mappings()
|
self.logo_to_images = self._load_logo_mappings()
|
||||||
self.all_logos = list(self.logo_to_images.keys())
|
self.all_logos = list(self.logo_to_images.keys())
|
||||||
|
|
||||||
# Create logo-level splits
|
if split_level == "logo":
|
||||||
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
# Logo-level splits: test logos are completely unseen brands
|
||||||
train_split, val_split, test_split
|
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
||||||
)
|
train_split, val_split, test_split
|
||||||
|
)
|
||||||
|
# For logo-level splits, each split has its own logos
|
||||||
|
self.train_logo_to_images = {
|
||||||
|
l: self.logo_to_images[l] for l in self.train_logos
|
||||||
|
}
|
||||||
|
self.val_logo_to_images = {
|
||||||
|
l: self.logo_to_images[l] for l in self.val_logos
|
||||||
|
}
|
||||||
|
self.test_logo_to_images = {
|
||||||
|
l: self.logo_to_images[l] for l in self.test_logos
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Image-level splits: all logos present in all splits, different images
|
||||||
|
(
|
||||||
|
self.train_logo_to_images,
|
||||||
|
self.val_logo_to_images,
|
||||||
|
self.test_logo_to_images,
|
||||||
|
) = self._split_images(train_split, val_split, test_split)
|
||||||
|
# All logos are in all splits
|
||||||
|
self.train_logos = list(self.train_logo_to_images.keys())
|
||||||
|
self.val_logos = list(self.val_logo_to_images.keys())
|
||||||
|
self.test_logos = list(self.test_logo_to_images.keys())
|
||||||
|
|
||||||
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
|
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
|
||||||
"""Load logo name to image paths mapping from database."""
|
"""Load logo name to image paths mapping from database."""
|
||||||
@ -151,21 +189,74 @@ class LogoDataset:
|
|||||||
|
|
||||||
return train_logos, val_logos, test_logos
|
return train_logos, val_logos, test_logos
|
||||||
|
|
||||||
def get_split_info(self) -> Dict[str, int]:
|
def _split_images(
|
||||||
|
self,
|
||||||
|
train_split: float,
|
||||||
|
val_split: float,
|
||||||
|
test_split: float,
|
||||||
|
) -> Tuple[Dict[str, List[Path]], Dict[str, List[Path]], Dict[str, List[Path]]]:
|
||||||
|
"""
|
||||||
|
Split images within each logo brand for train/val/test.
|
||||||
|
|
||||||
|
Each logo brand will have images in all splits, allowing the model
|
||||||
|
to see some examples of each brand during training.
|
||||||
|
"""
|
||||||
|
random.seed(self.seed)
|
||||||
|
|
||||||
|
train_logo_to_images: Dict[str, List[Path]] = {}
|
||||||
|
val_logo_to_images: Dict[str, List[Path]] = {}
|
||||||
|
test_logo_to_images: Dict[str, List[Path]] = {}
|
||||||
|
|
||||||
|
for logo, images in self.logo_to_images.items():
|
||||||
|
# Shuffle images for this logo
|
||||||
|
shuffled_images = images.copy()
|
||||||
|
random.shuffle(shuffled_images)
|
||||||
|
|
||||||
|
n = len(shuffled_images)
|
||||||
|
if n == 1:
|
||||||
|
# Only one image: put in train only
|
||||||
|
train_logo_to_images[logo] = shuffled_images
|
||||||
|
continue
|
||||||
|
elif n == 2:
|
||||||
|
# Two images: one train, one val
|
||||||
|
train_logo_to_images[logo] = [shuffled_images[0]]
|
||||||
|
val_logo_to_images[logo] = [shuffled_images[1]]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Normal split for 3+ images
|
||||||
|
train_end = max(1, int(n * train_split))
|
||||||
|
val_end = train_end + max(1, int(n * val_split))
|
||||||
|
|
||||||
|
train_images = shuffled_images[:train_end]
|
||||||
|
val_images = shuffled_images[train_end:val_end]
|
||||||
|
test_images = shuffled_images[val_end:]
|
||||||
|
|
||||||
|
# Ensure at least one image in train
|
||||||
|
if train_images:
|
||||||
|
train_logo_to_images[logo] = train_images
|
||||||
|
if val_images:
|
||||||
|
val_logo_to_images[logo] = val_images
|
||||||
|
if test_images:
|
||||||
|
test_logo_to_images[logo] = test_images
|
||||||
|
|
||||||
|
return train_logo_to_images, val_logo_to_images, test_logo_to_images
|
||||||
|
|
||||||
|
def get_split_info(self) -> Dict[str, any]:
|
||||||
"""Return information about the splits."""
|
"""Return information about the splits."""
|
||||||
return {
|
return {
|
||||||
|
"split_level": self.split_level,
|
||||||
"total_logos": len(self.all_logos),
|
"total_logos": len(self.all_logos),
|
||||||
"train_logos": len(self.train_logos),
|
"train_logos": len(self.train_logos),
|
||||||
"val_logos": len(self.val_logos),
|
"val_logos": len(self.val_logos),
|
||||||
"test_logos": len(self.test_logos),
|
"test_logos": len(self.test_logos),
|
||||||
"train_images": sum(
|
"train_images": sum(
|
||||||
len(self.logo_to_images[l]) for l in self.train_logos
|
len(imgs) for imgs in self.train_logo_to_images.values()
|
||||||
),
|
),
|
||||||
"val_images": sum(
|
"val_images": sum(
|
||||||
len(self.logo_to_images[l]) for l in self.val_logos
|
len(imgs) for imgs in self.val_logo_to_images.values()
|
||||||
),
|
),
|
||||||
"test_images": sum(
|
"test_images": sum(
|
||||||
len(self.logo_to_images[l]) for l in self.test_logos
|
len(imgs) for imgs in self.test_logo_to_images.values()
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -205,29 +296,33 @@ class LogoContrastiveDataset(Dataset):
|
|||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.batches_per_epoch = batches_per_epoch
|
self.batches_per_epoch = batches_per_epoch
|
||||||
|
|
||||||
# Get logos for this split
|
# Get logos and their images for this split
|
||||||
|
# This respects both logo-level and image-level splits
|
||||||
if split == "train":
|
if split == "train":
|
||||||
self.logos = logo_data.train_logos
|
self.logos = logo_data.train_logos
|
||||||
|
self.logo_to_images = logo_data.train_logo_to_images
|
||||||
elif split == "val":
|
elif split == "val":
|
||||||
self.logos = logo_data.val_logos
|
self.logos = logo_data.val_logos
|
||||||
|
self.logo_to_images = logo_data.val_logo_to_images
|
||||||
else:
|
else:
|
||||||
self.logos = logo_data.test_logos
|
self.logos = logo_data.test_logos
|
||||||
|
self.logo_to_images = logo_data.test_logo_to_images
|
||||||
|
|
||||||
# Filter logos with enough samples
|
# Filter logos with enough samples for this split
|
||||||
self.valid_logos = [
|
self.valid_logos = [
|
||||||
logo for logo in self.logos
|
logo for logo in self.logos
|
||||||
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
|
if logo in self.logo_to_images and len(self.logo_to_images[logo]) >= samples_per_logo
|
||||||
]
|
]
|
||||||
|
|
||||||
# For logos with fewer samples, we'll use with replacement
|
# For logos with fewer samples, we'll use with replacement
|
||||||
self.logos_needing_replacement = [
|
self.logos_needing_replacement = [
|
||||||
logo for logo in self.logos
|
logo for logo in self.logos
|
||||||
if len(logo_data.logo_to_images[logo]) < samples_per_logo
|
if logo in self.logo_to_images and len(self.logo_to_images[logo]) < samples_per_logo
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create label mapping
|
# Create label mapping (use all logos from the full dataset for consistent labels)
|
||||||
self.logo_to_label = {
|
self.logo_to_label = {
|
||||||
logo: idx for idx, logo in enumerate(self.logos)
|
logo: idx for idx, logo in enumerate(logo_data.all_logos)
|
||||||
}
|
}
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@ -244,12 +339,13 @@ class LogoContrastiveDataset(Dataset):
|
|||||||
images = []
|
images = []
|
||||||
labels = []
|
labels = []
|
||||||
|
|
||||||
# Sample K logos for this batch
|
# Sample K logos for this batch (only from logos that have images in this split)
|
||||||
k = min(self.logos_per_batch, len(self.logos))
|
available_logos = [l for l in self.logos if l in self.logo_to_images]
|
||||||
batch_logos = random.sample(self.logos, k)
|
k = min(self.logos_per_batch, len(available_logos))
|
||||||
|
batch_logos = random.sample(available_logos, k)
|
||||||
|
|
||||||
for logo in batch_logos:
|
for logo in batch_logos:
|
||||||
logo_images = self.logo_data.logo_to_images[logo]
|
logo_images = self.logo_to_images[logo]
|
||||||
|
|
||||||
# Sample M images for this logo
|
# Sample M images for this logo
|
||||||
if len(logo_images) >= self.samples_per_logo:
|
if len(logo_images) >= self.samples_per_logo:
|
||||||
@ -353,6 +449,7 @@ def create_dataloaders(
|
|||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
augmentation_strength: str = "medium",
|
augmentation_strength: str = "medium",
|
||||||
batches_per_epoch: int = 1000,
|
batches_per_epoch: int = 1000,
|
||||||
|
split_level: str = "logo",
|
||||||
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
|
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
|
||||||
"""
|
"""
|
||||||
Create train, validation, and optionally test dataloaders.
|
Create train, validation, and optionally test dataloaders.
|
||||||
@ -370,6 +467,7 @@ def create_dataloaders(
|
|||||||
seed: Random seed
|
seed: Random seed
|
||||||
augmentation_strength: "light", "medium", or "strong"
|
augmentation_strength: "light", "medium", or "strong"
|
||||||
batches_per_epoch: Number of batches per training epoch
|
batches_per_epoch: Number of batches per training epoch
|
||||||
|
split_level: "logo" for brand-level splits, "image" for image-level splits
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (train_loader, val_loader, test_loader)
|
Tuple of (train_loader, val_loader, test_loader)
|
||||||
@ -382,11 +480,13 @@ def create_dataloaders(
|
|||||||
val_split=val_split,
|
val_split=val_split,
|
||||||
test_split=test_split,
|
test_split=test_split,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
split_level=split_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print split info
|
# Print split info
|
||||||
split_info = logo_data.get_split_info()
|
split_info = logo_data.get_split_info()
|
||||||
print(f"Dataset loaded:")
|
print(f"Dataset loaded:")
|
||||||
|
print(f" Split level: {split_info['split_level']}")
|
||||||
print(f" Total logos: {split_info['total_logos']}")
|
print(f" Total logos: {split_info['total_logos']}")
|
||||||
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
|
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
|
||||||
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
|
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
|
||||||
|
|||||||
Reference in New Issue
Block a user