The multi-ref matching method was missing a margin check against other logos, causing excessive false positives. This fix adds: - margin parameter to find_best_match_multi_ref() that requires the best logo's score to exceed the second-best by a minimum margin - Test script now passes --margin to both matching methods - Updated documentation to reflect margin applies to both methods Also adds run_comparison_tests.sh to run all three matching methods and compare results.
553 lines
19 KiB
Python
Executable File
553 lines
19 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Test script for logo detection accuracy.
|
|
|
|
This script:
|
|
1. Randomly samples N reference logos from the database
|
|
2. Processes all test images through the DETR+CLIP pipeline
|
|
3. Compares detected logos against reference embeddings
|
|
4. Reports accuracy metrics (correct matches, false positives, missed detections)
|
|
|
|
Embeddings are cached to avoid reprocessing images.
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import pickle
|
|
import random
|
|
import sqlite3
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Set, Tuple
|
|
|
|
import cv2
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from logo_detection_detr import DetectLogosDETR
|
|
|
|
|
|
def setup_logging(verbose: bool = False) -> logging.Logger:
|
|
"""Configure logging."""
|
|
level = logging.DEBUG if verbose else logging.INFO
|
|
logging.basicConfig(
|
|
level=level,
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
datefmt="%H:%M:%S",
|
|
)
|
|
return logging.getLogger(__name__)
|
|
|
|
|
|
def load_image(image_path: Path) -> Optional[cv2.Mat]:
|
|
"""Load an image using OpenCV."""
|
|
img = cv2.imread(str(image_path))
|
|
if img is None:
|
|
return None
|
|
return img
|
|
|
|
|
|
class EmbeddingCache:
|
|
"""Simple file-based cache for embeddings."""
|
|
|
|
def __init__(self, cache_path: Path):
|
|
self.cache_path = cache_path
|
|
self.cache: Dict[str, torch.Tensor] = {}
|
|
self._load()
|
|
|
|
def _load(self):
|
|
"""Load cache from disk if it exists."""
|
|
if self.cache_path.exists():
|
|
try:
|
|
with open(self.cache_path, "rb") as f:
|
|
self.cache = pickle.load(f)
|
|
except Exception:
|
|
self.cache = {}
|
|
|
|
def save(self):
|
|
"""Save cache to disk."""
|
|
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(self.cache_path, "wb") as f:
|
|
pickle.dump(self.cache, f)
|
|
|
|
def get(self, key: str) -> Optional[torch.Tensor]:
|
|
"""Get embedding from cache."""
|
|
return self.cache.get(key)
|
|
|
|
def put(self, key: str, embedding: torch.Tensor):
|
|
"""Store embedding in cache."""
|
|
# Store on CPU to save GPU memory
|
|
self.cache[key] = embedding.cpu()
|
|
|
|
def __len__(self):
|
|
return len(self.cache)
|
|
|
|
|
|
def get_ground_truth(db_path: Path) -> Tuple[Dict[str, Set[str]], Dict[str, Set[str]]]:
|
|
"""
|
|
Load ground truth from database.
|
|
|
|
Returns:
|
|
Tuple of:
|
|
- Dict mapping test image filename to set of logo names it contains
|
|
- Dict mapping logo name to set of test image filenames containing it
|
|
"""
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Query to get test image -> logo names mapping
|
|
cursor.execute("""
|
|
SELECT ti.filename, ln.name
|
|
FROM test_images ti
|
|
JOIN reference_logos rl ON ti.id = rl.test_image_id
|
|
JOIN logo_names ln ON rl.logo_name_id = ln.id
|
|
""")
|
|
|
|
image_to_logos: Dict[str, Set[str]] = {}
|
|
logo_to_images: Dict[str, Set[str]] = {}
|
|
|
|
for row in cursor.fetchall():
|
|
test_filename, logo_name = row
|
|
if test_filename not in image_to_logos:
|
|
image_to_logos[test_filename] = set()
|
|
image_to_logos[test_filename].add(logo_name)
|
|
|
|
if logo_name not in logo_to_images:
|
|
logo_to_images[logo_name] = set()
|
|
logo_to_images[logo_name].add(test_filename)
|
|
|
|
conn.close()
|
|
return image_to_logos, logo_to_images
|
|
|
|
|
|
def sample_reference_logos(
|
|
db_path: Path, num_logos: int, refs_per_logo: int = 1, seed: Optional[int] = None
|
|
) -> Dict[str, List[str]]:
|
|
"""
|
|
Randomly sample reference logos from database with multiple refs per logo.
|
|
|
|
Args:
|
|
db_path: Path to database
|
|
num_logos: Number of logos to sample
|
|
refs_per_logo: Number of reference images per logo
|
|
seed: Random seed for reproducibility
|
|
|
|
Returns:
|
|
Dict mapping logo_name to list of reference filenames
|
|
"""
|
|
if seed is not None:
|
|
random.seed(seed)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Get all unique logo names
|
|
cursor.execute("SELECT id, name FROM logo_names")
|
|
all_logo_names = cursor.fetchall()
|
|
|
|
# Sample logos
|
|
if num_logos >= len(all_logo_names):
|
|
sampled_logos = all_logo_names
|
|
else:
|
|
sampled_logos = random.sample(all_logo_names, num_logos)
|
|
|
|
# For each sampled logo, get multiple reference files
|
|
result: Dict[str, List[str]] = {}
|
|
for logo_id, logo_name in sampled_logos:
|
|
cursor.execute(
|
|
"SELECT filename FROM reference_logos WHERE logo_name_id = ?",
|
|
(logo_id,)
|
|
)
|
|
all_refs = [row[0] for row in cursor.fetchall()]
|
|
|
|
# Sample refs_per_logo references (or all if fewer available)
|
|
if len(all_refs) > refs_per_logo:
|
|
selected_refs = random.sample(all_refs, refs_per_logo)
|
|
else:
|
|
selected_refs = all_refs
|
|
|
|
result[logo_name] = selected_refs
|
|
|
|
conn.close()
|
|
return result
|
|
|
|
|
|
def get_test_images(db_path: Path) -> List[str]:
|
|
"""Get all test image filenames from database."""
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT filename FROM test_images")
|
|
filenames = [row[0] for row in cursor.fetchall()]
|
|
conn.close()
|
|
return filenames
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Test logo detection accuracy against ground truth"
|
|
)
|
|
parser.add_argument(
|
|
"-n", "--num-logos",
|
|
type=int,
|
|
default=10,
|
|
help="Number of reference logos to sample (default: 10)",
|
|
)
|
|
parser.add_argument(
|
|
"-t", "--threshold",
|
|
type=float,
|
|
default=0.7,
|
|
help="CLIP similarity threshold for matching (default: 0.7)",
|
|
)
|
|
parser.add_argument(
|
|
"-d", "--detr-threshold",
|
|
type=float,
|
|
default=0.5,
|
|
help="DETR detection confidence threshold (default: 0.5)",
|
|
)
|
|
parser.add_argument(
|
|
"-s", "--seed",
|
|
type=int,
|
|
default=None,
|
|
help="Random seed for reproducibility",
|
|
)
|
|
parser.add_argument(
|
|
"--positive-samples",
|
|
type=int,
|
|
default=5,
|
|
help="Number of positive test images per logo (images containing the logo) (default: 5)",
|
|
)
|
|
parser.add_argument(
|
|
"--negative-samples",
|
|
type=int,
|
|
default=20,
|
|
help="Number of negative test images per logo (images NOT containing the logo) (default: 20)",
|
|
)
|
|
parser.add_argument(
|
|
"--refs-per-logo",
|
|
type=int,
|
|
default=3,
|
|
help="Number of reference images per logo for multi-ref matching (default: 3)",
|
|
)
|
|
parser.add_argument(
|
|
"--margin",
|
|
type=float,
|
|
default=0.05,
|
|
help="Required margin between best and second-best match (applies to both methods) (default: 0.05)",
|
|
)
|
|
parser.add_argument(
|
|
"--matching-method",
|
|
type=str,
|
|
choices=["margin", "multi-ref"],
|
|
default="margin",
|
|
help="Matching method: 'margin' requires confidence margin over 2nd best, "
|
|
"'multi-ref' aggregates scores across reference images (default: margin)",
|
|
)
|
|
parser.add_argument(
|
|
"--min-matching-refs",
|
|
type=int,
|
|
default=1,
|
|
help="For 'multi-ref' method: minimum references that must match above threshold (default: 1)",
|
|
)
|
|
parser.add_argument(
|
|
"--use-max-similarity",
|
|
action="store_true",
|
|
help="For 'multi-ref' method: use max similarity instead of mean across references",
|
|
)
|
|
parser.add_argument(
|
|
"-v", "--verbose",
|
|
action="store_true",
|
|
help="Enable verbose logging",
|
|
)
|
|
parser.add_argument(
|
|
"--no-cache",
|
|
action="store_true",
|
|
help="Disable embedding cache",
|
|
)
|
|
parser.add_argument(
|
|
"--clear-cache",
|
|
action="store_true",
|
|
help="Clear embedding cache before running",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
logger = setup_logging(args.verbose)
|
|
|
|
# Paths
|
|
base_dir = Path(__file__).resolve().parent
|
|
db_path = base_dir / "test_data_mapping.db"
|
|
reference_dir = base_dir / "reference_logos"
|
|
test_images_dir = base_dir / "test_images"
|
|
cache_path = base_dir / ".embedding_cache.pkl"
|
|
|
|
# Verify database exists
|
|
if not db_path.exists():
|
|
logger.error(f"Database not found: {db_path}")
|
|
logger.error("Run prepare_test_data.py first to create the database.")
|
|
sys.exit(1)
|
|
|
|
# Handle cache clearing
|
|
if args.clear_cache and cache_path.exists():
|
|
cache_path.unlink()
|
|
logger.info("Cleared embedding cache")
|
|
|
|
# Initialize embedding cache
|
|
cache = EmbeddingCache(cache_path) if not args.no_cache else None
|
|
if cache:
|
|
logger.info(f"Loaded {len(cache)} cached embeddings")
|
|
|
|
# Initialize detector
|
|
logger.info("Initializing logo detector...")
|
|
detector = DetectLogosDETR(
|
|
logger=logger,
|
|
detr_threshold=args.detr_threshold,
|
|
)
|
|
|
|
# Load ground truth (both mappings)
|
|
logger.info("Loading ground truth from database...")
|
|
image_to_logos, logo_to_images = get_ground_truth(db_path)
|
|
all_test_images = set(image_to_logos.keys())
|
|
logger.info(f"Loaded ground truth for {len(image_to_logos)} test images")
|
|
|
|
# Sample reference logos (with multiple refs per logo)
|
|
logger.info(f"Sampling {args.num_logos} reference logos with {args.refs_per_logo} refs each...")
|
|
sampled_logos = sample_reference_logos(db_path, args.num_logos, args.refs_per_logo, args.seed)
|
|
logger.info(f"Selected {len(sampled_logos)} reference logos")
|
|
|
|
# Compute reference embeddings (multiple per logo for multi-ref matching)
|
|
logger.info("Computing reference logo embeddings...")
|
|
# Dict for multi-ref matching: logo_name -> list of embeddings
|
|
multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {}
|
|
# List for margin-based matching: (logo_name, embedding) tuples
|
|
reference_embeddings: List[Tuple[str, torch.Tensor]] = []
|
|
total_refs = 0
|
|
|
|
for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"):
|
|
multi_ref_embeddings[logo_name] = []
|
|
|
|
for ref_filename in ref_filenames:
|
|
ref_path = reference_dir / ref_filename
|
|
|
|
if not ref_path.exists():
|
|
logger.warning(f"Reference logo not found: {ref_path}")
|
|
continue
|
|
|
|
# Check cache
|
|
cache_key = f"ref:{ref_filename}"
|
|
embedding = cache.get(cache_key) if cache else None
|
|
|
|
if embedding is None:
|
|
img = load_image(ref_path)
|
|
if img is None:
|
|
logger.warning(f"Failed to load reference logo: {ref_path}")
|
|
continue
|
|
embedding = detector.get_embedding(img)
|
|
if cache:
|
|
cache.put(cache_key, embedding)
|
|
|
|
multi_ref_embeddings[logo_name].append(embedding)
|
|
reference_embeddings.append((logo_name, embedding))
|
|
total_refs += 1
|
|
|
|
logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos")
|
|
|
|
# Build test set: for each logo, sample positive and negative images
|
|
logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...")
|
|
test_image_set: Set[str] = set()
|
|
test_image_expected: Dict[str, Set[str]] = {} # image -> logos it should match
|
|
|
|
# Use sampled_logos keys (unique logo names) instead of reference_embeddings
|
|
for logo_name in sampled_logos.keys():
|
|
# Get positive images (contain this logo)
|
|
positive_images = list(logo_to_images.get(logo_name, set()))
|
|
if len(positive_images) > args.positive_samples:
|
|
positive_images = random.sample(positive_images, args.positive_samples)
|
|
|
|
# Get negative images (do NOT contain this logo)
|
|
negative_pool = list(all_test_images - logo_to_images.get(logo_name, set()))
|
|
if len(negative_pool) > args.negative_samples:
|
|
negative_images = random.sample(negative_pool, args.negative_samples)
|
|
else:
|
|
negative_images = negative_pool
|
|
|
|
# Add to test set
|
|
for img in positive_images:
|
|
test_image_set.add(img)
|
|
if img not in test_image_expected:
|
|
test_image_expected[img] = set()
|
|
test_image_expected[img].add(logo_name)
|
|
|
|
for img in negative_images:
|
|
test_image_set.add(img)
|
|
if img not in test_image_expected:
|
|
test_image_expected[img] = set()
|
|
# Don't add logo_name - this is a negative sample
|
|
|
|
test_images = list(test_image_set)
|
|
logger.info(f"Selected {len(test_images)} unique test images")
|
|
|
|
# Get set of reference logo names for quick lookup
|
|
reference_logo_names = set(sampled_logos.keys())
|
|
|
|
# Metrics
|
|
true_positives = 0 # Correctly matched logos
|
|
false_positives = 0 # Matched but wrong logo or no logo present
|
|
false_negatives = 0 # Logo present but not detected/matched
|
|
total_expected = 0 # Total logos we should have found
|
|
|
|
# Detailed results for analysis
|
|
results = []
|
|
|
|
# Process test images
|
|
for test_filename in tqdm(test_images, desc="Testing"):
|
|
test_path = test_images_dir / test_filename
|
|
|
|
if not test_path.exists():
|
|
logger.warning(f"Test image not found: {test_path}")
|
|
continue
|
|
|
|
# Get expected logos for this image (from our sampled set)
|
|
expected_logos = test_image_expected.get(test_filename, set())
|
|
total_expected += len(expected_logos)
|
|
|
|
# Check cache for detections
|
|
cache_key = f"det:{test_filename}"
|
|
cached_detections = cache.get(cache_key) if cache else None
|
|
|
|
if cached_detections is not None:
|
|
# Cached detections contain serialized box data and embeddings
|
|
detections = cached_detections
|
|
else:
|
|
# Load and detect
|
|
img = load_image(test_path)
|
|
if img is None:
|
|
logger.warning(f"Failed to load test image: {test_path}")
|
|
continue
|
|
|
|
detections = detector.detect(img)
|
|
|
|
# Cache the detections
|
|
if cache:
|
|
cache.put(cache_key, detections)
|
|
|
|
# Match detections against references using selected method
|
|
matched_logos: Set[str] = set()
|
|
for detection in detections:
|
|
match = None
|
|
similarity = None
|
|
|
|
if args.matching_method == "margin":
|
|
# Margin-based matching: requires margin over second-best
|
|
match_result = detector.find_best_match_with_margin(
|
|
detection["embedding"],
|
|
reference_embeddings,
|
|
similarity_threshold=args.threshold,
|
|
margin=args.margin,
|
|
)
|
|
if match_result:
|
|
label, similarity = match_result
|
|
match = label
|
|
else: # multi-ref
|
|
# Multi-ref matching: aggregates scores across reference images
|
|
match_result = detector.find_best_match_multi_ref(
|
|
detection["embedding"],
|
|
multi_ref_embeddings,
|
|
similarity_threshold=args.threshold,
|
|
min_matching_refs=args.min_matching_refs,
|
|
use_mean_similarity=not args.use_max_similarity,
|
|
margin=args.margin,
|
|
)
|
|
if match_result:
|
|
label, similarity, num_matching = match_result
|
|
match = label
|
|
|
|
if match:
|
|
matched_logos.add(match)
|
|
|
|
# Check if this is a correct match
|
|
if match in expected_logos:
|
|
true_positives += 1
|
|
else:
|
|
false_positives += 1
|
|
|
|
results.append({
|
|
"test_image": test_filename,
|
|
"matched_logo": match,
|
|
"similarity": similarity,
|
|
"correct": match in expected_logos,
|
|
})
|
|
|
|
# Count missed detections (false negatives)
|
|
missed = expected_logos - matched_logos
|
|
false_negatives += len(missed)
|
|
|
|
for missed_logo in missed:
|
|
results.append({
|
|
"test_image": test_filename,
|
|
"matched_logo": None,
|
|
"expected_logo": missed_logo,
|
|
"similarity": None,
|
|
"correct": False,
|
|
})
|
|
|
|
# Save cache
|
|
if cache:
|
|
cache.save()
|
|
logger.info(f"Saved {len(cache)} embeddings to cache")
|
|
|
|
# Calculate metrics
|
|
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
|
recall = true_positives / total_expected if total_expected > 0 else 0
|
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
|
|
|
# Print results
|
|
print("\n" + "=" * 60)
|
|
print("LOGO DETECTION TEST RESULTS")
|
|
print("=" * 60)
|
|
print(f"\nConfiguration:")
|
|
print(f" Reference logos sampled: {len(sampled_logos)}")
|
|
print(f" Refs per logo: {args.refs_per_logo}")
|
|
print(f" Total reference embeddings:{total_refs}")
|
|
print(f" Positive samples per logo: {args.positive_samples}")
|
|
print(f" Negative samples per logo: {args.negative_samples}")
|
|
print(f" Test images processed: {len(test_images)}")
|
|
print(f" CLIP similarity threshold: {args.threshold}")
|
|
print(f" DETR confidence threshold: {args.detr_threshold}")
|
|
print(f" Matching method: {args.matching_method}")
|
|
print(f" Matching margin: {args.margin}")
|
|
if args.matching_method == "multi-ref":
|
|
print(f" Min matching refs: {args.min_matching_refs}")
|
|
print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}")
|
|
if args.seed is not None:
|
|
print(f" Random seed: {args.seed}")
|
|
|
|
print(f"\nMetrics:")
|
|
print(f" True Positives (correct matches): {true_positives}")
|
|
print(f" False Positives (wrong matches): {false_positives}")
|
|
print(f" False Negatives (missed logos): {false_negatives}")
|
|
print(f" Total expected matches: {total_expected}")
|
|
|
|
print(f"\nScores:")
|
|
print(f" Precision: {precision:.4f} ({precision*100:.1f}%)")
|
|
print(f" Recall: {recall:.4f} ({recall*100:.1f}%)")
|
|
print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)")
|
|
|
|
# Show some example false positives
|
|
false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]]
|
|
if false_positive_examples:
|
|
print(f"\nExample False Positives (first 5):")
|
|
for r in false_positive_examples[:5]:
|
|
print(f" - Image: {r['test_image']}")
|
|
print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})")
|
|
|
|
# Show reference logos used (unique names)
|
|
unique_logos = sorted(sampled_logos.keys())
|
|
print(f"\nReference logos used ({len(unique_logos)}):")
|
|
for name in unique_logos[:20]:
|
|
print(f" - {name}")
|
|
if len(unique_logos) > 20:
|
|
print(f" ... and {len(unique_logos) - 20} more")
|
|
|
|
print("=" * 60)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |