Files
logo_test/training/evaluation.py
Rick McEwen 44e8b6ae7d Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision
encoder on LogoDet-3K dataset for improved logo embedding similarity.

New training module (training/):
- config.py: TrainingConfig dataclass with all hyperparameters
- dataset.py: LogoContrastiveDataset with logo-level splits
- model.py: LogoFineTunedCLIP wrapper with LoRA support
- losses.py: InfoNCE, TripletLoss, SupConLoss implementations
- trainer.py: Training loop with mixed precision and checkpointing
- evaluation.py: EmbeddingEvaluator for validation metrics

New scripts:
- train_clip_logo.py: Main training entry point
- export_model.py: Export to HuggingFace-compatible format

Configurations:
- configs/jetson_orin.yaml: Optimized for Jetson Orin AGX
- configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs
- configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs

Documentation:
- CLIP_FINETUNING.md: Training guide and usage instructions
- CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates

Modified:
- logo_detection_detr.py: Add fine-tuned model loading support
- pyproject.toml: Add peft, pyyaml, torchvision dependencies
2026-01-04 13:45:25 -05:00

340 lines
10 KiB
Python

"""
Evaluation metrics for embedding quality.
"""
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import numpy as np
class EmbeddingEvaluator:
"""
Evaluator for embedding quality metrics.
Computes metrics that indicate how well the embeddings
separate different logo classes.
"""
def compute_metrics(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> Dict[str, float]:
"""
Compute embedding quality metrics.
Args:
embeddings: [N, D] L2-normalized embeddings
labels: [N] integer class labels
Returns:
Dict with metric names and values
"""
device = embeddings.device
batch_size = embeddings.shape[0]
if batch_size <= 1:
return {
"mean_pos_sim": 0.0,
"mean_neg_sim": 0.0,
"separation": 0.0,
"recall_at_1": 0.0,
"recall_at_5": 0.0,
}
# Compute similarity matrix
similarity = embeddings @ embeddings.T
# Create masks
labels_col = labels.unsqueeze(0)
labels_row = labels.unsqueeze(1)
positive_mask = (labels_row == labels_col).float()
negative_mask = 1 - positive_mask
# Remove diagonal from positive mask
identity = torch.eye(batch_size, device=device)
positive_mask = positive_mask - identity
# Count pairs
num_positives = positive_mask.sum()
num_negatives = negative_mask.sum()
# Mean positive similarity (excluding self)
if num_positives > 0:
pos_sims = (similarity * positive_mask).sum() / num_positives
mean_pos_sim = pos_sims.item()
else:
mean_pos_sim = 0.0
# Mean negative similarity
if num_negatives > 0:
neg_sims = (similarity * negative_mask).sum() / num_negatives
mean_neg_sim = neg_sims.item()
else:
mean_neg_sim = 0.0
# Separation: gap between positive and negative similarity
separation = mean_pos_sim - mean_neg_sim
# Recall@K metrics
recall_at_1 = self._compute_recall_at_k(similarity, labels, k=1)
recall_at_5 = self._compute_recall_at_k(similarity, labels, k=5)
return {
"mean_pos_sim": mean_pos_sim,
"mean_neg_sim": mean_neg_sim,
"separation": separation,
"recall_at_1": recall_at_1,
"recall_at_5": recall_at_5,
}
def _compute_recall_at_k(
self,
similarity: torch.Tensor,
labels: torch.Tensor,
k: int = 1,
) -> float:
"""
Compute Recall@K for nearest neighbor retrieval.
For each sample, check if the k nearest neighbors (excluding self)
contain at least one sample with the same label.
Args:
similarity: [N, N] similarity matrix
labels: [N] class labels
k: Number of neighbors to consider
Returns:
Recall@K score (0 to 1)
"""
batch_size = similarity.shape[0]
if batch_size <= 1:
return 0.0
# Mask out self-similarity
similarity = similarity.clone()
similarity.fill_diagonal_(float("-inf"))
# Get top-k indices
_, top_k_indices = similarity.topk(min(k, batch_size - 1), dim=1)
# Check if any of top-k have same label
correct = 0
for i in range(batch_size):
query_label = labels[i]
retrieved_labels = labels[top_k_indices[i]]
if (retrieved_labels == query_label).any():
correct += 1
return correct / batch_size
def compute_detailed_metrics(
self,
embeddings: torch.Tensor,
labels: torch.Tensor,
label_names: Optional[List[str]] = None,
) -> Dict:
"""
Compute detailed per-class metrics.
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
label_names: Optional list of label names
Returns:
Dict with detailed metrics including per-class stats
"""
basic_metrics = self.compute_metrics(embeddings, labels)
# Per-class statistics
unique_labels = labels.unique()
per_class_stats = {}
similarity = embeddings @ embeddings.T
for label in unique_labels:
mask = labels == label
class_embeddings = embeddings[mask]
class_size = mask.sum().item()
if class_size > 1:
# Intra-class similarity
class_sim = class_embeddings @ class_embeddings.T
# Exclude diagonal
mask_diag = ~torch.eye(class_size, dtype=torch.bool, device=class_sim.device)
intra_sim = class_sim[mask_diag].mean().item()
else:
intra_sim = 1.0
# Inter-class similarity (to other classes)
other_mask = labels != label
if other_mask.any():
inter_sim = similarity[mask][:, other_mask].mean().item()
else:
inter_sim = 0.0
class_name = label_names[label.item()] if label_names else str(label.item())
per_class_stats[class_name] = {
"size": class_size,
"intra_class_sim": intra_sim,
"inter_class_sim": inter_sim,
"class_separation": intra_sim - inter_sim,
}
# Aggregate per-class stats
if per_class_stats:
separations = [s["class_separation"] for s in per_class_stats.values()]
min_separation = min(separations)
max_separation = max(separations)
std_separation = np.std(separations)
else:
min_separation = max_separation = std_separation = 0.0
return {
**basic_metrics,
"per_class": per_class_stats,
"min_class_separation": min_separation,
"max_class_separation": max_separation,
"std_class_separation": std_separation,
}
class SimilarityAnalyzer:
"""
Analyze similarity distributions for debugging and tuning.
"""
@staticmethod
def analyze_similarity_distribution(
embeddings: torch.Tensor,
labels: torch.Tensor,
) -> Dict[str, np.ndarray]:
"""
Get similarity distributions for positive and negative pairs.
Useful for choosing appropriate thresholds.
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
Returns:
Dict with 'positive_sims' and 'negative_sims' arrays
"""
similarity = (embeddings @ embeddings.T).cpu().numpy()
labels_np = labels.cpu().numpy()
batch_size = len(labels_np)
positive_sims = []
negative_sims = []
for i in range(batch_size):
for j in range(i + 1, batch_size):
if labels_np[i] == labels_np[j]:
positive_sims.append(similarity[i, j])
else:
negative_sims.append(similarity[i, j])
return {
"positive_sims": np.array(positive_sims),
"negative_sims": np.array(negative_sims),
}
@staticmethod
def find_hard_pairs(
embeddings: torch.Tensor,
labels: torch.Tensor,
n_hard: int = 10,
) -> Tuple[List[Tuple[int, int, float]], List[Tuple[int, int, float]]]:
"""
Find hardest positive and negative pairs.
Hard positives: same label but low similarity
Hard negatives: different label but high similarity
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
n_hard: Number of hard pairs to return
Returns:
Tuple of (hard_positives, hard_negatives)
Each is a list of (idx1, idx2, similarity) tuples
"""
similarity = embeddings @ embeddings.T
batch_size = len(labels)
hard_positives = [] # Low similarity, same label
hard_negatives = [] # High similarity, different label
for i in range(batch_size):
for j in range(i + 1, batch_size):
sim = similarity[i, j].item()
if labels[i] == labels[j]:
hard_positives.append((i, j, sim))
else:
hard_negatives.append((i, j, sim))
# Sort: hard positives by ascending similarity (lowest first)
hard_positives.sort(key=lambda x: x[2])
# Sort: hard negatives by descending similarity (highest first)
hard_negatives.sort(key=lambda x: -x[2])
return hard_positives[:n_hard], hard_negatives[:n_hard]
@staticmethod
def compute_confusion_pairs(
embeddings: torch.Tensor,
labels: torch.Tensor,
label_names: Optional[List[str]] = None,
top_k: int = 10,
) -> List[Dict]:
"""
Find pairs of classes that are most confused (highest cross-class similarity).
Args:
embeddings: [N, D] embeddings
labels: [N] class labels
label_names: Optional label names
top_k: Number of confused pairs to return
Returns:
List of dicts with class pairs and their similarity
"""
unique_labels = labels.unique()
class_centroids = {}
# Compute class centroids
for label in unique_labels:
mask = labels == label
centroid = embeddings[mask].mean(dim=0)
centroid = F.normalize(centroid, dim=0)
class_centroids[label.item()] = centroid
# Compute pairwise centroid similarities
confusions = []
label_list = list(class_centroids.keys())
for i, label1 in enumerate(label_list):
for label2 in label_list[i + 1:]:
sim = (class_centroids[label1] @ class_centroids[label2]).item()
name1 = label_names[label1] if label_names else str(label1)
name2 = label_names[label2] if label_names else str(label2)
confusions.append({
"class1": name1,
"class2": name2,
"label1": label1,
"label2": label2,
"centroid_similarity": sim,
})
# Sort by similarity (highest first)
confusions.sort(key=lambda x: -x["centroid_similarity"])
return confusions[:top_k]