Hybrid matching combines text recognition with CLIP similarity: - If reference logo has text and detection matches: lower CLIP threshold - If reference has text but detection doesn't match: higher threshold - If reference has no text: standard threshold Image preprocessing adds letterbox/stretch modes for CLIP input to preserve aspect ratio instead of center cropping. New files: - run_hybrid_test.sh: Test hybrid matching configurations - run_preprocess_test.sh: Compare preprocessing modes Changes to logo_detection_detr.py: - Add preprocess_mode parameter (default/letterbox/stretch) - Add set_text_detector() for hybrid matching - Add extract_text() using EasyOCR - Add compute_text_similarity() with fuzzy matching - Add find_best_match_hybrid() with tiered thresholds Changes to test_logo_detection.py: - Add --matching-method hybrid option - Add --preprocess-mode option - Add hybrid threshold arguments
1075 lines
40 KiB
Python
1075 lines
40 KiB
Python
"""
|
|
Logo detection using DETR for object detection and vision models for feature matching.
|
|
|
|
This module provides a class for detecting logos in images using:
|
|
1. DETR (DEtection TRansformer) for initial logo region detection
|
|
2. Vision models (CLIP, DINOv2, etc.) for feature extraction and matching
|
|
|
|
The class supports caching of embeddings for efficient reprocessing.
|
|
The class automatically uses local models if available, otherwise falls back to HuggingFace.
|
|
|
|
Supported embedding models:
|
|
- CLIP models (openai/clip-vit-*): Text-image alignment, good general features
|
|
- DINOv2 models (facebook/dinov2-*): Self-supervised, excellent for visual similarity
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from transformers import pipeline, CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModel
|
|
from PIL import Image
|
|
import cv2
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from typing import List, Tuple, Dict, Optional, Any
|
|
from difflib import SequenceMatcher
|
|
|
|
|
|
class DetectLogosDETR:
|
|
"""
|
|
Logo detection class using DETR and vision embedding models.
|
|
|
|
This class detects logos in images by:
|
|
1. Using DETR to find potential logo regions (bounding boxes)
|
|
2. Extracting embeddings for each detected region (CLIP, DINOv2, etc.)
|
|
3. Comparing embeddings with reference logos for identification
|
|
|
|
The class automatically checks for local models before downloading from HuggingFace.
|
|
|
|
Supported embedding models:
|
|
- CLIP models (openai/clip-vit-*): Text-image alignment
|
|
- DINOv2 models (facebook/dinov2-*): Self-supervised visual features
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
logger,
|
|
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
|
|
embedding_model: str = "openai/clip-vit-large-patch14",
|
|
detr_threshold: float = 0.5,
|
|
min_box_size: int = 20,
|
|
nms_iou_threshold: float = 0.5,
|
|
preprocess_mode: str = "default",
|
|
):
|
|
"""
|
|
Initialize DETR and embedding models.
|
|
|
|
The class will automatically check for local models in the default directories
|
|
before downloading from HuggingFace. You can override this by providing absolute
|
|
paths to local models.
|
|
|
|
Args:
|
|
logger: Logger instance for logging
|
|
detr_model: HuggingFace model name or local path for DETR object detection
|
|
embedding_model: HuggingFace model name for embeddings (CLIP or DINOv2)
|
|
detr_threshold: Confidence threshold for DETR detections (0-1)
|
|
min_box_size: Minimum width/height in pixels for detected boxes (filters noise)
|
|
nms_iou_threshold: IoU threshold for Non-Maximum Suppression
|
|
preprocess_mode: Image preprocessing mode for CLIP:
|
|
- "default": Use CLIP's default (resize shortest edge + center crop)
|
|
- "letterbox": Pad to square with black bars, preserving aspect ratio
|
|
- "stretch": Stretch to square (distorts aspect ratio)
|
|
"""
|
|
self.logger = logger
|
|
self.detr_threshold = detr_threshold
|
|
self.min_box_size = min_box_size
|
|
self.nms_iou_threshold = nms_iou_threshold
|
|
self.embedding_model_name = embedding_model
|
|
self.preprocess_mode = preprocess_mode
|
|
|
|
# Set device
|
|
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
self.device_index = 0 if torch.cuda.is_available() else -1
|
|
self.device = torch.device(self.device_str)
|
|
|
|
self.logger.info(f"Initializing DetectLogosDETR on device: {self.device_str}")
|
|
|
|
# Get default model directories from environment variables
|
|
default_detr_dir = os.environ.get('LOGO_DETR_MODEL_DIR', 'models/logo_detection/detr')
|
|
default_embedding_dir = os.environ.get('LOGO_EMBEDDING_MODEL_DIR', 'models/logo_detection/embedding')
|
|
|
|
# Resolve DETR model path (check local first, then use HuggingFace name)
|
|
detr_model_path = self._resolve_model_path(
|
|
detr_model, default_detr_dir, "DETR"
|
|
)
|
|
|
|
# Initialize DETR pipeline for logo detection
|
|
self.logger.info(f"Loading DETR model: {detr_model_path}")
|
|
self.detr_pipe = pipeline(
|
|
task="object-detection",
|
|
model=detr_model_path,
|
|
device=self.device_index,
|
|
use_fast=True,
|
|
)
|
|
|
|
# Resolve embedding model path
|
|
embedding_model_path = self._resolve_model_path(
|
|
embedding_model, default_embedding_dir, "Embedding"
|
|
)
|
|
|
|
# Check if this is a fine-tuned model
|
|
if self._is_finetuned_model(embedding_model_path):
|
|
self._load_finetuned_embedding_model(embedding_model_path)
|
|
else:
|
|
# Detect model type and initialize accordingly
|
|
self.model_type = self._detect_model_type(embedding_model)
|
|
self.logger.info(f"Loading {self.model_type} embedding model: {embedding_model_path}")
|
|
|
|
if self.model_type == "clip":
|
|
self.embedding_model = CLIPModel.from_pretrained(embedding_model_path).to(self.device)
|
|
self.embedding_processor = CLIPProcessor.from_pretrained(embedding_model_path)
|
|
else: # dinov2 or other transformer models
|
|
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
|
|
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
|
|
|
|
if self.preprocess_mode != "default":
|
|
self.logger.info(f"Image preprocessing mode: {self.preprocess_mode}")
|
|
self.logger.info("DetectLogosDETR initialization complete")
|
|
|
|
def _detect_model_type(self, model_name: str) -> str:
|
|
"""Detect the type of embedding model based on name."""
|
|
model_name_lower = model_name.lower()
|
|
if "clip" in model_name_lower:
|
|
return "clip"
|
|
elif "dino" in model_name_lower:
|
|
return "dinov2"
|
|
else:
|
|
# Default to generic transformer for unknown models
|
|
return "transformer"
|
|
|
|
def _is_finetuned_model(self, model_path: str) -> bool:
|
|
"""Check if a model path points to a fine-tuned CLIP model."""
|
|
config_path = Path(model_path) / "config.json"
|
|
if config_path.exists():
|
|
try:
|
|
with open(config_path, "r") as f:
|
|
config = json.load(f)
|
|
return config.get("model_type") == "clip_logo_finetuned"
|
|
except (json.JSONDecodeError, IOError):
|
|
pass
|
|
return False
|
|
|
|
def _load_finetuned_embedding_model(self, model_path: str) -> None:
|
|
"""
|
|
Load a fine-tuned CLIP model from the training module.
|
|
|
|
Args:
|
|
model_path: Path to the fine-tuned model directory
|
|
"""
|
|
# Import the fine-tuned model class
|
|
try:
|
|
from training.model import LogoFineTunedCLIP
|
|
except ImportError as e:
|
|
self.logger.error(
|
|
f"Cannot import training.model for fine-tuned model: {e}"
|
|
)
|
|
raise ImportError(
|
|
"Fine-tuned model requires the training module. "
|
|
"Ensure the training/ directory is in your Python path."
|
|
) from e
|
|
|
|
# Load config
|
|
config_path = Path(model_path) / "config.json"
|
|
with open(config_path, "r") as f:
|
|
config = json.load(f)
|
|
|
|
base_model = config.get("base_model", "openai/clip-vit-large-patch14")
|
|
|
|
self.logger.info(f"Loading fine-tuned CLIP model from: {model_path}")
|
|
self.logger.info(f" Base model: {base_model}")
|
|
|
|
# Load model using the from_pretrained method
|
|
self.embedding_model = LogoFineTunedCLIP.from_pretrained(
|
|
model_path,
|
|
base_model=base_model,
|
|
device=self.device,
|
|
)
|
|
self.embedding_model.eval()
|
|
|
|
# Load processor from base model
|
|
self.embedding_processor = CLIPProcessor.from_pretrained(base_model)
|
|
|
|
# Set model type for embedding extraction
|
|
self.model_type = "clip_finetuned"
|
|
self.logger.info("Fine-tuned CLIP model loaded successfully")
|
|
|
|
def _resolve_model_path(
|
|
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
|
) -> str:
|
|
"""
|
|
Resolve model path, checking for local models before using HuggingFace.
|
|
|
|
Args:
|
|
model_name_or_path: HuggingFace model name or absolute path
|
|
default_local_dir: Default local directory to check
|
|
model_type: Type of model (for logging, e.g., "DETR" or "CLIP")
|
|
|
|
Returns:
|
|
Resolved model path (local path or HuggingFace model name)
|
|
"""
|
|
# If it's an absolute path, use it directly
|
|
if os.path.isabs(model_name_or_path):
|
|
if os.path.exists(model_name_or_path):
|
|
self.logger.info(
|
|
f"{model_type} model: Using local model at {model_name_or_path}"
|
|
)
|
|
return model_name_or_path
|
|
else:
|
|
self.logger.warning(
|
|
f"{model_type} model: Local path {model_name_or_path} does not exist, "
|
|
f"falling back to HuggingFace"
|
|
)
|
|
return model_name_or_path
|
|
|
|
# Check if default local directory exists
|
|
if os.path.exists(default_local_dir):
|
|
# Verify it's a valid model directory (has config.json)
|
|
config_file = os.path.join(default_local_dir, "config.json")
|
|
if os.path.exists(config_file):
|
|
abs_path = os.path.abspath(default_local_dir)
|
|
self.logger.info(
|
|
f"{model_type} model: Found local model at {abs_path}"
|
|
)
|
|
return abs_path
|
|
else:
|
|
self.logger.warning(
|
|
f"{model_type} model: Local directory {default_local_dir} exists but "
|
|
f"is not a valid model (missing config.json)"
|
|
)
|
|
|
|
# Use HuggingFace model name
|
|
self.logger.info(
|
|
f"{model_type} model: No local model found, will download from HuggingFace: "
|
|
f"{model_name_or_path}"
|
|
)
|
|
return model_name_or_path
|
|
|
|
def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
|
"""
|
|
Detect logos in an image and return bounding boxes with CLIP embeddings.
|
|
|
|
Args:
|
|
image: OpenCV image (BGR format, numpy array)
|
|
|
|
Returns:
|
|
List of dictionaries, each containing:
|
|
- 'box': dict with 'xmin', 'ymin', 'xmax', 'ymax' (pixel coordinates)
|
|
- 'score': DETR confidence score (float 0-1)
|
|
- 'embedding': CLIP feature embedding (torch.Tensor)
|
|
- 'label': DETR predicted label (string)
|
|
"""
|
|
# Convert OpenCV BGR to RGB PIL Image
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
pil_image = Image.fromarray(image_rgb)
|
|
|
|
# Run DETR detection
|
|
predictions = self.detr_pipe(pil_image)
|
|
|
|
# Filter by threshold and size, then add CLIP embeddings
|
|
detections = []
|
|
for pred in predictions:
|
|
score = pred.get("score", 0.0)
|
|
if score < self.detr_threshold:
|
|
continue
|
|
|
|
box = pred.get("box", {})
|
|
xmin = box.get("xmin", 0)
|
|
ymin = box.get("ymin", 0)
|
|
xmax = box.get("xmax", 0)
|
|
ymax = box.get("ymax", 0)
|
|
|
|
# Filter by minimum box size
|
|
box_width = xmax - xmin
|
|
box_height = ymax - ymin
|
|
if box_width < self.min_box_size or box_height < self.min_box_size:
|
|
continue
|
|
|
|
# Extract bounding box region
|
|
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
|
|
|
|
# Get embedding for this region
|
|
embedding = self._get_embedding_pil(bbox_crop)
|
|
|
|
detections.append(
|
|
{
|
|
"box": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax},
|
|
"score": score,
|
|
"embedding": embedding,
|
|
"label": pred.get("label", "logo"),
|
|
}
|
|
)
|
|
|
|
# Apply Non-Maximum Suppression to remove overlapping detections
|
|
detections = self._apply_nms(detections, self.nms_iou_threshold)
|
|
|
|
self.logger.debug(f"Detected {len(detections)} logos (threshold: {self.detr_threshold})")
|
|
return detections
|
|
|
|
def _apply_nms(self, predictions: List[Dict], iou_threshold: float) -> List[Dict]:
|
|
"""
|
|
Apply Non-Maximum Suppression to remove overlapping detections.
|
|
|
|
Args:
|
|
predictions: List of prediction dictionaries with 'box' and 'score'
|
|
iou_threshold: IoU threshold for considering boxes as overlapping
|
|
|
|
Returns:
|
|
Filtered list of predictions after NMS
|
|
"""
|
|
if len(predictions) == 0:
|
|
return []
|
|
|
|
# Extract boxes and scores
|
|
boxes = []
|
|
scores = []
|
|
for pred in predictions:
|
|
box = pred.get("box", {})
|
|
boxes.append([
|
|
box.get("xmin", 0),
|
|
box.get("ymin", 0),
|
|
box.get("xmax", 0),
|
|
box.get("ymax", 0)
|
|
])
|
|
scores.append(pred.get("score", 0.0))
|
|
|
|
# Convert to numpy arrays
|
|
boxes = np.array(boxes, dtype=np.float32)
|
|
scores = np.array(scores, dtype=np.float32)
|
|
|
|
# Sort by scores (descending)
|
|
sorted_indices = np.argsort(scores)[::-1]
|
|
|
|
keep_indices = []
|
|
while len(sorted_indices) > 0:
|
|
# Keep the box with highest score
|
|
current_idx = sorted_indices[0]
|
|
keep_indices.append(current_idx)
|
|
|
|
if len(sorted_indices) == 1:
|
|
break
|
|
|
|
# Calculate IoU with remaining boxes
|
|
current_box = boxes[current_idx]
|
|
remaining_boxes = boxes[sorted_indices[1:]]
|
|
|
|
ious = self._calculate_iou_batch(current_box, remaining_boxes)
|
|
|
|
# Keep only boxes with IoU below threshold
|
|
mask = ious < iou_threshold
|
|
sorted_indices = sorted_indices[1:][mask]
|
|
|
|
# Return predictions for kept indices
|
|
return [predictions[i] for i in keep_indices]
|
|
|
|
def _calculate_iou_batch(self, box: np.ndarray, boxes: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Calculate IoU between one box and multiple boxes.
|
|
|
|
Args:
|
|
box: Single box [xmin, ymin, xmax, ymax]
|
|
boxes: Multiple boxes [[xmin, ymin, xmax, ymax], ...]
|
|
|
|
Returns:
|
|
Array of IoU values
|
|
"""
|
|
# Calculate intersection coordinates
|
|
x1 = np.maximum(box[0], boxes[:, 0])
|
|
y1 = np.maximum(box[1], boxes[:, 1])
|
|
x2 = np.minimum(box[2], boxes[:, 2])
|
|
y2 = np.minimum(box[3], boxes[:, 3])
|
|
|
|
# Calculate intersection area
|
|
intersection = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1)
|
|
|
|
# Calculate union area
|
|
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
|
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
|
|
union = box_area + boxes_area - intersection
|
|
|
|
# Calculate IoU
|
|
iou = intersection / (union + 1e-6) # Add small epsilon to avoid division by zero
|
|
|
|
return iou
|
|
|
|
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
|
|
"""
|
|
Get embedding for a reference logo image.
|
|
|
|
This method is used to compute embeddings for reference logos
|
|
that will be compared against detected regions.
|
|
|
|
Args:
|
|
image: OpenCV image (BGR format, numpy array)
|
|
|
|
Returns:
|
|
Normalized feature embedding (torch.Tensor)
|
|
"""
|
|
# Convert OpenCV BGR to RGB PIL Image
|
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
pil_image = Image.fromarray(image_rgb)
|
|
|
|
return self._get_embedding_pil(pil_image)
|
|
|
|
def _preprocess_image(self, pil_image: Image.Image, target_size: int = 224) -> Image.Image:
|
|
"""
|
|
Preprocess image based on the configured preprocessing mode.
|
|
|
|
Args:
|
|
pil_image: PIL Image (RGB format)
|
|
target_size: Target size for the square output (default 224 for CLIP)
|
|
|
|
Returns:
|
|
Preprocessed PIL Image
|
|
"""
|
|
if self.preprocess_mode == "default":
|
|
# Let the processor handle it (resize shortest edge + center crop)
|
|
return pil_image
|
|
|
|
width, height = pil_image.size
|
|
|
|
if self.preprocess_mode == "letterbox":
|
|
# Pad to square with black bars, preserving aspect ratio
|
|
max_dim = max(width, height)
|
|
|
|
# Create a black square canvas
|
|
new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
|
|
|
|
# Paste the original image centered
|
|
paste_x = (max_dim - width) // 2
|
|
paste_y = (max_dim - height) // 2
|
|
new_image.paste(pil_image, (paste_x, paste_y))
|
|
|
|
# Resize to target size
|
|
return new_image.resize((target_size, target_size), Image.LANCZOS)
|
|
|
|
elif self.preprocess_mode == "stretch":
|
|
# Stretch to square (distorts aspect ratio)
|
|
return pil_image.resize((target_size, target_size), Image.LANCZOS)
|
|
|
|
else:
|
|
# Unknown mode, return original
|
|
return pil_image
|
|
|
|
def _get_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor:
|
|
"""
|
|
Internal method to get embedding from PIL image.
|
|
|
|
Handles CLIP, fine-tuned CLIP, and DINOv2 model types.
|
|
|
|
Args:
|
|
pil_image: PIL Image (RGB format)
|
|
|
|
Returns:
|
|
Normalized feature embedding (torch.Tensor)
|
|
"""
|
|
# Apply preprocessing if configured
|
|
if self.preprocess_mode != "default":
|
|
pil_image = self._preprocess_image(pil_image)
|
|
|
|
# Process image through the embedding model
|
|
inputs = self.embedding_processor(images=pil_image, return_tensors="pt").to(self.device)
|
|
|
|
with torch.no_grad():
|
|
if self.model_type == "clip":
|
|
# CLIP has a dedicated method for image features
|
|
features = self.embedding_model.get_image_features(**inputs)
|
|
elif self.model_type == "clip_finetuned":
|
|
# Fine-tuned CLIP uses get_image_features or forward with pixel_values
|
|
features = self.embedding_model.get_image_features(**inputs)
|
|
else:
|
|
# DINOv2 and other transformers use the CLS token or pooled output
|
|
outputs = self.embedding_model(**inputs)
|
|
# Use the CLS token (first token) from last hidden state
|
|
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
|
features = outputs.pooler_output
|
|
else:
|
|
# Use CLS token from last_hidden_state
|
|
features = outputs.last_hidden_state[:, 0, :]
|
|
|
|
# Normalize for cosine similarity (fine-tuned model already normalizes)
|
|
if self.model_type != "clip_finetuned":
|
|
features = F.normalize(features, dim=-1)
|
|
|
|
return features
|
|
|
|
def compare_embeddings(
|
|
self, embedding1: torch.Tensor, embedding2: torch.Tensor
|
|
) -> float:
|
|
"""
|
|
Compute cosine similarity between two CLIP embeddings.
|
|
|
|
Args:
|
|
embedding1: First CLIP embedding (torch.Tensor)
|
|
embedding2: Second CLIP embedding (torch.Tensor)
|
|
|
|
Returns:
|
|
Cosine similarity score (float, range: -1 to 1, typically 0 to 1)
|
|
"""
|
|
# Ensure tensors are on the same device
|
|
if embedding1.device != embedding2.device:
|
|
embedding2 = embedding2.to(embedding1.device)
|
|
|
|
# Compute cosine similarity
|
|
similarity = F.cosine_similarity(embedding1, embedding2, dim=-1)
|
|
|
|
# Return as Python float
|
|
return similarity.item()
|
|
|
|
def find_best_match(
|
|
self,
|
|
detected_embedding: torch.Tensor,
|
|
reference_embeddings: List[Tuple[str, torch.Tensor]],
|
|
similarity_threshold: float = 0.7,
|
|
) -> Optional[Tuple[str, float]]:
|
|
"""
|
|
Find the best matching reference logo for a detected embedding.
|
|
|
|
Args:
|
|
detected_embedding: CLIP embedding from detected logo region
|
|
reference_embeddings: List of (label, embedding) tuples for reference logos
|
|
similarity_threshold: Minimum similarity to consider a match (0-1)
|
|
|
|
Returns:
|
|
Tuple of (label, similarity) for best match, or None if no match above threshold
|
|
"""
|
|
if not reference_embeddings:
|
|
return None
|
|
|
|
best_similarity = -1.0
|
|
best_label = None
|
|
|
|
for label, ref_embedding in reference_embeddings:
|
|
similarity = self.compare_embeddings(detected_embedding, ref_embedding)
|
|
|
|
if similarity > best_similarity:
|
|
best_similarity = similarity
|
|
best_label = label
|
|
|
|
if best_similarity >= similarity_threshold:
|
|
return (best_label, best_similarity)
|
|
else:
|
|
return None
|
|
|
|
def find_all_matches(
|
|
self,
|
|
detected_embedding: torch.Tensor,
|
|
reference_embeddings: List[Tuple[str, torch.Tensor]],
|
|
similarity_threshold: float = 0.7,
|
|
) -> List[Tuple[str, float]]:
|
|
"""
|
|
Find all matching reference logos above the similarity threshold.
|
|
|
|
Unlike find_best_match, this returns ALL logos that have at least one
|
|
reference above threshold. Each unique logo is returned once with its
|
|
highest similarity score.
|
|
|
|
Args:
|
|
detected_embedding: CLIP embedding from detected logo region
|
|
reference_embeddings: List of (label, embedding) tuples for reference logos
|
|
similarity_threshold: Minimum similarity to consider a match (0-1)
|
|
|
|
Returns:
|
|
List of (label, similarity) tuples for all matches above threshold,
|
|
sorted by similarity descending. Each logo appears at most once.
|
|
"""
|
|
if not reference_embeddings:
|
|
return []
|
|
|
|
# Track best similarity for each logo
|
|
logo_best_sim: Dict[str, float] = {}
|
|
|
|
for label, ref_embedding in reference_embeddings:
|
|
similarity = self.compare_embeddings(detected_embedding, ref_embedding)
|
|
|
|
if similarity >= similarity_threshold:
|
|
if label not in logo_best_sim or similarity > logo_best_sim[label]:
|
|
logo_best_sim[label] = similarity
|
|
|
|
# Convert to list and sort by similarity descending
|
|
matches = [(label, sim) for label, sim in logo_best_sim.items()]
|
|
matches.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
return matches
|
|
|
|
def find_best_match_multi_ref(
|
|
self,
|
|
detected_embedding: torch.Tensor,
|
|
reference_embeddings: Dict[str, List[torch.Tensor]],
|
|
similarity_threshold: float = 0.85,
|
|
min_matching_refs: int = 1,
|
|
use_mean_similarity: bool = True,
|
|
margin: float = 0.0,
|
|
) -> Optional[Tuple[str, float, int]]:
|
|
"""
|
|
Find the best matching reference logo using multiple reference embeddings per logo.
|
|
|
|
This method improves accuracy by using multiple reference images for each logo
|
|
and requiring consistency across references.
|
|
|
|
Args:
|
|
detected_embedding: CLIP embedding from detected logo region
|
|
reference_embeddings: Dict mapping logo name to list of embeddings
|
|
similarity_threshold: Minimum similarity to consider a match (0-1)
|
|
min_matching_refs: Minimum number of references that must match above threshold
|
|
use_mean_similarity: If True, use mean similarity across all refs; if False, use max
|
|
margin: Required margin between best and second-best logo scores (0-1)
|
|
|
|
Returns:
|
|
Tuple of (label, similarity, num_matching_refs) for best match,
|
|
or None if no match meets criteria
|
|
"""
|
|
if not reference_embeddings:
|
|
return None
|
|
|
|
# Calculate scores for all logos that meet the min_matching_refs requirement
|
|
logo_scores = []
|
|
|
|
for label, ref_embedding_list in reference_embeddings.items():
|
|
if not ref_embedding_list:
|
|
continue
|
|
|
|
# Calculate similarity to each reference embedding
|
|
similarities = []
|
|
for ref_embedding in ref_embedding_list:
|
|
sim = self.compare_embeddings(detected_embedding, ref_embedding)
|
|
similarities.append(sim)
|
|
|
|
# Count how many references match above threshold
|
|
num_matches = sum(1 for s in similarities if s >= similarity_threshold)
|
|
|
|
# Calculate aggregate score
|
|
if use_mean_similarity:
|
|
score = sum(similarities) / len(similarities)
|
|
else:
|
|
score = max(similarities)
|
|
|
|
# Only consider logos that meet the minimum matching refs requirement
|
|
if num_matches >= min_matching_refs:
|
|
logo_scores.append((label, score, num_matches))
|
|
|
|
if not logo_scores:
|
|
return None
|
|
|
|
# Sort by score descending
|
|
logo_scores.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
best_label, best_score, best_num_matches = logo_scores[0]
|
|
|
|
# Check if best score meets threshold
|
|
if best_score < similarity_threshold:
|
|
return None
|
|
|
|
# Check margin against second-best logo (if exists)
|
|
if margin > 0 and len(logo_scores) > 1:
|
|
second_best_score = logo_scores[1][1]
|
|
if best_score - second_best_score < margin:
|
|
return None # Not confident enough
|
|
|
|
return (best_label, best_score, best_num_matches)
|
|
|
|
def find_best_match_with_margin(
|
|
self,
|
|
detected_embedding: torch.Tensor,
|
|
reference_embeddings: List[Tuple[str, torch.Tensor]],
|
|
similarity_threshold: float = 0.85,
|
|
margin: float = 0.05,
|
|
) -> Optional[Tuple[str, float]]:
|
|
"""
|
|
Find best match with a confidence margin over the second-best match.
|
|
|
|
This reduces false positives by requiring the best match to be
|
|
significantly better than alternatives.
|
|
|
|
Args:
|
|
detected_embedding: CLIP embedding from detected logo region
|
|
reference_embeddings: List of (label, embedding) tuples for reference logos
|
|
similarity_threshold: Minimum similarity to consider a match (0-1)
|
|
margin: Required margin between best and second-best match
|
|
|
|
Returns:
|
|
Tuple of (label, similarity) for best match, or None if no confident match
|
|
"""
|
|
if not reference_embeddings:
|
|
return None
|
|
|
|
# Calculate all similarities
|
|
similarities = []
|
|
for label, ref_embedding in reference_embeddings:
|
|
sim = self.compare_embeddings(detected_embedding, ref_embedding)
|
|
similarities.append((label, sim))
|
|
|
|
# Sort by similarity descending
|
|
similarities.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
best_label, best_sim = similarities[0]
|
|
|
|
# Check if best is above threshold
|
|
if best_sim < similarity_threshold:
|
|
return None
|
|
|
|
# Check margin against second best (if exists)
|
|
if len(similarities) > 1:
|
|
second_best_sim = similarities[1][1]
|
|
if best_sim - second_best_sim < margin:
|
|
return None # Not confident enough
|
|
|
|
return (best_label, best_sim)
|
|
|
|
def detect_and_match(
|
|
self,
|
|
image: np.ndarray,
|
|
reference_embeddings: List[Tuple[str, torch.Tensor]],
|
|
similarity_threshold: float = 0.7,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Detect logos and match them against reference embeddings in one step.
|
|
|
|
This is a convenience method that combines detection and matching.
|
|
|
|
Args:
|
|
image: OpenCV image (BGR format, numpy array)
|
|
reference_embeddings: List of (label, embedding) tuples for reference logos
|
|
similarity_threshold: Minimum similarity to consider a match (0-1)
|
|
|
|
Returns:
|
|
List of matched detections, each containing:
|
|
- 'box': bounding box coordinates
|
|
- 'detr_score': DETR confidence score
|
|
- 'clip_similarity': CLIP similarity score
|
|
- 'label': matched reference logo label
|
|
"""
|
|
# Detect all logos
|
|
detections = self.detect(image)
|
|
|
|
# Match each detection against references
|
|
matched_detections = []
|
|
for detection in detections:
|
|
match_result = self.find_best_match(
|
|
detection["embedding"], reference_embeddings, similarity_threshold
|
|
)
|
|
|
|
if match_result is not None:
|
|
label, similarity = match_result
|
|
matched_detections.append(
|
|
{
|
|
"box": detection["box"],
|
|
"detr_score": detection["score"],
|
|
"clip_similarity": similarity,
|
|
"label": label,
|
|
}
|
|
)
|
|
|
|
self.logger.debug(
|
|
f"Matched {len(matched_detections)}/{len(detections)} detections "
|
|
f"(threshold: {similarity_threshold})"
|
|
)
|
|
|
|
return matched_detections
|
|
|
|
# =========================================================================
|
|
# Hybrid Text + CLIP Matching
|
|
# =========================================================================
|
|
|
|
def set_text_detector(self, text_detector) -> None:
|
|
"""
|
|
Set an optional text detector for hybrid matching.
|
|
|
|
Args:
|
|
text_detector: Instance of DetectText class from text_recognition.py
|
|
"""
|
|
self.text_detector = text_detector
|
|
self.logger.info("Text detector enabled for hybrid matching")
|
|
|
|
def extract_text(self, image: np.ndarray, min_confidence: float = 0.3) -> List[str]:
|
|
"""
|
|
Extract text from an image using the text detector.
|
|
|
|
Args:
|
|
image: OpenCV image (BGR format)
|
|
min_confidence: Minimum OCR confidence to accept text
|
|
|
|
Returns:
|
|
List of detected text strings (lowercased, stripped)
|
|
"""
|
|
if not hasattr(self, 'text_detector') or self.text_detector is None:
|
|
return []
|
|
|
|
try:
|
|
results, _ = self.text_detector.detect(image)
|
|
# Filter by confidence and normalize text
|
|
texts = []
|
|
for text, confidence in results:
|
|
if confidence >= min_confidence:
|
|
# Normalize: lowercase, strip whitespace, remove special chars
|
|
normalized = text.lower().strip()
|
|
if len(normalized) >= 2: # Ignore single characters
|
|
texts.append(normalized)
|
|
return texts
|
|
except Exception as e:
|
|
self.logger.warning(f"Text extraction failed: {e}")
|
|
return []
|
|
|
|
def extract_text_pil(self, pil_image: Image.Image, min_confidence: float = 0.3) -> List[str]:
|
|
"""
|
|
Extract text from a PIL image.
|
|
|
|
Args:
|
|
pil_image: PIL Image (RGB format)
|
|
min_confidence: Minimum OCR confidence
|
|
|
|
Returns:
|
|
List of detected text strings
|
|
"""
|
|
# Convert PIL to OpenCV format
|
|
cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
|
return self.extract_text(cv_image, min_confidence)
|
|
|
|
@staticmethod
|
|
def compute_text_similarity(text1_list: List[str], text2_list: List[str]) -> float:
|
|
"""
|
|
Compute fuzzy text similarity between two lists of text strings.
|
|
|
|
Uses a combination of exact matches and fuzzy matching to handle
|
|
OCR variations like case differences, spacing, and minor errors.
|
|
|
|
Args:
|
|
text1_list: List of text strings from first image
|
|
text2_list: List of text strings from second image
|
|
|
|
Returns:
|
|
Similarity score between 0 and 1
|
|
"""
|
|
if not text1_list or not text2_list:
|
|
return 0.0
|
|
|
|
# Combine all text into single strings for overall comparison
|
|
text1_combined = " ".join(sorted(text1_list))
|
|
text2_combined = " ".join(sorted(text2_list))
|
|
|
|
# Method 1: Sequence matching on combined text
|
|
seq_similarity = SequenceMatcher(None, text1_combined, text2_combined).ratio()
|
|
|
|
# Method 2: Token overlap (Jaccard-like)
|
|
# Split into tokens
|
|
tokens1 = set(text1_combined.split())
|
|
tokens2 = set(text2_combined.split())
|
|
|
|
if tokens1 and tokens2:
|
|
intersection = len(tokens1 & tokens2)
|
|
union = len(tokens1 | tokens2)
|
|
token_similarity = intersection / union if union > 0 else 0
|
|
else:
|
|
token_similarity = 0
|
|
|
|
# Method 3: Best pairwise match for each text in list1
|
|
pairwise_scores = []
|
|
for t1 in text1_list:
|
|
best_match = 0
|
|
for t2 in text2_list:
|
|
score = SequenceMatcher(None, t1, t2).ratio()
|
|
best_match = max(best_match, score)
|
|
pairwise_scores.append(best_match)
|
|
|
|
pairwise_similarity = sum(pairwise_scores) / len(pairwise_scores) if pairwise_scores else 0
|
|
|
|
# Combine methods (weighted average)
|
|
combined = (seq_similarity * 0.3 + token_similarity * 0.3 + pairwise_similarity * 0.4)
|
|
|
|
return combined
|
|
|
|
@staticmethod
|
|
def texts_match(
|
|
ref_texts: List[str],
|
|
det_texts: List[str],
|
|
threshold: float = 0.5
|
|
) -> Tuple[bool, float]:
|
|
"""
|
|
Determine if texts match above a threshold.
|
|
|
|
Args:
|
|
ref_texts: Text from reference logo
|
|
det_texts: Text from detected region
|
|
threshold: Minimum similarity to consider a match
|
|
|
|
Returns:
|
|
Tuple of (is_match, similarity_score)
|
|
"""
|
|
if not ref_texts:
|
|
# Reference has no text - can't match on text
|
|
return (False, 0.0)
|
|
|
|
if not det_texts:
|
|
# Reference has text but detection doesn't - no text match
|
|
return (False, 0.0)
|
|
|
|
similarity = DetectLogosDETR.compute_text_similarity(ref_texts, det_texts)
|
|
return (similarity >= threshold, similarity)
|
|
|
|
def find_best_match_hybrid(
|
|
self,
|
|
detected_embedding: torch.Tensor,
|
|
detected_image: np.ndarray,
|
|
reference_data: Dict[str, Dict[str, Any]],
|
|
clip_threshold: float = 0.70,
|
|
clip_threshold_with_text: float = 0.60,
|
|
clip_threshold_text_mismatch: float = 0.80,
|
|
text_similarity_threshold: float = 0.5,
|
|
margin: float = 0.05,
|
|
use_mean_similarity: bool = False,
|
|
) -> Optional[Tuple[str, float, Dict[str, Any]]]:
|
|
"""
|
|
Find best match using hybrid text + CLIP approach.
|
|
|
|
Strategy:
|
|
- If reference has text AND detection has matching text:
|
|
→ Use lower CLIP threshold (text provides additional confidence)
|
|
- If reference has text but detection doesn't match:
|
|
→ Use higher CLIP threshold (need more visual confidence)
|
|
- If reference has no text:
|
|
→ Use standard CLIP threshold
|
|
|
|
Args:
|
|
detected_embedding: CLIP embedding from detected logo region
|
|
detected_image: OpenCV image of the detected region (for text extraction)
|
|
reference_data: Dict mapping logo name to:
|
|
{
|
|
'embeddings': List[torch.Tensor], # CLIP embeddings
|
|
'texts': List[str], # Extracted text from reference
|
|
}
|
|
clip_threshold: Standard CLIP threshold for no-text references
|
|
clip_threshold_with_text: Lower threshold when text matches
|
|
clip_threshold_text_mismatch: Higher threshold when text expected but missing
|
|
text_similarity_threshold: Threshold for text matching
|
|
margin: Required margin between best and second-best
|
|
use_mean_similarity: Use mean vs max for multi-ref aggregation
|
|
|
|
Returns:
|
|
Tuple of (label, clip_similarity, match_info) or None
|
|
match_info contains: text_matched, text_similarity, threshold_used
|
|
"""
|
|
if not reference_data:
|
|
return None
|
|
|
|
# Extract text from detected region
|
|
detected_texts = self.extract_text(detected_image)
|
|
|
|
# Calculate scores for all logos
|
|
logo_scores = []
|
|
|
|
for label, ref_info in reference_data.items():
|
|
ref_embeddings = ref_info.get('embeddings', [])
|
|
ref_texts = ref_info.get('texts', [])
|
|
|
|
if not ref_embeddings:
|
|
continue
|
|
|
|
# Calculate CLIP similarity
|
|
similarities = []
|
|
for ref_emb in ref_embeddings:
|
|
sim = self.compare_embeddings(detected_embedding, ref_emb)
|
|
similarities.append(sim)
|
|
|
|
if use_mean_similarity:
|
|
clip_score = sum(similarities) / len(similarities)
|
|
else:
|
|
clip_score = max(similarities)
|
|
|
|
# Determine text match status and appropriate threshold
|
|
has_ref_text = len(ref_texts) > 0
|
|
text_matched, text_sim = self.texts_match(
|
|
ref_texts, detected_texts, text_similarity_threshold
|
|
)
|
|
|
|
if has_ref_text:
|
|
if text_matched:
|
|
# Text matches - use lower threshold, boost confidence
|
|
threshold_used = clip_threshold_with_text
|
|
match_type = "text_match"
|
|
else:
|
|
# Reference has text but detection doesn't match
|
|
# Require higher CLIP threshold
|
|
threshold_used = clip_threshold_text_mismatch
|
|
match_type = "text_mismatch"
|
|
else:
|
|
# No text in reference - standard matching
|
|
threshold_used = clip_threshold
|
|
match_type = "no_text"
|
|
text_sim = 0.0
|
|
|
|
# Check if CLIP score meets the appropriate threshold
|
|
if clip_score >= threshold_used:
|
|
logo_scores.append({
|
|
'label': label,
|
|
'clip_score': clip_score,
|
|
'text_matched': text_matched,
|
|
'text_similarity': text_sim,
|
|
'threshold_used': threshold_used,
|
|
'match_type': match_type,
|
|
'has_ref_text': has_ref_text,
|
|
})
|
|
|
|
if not logo_scores:
|
|
return None
|
|
|
|
# Sort by CLIP score descending
|
|
logo_scores.sort(key=lambda x: x['clip_score'], reverse=True)
|
|
|
|
best = logo_scores[0]
|
|
|
|
# Check margin against second-best
|
|
if margin > 0 and len(logo_scores) > 1:
|
|
second_best_score = logo_scores[1]['clip_score']
|
|
if best['clip_score'] - second_best_score < margin:
|
|
return None
|
|
|
|
match_info = {
|
|
'text_matched': best['text_matched'],
|
|
'text_similarity': best['text_similarity'],
|
|
'threshold_used': best['threshold_used'],
|
|
'match_type': best['match_type'],
|
|
'has_ref_text': best['has_ref_text'],
|
|
'detected_texts': detected_texts,
|
|
}
|
|
|
|
return (best['label'], best['clip_score'], match_info)
|
|
|
|
def prepare_reference_data_hybrid(
|
|
self,
|
|
reference_images: Dict[str, List[np.ndarray]],
|
|
text_min_confidence: float = 0.3,
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
"""
|
|
Prepare reference data for hybrid matching by computing embeddings and extracting text.
|
|
|
|
Args:
|
|
reference_images: Dict mapping logo name to list of reference images (OpenCV BGR)
|
|
text_min_confidence: Minimum confidence for text extraction
|
|
|
|
Returns:
|
|
Dict mapping logo name to {'embeddings': [...], 'texts': [...]}
|
|
"""
|
|
reference_data = {}
|
|
|
|
for logo_name, images in reference_images.items():
|
|
embeddings = []
|
|
all_texts = set()
|
|
|
|
for img in images:
|
|
# Compute CLIP embedding
|
|
emb = self.get_embedding(img)
|
|
embeddings.append(emb)
|
|
|
|
# Extract text
|
|
texts = self.extract_text(img, text_min_confidence)
|
|
all_texts.update(texts)
|
|
|
|
reference_data[logo_name] = {
|
|
'embeddings': embeddings,
|
|
'texts': list(all_texts),
|
|
}
|
|
|
|
if all_texts:
|
|
self.logger.debug(f"Reference '{logo_name}' has text: {all_texts}")
|
|
|
|
return reference_data |