Files
logo_test/test_logo_detection.py
Rick McEwen 197e007591 Add margin check to multi-ref matching to reduce false positives
The multi-ref matching method was missing a margin check against other
logos, causing excessive false positives. This fix adds:

- margin parameter to find_best_match_multi_ref() that requires the
  best logo's score to exceed the second-best by a minimum margin
- Test script now passes --margin to both matching methods
- Updated documentation to reflect margin applies to both methods

Also adds run_comparison_tests.sh to run all three matching methods
and compare results.
2025-12-31 11:23:47 -05:00

553 lines
19 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Test script for logo detection accuracy.
This script:
1. Randomly samples N reference logos from the database
2. Processes all test images through the DETR+CLIP pipeline
3. Compares detected logos against reference embeddings
4. Reports accuracy metrics (correct matches, false positives, missed detections)
Embeddings are cached to avoid reprocessing images.
"""
import argparse
import logging
import pickle
import random
import sqlite3
import sys
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import cv2
import torch
from tqdm import tqdm
from logo_detection_detr import DetectLogosDETR
def setup_logging(verbose: bool = False) -> logging.Logger:
"""Configure logging."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%H:%M:%S",
)
return logging.getLogger(__name__)
def load_image(image_path: Path) -> Optional[cv2.Mat]:
"""Load an image using OpenCV."""
img = cv2.imread(str(image_path))
if img is None:
return None
return img
class EmbeddingCache:
"""Simple file-based cache for embeddings."""
def __init__(self, cache_path: Path):
self.cache_path = cache_path
self.cache: Dict[str, torch.Tensor] = {}
self._load()
def _load(self):
"""Load cache from disk if it exists."""
if self.cache_path.exists():
try:
with open(self.cache_path, "rb") as f:
self.cache = pickle.load(f)
except Exception:
self.cache = {}
def save(self):
"""Save cache to disk."""
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.cache_path, "wb") as f:
pickle.dump(self.cache, f)
def get(self, key: str) -> Optional[torch.Tensor]:
"""Get embedding from cache."""
return self.cache.get(key)
def put(self, key: str, embedding: torch.Tensor):
"""Store embedding in cache."""
# Store on CPU to save GPU memory
self.cache[key] = embedding.cpu()
def __len__(self):
return len(self.cache)
def get_ground_truth(db_path: Path) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
"""
Load ground truth from database.
Returns:
Tuple of:
- Dict mapping test image filename to set of logo names it contains
- Dict mapping logo name to set of test image filenames containing it
"""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Query to get test image -> logo names mapping
cursor.execute("""
SELECT ti.filename, ln.name
FROM test_images ti
JOIN reference_logos rl ON ti.id = rl.test_image_id
JOIN logo_names ln ON rl.logo_name_id = ln.id
""")
image_to_logos: Dict[str, Set[str]] = {}
logo_to_images: Dict[str, Set[str]] = {}
for row in cursor.fetchall():
test_filename, logo_name = row
if test_filename not in image_to_logos:
image_to_logos[test_filename] = set()
image_to_logos[test_filename].add(logo_name)
if logo_name not in logo_to_images:
logo_to_images[logo_name] = set()
logo_to_images[logo_name].add(test_filename)
conn.close()
return image_to_logos, logo_to_images
def sample_reference_logos(
db_path: Path, num_logos: int, refs_per_logo: int = 1, seed: Optional[int] = None
) -> Dict[str, List[str]]:
"""
Randomly sample reference logos from database with multiple refs per logo.
Args:
db_path: Path to database
num_logos: Number of logos to sample
refs_per_logo: Number of reference images per logo
seed: Random seed for reproducibility
Returns:
Dict mapping logo_name to list of reference filenames
"""
if seed is not None:
random.seed(seed)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Get all unique logo names
cursor.execute("SELECT id, name FROM logo_names")
all_logo_names = cursor.fetchall()
# Sample logos
if num_logos >= len(all_logo_names):
sampled_logos = all_logo_names
else:
sampled_logos = random.sample(all_logo_names, num_logos)
# For each sampled logo, get multiple reference files
result: Dict[str, List[str]] = {}
for logo_id, logo_name in sampled_logos:
cursor.execute(
"SELECT filename FROM reference_logos WHERE logo_name_id = ?",
(logo_id,)
)
all_refs = [row[0] for row in cursor.fetchall()]
# Sample refs_per_logo references (or all if fewer available)
if len(all_refs) > refs_per_logo:
selected_refs = random.sample(all_refs, refs_per_logo)
else:
selected_refs = all_refs
result[logo_name] = selected_refs
conn.close()
return result
def get_test_images(db_path: Path) -> List[str]:
"""Get all test image filenames from database."""
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT filename FROM test_images")
filenames = [row[0] for row in cursor.fetchall()]
conn.close()
return filenames
def main():
parser = argparse.ArgumentParser(
description="Test logo detection accuracy against ground truth"
)
parser.add_argument(
"-n", "--num-logos",
type=int,
default=10,
help="Number of reference logos to sample (default: 10)",
)
parser.add_argument(
"-t", "--threshold",
type=float,
default=0.7,
help="CLIP similarity threshold for matching (default: 0.7)",
)
parser.add_argument(
"-d", "--detr-threshold",
type=float,
default=0.5,
help="DETR detection confidence threshold (default: 0.5)",
)
parser.add_argument(
"-s", "--seed",
type=int,
default=None,
help="Random seed for reproducibility",
)
parser.add_argument(
"--positive-samples",
type=int,
default=5,
help="Number of positive test images per logo (images containing the logo) (default: 5)",
)
parser.add_argument(
"--negative-samples",
type=int,
default=20,
help="Number of negative test images per logo (images NOT containing the logo) (default: 20)",
)
parser.add_argument(
"--refs-per-logo",
type=int,
default=3,
help="Number of reference images per logo for multi-ref matching (default: 3)",
)
parser.add_argument(
"--margin",
type=float,
default=0.05,
help="Required margin between best and second-best match (applies to both methods) (default: 0.05)",
)
parser.add_argument(
"--matching-method",
type=str,
choices=["margin", "multi-ref"],
default="margin",
help="Matching method: 'margin' requires confidence margin over 2nd best, "
"'multi-ref' aggregates scores across reference images (default: margin)",
)
parser.add_argument(
"--min-matching-refs",
type=int,
default=1,
help="For 'multi-ref' method: minimum references that must match above threshold (default: 1)",
)
parser.add_argument(
"--use-max-similarity",
action="store_true",
help="For 'multi-ref' method: use max similarity instead of mean across references",
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
help="Enable verbose logging",
)
parser.add_argument(
"--no-cache",
action="store_true",
help="Disable embedding cache",
)
parser.add_argument(
"--clear-cache",
action="store_true",
help="Clear embedding cache before running",
)
args = parser.parse_args()
logger = setup_logging(args.verbose)
# Paths
base_dir = Path(__file__).resolve().parent
db_path = base_dir / "test_data_mapping.db"
reference_dir = base_dir / "reference_logos"
test_images_dir = base_dir / "test_images"
cache_path = base_dir / ".embedding_cache.pkl"
# Verify database exists
if not db_path.exists():
logger.error(f"Database not found: {db_path}")
logger.error("Run prepare_test_data.py first to create the database.")
sys.exit(1)
# Handle cache clearing
if args.clear_cache and cache_path.exists():
cache_path.unlink()
logger.info("Cleared embedding cache")
# Initialize embedding cache
cache = EmbeddingCache(cache_path) if not args.no_cache else None
if cache:
logger.info(f"Loaded {len(cache)} cached embeddings")
# Initialize detector
logger.info("Initializing logo detector...")
detector = DetectLogosDETR(
logger=logger,
detr_threshold=args.detr_threshold,
)
# Load ground truth (both mappings)
logger.info("Loading ground truth from database...")
image_to_logos, logo_to_images = get_ground_truth(db_path)
all_test_images = set(image_to_logos.keys())
logger.info(f"Loaded ground truth for {len(image_to_logos)} test images")
# Sample reference logos (with multiple refs per logo)
logger.info(f"Sampling {args.num_logos} reference logos with {args.refs_per_logo} refs each...")
sampled_logos = sample_reference_logos(db_path, args.num_logos, args.refs_per_logo, args.seed)
logger.info(f"Selected {len(sampled_logos)} reference logos")
# Compute reference embeddings (multiple per logo for multi-ref matching)
logger.info("Computing reference logo embeddings...")
# Dict for multi-ref matching: logo_name -> list of embeddings
multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {}
# List for margin-based matching: (logo_name, embedding) tuples
reference_embeddings: List[Tuple[str, torch.Tensor]] = []
total_refs = 0
for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"):
multi_ref_embeddings[logo_name] = []
for ref_filename in ref_filenames:
ref_path = reference_dir / ref_filename
if not ref_path.exists():
logger.warning(f"Reference logo not found: {ref_path}")
continue
# Check cache
cache_key = f"ref:{ref_filename}"
embedding = cache.get(cache_key) if cache else None
if embedding is None:
img = load_image(ref_path)
if img is None:
logger.warning(f"Failed to load reference logo: {ref_path}")
continue
embedding = detector.get_embedding(img)
if cache:
cache.put(cache_key, embedding)
multi_ref_embeddings[logo_name].append(embedding)
reference_embeddings.append((logo_name, embedding))
total_refs += 1
logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos")
# Build test set: for each logo, sample positive and negative images
logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...")
test_image_set: Set[str] = set()
test_image_expected: Dict[str, Set[str]] = {} # image -> logos it should match
# Use sampled_logos keys (unique logo names) instead of reference_embeddings
for logo_name in sampled_logos.keys():
# Get positive images (contain this logo)
positive_images = list(logo_to_images.get(logo_name, set()))
if len(positive_images) > args.positive_samples:
positive_images = random.sample(positive_images, args.positive_samples)
# Get negative images (do NOT contain this logo)
negative_pool = list(all_test_images - logo_to_images.get(logo_name, set()))
if len(negative_pool) > args.negative_samples:
negative_images = random.sample(negative_pool, args.negative_samples)
else:
negative_images = negative_pool
# Add to test set
for img in positive_images:
test_image_set.add(img)
if img not in test_image_expected:
test_image_expected[img] = set()
test_image_expected[img].add(logo_name)
for img in negative_images:
test_image_set.add(img)
if img not in test_image_expected:
test_image_expected[img] = set()
# Don't add logo_name - this is a negative sample
test_images = list(test_image_set)
logger.info(f"Selected {len(test_images)} unique test images")
# Get set of reference logo names for quick lookup
reference_logo_names = set(sampled_logos.keys())
# Metrics
true_positives = 0 # Correctly matched logos
false_positives = 0 # Matched but wrong logo or no logo present
false_negatives = 0 # Logo present but not detected/matched
total_expected = 0 # Total logos we should have found
# Detailed results for analysis
results = []
# Process test images
for test_filename in tqdm(test_images, desc="Testing"):
test_path = test_images_dir / test_filename
if not test_path.exists():
logger.warning(f"Test image not found: {test_path}")
continue
# Get expected logos for this image (from our sampled set)
expected_logos = test_image_expected.get(test_filename, set())
total_expected += len(expected_logos)
# Check cache for detections
cache_key = f"det:{test_filename}"
cached_detections = cache.get(cache_key) if cache else None
if cached_detections is not None:
# Cached detections contain serialized box data and embeddings
detections = cached_detections
else:
# Load and detect
img = load_image(test_path)
if img is None:
logger.warning(f"Failed to load test image: {test_path}")
continue
detections = detector.detect(img)
# Cache the detections
if cache:
cache.put(cache_key, detections)
# Match detections against references using selected method
matched_logos: Set[str] = set()
for detection in detections:
match = None
similarity = None
if args.matching_method == "margin":
# Margin-based matching: requires margin over second-best
match_result = detector.find_best_match_with_margin(
detection["embedding"],
reference_embeddings,
similarity_threshold=args.threshold,
margin=args.margin,
)
if match_result:
label, similarity = match_result
match = label
else: # multi-ref
# Multi-ref matching: aggregates scores across reference images
match_result = detector.find_best_match_multi_ref(
detection["embedding"],
multi_ref_embeddings,
similarity_threshold=args.threshold,
min_matching_refs=args.min_matching_refs,
use_mean_similarity=not args.use_max_similarity,
margin=args.margin,
)
if match_result:
label, similarity, num_matching = match_result
match = label
if match:
matched_logos.add(match)
# Check if this is a correct match
if match in expected_logos:
true_positives += 1
else:
false_positives += 1
results.append({
"test_image": test_filename,
"matched_logo": match,
"similarity": similarity,
"correct": match in expected_logos,
})
# Count missed detections (false negatives)
missed = expected_logos - matched_logos
false_negatives += len(missed)
for missed_logo in missed:
results.append({
"test_image": test_filename,
"matched_logo": None,
"expected_logo": missed_logo,
"similarity": None,
"correct": False,
})
# Save cache
if cache:
cache.save()
logger.info(f"Saved {len(cache)} embeddings to cache")
# Calculate metrics
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / total_expected if total_expected > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
# Print results
print("\n" + "=" * 60)
print("LOGO DETECTION TEST RESULTS")
print("=" * 60)
print(f"\nConfiguration:")
print(f" Reference logos sampled: {len(sampled_logos)}")
print(f" Refs per logo: {args.refs_per_logo}")
print(f" Total reference embeddings:{total_refs}")
print(f" Positive samples per logo: {args.positive_samples}")
print(f" Negative samples per logo: {args.negative_samples}")
print(f" Test images processed: {len(test_images)}")
print(f" CLIP similarity threshold: {args.threshold}")
print(f" DETR confidence threshold: {args.detr_threshold}")
print(f" Matching method: {args.matching_method}")
print(f" Matching margin: {args.margin}")
if args.matching_method == "multi-ref":
print(f" Min matching refs: {args.min_matching_refs}")
print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}")
if args.seed is not None:
print(f" Random seed: {args.seed}")
print(f"\nMetrics:")
print(f" True Positives (correct matches): {true_positives}")
print(f" False Positives (wrong matches): {false_positives}")
print(f" False Negatives (missed logos): {false_negatives}")
print(f" Total expected matches: {total_expected}")
print(f"\nScores:")
print(f" Precision: {precision:.4f} ({precision*100:.1f}%)")
print(f" Recall: {recall:.4f} ({recall*100:.1f}%)")
print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)")
# Show some example false positives
false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]]
if false_positive_examples:
print(f"\nExample False Positives (first 5):")
for r in false_positive_examples[:5]:
print(f" - Image: {r['test_image']}")
print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})")
# Show reference logos used (unique names)
unique_logos = sorted(sampled_logos.keys())
print(f"\nReference logos used ({len(unique_logos)}):")
for name in unique_logos[:20]:
print(f" - {name}")
if len(unique_logos) > 20:
print(f" ... and {len(unique_logos) - 20} more")
print("=" * 60)
if __name__ == "__main__":
main()