Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
This commit is contained in:
467
training/dataset.py
Normal file
467
training/dataset.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""
|
||||
Dataset classes for contrastive learning on logo images.
|
||||
"""
|
||||
|
||||
import random
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset, DataLoader, Sampler
|
||||
from torchvision import transforms
|
||||
|
||||
|
||||
# CLIP normalization values
|
||||
CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
||||
CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
||||
|
||||
|
||||
def get_train_transforms(strength: str = "medium") -> transforms.Compose:
|
||||
"""
|
||||
Get training data augmentation transforms.
|
||||
|
||||
Args:
|
||||
strength: Augmentation strength - "light", "medium", or "strong"
|
||||
|
||||
Returns:
|
||||
Composed transforms for training
|
||||
"""
|
||||
if strength == "light":
|
||||
return transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.ColorJitter(brightness=0.1, contrast=0.1),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
elif strength == "medium":
|
||||
return transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomRotation(degrees=15),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05
|
||||
),
|
||||
transforms.RandomAffine(
|
||||
degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.1),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
else: # strong
|
||||
return transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.RandomCrop(224),
|
||||
transforms.RandomHorizontalFlip(p=0.5),
|
||||
transforms.RandomVerticalFlip(p=0.1),
|
||||
transforms.RandomRotation(degrees=30),
|
||||
transforms.ColorJitter(
|
||||
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
|
||||
),
|
||||
transforms.RandomAffine(
|
||||
degrees=0, translate=(0.15, 0.15), scale=(0.8, 1.2), shear=10
|
||||
),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
|
||||
|
||||
def get_val_transforms() -> transforms.Compose:
|
||||
"""Get validation/test transforms (no augmentation)."""
|
||||
return transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD),
|
||||
])
|
||||
|
||||
|
||||
class LogoDataset:
|
||||
"""
|
||||
Manages logo data from the SQLite database.
|
||||
|
||||
Handles loading logo-to-image mappings and splitting by logo brand.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: str,
|
||||
reference_dir: str,
|
||||
train_split: float = 0.7,
|
||||
val_split: float = 0.15,
|
||||
test_split: float = 0.15,
|
||||
seed: int = 42,
|
||||
):
|
||||
self.db_path = Path(db_path)
|
||||
self.reference_dir = Path(reference_dir)
|
||||
self.seed = seed
|
||||
|
||||
# Load logo-to-images mapping from database
|
||||
self.logo_to_images = self._load_logo_mappings()
|
||||
self.all_logos = list(self.logo_to_images.keys())
|
||||
|
||||
# Create logo-level splits
|
||||
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
||||
train_split, val_split, test_split
|
||||
)
|
||||
|
||||
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
|
||||
"""Load logo name to image paths mapping from database."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("""
|
||||
SELECT ln.name, rl.filename
|
||||
FROM reference_logos rl
|
||||
JOIN logo_names ln ON rl.logo_name_id = ln.id
|
||||
ORDER BY ln.name
|
||||
""")
|
||||
|
||||
logo_to_images: Dict[str, List[Path]] = {}
|
||||
for logo_name, filename in cursor.fetchall():
|
||||
if logo_name not in logo_to_images:
|
||||
logo_to_images[logo_name] = []
|
||||
logo_to_images[logo_name].append(self.reference_dir / filename)
|
||||
|
||||
conn.close()
|
||||
return logo_to_images
|
||||
|
||||
def _split_logos(
|
||||
self,
|
||||
train_split: float,
|
||||
val_split: float,
|
||||
test_split: float,
|
||||
) -> Tuple[List[str], List[str], List[str]]:
|
||||
"""Split logos at brand level for train/val/test."""
|
||||
random.seed(self.seed)
|
||||
logos = self.all_logos.copy()
|
||||
random.shuffle(logos)
|
||||
|
||||
n = len(logos)
|
||||
train_end = int(n * train_split)
|
||||
val_end = train_end + int(n * val_split)
|
||||
|
||||
train_logos = logos[:train_end]
|
||||
val_logos = logos[train_end:val_end]
|
||||
test_logos = logos[val_end:]
|
||||
|
||||
return train_logos, val_logos, test_logos
|
||||
|
||||
def get_split_info(self) -> Dict[str, int]:
|
||||
"""Return information about the splits."""
|
||||
return {
|
||||
"total_logos": len(self.all_logos),
|
||||
"train_logos": len(self.train_logos),
|
||||
"val_logos": len(self.val_logos),
|
||||
"test_logos": len(self.test_logos),
|
||||
"train_images": sum(
|
||||
len(self.logo_to_images[l]) for l in self.train_logos
|
||||
),
|
||||
"val_images": sum(
|
||||
len(self.logo_to_images[l]) for l in self.val_logos
|
||||
),
|
||||
"test_images": sum(
|
||||
len(self.logo_to_images[l]) for l in self.test_logos
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class LogoContrastiveDataset(Dataset):
|
||||
"""
|
||||
Dataset for contrastive learning on logos.
|
||||
|
||||
Each __getitem__ call returns a batch of images organized for contrastive
|
||||
learning: K different logos with M samples each, ensuring positive pairs
|
||||
exist within each batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logo_data: LogoDataset,
|
||||
split: str = "train",
|
||||
logos_per_batch: int = 32,
|
||||
samples_per_logo: int = 4,
|
||||
transform: Optional[transforms.Compose] = None,
|
||||
batches_per_epoch: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the contrastive dataset.
|
||||
|
||||
Args:
|
||||
logo_data: LogoDataset instance with logo mappings
|
||||
split: One of "train", "val", or "test"
|
||||
logos_per_batch: Number of different logos per batch
|
||||
samples_per_logo: Number of samples for each logo
|
||||
transform: Image transforms to apply
|
||||
batches_per_epoch: Number of batches per epoch
|
||||
"""
|
||||
self.logo_data = logo_data
|
||||
self.logos_per_batch = logos_per_batch
|
||||
self.samples_per_logo = samples_per_logo
|
||||
self.transform = transform
|
||||
self.batches_per_epoch = batches_per_epoch
|
||||
|
||||
# Get logos for this split
|
||||
if split == "train":
|
||||
self.logos = logo_data.train_logos
|
||||
elif split == "val":
|
||||
self.logos = logo_data.val_logos
|
||||
else:
|
||||
self.logos = logo_data.test_logos
|
||||
|
||||
# Filter logos with enough samples
|
||||
self.valid_logos = [
|
||||
logo for logo in self.logos
|
||||
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
|
||||
]
|
||||
|
||||
# For logos with fewer samples, we'll use with replacement
|
||||
self.logos_needing_replacement = [
|
||||
logo for logo in self.logos
|
||||
if len(logo_data.logo_to_images[logo]) < samples_per_logo
|
||||
]
|
||||
|
||||
# Create label mapping
|
||||
self.logo_to_label = {
|
||||
logo: idx for idx, logo in enumerate(self.logos)
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.batches_per_epoch
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Get a batch of images for contrastive learning.
|
||||
|
||||
Returns:
|
||||
images: Tensor of shape [K*M, 3, 224, 224]
|
||||
labels: Tensor of shape [K*M] with logo class indices
|
||||
"""
|
||||
images = []
|
||||
labels = []
|
||||
|
||||
# Sample K logos for this batch
|
||||
k = min(self.logos_per_batch, len(self.logos))
|
||||
batch_logos = random.sample(self.logos, k)
|
||||
|
||||
for logo in batch_logos:
|
||||
logo_images = self.logo_data.logo_to_images[logo]
|
||||
|
||||
# Sample M images for this logo
|
||||
if len(logo_images) >= self.samples_per_logo:
|
||||
sampled_paths = random.sample(logo_images, self.samples_per_logo)
|
||||
else:
|
||||
# Sample with replacement if not enough images
|
||||
sampled_paths = random.choices(
|
||||
logo_images, k=self.samples_per_logo
|
||||
)
|
||||
|
||||
# Load and transform images
|
||||
for img_path in sampled_paths:
|
||||
try:
|
||||
img = Image.open(img_path).convert("RGB")
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
else:
|
||||
img = get_val_transforms()(img)
|
||||
images.append(img)
|
||||
labels.append(self.logo_to_label[logo])
|
||||
except Exception as e:
|
||||
# Skip problematic images, sample another
|
||||
continue
|
||||
|
||||
# Stack into tensors
|
||||
if len(images) == 0:
|
||||
# Fallback: return dummy batch
|
||||
return (
|
||||
torch.zeros(1, 3, 224, 224),
|
||||
torch.zeros(1, dtype=torch.long),
|
||||
)
|
||||
|
||||
images_tensor = torch.stack(images)
|
||||
labels_tensor = torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
return images_tensor, labels_tensor
|
||||
|
||||
|
||||
class BalancedBatchSampler(Sampler):
|
||||
"""
|
||||
Sampler that ensures each batch has a balanced distribution of logos.
|
||||
|
||||
Used with a flattened dataset where each sample is a single image.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logo_labels: List[int],
|
||||
logos_per_batch: int,
|
||||
samples_per_logo: int,
|
||||
num_batches: int,
|
||||
):
|
||||
self.logo_labels = logo_labels
|
||||
self.logos_per_batch = logos_per_batch
|
||||
self.samples_per_logo = samples_per_logo
|
||||
self.num_batches = num_batches
|
||||
|
||||
# Group indices by logo
|
||||
self.logo_to_indices: Dict[int, List[int]] = {}
|
||||
for idx, label in enumerate(logo_labels):
|
||||
if label not in self.logo_to_indices:
|
||||
self.logo_to_indices[label] = []
|
||||
self.logo_to_indices[label].append(idx)
|
||||
|
||||
self.all_logos = list(self.logo_to_indices.keys())
|
||||
|
||||
def __iter__(self):
|
||||
for _ in range(self.num_batches):
|
||||
batch_indices = []
|
||||
|
||||
# Sample logos for this batch
|
||||
logos = random.sample(
|
||||
self.all_logos,
|
||||
min(self.logos_per_batch, len(self.all_logos)),
|
||||
)
|
||||
|
||||
for logo in logos:
|
||||
indices = self.logo_to_indices[logo]
|
||||
if len(indices) >= self.samples_per_logo:
|
||||
sampled = random.sample(indices, self.samples_per_logo)
|
||||
else:
|
||||
sampled = random.choices(indices, k=self.samples_per_logo)
|
||||
batch_indices.extend(sampled)
|
||||
|
||||
yield batch_indices
|
||||
|
||||
def __len__(self):
|
||||
return self.num_batches
|
||||
|
||||
|
||||
def create_dataloaders(
|
||||
db_path: str,
|
||||
reference_dir: str,
|
||||
batch_size: int = 16,
|
||||
logos_per_batch: int = 32,
|
||||
samples_per_logo: int = 4,
|
||||
num_workers: int = 4,
|
||||
train_split: float = 0.7,
|
||||
val_split: float = 0.15,
|
||||
test_split: float = 0.15,
|
||||
seed: int = 42,
|
||||
augmentation_strength: str = "medium",
|
||||
batches_per_epoch: int = 1000,
|
||||
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
|
||||
"""
|
||||
Create train, validation, and optionally test dataloaders.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database
|
||||
reference_dir: Directory containing reference logo images
|
||||
batch_size: Not used directly (see logos_per_batch and samples_per_logo)
|
||||
logos_per_batch: Number of different logos per batch
|
||||
samples_per_logo: Samples per logo in batch
|
||||
num_workers: Number of data loading workers
|
||||
train_split: Fraction for training
|
||||
val_split: Fraction for validation
|
||||
test_split: Fraction for testing
|
||||
seed: Random seed
|
||||
augmentation_strength: "light", "medium", or "strong"
|
||||
batches_per_epoch: Number of batches per training epoch
|
||||
|
||||
Returns:
|
||||
Tuple of (train_loader, val_loader, test_loader)
|
||||
"""
|
||||
# Load logo data
|
||||
logo_data = LogoDataset(
|
||||
db_path=db_path,
|
||||
reference_dir=reference_dir,
|
||||
train_split=train_split,
|
||||
val_split=val_split,
|
||||
test_split=test_split,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
# Print split info
|
||||
split_info = logo_data.get_split_info()
|
||||
print(f"Dataset loaded:")
|
||||
print(f" Total logos: {split_info['total_logos']}")
|
||||
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
|
||||
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
|
||||
print(f" Test: {split_info['test_logos']} logos, {split_info['test_images']} images")
|
||||
|
||||
# Create datasets
|
||||
train_dataset = LogoContrastiveDataset(
|
||||
logo_data=logo_data,
|
||||
split="train",
|
||||
logos_per_batch=logos_per_batch,
|
||||
samples_per_logo=samples_per_logo,
|
||||
transform=get_train_transforms(augmentation_strength),
|
||||
batches_per_epoch=batches_per_epoch,
|
||||
)
|
||||
|
||||
val_dataset = LogoContrastiveDataset(
|
||||
logo_data=logo_data,
|
||||
split="val",
|
||||
logos_per_batch=logos_per_batch,
|
||||
samples_per_logo=samples_per_logo,
|
||||
transform=get_val_transforms(),
|
||||
batches_per_epoch=batches_per_epoch // 10, # Fewer val batches
|
||||
)
|
||||
|
||||
test_dataset = LogoContrastiveDataset(
|
||||
logo_data=logo_data,
|
||||
split="test",
|
||||
logos_per_batch=logos_per_batch,
|
||||
samples_per_logo=samples_per_logo,
|
||||
transform=get_val_transforms(),
|
||||
batches_per_epoch=batches_per_epoch // 10,
|
||||
) if test_split > 0 else None
|
||||
|
||||
# Create dataloaders
|
||||
# Note: batch_size=1 because each __getitem__ already returns a batch
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=1,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=_collate_contrastive_batch,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=_collate_contrastive_batch,
|
||||
)
|
||||
|
||||
test_loader = None
|
||||
if test_dataset is not None:
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=_collate_contrastive_batch,
|
||||
)
|
||||
|
||||
return train_loader, val_loader, test_loader
|
||||
|
||||
|
||||
def _collate_contrastive_batch(
|
||||
batch: List[Tuple[torch.Tensor, torch.Tensor]]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Collate function that unpacks pre-batched data.
|
||||
|
||||
Since LogoContrastiveDataset already returns batched data,
|
||||
we just squeeze the outer dimension.
|
||||
"""
|
||||
images, labels = batch[0]
|
||||
return images, labels
|
||||
Reference in New Issue
Block a user