Add latest test detection method
This commit is contained in:
364
logo_detection_embeddings.py
Normal file
364
logo_detection_embeddings.py
Normal file
@ -0,0 +1,364 @@
|
|||||||
|
"""
|
||||||
|
Logo detection using DETR for object detection and selectable embedding models for feature matching.
|
||||||
|
|
||||||
|
This module provides a class for detecting logos in images using:
|
||||||
|
1. DETR (DEtection TRansformer) for initial logo region detection
|
||||||
|
2. Selectable embedding model (CLIP, DINOv2, or SigLIP) for feature extraction and matching
|
||||||
|
|
||||||
|
Key features:
|
||||||
|
- Multiple reference images per logo entry, averaged into a single embedding
|
||||||
|
- Cache-aware: averaged embeddings are only recalculated when the filenames list changes
|
||||||
|
- Supports local model directories with fallback to HuggingFace
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import (
|
||||||
|
AutoImageProcessor,
|
||||||
|
AutoModel,
|
||||||
|
AutoProcessor,
|
||||||
|
CLIPModel,
|
||||||
|
CLIPProcessor,
|
||||||
|
Dinov2Model,
|
||||||
|
pipeline,
|
||||||
|
)
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
class DetectLogosEmbeddings:
|
||||||
|
"""
|
||||||
|
Logo detection class using DETR and a selectable embedding model.
|
||||||
|
|
||||||
|
This class detects logos in images by:
|
||||||
|
1. Using DETR to find potential logo regions (bounding boxes)
|
||||||
|
2. Extracting embeddings for each detected region using the selected model
|
||||||
|
3. Comparing embeddings with averaged reference logo embeddings for identification
|
||||||
|
|
||||||
|
Supported embedding models:
|
||||||
|
- clip: openai/clip-vit-large-patch14
|
||||||
|
- dinov2: facebook/dinov2-base (recommended for visual similarity)
|
||||||
|
- siglip: google/siglip-base-patch16-224
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
logger,
|
||||||
|
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
|
||||||
|
embedding_model_type: str = "dinov2",
|
||||||
|
detr_threshold: float = 0.5,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize DETR and embedding models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logger: Logger instance for logging
|
||||||
|
detr_model: HuggingFace model name or local path for DETR object detection
|
||||||
|
embedding_model_type: One of "clip", "dinov2", or "siglip"
|
||||||
|
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||||
|
"""
|
||||||
|
self.logger = logger
|
||||||
|
self.detr_threshold = detr_threshold
|
||||||
|
self.embedding_model_type = embedding_model_type
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
|
self.device_index = 0 if torch.cuda.is_available() else -1
|
||||||
|
self.device = torch.device(self.device_str)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"Initializing DetectLogosEmbeddings on device: {self.device_str}, "
|
||||||
|
f"embedding model: {embedding_model_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- DETR model ---
|
||||||
|
default_detr_dir = os.environ.get(
|
||||||
|
"LOGO_DETR_MODEL_DIR", "models/logo_detection/detr"
|
||||||
|
)
|
||||||
|
detr_model_path = self._resolve_model_path(detr_model, default_detr_dir, "DETR")
|
||||||
|
|
||||||
|
self.logger.info(f"Loading DETR model: {detr_model_path}")
|
||||||
|
self.detr_pipe = pipeline(
|
||||||
|
task="object-detection",
|
||||||
|
model=detr_model_path,
|
||||||
|
device=self.device_index,
|
||||||
|
use_fast=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Embedding model ---
|
||||||
|
self._load_embedding_model(embedding_model_type)
|
||||||
|
|
||||||
|
self.logger.info("DetectLogosEmbeddings initialization complete")
|
||||||
|
|
||||||
|
def _load_embedding_model(self, model_type: str) -> None:
|
||||||
|
"""
|
||||||
|
Load the selected embedding model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_type: One of "clip", "dinov2", or "siglip"
|
||||||
|
"""
|
||||||
|
default_embedding_dir = os.environ.get(
|
||||||
|
"LOGO_EMBEDDING_MODEL_DIR", f"models/logo_detection/{model_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if model_type == "clip":
|
||||||
|
model_name = "openai/clip-vit-large-patch14"
|
||||||
|
model_path = self._resolve_model_path(
|
||||||
|
model_name, default_embedding_dir, "CLIP"
|
||||||
|
)
|
||||||
|
self.logger.info(f"Loading CLIP model: {model_path}")
|
||||||
|
self._clip_model = CLIPModel.from_pretrained(model_path).to(self.device)
|
||||||
|
self._clip_processor = CLIPProcessor.from_pretrained(model_path)
|
||||||
|
self._clip_model.eval()
|
||||||
|
|
||||||
|
def embed_fn(pil_image):
|
||||||
|
inputs = self._clip_processor(
|
||||||
|
images=pil_image, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
features = self._clip_model.get_image_features(**inputs)
|
||||||
|
return F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
elif model_type == "dinov2":
|
||||||
|
model_name = "facebook/dinov2-base"
|
||||||
|
model_path = self._resolve_model_path(
|
||||||
|
model_name, default_embedding_dir, "DINOv2"
|
||||||
|
)
|
||||||
|
self.logger.info(f"Loading DINOv2 model: {model_path}")
|
||||||
|
self._dinov2_model = Dinov2Model.from_pretrained(model_path).to(self.device)
|
||||||
|
self._dinov2_processor = AutoImageProcessor.from_pretrained(model_path)
|
||||||
|
self._dinov2_model.eval()
|
||||||
|
|
||||||
|
def embed_fn(pil_image):
|
||||||
|
inputs = self._dinov2_processor(
|
||||||
|
images=pil_image, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = self._dinov2_model(**inputs)
|
||||||
|
# Use CLS token embedding
|
||||||
|
features = outputs.last_hidden_state[:, 0, :]
|
||||||
|
return F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
elif model_type == "siglip":
|
||||||
|
model_name = "google/siglip-base-patch16-224"
|
||||||
|
model_path = self._resolve_model_path(
|
||||||
|
model_name, default_embedding_dir, "SigLIP"
|
||||||
|
)
|
||||||
|
self.logger.info(f"Loading SigLIP model: {model_path}")
|
||||||
|
self._siglip_model = AutoModel.from_pretrained(model_path).to(self.device)
|
||||||
|
self._siglip_processor = AutoProcessor.from_pretrained(model_path)
|
||||||
|
self._siglip_model.eval()
|
||||||
|
|
||||||
|
def embed_fn(pil_image):
|
||||||
|
inputs = self._siglip_processor(
|
||||||
|
images=pil_image, return_tensors="pt"
|
||||||
|
).to(self.device)
|
||||||
|
with torch.no_grad():
|
||||||
|
features = self._siglip_model.get_image_features(**inputs)
|
||||||
|
return F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown embedding model type: {model_type}. "
|
||||||
|
f"Use 'clip', 'dinov2', or 'siglip'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._embed_fn = embed_fn
|
||||||
|
|
||||||
|
def _resolve_model_path(
|
||||||
|
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Resolve model path, checking for local models before using HuggingFace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name_or_path: HuggingFace model name or absolute path
|
||||||
|
default_local_dir: Default local directory to check
|
||||||
|
model_type: Type of model (for logging)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Resolved model path (local path or HuggingFace model name)
|
||||||
|
"""
|
||||||
|
# If it's an absolute path, use it directly
|
||||||
|
if os.path.isabs(model_name_or_path):
|
||||||
|
if os.path.exists(model_name_or_path):
|
||||||
|
self.logger.info(
|
||||||
|
f"{model_type} model: Using local model at {model_name_or_path}"
|
||||||
|
)
|
||||||
|
return model_name_or_path
|
||||||
|
else:
|
||||||
|
self.logger.warning(
|
||||||
|
f"{model_type} model: Local path {model_name_or_path} does not exist, "
|
||||||
|
f"falling back to HuggingFace"
|
||||||
|
)
|
||||||
|
return model_name_or_path
|
||||||
|
|
||||||
|
# Check if default local directory exists
|
||||||
|
if os.path.exists(default_local_dir):
|
||||||
|
config_file = os.path.join(default_local_dir, "config.json")
|
||||||
|
if os.path.exists(config_file):
|
||||||
|
abs_path = os.path.abspath(default_local_dir)
|
||||||
|
self.logger.info(
|
||||||
|
f"{model_type} model: Found local model at {abs_path}"
|
||||||
|
)
|
||||||
|
return abs_path
|
||||||
|
else:
|
||||||
|
self.logger.warning(
|
||||||
|
f"{model_type} model: Local directory {default_local_dir} exists but "
|
||||||
|
f"is not a valid model (missing config.json)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use HuggingFace model name
|
||||||
|
self.logger.info(
|
||||||
|
f"{model_type} model: No local model found, will download from HuggingFace: "
|
||||||
|
f"{model_name_or_path}"
|
||||||
|
)
|
||||||
|
return model_name_or_path
|
||||||
|
|
||||||
|
def detect(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Detect logos in an image and return bounding boxes with embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: OpenCV image (BGR format, numpy array)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dictionaries, each containing:
|
||||||
|
- 'box': dict with 'xmin', 'ymin', 'xmax', 'ymax' (pixel coordinates)
|
||||||
|
- 'score': DETR confidence score (float 0-1)
|
||||||
|
- 'embedding': Feature embedding (torch.Tensor)
|
||||||
|
- 'label': DETR predicted label (string)
|
||||||
|
"""
|
||||||
|
# Convert OpenCV BGR to RGB PIL Image
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
pil_image = Image.fromarray(image_rgb)
|
||||||
|
|
||||||
|
# Run DETR detection
|
||||||
|
predictions = self.detr_pipe(pil_image)
|
||||||
|
|
||||||
|
# Filter by threshold and add embeddings
|
||||||
|
detections = []
|
||||||
|
for pred in predictions:
|
||||||
|
score = pred.get("score", 0.0)
|
||||||
|
if score < self.detr_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
box = pred.get("box", {})
|
||||||
|
xmin = box.get("xmin", 0)
|
||||||
|
ymin = box.get("ymin", 0)
|
||||||
|
xmax = box.get("xmax", 0)
|
||||||
|
ymax = box.get("ymax", 0)
|
||||||
|
|
||||||
|
# Extract bounding box region
|
||||||
|
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
|
||||||
|
|
||||||
|
# Get embedding for this region
|
||||||
|
embedding = self._embed_fn(bbox_crop)
|
||||||
|
|
||||||
|
detections.append(
|
||||||
|
{
|
||||||
|
"box": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax},
|
||||||
|
"score": score,
|
||||||
|
"embedding": embedding,
|
||||||
|
"label": pred.get("label", "logo"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Detected {len(detections)} logos (threshold: {self.detr_threshold})"
|
||||||
|
)
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get embedding for a single reference logo image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: OpenCV image (BGR format, numpy array)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized feature embedding (torch.Tensor)
|
||||||
|
"""
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
pil_image = Image.fromarray(image_rgb)
|
||||||
|
return self._embed_fn(pil_image)
|
||||||
|
|
||||||
|
def get_averaged_embedding(self, images: List[np.ndarray]) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Compute averaged embedding from multiple reference logo images.
|
||||||
|
|
||||||
|
Follows the averaging pattern from db_embeddings.py:
|
||||||
|
1. Compute embedding for each image
|
||||||
|
2. Stack and average across all images
|
||||||
|
3. Re-normalize the averaged embedding
|
||||||
|
|
||||||
|
Args:
|
||||||
|
images: List of OpenCV images (BGR format, numpy arrays)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized averaged embedding (torch.Tensor, shape [1, D]),
|
||||||
|
or None if no valid embeddings could be computed
|
||||||
|
"""
|
||||||
|
embeddings = []
|
||||||
|
for img in images:
|
||||||
|
try:
|
||||||
|
emb = self.get_embedding(img)
|
||||||
|
embeddings.append(emb)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Failed to compute embedding for reference image: {e}")
|
||||||
|
|
||||||
|
if not embeddings:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Stack: (N, D), average: (1, D), re-normalize
|
||||||
|
stacked = torch.cat(embeddings, dim=0)
|
||||||
|
avg_emb = stacked.mean(dim=0, keepdim=True)
|
||||||
|
avg_emb = F.normalize(avg_emb, dim=-1)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Computed averaged embedding from {len(embeddings)} reference image(s)"
|
||||||
|
)
|
||||||
|
return avg_emb
|
||||||
|
|
||||||
|
def compare_embeddings(
|
||||||
|
self, embedding1: torch.Tensor, embedding2: torch.Tensor
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Compute cosine similarity between two embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding1: First embedding (torch.Tensor)
|
||||||
|
embedding2: Second embedding (torch.Tensor)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cosine similarity score (float, range: -1 to 1, typically 0 to 1)
|
||||||
|
"""
|
||||||
|
# Ensure tensors are on the same device
|
||||||
|
if embedding1.device != embedding2.device:
|
||||||
|
embedding2 = embedding2.to(embedding1.device)
|
||||||
|
|
||||||
|
similarity = F.cosine_similarity(embedding1, embedding2, dim=-1)
|
||||||
|
return similarity.item()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def make_filenames_hash(filenames: List[str]) -> str:
|
||||||
|
"""
|
||||||
|
Compute a deterministic hash of a filenames list.
|
||||||
|
|
||||||
|
Used for cache invalidation — if the filenames list changes,
|
||||||
|
the hash changes, triggering re-computation of averaged embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filenames: List of filename strings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
16-character hex hash string
|
||||||
|
"""
|
||||||
|
canonical = json.dumps(sorted(filenames))
|
||||||
|
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]
|
||||||
Reference in New Issue
Block a user