""" Evaluation metrics for embedding quality. """ from typing import Dict, List, Optional, Tuple import torch import torch.nn.functional as F import numpy as np class EmbeddingEvaluator: """ Evaluator for embedding quality metrics. Computes metrics that indicate how well the embeddings separate different logo classes. """ def compute_metrics( self, embeddings: torch.Tensor, labels: torch.Tensor, ) -> Dict[str, float]: """ Compute embedding quality metrics. Args: embeddings: [N, D] L2-normalized embeddings labels: [N] integer class labels Returns: Dict with metric names and values """ device = embeddings.device batch_size = embeddings.shape[0] if batch_size <= 1: return { "mean_pos_sim": 0.0, "mean_neg_sim": 0.0, "separation": 0.0, "recall_at_1": 0.0, "recall_at_5": 0.0, } # Compute similarity matrix similarity = embeddings @ embeddings.T # Create masks labels_col = labels.unsqueeze(0) labels_row = labels.unsqueeze(1) positive_mask = (labels_row == labels_col).float() negative_mask = 1 - positive_mask # Remove diagonal from positive mask identity = torch.eye(batch_size, device=device) positive_mask = positive_mask - identity # Count pairs num_positives = positive_mask.sum() num_negatives = negative_mask.sum() # Mean positive similarity (excluding self) if num_positives > 0: pos_sims = (similarity * positive_mask).sum() / num_positives mean_pos_sim = pos_sims.item() else: mean_pos_sim = 0.0 # Mean negative similarity if num_negatives > 0: neg_sims = (similarity * negative_mask).sum() / num_negatives mean_neg_sim = neg_sims.item() else: mean_neg_sim = 0.0 # Separation: gap between positive and negative similarity separation = mean_pos_sim - mean_neg_sim # Recall@K metrics recall_at_1 = self._compute_recall_at_k(similarity, labels, k=1) recall_at_5 = self._compute_recall_at_k(similarity, labels, k=5) return { "mean_pos_sim": mean_pos_sim, "mean_neg_sim": mean_neg_sim, "separation": separation, "recall_at_1": recall_at_1, "recall_at_5": recall_at_5, } def _compute_recall_at_k( self, similarity: torch.Tensor, labels: torch.Tensor, k: int = 1, ) -> float: """ Compute Recall@K for nearest neighbor retrieval. For each sample, check if the k nearest neighbors (excluding self) contain at least one sample with the same label. Args: similarity: [N, N] similarity matrix labels: [N] class labels k: Number of neighbors to consider Returns: Recall@K score (0 to 1) """ batch_size = similarity.shape[0] if batch_size <= 1: return 0.0 # Mask out self-similarity similarity = similarity.clone() similarity.fill_diagonal_(float("-inf")) # Get top-k indices _, top_k_indices = similarity.topk(min(k, batch_size - 1), dim=1) # Check if any of top-k have same label correct = 0 for i in range(batch_size): query_label = labels[i] retrieved_labels = labels[top_k_indices[i]] if (retrieved_labels == query_label).any(): correct += 1 return correct / batch_size def compute_detailed_metrics( self, embeddings: torch.Tensor, labels: torch.Tensor, label_names: Optional[List[str]] = None, ) -> Dict: """ Compute detailed per-class metrics. Args: embeddings: [N, D] embeddings labels: [N] class labels label_names: Optional list of label names Returns: Dict with detailed metrics including per-class stats """ basic_metrics = self.compute_metrics(embeddings, labels) # Per-class statistics unique_labels = labels.unique() per_class_stats = {} similarity = embeddings @ embeddings.T for label in unique_labels: mask = labels == label class_embeddings = embeddings[mask] class_size = mask.sum().item() if class_size > 1: # Intra-class similarity class_sim = class_embeddings @ class_embeddings.T # Exclude diagonal mask_diag = ~torch.eye(class_size, dtype=torch.bool, device=class_sim.device) intra_sim = class_sim[mask_diag].mean().item() else: intra_sim = 1.0 # Inter-class similarity (to other classes) other_mask = labels != label if other_mask.any(): inter_sim = similarity[mask][:, other_mask].mean().item() else: inter_sim = 0.0 class_name = label_names[label.item()] if label_names else str(label.item()) per_class_stats[class_name] = { "size": class_size, "intra_class_sim": intra_sim, "inter_class_sim": inter_sim, "class_separation": intra_sim - inter_sim, } # Aggregate per-class stats if per_class_stats: separations = [s["class_separation"] for s in per_class_stats.values()] min_separation = min(separations) max_separation = max(separations) std_separation = np.std(separations) else: min_separation = max_separation = std_separation = 0.0 return { **basic_metrics, "per_class": per_class_stats, "min_class_separation": min_separation, "max_class_separation": max_separation, "std_class_separation": std_separation, } class SimilarityAnalyzer: """ Analyze similarity distributions for debugging and tuning. """ @staticmethod def analyze_similarity_distribution( embeddings: torch.Tensor, labels: torch.Tensor, ) -> Dict[str, np.ndarray]: """ Get similarity distributions for positive and negative pairs. Useful for choosing appropriate thresholds. Args: embeddings: [N, D] embeddings labels: [N] class labels Returns: Dict with 'positive_sims' and 'negative_sims' arrays """ similarity = (embeddings @ embeddings.T).cpu().numpy() labels_np = labels.cpu().numpy() batch_size = len(labels_np) positive_sims = [] negative_sims = [] for i in range(batch_size): for j in range(i + 1, batch_size): if labels_np[i] == labels_np[j]: positive_sims.append(similarity[i, j]) else: negative_sims.append(similarity[i, j]) return { "positive_sims": np.array(positive_sims), "negative_sims": np.array(negative_sims), } @staticmethod def find_hard_pairs( embeddings: torch.Tensor, labels: torch.Tensor, n_hard: int = 10, ) -> Tuple[List[Tuple[int, int, float]], List[Tuple[int, int, float]]]: """ Find hardest positive and negative pairs. Hard positives: same label but low similarity Hard negatives: different label but high similarity Args: embeddings: [N, D] embeddings labels: [N] class labels n_hard: Number of hard pairs to return Returns: Tuple of (hard_positives, hard_negatives) Each is a list of (idx1, idx2, similarity) tuples """ similarity = embeddings @ embeddings.T batch_size = len(labels) hard_positives = [] # Low similarity, same label hard_negatives = [] # High similarity, different label for i in range(batch_size): for j in range(i + 1, batch_size): sim = similarity[i, j].item() if labels[i] == labels[j]: hard_positives.append((i, j, sim)) else: hard_negatives.append((i, j, sim)) # Sort: hard positives by ascending similarity (lowest first) hard_positives.sort(key=lambda x: x[2]) # Sort: hard negatives by descending similarity (highest first) hard_negatives.sort(key=lambda x: -x[2]) return hard_positives[:n_hard], hard_negatives[:n_hard] @staticmethod def compute_confusion_pairs( embeddings: torch.Tensor, labels: torch.Tensor, label_names: Optional[List[str]] = None, top_k: int = 10, ) -> List[Dict]: """ Find pairs of classes that are most confused (highest cross-class similarity). Args: embeddings: [N, D] embeddings labels: [N] class labels label_names: Optional label names top_k: Number of confused pairs to return Returns: List of dicts with class pairs and their similarity """ unique_labels = labels.unique() class_centroids = {} # Compute class centroids for label in unique_labels: mask = labels == label centroid = embeddings[mask].mean(dim=0) centroid = F.normalize(centroid, dim=0) class_centroids[label.item()] = centroid # Compute pairwise centroid similarities confusions = [] label_list = list(class_centroids.keys()) for i, label1 in enumerate(label_list): for label2 in label_list[i + 1:]: sim = (class_centroids[label1] @ class_centroids[label2]).item() name1 = label_names[label1] if label_names else str(label1) name2 = label_names[label2] if label_names else str(label2) confusions.append({ "class1": name1, "class2": name2, "label1": label1, "label2": label2, "centroid_similarity": sim, }) # Sort by similarity (highest first) confusions.sort(key=lambda x: -x["centroid_similarity"]) return confusions[:top_k]