Add embedding model selection and comparison test scripts
- Update DetectLogosDETR to support both CLIP and DINOv2 models - Rename clip_model parameter to embedding_model - Add model type detection for different embedding extraction - DINOv2 uses CLS token, CLIP uses get_image_features() - Add -e/--embedding-model argument to test_logo_detection.py - Include model name in file output header - Add run_threshold_tests.sh for testing various threshold/margin values - Add run_model_comparison.sh for comparing CLIP vs DINOv2 models
This commit is contained in:
@ -1,18 +1,22 @@
|
|||||||
"""
|
"""
|
||||||
Logo detection using DETR for object detection and CLIP for feature matching.
|
Logo detection using DETR for object detection and vision models for feature matching.
|
||||||
|
|
||||||
This module provides a class for detecting logos in images using:
|
This module provides a class for detecting logos in images using:
|
||||||
1. DETR (DEtection TRansformer) for initial logo region detection
|
1. DETR (DEtection TRansformer) for initial logo region detection
|
||||||
2. CLIP (Contrastive Language-Image Pre-training) for feature extraction and matching
|
2. Vision models (CLIP, DINOv2, etc.) for feature extraction and matching
|
||||||
|
|
||||||
The class supports caching of embeddings for efficient reprocessing.
|
The class supports caching of embeddings for efficient reprocessing.
|
||||||
The class automatically uses local models if available, otherwise falls back to HuggingFace.
|
The class automatically uses local models if available, otherwise falls back to HuggingFace.
|
||||||
|
|
||||||
|
Supported embedding models:
|
||||||
|
- CLIP models (openai/clip-vit-*): Text-image alignment, good general features
|
||||||
|
- DINOv2 models (facebook/dinov2-*): Self-supervised, excellent for visual similarity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import pipeline, CLIPProcessor, CLIPModel
|
from transformers import pipeline, CLIPProcessor, CLIPModel, AutoImageProcessor, AutoModel
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -22,28 +26,31 @@ from typing import List, Tuple, Dict, Optional, Any
|
|||||||
|
|
||||||
class DetectLogosDETR:
|
class DetectLogosDETR:
|
||||||
"""
|
"""
|
||||||
Logo detection class using DETR and CLIP models.
|
Logo detection class using DETR and vision embedding models.
|
||||||
|
|
||||||
This class detects logos in images by:
|
This class detects logos in images by:
|
||||||
1. Using DETR to find potential logo regions (bounding boxes)
|
1. Using DETR to find potential logo regions (bounding boxes)
|
||||||
2. Extracting CLIP embeddings for each detected region
|
2. Extracting embeddings for each detected region (CLIP, DINOv2, etc.)
|
||||||
3. Comparing embeddings with reference logos for identification
|
3. Comparing embeddings with reference logos for identification
|
||||||
|
|
||||||
The class automatically checks for local models before downloading from HuggingFace.
|
The class automatically checks for local models before downloading from HuggingFace.
|
||||||
|
|
||||||
|
Supported embedding models:
|
||||||
|
- CLIP models (openai/clip-vit-*): Text-image alignment
|
||||||
|
- DINOv2 models (facebook/dinov2-*): Self-supervised visual features
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
logger,
|
logger,
|
||||||
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
|
detr_model: str = "Pravallika6/detr-finetuned-logo-detection_v2",
|
||||||
#clip_model: str = "openai/clip-vit-base-patch32",
|
embedding_model: str = "openai/clip-vit-large-patch14",
|
||||||
clip_model: str = "openai/clip-vit-large-patch14",
|
|
||||||
detr_threshold: float = 0.5,
|
detr_threshold: float = 0.5,
|
||||||
min_box_size: int = 20,
|
min_box_size: int = 20,
|
||||||
nms_iou_threshold: float = 0.5,
|
nms_iou_threshold: float = 0.5,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize DETR and CLIP models.
|
Initialize DETR and embedding models.
|
||||||
|
|
||||||
The class will automatically check for local models in the default directories
|
The class will automatically check for local models in the default directories
|
||||||
before downloading from HuggingFace. You can override this by providing absolute
|
before downloading from HuggingFace. You can override this by providing absolute
|
||||||
@ -52,7 +59,7 @@ class DetectLogosDETR:
|
|||||||
Args:
|
Args:
|
||||||
logger: Logger instance for logging
|
logger: Logger instance for logging
|
||||||
detr_model: HuggingFace model name or local path for DETR object detection
|
detr_model: HuggingFace model name or local path for DETR object detection
|
||||||
clip_model: HuggingFace model name or local path for CLIP embeddings
|
embedding_model: HuggingFace model name for embeddings (CLIP or DINOv2)
|
||||||
detr_threshold: Confidence threshold for DETR detections (0-1)
|
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||||
min_box_size: Minimum width/height in pixels for detected boxes (filters noise)
|
min_box_size: Minimum width/height in pixels for detected boxes (filters noise)
|
||||||
nms_iou_threshold: IoU threshold for Non-Maximum Suppression
|
nms_iou_threshold: IoU threshold for Non-Maximum Suppression
|
||||||
@ -61,6 +68,7 @@ class DetectLogosDETR:
|
|||||||
self.detr_threshold = detr_threshold
|
self.detr_threshold = detr_threshold
|
||||||
self.min_box_size = min_box_size
|
self.min_box_size = min_box_size
|
||||||
self.nms_iou_threshold = nms_iou_threshold
|
self.nms_iou_threshold = nms_iou_threshold
|
||||||
|
self.embedding_model_name = embedding_model
|
||||||
|
|
||||||
# Set device
|
# Set device
|
||||||
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
@ -71,7 +79,7 @@ class DetectLogosDETR:
|
|||||||
|
|
||||||
# Get default model directories from environment variables
|
# Get default model directories from environment variables
|
||||||
default_detr_dir = os.environ.get('LOGO_DETR_MODEL_DIR', 'models/logo_detection/detr')
|
default_detr_dir = os.environ.get('LOGO_DETR_MODEL_DIR', 'models/logo_detection/detr')
|
||||||
default_clip_dir = os.environ.get('LOGO_CLIP_MODEL_DIR', 'models/logo_detection/clip')
|
default_embedding_dir = os.environ.get('LOGO_EMBEDDING_MODEL_DIR', 'models/logo_detection/embedding')
|
||||||
|
|
||||||
# Resolve DETR model path (check local first, then use HuggingFace name)
|
# Resolve DETR model path (check local first, then use HuggingFace name)
|
||||||
detr_model_path = self._resolve_model_path(
|
detr_model_path = self._resolve_model_path(
|
||||||
@ -87,18 +95,35 @@ class DetectLogosDETR:
|
|||||||
use_fast=True,
|
use_fast=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Resolve CLIP model path (check local first, then use HuggingFace name)
|
# Resolve embedding model path
|
||||||
clip_model_path = self._resolve_model_path(
|
embedding_model_path = self._resolve_model_path(
|
||||||
clip_model, default_clip_dir, "CLIP"
|
embedding_model, default_embedding_dir, "Embedding"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize CLIP model for feature extraction
|
# Detect model type and initialize accordingly
|
||||||
self.logger.info(f"Loading CLIP model: {clip_model_path}")
|
self.model_type = self._detect_model_type(embedding_model)
|
||||||
self.clip_model = CLIPModel.from_pretrained(clip_model_path).to(self.device)
|
self.logger.info(f"Loading {self.model_type} embedding model: {embedding_model_path}")
|
||||||
self.clip_processor = CLIPProcessor.from_pretrained(clip_model_path)
|
|
||||||
|
if self.model_type == "clip":
|
||||||
|
self.embedding_model = CLIPModel.from_pretrained(embedding_model_path).to(self.device)
|
||||||
|
self.embedding_processor = CLIPProcessor.from_pretrained(embedding_model_path)
|
||||||
|
else: # dinov2 or other transformer models
|
||||||
|
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
|
||||||
|
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
|
||||||
|
|
||||||
self.logger.info("DetectLogosDETR initialization complete")
|
self.logger.info("DetectLogosDETR initialization complete")
|
||||||
|
|
||||||
|
def _detect_model_type(self, model_name: str) -> str:
|
||||||
|
"""Detect the type of embedding model based on name."""
|
||||||
|
model_name_lower = model_name.lower()
|
||||||
|
if "clip" in model_name_lower:
|
||||||
|
return "clip"
|
||||||
|
elif "dino" in model_name_lower:
|
||||||
|
return "dinov2"
|
||||||
|
else:
|
||||||
|
# Default to generic transformer for unknown models
|
||||||
|
return "transformer"
|
||||||
|
|
||||||
def _resolve_model_path(
|
def _resolve_model_path(
|
||||||
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
self, model_name_or_path: str, default_local_dir: str, model_type: str
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -193,8 +218,8 @@ class DetectLogosDETR:
|
|||||||
# Extract bounding box region
|
# Extract bounding box region
|
||||||
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
|
bbox_crop = pil_image.crop((xmin, ymin, xmax, ymax))
|
||||||
|
|
||||||
# Get CLIP embedding for this region
|
# Get embedding for this region
|
||||||
embedding = self._get_clip_embedding_pil(bbox_crop)
|
embedding = self._get_embedding_pil(bbox_crop)
|
||||||
|
|
||||||
detections.append(
|
detections.append(
|
||||||
{
|
{
|
||||||
@ -299,7 +324,7 @@ class DetectLogosDETR:
|
|||||||
|
|
||||||
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
|
def get_embedding(self, image: np.ndarray) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Get CLIP embedding for a reference logo image.
|
Get embedding for a reference logo image.
|
||||||
|
|
||||||
This method is used to compute embeddings for reference logos
|
This method is used to compute embeddings for reference logos
|
||||||
that will be compared against detected regions.
|
that will be compared against detected regions.
|
||||||
@ -308,29 +333,43 @@ class DetectLogosDETR:
|
|||||||
image: OpenCV image (BGR format, numpy array)
|
image: OpenCV image (BGR format, numpy array)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Normalized CLIP feature embedding (torch.Tensor, shape: [1, 512])
|
Normalized feature embedding (torch.Tensor)
|
||||||
"""
|
"""
|
||||||
# Convert OpenCV BGR to RGB PIL Image
|
# Convert OpenCV BGR to RGB PIL Image
|
||||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
pil_image = Image.fromarray(image_rgb)
|
pil_image = Image.fromarray(image_rgb)
|
||||||
|
|
||||||
return self._get_clip_embedding_pil(pil_image)
|
return self._get_embedding_pil(pil_image)
|
||||||
|
|
||||||
def _get_clip_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor:
|
def _get_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Internal method to get CLIP embedding from PIL image.
|
Internal method to get embedding from PIL image.
|
||||||
|
|
||||||
|
Handles both CLIP and DINOv2 model types.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pil_image: PIL Image (RGB format)
|
pil_image: PIL Image (RGB format)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Normalized CLIP feature embedding (torch.Tensor)
|
Normalized feature embedding (torch.Tensor)
|
||||||
"""
|
"""
|
||||||
# Process image through CLIP
|
# Process image through the embedding model
|
||||||
inputs = self.clip_processor(images=pil_image, return_tensors="pt").to(self.device)
|
inputs = self.embedding_processor(images=pil_image, return_tensors="pt").to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
features = self.clip_model.get_image_features(**inputs)
|
if self.model_type == "clip":
|
||||||
|
# CLIP has a dedicated method for image features
|
||||||
|
features = self.embedding_model.get_image_features(**inputs)
|
||||||
|
else:
|
||||||
|
# DINOv2 and other transformers use the CLS token or pooled output
|
||||||
|
outputs = self.embedding_model(**inputs)
|
||||||
|
# Use the CLS token (first token) from last hidden state
|
||||||
|
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
||||||
|
features = outputs.pooler_output
|
||||||
|
else:
|
||||||
|
# Use CLS token from last_hidden_state
|
||||||
|
features = outputs.last_hidden_state[:, 0, :]
|
||||||
|
|
||||||
# Normalize for cosine similarity
|
# Normalize for cosine similarity
|
||||||
features = F.normalize(features, dim=-1)
|
features = F.normalize(features, dim=-1)
|
||||||
|
|
||||||
|
|||||||
92
run_model_comparison.sh
Executable file
92
run_model_comparison.sh
Executable file
@ -0,0 +1,92 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Compare different embedding models for logo detection.
|
||||||
|
# Tests CLIP vs DINOv2 models.
|
||||||
|
#
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
OUTPUT_FILE="${SCRIPT_DIR}/model_comparison_results.txt"
|
||||||
|
|
||||||
|
# Common parameters
|
||||||
|
NUM_LOGOS=20
|
||||||
|
REFS_PER_LOGO=10
|
||||||
|
POSITIVE_SAMPLES=20
|
||||||
|
NEGATIVE_SAMPLES=100
|
||||||
|
MIN_MATCHING_REFS=3
|
||||||
|
THRESHOLD=0.80
|
||||||
|
MARGIN=0.10
|
||||||
|
SEED=42
|
||||||
|
|
||||||
|
# Clear output file and write header
|
||||||
|
echo "Embedding Model Comparison Tests" > "$OUTPUT_FILE"
|
||||||
|
echo "=================================" >> "$OUTPUT_FILE"
|
||||||
|
echo "Date: $(date)" >> "$OUTPUT_FILE"
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "Common Parameters:" >> "$OUTPUT_FILE"
|
||||||
|
echo " Matching method: multi-ref (max)" >> "$OUTPUT_FILE"
|
||||||
|
echo " Reference logos: $NUM_LOGOS" >> "$OUTPUT_FILE"
|
||||||
|
echo " Refs per logo: $REFS_PER_LOGO" >> "$OUTPUT_FILE"
|
||||||
|
echo " Positive samples: $POSITIVE_SAMPLES" >> "$OUTPUT_FILE"
|
||||||
|
echo " Negative samples: $NEGATIVE_SAMPLES" >> "$OUTPUT_FILE"
|
||||||
|
echo " Min matching refs: $MIN_MATCHING_REFS" >> "$OUTPUT_FILE"
|
||||||
|
echo " Threshold: $THRESHOLD" >> "$OUTPUT_FILE"
|
||||||
|
echo " Margin: $MARGIN" >> "$OUTPUT_FILE"
|
||||||
|
echo " Seed: $SEED" >> "$OUTPUT_FILE"
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo "Running model comparison tests..."
|
||||||
|
echo " Matching method: multi-ref (max)"
|
||||||
|
echo " Reference logos: $NUM_LOGOS"
|
||||||
|
echo " Threshold: $THRESHOLD"
|
||||||
|
echo " Margin: $MARGIN"
|
||||||
|
echo " Seed: $SEED"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# IMPORTANT: Clear cache between model tests since embeddings are model-specific
|
||||||
|
echo "NOTE: Cache will be cleared between model tests to ensure correct embeddings."
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 1: CLIP ViT-Large (default)
|
||||||
|
echo "=== Test 1: CLIP ViT-Large (openai/clip-vit-large-patch14) ==="
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
|
--use-max-similarity \
|
||||||
|
--threshold $THRESHOLD \
|
||||||
|
--margin $MARGIN \
|
||||||
|
--seed $SEED \
|
||||||
|
--embedding-model "openai/clip-vit-large-patch14" \
|
||||||
|
--clear-cache \
|
||||||
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 2: DINOv2 Small
|
||||||
|
echo "=== Test 2: DINOv2 Small (facebook/dinov2-small) ==="
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
|
--use-max-similarity \
|
||||||
|
--threshold $THRESHOLD \
|
||||||
|
--margin $MARGIN \
|
||||||
|
--seed $SEED \
|
||||||
|
--embedding-model "facebook/dinov2-small" \
|
||||||
|
--clear-cache \
|
||||||
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Results saved to: $OUTPUT_FILE"
|
||||||
|
echo ""
|
||||||
|
echo "Note: You can also try other models:"
|
||||||
|
echo " - facebook/dinov2-base"
|
||||||
|
echo " - facebook/dinov2-large"
|
||||||
|
echo " - openai/clip-vit-base-patch32"
|
||||||
|
echo " - openai/clip-vit-large-patch14-336"
|
||||||
141
run_threshold_tests.sh
Executable file
141
run_threshold_tests.sh
Executable file
@ -0,0 +1,141 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Run logo detection tests with various threshold and margin settings.
|
||||||
|
# Uses multi-ref (max) matching method for all tests.
|
||||||
|
#
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
OUTPUT_FILE="${SCRIPT_DIR}/threshold_test_results.txt"
|
||||||
|
|
||||||
|
# Common parameters
|
||||||
|
NUM_LOGOS=20
|
||||||
|
REFS_PER_LOGO=10
|
||||||
|
POSITIVE_SAMPLES=20
|
||||||
|
NEGATIVE_SAMPLES=100
|
||||||
|
MIN_MATCHING_REFS=3
|
||||||
|
SEED=42
|
||||||
|
|
||||||
|
# Clear output file and write header
|
||||||
|
echo "Threshold Optimization Tests" > "$OUTPUT_FILE"
|
||||||
|
echo "=============================" >> "$OUTPUT_FILE"
|
||||||
|
echo "Date: $(date)" >> "$OUTPUT_FILE"
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "Common Parameters:" >> "$OUTPUT_FILE"
|
||||||
|
echo " Matching method: multi-ref (max)" >> "$OUTPUT_FILE"
|
||||||
|
echo " Reference logos: $NUM_LOGOS" >> "$OUTPUT_FILE"
|
||||||
|
echo " Refs per logo: $REFS_PER_LOGO" >> "$OUTPUT_FILE"
|
||||||
|
echo " Positive samples: $POSITIVE_SAMPLES" >> "$OUTPUT_FILE"
|
||||||
|
echo " Negative samples: $NEGATIVE_SAMPLES" >> "$OUTPUT_FILE"
|
||||||
|
echo " Min matching refs: $MIN_MATCHING_REFS" >> "$OUTPUT_FILE"
|
||||||
|
echo " Seed: $SEED" >> "$OUTPUT_FILE"
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo "Running threshold optimization tests..."
|
||||||
|
echo " Matching method: multi-ref (max)"
|
||||||
|
echo " Reference logos: $NUM_LOGOS"
|
||||||
|
echo " Refs per logo: $REFS_PER_LOGO"
|
||||||
|
echo " Seed: $SEED"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 1: Default parameters (baseline)
|
||||||
|
echo "=== Test 1: Default parameters (threshold=0.70, margin=0.05) ==="
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
|
--use-max-similarity \
|
||||||
|
--threshold 0.70 \
|
||||||
|
--margin 0.05 \
|
||||||
|
--seed $SEED \
|
||||||
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 2: Higher threshold
|
||||||
|
echo "=== Test 2: Higher threshold (threshold=0.80, margin=0.05) ==="
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
|
--use-max-similarity \
|
||||||
|
--threshold 0.80 \
|
||||||
|
--margin 0.05 \
|
||||||
|
--seed $SEED \
|
||||||
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 3: Higher threshold + larger margin
|
||||||
|
echo "=== Test 3: Higher threshold + larger margin (threshold=0.80, margin=0.10) ==="
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
|
--use-max-similarity \
|
||||||
|
--threshold 0.80 \
|
||||||
|
--margin 0.10 \
|
||||||
|
--seed $SEED \
|
||||||
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 4: Very high threshold
|
||||||
|
echo "=== Test 4: Very high threshold (threshold=0.85, margin=0.10) ==="
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
|
--use-max-similarity \
|
||||||
|
--threshold 0.85 \
|
||||||
|
--margin 0.10 \
|
||||||
|
--seed $SEED \
|
||||||
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 5: Very high threshold + large margin
|
||||||
|
echo "=== Test 5: Strict parameters (threshold=0.85, margin=0.15) ==="
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
|
--use-max-similarity \
|
||||||
|
--threshold 0.85 \
|
||||||
|
--margin 0.15 \
|
||||||
|
--seed $SEED \
|
||||||
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 6: Maximum strictness
|
||||||
|
echo "=== Test 6: Maximum strictness (threshold=0.90, margin=0.15) ==="
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
|
--use-max-similarity \
|
||||||
|
--threshold 0.90 \
|
||||||
|
--margin 0.15 \
|
||||||
|
--seed $SEED \
|
||||||
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Results saved to: $OUTPUT_FILE"
|
||||||
@ -203,6 +203,13 @@ def main():
|
|||||||
default=0.5,
|
default=0.5,
|
||||||
help="DETR detection confidence threshold (default: 0.5)",
|
help="DETR detection confidence threshold (default: 0.5)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-e", "--embedding-model",
|
||||||
|
type=str,
|
||||||
|
default="openai/clip-vit-large-patch14",
|
||||||
|
help="Embedding model for feature extraction (default: openai/clip-vit-large-patch14). "
|
||||||
|
"Supports CLIP models (openai/clip-*) and DINOv2 models (facebook/dinov2-*)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-s", "--seed",
|
"-s", "--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -302,10 +309,11 @@ def main():
|
|||||||
logger.info(f"Loaded {len(cache)} cached embeddings")
|
logger.info(f"Loaded {len(cache)} cached embeddings")
|
||||||
|
|
||||||
# Initialize detector
|
# Initialize detector
|
||||||
logger.info("Initializing logo detector...")
|
logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}")
|
||||||
detector = DetectLogosDETR(
|
detector = DetectLogosDETR(
|
||||||
logger=logger,
|
logger=logger,
|
||||||
detr_threshold=args.detr_threshold,
|
detr_threshold=args.detr_threshold,
|
||||||
|
embedding_model=args.embedding_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load ground truth (both mappings)
|
# Load ground truth (both mappings)
|
||||||
@ -633,18 +641,20 @@ def write_results_to_file(
|
|||||||
lines = [
|
lines = [
|
||||||
"=" * 70,
|
"=" * 70,
|
||||||
f"TEST: {args.matching_method.upper()} MATCHING",
|
f"TEST: {args.matching_method.upper()} MATCHING",
|
||||||
|
f"Model: {args.embedding_model}",
|
||||||
f"Method: {method_desc}",
|
f"Method: {method_desc}",
|
||||||
"=" * 70,
|
"=" * 70,
|
||||||
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||||
"",
|
"",
|
||||||
"Configuration:",
|
"Configuration:",
|
||||||
|
f" Embedding model: {args.embedding_model}",
|
||||||
f" Reference logos: {num_logos}",
|
f" Reference logos: {num_logos}",
|
||||||
f" Refs per logo: {args.refs_per_logo}",
|
f" Refs per logo: {args.refs_per_logo}",
|
||||||
f" Total reference embeddings:{total_refs}",
|
f" Total reference embeddings:{total_refs}",
|
||||||
f" Positive samples/logo: {args.positive_samples}",
|
f" Positive samples/logo: {args.positive_samples}",
|
||||||
f" Negative samples/logo: {args.negative_samples}",
|
f" Negative samples/logo: {args.negative_samples}",
|
||||||
f" Test images processed: {num_test_images}",
|
f" Test images processed: {num_test_images}",
|
||||||
f" CLIP threshold: {args.threshold}",
|
f" Similarity threshold: {args.threshold}",
|
||||||
f" DETR threshold: {args.detr_threshold}",
|
f" DETR threshold: {args.detr_threshold}",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user