Compare commits
5 Commits
91d1c9cd59
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| f2ae80c9e5 | |||
| 8b67b50d19 | |||
| 5ce6265a90 | |||
| 512f678310 | |||
| f598866d37 |
27
README.md
27
README.md
@ -2,6 +2,33 @@
|
|||||||
|
|
||||||
A testing framework for evaluating logo detection accuracy using DETR (DEtection TRansformer) and CLIP (Contrastive Language-Image Pre-training) models.
|
A testing framework for evaluating logo detection accuracy using DETR (DEtection TRansformer) and CLIP (Contrastive Language-Image Pre-training) models.
|
||||||
|
|
||||||
|
## Burnley Test: Averaged Embeddings with DINOv2
|
||||||
|
|
||||||
|
A targeted test using `DetectLogosEmbeddings` to detect two specific logos (barnfield and vertu) in 516 Burnley match images. Reference embeddings are averaged across all images in each reference directory, and matching uses margin-based comparison (margin=0.05).
|
||||||
|
|
||||||
|
**Test command:**
|
||||||
|
```bash
|
||||||
|
uv run python test_burnley_detection.py -e dinov2 -t 0.7 --margin 0.05 --output-file results_average_embeddings.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
**Results (DINOv2, threshold 0.70, margin 0.05):**
|
||||||
|
|
||||||
|
| Metric | Value |
|
||||||
|
|--------|-------|
|
||||||
|
| True Positives | 28 |
|
||||||
|
| False Positives | 36 |
|
||||||
|
| False Negatives | 125 |
|
||||||
|
| Total Expected | 146 |
|
||||||
|
| **Precision** | **43.8%** |
|
||||||
|
| **Recall** | **19.2%** |
|
||||||
|
| **F1 Score** | **26.7%** |
|
||||||
|
|
||||||
|
Ground truth is derived from filename prefixes: `vertu_` (vertu logo), `barnfield_` (barnfield logo), `barnfield+vertu_` (both logos). Images without these prefixes are treated as negatives.
|
||||||
|
|
||||||
|
Low recall suggests many logos go undetected by DETR or fall below the similarity threshold. The relatively low precision indicates DINOv2 averaged embeddings struggle to discriminate between the two logos in this domain. Further tuning of thresholds, margin, and embedding model (e.g. CLIP or SigLIP) may improve results.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Recommended Settings
|
## Recommended Settings
|
||||||
|
|
||||||
Based on extensive testing with the LogoDet-3K dataset, these are the optimal settings:
|
Based on extensive testing with the LogoDet-3K dataset, these are the optimal settings:
|
||||||
|
|||||||
364
logo_detection_embeddings.py
Normal file
364
logo_detection_embeddings.py
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
"""
|
||||||
|
Logo detection using DETR for object detection and selectable embedding models for feature matching.
|
||||||
|
|
||||||
|
This module provides a class for detecting logos in images using:
|
||||||
|
1. DETR (DEtection TRansformer) for initial logo region detection
|
||||||
|
2. Selectable embedding model (CLIP, DINOv2, or SigLIP) for feature extraction and matching
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Multiple reference images per logo entry, averaged into a single embedding
|
||||||
|
- Cache-aware: averaged embeddings are only recalculated when the filenames list changes
|
||||||
|
- Supports local model directories with fallback to HuggingFace
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import (
|
||||||
|
AutoImageProcessor,
|
||||||
|
AutoModel,
|
||||||
|
AutoProcessor,
|
||||||
|
CLIPModel,
|
||||||
|
CLIPProcessor,
|
||||||
|
Dinov2Model,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class DetectLogosEmbeddings:
|
||||||
|
"""
|
||||||
|
Logo detection class using DETR and a selectable embedding model.
|
||||||
|
|
||||||
|
This class detects logos in images by:
|
||||||
|
1. Using DETR to find potential logo regions (bounding boxes)
|
||||||
|
2. Extracting embeddings for each detected region using the selected model
|
||||||
|
3. Comparing embeddings with averaged reference logo embeddings for identification
|
||||||
|
|
||||||
|
Supported embedding models:
|
||||||
|
- clip: openai/clip-vit-large-patch14
|
||||||
|
- dinov2: facebook/dinov2-base (recommended for visual similarity)
|
||||||
|
- siglip: google/siglip-base-patch16-224
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logger,
|
||||||
|
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
|
||||||
|
embedding_model_type: str = "dinov2",
|
||||||
|
detr_threshold: float = 0.5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize DETR and embedding models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger: Logger instance for logging
|
||||||
|
detr_model: HuggingFace model name or local path for DETR object detection
|
||||||
|
embedding_model_type: One of "clip", "dinov2", or "siglip"
|
||||||
|
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||||
|
"""
|
||||||
|
self.logger = logger
|
||||||
|
self.detr_threshold = detr_threshold
|
||||||
|
self.embedding_model_type = embedding_model_type
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.device_index = 0 if torch.cuda.is_available() else -1
|
||||||
|
self.device = torch.device(self.device_str)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"Initializing DetectLogosEmbeddings on device: {self.device_str}, "
|
||||||
|
f"embedding model: {embedding_model_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- DETR model ---
|
||||||
|
default_detr_dir = os.environ.get(
|
||||||
|
"LOGO_DETR_MODEL_DIR", "models/logo_detection/detr"
|
||||||
|
)
|
||||||
|
detr_model_path = self._resolve_model_path(detr_model, default_detr_dir, "DETR")
|
||||||
|
|
||||||
|
self.logger.info(f"Loading DETR model: {detr_model_path}")
|
||||||
|
self.detr_pipe = pipeline(
|
||||||
|
task="object-detection",
|
||||||
|
model=detr_model_path,
|
||||||
|
device=self.device_index,
|
||||||
|
use_fast=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Embedding model ---
|
||||||
|
self._load_embedding_model(embedding_model_type)
|
||||||
|
|
||||||
|
self.logger.info("DetectLogosEmbeddings initialization complete")
|
||||||
|
|
||||||
|
def _load_embedding_model(self, model_type: str) -> None:
|
||||||
|
"""
|
||||||
|
Load the selected embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: One of "clip", "dinov2", or "siglip"
|
||||||
|
"""
|
||||||
|
default_embedding_dir = os.environ.get(
|
||||||
|
"LOGO_EMBEDDING_MODEL_DIR", f"models/logo_detection/{model_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "clip":
|
||||||
|
model_name = "openai/clip-vit-large-patch14"
|
||||||
|
model_path = self._resolve_model_path(
|
||||||
|
model_name, default_embedding_dir, "CLIP"
|
||||||
|
)
|
||||||
|
self.logger.info(f"Loading CLIP model: {model_path}")
|
||||||
|
self._clip_model = CLIPModel.from_pretrained(model_path).to(self.device)
|
||||||
|
self._clip_processor = CLIPProcessor.from_pretrained(model_path)
|
||||||
|
self._clip_model.eval()
|
||||||
|
|
||||||
|
def embed_fn(pil_image):
|
||||||
|
inputs = self._clip_processor(
|
||||||
|
images=pil_image, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
features = self._clip_model.get_image_features(**inputs)
|
||||||
|
return F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
elif model_type == "dinov2":
|
||||||
|
model_name = "facebook/dinov2-base"
|
||||||
|
model_path = self._resolve_model_path(
|
||||||
|
model_name, default_embedding_dir, "DINOv2"
|
||||||
|
)
|
||||||
|
self.logger.info(f"Loading DINOv2 model: {model_path}")
|
||||||
|
self._dinov2_model = Dinov2Model.from_pretrained(model_path).to(self.device)
|
||||||
|
self._dinov2_processor = AutoImageProcessor.from_pretrained(model_path)
|
||||||
|
self._dinov2_model.eval()
|
||||||
|
|
||||||
|
def embed_fn(pil_image):
|
||||||
|
inputs = self._dinov2_processor(
|
||||||
|
images=pil_image, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self._dinov2_model(**inputs)
|
||||||
|
# Use CLS token embedding
|
||||||
|
features = outputs.last_hidden_state[:, 0, :]
|
||||||
|
return F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
elif model_type == "siglip":
|
||||||
|
model_name = "google/siglip-base-patch16-224"
|
||||||
|
model_path = self._resolve_model_path(
|
||||||
|
model_name, default_embedding_dir, "SigLIP"
|
||||||
|
)
|
||||||
|
self.logger.info(f"Loading SigLIP model: {model_path}")
|
||||||
|
self._siglip_model = AutoModel.from_pretrained(model_path).to(self.device)
|
||||||
|
self._siglip_processor = AutoProcessor.from_pretrained(model_path)
|
||||||
|
self._siglip_model.eval()
|
||||||
|
|
||||||
|
def embed_fn(pil_image):
|
||||||
|
inputs = self._siglip_processor(
|
||||||
|
images=pil_image, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
features = self._siglip_model.get_image_features(**inputs)
|
||||||
|
return F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown embedding model type: {model_type}. "
|
||||||
|
f"Use 'clip', 'dinov2', or 'siglip'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._embed_fn = embed_fn
|
||||||
|
|
||||||
|
def _resolve_model_path(
|
||||||
|
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Resolve model path, checking for local models before using HuggingFace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name_or_path: HuggingFace model name or absolute path
|
||||||
|
default_local_dir: Default local directory to check
|
||||||
|
model_type: Type of model (for logging)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved model path (local path or HuggingFace model name)
|
||||||
|
"""
|
||||||
|
# If it's an absolute path, use it directly
|
||||||
|
if os.path.isabs(model_name_or_path):
|
||||||
|
if os.path.exists(model_name_or_path):
|
||||||
|
self.logger.info(
|
||||||
|
f"{model_type} model: Using local model at {model_name_or_path}"
|
||||||
|
)
|
||||||
|
return model_name_or_path
|
||||||
|
else:
|
||||||
|
self.logger.warning(
|
||||||
|
f"{model_type} model: Local path {model_name_or_path} does not exist, "
|
||||||
|
f"falling back to HuggingFace"
|
||||||
|
)
|
||||||
|
return model_name_or_path
|
||||||
|
|
||||||
|
# Check if default local directory exists
|
||||||
|
if os.path.exists(default_local_dir):
|
||||||
|
config_file = os.path.join(default_local_dir, "config.json")
|
||||||
|
if os.path.exists(config_file):
|
||||||
|
abs_path = os.path.abspath(default_local_dir)
|
||||||
|
self.logger.info(
|
||||||
|
f"{model_type} model: Found local model at {abs_path}"
|
||||||
|
)
|
||||||
|
return abs_path
|
||||||
|
else:
|
||||||
|
self.logger.warning(
|
||||||
|
f"{model_type} model: Local directory {default_local_dir} exists but "
|
||||||
|
f"is not a valid model (missing config.json)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use HuggingFace model name
|
||||||
|
self.logger.info(
|
||||||
|
f"{model_type} model: No local model found, will download from HuggingFace: "
|
||||||
|
f"{model_name_or_path}"
|
||||||
|
)
|
||||||
|
return model_name_or_path
|
||||||
|
|
||||||
|
def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Detect logos in an image and return bounding boxes with embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: OpenCV image (BGR format, numpy array)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries, each containing:
|
||||||
|
- 'box': dict with 'xmin', 'ymin', 'xmax', 'ymax' (pixel coordinates)
|
||||||
|
- 'score': DETR confidence score (float 0-1)
|
||||||
|
- 'embedding': Feature embedding (torch.Tensor)
|
||||||
|
- 'label': DETR predicted label (string)
|
||||||
|
"""
|
||||||
|
# Convert OpenCV BGR to RGB PIL Image
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
pil_image = Image.fromarray(image_rgb)
|
||||||
|
|
||||||
|
# Run DETR detection
|
||||||
|
predictions = self.detr_pipe(pil_image)
|
||||||
|
|
||||||
|
# Filter by threshold and add embeddings
|
||||||
|
detections = []
|
||||||
|
for pred in predictions:
|
||||||
|
score = pred.get("score", 0.0)
|
||||||
|
if score < self.detr_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
box = pred.get("box", {})
|
||||||
|
xmin = box.get("xmin", 0)
|
||||||
|
ymin = box.get("ymin", 0)
|
||||||
|
xmax = box.get("xmax", 0)
|
||||||
|
ymax = box.get("ymax", 0)
|
||||||
|
|
||||||
|
# Extract bounding box region
|
||||||
|
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
|
||||||
|
|
||||||
|
# Get embedding for this region
|
||||||
|
embedding = self._embed_fn(bbox_crop)
|
||||||
|
|
||||||
|
detections.append(
|
||||||
|
{
|
||||||
|
"box": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax},
|
||||||
|
"score": score,
|
||||||
|
"embedding": embedding,
|
||||||
|
"label": pred.get("label", "logo"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Detected {len(detections)} logos (threshold: {self.detr_threshold})"
|
||||||
|
)
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get embedding for a single reference logo image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: OpenCV image (BGR format, numpy array)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized feature embedding (torch.Tensor)
|
||||||
|
"""
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
pil_image = Image.fromarray(image_rgb)
|
||||||
|
return self._embed_fn(pil_image)
|
||||||
|
|
||||||
|
def get_averaged_embedding(self, images: List[np.ndarray]) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute averaged embedding from multiple reference logo images.
|
||||||
|
|
||||||
|
Follows the averaging pattern from db_embeddings.py:
|
||||||
|
1. Compute embedding for each image
|
||||||
|
2. Stack and average across all images
|
||||||
|
3. Re-normalize the averaged embedding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of OpenCV images (BGR format, numpy arrays)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized averaged embedding (torch.Tensor, shape [1, D]),
|
||||||
|
or None if no valid embeddings could be computed
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
for img in images:
|
||||||
|
try:
|
||||||
|
emb = self.get_embedding(img)
|
||||||
|
embeddings.append(emb)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to compute embedding for reference image: {e}")
|
||||||
|
|
||||||
|
if not embeddings:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Stack: (N, D), average: (1, D), re-normalize
|
||||||
|
stacked = torch.cat(embeddings, dim=0)
|
||||||
|
avg_emb = stacked.mean(dim=0, keepdim=True)
|
||||||
|
avg_emb = F.normalize(avg_emb, dim=-1)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Computed averaged embedding from {len(embeddings)} reference image(s)"
|
||||||
|
)
|
||||||
|
return avg_emb
|
||||||
|
|
||||||
|
def compare_embeddings(
|
||||||
|
self, embedding1: torch.Tensor, embedding2: torch.Tensor
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute cosine similarity between two embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding1: First embedding (torch.Tensor)
|
||||||
|
embedding2: Second embedding (torch.Tensor)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cosine similarity score (float, range: -1 to 1, typically 0 to 1)
|
||||||
|
"""
|
||||||
|
# Ensure tensors are on the same device
|
||||||
|
if embedding1.device != embedding2.device:
|
||||||
|
embedding2 = embedding2.to(embedding1.device)
|
||||||
|
|
||||||
|
similarity = F.cosine_similarity(embedding1, embedding2, dim=-1)
|
||||||
|
return similarity.item()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_filenames_hash(filenames: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Compute a deterministic hash of a filenames list.
|
||||||
|
|
||||||
|
Used for cache invalidation — if the filenames list changes,
|
||||||
|
the hash changes, triggering re-computation of averaged embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filenames: List of filename strings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
16-character hex hash string
|
||||||
|
"""
|
||||||
|
canonical = json.dumps(sorted(filenames))
|
||||||
|
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]
|
||||||
52
results_average_embeddings.txt
Normal file
52
results_average_embeddings.txt
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
======================================================================
|
||||||
|
BURNLEY LOGO DETECTION TEST
|
||||||
|
Model: dinov2
|
||||||
|
Method: Margin-based (margin=0.05)
|
||||||
|
======================================================================
|
||||||
|
Date: 2026-03-31 11:45:03
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
Embedding model: dinov2
|
||||||
|
Similarity threshold: 0.7
|
||||||
|
DETR threshold: 0.5
|
||||||
|
Matching margin: 0.05
|
||||||
|
Test images processed: 516
|
||||||
|
Reference logos: barnfield, vertu
|
||||||
|
|
||||||
|
Results:
|
||||||
|
True Positives: 28
|
||||||
|
False Positives: 36
|
||||||
|
False Negatives: 125
|
||||||
|
Total Expected: 146
|
||||||
|
|
||||||
|
Scores:
|
||||||
|
Precision: 0.4375 (43.8%)
|
||||||
|
Recall: 0.1918 (19.2%)
|
||||||
|
F1 Score: 0.2667 (26.7%)
|
||||||
|
|
||||||
|
======================================================================
|
||||||
|
BURNLEY LOGO DETECTION TEST
|
||||||
|
Model: dinov2
|
||||||
|
Method: Margin-based (margin=0.05)
|
||||||
|
======================================================================
|
||||||
|
Date: 2026-03-31 12:29:32
|
||||||
|
|
||||||
|
Configuration:
|
||||||
|
Embedding model: dinov2
|
||||||
|
Similarity threshold: 0.7
|
||||||
|
DETR threshold: 0.5
|
||||||
|
Matching margin: 0.05
|
||||||
|
Test images processed: 516
|
||||||
|
Reference logos: barnfield, vertu
|
||||||
|
|
||||||
|
Results:
|
||||||
|
True Positives: 28
|
||||||
|
False Positives: 36
|
||||||
|
False Negatives: 125
|
||||||
|
Total Expected: 146
|
||||||
|
|
||||||
|
Scores:
|
||||||
|
Precision: 0.4375 (43.8%)
|
||||||
|
Recall: 0.1918 (19.2%)
|
||||||
|
F1 Score: 0.2667 (26.7%)
|
||||||
|
|
||||||
521
test_burnley_detection.py
Normal file
521
test_burnley_detection.py
Normal file
@ -0,0 +1,521 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script for logo detection accuracy on Burnley test images.
|
||||||
|
|
||||||
|
Uses DetectLogosEmbeddings from logo_detection_embeddings.py to detect
|
||||||
|
barnfield and vertu logos. Ground truth is determined by filename prefix:
|
||||||
|
- "vertu_" → contains vertu logo
|
||||||
|
- "barnfield_" → contains barnfield logo
|
||||||
|
- "barnfield+vertu_" → contains both logos
|
||||||
|
- anything else → no target logos
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import pickle
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from logo_detection_embeddings import DetectLogosEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(verbose: bool = False) -> logging.Logger:
|
||||||
|
"""Configure logging."""
|
||||||
|
level = logging.DEBUG if verbose else logging.INFO
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||||
|
datefmt="%H:%M:%S",
|
||||||
|
)
|
||||||
|
return logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image_path: Path) -> Optional[cv2.Mat]:
|
||||||
|
"""Load an image using OpenCV."""
|
||||||
|
img = cv2.imread(str(image_path))
|
||||||
|
if img is None:
|
||||||
|
return None
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingCache:
|
||||||
|
"""Simple file-based cache for embeddings."""
|
||||||
|
|
||||||
|
def __init__(self, cache_path: Path):
|
||||||
|
self.cache_path = cache_path
|
||||||
|
self.cache: Dict[str, Any] = {}
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
if self.cache_path.exists():
|
||||||
|
try:
|
||||||
|
with open(self.cache_path, "rb") as f:
|
||||||
|
self.cache = pickle.load(f)
|
||||||
|
except Exception:
|
||||||
|
self.cache = {}
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(self.cache_path, "wb") as f:
|
||||||
|
pickle.dump(self.cache, f)
|
||||||
|
|
||||||
|
def get(self, key: str):
|
||||||
|
return self.cache.get(key)
|
||||||
|
|
||||||
|
def put(self, key: str, value):
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
self.cache[key] = value.cpu()
|
||||||
|
else:
|
||||||
|
self.cache[key] = value
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.cache)
|
||||||
|
|
||||||
|
|
||||||
|
def get_expected_logos(filename: str) -> Set[str]:
|
||||||
|
"""Determine expected logos from filename prefix."""
|
||||||
|
name = filename.lower()
|
||||||
|
if name.startswith("barnfield+vertu_"):
|
||||||
|
return {"barnfield", "vertu"}
|
||||||
|
elif name.startswith("barnfield_"):
|
||||||
|
return {"barnfield"}
|
||||||
|
elif name.startswith("vertu_"):
|
||||||
|
return {"vertu"}
|
||||||
|
return set()
|
||||||
|
|
||||||
|
|
||||||
|
def load_reference_images(ref_dir: Path, logger: logging.Logger) -> List[cv2.Mat]:
|
||||||
|
"""Load all images from a reference directory."""
|
||||||
|
images = []
|
||||||
|
for path in sorted(ref_dir.iterdir()):
|
||||||
|
if path.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp"):
|
||||||
|
img = load_image(path)
|
||||||
|
if img is not None:
|
||||||
|
images.append(img)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to load reference image: {path}")
|
||||||
|
return images
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Test logo detection on Burnley test images using DetectLogosEmbeddings"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-t", "--threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.7,
|
||||||
|
help="Similarity threshold for matching (default: 0.7)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-d", "--detr-threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="DETR detection confidence threshold (default: 0.5)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-e", "--embedding-model",
|
||||||
|
type=str,
|
||||||
|
choices=["clip", "dinov2", "siglip"],
|
||||||
|
default="dinov2",
|
||||||
|
help="Embedding model type (default: dinov2)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--margin",
|
||||||
|
type=float,
|
||||||
|
default=0.05,
|
||||||
|
help="Required margin between best and second-best match (default: 0.05)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-v", "--verbose",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable verbose logging",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--similarity-details",
|
||||||
|
action="store_true",
|
||||||
|
help="Output detailed similarity scores for each detection",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-cache",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable embedding cache",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--clear-cache",
|
||||||
|
action="store_true",
|
||||||
|
help="Clear embedding cache before running",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Append results summary to this file",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
logger = setup_logging(args.verbose)
|
||||||
|
|
||||||
|
# Paths
|
||||||
|
base_dir = Path(__file__).resolve().parent
|
||||||
|
test_images_dir = base_dir / "burnley_test_images"
|
||||||
|
barnfield_ref_dir = base_dir / "barnfield_reference_images"
|
||||||
|
vertu_ref_dir = base_dir / "vertu_reference_images"
|
||||||
|
cache_path = base_dir / ".burnley_embedding_cache.pkl"
|
||||||
|
|
||||||
|
# Verify directories exist
|
||||||
|
for d, name in [(test_images_dir, "Test images"), (barnfield_ref_dir, "Barnfield refs"), (vertu_ref_dir, "Vertu refs")]:
|
||||||
|
if not d.exists():
|
||||||
|
logger.error(f"{name} directory not found: {d}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Handle cache
|
||||||
|
if args.clear_cache and cache_path.exists():
|
||||||
|
cache_path.unlink()
|
||||||
|
logger.info("Cleared embedding cache")
|
||||||
|
|
||||||
|
cache = EmbeddingCache(cache_path) if not args.no_cache else None
|
||||||
|
if cache:
|
||||||
|
logger.info(f"Loaded {len(cache)} cached embeddings")
|
||||||
|
|
||||||
|
# Initialize detector
|
||||||
|
logger.info(f"Initializing detector with embedding model: {args.embedding_model}")
|
||||||
|
detector = DetectLogosEmbeddings(
|
||||||
|
logger=logger,
|
||||||
|
detr_threshold=args.detr_threshold,
|
||||||
|
embedding_model_type=args.embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute averaged reference embeddings
|
||||||
|
logger.info("Computing reference embeddings...")
|
||||||
|
|
||||||
|
reference_embeddings: Dict[str, torch.Tensor] = {}
|
||||||
|
for logo_name, ref_dir in [("barnfield", barnfield_ref_dir), ("vertu", vertu_ref_dir)]:
|
||||||
|
cache_key = f"avg_ref:{logo_name}:{args.embedding_model}"
|
||||||
|
cached = cache.get(cache_key) if cache else None
|
||||||
|
|
||||||
|
if cached is not None:
|
||||||
|
reference_embeddings[logo_name] = cached
|
||||||
|
logger.info(f"Loaded cached averaged embedding for {logo_name}")
|
||||||
|
else:
|
||||||
|
ref_images = load_reference_images(ref_dir, logger)
|
||||||
|
logger.info(f"Computing averaged embedding for {logo_name} from {len(ref_images)} images")
|
||||||
|
avg_emb = detector.get_averaged_embedding(ref_images)
|
||||||
|
if avg_emb is None:
|
||||||
|
logger.error(f"Failed to compute embedding for {logo_name}")
|
||||||
|
sys.exit(1)
|
||||||
|
reference_embeddings[logo_name] = avg_emb
|
||||||
|
if cache:
|
||||||
|
cache.put(cache_key, avg_emb)
|
||||||
|
|
||||||
|
# Collect test images
|
||||||
|
test_files = sorted([
|
||||||
|
f.name for f in test_images_dir.iterdir()
|
||||||
|
if f.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp")
|
||||||
|
])
|
||||||
|
logger.info(f"Found {len(test_files)} test images")
|
||||||
|
|
||||||
|
# Metrics
|
||||||
|
true_positives = 0
|
||||||
|
false_positives = 0
|
||||||
|
false_negatives = 0
|
||||||
|
total_expected = 0
|
||||||
|
results = []
|
||||||
|
|
||||||
|
similarity_details = {
|
||||||
|
"true_positive_sims": [],
|
||||||
|
"false_positive_sims": [],
|
||||||
|
"missed_best_sims": [],
|
||||||
|
"detection_details": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process test images
|
||||||
|
for test_filename in tqdm(test_files, desc="Testing"):
|
||||||
|
test_path = test_images_dir / test_filename
|
||||||
|
expected_logos = get_expected_logos(test_filename)
|
||||||
|
total_expected += len(expected_logos)
|
||||||
|
|
||||||
|
# Check cache for detections
|
||||||
|
det_cache_key = f"det:{test_filename}:{args.embedding_model}"
|
||||||
|
cached_detections = cache.get(det_cache_key) if cache else None
|
||||||
|
|
||||||
|
if cached_detections is not None:
|
||||||
|
detections = cached_detections
|
||||||
|
else:
|
||||||
|
test_img = load_image(test_path)
|
||||||
|
if test_img is None:
|
||||||
|
logger.warning(f"Failed to load test image: {test_path}")
|
||||||
|
continue
|
||||||
|
detections = detector.detect(test_img)
|
||||||
|
if cache:
|
||||||
|
cache.put(det_cache_key, detections)
|
||||||
|
|
||||||
|
# Match each detection against reference embeddings with margin
|
||||||
|
matched_logos: Set[str] = set()
|
||||||
|
for det_idx, detection in enumerate(detections):
|
||||||
|
# Compute similarity to each reference logo
|
||||||
|
sims: Dict[str, float] = {}
|
||||||
|
for logo_name, ref_emb in reference_embeddings.items():
|
||||||
|
sims[logo_name] = detector.compare_embeddings(
|
||||||
|
detection["embedding"], ref_emb
|
||||||
|
)
|
||||||
|
|
||||||
|
sorted_sims = sorted(sims.items(), key=lambda x: -x[1])
|
||||||
|
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["detection_details"].append({
|
||||||
|
"image": test_filename,
|
||||||
|
"detection_idx": det_idx,
|
||||||
|
"expected_logos": list(expected_logos),
|
||||||
|
"similarities": sorted_sims,
|
||||||
|
"detr_score": detection.get("score", 0),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Best match with margin check
|
||||||
|
if not sorted_sims:
|
||||||
|
continue
|
||||||
|
|
||||||
|
best_name, best_sim = sorted_sims[0]
|
||||||
|
if best_sim < args.threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check margin over second best
|
||||||
|
if len(sorted_sims) > 1:
|
||||||
|
second_sim = sorted_sims[1][1]
|
||||||
|
if best_sim - second_sim < args.margin:
|
||||||
|
continue
|
||||||
|
|
||||||
|
matched_logos.add(best_name)
|
||||||
|
is_correct = best_name in expected_logos
|
||||||
|
|
||||||
|
if is_correct:
|
||||||
|
true_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["true_positive_sims"].append(best_sim)
|
||||||
|
else:
|
||||||
|
false_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["false_positive_sims"].append(best_sim)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"test_image": test_filename,
|
||||||
|
"matched_logo": best_name,
|
||||||
|
"similarity": best_sim,
|
||||||
|
"correct": is_correct,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Count missed detections
|
||||||
|
missed = expected_logos - matched_logos
|
||||||
|
false_negatives += len(missed)
|
||||||
|
|
||||||
|
for missed_logo in missed:
|
||||||
|
if args.similarity_details and detections:
|
||||||
|
best_sim_for_missed = 0
|
||||||
|
ref_emb = reference_embeddings[missed_logo]
|
||||||
|
for detection in detections:
|
||||||
|
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||||
|
best_sim_for_missed = max(best_sim_for_missed, sim)
|
||||||
|
similarity_details["missed_best_sims"].append(best_sim_for_missed)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"test_image": test_filename,
|
||||||
|
"matched_logo": None,
|
||||||
|
"expected_logo": missed_logo,
|
||||||
|
"similarity": None,
|
||||||
|
"correct": False,
|
||||||
|
})
|
||||||
|
|
||||||
|
# Save cache
|
||||||
|
if cache:
|
||||||
|
cache.save()
|
||||||
|
logger.info(f"Saved {len(cache)} embeddings to cache")
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
||||||
|
recall = true_positives / total_expected if total_expected > 0 else 0
|
||||||
|
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("BURNLEY LOGO DETECTION TEST RESULTS")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"\nConfiguration:")
|
||||||
|
print(f" Embedding model: {args.embedding_model}")
|
||||||
|
print(f" Similarity threshold: {args.threshold}")
|
||||||
|
print(f" DETR confidence threshold: {args.detr_threshold}")
|
||||||
|
print(f" Matching margin: {args.margin}")
|
||||||
|
print(f" Test images processed: {len(test_files)}")
|
||||||
|
print(f" Reference logos: barnfield, vertu")
|
||||||
|
|
||||||
|
print(f"\nMetrics:")
|
||||||
|
print(f" True Positives (correct matches): {true_positives}")
|
||||||
|
print(f" False Positives (wrong matches): {false_positives}")
|
||||||
|
print(f" False Negatives (missed logos): {false_negatives}")
|
||||||
|
print(f" Total expected matches: {total_expected}")
|
||||||
|
|
||||||
|
print(f"\nScores:")
|
||||||
|
print(f" Precision: {precision:.4f} ({precision*100:.1f}%)")
|
||||||
|
print(f" Recall: {recall:.4f} ({recall*100:.1f}%)")
|
||||||
|
print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)")
|
||||||
|
|
||||||
|
# Show false positive examples
|
||||||
|
false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]]
|
||||||
|
if false_positive_examples:
|
||||||
|
print(f"\nExample False Positives (first 5):")
|
||||||
|
for r in false_positive_examples[:5]:
|
||||||
|
print(f" - Image: {r['test_image']}")
|
||||||
|
print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})")
|
||||||
|
|
||||||
|
# Show false negative examples
|
||||||
|
false_negative_examples = [r for r in results if r.get("expected_logo")]
|
||||||
|
if false_negative_examples:
|
||||||
|
print(f"\nExample False Negatives (first 5):")
|
||||||
|
for r in false_negative_examples[:5]:
|
||||||
|
print(f" - Image: {r['test_image']}")
|
||||||
|
print(f" Expected: {r['expected_logo']}")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Print similarity details if requested
|
||||||
|
if args.similarity_details:
|
||||||
|
print_similarity_details(similarity_details, args.threshold)
|
||||||
|
|
||||||
|
# Write results to file if requested
|
||||||
|
if args.output_file:
|
||||||
|
write_results_to_file(
|
||||||
|
output_path=Path(args.output_file),
|
||||||
|
args=args,
|
||||||
|
num_test_images=len(test_files),
|
||||||
|
true_positives=true_positives,
|
||||||
|
false_positives=false_positives,
|
||||||
|
false_negatives=false_negatives,
|
||||||
|
total_expected=total_expected,
|
||||||
|
precision=precision,
|
||||||
|
recall=recall,
|
||||||
|
f1=f1,
|
||||||
|
)
|
||||||
|
print(f"\nResults appended to: {args.output_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_similarity_details(details: dict, threshold: float):
|
||||||
|
"""Print detailed similarity distribution analysis."""
|
||||||
|
import statistics
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SIMILARITY DISTRIBUTION ANALYSIS")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
def compute_stats(values, name):
|
||||||
|
if not values:
|
||||||
|
print(f"\n{name}: No data")
|
||||||
|
return
|
||||||
|
print(f"\n{name} (n={len(values)}):")
|
||||||
|
print(f" Min: {min(values):.4f}")
|
||||||
|
print(f" Max: {max(values):.4f}")
|
||||||
|
print(f" Mean: {statistics.mean(values):.4f}")
|
||||||
|
if len(values) > 1:
|
||||||
|
print(f" StdDev: {statistics.stdev(values):.4f}")
|
||||||
|
print(f" Median: {statistics.median(values):.4f}")
|
||||||
|
|
||||||
|
above = sum(1 for v in values if v >= threshold)
|
||||||
|
below = sum(1 for v in values if v < threshold)
|
||||||
|
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
|
||||||
|
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
|
||||||
|
|
||||||
|
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities")
|
||||||
|
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities")
|
||||||
|
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities")
|
||||||
|
|
||||||
|
# Overlap analysis
|
||||||
|
tp_sims = details["true_positive_sims"]
|
||||||
|
fp_sims = details["false_positive_sims"]
|
||||||
|
if tp_sims and fp_sims:
|
||||||
|
print("\n" + "-" * 40)
|
||||||
|
print("OVERLAP ANALYSIS:")
|
||||||
|
tp_min, tp_max = min(tp_sims), max(tp_sims)
|
||||||
|
fp_min, fp_max = min(fp_sims), max(fp_sims)
|
||||||
|
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
|
||||||
|
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
|
||||||
|
|
||||||
|
overlap_min = max(tp_min, fp_min)
|
||||||
|
overlap_max = min(tp_max, fp_max)
|
||||||
|
if overlap_min < overlap_max:
|
||||||
|
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
|
||||||
|
else:
|
||||||
|
print(" NO OVERLAP - distributions are separable!")
|
||||||
|
|
||||||
|
# Sample detection details
|
||||||
|
det_details = details["detection_details"]
|
||||||
|
if det_details:
|
||||||
|
print("\n" + "-" * 40)
|
||||||
|
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
|
||||||
|
for i, det in enumerate(det_details[:20]):
|
||||||
|
expected = det["expected_logos"]
|
||||||
|
sims = det["similarities"]
|
||||||
|
print(f"\n [{i+1}] Image: {det['image']}")
|
||||||
|
print(f" Expected: {expected if expected else '(none)'}")
|
||||||
|
print(f" DETR score: {det['detr_score']:.3f}")
|
||||||
|
print(f" Similarities:")
|
||||||
|
for logo, sim in sims:
|
||||||
|
marker = " <-- CORRECT" if logo in expected else ""
|
||||||
|
print(f" {sim:.4f} {logo}{marker}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
def write_results_to_file(
|
||||||
|
output_path: Path,
|
||||||
|
args,
|
||||||
|
num_test_images: int,
|
||||||
|
true_positives: int,
|
||||||
|
false_positives: int,
|
||||||
|
false_negatives: int,
|
||||||
|
total_expected: int,
|
||||||
|
precision: float,
|
||||||
|
recall: float,
|
||||||
|
f1: float,
|
||||||
|
):
|
||||||
|
"""Write results summary to file."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"=" * 70,
|
||||||
|
"BURNLEY LOGO DETECTION TEST",
|
||||||
|
f"Model: {args.embedding_model}",
|
||||||
|
f"Method: Margin-based (margin={args.margin})",
|
||||||
|
"=" * 70,
|
||||||
|
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||||
|
"",
|
||||||
|
"Configuration:",
|
||||||
|
f" Embedding model: {args.embedding_model}",
|
||||||
|
f" Similarity threshold: {args.threshold}",
|
||||||
|
f" DETR threshold: {args.detr_threshold}",
|
||||||
|
f" Matching margin: {args.margin}",
|
||||||
|
f" Test images processed: {num_test_images}",
|
||||||
|
f" Reference logos: barnfield, vertu",
|
||||||
|
"",
|
||||||
|
"Results:",
|
||||||
|
f" True Positives: {true_positives:>6}",
|
||||||
|
f" False Positives: {false_positives:>6}",
|
||||||
|
f" False Negatives: {false_negatives:>6}",
|
||||||
|
f" Total Expected: {total_expected:>6}",
|
||||||
|
"",
|
||||||
|
"Scores:",
|
||||||
|
f" Precision: {precision:.4f} ({precision*100:.1f}%)",
|
||||||
|
f" Recall: {recall:.4f} ({recall*100:.1f}%)",
|
||||||
|
f" F1 Score: {f1:.4f} ({f1*100:.1f}%)",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
with open(output_path, "a") as f:
|
||||||
|
f.write("\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user