Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
340 lines
10 KiB
Python
340 lines
10 KiB
Python
"""
|
|
Evaluation metrics for embedding quality.
|
|
"""
|
|
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
|
|
class EmbeddingEvaluator:
|
|
"""
|
|
Evaluator for embedding quality metrics.
|
|
|
|
Computes metrics that indicate how well the embeddings
|
|
separate different logo classes.
|
|
"""
|
|
|
|
def compute_metrics(
|
|
self,
|
|
embeddings: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
) -> Dict[str, float]:
|
|
"""
|
|
Compute embedding quality metrics.
|
|
|
|
Args:
|
|
embeddings: [N, D] L2-normalized embeddings
|
|
labels: [N] integer class labels
|
|
|
|
Returns:
|
|
Dict with metric names and values
|
|
"""
|
|
device = embeddings.device
|
|
batch_size = embeddings.shape[0]
|
|
|
|
if batch_size <= 1:
|
|
return {
|
|
"mean_pos_sim": 0.0,
|
|
"mean_neg_sim": 0.0,
|
|
"separation": 0.0,
|
|
"recall_at_1": 0.0,
|
|
"recall_at_5": 0.0,
|
|
}
|
|
|
|
# Compute similarity matrix
|
|
similarity = embeddings @ embeddings.T
|
|
|
|
# Create masks
|
|
labels_col = labels.unsqueeze(0)
|
|
labels_row = labels.unsqueeze(1)
|
|
positive_mask = (labels_row == labels_col).float()
|
|
negative_mask = 1 - positive_mask
|
|
|
|
# Remove diagonal from positive mask
|
|
identity = torch.eye(batch_size, device=device)
|
|
positive_mask = positive_mask - identity
|
|
|
|
# Count pairs
|
|
num_positives = positive_mask.sum()
|
|
num_negatives = negative_mask.sum()
|
|
|
|
# Mean positive similarity (excluding self)
|
|
if num_positives > 0:
|
|
pos_sims = (similarity * positive_mask).sum() / num_positives
|
|
mean_pos_sim = pos_sims.item()
|
|
else:
|
|
mean_pos_sim = 0.0
|
|
|
|
# Mean negative similarity
|
|
if num_negatives > 0:
|
|
neg_sims = (similarity * negative_mask).sum() / num_negatives
|
|
mean_neg_sim = neg_sims.item()
|
|
else:
|
|
mean_neg_sim = 0.0
|
|
|
|
# Separation: gap between positive and negative similarity
|
|
separation = mean_pos_sim - mean_neg_sim
|
|
|
|
# Recall@K metrics
|
|
recall_at_1 = self._compute_recall_at_k(similarity, labels, k=1)
|
|
recall_at_5 = self._compute_recall_at_k(similarity, labels, k=5)
|
|
|
|
return {
|
|
"mean_pos_sim": mean_pos_sim,
|
|
"mean_neg_sim": mean_neg_sim,
|
|
"separation": separation,
|
|
"recall_at_1": recall_at_1,
|
|
"recall_at_5": recall_at_5,
|
|
}
|
|
|
|
def _compute_recall_at_k(
|
|
self,
|
|
similarity: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
k: int = 1,
|
|
) -> float:
|
|
"""
|
|
Compute Recall@K for nearest neighbor retrieval.
|
|
|
|
For each sample, check if the k nearest neighbors (excluding self)
|
|
contain at least one sample with the same label.
|
|
|
|
Args:
|
|
similarity: [N, N] similarity matrix
|
|
labels: [N] class labels
|
|
k: Number of neighbors to consider
|
|
|
|
Returns:
|
|
Recall@K score (0 to 1)
|
|
"""
|
|
batch_size = similarity.shape[0]
|
|
if batch_size <= 1:
|
|
return 0.0
|
|
|
|
# Mask out self-similarity
|
|
similarity = similarity.clone()
|
|
similarity.fill_diagonal_(float("-inf"))
|
|
|
|
# Get top-k indices
|
|
_, top_k_indices = similarity.topk(min(k, batch_size - 1), dim=1)
|
|
|
|
# Check if any of top-k have same label
|
|
correct = 0
|
|
for i in range(batch_size):
|
|
query_label = labels[i]
|
|
retrieved_labels = labels[top_k_indices[i]]
|
|
if (retrieved_labels == query_label).any():
|
|
correct += 1
|
|
|
|
return correct / batch_size
|
|
|
|
def compute_detailed_metrics(
|
|
self,
|
|
embeddings: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
label_names: Optional[List[str]] = None,
|
|
) -> Dict:
|
|
"""
|
|
Compute detailed per-class metrics.
|
|
|
|
Args:
|
|
embeddings: [N, D] embeddings
|
|
labels: [N] class labels
|
|
label_names: Optional list of label names
|
|
|
|
Returns:
|
|
Dict with detailed metrics including per-class stats
|
|
"""
|
|
basic_metrics = self.compute_metrics(embeddings, labels)
|
|
|
|
# Per-class statistics
|
|
unique_labels = labels.unique()
|
|
per_class_stats = {}
|
|
|
|
similarity = embeddings @ embeddings.T
|
|
|
|
for label in unique_labels:
|
|
mask = labels == label
|
|
class_embeddings = embeddings[mask]
|
|
class_size = mask.sum().item()
|
|
|
|
if class_size > 1:
|
|
# Intra-class similarity
|
|
class_sim = class_embeddings @ class_embeddings.T
|
|
# Exclude diagonal
|
|
mask_diag = ~torch.eye(class_size, dtype=torch.bool, device=class_sim.device)
|
|
intra_sim = class_sim[mask_diag].mean().item()
|
|
else:
|
|
intra_sim = 1.0
|
|
|
|
# Inter-class similarity (to other classes)
|
|
other_mask = labels != label
|
|
if other_mask.any():
|
|
inter_sim = similarity[mask][:, other_mask].mean().item()
|
|
else:
|
|
inter_sim = 0.0
|
|
|
|
class_name = label_names[label.item()] if label_names else str(label.item())
|
|
per_class_stats[class_name] = {
|
|
"size": class_size,
|
|
"intra_class_sim": intra_sim,
|
|
"inter_class_sim": inter_sim,
|
|
"class_separation": intra_sim - inter_sim,
|
|
}
|
|
|
|
# Aggregate per-class stats
|
|
if per_class_stats:
|
|
separations = [s["class_separation"] for s in per_class_stats.values()]
|
|
min_separation = min(separations)
|
|
max_separation = max(separations)
|
|
std_separation = np.std(separations)
|
|
else:
|
|
min_separation = max_separation = std_separation = 0.0
|
|
|
|
return {
|
|
**basic_metrics,
|
|
"per_class": per_class_stats,
|
|
"min_class_separation": min_separation,
|
|
"max_class_separation": max_separation,
|
|
"std_class_separation": std_separation,
|
|
}
|
|
|
|
|
|
class SimilarityAnalyzer:
|
|
"""
|
|
Analyze similarity distributions for debugging and tuning.
|
|
"""
|
|
|
|
@staticmethod
|
|
def analyze_similarity_distribution(
|
|
embeddings: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Get similarity distributions for positive and negative pairs.
|
|
|
|
Useful for choosing appropriate thresholds.
|
|
|
|
Args:
|
|
embeddings: [N, D] embeddings
|
|
labels: [N] class labels
|
|
|
|
Returns:
|
|
Dict with 'positive_sims' and 'negative_sims' arrays
|
|
"""
|
|
similarity = (embeddings @ embeddings.T).cpu().numpy()
|
|
labels_np = labels.cpu().numpy()
|
|
|
|
batch_size = len(labels_np)
|
|
positive_sims = []
|
|
negative_sims = []
|
|
|
|
for i in range(batch_size):
|
|
for j in range(i + 1, batch_size):
|
|
if labels_np[i] == labels_np[j]:
|
|
positive_sims.append(similarity[i, j])
|
|
else:
|
|
negative_sims.append(similarity[i, j])
|
|
|
|
return {
|
|
"positive_sims": np.array(positive_sims),
|
|
"negative_sims": np.array(negative_sims),
|
|
}
|
|
|
|
@staticmethod
|
|
def find_hard_pairs(
|
|
embeddings: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
n_hard: int = 10,
|
|
) -> Tuple[List[Tuple[int, int, float]], List[Tuple[int, int, float]]]:
|
|
"""
|
|
Find hardest positive and negative pairs.
|
|
|
|
Hard positives: same label but low similarity
|
|
Hard negatives: different label but high similarity
|
|
|
|
Args:
|
|
embeddings: [N, D] embeddings
|
|
labels: [N] class labels
|
|
n_hard: Number of hard pairs to return
|
|
|
|
Returns:
|
|
Tuple of (hard_positives, hard_negatives)
|
|
Each is a list of (idx1, idx2, similarity) tuples
|
|
"""
|
|
similarity = embeddings @ embeddings.T
|
|
batch_size = len(labels)
|
|
|
|
hard_positives = [] # Low similarity, same label
|
|
hard_negatives = [] # High similarity, different label
|
|
|
|
for i in range(batch_size):
|
|
for j in range(i + 1, batch_size):
|
|
sim = similarity[i, j].item()
|
|
if labels[i] == labels[j]:
|
|
hard_positives.append((i, j, sim))
|
|
else:
|
|
hard_negatives.append((i, j, sim))
|
|
|
|
# Sort: hard positives by ascending similarity (lowest first)
|
|
hard_positives.sort(key=lambda x: x[2])
|
|
|
|
# Sort: hard negatives by descending similarity (highest first)
|
|
hard_negatives.sort(key=lambda x: -x[2])
|
|
|
|
return hard_positives[:n_hard], hard_negatives[:n_hard]
|
|
|
|
@staticmethod
|
|
def compute_confusion_pairs(
|
|
embeddings: torch.Tensor,
|
|
labels: torch.Tensor,
|
|
label_names: Optional[List[str]] = None,
|
|
top_k: int = 10,
|
|
) -> List[Dict]:
|
|
"""
|
|
Find pairs of classes that are most confused (highest cross-class similarity).
|
|
|
|
Args:
|
|
embeddings: [N, D] embeddings
|
|
labels: [N] class labels
|
|
label_names: Optional label names
|
|
top_k: Number of confused pairs to return
|
|
|
|
Returns:
|
|
List of dicts with class pairs and their similarity
|
|
"""
|
|
unique_labels = labels.unique()
|
|
class_centroids = {}
|
|
|
|
# Compute class centroids
|
|
for label in unique_labels:
|
|
mask = labels == label
|
|
centroid = embeddings[mask].mean(dim=0)
|
|
centroid = F.normalize(centroid, dim=0)
|
|
class_centroids[label.item()] = centroid
|
|
|
|
# Compute pairwise centroid similarities
|
|
confusions = []
|
|
label_list = list(class_centroids.keys())
|
|
|
|
for i, label1 in enumerate(label_list):
|
|
for label2 in label_list[i + 1:]:
|
|
sim = (class_centroids[label1] @ class_centroids[label2]).item()
|
|
name1 = label_names[label1] if label_names else str(label1)
|
|
name2 = label_names[label2] if label_names else str(label2)
|
|
confusions.append({
|
|
"class1": name1,
|
|
"class2": name2,
|
|
"label1": label1,
|
|
"label2": label2,
|
|
"centroid_similarity": sim,
|
|
})
|
|
|
|
# Sort by similarity (highest first)
|
|
confusions.sort(key=lambda x: -x["centroid_similarity"])
|
|
|
|
return confusions[:top_k]
|