""" Loss functions for contrastive learning on logo embeddings. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional class InfoNCELoss(nn.Module): """ Normalized Temperature-scaled Cross Entropy Loss (InfoNCE). This is the contrastive loss used in CLIP training. It maximizes similarity between embeddings of the same logo class while minimizing similarity to embeddings of different classes. For a batch with N samples: - Each sample is an anchor - Positive pairs: samples with the same label - Negative pairs: samples with different labels The loss for each anchor is: -log(sum(exp(sim(anchor, pos)/temp)) / sum(exp(sim(anchor, all)/temp))) """ def __init__(self, temperature: float = 0.07): """ Initialize InfoNCE loss. Args: temperature: Scaling factor for similarities (0.05-0.1 typical). Lower temperature makes the distribution sharper. """ super().__init__() self.temperature = temperature def forward( self, embeddings: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: """ Compute InfoNCE loss for a batch of embeddings. Args: embeddings: [N, D] L2-normalized embeddings labels: [N] integer logo class labels Returns: Scalar loss value """ device = embeddings.device batch_size = embeddings.shape[0] if batch_size <= 1: return torch.tensor(0.0, device=device, requires_grad=True) # Compute similarity matrix [N, N] # Since embeddings are L2-normalized, dot product = cosine similarity similarity = embeddings @ embeddings.T / self.temperature # Create positive mask: same label = 1, different = 0 labels_col = labels.unsqueeze(0) # [1, N] labels_row = labels.unsqueeze(1) # [N, 1] positive_mask = (labels_row == labels_col).float() # [N, N] # Remove self-similarity from positives (diagonal) identity = torch.eye(batch_size, device=device) positive_mask = positive_mask - identity # Count positives per anchor (avoid division by zero) num_positives = positive_mask.sum(dim=1) has_positives = num_positives > 0 # If no positives exist for any anchor, return zero loss if not has_positives.any(): return torch.tensor(0.0, device=device, requires_grad=True) # Mask out self-similarity with large negative value similarity = similarity - identity * 1e9 # Compute log-softmax over similarities log_softmax = F.log_softmax(similarity, dim=1) # Sum log probabilities of positive pairs positive_log_probs = (log_softmax * positive_mask).sum(dim=1) # Average over number of positives (only for anchors with positives) loss_per_anchor = torch.zeros(batch_size, device=device) loss_per_anchor[has_positives] = ( -positive_log_probs[has_positives] / num_positives[has_positives] ) return loss_per_anchor.mean() class SupConLoss(nn.Module): """ Supervised Contrastive Loss. Similar to InfoNCE but uses a different formulation that considers each positive pair separately rather than averaging. Reference: https://arxiv.org/abs/2004.11362 """ def __init__(self, temperature: float = 0.07): super().__init__() self.temperature = temperature def forward( self, embeddings: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: """ Compute Supervised Contrastive loss. Args: embeddings: [N, D] L2-normalized embeddings labels: [N] integer logo class labels Returns: Scalar loss value """ device = embeddings.device batch_size = embeddings.shape[0] if batch_size <= 1: return torch.tensor(0.0, device=device, requires_grad=True) # Compute similarity matrix similarity = embeddings @ embeddings.T / self.temperature # Create masks labels_col = labels.unsqueeze(0) labels_row = labels.unsqueeze(1) positive_mask = (labels_row == labels_col).float() identity = torch.eye(batch_size, device=device) # Remove self from positives positive_mask = positive_mask - identity # Number of positives per anchor num_positives = positive_mask.sum(dim=1) has_positives = num_positives > 0 if not has_positives.any(): return torch.tensor(0.0, device=device, requires_grad=True) # For numerical stability, subtract max similarity sim_max, _ = similarity.max(dim=1, keepdim=True) similarity = similarity - sim_max.detach() # Compute exp(similarity) with self masked out exp_sim = torch.exp(similarity) * (1 - identity) # Denominator: sum of exp over all pairs except self log_prob = similarity - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-8) # Mean of log-prob over positive pairs mean_log_prob_pos = (positive_mask * log_prob).sum(dim=1) / ( num_positives + 1e-8 ) # Loss is negative mean log probability loss = -mean_log_prob_pos[has_positives].mean() return loss class TripletLoss(nn.Module): """ Triplet loss with online hard mining. For each anchor: - Hardest positive: most distant sample with same label - Hardest negative: closest sample with different label Loss = max(0, d(anchor, hardest_pos) - d(anchor, hardest_neg) + margin) This is an alternative to InfoNCE for when batch sizes are small. """ def __init__(self, margin: float = 0.3): """ Initialize Triplet loss. Args: margin: Minimum required gap between positive and negative distances """ super().__init__() self.margin = margin def forward( self, embeddings: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: """ Compute triplet loss with online hard mining. Args: embeddings: [N, D] L2-normalized embeddings labels: [N] integer logo class labels Returns: Scalar loss value """ device = embeddings.device batch_size = embeddings.shape[0] if batch_size <= 1: return torch.tensor(0.0, device=device, requires_grad=True) # Compute pairwise cosine distances (1 - cosine_similarity) # For normalized vectors: distance = 1 - dot_product similarity = embeddings @ embeddings.T distances = 1 - similarity # Create masks labels_col = labels.unsqueeze(0) labels_row = labels.unsqueeze(1) positive_mask = (labels_row == labels_col).float() negative_mask = 1 - positive_mask # Remove self from positives (diagonal) identity = torch.eye(batch_size, device=device) positive_mask = positive_mask - identity # Check if we have any valid triplets has_positives = positive_mask.sum(dim=1) > 0 has_negatives = negative_mask.sum(dim=1) > 0 valid_anchors = has_positives & has_negatives if not valid_anchors.any(): return torch.tensor(0.0, device=device, requires_grad=True) # For each anchor, find hardest positive (max distance among positives) # Set negatives to -inf so they don't affect max pos_distances = distances.clone() pos_distances[positive_mask == 0] = float("-inf") hardest_positive, _ = pos_distances.max(dim=1) # For each anchor, find hardest negative (min distance among negatives) # Set positives to inf so they don't affect min neg_distances = distances.clone() neg_distances[negative_mask == 0] = float("inf") hardest_negative, _ = neg_distances.min(dim=1) # Triplet loss: want positive to be closer than negative by margin triplet_loss = F.relu( hardest_positive - hardest_negative + self.margin ) # Average over valid anchors only loss = triplet_loss[valid_anchors].mean() return loss class CombinedLoss(nn.Module): """ Combined loss function with weighted InfoNCE and Triplet losses. Can help stabilize training by combining the benefits of both losses. """ def __init__( self, temperature: float = 0.07, triplet_margin: float = 0.3, infonce_weight: float = 1.0, triplet_weight: float = 0.5, ): super().__init__() self.infonce = InfoNCELoss(temperature=temperature) self.triplet = TripletLoss(margin=triplet_margin) self.infonce_weight = infonce_weight self.triplet_weight = triplet_weight def forward( self, embeddings: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: infonce_loss = self.infonce(embeddings, labels) triplet_loss = self.triplet(embeddings, labels) return ( self.infonce_weight * infonce_loss + self.triplet_weight * triplet_loss ) def get_loss_function( loss_type: str = "infonce", temperature: float = 0.07, triplet_margin: float = 0.3, ) -> nn.Module: """ Factory function to create loss function. Args: loss_type: One of "infonce", "supcon", "triplet", or "combined" temperature: Temperature for InfoNCE/SupCon triplet_margin: Margin for triplet loss Returns: Loss function module """ if loss_type == "infonce": return InfoNCELoss(temperature=temperature) elif loss_type == "supcon": return SupConLoss(temperature=temperature) elif loss_type == "triplet": return TripletLoss(margin=triplet_margin) elif loss_type == "combined": return CombinedLoss( temperature=temperature, triplet_margin=triplet_margin, ) else: raise ValueError(f"Unknown loss type: {loss_type}")