Add latest test detection method
This commit is contained in:
364
logo_detection_embeddings.py
Normal file
364
logo_detection_embeddings.py
Normal file
@ -0,0 +1,364 @@
|
||||
"""
|
||||
Logo detection using DETR for object detection and selectable embedding models for feature matching.
|
||||
|
||||
This module provides a class for detecting logos in images using:
|
||||
1. DETR (DEtection TRansformer) for initial logo region detection
|
||||
2. Selectable embedding model (CLIP, DINOv2, or SigLIP) for feature extraction and matching
|
||||
|
||||
Key features:
|
||||
- Multiple reference images per logo entry, averaged into a single embedding
|
||||
- Cache-aware: averaged embeddings are only recalculated when the filenames list changes
|
||||
- Supports local model directories with fallback to HuggingFace
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoModel,
|
||||
AutoProcessor,
|
||||
CLIPModel,
|
||||
CLIPProcessor,
|
||||
Dinov2Model,
|
||||
pipeline,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class DetectLogosEmbeddings:
|
||||
"""
|
||||
Logo detection class using DETR and a selectable embedding model.
|
||||
|
||||
This class detects logos in images by:
|
||||
1. Using DETR to find potential logo regions (bounding boxes)
|
||||
2. Extracting embeddings for each detected region using the selected model
|
||||
3. Comparing embeddings with averaged reference logo embeddings for identification
|
||||
|
||||
Supported embedding models:
|
||||
- clip: openai/clip-vit-large-patch14
|
||||
- dinov2: facebook/dinov2-base (recommended for visual similarity)
|
||||
- siglip: google/siglip-base-patch16-224
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
logger,
|
||||
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
|
||||
embedding_model_type: str = "dinov2",
|
||||
detr_threshold: float = 0.5,
|
||||
):
|
||||
"""
|
||||
Initialize DETR and embedding models.
|
||||
|
||||
Args:
|
||||
logger: Logger instance for logging
|
||||
detr_model: HuggingFace model name or local path for DETR object detection
|
||||
embedding_model_type: One of "clip", "dinov2", or "siglip"
|
||||
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||
"""
|
||||
self.logger = logger
|
||||
self.detr_threshold = detr_threshold
|
||||
self.embedding_model_type = embedding_model_type
|
||||
|
||||
# Set device
|
||||
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
self.device_index = 0 if torch.cuda.is_available() else -1
|
||||
self.device = torch.device(self.device_str)
|
||||
|
||||
self.logger.info(
|
||||
f"Initializing DetectLogosEmbeddings on device: {self.device_str}, "
|
||||
f"embedding model: {embedding_model_type}"
|
||||
)
|
||||
|
||||
# --- DETR model ---
|
||||
default_detr_dir = os.environ.get(
|
||||
"LOGO_DETR_MODEL_DIR", "models/logo_detection/detr"
|
||||
)
|
||||
detr_model_path = self._resolve_model_path(detr_model, default_detr_dir, "DETR")
|
||||
|
||||
self.logger.info(f"Loading DETR model: {detr_model_path}")
|
||||
self.detr_pipe = pipeline(
|
||||
task="object-detection",
|
||||
model=detr_model_path,
|
||||
device=self.device_index,
|
||||
use_fast=True,
|
||||
)
|
||||
|
||||
# --- Embedding model ---
|
||||
self._load_embedding_model(embedding_model_type)
|
||||
|
||||
self.logger.info("DetectLogosEmbeddings initialization complete")
|
||||
|
||||
def _load_embedding_model(self, model_type: str) -> None:
|
||||
"""
|
||||
Load the selected embedding model.
|
||||
|
||||
Args:
|
||||
model_type: One of "clip", "dinov2", or "siglip"
|
||||
"""
|
||||
default_embedding_dir = os.environ.get(
|
||||
"LOGO_EMBEDDING_MODEL_DIR", f"models/logo_detection/{model_type}"
|
||||
)
|
||||
|
||||
if model_type == "clip":
|
||||
model_name = "openai/clip-vit-large-patch14"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "CLIP"
|
||||
)
|
||||
self.logger.info(f"Loading CLIP model: {model_path}")
|
||||
self._clip_model = CLIPModel.from_pretrained(model_path).to(self.device)
|
||||
self._clip_processor = CLIPProcessor.from_pretrained(model_path)
|
||||
self._clip_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._clip_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self._clip_model.get_image_features(**inputs)
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
elif model_type == "dinov2":
|
||||
model_name = "facebook/dinov2-base"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "DINOv2"
|
||||
)
|
||||
self.logger.info(f"Loading DINOv2 model: {model_path}")
|
||||
self._dinov2_model = Dinov2Model.from_pretrained(model_path).to(self.device)
|
||||
self._dinov2_processor = AutoImageProcessor.from_pretrained(model_path)
|
||||
self._dinov2_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._dinov2_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
outputs = self._dinov2_model(**inputs)
|
||||
# Use CLS token embedding
|
||||
features = outputs.last_hidden_state[:, 0, :]
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
elif model_type == "siglip":
|
||||
model_name = "google/siglip-base-patch16-224"
|
||||
model_path = self._resolve_model_path(
|
||||
model_name, default_embedding_dir, "SigLIP"
|
||||
)
|
||||
self.logger.info(f"Loading SigLIP model: {model_path}")
|
||||
self._siglip_model = AutoModel.from_pretrained(model_path).to(self.device)
|
||||
self._siglip_processor = AutoProcessor.from_pretrained(model_path)
|
||||
self._siglip_model.eval()
|
||||
|
||||
def embed_fn(pil_image):
|
||||
inputs = self._siglip_processor(
|
||||
images=pil_image, return_tensors="pt"
|
||||
).to(self.device)
|
||||
with torch.no_grad():
|
||||
features = self._siglip_model.get_image_features(**inputs)
|
||||
return F.normalize(features, dim=-1)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown embedding model type: {model_type}. "
|
||||
f"Use 'clip', 'dinov2', or 'siglip'"
|
||||
)
|
||||
|
||||
self._embed_fn = embed_fn
|
||||
|
||||
def _resolve_model_path(
|
||||
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
||||
) -> str:
|
||||
"""
|
||||
Resolve model path, checking for local models before using HuggingFace.
|
||||
|
||||
Args:
|
||||
model_name_or_path: HuggingFace model name or absolute path
|
||||
default_local_dir: Default local directory to check
|
||||
model_type: Type of model (for logging)
|
||||
|
||||
Returns:
|
||||
Resolved model path (local path or HuggingFace model name)
|
||||
"""
|
||||
# If it's an absolute path, use it directly
|
||||
if os.path.isabs(model_name_or_path):
|
||||
if os.path.exists(model_name_or_path):
|
||||
self.logger.info(
|
||||
f"{model_type} model: Using local model at {model_name_or_path}"
|
||||
)
|
||||
return model_name_or_path
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"{model_type} model: Local path {model_name_or_path} does not exist, "
|
||||
f"falling back to HuggingFace"
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
# Check if default local directory exists
|
||||
if os.path.exists(default_local_dir):
|
||||
config_file = os.path.join(default_local_dir, "config.json")
|
||||
if os.path.exists(config_file):
|
||||
abs_path = os.path.abspath(default_local_dir)
|
||||
self.logger.info(
|
||||
f"{model_type} model: Found local model at {abs_path}"
|
||||
)
|
||||
return abs_path
|
||||
else:
|
||||
self.logger.warning(
|
||||
f"{model_type} model: Local directory {default_local_dir} exists but "
|
||||
f"is not a valid model (missing config.json)"
|
||||
)
|
||||
|
||||
# Use HuggingFace model name
|
||||
self.logger.info(
|
||||
f"{model_type} model: No local model found, will download from HuggingFace: "
|
||||
f"{model_name_or_path}"
|
||||
)
|
||||
return model_name_or_path
|
||||
|
||||
def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect logos in an image and return bounding boxes with embeddings.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (BGR format, numpy array)
|
||||
|
||||
Returns:
|
||||
List of dictionaries, each containing:
|
||||
- 'box': dict with 'xmin', 'ymin', 'xmax', 'ymax' (pixel coordinates)
|
||||
- 'score': DETR confidence score (float 0-1)
|
||||
- 'embedding': Feature embedding (torch.Tensor)
|
||||
- 'label': DETR predicted label (string)
|
||||
"""
|
||||
# Convert OpenCV BGR to RGB PIL Image
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
|
||||
# Run DETR detection
|
||||
predictions = self.detr_pipe(pil_image)
|
||||
|
||||
# Filter by threshold and add embeddings
|
||||
detections = []
|
||||
for pred in predictions:
|
||||
score = pred.get("score", 0.0)
|
||||
if score < self.detr_threshold:
|
||||
continue
|
||||
|
||||
box = pred.get("box", {})
|
||||
xmin = box.get("xmin", 0)
|
||||
ymin = box.get("ymin", 0)
|
||||
xmax = box.get("xmax", 0)
|
||||
ymax = box.get("ymax", 0)
|
||||
|
||||
# Extract bounding box region
|
||||
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
|
||||
|
||||
# Get embedding for this region
|
||||
embedding = self._embed_fn(bbox_crop)
|
||||
|
||||
detections.append(
|
||||
{
|
||||
"box": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax},
|
||||
"score": score,
|
||||
"embedding": embedding,
|
||||
"label": pred.get("label", "logo"),
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.debug(
|
||||
f"Detected {len(detections)} logos (threshold: {self.detr_threshold})"
|
||||
)
|
||||
return detections
|
||||
|
||||
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
Get embedding for a single reference logo image.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (BGR format, numpy array)
|
||||
|
||||
Returns:
|
||||
Normalized feature embedding (torch.Tensor)
|
||||
"""
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(image_rgb)
|
||||
return self._embed_fn(pil_image)
|
||||
|
||||
def get_averaged_embedding(self, images: List[np.ndarray]) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Compute averaged embedding from multiple reference logo images.
|
||||
|
||||
Follows the averaging pattern from db_embeddings.py:
|
||||
1. Compute embedding for each image
|
||||
2. Stack and average across all images
|
||||
3. Re-normalize the averaged embedding
|
||||
|
||||
Args:
|
||||
images: List of OpenCV images (BGR format, numpy arrays)
|
||||
|
||||
Returns:
|
||||
Normalized averaged embedding (torch.Tensor, shape [1, D]),
|
||||
or None if no valid embeddings could be computed
|
||||
"""
|
||||
embeddings = []
|
||||
for img in images:
|
||||
try:
|
||||
emb = self.get_embedding(img)
|
||||
embeddings.append(emb)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to compute embedding for reference image: {e}")
|
||||
|
||||
if not embeddings:
|
||||
return None
|
||||
|
||||
# Stack: (N, D), average: (1, D), re-normalize
|
||||
stacked = torch.cat(embeddings, dim=0)
|
||||
avg_emb = stacked.mean(dim=0, keepdim=True)
|
||||
avg_emb = F.normalize(avg_emb, dim=-1)
|
||||
|
||||
self.logger.debug(
|
||||
f"Computed averaged embedding from {len(embeddings)} reference image(s)"
|
||||
)
|
||||
return avg_emb
|
||||
|
||||
def compare_embeddings(
|
||||
self, embedding1: torch.Tensor, embedding2: torch.Tensor
|
||||
) -> float:
|
||||
"""
|
||||
Compute cosine similarity between two embeddings.
|
||||
|
||||
Args:
|
||||
embedding1: First embedding (torch.Tensor)
|
||||
embedding2: Second embedding (torch.Tensor)
|
||||
|
||||
Returns:
|
||||
Cosine similarity score (float, range: -1 to 1, typically 0 to 1)
|
||||
"""
|
||||
# Ensure tensors are on the same device
|
||||
if embedding1.device != embedding2.device:
|
||||
embedding2 = embedding2.to(embedding1.device)
|
||||
|
||||
similarity = F.cosine_similarity(embedding1, embedding2, dim=-1)
|
||||
return similarity.item()
|
||||
|
||||
@staticmethod
|
||||
def make_filenames_hash(filenames: List[str]) -> str:
|
||||
"""
|
||||
Compute a deterministic hash of a filenames list.
|
||||
|
||||
Used for cache invalidation — if the filenames list changes,
|
||||
the hash changes, triggering re-computation of averaged embeddings.
|
||||
|
||||
Args:
|
||||
filenames: List of filename strings
|
||||
|
||||
Returns:
|
||||
16-character hex hash string
|
||||
"""
|
||||
canonical = json.dumps(sorted(filenames))
|
||||
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]
|
||||
Reference in New Issue
Block a user