Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
This commit is contained in:
326
training/losses.py
Normal file
326
training/losses.py
Normal file
@ -0,0 +1,326 @@
|
||||
"""
|
||||
Loss functions for contrastive learning on logo embeddings.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class InfoNCELoss(nn.Module):
|
||||
"""
|
||||
Normalized Temperature-scaled Cross Entropy Loss (InfoNCE).
|
||||
|
||||
This is the contrastive loss used in CLIP training. It maximizes
|
||||
similarity between embeddings of the same logo class while
|
||||
minimizing similarity to embeddings of different classes.
|
||||
|
||||
For a batch with N samples:
|
||||
- Each sample is an anchor
|
||||
- Positive pairs: samples with the same label
|
||||
- Negative pairs: samples with different labels
|
||||
|
||||
The loss for each anchor is:
|
||||
-log(sum(exp(sim(anchor, pos)/temp)) / sum(exp(sim(anchor, all)/temp)))
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 0.07):
|
||||
"""
|
||||
Initialize InfoNCE loss.
|
||||
|
||||
Args:
|
||||
temperature: Scaling factor for similarities (0.05-0.1 typical).
|
||||
Lower temperature makes the distribution sharper.
|
||||
"""
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute InfoNCE loss for a batch of embeddings.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer logo class labels
|
||||
|
||||
Returns:
|
||||
Scalar loss value
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Compute similarity matrix [N, N]
|
||||
# Since embeddings are L2-normalized, dot product = cosine similarity
|
||||
similarity = embeddings @ embeddings.T / self.temperature
|
||||
|
||||
# Create positive mask: same label = 1, different = 0
|
||||
labels_col = labels.unsqueeze(0) # [1, N]
|
||||
labels_row = labels.unsqueeze(1) # [N, 1]
|
||||
positive_mask = (labels_row == labels_col).float() # [N, N]
|
||||
|
||||
# Remove self-similarity from positives (diagonal)
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Count positives per anchor (avoid division by zero)
|
||||
num_positives = positive_mask.sum(dim=1)
|
||||
has_positives = num_positives > 0
|
||||
|
||||
# If no positives exist for any anchor, return zero loss
|
||||
if not has_positives.any():
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Mask out self-similarity with large negative value
|
||||
similarity = similarity - identity * 1e9
|
||||
|
||||
# Compute log-softmax over similarities
|
||||
log_softmax = F.log_softmax(similarity, dim=1)
|
||||
|
||||
# Sum log probabilities of positive pairs
|
||||
positive_log_probs = (log_softmax * positive_mask).sum(dim=1)
|
||||
|
||||
# Average over number of positives (only for anchors with positives)
|
||||
loss_per_anchor = torch.zeros(batch_size, device=device)
|
||||
loss_per_anchor[has_positives] = (
|
||||
-positive_log_probs[has_positives] / num_positives[has_positives]
|
||||
)
|
||||
|
||||
return loss_per_anchor.mean()
|
||||
|
||||
|
||||
class SupConLoss(nn.Module):
|
||||
"""
|
||||
Supervised Contrastive Loss.
|
||||
|
||||
Similar to InfoNCE but uses a different formulation that
|
||||
considers each positive pair separately rather than averaging.
|
||||
|
||||
Reference: https://arxiv.org/abs/2004.11362
|
||||
"""
|
||||
|
||||
def __init__(self, temperature: float = 0.07):
|
||||
super().__init__()
|
||||
self.temperature = temperature
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute Supervised Contrastive loss.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer logo class labels
|
||||
|
||||
Returns:
|
||||
Scalar loss value
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Compute similarity matrix
|
||||
similarity = embeddings @ embeddings.T / self.temperature
|
||||
|
||||
# Create masks
|
||||
labels_col = labels.unsqueeze(0)
|
||||
labels_row = labels.unsqueeze(1)
|
||||
positive_mask = (labels_row == labels_col).float()
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
|
||||
# Remove self from positives
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Number of positives per anchor
|
||||
num_positives = positive_mask.sum(dim=1)
|
||||
has_positives = num_positives > 0
|
||||
|
||||
if not has_positives.any():
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# For numerical stability, subtract max similarity
|
||||
sim_max, _ = similarity.max(dim=1, keepdim=True)
|
||||
similarity = similarity - sim_max.detach()
|
||||
|
||||
# Compute exp(similarity) with self masked out
|
||||
exp_sim = torch.exp(similarity) * (1 - identity)
|
||||
|
||||
# Denominator: sum of exp over all pairs except self
|
||||
log_prob = similarity - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)
|
||||
|
||||
# Mean of log-prob over positive pairs
|
||||
mean_log_prob_pos = (positive_mask * log_prob).sum(dim=1) / (
|
||||
num_positives + 1e-8
|
||||
)
|
||||
|
||||
# Loss is negative mean log probability
|
||||
loss = -mean_log_prob_pos[has_positives].mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class TripletLoss(nn.Module):
|
||||
"""
|
||||
Triplet loss with online hard mining.
|
||||
|
||||
For each anchor:
|
||||
- Hardest positive: most distant sample with same label
|
||||
- Hardest negative: closest sample with different label
|
||||
|
||||
Loss = max(0, d(anchor, hardest_pos) - d(anchor, hardest_neg) + margin)
|
||||
|
||||
This is an alternative to InfoNCE for when batch sizes are small.
|
||||
"""
|
||||
|
||||
def __init__(self, margin: float = 0.3):
|
||||
"""
|
||||
Initialize Triplet loss.
|
||||
|
||||
Args:
|
||||
margin: Minimum required gap between positive and negative distances
|
||||
"""
|
||||
super().__init__()
|
||||
self.margin = margin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute triplet loss with online hard mining.
|
||||
|
||||
Args:
|
||||
embeddings: [N, D] L2-normalized embeddings
|
||||
labels: [N] integer logo class labels
|
||||
|
||||
Returns:
|
||||
Scalar loss value
|
||||
"""
|
||||
device = embeddings.device
|
||||
batch_size = embeddings.shape[0]
|
||||
|
||||
if batch_size <= 1:
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# Compute pairwise cosine distances (1 - cosine_similarity)
|
||||
# For normalized vectors: distance = 1 - dot_product
|
||||
similarity = embeddings @ embeddings.T
|
||||
distances = 1 - similarity
|
||||
|
||||
# Create masks
|
||||
labels_col = labels.unsqueeze(0)
|
||||
labels_row = labels.unsqueeze(1)
|
||||
positive_mask = (labels_row == labels_col).float()
|
||||
negative_mask = 1 - positive_mask
|
||||
|
||||
# Remove self from positives (diagonal)
|
||||
identity = torch.eye(batch_size, device=device)
|
||||
positive_mask = positive_mask - identity
|
||||
|
||||
# Check if we have any valid triplets
|
||||
has_positives = positive_mask.sum(dim=1) > 0
|
||||
has_negatives = negative_mask.sum(dim=1) > 0
|
||||
valid_anchors = has_positives & has_negatives
|
||||
|
||||
if not valid_anchors.any():
|
||||
return torch.tensor(0.0, device=device, requires_grad=True)
|
||||
|
||||
# For each anchor, find hardest positive (max distance among positives)
|
||||
# Set negatives to -inf so they don't affect max
|
||||
pos_distances = distances.clone()
|
||||
pos_distances[positive_mask == 0] = float("-inf")
|
||||
hardest_positive, _ = pos_distances.max(dim=1)
|
||||
|
||||
# For each anchor, find hardest negative (min distance among negatives)
|
||||
# Set positives to inf so they don't affect min
|
||||
neg_distances = distances.clone()
|
||||
neg_distances[negative_mask == 0] = float("inf")
|
||||
hardest_negative, _ = neg_distances.min(dim=1)
|
||||
|
||||
# Triplet loss: want positive to be closer than negative by margin
|
||||
triplet_loss = F.relu(
|
||||
hardest_positive - hardest_negative + self.margin
|
||||
)
|
||||
|
||||
# Average over valid anchors only
|
||||
loss = triplet_loss[valid_anchors].mean()
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Module):
|
||||
"""
|
||||
Combined loss function with weighted InfoNCE and Triplet losses.
|
||||
|
||||
Can help stabilize training by combining the benefits of both losses.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: float = 0.07,
|
||||
triplet_margin: float = 0.3,
|
||||
infonce_weight: float = 1.0,
|
||||
triplet_weight: float = 0.5,
|
||||
):
|
||||
super().__init__()
|
||||
self.infonce = InfoNCELoss(temperature=temperature)
|
||||
self.triplet = TripletLoss(margin=triplet_margin)
|
||||
self.infonce_weight = infonce_weight
|
||||
self.triplet_weight = triplet_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
embeddings: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
infonce_loss = self.infonce(embeddings, labels)
|
||||
triplet_loss = self.triplet(embeddings, labels)
|
||||
|
||||
return (
|
||||
self.infonce_weight * infonce_loss +
|
||||
self.triplet_weight * triplet_loss
|
||||
)
|
||||
|
||||
|
||||
def get_loss_function(
|
||||
loss_type: str = "infonce",
|
||||
temperature: float = 0.07,
|
||||
triplet_margin: float = 0.3,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Factory function to create loss function.
|
||||
|
||||
Args:
|
||||
loss_type: One of "infonce", "supcon", "triplet", or "combined"
|
||||
temperature: Temperature for InfoNCE/SupCon
|
||||
triplet_margin: Margin for triplet loss
|
||||
|
||||
Returns:
|
||||
Loss function module
|
||||
"""
|
||||
if loss_type == "infonce":
|
||||
return InfoNCELoss(temperature=temperature)
|
||||
elif loss_type == "supcon":
|
||||
return SupConLoss(temperature=temperature)
|
||||
elif loss_type == "triplet":
|
||||
return TripletLoss(margin=triplet_margin)
|
||||
elif loss_type == "combined":
|
||||
return CombinedLoss(
|
||||
temperature=temperature,
|
||||
triplet_margin=triplet_margin,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {loss_type}")
|
||||
Reference in New Issue
Block a user