Add Burnley logo detection test using DetectLogosEmbeddings

Test script for barnfield and vertu logo detection on Burnley test
images. Uses averaged reference embeddings and margin-based matching.
Ground truth derived from filename prefixes.
This commit is contained in:
Rick McEwen
2026-03-31 11:49:11 -06:00
parent 91d1c9cd59
commit f598866d37

521
test_burnley_detection.py Normal file
View File

@ -0,0 +1,521 @@
#!/usr/bin/env python3
"""
Test script for logo detection accuracy on Burnley test images.
Uses DetectLogosEmbeddings from logo_detection_embeddings.py to detect
barnfield and vertu logos. Ground truth is determined by filename prefix:
- "vertu_" → contains vertu logo
- "barnfield_" → contains barnfield logo
- "barnfield+vertu_" → contains both logos
- anything else → no target logos
"""
import argparse
import logging
import pickle
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple
import cv2
import torch
from tqdm import tqdm
from logo_detection_embeddings import DetectLogosEmbeddings
def setup_logging(verbose: bool = False) -> logging.Logger:
"""Configure logging."""
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%H:%M:%S",
)
return logging.getLogger(__name__)
def load_image(image_path: Path) -> Optional[cv2.Mat]:
"""Load an image using OpenCV."""
img = cv2.imread(str(image_path))
if img is None:
return None
return img
class EmbeddingCache:
"""Simple file-based cache for embeddings."""
def __init__(self, cache_path: Path):
self.cache_path = cache_path
self.cache: Dict[str, Any] = {}
self._load()
def _load(self):
if self.cache_path.exists():
try:
with open(self.cache_path, "rb") as f:
self.cache = pickle.load(f)
except Exception:
self.cache = {}
def save(self):
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(self.cache_path, "wb") as f:
pickle.dump(self.cache, f)
def get(self, key: str):
return self.cache.get(key)
def put(self, key: str, value):
if isinstance(value, torch.Tensor):
self.cache[key] = value.cpu()
else:
self.cache[key] = value
def __len__(self):
return len(self.cache)
def get_expected_logos(filename: str) -> Set[str]:
"""Determine expected logos from filename prefix."""
name = filename.lower()
if name.startswith("barnfield+vertu_"):
return {"barnfield", "vertu"}
elif name.startswith("barnfield_"):
return {"barnfield"}
elif name.startswith("vertu_"):
return {"vertu"}
return set()
def load_reference_images(ref_dir: Path, logger: logging.Logger) -> List[cv2.Mat]:
"""Load all images from a reference directory."""
images = []
for path in sorted(ref_dir.iterdir()):
if path.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp"):
img = load_image(path)
if img is not None:
images.append(img)
else:
logger.warning(f"Failed to load reference image: {path}")
return images
def main():
parser = argparse.ArgumentParser(
description="Test logo detection on Burnley test images using DetectLogosEmbeddings"
)
parser.add_argument(
"-t", "--threshold",
type=float,
default=0.7,
help="Similarity threshold for matching (default: 0.7)",
)
parser.add_argument(
"-d", "--detr-threshold",
type=float,
default=0.5,
help="DETR detection confidence threshold (default: 0.5)",
)
parser.add_argument(
"-e", "--embedding-model",
type=str,
choices=["clip", "dinov2", "siglip"],
default="dinov2",
help="Embedding model type (default: dinov2)",
)
parser.add_argument(
"--margin",
type=float,
default=0.05,
help="Required margin between best and second-best match (default: 0.05)",
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
help="Enable verbose logging",
)
parser.add_argument(
"--similarity-details",
action="store_true",
help="Output detailed similarity scores for each detection",
)
parser.add_argument(
"--no-cache",
action="store_true",
help="Disable embedding cache",
)
parser.add_argument(
"--clear-cache",
action="store_true",
help="Clear embedding cache before running",
)
parser.add_argument(
"--output-file",
type=str,
default=None,
help="Append results summary to this file",
)
args = parser.parse_args()
logger = setup_logging(args.verbose)
# Paths
base_dir = Path(__file__).resolve().parent
test_images_dir = base_dir / "burnley_test_images"
barnfield_ref_dir = base_dir / "barnfield_reference_images"
vertu_ref_dir = base_dir / "vertu_reference_images"
cache_path = base_dir / ".burnley_embedding_cache.pkl"
# Verify directories exist
for d, name in [(test_images_dir, "Test images"), (barnfield_ref_dir, "Barnfield refs"), (vertu_ref_dir, "Vertu refs")]:
if not d.exists():
logger.error(f"{name} directory not found: {d}")
sys.exit(1)
# Handle cache
if args.clear_cache and cache_path.exists():
cache_path.unlink()
logger.info("Cleared embedding cache")
cache = EmbeddingCache(cache_path) if not args.no_cache else None
if cache:
logger.info(f"Loaded {len(cache)} cached embeddings")
# Initialize detector
logger.info(f"Initializing detector with embedding model: {args.embedding_model}")
detector = DetectLogosEmbeddings(
logger=logger,
detr_threshold=args.detr_threshold,
embedding_model_type=args.embedding_model,
)
# Compute averaged reference embeddings
logger.info("Computing reference embeddings...")
reference_embeddings: Dict[str, torch.Tensor] = {}
for logo_name, ref_dir in [("barnfield", barnfield_ref_dir), ("vertu", vertu_ref_dir)]:
cache_key = f"avg_ref:{logo_name}:{args.embedding_model}"
cached = cache.get(cache_key) if cache else None
if cached is not None:
reference_embeddings[logo_name] = cached
logger.info(f"Loaded cached averaged embedding for {logo_name}")
else:
ref_images = load_reference_images(ref_dir, logger)
logger.info(f"Computing averaged embedding for {logo_name} from {len(ref_images)} images")
avg_emb = detector.get_averaged_embedding(ref_images)
if avg_emb is None:
logger.error(f"Failed to compute embedding for {logo_name}")
sys.exit(1)
reference_embeddings[logo_name] = avg_emb
if cache:
cache.put(cache_key, avg_emb)
# Collect test images
test_files = sorted([
f.name for f in test_images_dir.iterdir()
if f.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp")
])
logger.info(f"Found {len(test_files)} test images")
# Metrics
true_positives = 0
false_positives = 0
false_negatives = 0
total_expected = 0
results = []
similarity_details = {
"true_positive_sims": [],
"false_positive_sims": [],
"missed_best_sims": [],
"detection_details": [],
}
# Process test images
for test_filename in tqdm(test_files, desc="Testing"):
test_path = test_images_dir / test_filename
expected_logos = get_expected_logos(test_filename)
total_expected += len(expected_logos)
# Check cache for detections
det_cache_key = f"det:{test_filename}:{args.embedding_model}"
cached_detections = cache.get(det_cache_key) if cache else None
if cached_detections is not None:
detections = cached_detections
else:
test_img = load_image(test_path)
if test_img is None:
logger.warning(f"Failed to load test image: {test_path}")
continue
detections = detector.detect(test_img)
if cache:
cache.put(det_cache_key, detections)
# Match each detection against reference embeddings with margin
matched_logos: Set[str] = set()
for det_idx, detection in enumerate(detections):
# Compute similarity to each reference logo
sims: Dict[str, float] = {}
for logo_name, ref_emb in reference_embeddings.items():
sims[logo_name] = detector.compare_embeddings(
detection["embedding"], ref_emb
)
sorted_sims = sorted(sims.items(), key=lambda x: -x[1])
if args.similarity_details:
similarity_details["detection_details"].append({
"image": test_filename,
"detection_idx": det_idx,
"expected_logos": list(expected_logos),
"similarities": sorted_sims,
"detr_score": detection.get("score", 0),
})
# Best match with margin check
if not sorted_sims:
continue
best_name, best_sim = sorted_sims[0]
if best_sim < args.threshold:
continue
# Check margin over second best
if len(sorted_sims) > 1:
second_sim = sorted_sims[1][1]
if best_sim - second_sim < args.margin:
continue
matched_logos.add(best_name)
is_correct = best_name in expected_logos
if is_correct:
true_positives += 1
if args.similarity_details:
similarity_details["true_positive_sims"].append(best_sim)
else:
false_positives += 1
if args.similarity_details:
similarity_details["false_positive_sims"].append(best_sim)
results.append({
"test_image": test_filename,
"matched_logo": best_name,
"similarity": best_sim,
"correct": is_correct,
})
# Count missed detections
missed = expected_logos - matched_logos
false_negatives += len(missed)
for missed_logo in missed:
if args.similarity_details and detections:
best_sim_for_missed = 0
ref_emb = reference_embeddings[missed_logo]
for detection in detections:
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
best_sim_for_missed = max(best_sim_for_missed, sim)
similarity_details["missed_best_sims"].append(best_sim_for_missed)
results.append({
"test_image": test_filename,
"matched_logo": None,
"expected_logo": missed_logo,
"similarity": None,
"correct": False,
})
# Save cache
if cache:
cache.save()
logger.info(f"Saved {len(cache)} embeddings to cache")
# Calculate metrics
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
recall = true_positives / total_expected if total_expected > 0 else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
# Print results
print("\n" + "=" * 60)
print("BURNLEY LOGO DETECTION TEST RESULTS")
print("=" * 60)
print(f"\nConfiguration:")
print(f" Embedding model: {args.embedding_model}")
print(f" Similarity threshold: {args.threshold}")
print(f" DETR confidence threshold: {args.detr_threshold}")
print(f" Matching margin: {args.margin}")
print(f" Test images processed: {len(test_files)}")
print(f" Reference logos: barnfield, vertu")
print(f"\nMetrics:")
print(f" True Positives (correct matches): {true_positives}")
print(f" False Positives (wrong matches): {false_positives}")
print(f" False Negatives (missed logos): {false_negatives}")
print(f" Total expected matches: {total_expected}")
print(f"\nScores:")
print(f" Precision: {precision:.4f} ({precision*100:.1f}%)")
print(f" Recall: {recall:.4f} ({recall*100:.1f}%)")
print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)")
# Show false positive examples
false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]]
if false_positive_examples:
print(f"\nExample False Positives (first 5):")
for r in false_positive_examples[:5]:
print(f" - Image: {r['test_image']}")
print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})")
# Show false negative examples
false_negative_examples = [r for r in results if r.get("expected_logo")]
if false_negative_examples:
print(f"\nExample False Negatives (first 5):")
for r in false_negative_examples[:5]:
print(f" - Image: {r['test_image']}")
print(f" Expected: {r['expected_logo']}")
print("=" * 60)
# Print similarity details if requested
if args.similarity_details:
print_similarity_details(similarity_details, args.threshold)
# Write results to file if requested
if args.output_file:
write_results_to_file(
output_path=Path(args.output_file),
args=args,
num_test_images=len(test_files),
true_positives=true_positives,
false_positives=false_positives,
false_negatives=false_negatives,
total_expected=total_expected,
precision=precision,
recall=recall,
f1=f1,
)
print(f"\nResults appended to: {args.output_file}")
def print_similarity_details(details: dict, threshold: float):
"""Print detailed similarity distribution analysis."""
import statistics
print("\n" + "=" * 60)
print("SIMILARITY DISTRIBUTION ANALYSIS")
print("=" * 60)
def compute_stats(values, name):
if not values:
print(f"\n{name}: No data")
return
print(f"\n{name} (n={len(values)}):")
print(f" Min: {min(values):.4f}")
print(f" Max: {max(values):.4f}")
print(f" Mean: {statistics.mean(values):.4f}")
if len(values) > 1:
print(f" StdDev: {statistics.stdev(values):.4f}")
print(f" Median: {statistics.median(values):.4f}")
above = sum(1 for v in values if v >= threshold)
below = sum(1 for v in values if v < threshold)
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities")
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities")
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities")
# Overlap analysis
tp_sims = details["true_positive_sims"]
fp_sims = details["false_positive_sims"]
if tp_sims and fp_sims:
print("\n" + "-" * 40)
print("OVERLAP ANALYSIS:")
tp_min, tp_max = min(tp_sims), max(tp_sims)
fp_min, fp_max = min(fp_sims), max(fp_sims)
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
overlap_min = max(tp_min, fp_min)
overlap_max = min(tp_max, fp_max)
if overlap_min < overlap_max:
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
else:
print(" NO OVERLAP - distributions are separable!")
# Sample detection details
det_details = details["detection_details"]
if det_details:
print("\n" + "-" * 40)
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
for i, det in enumerate(det_details[:20]):
expected = det["expected_logos"]
sims = det["similarities"]
print(f"\n [{i+1}] Image: {det['image']}")
print(f" Expected: {expected if expected else '(none)'}")
print(f" DETR score: {det['detr_score']:.3f}")
print(f" Similarities:")
for logo, sim in sims:
marker = " <-- CORRECT" if logo in expected else ""
print(f" {sim:.4f} {logo}{marker}")
print("\n" + "=" * 60)
def write_results_to_file(
output_path: Path,
args,
num_test_images: int,
true_positives: int,
false_positives: int,
false_negatives: int,
total_expected: int,
precision: float,
recall: float,
f1: float,
):
"""Write results summary to file."""
from datetime import datetime
lines = [
"=" * 70,
"BURNLEY LOGO DETECTION TEST",
f"Model: {args.embedding_model}",
f"Method: Margin-based (margin={args.margin})",
"=" * 70,
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"",
"Configuration:",
f" Embedding model: {args.embedding_model}",
f" Similarity threshold: {args.threshold}",
f" DETR threshold: {args.detr_threshold}",
f" Matching margin: {args.margin}",
f" Test images processed: {num_test_images}",
f" Reference logos: barnfield, vertu",
"",
"Results:",
f" True Positives: {true_positives:>6}",
f" False Positives: {false_positives:>6}",
f" False Negatives: {false_negatives:>6}",
f" Total Expected: {total_expected:>6}",
"",
"Scores:",
f" Precision: {precision:.4f} ({precision*100:.1f}%)",
f" Recall: {recall:.4f} ({recall*100:.1f}%)",
f" F1 Score: {f1:.4f} ({f1*100:.1f}%)",
"",
"",
]
with open(output_path, "a") as f:
f.write("\n".join(lines))
if __name__ == "__main__":
main()