diff --git a/logo_detection_embeddings.py b/logo_detection_embeddings.py new file mode 100644 index 0000000..b4e2747 --- /dev/null +++ b/logo_detection_embeddings.py @@ -0,0 +1,364 @@ +""" +Logo detection using DETR for object detection and selectable embedding models for feature matching. + +This module provides a class for detecting logos in images using: +1. DETR (DEtection TRansformer) for initial logo region detection +2. Selectable embedding model (CLIP, DINOv2, or SigLIP) for feature extraction and matching + +Key features: +- Multiple reference images per logo entry, averaged into a single embedding +- Cache-aware: averaged embeddings are only recalculated when the filenames list changes +- Supports local model directories with fallback to HuggingFace +""" + +import hashlib +import json +import os + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import ( + AutoImageProcessor, + AutoModel, + AutoProcessor, + CLIPModel, + CLIPProcessor, + Dinov2Model, + pipeline, +) +from typing import Any, Dict, List, Optional, Tuple + + +class DetectLogosEmbeddings: + """ + Logo detection class using DETR and a selectable embedding model. + + This class detects logos in images by: + 1. Using DETR to find potential logo regions (bounding boxes) + 2. Extracting embeddings for each detected region using the selected model + 3. Comparing embeddings with averaged reference logo embeddings for identification + + Supported embedding models: + - clip: openai/clip-vit-large-patch14 + - dinov2: facebook/dinov2-base (recommended for visual similarity) + - siglip: google/siglip-base-patch16-224 + """ + + def __init__( + self, + logger, + detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2", + embedding_model_type: str = "dinov2", + detr_threshold: float = 0.5, + ): + """ + Initialize DETR and embedding models. + + Args: + logger: Logger instance for logging + detr_model: HuggingFace model name or local path for DETR object detection + embedding_model_type: One of "clip", "dinov2", or "siglip" + detr_threshold: Confidence threshold for DETR detections (0-1) + """ + self.logger = logger + self.detr_threshold = detr_threshold + self.embedding_model_type = embedding_model_type + + # Set device + self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu" + self.device_index = 0 if torch.cuda.is_available() else -1 + self.device = torch.device(self.device_str) + + self.logger.info( + f"Initializing DetectLogosEmbeddings on device: {self.device_str}, " + f"embedding model: {embedding_model_type}" + ) + + # --- DETR model --- + default_detr_dir = os.environ.get( + "LOGO_DETR_MODEL_DIR", "models/logo_detection/detr" + ) + detr_model_path = self._resolve_model_path(detr_model, default_detr_dir, "DETR") + + self.logger.info(f"Loading DETR model: {detr_model_path}") + self.detr_pipe = pipeline( + task="object-detection", + model=detr_model_path, + device=self.device_index, + use_fast=True, + ) + + # --- Embedding model --- + self._load_embedding_model(embedding_model_type) + + self.logger.info("DetectLogosEmbeddings initialization complete") + + def _load_embedding_model(self, model_type: str) -> None: + """ + Load the selected embedding model. + + Args: + model_type: One of "clip", "dinov2", or "siglip" + """ + default_embedding_dir = os.environ.get( + "LOGO_EMBEDDING_MODEL_DIR", f"models/logo_detection/{model_type}" + ) + + if model_type == "clip": + model_name = "openai/clip-vit-large-patch14" + model_path = self._resolve_model_path( + model_name, default_embedding_dir, "CLIP" + ) + self.logger.info(f"Loading CLIP model: {model_path}") + self._clip_model = CLIPModel.from_pretrained(model_path).to(self.device) + self._clip_processor = CLIPProcessor.from_pretrained(model_path) + self._clip_model.eval() + + def embed_fn(pil_image): + inputs = self._clip_processor( + images=pil_image, return_tensors="pt" + ).to(self.device) + with torch.no_grad(): + features = self._clip_model.get_image_features(**inputs) + return F.normalize(features, dim=-1) + + elif model_type == "dinov2": + model_name = "facebook/dinov2-base" + model_path = self._resolve_model_path( + model_name, default_embedding_dir, "DINOv2" + ) + self.logger.info(f"Loading DINOv2 model: {model_path}") + self._dinov2_model = Dinov2Model.from_pretrained(model_path).to(self.device) + self._dinov2_processor = AutoImageProcessor.from_pretrained(model_path) + self._dinov2_model.eval() + + def embed_fn(pil_image): + inputs = self._dinov2_processor( + images=pil_image, return_tensors="pt" + ).to(self.device) + with torch.no_grad(): + outputs = self._dinov2_model(**inputs) + # Use CLS token embedding + features = outputs.last_hidden_state[:, 0, :] + return F.normalize(features, dim=-1) + + elif model_type == "siglip": + model_name = "google/siglip-base-patch16-224" + model_path = self._resolve_model_path( + model_name, default_embedding_dir, "SigLIP" + ) + self.logger.info(f"Loading SigLIP model: {model_path}") + self._siglip_model = AutoModel.from_pretrained(model_path).to(self.device) + self._siglip_processor = AutoProcessor.from_pretrained(model_path) + self._siglip_model.eval() + + def embed_fn(pil_image): + inputs = self._siglip_processor( + images=pil_image, return_tensors="pt" + ).to(self.device) + with torch.no_grad(): + features = self._siglip_model.get_image_features(**inputs) + return F.normalize(features, dim=-1) + + else: + raise ValueError( + f"Unknown embedding model type: {model_type}. " + f"Use 'clip', 'dinov2', or 'siglip'" + ) + + self._embed_fn = embed_fn + + def _resolve_model_path( + self, model_name_or_path: str, default_local_dir: str, model_type: str + ) -> str: + """ + Resolve model path, checking for local models before using HuggingFace. + + Args: + model_name_or_path: HuggingFace model name or absolute path + default_local_dir: Default local directory to check + model_type: Type of model (for logging) + + Returns: + Resolved model path (local path or HuggingFace model name) + """ + # If it's an absolute path, use it directly + if os.path.isabs(model_name_or_path): + if os.path.exists(model_name_or_path): + self.logger.info( + f"{model_type} model: Using local model at {model_name_or_path}" + ) + return model_name_or_path + else: + self.logger.warning( + f"{model_type} model: Local path {model_name_or_path} does not exist, " + f"falling back to HuggingFace" + ) + return model_name_or_path + + # Check if default local directory exists + if os.path.exists(default_local_dir): + config_file = os.path.join(default_local_dir, "config.json") + if os.path.exists(config_file): + abs_path = os.path.abspath(default_local_dir) + self.logger.info( + f"{model_type} model: Found local model at {abs_path}" + ) + return abs_path + else: + self.logger.warning( + f"{model_type} model: Local directory {default_local_dir} exists but " + f"is not a valid model (missing config.json)" + ) + + # Use HuggingFace model name + self.logger.info( + f"{model_type} model: No local model found, will download from HuggingFace: " + f"{model_name_or_path}" + ) + return model_name_or_path + + def detect(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Detect logos in an image and return bounding boxes with embeddings. + + Args: + image: OpenCV image (BGR format, numpy array) + + Returns: + List of dictionaries, each containing: + - 'box': dict with 'xmin', 'ymin', 'xmax', 'ymax' (pixel coordinates) + - 'score': DETR confidence score (float 0-1) + - 'embedding': Feature embedding (torch.Tensor) + - 'label': DETR predicted label (string) + """ + # Convert OpenCV BGR to RGB PIL Image + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(image_rgb) + + # Run DETR detection + predictions = self.detr_pipe(pil_image) + + # Filter by threshold and add embeddings + detections = [] + for pred in predictions: + score = pred.get("score", 0.0) + if score < self.detr_threshold: + continue + + box = pred.get("box", {}) + xmin = box.get("xmin", 0) + ymin = box.get("ymin", 0) + xmax = box.get("xmax", 0) + ymax = box.get("ymax", 0) + + # Extract bounding box region + bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax)) + + # Get embedding for this region + embedding = self._embed_fn(bbox_crop) + + detections.append( + { + "box": {"xmin": xmin, "ymin": ymin, "xmax": xmax, "ymax": ymax}, + "score": score, + "embedding": embedding, + "label": pred.get("label", "logo"), + } + ) + + self.logger.debug( + f"Detected {len(detections)} logos (threshold: {self.detr_threshold})" + ) + return detections + + def get_embedding(self, image: np.ndarray) -> torch.Tensor: + """ + Get embedding for a single reference logo image. + + Args: + image: OpenCV image (BGR format, numpy array) + + Returns: + Normalized feature embedding (torch.Tensor) + """ + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(image_rgb) + return self._embed_fn(pil_image) + + def get_averaged_embedding(self, images: List[np.ndarray]) -> Optional[torch.Tensor]: + """ + Compute averaged embedding from multiple reference logo images. + + Follows the averaging pattern from db_embeddings.py: + 1. Compute embedding for each image + 2. Stack and average across all images + 3. Re-normalize the averaged embedding + + Args: + images: List of OpenCV images (BGR format, numpy arrays) + + Returns: + Normalized averaged embedding (torch.Tensor, shape [1, D]), + or None if no valid embeddings could be computed + """ + embeddings = [] + for img in images: + try: + emb = self.get_embedding(img) + embeddings.append(emb) + except Exception as e: + self.logger.warning(f"Failed to compute embedding for reference image: {e}") + + if not embeddings: + return None + + # Stack: (N, D), average: (1, D), re-normalize + stacked = torch.cat(embeddings, dim=0) + avg_emb = stacked.mean(dim=0, keepdim=True) + avg_emb = F.normalize(avg_emb, dim=-1) + + self.logger.debug( + f"Computed averaged embedding from {len(embeddings)} reference image(s)" + ) + return avg_emb + + def compare_embeddings( + self, embedding1: torch.Tensor, embedding2: torch.Tensor + ) -> float: + """ + Compute cosine similarity between two embeddings. + + Args: + embedding1: First embedding (torch.Tensor) + embedding2: Second embedding (torch.Tensor) + + Returns: + Cosine similarity score (float, range: -1 to 1, typically 0 to 1) + """ + # Ensure tensors are on the same device + if embedding1.device != embedding2.device: + embedding2 = embedding2.to(embedding1.device) + + similarity = F.cosine_similarity(embedding1, embedding2, dim=-1) + return similarity.item() + + @staticmethod + def make_filenames_hash(filenames: List[str]) -> str: + """ + Compute a deterministic hash of a filenames list. + + Used for cache invalidation — if the filenames list changes, + the hash changes, triggering re-computation of averaged embeddings. + + Args: + filenames: List of filename strings + + Returns: + 16-character hex hash string + """ + canonical = json.dumps(sorted(filenames)) + return hashlib.sha256(canonical.encode("utf-8")).hexdigest()[:16]