- Add --similarity-details flag to test_logo_detection.py - Track true positive, false positive, and missed detection similarities - Compute distribution statistics (min, max, mean, stddev, percentiles) - Analyze overlap between TP and FP distributions - Suggest optimal threshold based on data - Show per-detection breakdown with top-5 matches - Create analyze_similarity_distribution.sh wrapper script - Supports baseline, finetuned, or both models - Saves output to similarity_analysis/ directory
870 lines
32 KiB
Python
Executable File
870 lines
32 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Test script for logo detection accuracy.
|
|
|
|
This script:
|
|
1. Randomly samples N reference logos from the database
|
|
2. Processes all test images through the DETR+CLIP pipeline
|
|
3. Compares detected logos against reference embeddings
|
|
4. Reports accuracy metrics (correct matches, false positives, missed detections)
|
|
|
|
Embeddings are cached to avoid reprocessing images.
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import pickle
|
|
import random
|
|
import sqlite3
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Set, Tuple
|
|
|
|
import cv2
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from logo_detection_detr import DetectLogosDETR
|
|
|
|
|
|
def setup_logging(verbose: bool = False) -> logging.Logger:
|
|
"""Configure logging."""
|
|
level = logging.DEBUG if verbose else logging.INFO
|
|
logging.basicConfig(
|
|
level=level,
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
return logging.getLogger(__name__)
|
|
|
|
|
|
def load_image(image_path: Path) -> Optional[cv2.Mat]:
|
|
"""Load an image using OpenCV."""
|
|
img = cv2.imread(str(image_path))
|
|
if img is None:
|
|
return None
|
|
return img
|
|
|
|
|
|
class EmbeddingCache:
|
|
"""Simple file-based cache for embeddings."""
|
|
|
|
def __init__(self, cache_path: Path):
|
|
self.cache_path = cache_path
|
|
self.cache: Dict[str, torch.Tensor] = {}
|
|
self._load()
|
|
|
|
def _load(self):
|
|
"""Load cache from disk if it exists."""
|
|
if self.cache_path.exists():
|
|
try:
|
|
with open(self.cache_path, "rb") as f:
|
|
self.cache = pickle.load(f)
|
|
except Exception:
|
|
self.cache = {}
|
|
|
|
def save(self):
|
|
"""Save cache to disk."""
|
|
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(self.cache_path, "wb") as f:
|
|
pickle.dump(self.cache, f)
|
|
|
|
def get(self, key: str) -> Optional[torch.Tensor]:
|
|
"""Get embedding from cache."""
|
|
return self.cache.get(key)
|
|
|
|
def put(self, key: str, embedding: torch.Tensor):
|
|
"""Store embedding in cache."""
|
|
# Store on CPU to save GPU memory
|
|
self.cache[key] = embedding.cpu()
|
|
|
|
def __len__(self):
|
|
return len(self.cache)
|
|
|
|
|
|
def get_ground_truth(db_path: Path) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
|
|
"""
|
|
Load ground truth from database.
|
|
|
|
Returns:
|
|
Tuple of:
|
|
- Dict mapping test image filename to set of logo names it contains
|
|
- Dict mapping logo name to set of test image filenames containing it
|
|
"""
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Query to get test image -> logo names mapping
|
|
cursor.execute("""
|
|
SELECT ti.filename, ln.name
|
|
FROM test_images ti
|
|
JOIN reference_logos rl ON ti.id = rl.test_image_id
|
|
JOIN logo_names ln ON rl.logo_name_id = ln.id
|
|
""")
|
|
|
|
image_to_logos: Dict[str, Set[str]] = {}
|
|
logo_to_images: Dict[str, Set[str]] = {}
|
|
|
|
for row in cursor.fetchall():
|
|
test_filename, logo_name = row
|
|
if test_filename not in image_to_logos:
|
|
image_to_logos[test_filename] = set()
|
|
image_to_logos[test_filename].add(logo_name)
|
|
|
|
if logo_name not in logo_to_images:
|
|
logo_to_images[logo_name] = set()
|
|
logo_to_images[logo_name].add(test_filename)
|
|
|
|
conn.close()
|
|
return image_to_logos, logo_to_images
|
|
|
|
|
|
def sample_reference_logos(
|
|
db_path: Path, num_logos: int, refs_per_logo: int = 1, seed: Optional[int] = None
|
|
) -> Dict[str, List[str]]:
|
|
"""
|
|
Randomly sample reference logos from database with multiple refs per logo.
|
|
|
|
Args:
|
|
db_path: Path to database
|
|
num_logos: Number of logos to sample
|
|
refs_per_logo: Number of reference images per logo
|
|
seed: Random seed for reproducibility
|
|
|
|
Returns:
|
|
Dict mapping logo_name to list of reference filenames
|
|
"""
|
|
if seed is not None:
|
|
random.seed(seed)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Get all unique logo names
|
|
cursor.execute("SELECT id, name FROM logo_names")
|
|
all_logo_names = cursor.fetchall()
|
|
|
|
# Sample logos
|
|
if num_logos >= len(all_logo_names):
|
|
sampled_logos = all_logo_names
|
|
else:
|
|
sampled_logos = random.sample(all_logo_names, num_logos)
|
|
|
|
# For each sampled logo, get multiple reference files
|
|
result: Dict[str, List[str]] = {}
|
|
for logo_id, logo_name in sampled_logos:
|
|
cursor.execute(
|
|
"SELECT filename FROM reference_logos WHERE logo_name_id = ?",
|
|
(logo_id,)
|
|
)
|
|
all_refs = [row[0] for row in cursor.fetchall()]
|
|
|
|
# Sample refs_per_logo references (or all if fewer available)
|
|
if len(all_refs) > refs_per_logo:
|
|
selected_refs = random.sample(all_refs, refs_per_logo)
|
|
else:
|
|
selected_refs = all_refs
|
|
|
|
result[logo_name] = selected_refs
|
|
|
|
conn.close()
|
|
return result
|
|
|
|
|
|
def get_test_images(db_path: Path) -> List[str]:
|
|
"""Get all test image filenames from database."""
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT filename FROM test_images")
|
|
filenames = [row[0] for row in cursor.fetchall()]
|
|
conn.close()
|
|
return filenames
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Test logo detection accuracy against ground truth"
|
|
)
|
|
parser.add_argument(
|
|
"-n", "--num-logos",
|
|
type=int,
|
|
default=10,
|
|
help="Number of reference logos to sample (default: 10)",
|
|
)
|
|
parser.add_argument(
|
|
"-t", "--threshold",
|
|
type=float,
|
|
default=0.7,
|
|
help="CLIP similarity threshold for matching (default: 0.7)",
|
|
)
|
|
parser.add_argument(
|
|
"-d", "--detr-threshold",
|
|
type=float,
|
|
default=0.5,
|
|
help="DETR detection confidence threshold (default: 0.5)",
|
|
)
|
|
parser.add_argument(
|
|
"-e", "--embedding-model",
|
|
type=str,
|
|
default="openai/clip-vit-large-patch14",
|
|
help="Embedding model for feature extraction (default: openai/clip-vit-large-patch14). "
|
|
"Supports CLIP models (openai/clip-*) and DINOv2 models (facebook/dinov2-*)",
|
|
)
|
|
parser.add_argument(
|
|
"-s", "--seed",
|
|
type=int,
|
|
default=None,
|
|
help="Random seed for reproducibility",
|
|
)
|
|
parser.add_argument(
|
|
"--positive-samples",
|
|
type=int,
|
|
default=5,
|
|
help="Number of positive test images per logo (images containing the logo) (default: 5)",
|
|
)
|
|
parser.add_argument(
|
|
"--negative-samples",
|
|
type=int,
|
|
default=20,
|
|
help="Number of negative test images per logo (images NOT containing the logo) (default: 20)",
|
|
)
|
|
parser.add_argument(
|
|
"--refs-per-logo",
|
|
type=int,
|
|
default=3,
|
|
help="Number of reference images per logo for multi-ref matching (default: 3)",
|
|
)
|
|
parser.add_argument(
|
|
"--margin",
|
|
type=float,
|
|
default=0.05,
|
|
help="Required margin between best and second-best match (applies to both methods) (default: 0.05)",
|
|
)
|
|
parser.add_argument(
|
|
"--matching-method",
|
|
type=str,
|
|
choices=["simple", "margin", "multi-ref"],
|
|
default="margin",
|
|
help="Matching method: 'simple' returns all matches above threshold, "
|
|
"'margin' requires confidence margin over 2nd best, "
|
|
"'multi-ref' aggregates scores across reference images (default: margin)",
|
|
)
|
|
parser.add_argument(
|
|
"--min-matching-refs",
|
|
type=int,
|
|
default=1,
|
|
help="For 'multi-ref' method: minimum references that must match above threshold (default: 1)",
|
|
)
|
|
parser.add_argument(
|
|
"--use-max-similarity",
|
|
action="store_true",
|
|
help="For 'multi-ref' method: use max similarity instead of mean across references",
|
|
)
|
|
parser.add_argument(
|
|
"-v", "--verbose",
|
|
action="store_true",
|
|
help="Enable verbose logging",
|
|
)
|
|
parser.add_argument(
|
|
"--similarity-details",
|
|
action="store_true",
|
|
help="Output detailed similarity scores for each detection (for analyzing score distributions)",
|
|
)
|
|
parser.add_argument(
|
|
"--no-cache",
|
|
action="store_true",
|
|
help="Disable embedding cache",
|
|
)
|
|
parser.add_argument(
|
|
"--clear-cache",
|
|
action="store_true",
|
|
help="Clear embedding cache before running",
|
|
)
|
|
parser.add_argument(
|
|
"--output-file",
|
|
type=str,
|
|
default=None,
|
|
help="Append results summary to this file (no progress output, just results)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
logger = setup_logging(args.verbose)
|
|
|
|
# Paths
|
|
base_dir = Path(__file__).resolve().parent
|
|
db_path = base_dir / "test_data_mapping.db"
|
|
reference_dir = base_dir / "reference_logos"
|
|
test_images_dir = base_dir / "test_images"
|
|
cache_path = base_dir / ".embedding_cache.pkl"
|
|
|
|
# Verify database exists
|
|
if not db_path.exists():
|
|
logger.error(f"Database not found: {db_path}")
|
|
logger.error("Run prepare_test_data.py first to create the database.")
|
|
sys.exit(1)
|
|
|
|
# Handle cache clearing
|
|
if args.clear_cache and cache_path.exists():
|
|
cache_path.unlink()
|
|
logger.info("Cleared embedding cache")
|
|
|
|
# Initialize embedding cache
|
|
cache = EmbeddingCache(cache_path) if not args.no_cache else None
|
|
if cache:
|
|
logger.info(f"Loaded {len(cache)} cached embeddings")
|
|
|
|
# Initialize detector
|
|
logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}")
|
|
detector = DetectLogosDETR(
|
|
logger=logger,
|
|
detr_threshold=args.detr_threshold,
|
|
embedding_model=args.embedding_model,
|
|
)
|
|
|
|
# Load ground truth (both mappings)
|
|
logger.info("Loading ground truth from database...")
|
|
image_to_logos, logo_to_images = get_ground_truth(db_path)
|
|
all_test_images = set(image_to_logos.keys())
|
|
logger.info(f"Loaded ground truth for {len(image_to_logos)} test images")
|
|
|
|
# Sample reference logos (with multiple refs per logo)
|
|
logger.info(f"Sampling {args.num_logos} reference logos with {args.refs_per_logo} refs each...")
|
|
sampled_logos = sample_reference_logos(db_path, args.num_logos, args.refs_per_logo, args.seed)
|
|
logger.info(f"Selected {len(sampled_logos)} reference logos")
|
|
|
|
# Compute reference embeddings (multiple per logo for multi-ref matching)
|
|
logger.info("Computing reference logo embeddings...")
|
|
# Dict for multi-ref matching: logo_name -> list of embeddings
|
|
multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {}
|
|
# List for margin-based matching: (logo_name, embedding) tuples
|
|
reference_embeddings: List[Tuple[str, torch.Tensor]] = []
|
|
total_refs = 0
|
|
|
|
for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"):
|
|
multi_ref_embeddings[logo_name] = []
|
|
|
|
for ref_filename in ref_filenames:
|
|
ref_path = reference_dir / ref_filename
|
|
|
|
if not ref_path.exists():
|
|
logger.warning(f"Reference logo not found: {ref_path}")
|
|
continue
|
|
|
|
# Check cache
|
|
cache_key = f"ref:{ref_filename}"
|
|
embedding = cache.get(cache_key) if cache else None
|
|
|
|
if embedding is None:
|
|
img = load_image(ref_path)
|
|
if img is None:
|
|
logger.warning(f"Failed to load reference logo: {ref_path}")
|
|
continue
|
|
embedding = detector.get_embedding(img)
|
|
if cache:
|
|
cache.put(cache_key, embedding)
|
|
|
|
multi_ref_embeddings[logo_name].append(embedding)
|
|
reference_embeddings.append((logo_name, embedding))
|
|
total_refs += 1
|
|
|
|
logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos")
|
|
|
|
# Build test set: for each logo, sample positive and negative images
|
|
logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...")
|
|
test_image_set: Set[str] = set()
|
|
test_image_expected: Dict[str, Set[str]] = {} # image -> logos it should match
|
|
|
|
# Use sampled_logos keys (unique logo names) instead of reference_embeddings
|
|
for logo_name in sampled_logos.keys():
|
|
# Get positive images (contain this logo)
|
|
positive_images = list(logo_to_images.get(logo_name, set()))
|
|
if len(positive_images) > args.positive_samples:
|
|
positive_images = random.sample(positive_images, args.positive_samples)
|
|
|
|
# Get negative images (do NOT contain this logo)
|
|
negative_pool = list(all_test_images - logo_to_images.get(logo_name, set()))
|
|
if len(negative_pool) > args.negative_samples:
|
|
negative_images = random.sample(negative_pool, args.negative_samples)
|
|
else:
|
|
negative_images = negative_pool
|
|
|
|
# Add to test set
|
|
for img in positive_images:
|
|
test_image_set.add(img)
|
|
if img not in test_image_expected:
|
|
test_image_expected[img] = set()
|
|
test_image_expected[img].add(logo_name)
|
|
|
|
for img in negative_images:
|
|
test_image_set.add(img)
|
|
if img not in test_image_expected:
|
|
test_image_expected[img] = set()
|
|
# Don't add logo_name - this is a negative sample
|
|
|
|
test_images = list(test_image_set)
|
|
logger.info(f"Selected {len(test_images)} unique test images")
|
|
|
|
# Get set of reference logo names for quick lookup
|
|
reference_logo_names = set(sampled_logos.keys())
|
|
|
|
# Metrics
|
|
true_positives = 0 # Correctly matched logos
|
|
false_positives = 0 # Matched but wrong logo or no logo present
|
|
false_negatives = 0 # Logo present but not detected/matched
|
|
total_expected = 0 # Total logos we should have found
|
|
|
|
# Detailed results for analysis
|
|
results = []
|
|
|
|
# Similarity distribution tracking (for --similarity-details)
|
|
similarity_details = {
|
|
"true_positive_sims": [], # Similarities for correct matches
|
|
"false_positive_sims": [], # Similarities for wrong matches
|
|
"missed_best_sims": [], # Best similarity for logos that should have matched but didn't
|
|
"all_positive_sims": [], # All similarities between detected regions and correct logos
|
|
"all_negative_sims": [], # All similarities between detected regions and wrong logos
|
|
"detection_details": [], # Per-detection breakdown
|
|
}
|
|
|
|
# Process test images
|
|
for test_filename in tqdm(test_images, desc="Testing"):
|
|
test_path = test_images_dir / test_filename
|
|
|
|
if not test_path.exists():
|
|
logger.warning(f"Test image not found: {test_path}")
|
|
continue
|
|
|
|
# Get expected logos for this image (from our sampled set)
|
|
expected_logos = test_image_expected.get(test_filename, set())
|
|
total_expected += len(expected_logos)
|
|
|
|
# Check cache for detections
|
|
cache_key = f"det:{test_filename}"
|
|
cached_detections = cache.get(cache_key) if cache else None
|
|
|
|
if cached_detections is not None:
|
|
# Cached detections contain serialized box data and embeddings
|
|
detections = cached_detections
|
|
else:
|
|
# Load and detect
|
|
img = load_image(test_path)
|
|
if img is None:
|
|
logger.warning(f"Failed to load test image: {test_path}")
|
|
continue
|
|
|
|
detections = detector.detect(img)
|
|
|
|
# Cache the detections
|
|
if cache:
|
|
cache.put(cache_key, detections)
|
|
|
|
# Match detections against references using selected method
|
|
matched_logos: Set[str] = set()
|
|
for det_idx, detection in enumerate(detections):
|
|
# Compute similarities to all reference logos for detailed analysis
|
|
if args.similarity_details:
|
|
all_sims = {}
|
|
for logo_name, ref_emb_list in multi_ref_embeddings.items():
|
|
sims = []
|
|
for ref_emb in ref_emb_list:
|
|
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
|
sims.append(sim)
|
|
# Use mean or max based on setting
|
|
if args.use_max_similarity:
|
|
all_sims[logo_name] = max(sims) if sims else 0
|
|
else:
|
|
all_sims[logo_name] = sum(sims) / len(sims) if sims else 0
|
|
|
|
# Track positive vs negative similarities
|
|
for sim in sims:
|
|
if logo_name in expected_logos:
|
|
similarity_details["all_positive_sims"].append(sim)
|
|
else:
|
|
similarity_details["all_negative_sims"].append(sim)
|
|
|
|
# Store detection details
|
|
sorted_sims = sorted(all_sims.items(), key=lambda x: -x[1])
|
|
similarity_details["detection_details"].append({
|
|
"image": test_filename,
|
|
"detection_idx": det_idx,
|
|
"expected_logos": list(expected_logos),
|
|
"top_5_matches": sorted_sims[:5],
|
|
"detr_score": detection.get("score", 0),
|
|
})
|
|
|
|
if args.matching_method == "simple":
|
|
# Simple matching: return ALL logos above threshold
|
|
all_matches = detector.find_all_matches(
|
|
detection["embedding"],
|
|
reference_embeddings,
|
|
similarity_threshold=args.threshold,
|
|
)
|
|
for label, similarity in all_matches:
|
|
matched_logos.add(label)
|
|
|
|
# Check if this is a correct match
|
|
is_correct = label in expected_logos
|
|
if is_correct:
|
|
true_positives += 1
|
|
if args.similarity_details:
|
|
similarity_details["true_positive_sims"].append(similarity)
|
|
else:
|
|
false_positives += 1
|
|
if args.similarity_details:
|
|
similarity_details["false_positive_sims"].append(similarity)
|
|
|
|
results.append({
|
|
"test_image": test_filename,
|
|
"matched_logo": label,
|
|
"similarity": similarity,
|
|
"correct": is_correct,
|
|
})
|
|
|
|
elif args.matching_method == "margin":
|
|
# Margin-based matching: requires margin over second-best
|
|
match_result = detector.find_best_match_with_margin(
|
|
detection["embedding"],
|
|
reference_embeddings,
|
|
similarity_threshold=args.threshold,
|
|
margin=args.margin,
|
|
)
|
|
if match_result:
|
|
label, similarity = match_result
|
|
matched_logos.add(label)
|
|
|
|
is_correct = label in expected_logos
|
|
if is_correct:
|
|
true_positives += 1
|
|
if args.similarity_details:
|
|
similarity_details["true_positive_sims"].append(similarity)
|
|
else:
|
|
false_positives += 1
|
|
if args.similarity_details:
|
|
similarity_details["false_positive_sims"].append(similarity)
|
|
|
|
results.append({
|
|
"test_image": test_filename,
|
|
"matched_logo": label,
|
|
"similarity": similarity,
|
|
"correct": is_correct,
|
|
})
|
|
|
|
else: # multi-ref
|
|
# Multi-ref matching: aggregates scores across reference images
|
|
match_result = detector.find_best_match_multi_ref(
|
|
detection["embedding"],
|
|
multi_ref_embeddings,
|
|
similarity_threshold=args.threshold,
|
|
min_matching_refs=args.min_matching_refs,
|
|
use_mean_similarity=not args.use_max_similarity,
|
|
margin=args.margin,
|
|
)
|
|
if match_result:
|
|
label, similarity, num_matching = match_result
|
|
matched_logos.add(label)
|
|
|
|
is_correct = label in expected_logos
|
|
if is_correct:
|
|
true_positives += 1
|
|
if args.similarity_details:
|
|
similarity_details["true_positive_sims"].append(similarity)
|
|
else:
|
|
false_positives += 1
|
|
if args.similarity_details:
|
|
similarity_details["false_positive_sims"].append(similarity)
|
|
|
|
results.append({
|
|
"test_image": test_filename,
|
|
"matched_logo": label,
|
|
"similarity": similarity,
|
|
"correct": is_correct,
|
|
})
|
|
|
|
# Count missed detections (false negatives)
|
|
missed = expected_logos - matched_logos
|
|
false_negatives += len(missed)
|
|
|
|
for missed_logo in missed:
|
|
# Track best similarity for missed logos (if we have detections)
|
|
if args.similarity_details and detections:
|
|
best_sim_for_missed = 0
|
|
for detection in detections:
|
|
for ref_emb in multi_ref_embeddings.get(missed_logo, []):
|
|
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
|
best_sim_for_missed = max(best_sim_for_missed, sim)
|
|
similarity_details["missed_best_sims"].append(best_sim_for_missed)
|
|
|
|
results.append({
|
|
"test_image": test_filename,
|
|
"matched_logo": None,
|
|
"expected_logo": missed_logo,
|
|
"similarity": None,
|
|
"correct": False,
|
|
})
|
|
|
|
# Save cache
|
|
if cache:
|
|
cache.save()
|
|
logger.info(f"Saved {len(cache)} embeddings to cache")
|
|
|
|
# Calculate metrics
|
|
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
|
recall = true_positives / total_expected if total_expected > 0 else 0
|
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
|
|
# Print results
|
|
print("\n" + "=" * 60)
|
|
print("LOGO DETECTION TEST RESULTS")
|
|
print("=" * 60)
|
|
print(f"\nConfiguration:")
|
|
print(f" Reference logos sampled: {len(sampled_logos)}")
|
|
print(f" Refs per logo: {args.refs_per_logo}")
|
|
print(f" Total reference embeddings:{total_refs}")
|
|
print(f" Positive samples per logo: {args.positive_samples}")
|
|
print(f" Negative samples per logo: {args.negative_samples}")
|
|
print(f" Test images processed: {len(test_images)}")
|
|
print(f" CLIP similarity threshold: {args.threshold}")
|
|
print(f" DETR confidence threshold: {args.detr_threshold}")
|
|
print(f" Matching method: {args.matching_method}")
|
|
if args.matching_method in ("margin", "multi-ref"):
|
|
print(f" Matching margin: {args.margin}")
|
|
if args.matching_method == "multi-ref":
|
|
print(f" Min matching refs: {args.min_matching_refs}")
|
|
print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}")
|
|
if args.seed is not None:
|
|
print(f" Random seed: {args.seed}")
|
|
|
|
print(f"\nMetrics:")
|
|
print(f" True Positives (correct matches): {true_positives}")
|
|
print(f" False Positives (wrong matches): {false_positives}")
|
|
print(f" False Negatives (missed logos): {false_negatives}")
|
|
print(f" Total expected matches: {total_expected}")
|
|
|
|
print(f"\nScores:")
|
|
print(f" Precision: {precision:.4f} ({precision*100:.1f}%)")
|
|
print(f" Recall: {recall:.4f} ({recall*100:.1f}%)")
|
|
print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)")
|
|
|
|
# Show some example false positives
|
|
false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]]
|
|
if false_positive_examples:
|
|
print(f"\nExample False Positives (first 5):")
|
|
for r in false_positive_examples[:5]:
|
|
print(f" - Image: {r['test_image']}")
|
|
print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})")
|
|
|
|
# Show reference logos used (unique names)
|
|
unique_logos = sorted(sampled_logos.keys())
|
|
print(f"\nReference logos used ({len(unique_logos)}):")
|
|
for name in unique_logos[:20]:
|
|
print(f" - {name}")
|
|
if len(unique_logos) > 20:
|
|
print(f" ... and {len(unique_logos) - 20} more")
|
|
|
|
print("=" * 60)
|
|
|
|
# Print similarity distribution details if requested
|
|
if args.similarity_details:
|
|
print_similarity_details(similarity_details, args.threshold)
|
|
|
|
# Write results to file if requested
|
|
if args.output_file:
|
|
write_results_to_file(
|
|
output_path=Path(args.output_file),
|
|
args=args,
|
|
num_logos=len(sampled_logos),
|
|
total_refs=total_refs,
|
|
num_test_images=len(test_images),
|
|
true_positives=true_positives,
|
|
false_positives=false_positives,
|
|
false_negatives=false_negatives,
|
|
total_expected=total_expected,
|
|
precision=precision,
|
|
recall=recall,
|
|
f1=f1,
|
|
)
|
|
print(f"\nResults appended to: {args.output_file}")
|
|
|
|
|
|
def print_similarity_details(details: dict, threshold: float):
|
|
"""Print detailed similarity distribution analysis."""
|
|
import statistics
|
|
|
|
print("\n" + "=" * 60)
|
|
print("SIMILARITY DISTRIBUTION ANALYSIS")
|
|
print("=" * 60)
|
|
|
|
# Helper to compute stats
|
|
def compute_stats(values, name):
|
|
if not values:
|
|
print(f"\n{name}: No data")
|
|
return
|
|
print(f"\n{name} (n={len(values)}):")
|
|
print(f" Min: {min(values):.4f}")
|
|
print(f" Max: {max(values):.4f}")
|
|
print(f" Mean: {statistics.mean(values):.4f}")
|
|
if len(values) > 1:
|
|
print(f" StdDev: {statistics.stdev(values):.4f}")
|
|
print(f" Median: {statistics.median(values):.4f}")
|
|
|
|
# Percentiles
|
|
sorted_vals = sorted(values)
|
|
n = len(sorted_vals)
|
|
p10 = sorted_vals[int(n * 0.10)] if n > 10 else sorted_vals[0]
|
|
p25 = sorted_vals[int(n * 0.25)] if n > 4 else sorted_vals[0]
|
|
p75 = sorted_vals[int(n * 0.75)] if n > 4 else sorted_vals[-1]
|
|
p90 = sorted_vals[int(n * 0.90)] if n > 10 else sorted_vals[-1]
|
|
print(f" P10: {p10:.4f}")
|
|
print(f" P25: {p25:.4f}")
|
|
print(f" P75: {p75:.4f}")
|
|
print(f" P90: {p90:.4f}")
|
|
|
|
# Count above/below threshold
|
|
above = sum(1 for v in values if v >= threshold)
|
|
below = sum(1 for v in values if v < threshold)
|
|
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
|
|
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
|
|
|
|
# Print distribution stats
|
|
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities (correct matches)")
|
|
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities (wrong matches)")
|
|
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities (false negatives)")
|
|
compute_stats(details["all_positive_sims"], "ALL similarities to CORRECT logos (per-ref)")
|
|
compute_stats(details["all_negative_sims"], "ALL similarities to WRONG logos (per-ref)")
|
|
|
|
# Overlap analysis
|
|
tp_sims = details["true_positive_sims"]
|
|
fp_sims = details["false_positive_sims"]
|
|
if tp_sims and fp_sims:
|
|
print("\n" + "-" * 40)
|
|
print("OVERLAP ANALYSIS:")
|
|
tp_min, tp_max = min(tp_sims), max(tp_sims)
|
|
fp_min, fp_max = min(fp_sims), max(fp_sims)
|
|
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
|
|
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
|
|
|
|
# Check overlap
|
|
overlap_min = max(tp_min, fp_min)
|
|
overlap_max = min(tp_max, fp_max)
|
|
if overlap_min < overlap_max:
|
|
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
|
|
tp_in_overlap = sum(1 for v in tp_sims if overlap_min <= v <= overlap_max)
|
|
fp_in_overlap = sum(1 for v in fp_sims if overlap_min <= v <= overlap_max)
|
|
print(f" TPs in overlap: {tp_in_overlap} ({100*tp_in_overlap/len(tp_sims):.1f}%)")
|
|
print(f" FPs in overlap: {fp_in_overlap} ({100*fp_in_overlap/len(fp_sims):.1f}%)")
|
|
else:
|
|
print(" NO OVERLAP - distributions are separable!")
|
|
|
|
# Suggest optimal threshold
|
|
all_points = [(s, "tp") for s in tp_sims] + [(s, "fp") for s in fp_sims]
|
|
all_points.sort()
|
|
best_thresh = threshold
|
|
best_f1 = 0
|
|
total_tp = len(tp_sims)
|
|
total_fp = len(fp_sims)
|
|
|
|
for thresh in [p[0] for p in all_points]:
|
|
# At this threshold:
|
|
tp_above = sum(1 for s in tp_sims if s >= thresh)
|
|
fp_above = sum(1 for s in fp_sims if s >= thresh)
|
|
prec = tp_above / (tp_above + fp_above) if (tp_above + fp_above) > 0 else 0
|
|
rec = tp_above / total_tp if total_tp > 0 else 0
|
|
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
|
|
if f1 > best_f1:
|
|
best_f1 = f1
|
|
best_thresh = thresh
|
|
|
|
print(f"\n SUGGESTED OPTIMAL THRESHOLD: {best_thresh:.4f}")
|
|
print(f" (would give F1 = {best_f1:.4f} on this data)")
|
|
|
|
# Print sample detection details
|
|
det_details = details["detection_details"]
|
|
if det_details:
|
|
print("\n" + "-" * 40)
|
|
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
|
|
for i, det in enumerate(det_details[:20]):
|
|
expected = det["expected_logos"]
|
|
top5 = det["top_5_matches"]
|
|
print(f"\n [{i+1}] Image: {det['image']}")
|
|
print(f" Expected: {expected if expected else '(none)'}")
|
|
print(f" DETR score: {det['detr_score']:.3f}")
|
|
print(f" Top 5 matches:")
|
|
for logo, sim in top5:
|
|
marker = " <-- CORRECT" if logo in expected else ""
|
|
print(f" {sim:.4f} {logo}{marker}")
|
|
|
|
print("\n" + "=" * 60)
|
|
|
|
|
|
def write_results_to_file(
|
|
output_path: Path,
|
|
args,
|
|
num_logos: int,
|
|
total_refs: int,
|
|
num_test_images: int,
|
|
true_positives: int,
|
|
false_positives: int,
|
|
false_negatives: int,
|
|
total_expected: int,
|
|
precision: float,
|
|
recall: float,
|
|
f1: float,
|
|
):
|
|
"""Write results summary to file with detailed header."""
|
|
from datetime import datetime
|
|
|
|
# Build method description for header
|
|
if args.matching_method == "simple":
|
|
method_desc = "Simple (all matches above threshold)"
|
|
elif args.matching_method == "margin":
|
|
method_desc = f"Margin-based (margin={args.margin})"
|
|
else: # multi-ref
|
|
agg = "max" if args.use_max_similarity else "mean"
|
|
method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})"
|
|
|
|
lines = [
|
|
"=" * 70,
|
|
f"TEST: {args.matching_method.upper()} MATCHING",
|
|
f"Model: {args.embedding_model}",
|
|
f"Method: {method_desc}",
|
|
"=" * 70,
|
|
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
|
"",
|
|
"Configuration:",
|
|
f" Embedding model: {args.embedding_model}",
|
|
f" Reference logos: {num_logos}",
|
|
f" Refs per logo: {args.refs_per_logo}",
|
|
f" Total reference embeddings:{total_refs}",
|
|
f" Positive samples/logo: {args.positive_samples}",
|
|
f" Negative samples/logo: {args.negative_samples}",
|
|
f" Test images processed: {num_test_images}",
|
|
f" Similarity threshold: {args.threshold}",
|
|
f" DETR threshold: {args.detr_threshold}",
|
|
]
|
|
|
|
if args.seed is not None:
|
|
lines.append(f" Random seed: {args.seed}")
|
|
|
|
lines.extend([
|
|
"",
|
|
"Results:",
|
|
f" True Positives: {true_positives:>6}",
|
|
f" False Positives: {false_positives:>6}",
|
|
f" False Negatives: {false_negatives:>6}",
|
|
f" Total Expected: {total_expected:>6}",
|
|
"",
|
|
"Scores:",
|
|
f" Precision: {precision:.4f} ({precision*100:.1f}%)",
|
|
f" Recall: {recall:.4f} ({recall*100:.1f}%)",
|
|
f" F1 Score: {f1:.4f} ({f1*100:.1f}%)",
|
|
"",
|
|
"",
|
|
])
|
|
|
|
# Append to file
|
|
with open(output_path, "a") as f:
|
|
f.write("\n".join(lines))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |