Compare commits
2 Commits
91d1c9cd59
...
512f678310
| Author | SHA1 | Date | |
|---|---|---|---|
| 512f678310 | |||
| f598866d37 |
364
logo_detection_embeddings.py
Normal file
364
logo_detection_embeddings.py
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
"""
|
||||||
|
Logo detection using DETR for object detection and selectable embedding models for feature matching.
|
||||||
|
|
||||||
|
This module provides a class for detecting logos in images using:
|
||||||
|
1. DETR (DEtection TRansformer) for initial logo region detection
|
||||||
|
2. Selectable embedding model (CLIP, DINOv2, or SigLIP) for feature extraction and matching
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Multiple reference images per logo entry, averaged into a single embedding
|
||||||
|
- Cache-aware: averaged embeddings are only recalculated when the filenames list changes
|
||||||
|
- Supports local model directories with fallback to HuggingFace
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import (
|
||||||
|
AutoImageProcessor,
|
||||||
|
AutoModel,
|
||||||
|
AutoProcessor,
|
||||||
|
CLIPModel,
|
||||||
|
CLIPProcessor,
|
||||||
|
Dinov2Model,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class DetectLogosEmbeddings:
|
||||||
|
"""
|
||||||
|
Logo detection class using DETR and a selectable embedding model.
|
||||||
|
|
||||||
|
This class detects logos in images by:
|
||||||
|
1. Using DETR to find potential logo regions (bounding boxes)
|
||||||
|
2. Extracting embeddings for each detected region using the selected model
|
||||||
|
3. Comparing embeddings with averaged reference logo embeddings for identification
|
||||||
|
|
||||||
|
Supported embedding models:
|
||||||
|
- clip: openai/clip-vit-large-patch14
|
||||||
|
- dinov2: facebook/dinov2-base (recommended for visual similarity)
|
||||||
|
- siglip: google/siglip-base-patch16-224
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logger,
|
||||||
|
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
|
||||||
|
embedding_model_type: str = "dinov2",
|
||||||
|
detr_threshold: float = 0.5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize DETR and embedding models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger: Logger instance for logging
|
||||||
|
detr_model: HuggingFace model name or local path for DETR object detection
|
||||||
|
embedding_model_type: One of "clip", "dinov2", or "siglip"
|
||||||
|
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||||
|
"""
|
||||||
|
self.logger = logger
|
||||||
|
self.detr_threshold = detr_threshold
|
||||||
|
self.embedding_model_type = embedding_model_type
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.device_index = 0 if torch.cuda.is_available() else -1
|
||||||
|
self.device = torch.device(self.device_str)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"Initializing DetectLogosEmbeddings on device: {self.device_str}, "
|
||||||
|
f"embedding model: {embedding_model_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- DETR model ---
|
||||||
|
default_detr_dir = os.environ.get(
|
||||||
|
"LOGO_DETR_MODEL_DIR", "models/logo_detection/detr"
|
||||||
|
)
|
||||||
|
detr_model_path = self._resolve_model_path(detr_model, default_detr_dir, "DETR")
|
||||||
|
|
||||||
|
self.logger.info(f"Loading DETR model: {detr_model_path}")
|
||||||
|
self.detr_pipe = pipeline(
|
||||||
|
task="object-detection",
|
||||||
|
model=detr_model_path,
|
||||||
|
device=self.device_index,
|
||||||
|
use_fast=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Embedding model ---
|
||||||
|
self._load_embedding_model(embedding_model_type)
|
||||||
|
|
||||||
|
self.logger.info("DetectLogosEmbeddings initialization complete")
|
||||||
|
|
||||||
|
def _load_embedding_model(self, model_type: str) -> None:
|
||||||
|
"""
|
||||||
|
Load the selected embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: One of "clip", "dinov2", or "siglip"
|
||||||
|
"""
|
||||||
|
default_embedding_dir = os.environ.get(
|
||||||
|
"LOGO_EMBEDDING_MODEL_DIR", f"models/logo_detection/{model_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "clip":
|
||||||
|
model_name = "openai/clip-vit-large-patch14"
|
||||||
|
model_path = self._resolve_model_path(
|
||||||
|
model_name, default_embedding_dir, "CLIP"
|
||||||
|
)
|
||||||
|
self.logger.info(f"Loading CLIP model: {model_path}")
|
||||||
|
self._clip_model = CLIPModel.from_pretrained(model_path).to(self.device)
|
||||||
|
self._clip_processor = CLIPProcessor.from_pretrained(model_path)
|
||||||
|
self._clip_model.eval()
|
||||||
|
|
||||||
|
def embed_fn(pil_image):
|
||||||
|
inputs = self._clip_processor(
|
||||||
|
images=pil_image, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
features = self._clip_model.get_image_features(**inputs)
|
||||||
|
return F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
elif model_type == "dinov2":
|
||||||
|
model_name = "facebook/dinov2-base"
|
||||||
|
model_path = self._resolve_model_path(
|
||||||
|
model_name, default_embedding_dir, "DINOv2"
|
||||||
|
)
|
||||||
|
self.logger.info(f"Loading DINOv2 model: {model_path}")
|
||||||
|
self._dinov2_model = Dinov2Model.from_pretrained(model_path).to(self.device)
|
||||||
|
self._dinov2_processor = AutoImageProcessor.from_pretrained(model_path)
|
||||||
|
self._dinov2_model.eval()
|
||||||
|
|
||||||
|
def embed_fn(pil_image):
|
||||||
|
inputs = self._dinov2_processor(
|
||||||
|
images=pil_image, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self._dinov2_model(**inputs)
|
||||||
|
# Use CLS token embedding
|
||||||
|
features = outputs.last_hidden_state[:, 0, :]
|
||||||
|
return F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
elif model_type == "siglip":
|
||||||
|
model_name = "google/siglip-base-patch16-224"
|
||||||
|
model_path = self._resolve_model_path(
|
||||||
|
model_name, default_embedding_dir, "SigLIP"
|
||||||
|
)
|
||||||
|
self.logger.info(f"Loading SigLIP model: {model_path}")
|
||||||
|
self._siglip_model = AutoModel.from_pretrained(model_path).to(self.device)
|
||||||
|
self._siglip_processor = AutoProcessor.from_pretrained(model_path)
|
||||||
|
self._siglip_model.eval()
|
||||||
|
|
||||||
|
def embed_fn(pil_image):
|
||||||
|
inputs = self._siglip_processor(
|
||||||
|
images=pil_image, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
features = self._siglip_model.get_image_features(**inputs)
|
||||||
|
return F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown embedding model type: {model_type}. "
|
||||||
|
f"Use 'clip', 'dinov2', or 'siglip'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._embed_fn = embed_fn
|
||||||
|
|
||||||
|
def _resolve_model_path(
|
||||||
|
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Resolve model path, checking for local models before using HuggingFace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name_or_path: HuggingFace model name or absolute path
|
||||||
|
default_local_dir: Default local directory to check
|
||||||
|
model_type: Type of model (for logging)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved model path (local path or HuggingFace model name)
|
||||||
|
"""
|
||||||
|
# If it's an absolute path, use it directly
|
||||||
|
if os.path.isabs(model_name_or_path):
|
||||||
|
if os.path.exists(model_name_or_path):
|
||||||
|
self.logger.info(
|
||||||
|
f"{model_type} model: Using local model at {model_name_or_path}"
|
||||||
|
)
|
||||||
|
return model_name_or_path
|
||||||
|
else:
|
||||||
|
self.logger.warning(
|
||||||
|
f"{model_type} model: Local path {model_name_or_path} does not exist, "
|
||||||
|
f"falling back to HuggingFace"
|
||||||
|
)
|
||||||
|
return model_name_or_path
|
||||||
|
|
||||||
|
# Check if default local directory exists
|
||||||
|
if os.path.exists(default_local_dir):
|
||||||
|
config_file = os.path.join(default_local_dir, "config.json")
|
||||||
|
if os.path.exists(config_file):
|
||||||
|
abs_path = os.path.abspath(default_local_dir)
|
||||||
|
self.logger.info(
|
||||||
|
f"{model_type} model: Found local model at {abs_path}"
|
||||||
|
)
|
||||||
|
return abs_path
|
||||||
|
else:
|
||||||
|
self.logger.warning(
|
||||||
|
f"{model_type} model: Local directory {default_local_dir} exists but "
|
||||||
|
f"is not a valid model (missing config.json)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use HuggingFace model name
|
||||||
|
self.logger.info(
|
||||||
|
f"{model_type} model: No local model found, will download from HuggingFace: "
|
||||||
|
f"{model_name_or_path}"
|
||||||
|
)
|
||||||
|
return model_name_or_path
|
||||||
|
|
||||||
|
def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Detect logos in an image and return bounding boxes with embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: OpenCV image (BGR format, numpy array)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries, each containing:
|
||||||
|
- 'box': dict with 'xmin', 'ymin', 'xmax', 'ymax' (pixel coordinates)
|
||||||
|
- 'score': DETR confidence score (float 0-1)
|
||||||
|
- 'embedding': Feature embedding (torch.Tensor)
|
||||||
|
- 'label': DETR predicted label (string)
|
||||||
|
"""
|
||||||
|
# Convert OpenCV BGR to RGB PIL Image
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
pil_image = Image.fromarray(image_rgb)
|
||||||
|
|
||||||
|
# Run DETR detection
|
||||||
|
predictions = self.detr_pipe(pil_image)
|
||||||
|
|
||||||
|
# Filter by threshold and add embeddings
|
||||||
|
detections = []
|
||||||
|
for pred in predictions:
|
||||||
|
score = pred.get("score", 0.0)
|
||||||
|
if score < self.detr_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
box = pred.get("box", {})
|
||||||
|
xmin = box.get("xmin", 0)
|
||||||
|
ymin = box.get("ymin", 0)
|
||||||
|
xmax = box.get("xmax", 0)
|
||||||
|
ymax = box.get("ymax", 0)
|
||||||
|
|
||||||
|
# Extract bounding box region
|
||||||
|
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
|
||||||
|
|
||||||
|
# Get embedding for this region
|
||||||
|
embedding = self._embed_fn(bbox_crop)
|
||||||
|
|
||||||
|
detections.append(
|
||||||
|
{
|
||||||
|
"box": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax},
|
||||||
|
"score": score,
|
||||||
|
"embedding": embedding,
|
||||||
|
"label": pred.get("label", "logo"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Detected {len(detections)} logos (threshold: {self.detr_threshold})"
|
||||||
|
)
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get embedding for a single reference logo image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: OpenCV image (BGR format, numpy array)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized feature embedding (torch.Tensor)
|
||||||
|
"""
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
pil_image = Image.fromarray(image_rgb)
|
||||||
|
return self._embed_fn(pil_image)
|
||||||
|
|
||||||
|
def get_averaged_embedding(self, images: List[np.ndarray]) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute averaged embedding from multiple reference logo images.
|
||||||
|
|
||||||
|
Follows the averaging pattern from db_embeddings.py:
|
||||||
|
1. Compute embedding for each image
|
||||||
|
2. Stack and average across all images
|
||||||
|
3. Re-normalize the averaged embedding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of OpenCV images (BGR format, numpy arrays)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized averaged embedding (torch.Tensor, shape [1, D]),
|
||||||
|
or None if no valid embeddings could be computed
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
for img in images:
|
||||||
|
try:
|
||||||
|
emb = self.get_embedding(img)
|
||||||
|
embeddings.append(emb)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to compute embedding for reference image: {e}")
|
||||||
|
|
||||||
|
if not embeddings:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Stack: (N, D), average: (1, D), re-normalize
|
||||||
|
stacked = torch.cat(embeddings, dim=0)
|
||||||
|
avg_emb = stacked.mean(dim=0, keepdim=True)
|
||||||
|
avg_emb = F.normalize(avg_emb, dim=-1)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Computed averaged embedding from {len(embeddings)} reference image(s)"
|
||||||
|
)
|
||||||
|
return avg_emb
|
||||||
|
|
||||||
|
def compare_embeddings(
|
||||||
|
self, embedding1: torch.Tensor, embedding2: torch.Tensor
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute cosine similarity between two embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding1: First embedding (torch.Tensor)
|
||||||
|
embedding2: Second embedding (torch.Tensor)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cosine similarity score (float, range: -1 to 1, typically 0 to 1)
|
||||||
|
"""
|
||||||
|
# Ensure tensors are on the same device
|
||||||
|
if embedding1.device != embedding2.device:
|
||||||
|
embedding2 = embedding2.to(embedding1.device)
|
||||||
|
|
||||||
|
similarity = F.cosine_similarity(embedding1, embedding2, dim=-1)
|
||||||
|
return similarity.item()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_filenames_hash(filenames: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Compute a deterministic hash of a filenames list.
|
||||||
|
|
||||||
|
Used for cache invalidation — if the filenames list changes,
|
||||||
|
the hash changes, triggering re-computation of averaged embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filenames: List of filename strings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
16-character hex hash string
|
||||||
|
"""
|
||||||
|
canonical = json.dumps(sorted(filenames))
|
||||||
|
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]
|
||||||
521
test_burnley_detection.py
Normal file
521
test_burnley_detection.py
Normal file
@ -0,0 +1,521 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script for logo detection accuracy on Burnley test images.
|
||||||
|
|
||||||
|
Uses DetectLogosEmbeddings from logo_detection_embeddings.py to detect
|
||||||
|
barnfield and vertu logos. Ground truth is determined by filename prefix:
|
||||||
|
- "vertu_" → contains vertu logo
|
||||||
|
- "barnfield_" → contains barnfield logo
|
||||||
|
- "barnfield+vertu_" → contains both logos
|
||||||
|
- anything else → no target logos
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from logo_detection_embeddings import DetectLogosEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(verbose: bool = False) -> logging.Logger:
|
||||||
|
"""Configure logging."""
|
||||||
|
level = logging.DEBUG if verbose else logging.INFO
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
datefmt="%H:%M:%S",
|
||||||
|
)
|
||||||
|
return logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image_path: Path) -> Optional[cv2.Mat]:
|
||||||
|
"""Load an image using OpenCV."""
|
||||||
|
img = cv2.imread(str(image_path))
|
||||||
|
if img is None:
|
||||||
|
return None
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingCache:
|
||||||
|
"""Simple file-based cache for embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, cache_path: Path):
|
||||||
|
self.cache_path = cache_path
|
||||||
|
self.cache: Dict[str, Any] = {}
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
if self.cache_path.exists():
|
||||||
|
try:
|
||||||
|
with open(self.cache_path, "rb") as f:
|
||||||
|
self.cache = pickle.load(f)
|
||||||
|
except Exception:
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(self.cache_path, "wb") as f:
|
||||||
|
pickle.dump(self.cache, f)
|
||||||
|
|
||||||
|
def get(self, key: str):
|
||||||
|
return self.cache.get(key)
|
||||||
|
|
||||||
|
def put(self, key: str, value):
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
self.cache[key] = value.cpu()
|
||||||
|
else:
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.cache)
|
||||||
|
|
||||||
|
|
||||||
|
def get_expected_logos(filename: str) -> Set[str]:
|
||||||
|
"""Determine expected logos from filename prefix."""
|
||||||
|
name = filename.lower()
|
||||||
|
if name.startswith("barnfield+vertu_"):
|
||||||
|
return {"barnfield", "vertu"}
|
||||||
|
elif name.startswith("barnfield_"):
|
||||||
|
return {"barnfield"}
|
||||||
|
elif name.startswith("vertu_"):
|
||||||
|
return {"vertu"}
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|
||||||
|
def load_reference_images(ref_dir: Path, logger: logging.Logger) -> List[cv2.Mat]:
|
||||||
|
"""Load all images from a reference directory."""
|
||||||
|
images = []
|
||||||
|
for path in sorted(ref_dir.iterdir()):
|
||||||
|
if path.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp"):
|
||||||
|
img = load_image(path)
|
||||||
|
if img is not None:
|
||||||
|
images.append(img)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to load reference image: {path}")
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Test logo detection on Burnley test images using DetectLogosEmbeddings"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-t", "--threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.7,
|
||||||
|
help="Similarity threshold for matching (default: 0.7)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-d", "--detr-threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="DETR detection confidence threshold (default: 0.5)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-e", "--embedding-model",
|
||||||
|
type=str,
|
||||||
|
choices=["clip", "dinov2", "siglip"],
|
||||||
|
default="dinov2",
|
||||||
|
help="Embedding model type (default: dinov2)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--margin",
|
||||||
|
type=float,
|
||||||
|
default=0.05,
|
||||||
|
help="Required margin between best and second-best match (default: 0.05)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-v", "--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable verbose logging",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--similarity-details",
|
||||||
|
action="store_true",
|
||||||
|
help="Output detailed similarity scores for each detection",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-cache",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable embedding cache",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clear-cache",
|
||||||
|
action="store_true",
|
||||||
|
help="Clear embedding cache before running",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Append results summary to this file",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
logger = setup_logging(args.verbose)
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
test_images_dir = base_dir / "burnley_test_images"
|
||||||
|
barnfield_ref_dir = base_dir / "barnfield_reference_images"
|
||||||
|
vertu_ref_dir = base_dir / "vertu_reference_images"
|
||||||
|
cache_path = base_dir / ".burnley_embedding_cache.pkl"
|
||||||
|
|
||||||
|
# Verify directories exist
|
||||||
|
for d, name in [(test_images_dir, "Test images"), (barnfield_ref_dir, "Barnfield refs"), (vertu_ref_dir, "Vertu refs")]:
|
||||||
|
if not d.exists():
|
||||||
|
logger.error(f"{name} directory not found: {d}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Handle cache
|
||||||
|
if args.clear_cache and cache_path.exists():
|
||||||
|
cache_path.unlink()
|
||||||
|
logger.info("Cleared embedding cache")
|
||||||
|
|
||||||
|
cache = EmbeddingCache(cache_path) if not args.no_cache else None
|
||||||
|
if cache:
|
||||||
|
logger.info(f"Loaded {len(cache)} cached embeddings")
|
||||||
|
|
||||||
|
# Initialize detector
|
||||||
|
logger.info(f"Initializing detector with embedding model: {args.embedding_model}")
|
||||||
|
detector = DetectLogosEmbeddings(
|
||||||
|
logger=logger,
|
||||||
|
detr_threshold=args.detr_threshold,
|
||||||
|
embedding_model_type=args.embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute averaged reference embeddings
|
||||||
|
logger.info("Computing reference embeddings...")
|
||||||
|
|
||||||
|
reference_embeddings: Dict[str, torch.Tensor] = {}
|
||||||
|
for logo_name, ref_dir in [("barnfield", barnfield_ref_dir), ("vertu", vertu_ref_dir)]:
|
||||||
|
cache_key = f"avg_ref:{logo_name}:{args.embedding_model}"
|
||||||
|
cached = cache.get(cache_key) if cache else None
|
||||||
|
|
||||||
|
if cached is not None:
|
||||||
|
reference_embeddings[logo_name] = cached
|
||||||
|
logger.info(f"Loaded cached averaged embedding for {logo_name}")
|
||||||
|
else:
|
||||||
|
ref_images = load_reference_images(ref_dir, logger)
|
||||||
|
logger.info(f"Computing averaged embedding for {logo_name} from {len(ref_images)} images")
|
||||||
|
avg_emb = detector.get_averaged_embedding(ref_images)
|
||||||
|
if avg_emb is None:
|
||||||
|
logger.error(f"Failed to compute embedding for {logo_name}")
|
||||||
|
sys.exit(1)
|
||||||
|
reference_embeddings[logo_name] = avg_emb
|
||||||
|
if cache:
|
||||||
|
cache.put(cache_key, avg_emb)
|
||||||
|
|
||||||
|
# Collect test images
|
||||||
|
test_files = sorted([
|
||||||
|
f.name for f in test_images_dir.iterdir()
|
||||||
|
if f.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp")
|
||||||
|
])
|
||||||
|
logger.info(f"Found {len(test_files)} test images")
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
true_positives = 0
|
||||||
|
false_positives = 0
|
||||||
|
false_negatives = 0
|
||||||
|
total_expected = 0
|
||||||
|
results = []
|
||||||
|
|
||||||
|
similarity_details = {
|
||||||
|
"true_positive_sims": [],
|
||||||
|
"false_positive_sims": [],
|
||||||
|
"missed_best_sims": [],
|
||||||
|
"detection_details": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process test images
|
||||||
|
for test_filename in tqdm(test_files, desc="Testing"):
|
||||||
|
test_path = test_images_dir / test_filename
|
||||||
|
expected_logos = get_expected_logos(test_filename)
|
||||||
|
total_expected += len(expected_logos)
|
||||||
|
|
||||||
|
# Check cache for detections
|
||||||
|
det_cache_key = f"det:{test_filename}:{args.embedding_model}"
|
||||||
|
cached_detections = cache.get(det_cache_key) if cache else None
|
||||||
|
|
||||||
|
if cached_detections is not None:
|
||||||
|
detections = cached_detections
|
||||||
|
else:
|
||||||
|
test_img = load_image(test_path)
|
||||||
|
if test_img is None:
|
||||||
|
logger.warning(f"Failed to load test image: {test_path}")
|
||||||
|
continue
|
||||||
|
detections = detector.detect(test_img)
|
||||||
|
if cache:
|
||||||
|
cache.put(det_cache_key, detections)
|
||||||
|
|
||||||
|
# Match each detection against reference embeddings with margin
|
||||||
|
matched_logos: Set[str] = set()
|
||||||
|
for det_idx, detection in enumerate(detections):
|
||||||
|
# Compute similarity to each reference logo
|
||||||
|
sims: Dict[str, float] = {}
|
||||||
|
for logo_name, ref_emb in reference_embeddings.items():
|
||||||
|
sims[logo_name] = detector.compare_embeddings(
|
||||||
|
detection["embedding"], ref_emb
|
||||||
|
)
|
||||||
|
|
||||||
|
sorted_sims = sorted(sims.items(), key=lambda x: -x[1])
|
||||||
|
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["detection_details"].append({
|
||||||
|
"image": test_filename,
|
||||||
|
"detection_idx": det_idx,
|
||||||
|
"expected_logos": list(expected_logos),
|
||||||
|
"similarities": sorted_sims,
|
||||||
|
"detr_score": detection.get("score", 0),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Best match with margin check
|
||||||
|
if not sorted_sims:
|
||||||
|
continue
|
||||||
|
|
||||||
|
best_name, best_sim = sorted_sims[0]
|
||||||
|
if best_sim < args.threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check margin over second best
|
||||||
|
if len(sorted_sims) > 1:
|
||||||
|
second_sim = sorted_sims[1][1]
|
||||||
|
if best_sim - second_sim < args.margin:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matched_logos.add(best_name)
|
||||||
|
is_correct = best_name in expected_logos
|
||||||
|
|
||||||
|
if is_correct:
|
||||||
|
true_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["true_positive_sims"].append(best_sim)
|
||||||
|
else:
|
||||||
|
false_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["false_positive_sims"].append(best_sim)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"test_image": test_filename,
|
||||||
|
"matched_logo": best_name,
|
||||||
|
"similarity": best_sim,
|
||||||
|
"correct": is_correct,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Count missed detections
|
||||||
|
missed = expected_logos - matched_logos
|
||||||
|
false_negatives += len(missed)
|
||||||
|
|
||||||
|
for missed_logo in missed:
|
||||||
|
if args.similarity_details and detections:
|
||||||
|
best_sim_for_missed = 0
|
||||||
|
ref_emb = reference_embeddings[missed_logo]
|
||||||
|
for detection in detections:
|
||||||
|
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||||
|
best_sim_for_missed = max(best_sim_for_missed, sim)
|
||||||
|
similarity_details["missed_best_sims"].append(best_sim_for_missed)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"test_image": test_filename,
|
||||||
|
"matched_logo": None,
|
||||||
|
"expected_logo": missed_logo,
|
||||||
|
"similarity": None,
|
||||||
|
"correct": False,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save cache
|
||||||
|
if cache:
|
||||||
|
cache.save()
|
||||||
|
logger.info(f"Saved {len(cache)} embeddings to cache")
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
||||||
|
recall = true_positives / total_expected if total_expected > 0 else 0
|
||||||
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("BURNLEY LOGO DETECTION TEST RESULTS")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"\nConfiguration:")
|
||||||
|
print(f" Embedding model: {args.embedding_model}")
|
||||||
|
print(f" Similarity threshold: {args.threshold}")
|
||||||
|
print(f" DETR confidence threshold: {args.detr_threshold}")
|
||||||
|
print(f" Matching margin: {args.margin}")
|
||||||
|
print(f" Test images processed: {len(test_files)}")
|
||||||
|
print(f" Reference logos: barnfield, vertu")
|
||||||
|
|
||||||
|
print(f"\nMetrics:")
|
||||||
|
print(f" True Positives (correct matches): {true_positives}")
|
||||||
|
print(f" False Positives (wrong matches): {false_positives}")
|
||||||
|
print(f" False Negatives (missed logos): {false_negatives}")
|
||||||
|
print(f" Total expected matches: {total_expected}")
|
||||||
|
|
||||||
|
print(f"\nScores:")
|
||||||
|
print(f" Precision: {precision:.4f} ({precision*100:.1f}%)")
|
||||||
|
print(f" Recall: {recall:.4f} ({recall*100:.1f}%)")
|
||||||
|
print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)")
|
||||||
|
|
||||||
|
# Show false positive examples
|
||||||
|
false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]]
|
||||||
|
if false_positive_examples:
|
||||||
|
print(f"\nExample False Positives (first 5):")
|
||||||
|
for r in false_positive_examples[:5]:
|
||||||
|
print(f" - Image: {r['test_image']}")
|
||||||
|
print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})")
|
||||||
|
|
||||||
|
# Show false negative examples
|
||||||
|
false_negative_examples = [r for r in results if r.get("expected_logo")]
|
||||||
|
if false_negative_examples:
|
||||||
|
print(f"\nExample False Negatives (first 5):")
|
||||||
|
for r in false_negative_examples[:5]:
|
||||||
|
print(f" - Image: {r['test_image']}")
|
||||||
|
print(f" Expected: {r['expected_logo']}")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Print similarity details if requested
|
||||||
|
if args.similarity_details:
|
||||||
|
print_similarity_details(similarity_details, args.threshold)
|
||||||
|
|
||||||
|
# Write results to file if requested
|
||||||
|
if args.output_file:
|
||||||
|
write_results_to_file(
|
||||||
|
output_path=Path(args.output_file),
|
||||||
|
args=args,
|
||||||
|
num_test_images=len(test_files),
|
||||||
|
true_positives=true_positives,
|
||||||
|
false_positives=false_positives,
|
||||||
|
false_negatives=false_negatives,
|
||||||
|
total_expected=total_expected,
|
||||||
|
precision=precision,
|
||||||
|
recall=recall,
|
||||||
|
f1=f1,
|
||||||
|
)
|
||||||
|
print(f"\nResults appended to: {args.output_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_similarity_details(details: dict, threshold: float):
|
||||||
|
"""Print detailed similarity distribution analysis."""
|
||||||
|
import statistics
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SIMILARITY DISTRIBUTION ANALYSIS")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
def compute_stats(values, name):
|
||||||
|
if not values:
|
||||||
|
print(f"\n{name}: No data")
|
||||||
|
return
|
||||||
|
print(f"\n{name} (n={len(values)}):")
|
||||||
|
print(f" Min: {min(values):.4f}")
|
||||||
|
print(f" Max: {max(values):.4f}")
|
||||||
|
print(f" Mean: {statistics.mean(values):.4f}")
|
||||||
|
if len(values) > 1:
|
||||||
|
print(f" StdDev: {statistics.stdev(values):.4f}")
|
||||||
|
print(f" Median: {statistics.median(values):.4f}")
|
||||||
|
|
||||||
|
above = sum(1 for v in values if v >= threshold)
|
||||||
|
below = sum(1 for v in values if v < threshold)
|
||||||
|
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
|
||||||
|
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
|
||||||
|
|
||||||
|
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities")
|
||||||
|
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities")
|
||||||
|
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities")
|
||||||
|
|
||||||
|
# Overlap analysis
|
||||||
|
tp_sims = details["true_positive_sims"]
|
||||||
|
fp_sims = details["false_positive_sims"]
|
||||||
|
if tp_sims and fp_sims:
|
||||||
|
print("\n" + "-" * 40)
|
||||||
|
print("OVERLAP ANALYSIS:")
|
||||||
|
tp_min, tp_max = min(tp_sims), max(tp_sims)
|
||||||
|
fp_min, fp_max = min(fp_sims), max(fp_sims)
|
||||||
|
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
|
||||||
|
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
|
||||||
|
|
||||||
|
overlap_min = max(tp_min, fp_min)
|
||||||
|
overlap_max = min(tp_max, fp_max)
|
||||||
|
if overlap_min < overlap_max:
|
||||||
|
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
|
||||||
|
else:
|
||||||
|
print(" NO OVERLAP - distributions are separable!")
|
||||||
|
|
||||||
|
# Sample detection details
|
||||||
|
det_details = details["detection_details"]
|
||||||
|
if det_details:
|
||||||
|
print("\n" + "-" * 40)
|
||||||
|
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
|
||||||
|
for i, det in enumerate(det_details[:20]):
|
||||||
|
expected = det["expected_logos"]
|
||||||
|
sims = det["similarities"]
|
||||||
|
print(f"\n [{i+1}] Image: {det['image']}")
|
||||||
|
print(f" Expected: {expected if expected else '(none)'}")
|
||||||
|
print(f" DETR score: {det['detr_score']:.3f}")
|
||||||
|
print(f" Similarities:")
|
||||||
|
for logo, sim in sims:
|
||||||
|
marker = " <-- CORRECT" if logo in expected else ""
|
||||||
|
print(f" {sim:.4f} {logo}{marker}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
def write_results_to_file(
|
||||||
|
output_path: Path,
|
||||||
|
args,
|
||||||
|
num_test_images: int,
|
||||||
|
true_positives: int,
|
||||||
|
false_positives: int,
|
||||||
|
false_negatives: int,
|
||||||
|
total_expected: int,
|
||||||
|
precision: float,
|
||||||
|
recall: float,
|
||||||
|
f1: float,
|
||||||
|
):
|
||||||
|
"""Write results summary to file."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"=" * 70,
|
||||||
|
"BURNLEY LOGO DETECTION TEST",
|
||||||
|
f"Model: {args.embedding_model}",
|
||||||
|
f"Method: Margin-based (margin={args.margin})",
|
||||||
|
"=" * 70,
|
||||||
|
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||||
|
"",
|
||||||
|
"Configuration:",
|
||||||
|
f" Embedding model: {args.embedding_model}",
|
||||||
|
f" Similarity threshold: {args.threshold}",
|
||||||
|
f" DETR threshold: {args.detr_threshold}",
|
||||||
|
f" Matching margin: {args.margin}",
|
||||||
|
f" Test images processed: {num_test_images}",
|
||||||
|
f" Reference logos: barnfield, vertu",
|
||||||
|
"",
|
||||||
|
"Results:",
|
||||||
|
f" True Positives: {true_positives:>6}",
|
||||||
|
f" False Positives: {false_positives:>6}",
|
||||||
|
f" False Negatives: {false_negatives:>6}",
|
||||||
|
f" Total Expected: {total_expected:>6}",
|
||||||
|
"",
|
||||||
|
"Scores:",
|
||||||
|
f" Precision: {precision:.4f} ({precision*100:.1f}%)",
|
||||||
|
f" Recall: {recall:.4f} ({recall*100:.1f}%)",
|
||||||
|
f" F1 Score: {f1:.4f} ({f1*100:.1f}%)",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(output_path, "a") as f:
|
||||||
|
f.write("\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user