From 94db5bd40b77a53be492f640803925b6d02694e9 Mon Sep 17 00:00:00 2001 From: Rick McEwen Date: Fri, 2 Jan 2026 12:05:27 -0500 Subject: [PATCH] Add embedding model selection and comparison test scripts - Update DetectLogosDETR to support both CLIP and DINOv2 models - Rename clip_model parameter to embedding_model - Add model type detection for different embedding extraction - DINOv2 uses CLS token, CLIP uses get_image_features() - Add -e/--embedding-model argument to test_logo_detection.py - Include model name in file output header - Add run_threshold_tests.sh for testing various threshold/margin values - Add run_model_comparison.sh for comparing CLIP vs DINOv2 models --- logo_detection_detr.py | 95 +++++++++++++++++++-------- run_model_comparison.sh | 92 ++++++++++++++++++++++++++ run_threshold_tests.sh | 141 ++++++++++++++++++++++++++++++++++++++++ test_logo_detection.py | 14 +++- 4 files changed, 312 insertions(+), 30 deletions(-) create mode 100755 run_model_comparison.sh create mode 100755 run_threshold_tests.sh diff --git a/logo_detection_detr.py b/logo_detection_detr.py index fa14087..19fc4e6 100644 --- a/logo_detection_detr.py +++ b/logo_detection_detr.py @@ -1,18 +1,22 @@ """ -Logo detection using DETR for object detection and CLIP for feature matching. +Logo detection using DETR for object detection and vision models for feature matching. This module provides a class for detecting logos in images using: 1. DETR (DEtection TRansformer) for initial logo region detection -2. CLIP (Contrastive Language-Image Pre-training) for feature extraction and matching +2. Vision models (CLIP, DINOv2, etc.) for feature extraction and matching The class supports caching of embeddings for efficient reprocessing. The class automatically uses local models if available, otherwise falls back to HuggingFace. + +Supported embedding models: +- CLIP models (openai/clip-vit-*): Text-image alignment, good general features +- DINOv2 models (facebook/dinov2-*): Self-supervised, excellent for visual similarity """ import os import torch import torch.nn.functional as F -from transformers import pipeline, CLIPProcessor, CLIPModel +from transformers import pipeline, CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModel from PIL import Image import cv2 import numpy as np @@ -22,28 +26,31 @@ from typing import List, Tuple, Dict, Optional, Any class DetectLogosDETR: """ - Logo detection class using DETR and CLIP models. + Logo detection class using DETR and vision embedding models. This class detects logos in images by: 1. Using DETR to find potential logo regions (bounding boxes) - 2. Extracting CLIP embeddings for each detected region + 2. Extracting embeddings for each detected region (CLIP, DINOv2, etc.) 3. Comparing embeddings with reference logos for identification The class automatically checks for local models before downloading from HuggingFace. + + Supported embedding models: + - CLIP models (openai/clip-vit-*): Text-image alignment + - DINOv2 models (facebook/dinov2-*): Self-supervised visual features """ def __init__( self, logger, detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2", - #clip_model: str = "openai/clip-vit-base-patch32", - clip_model: str = "openai/clip-vit-large-patch14", + embedding_model: str = "openai/clip-vit-large-patch14", detr_threshold: float = 0.5, min_box_size: int = 20, nms_iou_threshold: float = 0.5, ): """ - Initialize DETR and CLIP models. + Initialize DETR and embedding models. The class will automatically check for local models in the default directories before downloading from HuggingFace. You can override this by providing absolute @@ -52,7 +59,7 @@ class DetectLogosDETR: Args: logger: Logger instance for logging detr_model: HuggingFace model name or local path for DETR object detection - clip_model: HuggingFace model name or local path for CLIP embeddings + embedding_model: HuggingFace model name for embeddings (CLIP or DINOv2) detr_threshold: Confidence threshold for DETR detections (0-1) min_box_size: Minimum width/height in pixels for detected boxes (filters noise) nms_iou_threshold: IoU threshold for Non-Maximum Suppression @@ -61,6 +68,7 @@ class DetectLogosDETR: self.detr_threshold = detr_threshold self.min_box_size = min_box_size self.nms_iou_threshold = nms_iou_threshold + self.embedding_model_name = embedding_model # Set device self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -71,7 +79,7 @@ class DetectLogosDETR: # Get default model directories from environment variables default_detr_dir = os.environ.get('LOGO_DETR_MODEL_DIR', 'models/logo_detection/detr') - default_clip_dir = os.environ.get('LOGO_CLIP_MODEL_DIR', 'models/logo_detection/clip') + default_embedding_dir = os.environ.get('LOGO_EMBEDDING_MODEL_DIR', 'models/logo_detection/embedding') # Resolve DETR model path (check local first, then use HuggingFace name) detr_model_path = self._resolve_model_path( @@ -87,18 +95,35 @@ class DetectLogosDETR: use_fast=True, ) - # Resolve CLIP model path (check local first, then use HuggingFace name) - clip_model_path = self._resolve_model_path( - clip_model, default_clip_dir, "CLIP" + # Resolve embedding model path + embedding_model_path = self._resolve_model_path( + embedding_model, default_embedding_dir, "Embedding" ) - # Initialize CLIP model for feature extraction - self.logger.info(f"Loading CLIP model: {clip_model_path}") - self.clip_model = CLIPModel.from_pretrained(clip_model_path).to(self.device) - self.clip_processor = CLIPProcessor.from_pretrained(clip_model_path) + # Detect model type and initialize accordingly + self.model_type = self._detect_model_type(embedding_model) + self.logger.info(f"Loading {self.model_type} embedding model: {embedding_model_path}") + + if self.model_type == "clip": + self.embedding_model = CLIPModel.from_pretrained(embedding_model_path).to(self.device) + self.embedding_processor = CLIPProcessor.from_pretrained(embedding_model_path) + else: # dinov2 or other transformer models + self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device) + self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path) self.logger.info("DetectLogosDETR initialization complete") + def _detect_model_type(self, model_name: str) -> str: + """Detect the type of embedding model based on name.""" + model_name_lower = model_name.lower() + if "clip" in model_name_lower: + return "clip" + elif "dino" in model_name_lower: + return "dinov2" + else: + # Default to generic transformer for unknown models + return "transformer" + def _resolve_model_path( self, model_name_or_path: str, default_local_dir: str, model_type: str ) -> str: @@ -193,8 +218,8 @@ class DetectLogosDETR: # Extract bounding box region bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax)) - # Get CLIP embedding for this region - embedding = self._get_clip_embedding_pil(bbox_crop) + # Get embedding for this region + embedding = self._get_embedding_pil(bbox_crop) detections.append( { @@ -299,7 +324,7 @@ class DetectLogosDETR: def get_embedding(self, image: np.ndarray) -> torch.Tensor: """ - Get CLIP embedding for a reference logo image. + Get embedding for a reference logo image. This method is used to compute embeddings for reference logos that will be compared against detected regions. @@ -308,29 +333,43 @@ class DetectLogosDETR: image: OpenCV image (BGR format, numpy array) Returns: - Normalized CLIP feature embedding (torch.Tensor, shape: [1, 512]) + Normalized feature embedding (torch.Tensor) """ # Convert OpenCV BGR to RGB PIL Image image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) pil_image = Image.fromarray(image_rgb) - return self._get_clip_embedding_pil(pil_image) + return self._get_embedding_pil(pil_image) - def _get_clip_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor: + def _get_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor: """ - Internal method to get CLIP embedding from PIL image. + Internal method to get embedding from PIL image. + + Handles both CLIP and DINOv2 model types. Args: pil_image: PIL Image (RGB format) Returns: - Normalized CLIP feature embedding (torch.Tensor) + Normalized feature embedding (torch.Tensor) """ - # Process image through CLIP - inputs = self.clip_processor(images=pil_image, return_tensors="pt").to(self.device) + # Process image through the embedding model + inputs = self.embedding_processor(images=pil_image, return_tensors="pt").to(self.device) with torch.no_grad(): - features = self.clip_model.get_image_features(**inputs) + if self.model_type == "clip": + # CLIP has a dedicated method for image features + features = self.embedding_model.get_image_features(**inputs) + else: + # DINOv2 and other transformers use the CLS token or pooled output + outputs = self.embedding_model(**inputs) + # Use the CLS token (first token) from last hidden state + if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: + features = outputs.pooler_output + else: + # Use CLS token from last_hidden_state + features = outputs.last_hidden_state[:, 0, :] + # Normalize for cosine similarity features = F.normalize(features, dim=-1) diff --git a/run_model_comparison.sh b/run_model_comparison.sh new file mode 100755 index 0000000..5d831dc --- /dev/null +++ b/run_model_comparison.sh @@ -0,0 +1,92 @@ +#!/bin/bash +# +# Compare different embedding models for logo detection. +# Tests CLIP vs DINOv2 models. +# + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +OUTPUT_FILE="${SCRIPT_DIR}/model_comparison_results.txt" + +# Common parameters +NUM_LOGOS=20 +REFS_PER_LOGO=10 +POSITIVE_SAMPLES=20 +NEGATIVE_SAMPLES=100 +MIN_MATCHING_REFS=3 +THRESHOLD=0.80 +MARGIN=0.10 +SEED=42 + +# Clear output file and write header +echo "Embedding Model Comparison Tests" > "$OUTPUT_FILE" +echo "=================================" >> "$OUTPUT_FILE" +echo "Date: $(date)" >> "$OUTPUT_FILE" +echo "" >> "$OUTPUT_FILE" +echo "Common Parameters:" >> "$OUTPUT_FILE" +echo " Matching method: multi-ref (max)" >> "$OUTPUT_FILE" +echo " Reference logos: $NUM_LOGOS" >> "$OUTPUT_FILE" +echo " Refs per logo: $REFS_PER_LOGO" >> "$OUTPUT_FILE" +echo " Positive samples: $POSITIVE_SAMPLES" >> "$OUTPUT_FILE" +echo " Negative samples: $NEGATIVE_SAMPLES" >> "$OUTPUT_FILE" +echo " Min matching refs: $MIN_MATCHING_REFS" >> "$OUTPUT_FILE" +echo " Threshold: $THRESHOLD" >> "$OUTPUT_FILE" +echo " Margin: $MARGIN" >> "$OUTPUT_FILE" +echo " Seed: $SEED" >> "$OUTPUT_FILE" +echo "" >> "$OUTPUT_FILE" + +echo "Running model comparison tests..." +echo " Matching method: multi-ref (max)" +echo " Reference logos: $NUM_LOGOS" +echo " Threshold: $THRESHOLD" +echo " Margin: $MARGIN" +echo " Seed: $SEED" +echo "" + +# IMPORTANT: Clear cache between model tests since embeddings are model-specific +echo "NOTE: Cache will be cleared between model tests to ensure correct embeddings." +echo "" + +# Test 1: CLIP ViT-Large (default) +echo "=== Test 1: CLIP ViT-Large (openai/clip-vit-large-patch14) ===" +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method multi-ref \ + --min-matching-refs $MIN_MATCHING_REFS \ + --use-max-similarity \ + --threshold $THRESHOLD \ + --margin $MARGIN \ + --seed $SEED \ + --embedding-model "openai/clip-vit-large-patch14" \ + --clear-cache \ + --output-file "$OUTPUT_FILE" + +echo "" + +# Test 2: DINOv2 Small +echo "=== Test 2: DINOv2 Small (facebook/dinov2-small) ===" +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method multi-ref \ + --min-matching-refs $MIN_MATCHING_REFS \ + --use-max-similarity \ + --threshold $THRESHOLD \ + --margin $MARGIN \ + --seed $SEED \ + --embedding-model "facebook/dinov2-small" \ + --clear-cache \ + --output-file "$OUTPUT_FILE" + +echo "" +echo "Results saved to: $OUTPUT_FILE" +echo "" +echo "Note: You can also try other models:" +echo " - facebook/dinov2-base" +echo " - facebook/dinov2-large" +echo " - openai/clip-vit-base-patch32" +echo " - openai/clip-vit-large-patch14-336" \ No newline at end of file diff --git a/run_threshold_tests.sh b/run_threshold_tests.sh new file mode 100755 index 0000000..db8e26f --- /dev/null +++ b/run_threshold_tests.sh @@ -0,0 +1,141 @@ +#!/bin/bash +# +# Run logo detection tests with various threshold and margin settings. +# Uses multi-ref (max) matching method for all tests. +# + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +OUTPUT_FILE="${SCRIPT_DIR}/threshold_test_results.txt" + +# Common parameters +NUM_LOGOS=20 +REFS_PER_LOGO=10 +POSITIVE_SAMPLES=20 +NEGATIVE_SAMPLES=100 +MIN_MATCHING_REFS=3 +SEED=42 + +# Clear output file and write header +echo "Threshold Optimization Tests" > "$OUTPUT_FILE" +echo "=============================" >> "$OUTPUT_FILE" +echo "Date: $(date)" >> "$OUTPUT_FILE" +echo "" >> "$OUTPUT_FILE" +echo "Common Parameters:" >> "$OUTPUT_FILE" +echo " Matching method: multi-ref (max)" >> "$OUTPUT_FILE" +echo " Reference logos: $NUM_LOGOS" >> "$OUTPUT_FILE" +echo " Refs per logo: $REFS_PER_LOGO" >> "$OUTPUT_FILE" +echo " Positive samples: $POSITIVE_SAMPLES" >> "$OUTPUT_FILE" +echo " Negative samples: $NEGATIVE_SAMPLES" >> "$OUTPUT_FILE" +echo " Min matching refs: $MIN_MATCHING_REFS" >> "$OUTPUT_FILE" +echo " Seed: $SEED" >> "$OUTPUT_FILE" +echo "" >> "$OUTPUT_FILE" + +echo "Running threshold optimization tests..." +echo " Matching method: multi-ref (max)" +echo " Reference logos: $NUM_LOGOS" +echo " Refs per logo: $REFS_PER_LOGO" +echo " Seed: $SEED" +echo "" + +# Test 1: Default parameters (baseline) +echo "=== Test 1: Default parameters (threshold=0.70, margin=0.05) ===" +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method multi-ref \ + --min-matching-refs $MIN_MATCHING_REFS \ + --use-max-similarity \ + --threshold 0.70 \ + --margin 0.05 \ + --seed $SEED \ + --output-file "$OUTPUT_FILE" + +echo "" + +# Test 2: Higher threshold +echo "=== Test 2: Higher threshold (threshold=0.80, margin=0.05) ===" +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method multi-ref \ + --min-matching-refs $MIN_MATCHING_REFS \ + --use-max-similarity \ + --threshold 0.80 \ + --margin 0.05 \ + --seed $SEED \ + --output-file "$OUTPUT_FILE" + +echo "" + +# Test 3: Higher threshold + larger margin +echo "=== Test 3: Higher threshold + larger margin (threshold=0.80, margin=0.10) ===" +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method multi-ref \ + --min-matching-refs $MIN_MATCHING_REFS \ + --use-max-similarity \ + --threshold 0.80 \ + --margin 0.10 \ + --seed $SEED \ + --output-file "$OUTPUT_FILE" + +echo "" + +# Test 4: Very high threshold +echo "=== Test 4: Very high threshold (threshold=0.85, margin=0.10) ===" +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method multi-ref \ + --min-matching-refs $MIN_MATCHING_REFS \ + --use-max-similarity \ + --threshold 0.85 \ + --margin 0.10 \ + --seed $SEED \ + --output-file "$OUTPUT_FILE" + +echo "" + +# Test 5: Very high threshold + large margin +echo "=== Test 5: Strict parameters (threshold=0.85, margin=0.15) ===" +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method multi-ref \ + --min-matching-refs $MIN_MATCHING_REFS \ + --use-max-similarity \ + --threshold 0.85 \ + --margin 0.15 \ + --seed $SEED \ + --output-file "$OUTPUT_FILE" + +echo "" + +# Test 6: Maximum strictness +echo "=== Test 6: Maximum strictness (threshold=0.90, margin=0.15) ===" +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method multi-ref \ + --min-matching-refs $MIN_MATCHING_REFS \ + --use-max-similarity \ + --threshold 0.90 \ + --margin 0.15 \ + --seed $SEED \ + --output-file "$OUTPUT_FILE" + +echo "" +echo "Results saved to: $OUTPUT_FILE" \ No newline at end of file diff --git a/test_logo_detection.py b/test_logo_detection.py index a94f980..4938855 100755 --- a/test_logo_detection.py +++ b/test_logo_detection.py @@ -203,6 +203,13 @@ def main(): default=0.5, help="DETR detection confidence threshold (default: 0.5)", ) + parser.add_argument( + "-e", "--embedding-model", + type=str, + default="openai/clip-vit-large-patch14", + help="Embedding model for feature extraction (default: openai/clip-vit-large-patch14). " + "Supports CLIP models (openai/clip-*) and DINOv2 models (facebook/dinov2-*)", + ) parser.add_argument( "-s", "--seed", type=int, @@ -302,10 +309,11 @@ def main(): logger.info(f"Loaded {len(cache)} cached embeddings") # Initialize detector - logger.info("Initializing logo detector...") + logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}") detector = DetectLogosDETR( logger=logger, detr_threshold=args.detr_threshold, + embedding_model=args.embedding_model, ) # Load ground truth (both mappings) @@ -633,18 +641,20 @@ def write_results_to_file( lines = [ "=" * 70, f"TEST: {args.matching_method.upper()} MATCHING", + f"Model: {args.embedding_model}", f"Method: {method_desc}", "=" * 70, f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", "", "Configuration:", + f" Embedding model: {args.embedding_model}", f" Reference logos: {num_logos}", f" Refs per logo: {args.refs_per_logo}", f" Total reference embeddings:{total_refs}", f" Positive samples/logo: {args.positive_samples}", f" Negative samples/logo: {args.negative_samples}", f" Test images processed: {num_test_images}", - f" CLIP threshold: {args.threshold}", + f" Similarity threshold: {args.threshold}", f" DETR threshold: {args.detr_threshold}", ]