#!/usr/bin/env python3 """ Test script for logo detection accuracy. This script: 1. Randomly samples N reference logos from the database 2. Processes all test images through the DETR+CLIP pipeline 3. Compares detected logos against reference embeddings 4. Reports accuracy metrics (correct matches, false positives, missed detections) Embeddings are cached to avoid reprocessing images. """ import argparse import logging import pickle import random import sqlite3 import sys from pathlib import Path from typing import Dict, List, Optional, Set, Tuple import cv2 import torch from tqdm import tqdm from logo_detection_detr import DetectLogosDETR def setup_logging(verbose: bool = False) -> logging.Logger: """Configure logging.""" level = logging.DEBUG if verbose else logging.INFO logging.basicConfig( level=level, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%H:%M:%S", ) return logging.getLogger(__name__) def load_image(image_path: Path) -> Optional[cv2.Mat]: """Load an image using OpenCV.""" img = cv2.imread(str(image_path)) if img is None: return None return img class EmbeddingCache: """Simple file-based cache for embeddings.""" def __init__(self, cache_path: Path): self.cache_path = cache_path self.cache: Dict[str, torch.Tensor] = {} self._load() def _load(self): """Load cache from disk if it exists.""" if self.cache_path.exists(): try: with open(self.cache_path, "rb") as f: self.cache = pickle.load(f) except Exception: self.cache = {} def save(self): """Save cache to disk.""" self.cache_path.parent.mkdir(parents=True, exist_ok=True) with open(self.cache_path, "wb") as f: pickle.dump(self.cache, f) def get(self, key: str) -> Optional[torch.Tensor]: """Get embedding from cache.""" return self.cache.get(key) def put(self, key: str, embedding: torch.Tensor): """Store embedding in cache.""" # Store on CPU to save GPU memory self.cache[key] = embedding.cpu() def __len__(self): return len(self.cache) def get_ground_truth(db_path: Path) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]: """ Load ground truth from database. Returns: Tuple of: - Dict mapping test image filename to set of logo names it contains - Dict mapping logo name to set of test image filenames containing it """ conn = sqlite3.connect(db_path) cursor = conn.cursor() # Query to get test image -> logo names mapping cursor.execute(""" SELECT ti.filename, ln.name FROM test_images ti JOIN reference_logos rl ON ti.id = rl.test_image_id JOIN logo_names ln ON rl.logo_name_id = ln.id """) image_to_logos: Dict[str, Set[str]] = {} logo_to_images: Dict[str, Set[str]] = {} for row in cursor.fetchall(): test_filename, logo_name = row if test_filename not in image_to_logos: image_to_logos[test_filename] = set() image_to_logos[test_filename].add(logo_name) if logo_name not in logo_to_images: logo_to_images[logo_name] = set() logo_to_images[logo_name].add(test_filename) conn.close() return image_to_logos, logo_to_images def sample_reference_logos( db_path: Path, num_logos: int, refs_per_logo: int = 1, seed: Optional[int] = None ) -> Dict[str, List[str]]: """ Randomly sample reference logos from database with multiple refs per logo. Args: db_path: Path to database num_logos: Number of logos to sample refs_per_logo: Number of reference images per logo seed: Random seed for reproducibility Returns: Dict mapping logo_name to list of reference filenames """ if seed is not None: random.seed(seed) conn = sqlite3.connect(db_path) cursor = conn.cursor() # Get all unique logo names cursor.execute("SELECT id, name FROM logo_names") all_logo_names = cursor.fetchall() # Sample logos if num_logos >= len(all_logo_names): sampled_logos = all_logo_names else: sampled_logos = random.sample(all_logo_names, num_logos) # For each sampled logo, get multiple reference files result: Dict[str, List[str]] = {} for logo_id, logo_name in sampled_logos: cursor.execute( "SELECT filename FROM reference_logos WHERE logo_name_id = ?", (logo_id,) ) all_refs = [row[0] for row in cursor.fetchall()] # Sample refs_per_logo references (or all if fewer available) if len(all_refs) > refs_per_logo: selected_refs = random.sample(all_refs, refs_per_logo) else: selected_refs = all_refs result[logo_name] = selected_refs conn.close() return result def get_test_images(db_path: Path) -> List[str]: """Get all test image filenames from database.""" conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("SELECT filename FROM test_images") filenames = [row[0] for row in cursor.fetchall()] conn.close() return filenames def main(): parser = argparse.ArgumentParser( description="Test logo detection accuracy against ground truth" ) parser.add_argument( "-n", "--num-logos", type=int, default=10, help="Number of reference logos to sample (default: 10)", ) parser.add_argument( "-t", "--threshold", type=float, default=0.7, help="CLIP similarity threshold for matching (default: 0.7)", ) parser.add_argument( "-d", "--detr-threshold", type=float, default=0.5, help="DETR detection confidence threshold (default: 0.5)", ) parser.add_argument( "-s", "--seed", type=int, default=None, help="Random seed for reproducibility", ) parser.add_argument( "--positive-samples", type=int, default=5, help="Number of positive test images per logo (images containing the logo) (default: 5)", ) parser.add_argument( "--negative-samples", type=int, default=20, help="Number of negative test images per logo (images NOT containing the logo) (default: 20)", ) parser.add_argument( "--refs-per-logo", type=int, default=3, help="Number of reference images per logo for multi-ref matching (default: 3)", ) parser.add_argument( "--margin", type=float, default=0.05, help="Required margin between best and second-best match for 'margin' method (default: 0.05)", ) parser.add_argument( "--matching-method", type=str, choices=["margin", "multi-ref"], default="margin", help="Matching method: 'margin' requires confidence margin over 2nd best, " "'multi-ref' aggregates scores across reference images (default: margin)", ) parser.add_argument( "--min-matching-refs", type=int, default=1, help="For 'multi-ref' method: minimum references that must match above threshold (default: 1)", ) parser.add_argument( "--use-max-similarity", action="store_true", help="For 'multi-ref' method: use max similarity instead of mean across references", ) parser.add_argument( "-v", "--verbose", action="store_true", help="Enable verbose logging", ) parser.add_argument( "--no-cache", action="store_true", help="Disable embedding cache", ) parser.add_argument( "--clear-cache", action="store_true", help="Clear embedding cache before running", ) args = parser.parse_args() logger = setup_logging(args.verbose) # Paths base_dir = Path(__file__).resolve().parent db_path = base_dir / "test_data_mapping.db" reference_dir = base_dir / "reference_logos" test_images_dir = base_dir / "test_images" cache_path = base_dir / ".embedding_cache.pkl" # Verify database exists if not db_path.exists(): logger.error(f"Database not found: {db_path}") logger.error("Run prepare_test_data.py first to create the database.") sys.exit(1) # Handle cache clearing if args.clear_cache and cache_path.exists(): cache_path.unlink() logger.info("Cleared embedding cache") # Initialize embedding cache cache = EmbeddingCache(cache_path) if not args.no_cache else None if cache: logger.info(f"Loaded {len(cache)} cached embeddings") # Initialize detector logger.info("Initializing logo detector...") detector = DetectLogosDETR( logger=logger, detr_threshold=args.detr_threshold, ) # Load ground truth (both mappings) logger.info("Loading ground truth from database...") image_to_logos, logo_to_images = get_ground_truth(db_path) all_test_images = set(image_to_logos.keys()) logger.info(f"Loaded ground truth for {len(image_to_logos)} test images") # Sample reference logos (with multiple refs per logo) logger.info(f"Sampling {args.num_logos} reference logos with {args.refs_per_logo} refs each...") sampled_logos = sample_reference_logos(db_path, args.num_logos, args.refs_per_logo, args.seed) logger.info(f"Selected {len(sampled_logos)} reference logos") # Compute reference embeddings (multiple per logo for multi-ref matching) logger.info("Computing reference logo embeddings...") # Dict for multi-ref matching: logo_name -> list of embeddings multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {} # List for margin-based matching: (logo_name, embedding) tuples reference_embeddings: List[Tuple[str, torch.Tensor]] = [] total_refs = 0 for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"): multi_ref_embeddings[logo_name] = [] for ref_filename in ref_filenames: ref_path = reference_dir / ref_filename if not ref_path.exists(): logger.warning(f"Reference logo not found: {ref_path}") continue # Check cache cache_key = f"ref:{ref_filename}" embedding = cache.get(cache_key) if cache else None if embedding is None: img = load_image(ref_path) if img is None: logger.warning(f"Failed to load reference logo: {ref_path}") continue embedding = detector.get_embedding(img) if cache: cache.put(cache_key, embedding) multi_ref_embeddings[logo_name].append(embedding) reference_embeddings.append((logo_name, embedding)) total_refs += 1 logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos") # Build test set: for each logo, sample positive and negative images logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...") test_image_set: Set[str] = set() test_image_expected: Dict[str, Set[str]] = {} # image -> logos it should match # Use sampled_logos keys (unique logo names) instead of reference_embeddings for logo_name in sampled_logos.keys(): # Get positive images (contain this logo) positive_images = list(logo_to_images.get(logo_name, set())) if len(positive_images) > args.positive_samples: positive_images = random.sample(positive_images, args.positive_samples) # Get negative images (do NOT contain this logo) negative_pool = list(all_test_images - logo_to_images.get(logo_name, set())) if len(negative_pool) > args.negative_samples: negative_images = random.sample(negative_pool, args.negative_samples) else: negative_images = negative_pool # Add to test set for img in positive_images: test_image_set.add(img) if img not in test_image_expected: test_image_expected[img] = set() test_image_expected[img].add(logo_name) for img in negative_images: test_image_set.add(img) if img not in test_image_expected: test_image_expected[img] = set() # Don't add logo_name - this is a negative sample test_images = list(test_image_set) logger.info(f"Selected {len(test_images)} unique test images") # Get set of reference logo names for quick lookup reference_logo_names = set(sampled_logos.keys()) # Metrics true_positives = 0 # Correctly matched logos false_positives = 0 # Matched but wrong logo or no logo present false_negatives = 0 # Logo present but not detected/matched total_expected = 0 # Total logos we should have found # Detailed results for analysis results = [] # Process test images for test_filename in tqdm(test_images, desc="Testing"): test_path = test_images_dir / test_filename if not test_path.exists(): logger.warning(f"Test image not found: {test_path}") continue # Get expected logos for this image (from our sampled set) expected_logos = test_image_expected.get(test_filename, set()) total_expected += len(expected_logos) # Check cache for detections cache_key = f"det:{test_filename}" cached_detections = cache.get(cache_key) if cache else None if cached_detections is not None: # Cached detections contain serialized box data and embeddings detections = cached_detections else: # Load and detect img = load_image(test_path) if img is None: logger.warning(f"Failed to load test image: {test_path}") continue detections = detector.detect(img) # Cache the detections if cache: cache.put(cache_key, detections) # Match detections against references using selected method matched_logos: Set[str] = set() for detection in detections: match = None similarity = None if args.matching_method == "margin": # Margin-based matching: requires margin over second-best match_result = detector.find_best_match_with_margin( detection["embedding"], reference_embeddings, similarity_threshold=args.threshold, margin=args.margin, ) if match_result: label, similarity = match_result match = label else: # multi-ref # Multi-ref matching: aggregates scores across reference images match_result = detector.find_best_match_multi_ref( detection["embedding"], multi_ref_embeddings, similarity_threshold=args.threshold, min_matching_refs=args.min_matching_refs, use_mean_similarity=not args.use_max_similarity, ) if match_result: label, similarity, num_matching = match_result match = label if match: matched_logos.add(match) # Check if this is a correct match if match in expected_logos: true_positives += 1 else: false_positives += 1 results.append({ "test_image": test_filename, "matched_logo": match, "similarity": similarity, "correct": match in expected_logos, }) # Count missed detections (false negatives) missed = expected_logos - matched_logos false_negatives += len(missed) for missed_logo in missed: results.append({ "test_image": test_filename, "matched_logo": None, "expected_logo": missed_logo, "similarity": None, "correct": False, }) # Save cache if cache: cache.save() logger.info(f"Saved {len(cache)} embeddings to cache") # Calculate metrics precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 recall = true_positives / total_expected if total_expected > 0 else 0 f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 # Print results print("\n" + "=" * 60) print("LOGO DETECTION TEST RESULTS") print("=" * 60) print(f"\nConfiguration:") print(f" Reference logos sampled: {len(sampled_logos)}") print(f" Refs per logo: {args.refs_per_logo}") print(f" Total reference embeddings:{total_refs}") print(f" Positive samples per logo: {args.positive_samples}") print(f" Negative samples per logo: {args.negative_samples}") print(f" Test images processed: {len(test_images)}") print(f" CLIP similarity threshold: {args.threshold}") print(f" DETR confidence threshold: {args.detr_threshold}") print(f" Matching method: {args.matching_method}") if args.matching_method == "margin": print(f" Matching margin: {args.margin}") else: # multi-ref print(f" Min matching refs: {args.min_matching_refs}") print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}") if args.seed is not None: print(f" Random seed: {args.seed}") print(f"\nMetrics:") print(f" True Positives (correct matches): {true_positives}") print(f" False Positives (wrong matches): {false_positives}") print(f" False Negatives (missed logos): {false_negatives}") print(f" Total expected matches: {total_expected}") print(f"\nScores:") print(f" Precision: {precision:.4f} ({precision*100:.1f}%)") print(f" Recall: {recall:.4f} ({recall*100:.1f}%)") print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)") # Show some example false positives false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]] if false_positive_examples: print(f"\nExample False Positives (first 5):") for r in false_positive_examples[:5]: print(f" - Image: {r['test_image']}") print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})") # Show reference logos used (unique names) unique_logos = sorted(sampled_logos.keys()) print(f"\nReference logos used ({len(unique_logos)}):") for name in unique_logos[:20]: print(f" - {name}") if len(unique_logos) > 20: print(f" ... and {len(unique_logos) - 20} more") print("=" * 60) if __name__ == "__main__": main()