Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
468 lines
15 KiB
Python
468 lines
15 KiB
Python
"""
|
|
Dataset classes for contrastive learning on logo images.
|
|
"""
|
|
|
|
import random
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
from PIL import Image
|
|
from torch.utils.data import Dataset, DataLoader, Sampler
|
|
from torchvision import transforms
|
|
|
|
|
|
# CLIP normalization values
|
|
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
|
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
|
|
|
|
|
def get_train_transforms(strength: str = "medium") -> transforms.Compose:
|
|
"""
|
|
Get training data augmentation transforms.
|
|
|
|
Args:
|
|
strength: Augmentation strength - "light", "medium", or "strong"
|
|
|
|
Returns:
|
|
Composed transforms for training
|
|
"""
|
|
if strength == "light":
|
|
return transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.RandomHorizontalFlip(p=0.5),
|
|
transforms.ColorJitter(brightness=0.1, contrast=0.1),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
|
])
|
|
elif strength == "medium":
|
|
return transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.RandomHorizontalFlip(p=0.5),
|
|
transforms.RandomRotation(degrees=15),
|
|
transforms.ColorJitter(
|
|
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
|
|
),
|
|
transforms.RandomAffine(
|
|
degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
|
|
),
|
|
transforms.RandomGrayscale(p=0.1),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
|
])
|
|
else: # strong
|
|
return transforms.Compose([
|
|
transforms.Resize((256, 256)),
|
|
transforms.RandomCrop(224),
|
|
transforms.RandomHorizontalFlip(p=0.5),
|
|
transforms.RandomVerticalFlip(p=0.1),
|
|
transforms.RandomRotation(degrees=30),
|
|
transforms.ColorJitter(
|
|
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
|
|
),
|
|
transforms.RandomAffine(
|
|
degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10
|
|
),
|
|
transforms.RandomGrayscale(p=0.2),
|
|
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
|
])
|
|
|
|
|
|
def get_val_transforms() -> transforms.Compose:
|
|
"""Get validation/test transforms (no augmentation)."""
|
|
return transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
|
])
|
|
|
|
|
|
class LogoDataset:
|
|
"""
|
|
Manages logo data from the SQLite database.
|
|
|
|
Handles loading logo-to-image mappings and splitting by logo brand.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
db_path: str,
|
|
reference_dir: str,
|
|
train_split: float = 0.7,
|
|
val_split: float = 0.15,
|
|
test_split: float = 0.15,
|
|
seed: int = 42,
|
|
):
|
|
self.db_path = Path(db_path)
|
|
self.reference_dir = Path(reference_dir)
|
|
self.seed = seed
|
|
|
|
# Load logo-to-images mapping from database
|
|
self.logo_to_images = self._load_logo_mappings()
|
|
self.all_logos = list(self.logo_to_images.keys())
|
|
|
|
# Create logo-level splits
|
|
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
|
train_split, val_split, test_split
|
|
)
|
|
|
|
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
|
|
"""Load logo name to image paths mapping from database."""
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT ln.name, rl.filename
|
|
FROM reference_logos rl
|
|
JOIN logo_names ln ON rl.logo_name_id = ln.id
|
|
ORDER BY ln.name
|
|
""")
|
|
|
|
logo_to_images: Dict[str, List[Path]] = {}
|
|
for logo_name, filename in cursor.fetchall():
|
|
if logo_name not in logo_to_images:
|
|
logo_to_images[logo_name] = []
|
|
logo_to_images[logo_name].append(self.reference_dir / filename)
|
|
|
|
conn.close()
|
|
return logo_to_images
|
|
|
|
def _split_logos(
|
|
self,
|
|
train_split: float,
|
|
val_split: float,
|
|
test_split: float,
|
|
) -> Tuple[List[str], List[str], List[str]]:
|
|
"""Split logos at brand level for train/val/test."""
|
|
random.seed(self.seed)
|
|
logos = self.all_logos.copy()
|
|
random.shuffle(logos)
|
|
|
|
n = len(logos)
|
|
train_end = int(n * train_split)
|
|
val_end = train_end + int(n * val_split)
|
|
|
|
train_logos = logos[:train_end]
|
|
val_logos = logos[train_end:val_end]
|
|
test_logos = logos[val_end:]
|
|
|
|
return train_logos, val_logos, test_logos
|
|
|
|
def get_split_info(self) -> Dict[str, int]:
|
|
"""Return information about the splits."""
|
|
return {
|
|
"total_logos": len(self.all_logos),
|
|
"train_logos": len(self.train_logos),
|
|
"val_logos": len(self.val_logos),
|
|
"test_logos": len(self.test_logos),
|
|
"train_images": sum(
|
|
len(self.logo_to_images[l]) for l in self.train_logos
|
|
),
|
|
"val_images": sum(
|
|
len(self.logo_to_images[l]) for l in self.val_logos
|
|
),
|
|
"test_images": sum(
|
|
len(self.logo_to_images[l]) for l in self.test_logos
|
|
),
|
|
}
|
|
|
|
|
|
class LogoContrastiveDataset(Dataset):
|
|
"""
|
|
Dataset for contrastive learning on logos.
|
|
|
|
Each __getitem__ call returns a batch of images organized for contrastive
|
|
learning: K different logos with M samples each, ensuring positive pairs
|
|
exist within each batch.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
logo_data: LogoDataset,
|
|
split: str = "train",
|
|
logos_per_batch: int = 32,
|
|
samples_per_logo: int = 4,
|
|
transform: Optional[transforms.Compose] = None,
|
|
batches_per_epoch: int = 1000,
|
|
):
|
|
"""
|
|
Initialize the contrastive dataset.
|
|
|
|
Args:
|
|
logo_data: LogoDataset instance with logo mappings
|
|
split: One of "train", "val", or "test"
|
|
logos_per_batch: Number of different logos per batch
|
|
samples_per_logo: Number of samples for each logo
|
|
transform: Image transforms to apply
|
|
batches_per_epoch: Number of batches per epoch
|
|
"""
|
|
self.logo_data = logo_data
|
|
self.logos_per_batch = logos_per_batch
|
|
self.samples_per_logo = samples_per_logo
|
|
self.transform = transform
|
|
self.batches_per_epoch = batches_per_epoch
|
|
|
|
# Get logos for this split
|
|
if split == "train":
|
|
self.logos = logo_data.train_logos
|
|
elif split == "val":
|
|
self.logos = logo_data.val_logos
|
|
else:
|
|
self.logos = logo_data.test_logos
|
|
|
|
# Filter logos with enough samples
|
|
self.valid_logos = [
|
|
logo for logo in self.logos
|
|
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
|
|
]
|
|
|
|
# For logos with fewer samples, we'll use with replacement
|
|
self.logos_needing_replacement = [
|
|
logo for logo in self.logos
|
|
if len(logo_data.logo_to_images[logo]) < samples_per_logo
|
|
]
|
|
|
|
# Create label mapping
|
|
self.logo_to_label = {
|
|
logo: idx for idx, logo in enumerate(self.logos)
|
|
}
|
|
|
|
def __len__(self) -> int:
|
|
return self.batches_per_epoch
|
|
|
|
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Get a batch of images for contrastive learning.
|
|
|
|
Returns:
|
|
images: Tensor of shape [K*M, 3, 224, 224]
|
|
labels: Tensor of shape [K*M] with logo class indices
|
|
"""
|
|
images = []
|
|
labels = []
|
|
|
|
# Sample K logos for this batch
|
|
k = min(self.logos_per_batch, len(self.logos))
|
|
batch_logos = random.sample(self.logos, k)
|
|
|
|
for logo in batch_logos:
|
|
logo_images = self.logo_data.logo_to_images[logo]
|
|
|
|
# Sample M images for this logo
|
|
if len(logo_images) >= self.samples_per_logo:
|
|
sampled_paths = random.sample(logo_images, self.samples_per_logo)
|
|
else:
|
|
# Sample with replacement if not enough images
|
|
sampled_paths = random.choices(
|
|
logo_images, k=self.samples_per_logo
|
|
)
|
|
|
|
# Load and transform images
|
|
for img_path in sampled_paths:
|
|
try:
|
|
img = Image.open(img_path).convert("RGB")
|
|
if self.transform:
|
|
img = self.transform(img)
|
|
else:
|
|
img = get_val_transforms()(img)
|
|
images.append(img)
|
|
labels.append(self.logo_to_label[logo])
|
|
except Exception as e:
|
|
# Skip problematic images, sample another
|
|
continue
|
|
|
|
# Stack into tensors
|
|
if len(images) == 0:
|
|
# Fallback: return dummy batch
|
|
return (
|
|
torch.zeros(1, 3, 224, 224),
|
|
torch.zeros(1, dtype=torch.long),
|
|
)
|
|
|
|
images_tensor = torch.stack(images)
|
|
labels_tensor = torch.tensor(labels, dtype=torch.long)
|
|
|
|
return images_tensor, labels_tensor
|
|
|
|
|
|
class BalancedBatchSampler(Sampler):
|
|
"""
|
|
Sampler that ensures each batch has a balanced distribution of logos.
|
|
|
|
Used with a flattened dataset where each sample is a single image.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
logo_labels: List[int],
|
|
logos_per_batch: int,
|
|
samples_per_logo: int,
|
|
num_batches: int,
|
|
):
|
|
self.logo_labels = logo_labels
|
|
self.logos_per_batch = logos_per_batch
|
|
self.samples_per_logo = samples_per_logo
|
|
self.num_batches = num_batches
|
|
|
|
# Group indices by logo
|
|
self.logo_to_indices: Dict[int, List[int]] = {}
|
|
for idx, label in enumerate(logo_labels):
|
|
if label not in self.logo_to_indices:
|
|
self.logo_to_indices[label] = []
|
|
self.logo_to_indices[label].append(idx)
|
|
|
|
self.all_logos = list(self.logo_to_indices.keys())
|
|
|
|
def __iter__(self):
|
|
for _ in range(self.num_batches):
|
|
batch_indices = []
|
|
|
|
# Sample logos for this batch
|
|
logos = random.sample(
|
|
self.all_logos,
|
|
min(self.logos_per_batch, len(self.all_logos)),
|
|
)
|
|
|
|
for logo in logos:
|
|
indices = self.logo_to_indices[logo]
|
|
if len(indices) >= self.samples_per_logo:
|
|
sampled = random.sample(indices, self.samples_per_logo)
|
|
else:
|
|
sampled = random.choices(indices, k=self.samples_per_logo)
|
|
batch_indices.extend(sampled)
|
|
|
|
yield batch_indices
|
|
|
|
def __len__(self):
|
|
return self.num_batches
|
|
|
|
|
|
def create_dataloaders(
|
|
db_path: str,
|
|
reference_dir: str,
|
|
batch_size: int = 16,
|
|
logos_per_batch: int = 32,
|
|
samples_per_logo: int = 4,
|
|
num_workers: int = 4,
|
|
train_split: float = 0.7,
|
|
val_split: float = 0.15,
|
|
test_split: float = 0.15,
|
|
seed: int = 42,
|
|
augmentation_strength: str = "medium",
|
|
batches_per_epoch: int = 1000,
|
|
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
|
|
"""
|
|
Create train, validation, and optionally test dataloaders.
|
|
|
|
Args:
|
|
db_path: Path to SQLite database
|
|
reference_dir: Directory containing reference logo images
|
|
batch_size: Not used directly (see logos_per_batch and samples_per_logo)
|
|
logos_per_batch: Number of different logos per batch
|
|
samples_per_logo: Samples per logo in batch
|
|
num_workers: Number of data loading workers
|
|
train_split: Fraction for training
|
|
val_split: Fraction for validation
|
|
test_split: Fraction for testing
|
|
seed: Random seed
|
|
augmentation_strength: "light", "medium", or "strong"
|
|
batches_per_epoch: Number of batches per training epoch
|
|
|
|
Returns:
|
|
Tuple of (train_loader, val_loader, test_loader)
|
|
"""
|
|
# Load logo data
|
|
logo_data = LogoDataset(
|
|
db_path=db_path,
|
|
reference_dir=reference_dir,
|
|
train_split=train_split,
|
|
val_split=val_split,
|
|
test_split=test_split,
|
|
seed=seed,
|
|
)
|
|
|
|
# Print split info
|
|
split_info = logo_data.get_split_info()
|
|
print(f"Dataset loaded:")
|
|
print(f" Total logos: {split_info['total_logos']}")
|
|
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
|
|
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
|
|
print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images")
|
|
|
|
# Create datasets
|
|
train_dataset = LogoContrastiveDataset(
|
|
logo_data=logo_data,
|
|
split="train",
|
|
logos_per_batch=logos_per_batch,
|
|
samples_per_logo=samples_per_logo,
|
|
transform=get_train_transforms(augmentation_strength),
|
|
batches_per_epoch=batches_per_epoch,
|
|
)
|
|
|
|
val_dataset = LogoContrastiveDataset(
|
|
logo_data=logo_data,
|
|
split="val",
|
|
logos_per_batch=logos_per_batch,
|
|
samples_per_logo=samples_per_logo,
|
|
transform=get_val_transforms(),
|
|
batches_per_epoch=batches_per_epoch // 10, # Fewer val batches
|
|
)
|
|
|
|
test_dataset = LogoContrastiveDataset(
|
|
logo_data=logo_data,
|
|
split="test",
|
|
logos_per_batch=logos_per_batch,
|
|
samples_per_logo=samples_per_logo,
|
|
transform=get_val_transforms(),
|
|
batches_per_epoch=batches_per_epoch // 10,
|
|
) if test_split > 0 else None
|
|
|
|
# Create dataloaders
|
|
# Note: batch_size=1 because each __getitem__ already returns a batch
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=1,
|
|
shuffle=True,
|
|
num_workers=num_workers,
|
|
pin_memory=True,
|
|
collate_fn=_collate_contrastive_batch,
|
|
)
|
|
|
|
val_loader = DataLoader(
|
|
val_dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
pin_memory=True,
|
|
collate_fn=_collate_contrastive_batch,
|
|
)
|
|
|
|
test_loader = None
|
|
if test_dataset is not None:
|
|
test_loader = DataLoader(
|
|
test_dataset,
|
|
batch_size=1,
|
|
shuffle=False,
|
|
num_workers=num_workers,
|
|
pin_memory=True,
|
|
collate_fn=_collate_contrastive_batch,
|
|
)
|
|
|
|
return train_loader, val_loader, test_loader
|
|
|
|
|
|
def _collate_contrastive_batch(
|
|
batch: List[Tuple[torch.Tensor, torch.Tensor]]
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Collate function that unpacks pre-batched data.
|
|
|
|
Since LogoContrastiveDataset already returns batched data,
|
|
we just squeeze the outer dimension.
|
|
"""
|
|
images, labels = batch[0]
|
|
return images, labels
|