Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
327 lines
10 KiB
Python
327 lines
10 KiB
Python
"""
|
|
Loss functions for contrastive learning on logo embeddings.
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Optional
|
|
|
|
|
|
class InfoNCELoss(nn.Module):
|
|
"""
|
|
Normalized Temperature-scaled Cross Entropy Loss (InfoNCE).
|
|
|
|
This is the contrastive loss used in CLIP training. It maximizes
|
|
similarity between embeddings of the same logo class while
|
|
minimizing similarity to embeddings of different classes.
|
|
|
|
For a batch with N samples:
|
|
- Each sample is an anchor
|
|
- Positive pairs: samples with the same label
|
|
- Negative pairs: samples with different labels
|
|
|
|
The loss for each anchor is:
|
|
-log(sum(exp(sim(anchor, pos)/temp)) / sum(exp(sim(anchor, all)/temp)))
|
|
"""
|
|
|
|
def __init__(self, temperature: float = 0.07):
|
|
"""
|
|
Initialize InfoNCE loss.
|
|
|
|
Args:
|
|
temperature: Scaling factor for similarities (0.05-0.1 typical).
|
|
Lower temperature makes the distribution sharper.
|
|
"""
|
|
super().__init__()
|
|
self.temperature = temperature
|
|
|
|
def forward(
|
|
self,
|
|
embeddings: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute InfoNCE loss for a batch of embeddings.
|
|
|
|
Args:
|
|
embeddings: [N, D] L2-normalized embeddings
|
|
labels: [N] integer logo class labels
|
|
|
|
Returns:
|
|
Scalar loss value
|
|
"""
|
|
device = embeddings.device
|
|
batch_size = embeddings.shape[0]
|
|
|
|
if batch_size <= 1:
|
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
|
|
# Compute similarity matrix [N, N]
|
|
# Since embeddings are L2-normalized, dot product = cosine similarity
|
|
similarity = embeddings @ embeddings.T / self.temperature
|
|
|
|
# Create positive mask: same label = 1, different = 0
|
|
labels_col = labels.unsqueeze(0) # [1, N]
|
|
labels_row = labels.unsqueeze(1) # [N, 1]
|
|
positive_mask = (labels_row == labels_col).float() # [N, N]
|
|
|
|
# Remove self-similarity from positives (diagonal)
|
|
identity = torch.eye(batch_size, device=device)
|
|
positive_mask = positive_mask - identity
|
|
|
|
# Count positives per anchor (avoid division by zero)
|
|
num_positives = positive_mask.sum(dim=1)
|
|
has_positives = num_positives > 0
|
|
|
|
# If no positives exist for any anchor, return zero loss
|
|
if not has_positives.any():
|
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
|
|
# Mask out self-similarity with large negative value
|
|
similarity = similarity - identity * 1e9
|
|
|
|
# Compute log-softmax over similarities
|
|
log_softmax = F.log_softmax(similarity, dim=1)
|
|
|
|
# Sum log probabilities of positive pairs
|
|
positive_log_probs = (log_softmax * positive_mask).sum(dim=1)
|
|
|
|
# Average over number of positives (only for anchors with positives)
|
|
loss_per_anchor = torch.zeros(batch_size, device=device)
|
|
loss_per_anchor[has_positives] = (
|
|
-positive_log_probs[has_positives] / num_positives[has_positives]
|
|
)
|
|
|
|
return loss_per_anchor.mean()
|
|
|
|
|
|
class SupConLoss(nn.Module):
|
|
"""
|
|
Supervised Contrastive Loss.
|
|
|
|
Similar to InfoNCE but uses a different formulation that
|
|
considers each positive pair separately rather than averaging.
|
|
|
|
Reference: https://arxiv.org/abs/2004.11362
|
|
"""
|
|
|
|
def __init__(self, temperature: float = 0.07):
|
|
super().__init__()
|
|
self.temperature = temperature
|
|
|
|
def forward(
|
|
self,
|
|
embeddings: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute Supervised Contrastive loss.
|
|
|
|
Args:
|
|
embeddings: [N, D] L2-normalized embeddings
|
|
labels: [N] integer logo class labels
|
|
|
|
Returns:
|
|
Scalar loss value
|
|
"""
|
|
device = embeddings.device
|
|
batch_size = embeddings.shape[0]
|
|
|
|
if batch_size <= 1:
|
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
|
|
# Compute similarity matrix
|
|
similarity = embeddings @ embeddings.T / self.temperature
|
|
|
|
# Create masks
|
|
labels_col = labels.unsqueeze(0)
|
|
labels_row = labels.unsqueeze(1)
|
|
positive_mask = (labels_row == labels_col).float()
|
|
identity = torch.eye(batch_size, device=device)
|
|
|
|
# Remove self from positives
|
|
positive_mask = positive_mask - identity
|
|
|
|
# Number of positives per anchor
|
|
num_positives = positive_mask.sum(dim=1)
|
|
has_positives = num_positives > 0
|
|
|
|
if not has_positives.any():
|
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
|
|
# For numerical stability, subtract max similarity
|
|
sim_max, _ = similarity.max(dim=1, keepdim=True)
|
|
similarity = similarity - sim_max.detach()
|
|
|
|
# Compute exp(similarity) with self masked out
|
|
exp_sim = torch.exp(similarity) * (1 - identity)
|
|
|
|
# Denominator: sum of exp over all pairs except self
|
|
log_prob = similarity - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8)
|
|
|
|
# Mean of log-prob over positive pairs
|
|
mean_log_prob_pos = (positive_mask * log_prob).sum(dim=1) / (
|
|
num_positives + 1e-8
|
|
)
|
|
|
|
# Loss is negative mean log probability
|
|
loss = -mean_log_prob_pos[has_positives].mean()
|
|
|
|
return loss
|
|
|
|
|
|
class TripletLoss(nn.Module):
|
|
"""
|
|
Triplet loss with online hard mining.
|
|
|
|
For each anchor:
|
|
- Hardest positive: most distant sample with same label
|
|
- Hardest negative: closest sample with different label
|
|
|
|
Loss = max(0, d(anchor, hardest_pos) - d(anchor, hardest_neg) + margin)
|
|
|
|
This is an alternative to InfoNCE for when batch sizes are small.
|
|
"""
|
|
|
|
def __init__(self, margin: float = 0.3):
|
|
"""
|
|
Initialize Triplet loss.
|
|
|
|
Args:
|
|
margin: Minimum required gap between positive and negative distances
|
|
"""
|
|
super().__init__()
|
|
self.margin = margin
|
|
|
|
def forward(
|
|
self,
|
|
embeddings: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute triplet loss with online hard mining.
|
|
|
|
Args:
|
|
embeddings: [N, D] L2-normalized embeddings
|
|
labels: [N] integer logo class labels
|
|
|
|
Returns:
|
|
Scalar loss value
|
|
"""
|
|
device = embeddings.device
|
|
batch_size = embeddings.shape[0]
|
|
|
|
if batch_size <= 1:
|
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
|
|
# Compute pairwise cosine distances (1 - cosine_similarity)
|
|
# For normalized vectors: distance = 1 - dot_product
|
|
similarity = embeddings @ embeddings.T
|
|
distances = 1 - similarity
|
|
|
|
# Create masks
|
|
labels_col = labels.unsqueeze(0)
|
|
labels_row = labels.unsqueeze(1)
|
|
positive_mask = (labels_row == labels_col).float()
|
|
negative_mask = 1 - positive_mask
|
|
|
|
# Remove self from positives (diagonal)
|
|
identity = torch.eye(batch_size, device=device)
|
|
positive_mask = positive_mask - identity
|
|
|
|
# Check if we have any valid triplets
|
|
has_positives = positive_mask.sum(dim=1) > 0
|
|
has_negatives = negative_mask.sum(dim=1) > 0
|
|
valid_anchors = has_positives & has_negatives
|
|
|
|
if not valid_anchors.any():
|
|
return torch.tensor(0.0, device=device, requires_grad=True)
|
|
|
|
# For each anchor, find hardest positive (max distance among positives)
|
|
# Set negatives to -inf so they don't affect max
|
|
pos_distances = distances.clone()
|
|
pos_distances[positive_mask == 0] = float("-inf")
|
|
hardest_positive, _ = pos_distances.max(dim=1)
|
|
|
|
# For each anchor, find hardest negative (min distance among negatives)
|
|
# Set positives to inf so they don't affect min
|
|
neg_distances = distances.clone()
|
|
neg_distances[negative_mask == 0] = float("inf")
|
|
hardest_negative, _ = neg_distances.min(dim=1)
|
|
|
|
# Triplet loss: want positive to be closer than negative by margin
|
|
triplet_loss = F.relu(
|
|
hardest_positive - hardest_negative + self.margin
|
|
)
|
|
|
|
# Average over valid anchors only
|
|
loss = triplet_loss[valid_anchors].mean()
|
|
|
|
return loss
|
|
|
|
|
|
class CombinedLoss(nn.Module):
|
|
"""
|
|
Combined loss function with weighted InfoNCE and Triplet losses.
|
|
|
|
Can help stabilize training by combining the benefits of both losses.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
temperature: float = 0.07,
|
|
triplet_margin: float = 0.3,
|
|
infonce_weight: float = 1.0,
|
|
triplet_weight: float = 0.5,
|
|
):
|
|
super().__init__()
|
|
self.infonce = InfoNCELoss(temperature=temperature)
|
|
self.triplet = TripletLoss(margin=triplet_margin)
|
|
self.infonce_weight = infonce_weight
|
|
self.triplet_weight = triplet_weight
|
|
|
|
def forward(
|
|
self,
|
|
embeddings: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
infonce_loss = self.infonce(embeddings, labels)
|
|
triplet_loss = self.triplet(embeddings, labels)
|
|
|
|
return (
|
|
self.infonce_weight * infonce_loss +
|
|
self.triplet_weight * triplet_loss
|
|
)
|
|
|
|
|
|
def get_loss_function(
|
|
loss_type: str = "infonce",
|
|
temperature: float = 0.07,
|
|
triplet_margin: float = 0.3,
|
|
) -> nn.Module:
|
|
"""
|
|
Factory function to create loss function.
|
|
|
|
Args:
|
|
loss_type: One of "infonce", "supcon", "triplet", or "combined"
|
|
temperature: Temperature for InfoNCE/SupCon
|
|
triplet_margin: Margin for triplet loss
|
|
|
|
Returns:
|
|
Loss function module
|
|
"""
|
|
if loss_type == "infonce":
|
|
return InfoNCELoss(temperature=temperature)
|
|
elif loss_type == "supcon":
|
|
return SupConLoss(temperature=temperature)
|
|
elif loss_type == "triplet":
|
|
return TripletLoss(margin=triplet_margin)
|
|
elif loss_type == "combined":
|
|
return CombinedLoss(
|
|
temperature=temperature,
|
|
triplet_margin=triplet_margin,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown loss type: {loss_type}")
|