Add latest test detection method

This commit is contained in:
Rick McEwen
2026-03-31 11:51:26 -06:00
parent f598866d37
commit 512f678310

View File

@ -0,0 +1,364 @@
"""
Logo detection using DETR for object detection and selectable embedding models for feature matching.
This module provides a class for detecting logos in images using:
1. DETR (DEtection TRansformer) for initial logo region detection
2. Selectable embedding model (CLIP, DINOv2, or SigLIP) for feature extraction and matching
Key features:
- Multiple reference images per logo entry, averaged into a single embedding
- Cache-aware: averaged embeddings are only recalculated when the filenames list changes
- Supports local model directories with fallback to HuggingFace
"""
import hashlib
import json
import os
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import (
AutoImageProcessor,
AutoModel,
AutoProcessor,
CLIPModel,
CLIPProcessor,
Dinov2Model,
pipeline,
)
from typing import Any, Dict, List, Optional, Tuple
class DetectLogosEmbeddings:
"""
Logo detection class using DETR and a selectable embedding model.
This class detects logos in images by:
1. Using DETR to find potential logo regions (bounding boxes)
2. Extracting embeddings for each detected region using the selected model
3. Comparing embeddings with averaged reference logo embeddings for identification
Supported embedding models:
- clip: openai/clip-vit-large-patch14
- dinov2: facebook/dinov2-base (recommended for visual similarity)
- siglip: google/siglip-base-patch16-224
"""
def __init__(
self,
logger,
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
embedding_model_type: str = "dinov2",
detr_threshold: float = 0.5,
):
"""
Initialize DETR and embedding models.
Args:
logger: Logger instance for logging
detr_model: HuggingFace model name or local path for DETR object detection
embedding_model_type: One of "clip", "dinov2", or "siglip"
detr_threshold: Confidence threshold for DETR detections (0-1)
"""
self.logger = logger
self.detr_threshold = detr_threshold
self.embedding_model_type = embedding_model_type
# Set device
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
self.device_index = 0 if torch.cuda.is_available() else -1
self.device = torch.device(self.device_str)
self.logger.info(
f"Initializing DetectLogosEmbeddings on device: {self.device_str}, "
f"embedding model: {embedding_model_type}"
)
# --- DETR model ---
default_detr_dir = os.environ.get(
"LOGO_DETR_MODEL_DIR", "models/logo_detection/detr"
)
detr_model_path = self._resolve_model_path(detr_model, default_detr_dir, "DETR")
self.logger.info(f"Loading DETR model: {detr_model_path}")
self.detr_pipe = pipeline(
task="object-detection",
model=detr_model_path,
device=self.device_index,
use_fast=True,
)
# --- Embedding model ---
self._load_embedding_model(embedding_model_type)
self.logger.info("DetectLogosEmbeddings initialization complete")
def _load_embedding_model(self, model_type: str) -> None:
"""
Load the selected embedding model.
Args:
model_type: One of "clip", "dinov2", or "siglip"
"""
default_embedding_dir = os.environ.get(
"LOGO_EMBEDDING_MODEL_DIR", f"models/logo_detection/{model_type}"
)
if model_type == "clip":
model_name = "openai/clip-vit-large-patch14"
model_path = self._resolve_model_path(
model_name, default_embedding_dir, "CLIP"
)
self.logger.info(f"Loading CLIP model: {model_path}")
self._clip_model = CLIPModel.from_pretrained(model_path).to(self.device)
self._clip_processor = CLIPProcessor.from_pretrained(model_path)
self._clip_model.eval()
def embed_fn(pil_image):
inputs = self._clip_processor(
images=pil_image, return_tensors="pt"
).to(self.device)
with torch.no_grad():
features = self._clip_model.get_image_features(**inputs)
return F.normalize(features, dim=-1)
elif model_type == "dinov2":
model_name = "facebook/dinov2-base"
model_path = self._resolve_model_path(
model_name, default_embedding_dir, "DINOv2"
)
self.logger.info(f"Loading DINOv2 model: {model_path}")
self._dinov2_model = Dinov2Model.from_pretrained(model_path).to(self.device)
self._dinov2_processor = AutoImageProcessor.from_pretrained(model_path)
self._dinov2_model.eval()
def embed_fn(pil_image):
inputs = self._dinov2_processor(
images=pil_image, return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self._dinov2_model(**inputs)
# Use CLS token embedding
features = outputs.last_hidden_state[:, 0, :]
return F.normalize(features, dim=-1)
elif model_type == "siglip":
model_name = "google/siglip-base-patch16-224"
model_path = self._resolve_model_path(
model_name, default_embedding_dir, "SigLIP"
)
self.logger.info(f"Loading SigLIP model: {model_path}")
self._siglip_model = AutoModel.from_pretrained(model_path).to(self.device)
self._siglip_processor = AutoProcessor.from_pretrained(model_path)
self._siglip_model.eval()
def embed_fn(pil_image):
inputs = self._siglip_processor(
images=pil_image, return_tensors="pt"
).to(self.device)
with torch.no_grad():
features = self._siglip_model.get_image_features(**inputs)
return F.normalize(features, dim=-1)
else:
raise ValueError(
f"Unknown embedding model type: {model_type}. "
f"Use 'clip', 'dinov2', or 'siglip'"
)
self._embed_fn = embed_fn
def _resolve_model_path(
self, model_name_or_path: str, default_local_dir: str, model_type: str
) -> str:
"""
Resolve model path, checking for local models before using HuggingFace.
Args:
model_name_or_path: HuggingFace model name or absolute path
default_local_dir: Default local directory to check
model_type: Type of model (for logging)
Returns:
Resolved model path (local path or HuggingFace model name)
"""
# If it's an absolute path, use it directly
if os.path.isabs(model_name_or_path):
if os.path.exists(model_name_or_path):
self.logger.info(
f"{model_type} model: Using local model at {model_name_or_path}"
)
return model_name_or_path
else:
self.logger.warning(
f"{model_type} model: Local path {model_name_or_path} does not exist, "
f"falling back to HuggingFace"
)
return model_name_or_path
# Check if default local directory exists
if os.path.exists(default_local_dir):
config_file = os.path.join(default_local_dir, "config.json")
if os.path.exists(config_file):
abs_path = os.path.abspath(default_local_dir)
self.logger.info(
f"{model_type} model: Found local model at {abs_path}"
)
return abs_path
else:
self.logger.warning(
f"{model_type} model: Local directory {default_local_dir} exists but "
f"is not a valid model (missing config.json)"
)
# Use HuggingFace model name
self.logger.info(
f"{model_type} model: No local model found, will download from HuggingFace: "
f"{model_name_or_path}"
)
return model_name_or_path
def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
"""
Detect logos in an image and return bounding boxes with embeddings.
Args:
image: OpenCV image (BGR format, numpy array)
Returns:
List of dictionaries, each containing:
- 'box': dict with 'xmin', 'ymin', 'xmax', 'ymax' (pixel coordinates)
- 'score': DETR confidence score (float 0-1)
- 'embedding': Feature embedding (torch.Tensor)
- 'label': DETR predicted label (string)
"""
# Convert OpenCV BGR to RGB PIL Image
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image_rgb)
# Run DETR detection
predictions = self.detr_pipe(pil_image)
# Filter by threshold and add embeddings
detections = []
for pred in predictions:
score = pred.get("score", 0.0)
if score < self.detr_threshold:
continue
box = pred.get("box", {})
xmin = box.get("xmin", 0)
ymin = box.get("ymin", 0)
xmax = box.get("xmax", 0)
ymax = box.get("ymax", 0)
# Extract bounding box region
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
# Get embedding for this region
embedding = self._embed_fn(bbox_crop)
detections.append(
{
"box": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax},
"score": score,
"embedding": embedding,
"label": pred.get("label", "logo"),
}
)
self.logger.debug(
f"Detected {len(detections)} logos (threshold: {self.detr_threshold})"
)
return detections
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
"""
Get embedding for a single reference logo image.
Args:
image: OpenCV image (BGR format, numpy array)
Returns:
Normalized feature embedding (torch.Tensor)
"""
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image_rgb)
return self._embed_fn(pil_image)
def get_averaged_embedding(self, images: List[np.ndarray]) -> Optional[torch.Tensor]:
"""
Compute averaged embedding from multiple reference logo images.
Follows the averaging pattern from db_embeddings.py:
1. Compute embedding for each image
2. Stack and average across all images
3. Re-normalize the averaged embedding
Args:
images: List of OpenCV images (BGR format, numpy arrays)
Returns:
Normalized averaged embedding (torch.Tensor, shape [1, D]),
or None if no valid embeddings could be computed
"""
embeddings = []
for img in images:
try:
emb = self.get_embedding(img)
embeddings.append(emb)
except Exception as e:
self.logger.warning(f"Failed to compute embedding for reference image: {e}")
if not embeddings:
return None
# Stack: (N, D), average: (1, D), re-normalize
stacked = torch.cat(embeddings, dim=0)
avg_emb = stacked.mean(dim=0, keepdim=True)
avg_emb = F.normalize(avg_emb, dim=-1)
self.logger.debug(
f"Computed averaged embedding from {len(embeddings)} reference image(s)"
)
return avg_emb
def compare_embeddings(
self, embedding1: torch.Tensor, embedding2: torch.Tensor
) -> float:
"""
Compute cosine similarity between two embeddings.
Args:
embedding1: First embedding (torch.Tensor)
embedding2: Second embedding (torch.Tensor)
Returns:
Cosine similarity score (float, range: -1 to 1, typically 0 to 1)
"""
# Ensure tensors are on the same device
if embedding1.device != embedding2.device:
embedding2 = embedding2.to(embedding1.device)
similarity = F.cosine_similarity(embedding1, embedding2, dim=-1)
return similarity.item()
@staticmethod
def make_filenames_hash(filenames: List[str]) -> str:
"""
Compute a deterministic hash of a filenames list.
Used for cache invalidation — if the filenames list changes,
the hash changes, triggering re-computation of averaged embeddings.
Args:
filenames: List of filename strings
Returns:
16-character hex hash string
"""
canonical = json.dumps(sorted(filenames))
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]