Files
logo_test/test_logo_detection.py
Rick McEwen ea6fcec9ce Remove hybrid text+CLIP matching approach
The hybrid approach combined OCR text recognition with CLIP embeddings
to improve logo matching accuracy. After extensive testing, the approach
was abandoned because:

1. OCR quality on small logo crops is unreliable
2. Text filtering rejected correct matches as often as wrong ones
3. Best hybrid result (57.1% precision) was similar to baseline (55.1%)
4. Recall dropped significantly (52.6% vs 59.6%)
5. Added complexity (EasyOCR dependency, extra parameters) wasn't justified

Removed:
- Hybrid matching methods from DetectLogosDETR class
- Text extraction and similarity methods
- Hybrid test scripts and text_recognition.py module
- Hybrid-related CLI arguments from test_logo_detection.py

The baseline multi-ref matching with 0.70 threshold remains the
recommended approach for logo detection.
2026-01-08 12:48:39 -05:00

886 lines
33 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Test script for logo detection accuracy.
This script:
1. Randomly samples N reference logos from the database
2. Processes all test images through the DETR+CLIP pipeline
3. Compares detected logos against reference embeddings
4. Reports accuracy metrics (correct matches, false positives, missed detections)
Embeddings are cached to avoid reprocessing images.
"""
import argparse
import logging
import pickle
import random
import sqlite3
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
import cv2
import torch
from tqdm import tqdm
from logo_detection_detr import DetectLogosDETR
def setup_logging(verbose: bool = False) -> logging.Logger:
"""Configure logging."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%H:%M:%S",
)
return logging.getLogger(__name__)
def load_image(image_path: Path) -> Optional[cv2.Mat]:
"""Load an image using OpenCV."""
img = cv2.imread(str(image_path))
if img is None:
return None
return img
class EmbeddingCache:
"""Simple file-based cache for embeddings."""
def __init__(self, cache_path: Path):
self.cache_path = cache_path
self.cache: Dict[str, torch.Tensor] = {}
self._load()
def _load(self):
"""Load cache from disk if it exists."""
if self.cache_path.exists():
try:
with open(self.cache_path, "rb") as f:
self.cache = pickle.load(f)
except Exception:
self.cache = {}
def save(self):
"""Save cache to disk."""
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.cache_path, "wb") as f:
pickle.dump(self.cache, f)
def get(self, key: str) -> Optional[torch.Tensor]:
"""Get embedding from cache."""
return self.cache.get(key)
def put(self, key: str, embedding: torch.Tensor):
"""Store embedding in cache."""
# Store on CPU to save GPU memory
self.cache[key] = embedding.cpu()
def __len__(self):
return len(self.cache)
def get_ground_truth(db_path: Path) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
"""
Load ground truth from database.
Returns:
Tuple of:
- Dict mapping test image filename to set of logo names it contains
- Dict mapping logo name to set of test image filenames containing it
"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Query to get test image -> logo names mapping
cursor.execute("""
SELECT ti.filename, ln.name
FROM test_images ti
JOIN reference_logos rl ON ti.id = rl.test_image_id
JOIN logo_names ln ON rl.logo_name_id = ln.id
""")
image_to_logos: Dict[str, Set[str]] = {}
logo_to_images: Dict[str, Set[str]] = {}
for row in cursor.fetchall():
test_filename, logo_name = row
if test_filename not in image_to_logos:
image_to_logos[test_filename] = set()
image_to_logos[test_filename].add(logo_name)
if logo_name not in logo_to_images:
logo_to_images[logo_name] = set()
logo_to_images[logo_name].add(test_filename)
conn.close()
return image_to_logos, logo_to_images
def sample_reference_logos(
db_path: Path, num_logos: int, refs_per_logo: int = 1, seed: Optional[int] = None
) -> Dict[str, List[str]]:
"""
Randomly sample reference logos from database with multiple refs per logo.
Args:
db_path: Path to database
num_logos: Number of logos to sample
refs_per_logo: Number of reference images per logo
seed: Random seed for reproducibility
Returns:
Dict mapping logo_name to list of reference filenames
"""
if seed is not None:
random.seed(seed)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Get all unique logo names
cursor.execute("SELECT id, name FROM logo_names")
all_logo_names = cursor.fetchall()
# Sample logos
if num_logos >= len(all_logo_names):
sampled_logos = all_logo_names
else:
sampled_logos = random.sample(all_logo_names, num_logos)
# For each sampled logo, get multiple reference files
result: Dict[str, List[str]] = {}
for logo_id, logo_name in sampled_logos:
cursor.execute(
"SELECT filename FROM reference_logos WHERE logo_name_id = ?",
(logo_id,)
)
all_refs = [row[0] for row in cursor.fetchall()]
# Sample refs_per_logo references (or all if fewer available)
if len(all_refs) > refs_per_logo:
selected_refs = random.sample(all_refs, refs_per_logo)
else:
selected_refs = all_refs
result[logo_name] = selected_refs
conn.close()
return result
def get_test_images(db_path: Path) -> List[str]:
"""Get all test image filenames from database."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT filename FROM test_images")
filenames = [row[0] for row in cursor.fetchall()]
conn.close()
return filenames
def main():
parser = argparse.ArgumentParser(
description="Test logo detection accuracy against ground truth"
)
parser.add_argument(
"-n", "--num-logos",
type=int,
default=10,
help="Number of reference logos to sample (default: 10)",
)
parser.add_argument(
"-t", "--threshold",
type=float,
default=0.7,
help="CLIP similarity threshold for matching (default: 0.7)",
)
parser.add_argument(
"-d", "--detr-threshold",
type=float,
default=0.5,
help="DETR detection confidence threshold (default: 0.5)",
)
parser.add_argument(
"-e", "--embedding-model",
type=str,
default="openai/clip-vit-large-patch14",
help="Embedding model for feature extraction (default: openai/clip-vit-large-patch14). "
"Supports CLIP models (openai/clip-*) and DINOv2 models (facebook/dinov2-*)",
)
parser.add_argument(
"-s", "--seed",
type=int,
default=None,
help="Random seed for reproducibility",
)
parser.add_argument(
"--positive-samples",
type=int,
default=5,
help="Number of positive test images per logo (images containing the logo) (default: 5)",
)
parser.add_argument(
"--negative-samples",
type=int,
default=20,
help="Number of negative test images per logo (images NOT containing the logo) (default: 20)",
)
parser.add_argument(
"--refs-per-logo",
type=int,
default=3,
help="Number of reference images per logo for multi-ref matching (default: 3)",
)
parser.add_argument(
"--margin",
type=float,
default=0.05,
help="Required margin between best and second-best match (applies to both methods) (default: 0.05)",
)
parser.add_argument(
"--matching-method",
type=str,
choices=["simple", "margin", "multi-ref"],
default="margin",
help="Matching method: 'simple' returns all matches above threshold, "
"'margin' requires confidence margin over 2nd best, "
"'multi-ref' aggregates scores across reference images (default: margin)",
)
parser.add_argument(
"--min-matching-refs",
type=int,
default=1,
help="For 'multi-ref' method: minimum references that must match above threshold (default: 1)",
)
parser.add_argument(
"--use-max-similarity",
action="store_true",
help="For 'multi-ref' method: use max similarity instead of mean across references",
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
help="Enable verbose logging",
)
parser.add_argument(
"--similarity-details",
action="store_true",
help="Output detailed similarity scores for each detection (for analyzing score distributions)",
)
parser.add_argument(
"--no-cache",
action="store_true",
help="Disable embedding cache",
)
parser.add_argument(
"--clear-cache",
action="store_true",
help="Clear embedding cache before running",
)
parser.add_argument(
"--output-file",
type=str,
default=None,
help="Append results summary to this file (no progress output, just results)",
)
parser.add_argument(
"--preprocess-mode",
type=str,
choices=["default", "letterbox", "stretch"],
default="default",
help="Image preprocessing mode for CLIP: 'default' (resize+center crop), "
"'letterbox' (pad to square with black bars), 'stretch' (distort to square)",
)
args = parser.parse_args()
logger = setup_logging(args.verbose)
# Paths
base_dir = Path(__file__).resolve().parent
db_path = base_dir / "test_data_mapping.db"
reference_dir = base_dir / "reference_logos"
test_images_dir = base_dir / "test_images"
cache_path = base_dir / ".embedding_cache.pkl"
# Verify database exists
if not db_path.exists():
logger.error(f"Database not found: {db_path}")
logger.error("Run prepare_test_data.py first to create the database.")
sys.exit(1)
# Handle cache clearing
if args.clear_cache and cache_path.exists():
cache_path.unlink()
logger.info("Cleared embedding cache")
# Initialize embedding cache
cache = EmbeddingCache(cache_path) if not args.no_cache else None
if cache:
logger.info(f"Loaded {len(cache)} cached embeddings")
# Initialize detector
logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}")
if args.preprocess_mode != "default":
logger.info(f"Using preprocessing mode: {args.preprocess_mode}")
detector = DetectLogosDETR(
logger=logger,
detr_threshold=args.detr_threshold,
embedding_model=args.embedding_model,
preprocess_mode=args.preprocess_mode,
)
# Load ground truth (both mappings)
logger.info("Loading ground truth from database...")
image_to_logos, logo_to_images = get_ground_truth(db_path)
all_test_images = set(image_to_logos.keys())
logger.info(f"Loaded ground truth for {len(image_to_logos)} test images")
# Sample reference logos (with multiple refs per logo)
logger.info(f"Sampling {args.num_logos} reference logos with {args.refs_per_logo} refs each...")
sampled_logos = sample_reference_logos(db_path, args.num_logos, args.refs_per_logo, args.seed)
logger.info(f"Selected {len(sampled_logos)} reference logos")
# Compute reference embeddings (multiple per logo for multi-ref matching)
logger.info("Computing reference logo embeddings...")
# Dict for multi-ref matching: logo_name -> list of embeddings
multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {}
# List for margin-based matching: (logo_name, embedding) tuples
reference_embeddings: List[Tuple[str, torch.Tensor]] = []
total_refs = 0
for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"):
multi_ref_embeddings[logo_name] = []
for ref_filename in ref_filenames:
ref_path = reference_dir / ref_filename
if not ref_path.exists():
logger.warning(f"Reference logo not found: {ref_path}")
continue
# Check cache
cache_key = f"ref:{ref_filename}"
embedding = cache.get(cache_key) if cache else None
# Load image if needed for embedding
if embedding is None:
img = load_image(ref_path)
if img is None:
logger.warning(f"Failed to load reference logo: {ref_path}")
continue
embedding = detector.get_embedding(img)
if cache:
cache.put(cache_key, embedding)
multi_ref_embeddings[logo_name].append(embedding)
reference_embeddings.append((logo_name, embedding))
total_refs += 1
logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos")
# Build test set: for each logo, sample positive and negative images
logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...")
test_image_set: Set[str] = set()
test_image_expected: Dict[str, Set[str]] = {} # image -> logos it should match
# Use sampled_logos keys (unique logo names) instead of reference_embeddings
for logo_name in sampled_logos.keys():
# Get positive images (contain this logo)
positive_images = list(logo_to_images.get(logo_name, set()))
if len(positive_images) > args.positive_samples:
positive_images = random.sample(positive_images, args.positive_samples)
# Get negative images (do NOT contain this logo)
negative_pool = list(all_test_images - logo_to_images.get(logo_name, set()))
if len(negative_pool) > args.negative_samples:
negative_images = random.sample(negative_pool, args.negative_samples)
else:
negative_images = negative_pool
# Add to test set
for img in positive_images:
test_image_set.add(img)
if img not in test_image_expected:
test_image_expected[img] = set()
test_image_expected[img].add(logo_name)
for img in negative_images:
test_image_set.add(img)
if img not in test_image_expected:
test_image_expected[img] = set()
# Don't add logo_name - this is a negative sample
test_images = list(test_image_set)
logger.info(f"Selected {len(test_images)} unique test images")
# Get set of reference logo names for quick lookup
reference_logo_names = set(sampled_logos.keys())
# Metrics
true_positives = 0 # Correctly matched logos
false_positives = 0 # Matched but wrong logo or no logo present
false_negatives = 0 # Logo present but not detected/matched
total_expected = 0 # Total logos we should have found
# Detailed results for analysis
results = []
# Similarity distribution tracking (for --similarity-details)
similarity_details = {
"true_positive_sims": [], # Similarities for correct matches
"false_positive_sims": [], # Similarities for wrong matches
"missed_best_sims": [], # Best similarity for logos that should have matched but didn't
"all_positive_sims": [], # All similarities between detected regions and correct logos
"all_negative_sims": [], # All similarities between detected regions and wrong logos
"detection_details": [], # Per-detection breakdown
}
# Process test images
for test_filename in tqdm(test_images, desc="Testing"):
test_path = test_images_dir / test_filename
if not test_path.exists():
logger.warning(f"Test image not found: {test_path}")
continue
# Get expected logos for this image (from our sampled set)
expected_logos = test_image_expected.get(test_filename, set())
total_expected += len(expected_logos)
# Check cache for detections
cache_key = f"det:{test_filename}"
cached_detections = cache.get(cache_key) if cache else None
test_img = None
if cached_detections is not None:
# Cached detections contain serialized box data and embeddings
detections = cached_detections
else:
# Load and detect
if test_img is None:
test_img = load_image(test_path)
if test_img is None:
logger.warning(f"Failed to load test image: {test_path}")
continue
detections = detector.detect(test_img)
# Cache the detections
if cache:
cache.put(cache_key, detections)
# Match detections against references using selected method
matched_logos: Set[str] = set()
for det_idx, detection in enumerate(detections):
# Compute similarities to all reference logos for detailed analysis
if args.similarity_details:
all_sims = {}
for logo_name, ref_emb_list in multi_ref_embeddings.items():
sims = []
for ref_emb in ref_emb_list:
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
sims.append(sim)
# Use mean or max based on setting
if args.use_max_similarity:
all_sims[logo_name] = max(sims) if sims else 0
else:
all_sims[logo_name] = sum(sims) / len(sims) if sims else 0
# Track positive vs negative similarities
for sim in sims:
if logo_name in expected_logos:
similarity_details["all_positive_sims"].append(sim)
else:
similarity_details["all_negative_sims"].append(sim)
# Store detection details
sorted_sims = sorted(all_sims.items(), key=lambda x: -x[1])
similarity_details["detection_details"].append({
"image": test_filename,
"detection_idx": det_idx,
"expected_logos": list(expected_logos),
"top_5_matches": sorted_sims[:5],
"detr_score": detection.get("score", 0),
})
if args.matching_method == "simple":
# Simple matching: return ALL logos above threshold
all_matches = detector.find_all_matches(
detection["embedding"],
reference_embeddings,
similarity_threshold=args.threshold,
)
for label, similarity in all_matches:
matched_logos.add(label)
# Check if this is a correct match
is_correct = label in expected_logos
if is_correct:
true_positives += 1
if args.similarity_details:
similarity_details["true_positive_sims"].append(similarity)
else:
false_positives += 1
if args.similarity_details:
similarity_details["false_positive_sims"].append(similarity)
results.append({
"test_image": test_filename,
"matched_logo": label,
"similarity": similarity,
"correct": is_correct,
})
elif args.matching_method == "margin":
# Margin-based matching: requires margin over second-best
match_result = detector.find_best_match_with_margin(
detection["embedding"],
reference_embeddings,
similarity_threshold=args.threshold,
margin=args.margin,
)
if match_result:
label, similarity = match_result
matched_logos.add(label)
is_correct = label in expected_logos
if is_correct:
true_positives += 1
if args.similarity_details:
similarity_details["true_positive_sims"].append(similarity)
else:
false_positives += 1
if args.similarity_details:
similarity_details["false_positive_sims"].append(similarity)
results.append({
"test_image": test_filename,
"matched_logo": label,
"similarity": similarity,
"correct": is_correct,
})
elif args.matching_method == "multi-ref":
# Multi-ref matching: aggregates scores across reference images
match_result = detector.find_best_match_multi_ref(
detection["embedding"],
multi_ref_embeddings,
similarity_threshold=args.threshold,
min_matching_refs=args.min_matching_refs,
use_mean_similarity=not args.use_max_similarity,
margin=args.margin,
)
if match_result:
label, similarity, num_matching = match_result
matched_logos.add(label)
is_correct = label in expected_logos
if is_correct:
true_positives += 1
if args.similarity_details:
similarity_details["true_positive_sims"].append(similarity)
else:
false_positives += 1
if args.similarity_details:
similarity_details["false_positive_sims"].append(similarity)
results.append({
"test_image": test_filename,
"matched_logo": label,
"similarity": similarity,
"correct": is_correct,
})
# Count missed detections (false negatives)
missed = expected_logos - matched_logos
false_negatives += len(missed)
for missed_logo in missed:
# Track best similarity for missed logos (if we have detections)
if args.similarity_details and detections:
best_sim_for_missed = 0
for detection in detections:
for ref_emb in multi_ref_embeddings.get(missed_logo, []):
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
best_sim_for_missed = max(best_sim_for_missed, sim)
similarity_details["missed_best_sims"].append(best_sim_for_missed)
results.append({
"test_image": test_filename,
"matched_logo": None,
"expected_logo": missed_logo,
"similarity": None,
"correct": False,
})
# Save cache
if cache:
cache.save()
logger.info(f"Saved {len(cache)} embeddings to cache")
# Calculate metrics
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / total_expected if total_expected > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
# Print results
print("\n" + "=" * 60)
print("LOGO DETECTION TEST RESULTS")
print("=" * 60)
print(f"\nConfiguration:")
print(f" Reference logos sampled: {len(sampled_logos)}")
print(f" Refs per logo: {args.refs_per_logo}")
print(f" Total reference embeddings:{total_refs}")
print(f" Positive samples per logo: {args.positive_samples}")
print(f" Negative samples per logo: {args.negative_samples}")
print(f" Test images processed: {len(test_images)}")
print(f" CLIP similarity threshold: {args.threshold}")
print(f" DETR confidence threshold: {args.detr_threshold}")
print(f" Preprocess mode: {args.preprocess_mode}")
print(f" Matching method: {args.matching_method}")
if args.matching_method in ("margin", "multi-ref"):
print(f" Matching margin: {args.margin}")
if args.matching_method == "multi-ref":
print(f" Min matching refs: {args.min_matching_refs}")
print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}")
if args.seed is not None:
print(f" Random seed: {args.seed}")
print(f"\nMetrics:")
print(f" True Positives (correct matches): {true_positives}")
print(f" False Positives (wrong matches): {false_positives}")
print(f" False Negatives (missed logos): {false_negatives}")
print(f" Total expected matches: {total_expected}")
print(f"\nScores:")
print(f" Precision: {precision:.4f} ({precision*100:.1f}%)")
print(f" Recall: {recall:.4f} ({recall*100:.1f}%)")
print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)")
# Show some example false positives
false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]]
if false_positive_examples:
print(f"\nExample False Positives (first 5):")
for r in false_positive_examples[:5]:
print(f" - Image: {r['test_image']}")
print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})")
# Show reference logos used (unique names)
unique_logos = sorted(sampled_logos.keys())
print(f"\nReference logos used ({len(unique_logos)}):")
for name in unique_logos[:20]:
print(f" - {name}")
if len(unique_logos) > 20:
print(f" ... and {len(unique_logos) - 20} more")
print("=" * 60)
# Print similarity distribution details if requested
if args.similarity_details:
print_similarity_details(similarity_details, args.threshold)
# Write results to file if requested
if args.output_file:
write_results_to_file(
output_path=Path(args.output_file),
args=args,
num_logos=len(sampled_logos),
total_refs=total_refs,
num_test_images=len(test_images),
true_positives=true_positives,
false_positives=false_positives,
false_negatives=false_negatives,
total_expected=total_expected,
precision=precision,
recall=recall,
f1=f1,
)
print(f"\nResults appended to: {args.output_file}")
def print_similarity_details(details: dict, threshold: float):
"""Print detailed similarity distribution analysis."""
import statistics
print("\n" + "=" * 60)
print("SIMILARITY DISTRIBUTION ANALYSIS")
print("=" * 60)
# Helper to compute stats
def compute_stats(values, name):
if not values:
print(f"\n{name}: No data")
return
print(f"\n{name} (n={len(values)}):")
print(f" Min: {min(values):.4f}")
print(f" Max: {max(values):.4f}")
print(f" Mean: {statistics.mean(values):.4f}")
if len(values) > 1:
print(f" StdDev: {statistics.stdev(values):.4f}")
print(f" Median: {statistics.median(values):.4f}")
# Percentiles
sorted_vals = sorted(values)
n = len(sorted_vals)
p10 = sorted_vals[int(n * 0.10)] if n > 10 else sorted_vals[0]
p25 = sorted_vals[int(n * 0.25)] if n > 4 else sorted_vals[0]
p75 = sorted_vals[int(n * 0.75)] if n > 4 else sorted_vals[-1]
p90 = sorted_vals[int(n * 0.90)] if n > 10 else sorted_vals[-1]
print(f" P10: {p10:.4f}")
print(f" P25: {p25:.4f}")
print(f" P75: {p75:.4f}")
print(f" P90: {p90:.4f}")
# Count above/below threshold
above = sum(1 for v in values if v >= threshold)
below = sum(1 for v in values if v < threshold)
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
# Print distribution stats
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities (correct matches)")
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities (wrong matches)")
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities (false negatives)")
compute_stats(details["all_positive_sims"], "ALL similarities to CORRECT logos (per-ref)")
compute_stats(details["all_negative_sims"], "ALL similarities to WRONG logos (per-ref)")
# Overlap analysis
tp_sims = details["true_positive_sims"]
fp_sims = details["false_positive_sims"]
if tp_sims and fp_sims:
print("\n" + "-" * 40)
print("OVERLAP ANALYSIS:")
tp_min, tp_max = min(tp_sims), max(tp_sims)
fp_min, fp_max = min(fp_sims), max(fp_sims)
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
# Check overlap
overlap_min = max(tp_min, fp_min)
overlap_max = min(tp_max, fp_max)
if overlap_min < overlap_max:
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
tp_in_overlap = sum(1 for v in tp_sims if overlap_min <= v <= overlap_max)
fp_in_overlap = sum(1 for v in fp_sims if overlap_min <= v <= overlap_max)
print(f" TPs in overlap: {tp_in_overlap} ({100*tp_in_overlap/len(tp_sims):.1f}%)")
print(f" FPs in overlap: {fp_in_overlap} ({100*fp_in_overlap/len(fp_sims):.1f}%)")
else:
print(" NO OVERLAP - distributions are separable!")
# Suggest optimal threshold
all_points = [(s, "tp") for s in tp_sims] + [(s, "fp") for s in fp_sims]
all_points.sort()
best_thresh = threshold
best_f1 = 0
total_tp = len(tp_sims)
total_fp = len(fp_sims)
for thresh in [p[0] for p in all_points]:
# At this threshold:
tp_above = sum(1 for s in tp_sims if s >= thresh)
fp_above = sum(1 for s in fp_sims if s >= thresh)
prec = tp_above / (tp_above + fp_above) if (tp_above + fp_above) > 0 else 0
rec = tp_above / total_tp if total_tp > 0 else 0
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
if f1 > best_f1:
best_f1 = f1
best_thresh = thresh
print(f"\n SUGGESTED OPTIMAL THRESHOLD: {best_thresh:.4f}")
print(f" (would give F1 = {best_f1:.4f} on this data)")
# Print sample detection details
det_details = details["detection_details"]
if det_details:
print("\n" + "-" * 40)
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
for i, det in enumerate(det_details[:20]):
expected = det["expected_logos"]
top5 = det["top_5_matches"]
print(f"\n [{i+1}] Image: {det['image']}")
print(f" Expected: {expected if expected else '(none)'}")
print(f" DETR score: {det['detr_score']:.3f}")
print(f" Top 5 matches:")
for logo, sim in top5:
marker = " <-- CORRECT" if logo in expected else ""
print(f" {sim:.4f} {logo}{marker}")
print("\n" + "=" * 60)
def write_results_to_file(
output_path: Path,
args,
num_logos: int,
total_refs: int,
num_test_images: int,
true_positives: int,
false_positives: int,
false_negatives: int,
total_expected: int,
precision: float,
recall: float,
f1: float,
):
"""Write results summary to file with detailed header."""
from datetime import datetime
# Build method description for header
if args.matching_method == "simple":
method_desc = "Simple (all matches above threshold)"
elif args.matching_method == "margin":
method_desc = f"Margin-based (margin={args.margin})"
else: # multi-ref
agg = "max" if args.use_max_similarity else "mean"
method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})"
lines = [
"=" * 70,
f"TEST: {args.matching_method.upper()} MATCHING",
f"Model: {args.embedding_model}",
f"Method: {method_desc}",
"=" * 70,
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"",
"Configuration:",
f" Embedding model: {args.embedding_model}",
f" Preprocess mode: {args.preprocess_mode}",
f" Reference logos: {num_logos}",
f" Refs per logo: {args.refs_per_logo}",
f" Total reference embeddings:{total_refs}",
f" Positive samples/logo: {args.positive_samples}",
f" Negative samples/logo: {args.negative_samples}",
f" Test images processed: {num_test_images}",
f" Similarity threshold: {args.threshold}",
f" DETR threshold: {args.detr_threshold}",
]
if args.seed is not None:
lines.append(f" Random seed: {args.seed}")
lines.extend([
"",
"Results:",
f" True Positives: {true_positives:>6}",
f" False Positives: {false_positives:>6}",
f" False Negatives: {false_negatives:>6}",
f" Total Expected: {total_expected:>6}",
"",
"Scores:",
f" Precision: {precision:.4f} ({precision*100:.1f}%)",
f" Recall: {recall:.4f} ({recall*100:.1f}%)",
f" F1 Score: {f1:.4f} ({f1*100:.1f}%)",
"",
"",
])
# Append to file
with open(output_path, "a") as f:
f.write("\n".join(lines))
if __name__ == "__main__":
main()