Compare commits
2 Commits
91d1c9cd59
...
512f678310
| Author | SHA1 | Date | |
|---|---|---|---|
| 512f678310 | |||
| f598866d37 |
364
logo_detection_embeddings.py
Normal file
364
logo_detection_embeddings.py
Normal file
@ -0,0 +1,364 @@
|
||||
"""
|
||||
Logo detection using DETR for object detection and selectable embedding models for feature matching.
|
||||
|
||||
This module provides a class for detecting logos in images using:
|
||||
1. DETR (DEtection TRansformer) for initial logo region detection
|
||||
2. Selectable embedding model (CLIP, DINOv2, or SigLIP) for feature extraction and matching
|
||||
|
||||
Key features:
|
||||
- Multiple reference images per logo entry, averaged into a single embedding
|
||||
- Cache-aware: averaged embeddings are only recalculated when the filenames list changes
|
||||
- Supports local model directories with fallback to HuggingFace
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoModel,
|
||||
AutoProcessor,
|
||||
CLIPModel,
|
||||
CLIPProcessor,
|
||||
Dinov2Model,
|
||||
pipeline,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class DetectLogosEmbeddings:
|
||||
"""
|
||||
Logo detection class using DETR and a selectable embedding model.
|
||||
|
||||
This class detects logos in images by:
|
||||
1. Using DETR to find potential logo regions (bounding boxes)
|
||||
2. Extracting embeddings for each detected region using the selected model
|
||||
3. Comparing embeddings with averaged reference logo embeddings for identification
|
||||
|
||||
Supported embedding models:
|
||||
- clip: openai/clip-vit-large-patch14
|
||||
- dinov2: facebook/dinov2-base (recommended for visual similarity)
|
||||
- siglip: google/siglip-base-patch16-224
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger,
|
||||
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
|
||||
embedding_model_type: str = "dinov2",
|
||||
detr_threshold: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Initialize DETR and embedding models.
|
||||
|
||||
Args:
|
||||
logger: Logger instance for logging
|
||||
detr_model: HuggingFace model name or local path for DETR object detection
|
||||
embedding_model_type: One of "clip", "dinov2", or "siglip"
|
||||
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||
"""
|
||||
self.logger = logger
|
||||
self.detr_threshold = detr_threshold
|
||||
self.embedding_model_type = embedding_model_type
|
||||
|
||||
# Set device
|
||||
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
self.device_index = 0 if torch.cuda.is_available() else -1
|
||||
self.device = torch.device(self.device_str)
|
||||
|
||||
self.logger.info(
|
||||
f"Initializing DetectLogosEmbeddings on device: {self.device_str}, "
|
||||
f"embedding model: {embedding_model_type}"
|
||||
)
|
||||
|
||||
# --- DETR model ---
|
||||
default_detr_dir = os.environ.get(
|
||||
"LOGO_DETR_MODEL_DIR", "models/logo_detection/detr"
|
||||
)
|
||||
detr_model_path = self._resolve_model_path(detr_model, default_detr_dir, "DETR")
|
||||
|
||||
self.logger.info(f"Loading DETR model: {detr_model_path}")
|
||||
self.detr_pipe = pipeline(
|
||||
task="object-detection",
|
||||
model=detr_model_path,
|
||||
device=self.device_index,
|
||||
use_fast=True,
|
||||
)
|
||||
|
||||
# --- Embedding model ---
|
||||
self._load_embedding_model(embedding_model_type)
|
||||
|
||||
self.logger.info("DetectLogosEmbeddings initialization complete")
|
||||
|
||||
def _load_embedding_model(self, model_type: str) -> None:
|
||||
"""
|
||||
Load the selected embedding model.
|
||||
|
||||
Args:
|
||||
model_type: One of "clip", "dinov2", or "siglip"
|
||||
"""
|
||||
default_embedding_dir = os.environ.get(
|
||||
"LOGO_EMBEDDING_MODEL_DIR", f"models/logo_detection/{model_type}"
|
||||
)
|
||||
|
||||
if model_type == "clip":
|
||||
model_name = "openai/clip-vit-large-patch14"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "CLIP"
|
||||
)
|
||||
self.logger.info(f"Loading CLIP model: {model_path}")
|
||||
self._clip_model = CLIPModel.from_pretrained(model_path).to(self.device)
|
||||
self._clip_processor = CLIPProcessor.from_pretrained(model_path)
|
||||
self._clip_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._clip_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self._clip_model.get_image_features(**inputs)
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
elif model_type == "dinov2":
|
||||
model_name = "facebook/dinov2-base"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "DINOv2"
|
||||
)
|
||||
self.logger.info(f"Loading DINOv2 model: {model_path}")
|
||||
self._dinov2_model = Dinov2Model.from_pretrained(model_path).to(self.device)
|
||||
self._dinov2_processor = AutoImageProcessor.from_pretrained(model_path)
|
||||
self._dinov2_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._dinov2_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self._dinov2_model(**inputs)
|
||||
# Use CLS token embedding
|
||||
features = outputs.last_hidden_state[:, 0, :]
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
elif model_type == "siglip":
|
||||
model_name = "google/siglip-base-patch16-224"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "SigLIP"
|
||||
)
|
||||
self.logger.info(f"Loading SigLIP model: {model_path}")
|
||||
self._siglip_model = AutoModel.from_pretrained(model_path).to(self.device)
|
||||
self._siglip_processor = AutoProcessor.from_pretrained(model_path)
|
||||
self._siglip_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._siglip_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self._siglip_model.get_image_features(**inputs)
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown embedding model type: {model_type}. "
|
||||
f"Use 'clip', 'dinov2', or 'siglip'"
|
||||
)
|
||||
|
||||
self._embed_fn = embed_fn
|
||||
|
||||
def _resolve_model_path(
|
||||
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Resolve model path, checking for local models before using HuggingFace.
|
||||
|
||||
Args:
|
||||
model_name_or_path: HuggingFace model name or absolute path
|
||||
default_local_dir: Default local directory to check
|
||||
model_type: Type of model (for logging)
|
||||
|
||||
Returns:
|
||||
Resolved model path (local path or HuggingFace model name)
|
||||
"""
|
||||
# If it's an absolute path, use it directly
|
||||
if os.path.isabs(model_name_or_path):
|
||||
if os.path.exists(model_name_or_path):
|
||||
self.logger.info(
|
||||
f"{model_type} model: Using local model at {model_name_or_path}"
|
||||
)
|
||||
return model_name_or_path
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"{model_type} model: Local path {model_name_or_path} does not exist, "
|
||||
f"falling back to HuggingFace"
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
# Check if default local directory exists
|
||||
if os.path.exists(default_local_dir):
|
||||
config_file = os.path.join(default_local_dir, "config.json")
|
||||
if os.path.exists(config_file):
|
||||
abs_path = os.path.abspath(default_local_dir)
|
||||
self.logger.info(
|
||||
f"{model_type} model: Found local model at {abs_path}"
|
||||
)
|
||||
return abs_path
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"{model_type} model: Local directory {default_local_dir} exists but "
|
||||
f"is not a valid model (missing config.json)"
|
||||
)
|
||||
|
||||
# Use HuggingFace model name
|
||||
self.logger.info(
|
||||
f"{model_type} model: No local model found, will download from HuggingFace: "
|
||||
f"{model_name_or_path}"
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect logos in an image and return bounding boxes with embeddings.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (BGR format, numpy array)
|
||||
|
||||
Returns:
|
||||
List of dictionaries, each containing:
|
||||
- 'box': dict with 'xmin', 'ymin', 'xmax', 'ymax' (pixel coordinates)
|
||||
- 'score': DETR confidence score (float 0-1)
|
||||
- 'embedding': Feature embedding (torch.Tensor)
|
||||
- 'label': DETR predicted label (string)
|
||||
"""
|
||||
# Convert OpenCV BGR to RGB PIL Image
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
|
||||
# Run DETR detection
|
||||
predictions = self.detr_pipe(pil_image)
|
||||
|
||||
# Filter by threshold and add embeddings
|
||||
detections = []
|
||||
for pred in predictions:
|
||||
score = pred.get("score", 0.0)
|
||||
if score < self.detr_threshold:
|
||||
continue
|
||||
|
||||
box = pred.get("box", {})
|
||||
xmin = box.get("xmin", 0)
|
||||
ymin = box.get("ymin", 0)
|
||||
xmax = box.get("xmax", 0)
|
||||
ymax = box.get("ymax", 0)
|
||||
|
||||
# Extract bounding box region
|
||||
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
|
||||
|
||||
# Get embedding for this region
|
||||
embedding = self._embed_fn(bbox_crop)
|
||||
|
||||
detections.append(
|
||||
{
|
||||
"box": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax},
|
||||
"score": score,
|
||||
"embedding": embedding,
|
||||
"label": pred.get("label", "logo"),
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Detected {len(detections)} logos (threshold: {self.detr_threshold})"
|
||||
)
|
||||
return detections
|
||||
|
||||
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Get embedding for a single reference logo image.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (BGR format, numpy array)
|
||||
|
||||
Returns:
|
||||
Normalized feature embedding (torch.Tensor)
|
||||
"""
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
return self._embed_fn(pil_image)
|
||||
|
||||
def get_averaged_embedding(self, images: List[np.ndarray]) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Compute averaged embedding from multiple reference logo images.
|
||||
|
||||
Follows the averaging pattern from db_embeddings.py:
|
||||
1. Compute embedding for each image
|
||||
2. Stack and average across all images
|
||||
3. Re-normalize the averaged embedding
|
||||
|
||||
Args:
|
||||
images: List of OpenCV images (BGR format, numpy arrays)
|
||||
|
||||
Returns:
|
||||
Normalized averaged embedding (torch.Tensor, shape [1, D]),
|
||||
or None if no valid embeddings could be computed
|
||||
"""
|
||||
embeddings = []
|
||||
for img in images:
|
||||
try:
|
||||
emb = self.get_embedding(img)
|
||||
embeddings.append(emb)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to compute embedding for reference image: {e}")
|
||||
|
||||
if not embeddings:
|
||||
return None
|
||||
|
||||
# Stack: (N, D), average: (1, D), re-normalize
|
||||
stacked = torch.cat(embeddings, dim=0)
|
||||
avg_emb = stacked.mean(dim=0, keepdim=True)
|
||||
avg_emb = F.normalize(avg_emb, dim=-1)
|
||||
|
||||
self.logger.debug(
|
||||
f"Computed averaged embedding from {len(embeddings)} reference image(s)"
|
||||
)
|
||||
return avg_emb
|
||||
|
||||
def compare_embeddings(
|
||||
self, embedding1: torch.Tensor, embedding2: torch.Tensor
|
||||
) -> float:
|
||||
"""
|
||||
Compute cosine similarity between two embeddings.
|
||||
|
||||
Args:
|
||||
embedding1: First embedding (torch.Tensor)
|
||||
embedding2: Second embedding (torch.Tensor)
|
||||
|
||||
Returns:
|
||||
Cosine similarity score (float, range: -1 to 1, typically 0 to 1)
|
||||
"""
|
||||
# Ensure tensors are on the same device
|
||||
if embedding1.device != embedding2.device:
|
||||
embedding2 = embedding2.to(embedding1.device)
|
||||
|
||||
similarity = F.cosine_similarity(embedding1, embedding2, dim=-1)
|
||||
return similarity.item()
|
||||
|
||||
@staticmethod
|
||||
def make_filenames_hash(filenames: List[str]) -> str:
|
||||
"""
|
||||
Compute a deterministic hash of a filenames list.
|
||||
|
||||
Used for cache invalidation — if the filenames list changes,
|
||||
the hash changes, triggering re-computation of averaged embeddings.
|
||||
|
||||
Args:
|
||||
filenames: List of filename strings
|
||||
|
||||
Returns:
|
||||
16-character hex hash string
|
||||
"""
|
||||
canonical = json.dumps(sorted(filenames))
|
||||
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]
|
||||
521
test_burnley_detection.py
Normal file
521
test_burnley_detection.py
Normal file
@ -0,0 +1,521 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for logo detection accuracy on Burnley test images.
|
||||
|
||||
Uses DetectLogosEmbeddings from logo_detection_embeddings.py to detect
|
||||
barnfield and vertu logos. Ground truth is determined by filename prefix:
|
||||
- "vertu_" → contains vertu logo
|
||||
- "barnfield_" → contains barnfield logo
|
||||
- "barnfield+vertu_" → contains both logos
|
||||
- anything else → no target logos
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from logo_detection_embeddings import DetectLogosEmbeddings
|
||||
|
||||
|
||||
def setup_logging(verbose: bool = False) -> logging.Logger:
|
||||
"""Configure logging."""
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_image(image_path: Path) -> Optional[cv2.Mat]:
|
||||
"""Load an image using OpenCV."""
|
||||
img = cv2.imread(str(image_path))
|
||||
if img is None:
|
||||
return None
|
||||
return img
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Simple file-based cache for embeddings."""
|
||||
|
||||
def __init__(self, cache_path: Path):
|
||||
self.cache_path = cache_path
|
||||
self.cache: Dict[str, Any] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
if self.cache_path.exists():
|
||||
try:
|
||||
with open(self.cache_path, "rb") as f:
|
||||
self.cache = pickle.load(f)
|
||||
except Exception:
|
||||
self.cache = {}
|
||||
|
||||
def save(self):
|
||||
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.cache_path, "wb") as f:
|
||||
pickle.dump(self.cache, f)
|
||||
|
||||
def get(self, key: str):
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, key: str, value):
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.cache[key] = value.cpu()
|
||||
else:
|
||||
self.cache[key] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.cache)
|
||||
|
||||
|
||||
def get_expected_logos(filename: str) -> Set[str]:
|
||||
"""Determine expected logos from filename prefix."""
|
||||
name = filename.lower()
|
||||
if name.startswith("barnfield+vertu_"):
|
||||
return {"barnfield", "vertu"}
|
||||
elif name.startswith("barnfield_"):
|
||||
return {"barnfield"}
|
||||
elif name.startswith("vertu_"):
|
||||
return {"vertu"}
|
||||
return set()
|
||||
|
||||
|
||||
def load_reference_images(ref_dir: Path, logger: logging.Logger) -> List[cv2.Mat]:
|
||||
"""Load all images from a reference directory."""
|
||||
images = []
|
||||
for path in sorted(ref_dir.iterdir()):
|
||||
if path.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp"):
|
||||
img = load_image(path)
|
||||
if img is not None:
|
||||
images.append(img)
|
||||
else:
|
||||
logger.warning(f"Failed to load reference image: {path}")
|
||||
return images
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test logo detection on Burnley test images using DetectLogosEmbeddings"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--threshold",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Similarity threshold for matching (default: 0.7)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d", "--detr-threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="DETR detection confidence threshold (default: 0.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--embedding-model",
|
||||
type=str,
|
||||
choices=["clip", "dinov2", "siglip"],
|
||||
default="dinov2",
|
||||
help="Embedding model type (default: dinov2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--margin",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Required margin between best and second-best match (default: 0.05)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--similarity-details",
|
||||
action="store_true",
|
||||
help="Output detailed similarity scores for each detection",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-cache",
|
||||
action="store_true",
|
||||
help="Disable embedding cache",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear-cache",
|
||||
action="store_true",
|
||||
help="Clear embedding cache before running",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Append results summary to this file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger = setup_logging(args.verbose)
|
||||
|
||||
# Paths
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
test_images_dir = base_dir / "burnley_test_images"
|
||||
barnfield_ref_dir = base_dir / "barnfield_reference_images"
|
||||
vertu_ref_dir = base_dir / "vertu_reference_images"
|
||||
cache_path = base_dir / ".burnley_embedding_cache.pkl"
|
||||
|
||||
# Verify directories exist
|
||||
for d, name in [(test_images_dir, "Test images"), (barnfield_ref_dir, "Barnfield refs"), (vertu_ref_dir, "Vertu refs")]:
|
||||
if not d.exists():
|
||||
logger.error(f"{name} directory not found: {d}")
|
||||
sys.exit(1)
|
||||
|
||||
# Handle cache
|
||||
if args.clear_cache and cache_path.exists():
|
||||
cache_path.unlink()
|
||||
logger.info("Cleared embedding cache")
|
||||
|
||||
cache = EmbeddingCache(cache_path) if not args.no_cache else None
|
||||
if cache:
|
||||
logger.info(f"Loaded {len(cache)} cached embeddings")
|
||||
|
||||
# Initialize detector
|
||||
logger.info(f"Initializing detector with embedding model: {args.embedding_model}")
|
||||
detector = DetectLogosEmbeddings(
|
||||
logger=logger,
|
||||
detr_threshold=args.detr_threshold,
|
||||
embedding_model_type=args.embedding_model,
|
||||
)
|
||||
|
||||
# Compute averaged reference embeddings
|
||||
logger.info("Computing reference embeddings...")
|
||||
|
||||
reference_embeddings: Dict[str, torch.Tensor] = {}
|
||||
for logo_name, ref_dir in [("barnfield", barnfield_ref_dir), ("vertu", vertu_ref_dir)]:
|
||||
cache_key = f"avg_ref:{logo_name}:{args.embedding_model}"
|
||||
cached = cache.get(cache_key) if cache else None
|
||||
|
||||
if cached is not None:
|
||||
reference_embeddings[logo_name] = cached
|
||||
logger.info(f"Loaded cached averaged embedding for {logo_name}")
|
||||
else:
|
||||
ref_images = load_reference_images(ref_dir, logger)
|
||||
logger.info(f"Computing averaged embedding for {logo_name} from {len(ref_images)} images")
|
||||
avg_emb = detector.get_averaged_embedding(ref_images)
|
||||
if avg_emb is None:
|
||||
logger.error(f"Failed to compute embedding for {logo_name}")
|
||||
sys.exit(1)
|
||||
reference_embeddings[logo_name] = avg_emb
|
||||
if cache:
|
||||
cache.put(cache_key, avg_emb)
|
||||
|
||||
# Collect test images
|
||||
test_files = sorted([
|
||||
f.name for f in test_images_dir.iterdir()
|
||||
if f.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp")
|
||||
])
|
||||
logger.info(f"Found {len(test_files)} test images")
|
||||
|
||||
# Metrics
|
||||
true_positives = 0
|
||||
false_positives = 0
|
||||
false_negatives = 0
|
||||
total_expected = 0
|
||||
results = []
|
||||
|
||||
similarity_details = {
|
||||
"true_positive_sims": [],
|
||||
"false_positive_sims": [],
|
||||
"missed_best_sims": [],
|
||||
"detection_details": [],
|
||||
}
|
||||
|
||||
# Process test images
|
||||
for test_filename in tqdm(test_files, desc="Testing"):
|
||||
test_path = test_images_dir / test_filename
|
||||
expected_logos = get_expected_logos(test_filename)
|
||||
total_expected += len(expected_logos)
|
||||
|
||||
# Check cache for detections
|
||||
det_cache_key = f"det:{test_filename}:{args.embedding_model}"
|
||||
cached_detections = cache.get(det_cache_key) if cache else None
|
||||
|
||||
if cached_detections is not None:
|
||||
detections = cached_detections
|
||||
else:
|
||||
test_img = load_image(test_path)
|
||||
if test_img is None:
|
||||
logger.warning(f"Failed to load test image: {test_path}")
|
||||
continue
|
||||
detections = detector.detect(test_img)
|
||||
if cache:
|
||||
cache.put(det_cache_key, detections)
|
||||
|
||||
# Match each detection against reference embeddings with margin
|
||||
matched_logos: Set[str] = set()
|
||||
for det_idx, detection in enumerate(detections):
|
||||
# Compute similarity to each reference logo
|
||||
sims: Dict[str, float] = {}
|
||||
for logo_name, ref_emb in reference_embeddings.items():
|
||||
sims[logo_name] = detector.compare_embeddings(
|
||||
detection["embedding"], ref_emb
|
||||
)
|
||||
|
||||
sorted_sims = sorted(sims.items(), key=lambda x: -x[1])
|
||||
|
||||
if args.similarity_details:
|
||||
similarity_details["detection_details"].append({
|
||||
"image": test_filename,
|
||||
"detection_idx": det_idx,
|
||||
"expected_logos": list(expected_logos),
|
||||
"similarities": sorted_sims,
|
||||
"detr_score": detection.get("score", 0),
|
||||
})
|
||||
|
||||
# Best match with margin check
|
||||
if not sorted_sims:
|
||||
continue
|
||||
|
||||
best_name, best_sim = sorted_sims[0]
|
||||
if best_sim < args.threshold:
|
||||
continue
|
||||
|
||||
# Check margin over second best
|
||||
if len(sorted_sims) > 1:
|
||||
second_sim = sorted_sims[1][1]
|
||||
if best_sim - second_sim < args.margin:
|
||||
continue
|
||||
|
||||
matched_logos.add(best_name)
|
||||
is_correct = best_name in expected_logos
|
||||
|
||||
if is_correct:
|
||||
true_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["true_positive_sims"].append(best_sim)
|
||||
else:
|
||||
false_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["false_positive_sims"].append(best_sim)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": best_name,
|
||||
"similarity": best_sim,
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
# Count missed detections
|
||||
missed = expected_logos - matched_logos
|
||||
false_negatives += len(missed)
|
||||
|
||||
for missed_logo in missed:
|
||||
if args.similarity_details and detections:
|
||||
best_sim_for_missed = 0
|
||||
ref_emb = reference_embeddings[missed_logo]
|
||||
for detection in detections:
|
||||
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||
best_sim_for_missed = max(best_sim_for_missed, sim)
|
||||
similarity_details["missed_best_sims"].append(best_sim_for_missed)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": None,
|
||||
"expected_logo": missed_logo,
|
||||
"similarity": None,
|
||||
"correct": False,
|
||||
})
|
||||
|
||||
# Save cache
|
||||
if cache:
|
||||
cache.save()
|
||||
logger.info(f"Saved {len(cache)} embeddings to cache")
|
||||
|
||||
# Calculate metrics
|
||||
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
||||
recall = true_positives / total_expected if total_expected > 0 else 0
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("BURNLEY LOGO DETECTION TEST RESULTS")
|
||||
print("=" * 60)
|
||||
print(f"\nConfiguration:")
|
||||
print(f" Embedding model: {args.embedding_model}")
|
||||
print(f" Similarity threshold: {args.threshold}")
|
||||
print(f" DETR confidence threshold: {args.detr_threshold}")
|
||||
print(f" Matching margin: {args.margin}")
|
||||
print(f" Test images processed: {len(test_files)}")
|
||||
print(f" Reference logos: barnfield, vertu")
|
||||
|
||||
print(f"\nMetrics:")
|
||||
print(f" True Positives (correct matches): {true_positives}")
|
||||
print(f" False Positives (wrong matches): {false_positives}")
|
||||
print(f" False Negatives (missed logos): {false_negatives}")
|
||||
print(f" Total expected matches: {total_expected}")
|
||||
|
||||
print(f"\nScores:")
|
||||
print(f" Precision: {precision:.4f} ({precision*100:.1f}%)")
|
||||
print(f" Recall: {recall:.4f} ({recall*100:.1f}%)")
|
||||
print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)")
|
||||
|
||||
# Show false positive examples
|
||||
false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]]
|
||||
if false_positive_examples:
|
||||
print(f"\nExample False Positives (first 5):")
|
||||
for r in false_positive_examples[:5]:
|
||||
print(f" - Image: {r['test_image']}")
|
||||
print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})")
|
||||
|
||||
# Show false negative examples
|
||||
false_negative_examples = [r for r in results if r.get("expected_logo")]
|
||||
if false_negative_examples:
|
||||
print(f"\nExample False Negatives (first 5):")
|
||||
for r in false_negative_examples[:5]:
|
||||
print(f" - Image: {r['test_image']}")
|
||||
print(f" Expected: {r['expected_logo']}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Print similarity details if requested
|
||||
if args.similarity_details:
|
||||
print_similarity_details(similarity_details, args.threshold)
|
||||
|
||||
# Write results to file if requested
|
||||
if args.output_file:
|
||||
write_results_to_file(
|
||||
output_path=Path(args.output_file),
|
||||
args=args,
|
||||
num_test_images=len(test_files),
|
||||
true_positives=true_positives,
|
||||
false_positives=false_positives,
|
||||
false_negatives=false_negatives,
|
||||
total_expected=total_expected,
|
||||
precision=precision,
|
||||
recall=recall,
|
||||
f1=f1,
|
||||
)
|
||||
print(f"\nResults appended to: {args.output_file}")
|
||||
|
||||
|
||||
def print_similarity_details(details: dict, threshold: float):
|
||||
"""Print detailed similarity distribution analysis."""
|
||||
import statistics
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SIMILARITY DISTRIBUTION ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
def compute_stats(values, name):
|
||||
if not values:
|
||||
print(f"\n{name}: No data")
|
||||
return
|
||||
print(f"\n{name} (n={len(values)}):")
|
||||
print(f" Min: {min(values):.4f}")
|
||||
print(f" Max: {max(values):.4f}")
|
||||
print(f" Mean: {statistics.mean(values):.4f}")
|
||||
if len(values) > 1:
|
||||
print(f" StdDev: {statistics.stdev(values):.4f}")
|
||||
print(f" Median: {statistics.median(values):.4f}")
|
||||
|
||||
above = sum(1 for v in values if v >= threshold)
|
||||
below = sum(1 for v in values if v < threshold)
|
||||
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
|
||||
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
|
||||
|
||||
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities")
|
||||
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities")
|
||||
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities")
|
||||
|
||||
# Overlap analysis
|
||||
tp_sims = details["true_positive_sims"]
|
||||
fp_sims = details["false_positive_sims"]
|
||||
if tp_sims and fp_sims:
|
||||
print("\n" + "-" * 40)
|
||||
print("OVERLAP ANALYSIS:")
|
||||
tp_min, tp_max = min(tp_sims), max(tp_sims)
|
||||
fp_min, fp_max = min(fp_sims), max(fp_sims)
|
||||
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
|
||||
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
|
||||
|
||||
overlap_min = max(tp_min, fp_min)
|
||||
overlap_max = min(tp_max, fp_max)
|
||||
if overlap_min < overlap_max:
|
||||
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
|
||||
else:
|
||||
print(" NO OVERLAP - distributions are separable!")
|
||||
|
||||
# Sample detection details
|
||||
det_details = details["detection_details"]
|
||||
if det_details:
|
||||
print("\n" + "-" * 40)
|
||||
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
|
||||
for i, det in enumerate(det_details[:20]):
|
||||
expected = det["expected_logos"]
|
||||
sims = det["similarities"]
|
||||
print(f"\n [{i+1}] Image: {det['image']}")
|
||||
print(f" Expected: {expected if expected else '(none)'}")
|
||||
print(f" DETR score: {det['detr_score']:.3f}")
|
||||
print(f" Similarities:")
|
||||
for logo, sim in sims:
|
||||
marker = " <-- CORRECT" if logo in expected else ""
|
||||
print(f" {sim:.4f} {logo}{marker}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
|
||||
def write_results_to_file(
|
||||
output_path: Path,
|
||||
args,
|
||||
num_test_images: int,
|
||||
true_positives: int,
|
||||
false_positives: int,
|
||||
false_negatives: int,
|
||||
total_expected: int,
|
||||
precision: float,
|
||||
recall: float,
|
||||
f1: float,
|
||||
):
|
||||
"""Write results summary to file."""
|
||||
from datetime import datetime
|
||||
|
||||
lines = [
|
||||
"=" * 70,
|
||||
"BURNLEY LOGO DETECTION TEST",
|
||||
f"Model: {args.embedding_model}",
|
||||
f"Method: Margin-based (margin={args.margin})",
|
||||
"=" * 70,
|
||||
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"",
|
||||
"Configuration:",
|
||||
f" Embedding model: {args.embedding_model}",
|
||||
f" Similarity threshold: {args.threshold}",
|
||||
f" DETR threshold: {args.detr_threshold}",
|
||||
f" Matching margin: {args.margin}",
|
||||
f" Test images processed: {num_test_images}",
|
||||
f" Reference logos: barnfield, vertu",
|
||||
"",
|
||||
"Results:",
|
||||
f" True Positives: {true_positives:>6}",
|
||||
f" False Positives: {false_positives:>6}",
|
||||
f" False Negatives: {false_negatives:>6}",
|
||||
f" Total Expected: {total_expected:>6}",
|
||||
"",
|
||||
"Scores:",
|
||||
f" Precision: {precision:.4f} ({precision*100:.1f}%)",
|
||||
f" Recall: {recall:.4f} ({recall*100:.1f}%)",
|
||||
f" F1 Score: {f1:.4f} ({f1*100:.1f}%)",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
|
||||
with open(output_path, "a") as f:
|
||||
f.write("\n".join(lines))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user