Files
logo_test/training/dataset.py
Rick McEwen 44e8b6ae7d Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision
encoder on LogoDet-3K dataset for improved logo embedding similarity.

New training module (training/):
- config.py: TrainingConfig dataclass with all hyperparameters
- dataset.py: LogoContrastiveDataset with logo-level splits
- model.py: LogoFineTunedCLIP wrapper with LoRA support
- losses.py: InfoNCE, TripletLoss, SupConLoss implementations
- trainer.py: Training loop with mixed precision and checkpointing
- evaluation.py: EmbeddingEvaluator for validation metrics

New scripts:
- train_clip_logo.py: Main training entry point
- export_model.py: Export to HuggingFace-compatible format

Configurations:
- configs/jetson_orin.yaml: Optimized for Jetson Orin AGX
- configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs
- configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs

Documentation:
- CLIP_FINETUNING.md: Training guide and usage instructions
- CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates

Modified:
- logo_detection_detr.py: Add fine-tuned model loading support
- pyproject.toml: Add peft, pyyaml, torchvision dependencies
2026-01-04 13:45:25 -05:00

468 lines
15 KiB
Python

"""
Dataset classes for contrastive learning on logo images.
"""
import random
import sqlite3
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Sampler
from torchvision import transforms
# CLIP normalization values
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
def get_train_transforms(strength: str = "medium") -> transforms.Compose:
"""
Get training data augmentation transforms.
Args:
strength: Augmentation strength - "light", "medium", or "strong"
Returns:
Composed transforms for training
"""
if strength == "light":
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.1, contrast=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
elif strength == "medium":
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=15),
transforms.ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
),
transforms.RandomAffine(
degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
else: # strong
return transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.1),
transforms.RandomRotation(degrees=30),
transforms.ColorJitter(
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
),
transforms.RandomAffine(
degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10
),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
def get_val_transforms() -> transforms.Compose:
"""Get validation/test transforms (no augmentation)."""
return transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
])
class LogoDataset:
"""
Manages logo data from the SQLite database.
Handles loading logo-to-image mappings and splitting by logo brand.
"""
def __init__(
self,
db_path: str,
reference_dir: str,
train_split: float = 0.7,
val_split: float = 0.15,
test_split: float = 0.15,
seed: int = 42,
):
self.db_path = Path(db_path)
self.reference_dir = Path(reference_dir)
self.seed = seed
# Load logo-to-images mapping from database
self.logo_to_images = self._load_logo_mappings()
self.all_logos = list(self.logo_to_images.keys())
# Create logo-level splits
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
train_split, val_split, test_split
)
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
"""Load logo name to image paths mapping from database."""
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
cursor.execute("""
SELECT ln.name, rl.filename
FROM reference_logos rl
JOIN logo_names ln ON rl.logo_name_id = ln.id
ORDER BY ln.name
""")
logo_to_images: Dict[str, List[Path]] = {}
for logo_name, filename in cursor.fetchall():
if logo_name not in logo_to_images:
logo_to_images[logo_name] = []
logo_to_images[logo_name].append(self.reference_dir / filename)
conn.close()
return logo_to_images
def _split_logos(
self,
train_split: float,
val_split: float,
test_split: float,
) -> Tuple[List[str], List[str], List[str]]:
"""Split logos at brand level for train/val/test."""
random.seed(self.seed)
logos = self.all_logos.copy()
random.shuffle(logos)
n = len(logos)
train_end = int(n * train_split)
val_end = train_end + int(n * val_split)
train_logos = logos[:train_end]
val_logos = logos[train_end:val_end]
test_logos = logos[val_end:]
return train_logos, val_logos, test_logos
def get_split_info(self) -> Dict[str, int]:
"""Return information about the splits."""
return {
"total_logos": len(self.all_logos),
"train_logos": len(self.train_logos),
"val_logos": len(self.val_logos),
"test_logos": len(self.test_logos),
"train_images": sum(
len(self.logo_to_images[l]) for l in self.train_logos
),
"val_images": sum(
len(self.logo_to_images[l]) for l in self.val_logos
),
"test_images": sum(
len(self.logo_to_images[l]) for l in self.test_logos
),
}
class LogoContrastiveDataset(Dataset):
"""
Dataset for contrastive learning on logos.
Each __getitem__ call returns a batch of images organized for contrastive
learning: K different logos with M samples each, ensuring positive pairs
exist within each batch.
"""
def __init__(
self,
logo_data: LogoDataset,
split: str = "train",
logos_per_batch: int = 32,
samples_per_logo: int = 4,
transform: Optional[transforms.Compose] = None,
batches_per_epoch: int = 1000,
):
"""
Initialize the contrastive dataset.
Args:
logo_data: LogoDataset instance with logo mappings
split: One of "train", "val", or "test"
logos_per_batch: Number of different logos per batch
samples_per_logo: Number of samples for each logo
transform: Image transforms to apply
batches_per_epoch: Number of batches per epoch
"""
self.logo_data = logo_data
self.logos_per_batch = logos_per_batch
self.samples_per_logo = samples_per_logo
self.transform = transform
self.batches_per_epoch = batches_per_epoch
# Get logos for this split
if split == "train":
self.logos = logo_data.train_logos
elif split == "val":
self.logos = logo_data.val_logos
else:
self.logos = logo_data.test_logos
# Filter logos with enough samples
self.valid_logos = [
logo for logo in self.logos
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
]
# For logos with fewer samples, we'll use with replacement
self.logos_needing_replacement = [
logo for logo in self.logos
if len(logo_data.logo_to_images[logo]) < samples_per_logo
]
# Create label mapping
self.logo_to_label = {
logo: idx for idx, logo in enumerate(self.logos)
}
def __len__(self) -> int:
return self.batches_per_epoch
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a batch of images for contrastive learning.
Returns:
images: Tensor of shape [K*M, 3, 224, 224]
labels: Tensor of shape [K*M] with logo class indices
"""
images = []
labels = []
# Sample K logos for this batch
k = min(self.logos_per_batch, len(self.logos))
batch_logos = random.sample(self.logos, k)
for logo in batch_logos:
logo_images = self.logo_data.logo_to_images[logo]
# Sample M images for this logo
if len(logo_images) >= self.samples_per_logo:
sampled_paths = random.sample(logo_images, self.samples_per_logo)
else:
# Sample with replacement if not enough images
sampled_paths = random.choices(
logo_images, k=self.samples_per_logo
)
# Load and transform images
for img_path in sampled_paths:
try:
img = Image.open(img_path).convert("RGB")
if self.transform:
img = self.transform(img)
else:
img = get_val_transforms()(img)
images.append(img)
labels.append(self.logo_to_label[logo])
except Exception as e:
# Skip problematic images, sample another
continue
# Stack into tensors
if len(images) == 0:
# Fallback: return dummy batch
return (
torch.zeros(1, 3, 224, 224),
torch.zeros(1, dtype=torch.long),
)
images_tensor = torch.stack(images)
labels_tensor = torch.tensor(labels, dtype=torch.long)
return images_tensor, labels_tensor
class BalancedBatchSampler(Sampler):
"""
Sampler that ensures each batch has a balanced distribution of logos.
Used with a flattened dataset where each sample is a single image.
"""
def __init__(
self,
logo_labels: List[int],
logos_per_batch: int,
samples_per_logo: int,
num_batches: int,
):
self.logo_labels = logo_labels
self.logos_per_batch = logos_per_batch
self.samples_per_logo = samples_per_logo
self.num_batches = num_batches
# Group indices by logo
self.logo_to_indices: Dict[int, List[int]] = {}
for idx, label in enumerate(logo_labels):
if label not in self.logo_to_indices:
self.logo_to_indices[label] = []
self.logo_to_indices[label].append(idx)
self.all_logos = list(self.logo_to_indices.keys())
def __iter__(self):
for _ in range(self.num_batches):
batch_indices = []
# Sample logos for this batch
logos = random.sample(
self.all_logos,
min(self.logos_per_batch, len(self.all_logos)),
)
for logo in logos:
indices = self.logo_to_indices[logo]
if len(indices) >= self.samples_per_logo:
sampled = random.sample(indices, self.samples_per_logo)
else:
sampled = random.choices(indices, k=self.samples_per_logo)
batch_indices.extend(sampled)
yield batch_indices
def __len__(self):
return self.num_batches
def create_dataloaders(
db_path: str,
reference_dir: str,
batch_size: int = 16,
logos_per_batch: int = 32,
samples_per_logo: int = 4,
num_workers: int = 4,
train_split: float = 0.7,
val_split: float = 0.15,
test_split: float = 0.15,
seed: int = 42,
augmentation_strength: str = "medium",
batches_per_epoch: int = 1000,
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
"""
Create train, validation, and optionally test dataloaders.
Args:
db_path: Path to SQLite database
reference_dir: Directory containing reference logo images
batch_size: Not used directly (see logos_per_batch and samples_per_logo)
logos_per_batch: Number of different logos per batch
samples_per_logo: Samples per logo in batch
num_workers: Number of data loading workers
train_split: Fraction for training
val_split: Fraction for validation
test_split: Fraction for testing
seed: Random seed
augmentation_strength: "light", "medium", or "strong"
batches_per_epoch: Number of batches per training epoch
Returns:
Tuple of (train_loader, val_loader, test_loader)
"""
# Load logo data
logo_data = LogoDataset(
db_path=db_path,
reference_dir=reference_dir,
train_split=train_split,
val_split=val_split,
test_split=test_split,
seed=seed,
)
# Print split info
split_info = logo_data.get_split_info()
print(f"Dataset loaded:")
print(f" Total logos: {split_info['total_logos']}")
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images")
# Create datasets
train_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="train",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_train_transforms(augmentation_strength),
batches_per_epoch=batches_per_epoch,
)
val_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="val",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_val_transforms(),
batches_per_epoch=batches_per_epoch // 10, # Fewer val batches
)
test_dataset = LogoContrastiveDataset(
logo_data=logo_data,
split="test",
logos_per_batch=logos_per_batch,
samples_per_logo=samples_per_logo,
transform=get_val_transforms(),
batches_per_epoch=batches_per_epoch // 10,
) if test_split > 0 else None
# Create dataloaders
# Note: batch_size=1 because each __getitem__ already returns a batch
train_loader = DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
val_loader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
test_loader = None
if test_dataset is not None:
test_loader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
pin_memory=True,
collate_fn=_collate_contrastive_batch,
)
return train_loader, val_loader, test_loader
def _collate_contrastive_batch(
batch: List[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Collate function that unpacks pre-batched data.
Since LogoContrastiveDataset already returns batched data,
we just squeeze the outer dimension.
"""
images, labels = batch[0]
return images, labels