Compare commits
8 Commits
49f982611a
...
master
| Author | SHA1 | Date | |
|---|---|---|---|
| f2ae80c9e5 | |||
| 8b67b50d19 | |||
| 5ce6265a90 | |||
| 512f678310 | |||
| f598866d37 | |||
| 91d1c9cd59 | |||
| ea6fcec9ce | |||
| f777b049a3 |
121
README.md
121
README.md
@ -2,6 +2,110 @@
|
||||
|
||||
A testing framework for evaluating logo detection accuracy using DETR (DEtection TRansformer) and CLIP (Contrastive Language-Image Pre-training) models.
|
||||
|
||||
## Burnley Test: Averaged Embeddings with DINOv2
|
||||
|
||||
A targeted test using `DetectLogosEmbeddings` to detect two specific logos (barnfield and vertu) in 516 Burnley match images. Reference embeddings are averaged across all images in each reference directory, and matching uses margin-based comparison (margin=0.05).
|
||||
|
||||
**Test command:**
|
||||
```bash
|
||||
uv run python test_burnley_detection.py -e dinov2 -t 0.7 --margin 0.05 --output-file results_average_embeddings.txt
|
||||
```
|
||||
|
||||
**Results (DINOv2, threshold 0.70, margin 0.05):**
|
||||
|
||||
| Metric | Value |
|
||||
|--------|-------|
|
||||
| True Positives | 28 |
|
||||
| False Positives | 36 |
|
||||
| False Negatives | 125 |
|
||||
| Total Expected | 146 |
|
||||
| **Precision** | **43.8%** |
|
||||
| **Recall** | **19.2%** |
|
||||
| **F1 Score** | **26.7%** |
|
||||
|
||||
Ground truth is derived from filename prefixes: `vertu_` (vertu logo), `barnfield_` (barnfield logo), `barnfield+vertu_` (both logos). Images without these prefixes are treated as negatives.
|
||||
|
||||
Low recall suggests many logos go undetected by DETR or fall below the similarity threshold. The relatively low precision indicates DINOv2 averaged embeddings struggle to discriminate between the two logos in this domain. Further tuning of thresholds, margin, and embedding model (e.g. CLIP or SigLIP) may improve results.
|
||||
|
||||
---
|
||||
|
||||
## Recommended Settings
|
||||
|
||||
Based on extensive testing with the LogoDet-3K dataset, these are the optimal settings:
|
||||
|
||||
| Parameter | Recommended Value | Notes |
|
||||
|-----------|-------------------|-------|
|
||||
| **Matching Method** | `multi-ref` | Best balance of precision and recall |
|
||||
| **Similarity Aggregation** | `max` (default) | Max outperforms mean aggregation |
|
||||
| **Embedding Model** | `openai/clip-vit-large-patch14` | Significantly outperforms DINOv2 |
|
||||
| **CLIP Threshold** | `0.70` | Good precision/recall balance |
|
||||
| **DETR Threshold** | `0.50` | Default detection confidence |
|
||||
| **Margin** | `0.05` | Reduces false positives |
|
||||
| **Refs per Logo** | `7-10` | More references = better accuracy |
|
||||
| **Preprocessing** | `default` | Best precision; letterbox/stretch hurt precision |
|
||||
|
||||
**Example command with recommended settings:**
|
||||
```bash
|
||||
uv run python test_logo_detection.py \
|
||||
--matching-method multi-ref \
|
||||
--refs-per-logo 10 \
|
||||
--threshold 0.70 \
|
||||
--margin 0.05 \
|
||||
--use-max-similarity
|
||||
```
|
||||
|
||||
### Performance Benchmarks
|
||||
|
||||
With recommended settings (multi-ref max, threshold 0.70, margin 0.05):
|
||||
|
||||
| Refs/Logo | Precision | Recall | F1 Score |
|
||||
|-----------|-----------|--------|----------|
|
||||
| 1 | 45.8% | 65.9% | 54.0% |
|
||||
| 3 | 40.5% | 72.4% | 51.9% |
|
||||
| 5 | 47.2% | 72.6% | 57.2% |
|
||||
| 7 | **51.0%** | **79.9%** | **62.3%** |
|
||||
| 10 | 50.2% | 81.6% | 62.1% |
|
||||
|
||||
**Key findings:**
|
||||
- More reference images per logo consistently improves recall
|
||||
- 7+ refs provides the best precision/recall balance
|
||||
- Diminishing returns beyond 10 refs
|
||||
|
||||
### Matching Method Comparison
|
||||
|
||||
| Method | Precision | Recall | F1 | Use Case |
|
||||
|--------|-----------|--------|-----|----------|
|
||||
| `simple` | 1.3% | 203%* | 2.5% | Not recommended (too many FPs) |
|
||||
| `margin` | 69.8% | 16.3% | 26.4% | High precision, low recall |
|
||||
| `multi-ref` (mean) | 51.8% | 63.1% | 56.9% | Balanced |
|
||||
| `multi-ref` (max) | **51.8%** | **75.3%** | **61.4%** | **Best overall** |
|
||||
|
||||
*Simple method returns all matches above threshold, causing many duplicates.
|
||||
|
||||
### Embedding Model Comparison
|
||||
|
||||
| Model | Precision | Recall | F1 | Recommendation |
|
||||
|-------|-----------|--------|-----|----------------|
|
||||
| `openai/clip-vit-large-patch14` | **49.1%** | **77.0%** | **59.9%** | **Recommended** |
|
||||
| `facebook/dinov2-small` | 22.4% | 42.8% | 29.5% | Not recommended |
|
||||
| `facebook/dinov2-large` | 32.2% | 28.5% | 30.2% | Not recommended |
|
||||
|
||||
CLIP significantly outperforms DINOv2 for logo matching tasks.
|
||||
|
||||
### Preprocessing Mode Comparison
|
||||
|
||||
| Mode | Precision | Recall | F1 | Notes |
|
||||
|------|-----------|--------|-----|-------|
|
||||
| `default` | **50.2%** | 81.6% | 62.1% | **Recommended** - best precision |
|
||||
| `letterbox` | 42.4% | 119%* | 62.6% | Higher recall but worse precision |
|
||||
| `stretch` | 34.5% | 113%* | 52.9% | Not recommended |
|
||||
|
||||
*Recall >100% indicates multiple detections per expected logo.
|
||||
|
||||
**Recommendation:** Use `default` preprocessing. While letterbox shows marginally higher F1, it has significantly worse precision (more false positives).
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
This project provides tools to:
|
||||
@ -97,9 +201,9 @@ uv run python test_logo_detection.py -n 50 --seed 42
|
||||
| `--clear-cache` | False | Clear embedding cache before running |
|
||||
|
||||
**Matching Methods:**
|
||||
- `simple` - Returns all logos above threshold (baseline, most permissive)
|
||||
- `margin` - Requires margin over second-best match (reduces false positives)
|
||||
- `multi-ref` - Aggregates scores across multiple reference images per logo
|
||||
- `simple` - Returns all logos above threshold (not recommended - too many false positives)
|
||||
- `margin` - Requires margin over second-best match (high precision, low recall)
|
||||
- `multi-ref` - **Recommended.** Aggregates scores across multiple reference images per logo
|
||||
|
||||
See `--help` for all options.
|
||||
|
||||
@ -114,13 +218,18 @@ See `--help` for all options.
|
||||
|
||||
# Compare embedding models (CLIP vs DINOv2)
|
||||
./run_model_comparison.sh
|
||||
|
||||
# Test different refs-per-logo values
|
||||
./run_refs_per_logo_test.sh
|
||||
```
|
||||
|
||||
| Script | Purpose | Output File |
|
||||
|--------|---------|-------------|
|
||||
| `run_comparison_tests.sh` | Compare all 4 matching methods | `comparison_results.txt` |
|
||||
| `run_threshold_tests.sh` | Test threshold/margin combinations | `threshold_test_results.txt` |
|
||||
| `run_model_comparison.sh` | Compare CLIP vs DINOv2 models | `model_comparison_results.txt` |
|
||||
| `run_comparison_tests.sh` | Compare matching methods | `test_results/comparison_*.txt` |
|
||||
| `run_threshold_tests.sh` | Test threshold/margin combinations | `test_results/threshold_*.txt` |
|
||||
| `run_model_comparison.sh` | Compare CLIP vs DINOv2 models | `test_results/model_comparison_results.txt` |
|
||||
| `run_refs_per_logo_test.sh` | Test refs-per-logo values | `test_results/refs_per_logo_analysis.txt` |
|
||||
| `run_preprocess_test.sh` | Compare preprocessing modes | `test_results/preprocessing_comparison.txt` |
|
||||
|
||||
## Project Structure
|
||||
|
||||
|
||||
@ -23,7 +23,6 @@ import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Optional, Any
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
class DetectLogosDETR:
|
||||
@ -765,311 +764,4 @@ class DetectLogosDETR:
|
||||
f"(threshold: {similarity_threshold})"
|
||||
)
|
||||
|
||||
return matched_detections
|
||||
|
||||
# =========================================================================
|
||||
# Hybrid Text + CLIP Matching
|
||||
# =========================================================================
|
||||
|
||||
def set_text_detector(self, text_detector) -> None:
|
||||
"""
|
||||
Set an optional text detector for hybrid matching.
|
||||
|
||||
Args:
|
||||
text_detector: Instance of DetectText class from text_recognition.py
|
||||
"""
|
||||
self.text_detector = text_detector
|
||||
self.logger.info("Text detector enabled for hybrid matching")
|
||||
|
||||
def extract_text(self, image: np.ndarray, min_confidence: float = 0.3) -> List[str]:
|
||||
"""
|
||||
Extract text from an image using the text detector.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (BGR format)
|
||||
min_confidence: Minimum OCR confidence to accept text
|
||||
|
||||
Returns:
|
||||
List of detected text strings (lowercased, stripped)
|
||||
"""
|
||||
if not hasattr(self, 'text_detector') or self.text_detector is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
results, _ = self.text_detector.detect(image)
|
||||
# Filter by confidence and normalize text
|
||||
texts = []
|
||||
for text, confidence in results:
|
||||
if confidence >= min_confidence:
|
||||
# Normalize: lowercase, strip whitespace, remove special chars
|
||||
normalized = text.lower().strip()
|
||||
if len(normalized) >= 2: # Ignore single characters
|
||||
texts.append(normalized)
|
||||
return texts
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Text extraction failed: {e}")
|
||||
return []
|
||||
|
||||
def extract_text_pil(self, pil_image: Image.Image, min_confidence: float = 0.3) -> List[str]:
|
||||
"""
|
||||
Extract text from a PIL image.
|
||||
|
||||
Args:
|
||||
pil_image: PIL Image (RGB format)
|
||||
min_confidence: Minimum OCR confidence
|
||||
|
||||
Returns:
|
||||
List of detected text strings
|
||||
"""
|
||||
# Convert PIL to OpenCV format
|
||||
cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||
return self.extract_text(cv_image, min_confidence)
|
||||
|
||||
@staticmethod
|
||||
def compute_text_similarity(text1_list: List[str], text2_list: List[str]) -> float:
|
||||
"""
|
||||
Compute fuzzy text similarity between two lists of text strings.
|
||||
|
||||
Uses a combination of exact matches and fuzzy matching to handle
|
||||
OCR variations like case differences, spacing, and minor errors.
|
||||
|
||||
Args:
|
||||
text1_list: List of text strings from first image
|
||||
text2_list: List of text strings from second image
|
||||
|
||||
Returns:
|
||||
Similarity score between 0 and 1
|
||||
"""
|
||||
if not text1_list or not text2_list:
|
||||
return 0.0
|
||||
|
||||
# Combine all text into single strings for overall comparison
|
||||
text1_combined = " ".join(sorted(text1_list))
|
||||
text2_combined = " ".join(sorted(text2_list))
|
||||
|
||||
# Method 1: Sequence matching on combined text
|
||||
seq_similarity = SequenceMatcher(None, text1_combined, text2_combined).ratio()
|
||||
|
||||
# Method 2: Token overlap (Jaccard-like)
|
||||
# Split into tokens
|
||||
tokens1 = set(text1_combined.split())
|
||||
tokens2 = set(text2_combined.split())
|
||||
|
||||
if tokens1 and tokens2:
|
||||
intersection = len(tokens1 & tokens2)
|
||||
union = len(tokens1 | tokens2)
|
||||
token_similarity = intersection / union if union > 0 else 0
|
||||
else:
|
||||
token_similarity = 0
|
||||
|
||||
# Method 3: Best pairwise match for each text in list1
|
||||
pairwise_scores = []
|
||||
for t1 in text1_list:
|
||||
best_match = 0
|
||||
for t2 in text2_list:
|
||||
score = SequenceMatcher(None, t1, t2).ratio()
|
||||
best_match = max(best_match, score)
|
||||
pairwise_scores.append(best_match)
|
||||
|
||||
pairwise_similarity = sum(pairwise_scores) / len(pairwise_scores) if pairwise_scores else 0
|
||||
|
||||
# Combine methods (weighted average)
|
||||
combined = (seq_similarity * 0.3 + token_similarity * 0.3 + pairwise_similarity * 0.4)
|
||||
|
||||
return combined
|
||||
|
||||
@staticmethod
|
||||
def texts_match(
|
||||
ref_texts: List[str],
|
||||
det_texts: List[str],
|
||||
threshold: float = 0.5
|
||||
) -> Tuple[bool, float]:
|
||||
"""
|
||||
Determine if texts match above a threshold.
|
||||
|
||||
Args:
|
||||
ref_texts: Text from reference logo
|
||||
det_texts: Text from detected region
|
||||
threshold: Minimum similarity to consider a match
|
||||
|
||||
Returns:
|
||||
Tuple of (is_match, similarity_score)
|
||||
"""
|
||||
if not ref_texts:
|
||||
# Reference has no text - can't match on text
|
||||
return (False, 0.0)
|
||||
|
||||
if not det_texts:
|
||||
# Reference has text but detection doesn't - no text match
|
||||
return (False, 0.0)
|
||||
|
||||
similarity = DetectLogosDETR.compute_text_similarity(ref_texts, det_texts)
|
||||
return (similarity >= threshold, similarity)
|
||||
|
||||
def find_best_match_hybrid(
|
||||
self,
|
||||
detected_embedding: torch.Tensor,
|
||||
detected_image: np.ndarray,
|
||||
reference_data: Dict[str, Dict[str, Any]],
|
||||
clip_threshold: float = 0.70,
|
||||
clip_threshold_with_text: float = 0.60,
|
||||
clip_threshold_text_mismatch: float = 0.80,
|
||||
text_similarity_threshold: float = 0.5,
|
||||
margin: float = 0.05,
|
||||
use_mean_similarity: bool = False,
|
||||
) -> Optional[Tuple[str, float, Dict[str, Any]]]:
|
||||
"""
|
||||
Find best match using hybrid text + CLIP approach.
|
||||
|
||||
Strategy:
|
||||
- If reference has text AND detection has matching text:
|
||||
→ Use lower CLIP threshold (text provides additional confidence)
|
||||
- If reference has text but detection doesn't match:
|
||||
→ Use higher CLIP threshold (need more visual confidence)
|
||||
- If reference has no text:
|
||||
→ Use standard CLIP threshold
|
||||
|
||||
Args:
|
||||
detected_embedding: CLIP embedding from detected logo region
|
||||
detected_image: OpenCV image of the detected region (for text extraction)
|
||||
reference_data: Dict mapping logo name to:
|
||||
{
|
||||
'embeddings': List[torch.Tensor], # CLIP embeddings
|
||||
'texts': List[str], # Extracted text from reference
|
||||
}
|
||||
clip_threshold: Standard CLIP threshold for no-text references
|
||||
clip_threshold_with_text: Lower threshold when text matches
|
||||
clip_threshold_text_mismatch: Higher threshold when text expected but missing
|
||||
text_similarity_threshold: Threshold for text matching
|
||||
margin: Required margin between best and second-best
|
||||
use_mean_similarity: Use mean vs max for multi-ref aggregation
|
||||
|
||||
Returns:
|
||||
Tuple of (label, clip_similarity, match_info) or None
|
||||
match_info contains: text_matched, text_similarity, threshold_used
|
||||
"""
|
||||
if not reference_data:
|
||||
return None
|
||||
|
||||
# Extract text from detected region
|
||||
detected_texts = self.extract_text(detected_image)
|
||||
|
||||
# Calculate scores for all logos
|
||||
logo_scores = []
|
||||
|
||||
for label, ref_info in reference_data.items():
|
||||
ref_embeddings = ref_info.get('embeddings', [])
|
||||
ref_texts = ref_info.get('texts', [])
|
||||
|
||||
if not ref_embeddings:
|
||||
continue
|
||||
|
||||
# Calculate CLIP similarity
|
||||
similarities = []
|
||||
for ref_emb in ref_embeddings:
|
||||
sim = self.compare_embeddings(detected_embedding, ref_emb)
|
||||
similarities.append(sim)
|
||||
|
||||
if use_mean_similarity:
|
||||
clip_score = sum(similarities) / len(similarities)
|
||||
else:
|
||||
clip_score = max(similarities)
|
||||
|
||||
# Determine text match status and appropriate threshold
|
||||
has_ref_text = len(ref_texts) > 0
|
||||
text_matched, text_sim = self.texts_match(
|
||||
ref_texts, detected_texts, text_similarity_threshold
|
||||
)
|
||||
|
||||
if has_ref_text:
|
||||
if text_matched:
|
||||
# Text matches - use lower threshold, boost confidence
|
||||
threshold_used = clip_threshold_with_text
|
||||
match_type = "text_match"
|
||||
else:
|
||||
# Reference has text but detection doesn't match
|
||||
# Require higher CLIP threshold
|
||||
threshold_used = clip_threshold_text_mismatch
|
||||
match_type = "text_mismatch"
|
||||
else:
|
||||
# No text in reference - standard matching
|
||||
threshold_used = clip_threshold
|
||||
match_type = "no_text"
|
||||
text_sim = 0.0
|
||||
|
||||
# Check if CLIP score meets the appropriate threshold
|
||||
if clip_score >= threshold_used:
|
||||
logo_scores.append({
|
||||
'label': label,
|
||||
'clip_score': clip_score,
|
||||
'text_matched': text_matched,
|
||||
'text_similarity': text_sim,
|
||||
'threshold_used': threshold_used,
|
||||
'match_type': match_type,
|
||||
'has_ref_text': has_ref_text,
|
||||
})
|
||||
|
||||
if not logo_scores:
|
||||
return None
|
||||
|
||||
# Sort by CLIP score descending
|
||||
logo_scores.sort(key=lambda x: x['clip_score'], reverse=True)
|
||||
|
||||
best = logo_scores[0]
|
||||
|
||||
# Check margin against second-best
|
||||
if margin > 0 and len(logo_scores) > 1:
|
||||
second_best_score = logo_scores[1]['clip_score']
|
||||
if best['clip_score'] - second_best_score < margin:
|
||||
return None
|
||||
|
||||
match_info = {
|
||||
'text_matched': best['text_matched'],
|
||||
'text_similarity': best['text_similarity'],
|
||||
'threshold_used': best['threshold_used'],
|
||||
'match_type': best['match_type'],
|
||||
'has_ref_text': best['has_ref_text'],
|
||||
'detected_texts': detected_texts,
|
||||
}
|
||||
|
||||
return (best['label'], best['clip_score'], match_info)
|
||||
|
||||
def prepare_reference_data_hybrid(
|
||||
self,
|
||||
reference_images: Dict[str, List[np.ndarray]],
|
||||
text_min_confidence: float = 0.3,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Prepare reference data for hybrid matching by computing embeddings and extracting text.
|
||||
|
||||
Args:
|
||||
reference_images: Dict mapping logo name to list of reference images (OpenCV BGR)
|
||||
text_min_confidence: Minimum confidence for text extraction
|
||||
|
||||
Returns:
|
||||
Dict mapping logo name to {'embeddings': [...], 'texts': [...]}
|
||||
"""
|
||||
reference_data = {}
|
||||
|
||||
for logo_name, images in reference_images.items():
|
||||
embeddings = []
|
||||
all_texts = set()
|
||||
|
||||
for img in images:
|
||||
# Compute CLIP embedding
|
||||
emb = self.get_embedding(img)
|
||||
embeddings.append(emb)
|
||||
|
||||
# Extract text
|
||||
texts = self.extract_text(img, text_min_confidence)
|
||||
all_texts.update(texts)
|
||||
|
||||
reference_data[logo_name] = {
|
||||
'embeddings': embeddings,
|
||||
'texts': list(all_texts),
|
||||
}
|
||||
|
||||
if all_texts:
|
||||
self.logger.debug(f"Reference '{logo_name}' has text: {all_texts}")
|
||||
|
||||
return reference_data
|
||||
return matched_detections
|
||||
364
logo_detection_embeddings.py
Normal file
364
logo_detection_embeddings.py
Normal file
@ -0,0 +1,364 @@
|
||||
"""
|
||||
Logo detection using DETR for object detection and selectable embedding models for feature matching.
|
||||
|
||||
This module provides a class for detecting logos in images using:
|
||||
1. DETR (DEtection TRansformer) for initial logo region detection
|
||||
2. Selectable embedding model (CLIP, DINOv2, or SigLIP) for feature extraction and matching
|
||||
|
||||
Key features:
|
||||
- Multiple reference images per logo entry, averaged into a single embedding
|
||||
- Cache-aware: averaged embeddings are only recalculated when the filenames list changes
|
||||
- Supports local model directories with fallback to HuggingFace
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoModel,
|
||||
AutoProcessor,
|
||||
CLIPModel,
|
||||
CLIPProcessor,
|
||||
Dinov2Model,
|
||||
pipeline,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class DetectLogosEmbeddings:
|
||||
"""
|
||||
Logo detection class using DETR and a selectable embedding model.
|
||||
|
||||
This class detects logos in images by:
|
||||
1. Using DETR to find potential logo regions (bounding boxes)
|
||||
2. Extracting embeddings for each detected region using the selected model
|
||||
3. Comparing embeddings with averaged reference logo embeddings for identification
|
||||
|
||||
Supported embedding models:
|
||||
- clip: openai/clip-vit-large-patch14
|
||||
- dinov2: facebook/dinov2-base (recommended for visual similarity)
|
||||
- siglip: google/siglip-base-patch16-224
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger,
|
||||
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
|
||||
embedding_model_type: str = "dinov2",
|
||||
detr_threshold: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Initialize DETR and embedding models.
|
||||
|
||||
Args:
|
||||
logger: Logger instance for logging
|
||||
detr_model: HuggingFace model name or local path for DETR object detection
|
||||
embedding_model_type: One of "clip", "dinov2", or "siglip"
|
||||
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||
"""
|
||||
self.logger = logger
|
||||
self.detr_threshold = detr_threshold
|
||||
self.embedding_model_type = embedding_model_type
|
||||
|
||||
# Set device
|
||||
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
self.device_index = 0 if torch.cuda.is_available() else -1
|
||||
self.device = torch.device(self.device_str)
|
||||
|
||||
self.logger.info(
|
||||
f"Initializing DetectLogosEmbeddings on device: {self.device_str}, "
|
||||
f"embedding model: {embedding_model_type}"
|
||||
)
|
||||
|
||||
# --- DETR model ---
|
||||
default_detr_dir = os.environ.get(
|
||||
"LOGO_DETR_MODEL_DIR", "models/logo_detection/detr"
|
||||
)
|
||||
detr_model_path = self._resolve_model_path(detr_model, default_detr_dir, "DETR")
|
||||
|
||||
self.logger.info(f"Loading DETR model: {detr_model_path}")
|
||||
self.detr_pipe = pipeline(
|
||||
task="object-detection",
|
||||
model=detr_model_path,
|
||||
device=self.device_index,
|
||||
use_fast=True,
|
||||
)
|
||||
|
||||
# --- Embedding model ---
|
||||
self._load_embedding_model(embedding_model_type)
|
||||
|
||||
self.logger.info("DetectLogosEmbeddings initialization complete")
|
||||
|
||||
def _load_embedding_model(self, model_type: str) -> None:
|
||||
"""
|
||||
Load the selected embedding model.
|
||||
|
||||
Args:
|
||||
model_type: One of "clip", "dinov2", or "siglip"
|
||||
"""
|
||||
default_embedding_dir = os.environ.get(
|
||||
"LOGO_EMBEDDING_MODEL_DIR", f"models/logo_detection/{model_type}"
|
||||
)
|
||||
|
||||
if model_type == "clip":
|
||||
model_name = "openai/clip-vit-large-patch14"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "CLIP"
|
||||
)
|
||||
self.logger.info(f"Loading CLIP model: {model_path}")
|
||||
self._clip_model = CLIPModel.from_pretrained(model_path).to(self.device)
|
||||
self._clip_processor = CLIPProcessor.from_pretrained(model_path)
|
||||
self._clip_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._clip_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self._clip_model.get_image_features(**inputs)
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
elif model_type == "dinov2":
|
||||
model_name = "facebook/dinov2-base"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "DINOv2"
|
||||
)
|
||||
self.logger.info(f"Loading DINOv2 model: {model_path}")
|
||||
self._dinov2_model = Dinov2Model.from_pretrained(model_path).to(self.device)
|
||||
self._dinov2_processor = AutoImageProcessor.from_pretrained(model_path)
|
||||
self._dinov2_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._dinov2_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self._dinov2_model(**inputs)
|
||||
# Use CLS token embedding
|
||||
features = outputs.last_hidden_state[:, 0, :]
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
elif model_type == "siglip":
|
||||
model_name = "google/siglip-base-patch16-224"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "SigLIP"
|
||||
)
|
||||
self.logger.info(f"Loading SigLIP model: {model_path}")
|
||||
self._siglip_model = AutoModel.from_pretrained(model_path).to(self.device)
|
||||
self._siglip_processor = AutoProcessor.from_pretrained(model_path)
|
||||
self._siglip_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._siglip_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self._siglip_model.get_image_features(**inputs)
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown embedding model type: {model_type}. "
|
||||
f"Use 'clip', 'dinov2', or 'siglip'"
|
||||
)
|
||||
|
||||
self._embed_fn = embed_fn
|
||||
|
||||
def _resolve_model_path(
|
||||
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Resolve model path, checking for local models before using HuggingFace.
|
||||
|
||||
Args:
|
||||
model_name_or_path: HuggingFace model name or absolute path
|
||||
default_local_dir: Default local directory to check
|
||||
model_type: Type of model (for logging)
|
||||
|
||||
Returns:
|
||||
Resolved model path (local path or HuggingFace model name)
|
||||
"""
|
||||
# If it's an absolute path, use it directly
|
||||
if os.path.isabs(model_name_or_path):
|
||||
if os.path.exists(model_name_or_path):
|
||||
self.logger.info(
|
||||
f"{model_type} model: Using local model at {model_name_or_path}"
|
||||
)
|
||||
return model_name_or_path
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"{model_type} model: Local path {model_name_or_path} does not exist, "
|
||||
f"falling back to HuggingFace"
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
# Check if default local directory exists
|
||||
if os.path.exists(default_local_dir):
|
||||
config_file = os.path.join(default_local_dir, "config.json")
|
||||
if os.path.exists(config_file):
|
||||
abs_path = os.path.abspath(default_local_dir)
|
||||
self.logger.info(
|
||||
f"{model_type} model: Found local model at {abs_path}"
|
||||
)
|
||||
return abs_path
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"{model_type} model: Local directory {default_local_dir} exists but "
|
||||
f"is not a valid model (missing config.json)"
|
||||
)
|
||||
|
||||
# Use HuggingFace model name
|
||||
self.logger.info(
|
||||
f"{model_type} model: No local model found, will download from HuggingFace: "
|
||||
f"{model_name_or_path}"
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect logos in an image and return bounding boxes with embeddings.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (BGR format, numpy array)
|
||||
|
||||
Returns:
|
||||
List of dictionaries, each containing:
|
||||
- 'box': dict with 'xmin', 'ymin', 'xmax', 'ymax' (pixel coordinates)
|
||||
- 'score': DETR confidence score (float 0-1)
|
||||
- 'embedding': Feature embedding (torch.Tensor)
|
||||
- 'label': DETR predicted label (string)
|
||||
"""
|
||||
# Convert OpenCV BGR to RGB PIL Image
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
|
||||
# Run DETR detection
|
||||
predictions = self.detr_pipe(pil_image)
|
||||
|
||||
# Filter by threshold and add embeddings
|
||||
detections = []
|
||||
for pred in predictions:
|
||||
score = pred.get("score", 0.0)
|
||||
if score < self.detr_threshold:
|
||||
continue
|
||||
|
||||
box = pred.get("box", {})
|
||||
xmin = box.get("xmin", 0)
|
||||
ymin = box.get("ymin", 0)
|
||||
xmax = box.get("xmax", 0)
|
||||
ymax = box.get("ymax", 0)
|
||||
|
||||
# Extract bounding box region
|
||||
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
|
||||
|
||||
# Get embedding for this region
|
||||
embedding = self._embed_fn(bbox_crop)
|
||||
|
||||
detections.append(
|
||||
{
|
||||
"box": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax},
|
||||
"score": score,
|
||||
"embedding": embedding,
|
||||
"label": pred.get("label", "logo"),
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Detected {len(detections)} logos (threshold: {self.detr_threshold})"
|
||||
)
|
||||
return detections
|
||||
|
||||
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Get embedding for a single reference logo image.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (BGR format, numpy array)
|
||||
|
||||
Returns:
|
||||
Normalized feature embedding (torch.Tensor)
|
||||
"""
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
return self._embed_fn(pil_image)
|
||||
|
||||
def get_averaged_embedding(self, images: List[np.ndarray]) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Compute averaged embedding from multiple reference logo images.
|
||||
|
||||
Follows the averaging pattern from db_embeddings.py:
|
||||
1. Compute embedding for each image
|
||||
2. Stack and average across all images
|
||||
3. Re-normalize the averaged embedding
|
||||
|
||||
Args:
|
||||
images: List of OpenCV images (BGR format, numpy arrays)
|
||||
|
||||
Returns:
|
||||
Normalized averaged embedding (torch.Tensor, shape [1, D]),
|
||||
or None if no valid embeddings could be computed
|
||||
"""
|
||||
embeddings = []
|
||||
for img in images:
|
||||
try:
|
||||
emb = self.get_embedding(img)
|
||||
embeddings.append(emb)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to compute embedding for reference image: {e}")
|
||||
|
||||
if not embeddings:
|
||||
return None
|
||||
|
||||
# Stack: (N, D), average: (1, D), re-normalize
|
||||
stacked = torch.cat(embeddings, dim=0)
|
||||
avg_emb = stacked.mean(dim=0, keepdim=True)
|
||||
avg_emb = F.normalize(avg_emb, dim=-1)
|
||||
|
||||
self.logger.debug(
|
||||
f"Computed averaged embedding from {len(embeddings)} reference image(s)"
|
||||
)
|
||||
return avg_emb
|
||||
|
||||
def compare_embeddings(
|
||||
self, embedding1: torch.Tensor, embedding2: torch.Tensor
|
||||
) -> float:
|
||||
"""
|
||||
Compute cosine similarity between two embeddings.
|
||||
|
||||
Args:
|
||||
embedding1: First embedding (torch.Tensor)
|
||||
embedding2: Second embedding (torch.Tensor)
|
||||
|
||||
Returns:
|
||||
Cosine similarity score (float, range: -1 to 1, typically 0 to 1)
|
||||
"""
|
||||
# Ensure tensors are on the same device
|
||||
if embedding1.device != embedding2.device:
|
||||
embedding2 = embedding2.to(embedding1.device)
|
||||
|
||||
similarity = F.cosine_similarity(embedding1, embedding2, dim=-1)
|
||||
return similarity.item()
|
||||
|
||||
@staticmethod
|
||||
def make_filenames_hash(filenames: List[str]) -> str:
|
||||
"""
|
||||
Compute a deterministic hash of a filenames list.
|
||||
|
||||
Used for cache invalidation — if the filenames list changes,
|
||||
the hash changes, triggering re-computation of averaged embeddings.
|
||||
|
||||
Args:
|
||||
filenames: List of filename strings
|
||||
|
||||
Returns:
|
||||
16-character hex hash string
|
||||
"""
|
||||
canonical = json.dumps(sorted(filenames))
|
||||
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]
|
||||
52
results_average_embeddings.txt
Normal file
52
results_average_embeddings.txt
Normal file
@ -0,0 +1,52 @@
|
||||
======================================================================
|
||||
BURNLEY LOGO DETECTION TEST
|
||||
Model: dinov2
|
||||
Method: Margin-based (margin=0.05)
|
||||
======================================================================
|
||||
Date: 2026-03-31 11:45:03
|
||||
|
||||
Configuration:
|
||||
Embedding model: dinov2
|
||||
Similarity threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Matching margin: 0.05
|
||||
Test images processed: 516
|
||||
Reference logos: barnfield, vertu
|
||||
|
||||
Results:
|
||||
True Positives: 28
|
||||
False Positives: 36
|
||||
False Negatives: 125
|
||||
Total Expected: 146
|
||||
|
||||
Scores:
|
||||
Precision: 0.4375 (43.8%)
|
||||
Recall: 0.1918 (19.2%)
|
||||
F1 Score: 0.2667 (26.7%)
|
||||
|
||||
======================================================================
|
||||
BURNLEY LOGO DETECTION TEST
|
||||
Model: dinov2
|
||||
Method: Margin-based (margin=0.05)
|
||||
======================================================================
|
||||
Date: 2026-03-31 12:29:32
|
||||
|
||||
Configuration:
|
||||
Embedding model: dinov2
|
||||
Similarity threshold: 0.7
|
||||
DETR threshold: 0.5
|
||||
Matching margin: 0.05
|
||||
Test images processed: 516
|
||||
Reference logos: barnfield, vertu
|
||||
|
||||
Results:
|
||||
True Positives: 28
|
||||
False Positives: 36
|
||||
False Negatives: 125
|
||||
Total Expected: 146
|
||||
|
||||
Scores:
|
||||
Precision: 0.4375 (43.8%)
|
||||
Recall: 0.1918 (19.2%)
|
||||
F1 Score: 0.2667 (26.7%)
|
||||
|
||||
@ -1,168 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Test the hybrid text+CLIP matching approach for logo detection.
|
||||
#
|
||||
# This approach uses text recognition to improve logo matching:
|
||||
# - If reference logo has text and detection matches it: use lower CLIP threshold
|
||||
# - If reference logo has text but detection doesn't match: use higher CLIP threshold
|
||||
# - If reference logo has no text: use standard CLIP threshold
|
||||
#
|
||||
# Usage:
|
||||
# ./run_hybrid_test.sh
|
||||
#
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
OUTPUT_FILE="${SCRIPT_DIR}/test_results/hybrid_matching_results.txt"
|
||||
|
||||
# Model - baseline CLIP
|
||||
MODEL="openai/clip-vit-large-patch14"
|
||||
|
||||
# Fixed parameters
|
||||
NUM_LOGOS=20
|
||||
REFS_PER_LOGO=10
|
||||
POSITIVE_SAMPLES=20
|
||||
NEGATIVE_SAMPLES=100
|
||||
SEED=42
|
||||
|
||||
# Create output directory if needed
|
||||
mkdir -p "${SCRIPT_DIR}/test_results"
|
||||
|
||||
# Clear output file and write header
|
||||
cat > "$OUTPUT_FILE" << EOF
|
||||
Hybrid Text+CLIP Matching Test Results
|
||||
======================================
|
||||
Date: $(date)
|
||||
|
||||
Model: ${MODEL}
|
||||
|
||||
Fixed Parameters:
|
||||
Number of logo brands: ${NUM_LOGOS}
|
||||
Refs per logo: ${REFS_PER_LOGO}
|
||||
Positive samples/logo: ${POSITIVE_SAMPLES}
|
||||
Negative samples/logo: ${NEGATIVE_SAMPLES}
|
||||
Seed: ${SEED}
|
||||
|
||||
EOF
|
||||
|
||||
echo "Hybrid Text+CLIP Matching Test"
|
||||
echo "==============================="
|
||||
echo "Model: ${MODEL}"
|
||||
echo ""
|
||||
|
||||
# Test 1: Compare hybrid vs multi-ref baseline
|
||||
echo "=== Test 1: Multi-ref baseline (for comparison) ==="
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "=== BASELINE: Multi-ref (max) at threshold 0.70 ===" >> "$OUTPUT_FILE"
|
||||
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method multi-ref \
|
||||
--min-matching-refs 1 \
|
||||
--use-max-similarity \
|
||||
--threshold 0.70 \
|
||||
--margin 0.05 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL" \
|
||||
--output-file "$OUTPUT_FILE" \
|
||||
--no-cache
|
||||
|
||||
echo ""
|
||||
|
||||
# Test 2: Hybrid with default thresholds
|
||||
echo "=== Test 2: Hybrid with default thresholds ==="
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "=== HYBRID: default thresholds (0.70/0.60/0.80) ===" >> "$OUTPUT_FILE"
|
||||
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method hybrid \
|
||||
--threshold 0.70 \
|
||||
--hybrid-text-threshold 0.60 \
|
||||
--hybrid-no-text-threshold 0.80 \
|
||||
--text-similarity-threshold 0.5 \
|
||||
--margin 0.05 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL" \
|
||||
--output-file "$OUTPUT_FILE" \
|
||||
--no-cache
|
||||
|
||||
echo ""
|
||||
|
||||
# Test 3: Hybrid with more aggressive text bonus
|
||||
echo "=== Test 3: Hybrid with lower text-match threshold ==="
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "=== HYBRID: aggressive text bonus (0.70/0.55/0.80) ===" >> "$OUTPUT_FILE"
|
||||
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method hybrid \
|
||||
--threshold 0.70 \
|
||||
--hybrid-text-threshold 0.55 \
|
||||
--hybrid-no-text-threshold 0.80 \
|
||||
--text-similarity-threshold 0.5 \
|
||||
--margin 0.05 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL" \
|
||||
--output-file "$OUTPUT_FILE" \
|
||||
--no-cache
|
||||
|
||||
echo ""
|
||||
|
||||
# Test 4: Hybrid with stricter text mismatch penalty
|
||||
echo "=== Test 4: Hybrid with stricter text mismatch penalty ==="
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "=== HYBRID: strict mismatch (0.70/0.60/0.85) ===" >> "$OUTPUT_FILE"
|
||||
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method hybrid \
|
||||
--threshold 0.70 \
|
||||
--hybrid-text-threshold 0.60 \
|
||||
--hybrid-no-text-threshold 0.85 \
|
||||
--text-similarity-threshold 0.5 \
|
||||
--margin 0.05 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL" \
|
||||
--output-file "$OUTPUT_FILE" \
|
||||
--no-cache
|
||||
|
||||
echo ""
|
||||
|
||||
# Test 5: Hybrid with lower text similarity threshold (more lenient OCR matching)
|
||||
echo "=== Test 5: Hybrid with lenient text matching ==="
|
||||
echo "" >> "$OUTPUT_FILE"
|
||||
echo "=== HYBRID: lenient text matching (text_sim=0.4) ===" >> "$OUTPUT_FILE"
|
||||
|
||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||
--num-logos $NUM_LOGOS \
|
||||
--refs-per-logo $REFS_PER_LOGO \
|
||||
--positive-samples $POSITIVE_SAMPLES \
|
||||
--negative-samples $NEGATIVE_SAMPLES \
|
||||
--matching-method hybrid \
|
||||
--threshold 0.70 \
|
||||
--hybrid-text-threshold 0.60 \
|
||||
--hybrid-no-text-threshold 0.80 \
|
||||
--text-similarity-threshold 0.4 \
|
||||
--margin 0.05 \
|
||||
--seed $SEED \
|
||||
--embedding-model "$MODEL" \
|
||||
--output-file "$OUTPUT_FILE" \
|
||||
--no-cache
|
||||
|
||||
echo ""
|
||||
echo "======================================="
|
||||
echo "Tests complete!"
|
||||
echo "Results saved to: $OUTPUT_FILE"
|
||||
echo "======================================="
|
||||
521
test_burnley_detection.py
Normal file
521
test_burnley_detection.py
Normal file
@ -0,0 +1,521 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for logo detection accuracy on Burnley test images.
|
||||
|
||||
Uses DetectLogosEmbeddings from logo_detection_embeddings.py to detect
|
||||
barnfield and vertu logos. Ground truth is determined by filename prefix:
|
||||
- "vertu_" → contains vertu logo
|
||||
- "barnfield_" → contains barnfield logo
|
||||
- "barnfield+vertu_" → contains both logos
|
||||
- anything else → no target logos
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import pickle
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from logo_detection_embeddings import DetectLogosEmbeddings
|
||||
|
||||
|
||||
def setup_logging(verbose: bool = False) -> logging.Logger:
|
||||
"""Configure logging."""
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_image(image_path: Path) -> Optional[cv2.Mat]:
|
||||
"""Load an image using OpenCV."""
|
||||
img = cv2.imread(str(image_path))
|
||||
if img is None:
|
||||
return None
|
||||
return img
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Simple file-based cache for embeddings."""
|
||||
|
||||
def __init__(self, cache_path: Path):
|
||||
self.cache_path = cache_path
|
||||
self.cache: Dict[str, Any] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
if self.cache_path.exists():
|
||||
try:
|
||||
with open(self.cache_path, "rb") as f:
|
||||
self.cache = pickle.load(f)
|
||||
except Exception:
|
||||
self.cache = {}
|
||||
|
||||
def save(self):
|
||||
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.cache_path, "wb") as f:
|
||||
pickle.dump(self.cache, f)
|
||||
|
||||
def get(self, key: str):
|
||||
return self.cache.get(key)
|
||||
|
||||
def put(self, key: str, value):
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.cache[key] = value.cpu()
|
||||
else:
|
||||
self.cache[key] = value
|
||||
|
||||
def __len__(self):
|
||||
return len(self.cache)
|
||||
|
||||
|
||||
def get_expected_logos(filename: str) -> Set[str]:
|
||||
"""Determine expected logos from filename prefix."""
|
||||
name = filename.lower()
|
||||
if name.startswith("barnfield+vertu_"):
|
||||
return {"barnfield", "vertu"}
|
||||
elif name.startswith("barnfield_"):
|
||||
return {"barnfield"}
|
||||
elif name.startswith("vertu_"):
|
||||
return {"vertu"}
|
||||
return set()
|
||||
|
||||
|
||||
def load_reference_images(ref_dir: Path, logger: logging.Logger) -> List[cv2.Mat]:
|
||||
"""Load all images from a reference directory."""
|
||||
images = []
|
||||
for path in sorted(ref_dir.iterdir()):
|
||||
if path.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp"):
|
||||
img = load_image(path)
|
||||
if img is not None:
|
||||
images.append(img)
|
||||
else:
|
||||
logger.warning(f"Failed to load reference image: {path}")
|
||||
return images
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test logo detection on Burnley test images using DetectLogosEmbeddings"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--threshold",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Similarity threshold for matching (default: 0.7)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d", "--detr-threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="DETR detection confidence threshold (default: 0.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--embedding-model",
|
||||
type=str,
|
||||
choices=["clip", "dinov2", "siglip"],
|
||||
default="dinov2",
|
||||
help="Embedding model type (default: dinov2)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--margin",
|
||||
type=float,
|
||||
default=0.05,
|
||||
help="Required margin between best and second-best match (default: 0.05)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--similarity-details",
|
||||
action="store_true",
|
||||
help="Output detailed similarity scores for each detection",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-cache",
|
||||
action="store_true",
|
||||
help="Disable embedding cache",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear-cache",
|
||||
action="store_true",
|
||||
help="Clear embedding cache before running",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Append results summary to this file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger = setup_logging(args.verbose)
|
||||
|
||||
# Paths
|
||||
base_dir = Path(__file__).resolve().parent
|
||||
test_images_dir = base_dir / "burnley_test_images"
|
||||
barnfield_ref_dir = base_dir / "barnfield_reference_images"
|
||||
vertu_ref_dir = base_dir / "vertu_reference_images"
|
||||
cache_path = base_dir / ".burnley_embedding_cache.pkl"
|
||||
|
||||
# Verify directories exist
|
||||
for d, name in [(test_images_dir, "Test images"), (barnfield_ref_dir, "Barnfield refs"), (vertu_ref_dir, "Vertu refs")]:
|
||||
if not d.exists():
|
||||
logger.error(f"{name} directory not found: {d}")
|
||||
sys.exit(1)
|
||||
|
||||
# Handle cache
|
||||
if args.clear_cache and cache_path.exists():
|
||||
cache_path.unlink()
|
||||
logger.info("Cleared embedding cache")
|
||||
|
||||
cache = EmbeddingCache(cache_path) if not args.no_cache else None
|
||||
if cache:
|
||||
logger.info(f"Loaded {len(cache)} cached embeddings")
|
||||
|
||||
# Initialize detector
|
||||
logger.info(f"Initializing detector with embedding model: {args.embedding_model}")
|
||||
detector = DetectLogosEmbeddings(
|
||||
logger=logger,
|
||||
detr_threshold=args.detr_threshold,
|
||||
embedding_model_type=args.embedding_model,
|
||||
)
|
||||
|
||||
# Compute averaged reference embeddings
|
||||
logger.info("Computing reference embeddings...")
|
||||
|
||||
reference_embeddings: Dict[str, torch.Tensor] = {}
|
||||
for logo_name, ref_dir in [("barnfield", barnfield_ref_dir), ("vertu", vertu_ref_dir)]:
|
||||
cache_key = f"avg_ref:{logo_name}:{args.embedding_model}"
|
||||
cached = cache.get(cache_key) if cache else None
|
||||
|
||||
if cached is not None:
|
||||
reference_embeddings[logo_name] = cached
|
||||
logger.info(f"Loaded cached averaged embedding for {logo_name}")
|
||||
else:
|
||||
ref_images = load_reference_images(ref_dir, logger)
|
||||
logger.info(f"Computing averaged embedding for {logo_name} from {len(ref_images)} images")
|
||||
avg_emb = detector.get_averaged_embedding(ref_images)
|
||||
if avg_emb is None:
|
||||
logger.error(f"Failed to compute embedding for {logo_name}")
|
||||
sys.exit(1)
|
||||
reference_embeddings[logo_name] = avg_emb
|
||||
if cache:
|
||||
cache.put(cache_key, avg_emb)
|
||||
|
||||
# Collect test images
|
||||
test_files = sorted([
|
||||
f.name for f in test_images_dir.iterdir()
|
||||
if f.suffix.lower() in (".jpg", ".jpeg", ".png", ".bmp")
|
||||
])
|
||||
logger.info(f"Found {len(test_files)} test images")
|
||||
|
||||
# Metrics
|
||||
true_positives = 0
|
||||
false_positives = 0
|
||||
false_negatives = 0
|
||||
total_expected = 0
|
||||
results = []
|
||||
|
||||
similarity_details = {
|
||||
"true_positive_sims": [],
|
||||
"false_positive_sims": [],
|
||||
"missed_best_sims": [],
|
||||
"detection_details": [],
|
||||
}
|
||||
|
||||
# Process test images
|
||||
for test_filename in tqdm(test_files, desc="Testing"):
|
||||
test_path = test_images_dir / test_filename
|
||||
expected_logos = get_expected_logos(test_filename)
|
||||
total_expected += len(expected_logos)
|
||||
|
||||
# Check cache for detections
|
||||
det_cache_key = f"det:{test_filename}:{args.embedding_model}"
|
||||
cached_detections = cache.get(det_cache_key) if cache else None
|
||||
|
||||
if cached_detections is not None:
|
||||
detections = cached_detections
|
||||
else:
|
||||
test_img = load_image(test_path)
|
||||
if test_img is None:
|
||||
logger.warning(f"Failed to load test image: {test_path}")
|
||||
continue
|
||||
detections = detector.detect(test_img)
|
||||
if cache:
|
||||
cache.put(det_cache_key, detections)
|
||||
|
||||
# Match each detection against reference embeddings with margin
|
||||
matched_logos: Set[str] = set()
|
||||
for det_idx, detection in enumerate(detections):
|
||||
# Compute similarity to each reference logo
|
||||
sims: Dict[str, float] = {}
|
||||
for logo_name, ref_emb in reference_embeddings.items():
|
||||
sims[logo_name] = detector.compare_embeddings(
|
||||
detection["embedding"], ref_emb
|
||||
)
|
||||
|
||||
sorted_sims = sorted(sims.items(), key=lambda x: -x[1])
|
||||
|
||||
if args.similarity_details:
|
||||
similarity_details["detection_details"].append({
|
||||
"image": test_filename,
|
||||
"detection_idx": det_idx,
|
||||
"expected_logos": list(expected_logos),
|
||||
"similarities": sorted_sims,
|
||||
"detr_score": detection.get("score", 0),
|
||||
})
|
||||
|
||||
# Best match with margin check
|
||||
if not sorted_sims:
|
||||
continue
|
||||
|
||||
best_name, best_sim = sorted_sims[0]
|
||||
if best_sim < args.threshold:
|
||||
continue
|
||||
|
||||
# Check margin over second best
|
||||
if len(sorted_sims) > 1:
|
||||
second_sim = sorted_sims[1][1]
|
||||
if best_sim - second_sim < args.margin:
|
||||
continue
|
||||
|
||||
matched_logos.add(best_name)
|
||||
is_correct = best_name in expected_logos
|
||||
|
||||
if is_correct:
|
||||
true_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["true_positive_sims"].append(best_sim)
|
||||
else:
|
||||
false_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["false_positive_sims"].append(best_sim)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": best_name,
|
||||
"similarity": best_sim,
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
# Count missed detections
|
||||
missed = expected_logos - matched_logos
|
||||
false_negatives += len(missed)
|
||||
|
||||
for missed_logo in missed:
|
||||
if args.similarity_details and detections:
|
||||
best_sim_for_missed = 0
|
||||
ref_emb = reference_embeddings[missed_logo]
|
||||
for detection in detections:
|
||||
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||
best_sim_for_missed = max(best_sim_for_missed, sim)
|
||||
similarity_details["missed_best_sims"].append(best_sim_for_missed)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": None,
|
||||
"expected_logo": missed_logo,
|
||||
"similarity": None,
|
||||
"correct": False,
|
||||
})
|
||||
|
||||
# Save cache
|
||||
if cache:
|
||||
cache.save()
|
||||
logger.info(f"Saved {len(cache)} embeddings to cache")
|
||||
|
||||
# Calculate metrics
|
||||
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
|
||||
recall = true_positives / total_expected if total_expected > 0 else 0
|
||||
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("BURNLEY LOGO DETECTION TEST RESULTS")
|
||||
print("=" * 60)
|
||||
print(f"\nConfiguration:")
|
||||
print(f" Embedding model: {args.embedding_model}")
|
||||
print(f" Similarity threshold: {args.threshold}")
|
||||
print(f" DETR confidence threshold: {args.detr_threshold}")
|
||||
print(f" Matching margin: {args.margin}")
|
||||
print(f" Test images processed: {len(test_files)}")
|
||||
print(f" Reference logos: barnfield, vertu")
|
||||
|
||||
print(f"\nMetrics:")
|
||||
print(f" True Positives (correct matches): {true_positives}")
|
||||
print(f" False Positives (wrong matches): {false_positives}")
|
||||
print(f" False Negatives (missed logos): {false_negatives}")
|
||||
print(f" Total expected matches: {total_expected}")
|
||||
|
||||
print(f"\nScores:")
|
||||
print(f" Precision: {precision:.4f} ({precision*100:.1f}%)")
|
||||
print(f" Recall: {recall:.4f} ({recall*100:.1f}%)")
|
||||
print(f" F1 Score: {f1:.4f} ({f1*100:.1f}%)")
|
||||
|
||||
# Show false positive examples
|
||||
false_positive_examples = [r for r in results if r.get("matched_logo") and not r["correct"]]
|
||||
if false_positive_examples:
|
||||
print(f"\nExample False Positives (first 5):")
|
||||
for r in false_positive_examples[:5]:
|
||||
print(f" - Image: {r['test_image']}")
|
||||
print(f" Matched: {r['matched_logo']} (similarity: {r['similarity']:.3f})")
|
||||
|
||||
# Show false negative examples
|
||||
false_negative_examples = [r for r in results if r.get("expected_logo")]
|
||||
if false_negative_examples:
|
||||
print(f"\nExample False Negatives (first 5):")
|
||||
for r in false_negative_examples[:5]:
|
||||
print(f" - Image: {r['test_image']}")
|
||||
print(f" Expected: {r['expected_logo']}")
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
# Print similarity details if requested
|
||||
if args.similarity_details:
|
||||
print_similarity_details(similarity_details, args.threshold)
|
||||
|
||||
# Write results to file if requested
|
||||
if args.output_file:
|
||||
write_results_to_file(
|
||||
output_path=Path(args.output_file),
|
||||
args=args,
|
||||
num_test_images=len(test_files),
|
||||
true_positives=true_positives,
|
||||
false_positives=false_positives,
|
||||
false_negatives=false_negatives,
|
||||
total_expected=total_expected,
|
||||
precision=precision,
|
||||
recall=recall,
|
||||
f1=f1,
|
||||
)
|
||||
print(f"\nResults appended to: {args.output_file}")
|
||||
|
||||
|
||||
def print_similarity_details(details: dict, threshold: float):
|
||||
"""Print detailed similarity distribution analysis."""
|
||||
import statistics
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SIMILARITY DISTRIBUTION ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
def compute_stats(values, name):
|
||||
if not values:
|
||||
print(f"\n{name}: No data")
|
||||
return
|
||||
print(f"\n{name} (n={len(values)}):")
|
||||
print(f" Min: {min(values):.4f}")
|
||||
print(f" Max: {max(values):.4f}")
|
||||
print(f" Mean: {statistics.mean(values):.4f}")
|
||||
if len(values) > 1:
|
||||
print(f" StdDev: {statistics.stdev(values):.4f}")
|
||||
print(f" Median: {statistics.median(values):.4f}")
|
||||
|
||||
above = sum(1 for v in values if v >= threshold)
|
||||
below = sum(1 for v in values if v < threshold)
|
||||
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
|
||||
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
|
||||
|
||||
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities")
|
||||
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities")
|
||||
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities")
|
||||
|
||||
# Overlap analysis
|
||||
tp_sims = details["true_positive_sims"]
|
||||
fp_sims = details["false_positive_sims"]
|
||||
if tp_sims and fp_sims:
|
||||
print("\n" + "-" * 40)
|
||||
print("OVERLAP ANALYSIS:")
|
||||
tp_min, tp_max = min(tp_sims), max(tp_sims)
|
||||
fp_min, fp_max = min(fp_sims), max(fp_sims)
|
||||
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
|
||||
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
|
||||
|
||||
overlap_min = max(tp_min, fp_min)
|
||||
overlap_max = min(tp_max, fp_max)
|
||||
if overlap_min < overlap_max:
|
||||
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
|
||||
else:
|
||||
print(" NO OVERLAP - distributions are separable!")
|
||||
|
||||
# Sample detection details
|
||||
det_details = details["detection_details"]
|
||||
if det_details:
|
||||
print("\n" + "-" * 40)
|
||||
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
|
||||
for i, det in enumerate(det_details[:20]):
|
||||
expected = det["expected_logos"]
|
||||
sims = det["similarities"]
|
||||
print(f"\n [{i+1}] Image: {det['image']}")
|
||||
print(f" Expected: {expected if expected else '(none)'}")
|
||||
print(f" DETR score: {det['detr_score']:.3f}")
|
||||
print(f" Similarities:")
|
||||
for logo, sim in sims:
|
||||
marker = " <-- CORRECT" if logo in expected else ""
|
||||
print(f" {sim:.4f} {logo}{marker}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
|
||||
def write_results_to_file(
|
||||
output_path: Path,
|
||||
args,
|
||||
num_test_images: int,
|
||||
true_positives: int,
|
||||
false_positives: int,
|
||||
false_negatives: int,
|
||||
total_expected: int,
|
||||
precision: float,
|
||||
recall: float,
|
||||
f1: float,
|
||||
):
|
||||
"""Write results summary to file."""
|
||||
from datetime import datetime
|
||||
|
||||
lines = [
|
||||
"=" * 70,
|
||||
"BURNLEY LOGO DETECTION TEST",
|
||||
f"Model: {args.embedding_model}",
|
||||
f"Method: Margin-based (margin={args.margin})",
|
||||
"=" * 70,
|
||||
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
"",
|
||||
"Configuration:",
|
||||
f" Embedding model: {args.embedding_model}",
|
||||
f" Similarity threshold: {args.threshold}",
|
||||
f" DETR threshold: {args.detr_threshold}",
|
||||
f" Matching margin: {args.margin}",
|
||||
f" Test images processed: {num_test_images}",
|
||||
f" Reference logos: barnfield, vertu",
|
||||
"",
|
||||
"Results:",
|
||||
f" True Positives: {true_positives:>6}",
|
||||
f" False Positives: {false_positives:>6}",
|
||||
f" False Negatives: {false_negatives:>6}",
|
||||
f" Total Expected: {total_expected:>6}",
|
||||
"",
|
||||
"Scores:",
|
||||
f" Precision: {precision:.4f} ({precision*100:.1f}%)",
|
||||
f" Recall: {recall:.4f} ({recall*100:.1f}%)",
|
||||
f" F1 Score: {f1:.4f} ({f1*100:.1f}%)",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
|
||||
with open(output_path, "a") as f:
|
||||
f.write("\n".join(lines))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -243,12 +243,11 @@ def main():
|
||||
parser.add_argument(
|
||||
"--matching-method",
|
||||
type=str,
|
||||
choices=["simple", "margin", "multi-ref", "hybrid"],
|
||||
choices=["simple", "margin", "multi-ref"],
|
||||
default="margin",
|
||||
help="Matching method: 'simple' returns all matches above threshold, "
|
||||
"'margin' requires confidence margin over 2nd best, "
|
||||
"'multi-ref' aggregates scores across reference images, "
|
||||
"'hybrid' combines text recognition with CLIP (default: margin)",
|
||||
"'multi-ref' aggregates scores across reference images (default: margin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-matching-refs",
|
||||
@ -261,25 +260,6 @@ def main():
|
||||
action="store_true",
|
||||
help="For 'multi-ref' method: use max similarity instead of mean across references",
|
||||
)
|
||||
# Hybrid method arguments
|
||||
parser.add_argument(
|
||||
"--hybrid-text-threshold",
|
||||
type=float,
|
||||
default=0.60,
|
||||
help="For 'hybrid' method: CLIP threshold when text matches (default: 0.60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hybrid-no-text-threshold",
|
||||
type=float,
|
||||
default=0.80,
|
||||
help="For 'hybrid' method: CLIP threshold when text expected but not found (default: 0.80)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text-similarity-threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="For 'hybrid' method: minimum text similarity to consider a match (default: 0.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose",
|
||||
action="store_true",
|
||||
@ -352,14 +332,6 @@ def main():
|
||||
preprocess_mode=args.preprocess_mode,
|
||||
)
|
||||
|
||||
# Initialize text detector for hybrid method
|
||||
text_detector = None
|
||||
if args.matching_method == "hybrid":
|
||||
logger.info("Initializing text detector for hybrid matching...")
|
||||
from text_recognition import DetectText
|
||||
text_detector = DetectText(logger=logger, threshold=0.3)
|
||||
detector.set_text_detector(text_detector)
|
||||
|
||||
# Load ground truth (both mappings)
|
||||
logger.info("Loading ground truth from database...")
|
||||
image_to_logos, logo_to_images = get_ground_truth(db_path)
|
||||
@ -377,15 +349,10 @@ def main():
|
||||
multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {}
|
||||
# List for margin-based matching: (logo_name, embedding) tuples
|
||||
reference_embeddings: List[Tuple[str, torch.Tensor]] = []
|
||||
# Dict for hybrid matching: logo_name -> {'embeddings': [...], 'texts': [...]}
|
||||
hybrid_reference_data: Dict[str, Dict[str, Any]] = {}
|
||||
total_refs = 0
|
||||
logos_with_text = 0
|
||||
|
||||
for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"):
|
||||
multi_ref_embeddings[logo_name] = []
|
||||
if args.matching_method == "hybrid":
|
||||
hybrid_reference_data[logo_name] = {'embeddings': [], 'texts': set()}
|
||||
|
||||
for ref_filename in ref_filenames:
|
||||
ref_path = reference_dir / ref_filename
|
||||
@ -398,15 +365,12 @@ def main():
|
||||
cache_key = f"ref:{ref_filename}"
|
||||
embedding = cache.get(cache_key) if cache else None
|
||||
|
||||
# Load image if needed (for embedding or text extraction)
|
||||
img = None
|
||||
if embedding is None or args.matching_method == "hybrid":
|
||||
# Load image if needed for embedding
|
||||
if embedding is None:
|
||||
img = load_image(ref_path)
|
||||
if img is None:
|
||||
logger.warning(f"Failed to load reference logo: {ref_path}")
|
||||
continue
|
||||
|
||||
if embedding is None:
|
||||
embedding = detector.get_embedding(img)
|
||||
if cache:
|
||||
cache.put(cache_key, embedding)
|
||||
@ -415,21 +379,7 @@ def main():
|
||||
reference_embeddings.append((logo_name, embedding))
|
||||
total_refs += 1
|
||||
|
||||
# Extract text for hybrid method
|
||||
if args.matching_method == "hybrid" and img is not None:
|
||||
hybrid_reference_data[logo_name]['embeddings'].append(embedding)
|
||||
texts = detector.extract_text(img, min_confidence=0.3)
|
||||
hybrid_reference_data[logo_name]['texts'].update(texts)
|
||||
|
||||
# Convert text set to list for hybrid data
|
||||
if args.matching_method == "hybrid":
|
||||
hybrid_reference_data[logo_name]['texts'] = list(hybrid_reference_data[logo_name]['texts'])
|
||||
if hybrid_reference_data[logo_name]['texts']:
|
||||
logos_with_text += 1
|
||||
|
||||
logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos")
|
||||
if args.matching_method == "hybrid":
|
||||
logger.info(f"Extracted text from {logos_with_text}/{len(sampled_logos)} reference logos")
|
||||
|
||||
# Build test set: for each logo, sample positive and negative images
|
||||
logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...")
|
||||
@ -504,14 +454,7 @@ def main():
|
||||
cache_key = f"det:{test_filename}"
|
||||
cached_detections = cache.get(cache_key) if cache else None
|
||||
|
||||
# For hybrid matching, we always need the original image for text extraction
|
||||
test_img = None
|
||||
if args.matching_method == "hybrid":
|
||||
test_img = load_image(test_path)
|
||||
if test_img is None:
|
||||
logger.warning(f"Failed to load test image: {test_path}")
|
||||
continue
|
||||
|
||||
if cached_detections is not None:
|
||||
# Cached detections contain serialized box data and embeddings
|
||||
detections = cached_detections
|
||||
@ -651,50 +594,6 @@ def main():
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
else: # hybrid
|
||||
# Hybrid matching: combines text recognition with CLIP
|
||||
# Extract crop from original image for text extraction
|
||||
box = detection["box"]
|
||||
crop = test_img[
|
||||
int(box["ymin"]):int(box["ymax"]),
|
||||
int(box["xmin"]):int(box["xmax"])
|
||||
]
|
||||
|
||||
match_result = detector.find_best_match_hybrid(
|
||||
detected_embedding=detection["embedding"],
|
||||
detected_image=crop,
|
||||
reference_data=hybrid_reference_data,
|
||||
clip_threshold=args.threshold,
|
||||
clip_threshold_with_text=args.hybrid_text_threshold,
|
||||
clip_threshold_text_mismatch=args.hybrid_no_text_threshold,
|
||||
text_similarity_threshold=args.text_similarity_threshold,
|
||||
margin=args.margin,
|
||||
use_mean_similarity=not args.use_max_similarity,
|
||||
)
|
||||
if match_result:
|
||||
label, similarity, match_info = match_result
|
||||
matched_logos.add(label)
|
||||
|
||||
is_correct = label in expected_logos
|
||||
if is_correct:
|
||||
true_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["true_positive_sims"].append(similarity)
|
||||
else:
|
||||
false_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["false_positive_sims"].append(similarity)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": label,
|
||||
"similarity": similarity,
|
||||
"correct": is_correct,
|
||||
"text_matched": match_info.get("text_matched", False),
|
||||
"text_similarity": match_info.get("text_similarity", 0),
|
||||
"match_type": match_info.get("match_type", "unknown"),
|
||||
})
|
||||
|
||||
# Count missed detections (false negatives)
|
||||
missed = expected_logos - matched_logos
|
||||
false_negatives += len(missed)
|
||||
@ -742,16 +641,11 @@ def main():
|
||||
print(f" DETR confidence threshold: {args.detr_threshold}")
|
||||
print(f" Preprocess mode: {args.preprocess_mode}")
|
||||
print(f" Matching method: {args.matching_method}")
|
||||
if args.matching_method in ("margin", "multi-ref", "hybrid"):
|
||||
if args.matching_method in ("margin", "multi-ref"):
|
||||
print(f" Matching margin: {args.margin}")
|
||||
if args.matching_method == "multi-ref":
|
||||
print(f" Min matching refs: {args.min_matching_refs}")
|
||||
print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}")
|
||||
if args.matching_method == "hybrid":
|
||||
print(f" CLIP threshold (text match): {args.hybrid_text_threshold}")
|
||||
print(f" CLIP threshold (no text): {args.hybrid_no_text_threshold}")
|
||||
print(f" Text similarity threshold: {args.text_similarity_threshold}")
|
||||
print(f" Refs with text: {logos_with_text}/{len(sampled_logos)}")
|
||||
if args.seed is not None:
|
||||
print(f" Random seed: {args.seed}")
|
||||
|
||||
@ -939,14 +833,9 @@ def write_results_to_file(
|
||||
method_desc = "Simple (all matches above threshold)"
|
||||
elif args.matching_method == "margin":
|
||||
method_desc = f"Margin-based (margin={args.margin})"
|
||||
elif args.matching_method == "multi-ref":
|
||||
else: # multi-ref
|
||||
agg = "max" if args.use_max_similarity else "mean"
|
||||
method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})"
|
||||
else: # hybrid
|
||||
method_desc = (
|
||||
f"Hybrid (text+CLIP, text_thresh={args.hybrid_text_threshold}, "
|
||||
f"no_text_thresh={args.hybrid_no_text_threshold}, margin={args.margin})"
|
||||
)
|
||||
|
||||
lines = [
|
||||
"=" * 70,
|
||||
|
||||
Reference in New Issue
Block a user