Add hybrid text+CLIP matching and image preprocessing
Hybrid matching combines text recognition with CLIP similarity: - If reference logo has text and detection matches: lower CLIP threshold - If reference has text but detection doesn't match: higher threshold - If reference has no text: standard threshold Image preprocessing adds letterbox/stretch modes for CLIP input to preserve aspect ratio instead of center cropping. New files: - run_hybrid_test.sh: Test hybrid matching configurations - run_preprocess_test.sh: Compare preprocessing modes Changes to logo_detection_detr.py: - Add preprocess_mode parameter (default/letterbox/stretch) - Add set_text_detector() for hybrid matching - Add extract_text() using EasyOCR - Add compute_text_similarity() with fuzzy matching - Add find_best_match_hybrid() with tiered thresholds Changes to test_logo_detection.py: - Add --matching-method hybrid option - Add --preprocess-mode option - Add hybrid threshold arguments
This commit is contained in:
@ -23,6 +23,7 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Dict, Optional, Any
|
from typing import List, Tuple, Dict, Optional, Any
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
|
||||||
|
|
||||||
class DetectLogosDETR:
|
class DetectLogosDETR:
|
||||||
@ -49,6 +50,7 @@ class DetectLogosDETR:
|
|||||||
detr_threshold: float = 0.5,
|
detr_threshold: float = 0.5,
|
||||||
min_box_size: int = 20,
|
min_box_size: int = 20,
|
||||||
nms_iou_threshold: float = 0.5,
|
nms_iou_threshold: float = 0.5,
|
||||||
|
preprocess_mode: str = "default",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize DETR and embedding models.
|
Initialize DETR and embedding models.
|
||||||
@ -64,12 +66,17 @@ class DetectLogosDETR:
|
|||||||
detr_threshold: Confidence threshold for DETR detections (0-1)
|
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||||
min_box_size: Minimum width/height in pixels for detected boxes (filters noise)
|
min_box_size: Minimum width/height in pixels for detected boxes (filters noise)
|
||||||
nms_iou_threshold: IoU threshold for Non-Maximum Suppression
|
nms_iou_threshold: IoU threshold for Non-Maximum Suppression
|
||||||
|
preprocess_mode: Image preprocessing mode for CLIP:
|
||||||
|
- "default": Use CLIP's default (resize shortest edge + center crop)
|
||||||
|
- "letterbox": Pad to square with black bars, preserving aspect ratio
|
||||||
|
- "stretch": Stretch to square (distorts aspect ratio)
|
||||||
"""
|
"""
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.detr_threshold = detr_threshold
|
self.detr_threshold = detr_threshold
|
||||||
self.min_box_size = min_box_size
|
self.min_box_size = min_box_size
|
||||||
self.nms_iou_threshold = nms_iou_threshold
|
self.nms_iou_threshold = nms_iou_threshold
|
||||||
self.embedding_model_name = embedding_model
|
self.embedding_model_name = embedding_model
|
||||||
|
self.preprocess_mode = preprocess_mode
|
||||||
|
|
||||||
# Set device
|
# Set device
|
||||||
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||||
@ -116,6 +123,8 @@ class DetectLogosDETR:
|
|||||||
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
|
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
|
||||||
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
|
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
|
||||||
|
|
||||||
|
if self.preprocess_mode != "default":
|
||||||
|
self.logger.info(f"Image preprocessing mode: {self.preprocess_mode}")
|
||||||
self.logger.info("DetectLogosDETR initialization complete")
|
self.logger.info("DetectLogosDETR initialization complete")
|
||||||
|
|
||||||
def _detect_model_type(self, model_name: str) -> str:
|
def _detect_model_type(self, model_name: str) -> str:
|
||||||
@ -402,6 +411,46 @@ class DetectLogosDETR:
|
|||||||
|
|
||||||
return self._get_embedding_pil(pil_image)
|
return self._get_embedding_pil(pil_image)
|
||||||
|
|
||||||
|
def _preprocess_image(self, pil_image: Image.Image, target_size: int = 224) -> Image.Image:
|
||||||
|
"""
|
||||||
|
Preprocess image based on the configured preprocessing mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pil_image: PIL Image (RGB format)
|
||||||
|
target_size: Target size for the square output (default 224 for CLIP)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Preprocessed PIL Image
|
||||||
|
"""
|
||||||
|
if self.preprocess_mode == "default":
|
||||||
|
# Let the processor handle it (resize shortest edge + center crop)
|
||||||
|
return pil_image
|
||||||
|
|
||||||
|
width, height = pil_image.size
|
||||||
|
|
||||||
|
if self.preprocess_mode == "letterbox":
|
||||||
|
# Pad to square with black bars, preserving aspect ratio
|
||||||
|
max_dim = max(width, height)
|
||||||
|
|
||||||
|
# Create a black square canvas
|
||||||
|
new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
|
||||||
|
|
||||||
|
# Paste the original image centered
|
||||||
|
paste_x = (max_dim - width) // 2
|
||||||
|
paste_y = (max_dim - height) // 2
|
||||||
|
new_image.paste(pil_image, (paste_x, paste_y))
|
||||||
|
|
||||||
|
# Resize to target size
|
||||||
|
return new_image.resize((target_size, target_size), Image.LANCZOS)
|
||||||
|
|
||||||
|
elif self.preprocess_mode == "stretch":
|
||||||
|
# Stretch to square (distorts aspect ratio)
|
||||||
|
return pil_image.resize((target_size, target_size), Image.LANCZOS)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Unknown mode, return original
|
||||||
|
return pil_image
|
||||||
|
|
||||||
def _get_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor:
|
def _get_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Internal method to get embedding from PIL image.
|
Internal method to get embedding from PIL image.
|
||||||
@ -414,6 +463,10 @@ class DetectLogosDETR:
|
|||||||
Returns:
|
Returns:
|
||||||
Normalized feature embedding (torch.Tensor)
|
Normalized feature embedding (torch.Tensor)
|
||||||
"""
|
"""
|
||||||
|
# Apply preprocessing if configured
|
||||||
|
if self.preprocess_mode != "default":
|
||||||
|
pil_image = self._preprocess_image(pil_image)
|
||||||
|
|
||||||
# Process image through the embedding model
|
# Process image through the embedding model
|
||||||
inputs = self.embedding_processor(images=pil_image, return_tensors="pt").to(self.device)
|
inputs = self.embedding_processor(images=pil_image, return_tensors="pt").to(self.device)
|
||||||
|
|
||||||
@ -713,3 +766,310 @@ class DetectLogosDETR:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return matched_detections
|
return matched_detections
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Hybrid Text + CLIP Matching
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def set_text_detector(self, text_detector) -> None:
|
||||||
|
"""
|
||||||
|
Set an optional text detector for hybrid matching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_detector: Instance of DetectText class from text_recognition.py
|
||||||
|
"""
|
||||||
|
self.text_detector = text_detector
|
||||||
|
self.logger.info("Text detector enabled for hybrid matching")
|
||||||
|
|
||||||
|
def extract_text(self, image: np.ndarray, min_confidence: float = 0.3) -> List[str]:
|
||||||
|
"""
|
||||||
|
Extract text from an image using the text detector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: OpenCV image (BGR format)
|
||||||
|
min_confidence: Minimum OCR confidence to accept text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of detected text strings (lowercased, stripped)
|
||||||
|
"""
|
||||||
|
if not hasattr(self, 'text_detector') or self.text_detector is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
results, _ = self.text_detector.detect(image)
|
||||||
|
# Filter by confidence and normalize text
|
||||||
|
texts = []
|
||||||
|
for text, confidence in results:
|
||||||
|
if confidence >= min_confidence:
|
||||||
|
# Normalize: lowercase, strip whitespace, remove special chars
|
||||||
|
normalized = text.lower().strip()
|
||||||
|
if len(normalized) >= 2: # Ignore single characters
|
||||||
|
texts.append(normalized)
|
||||||
|
return texts
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.warning(f"Text extraction failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def extract_text_pil(self, pil_image: Image.Image, min_confidence: float = 0.3) -> List[str]:
|
||||||
|
"""
|
||||||
|
Extract text from a PIL image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pil_image: PIL Image (RGB format)
|
||||||
|
min_confidence: Minimum OCR confidence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of detected text strings
|
||||||
|
"""
|
||||||
|
# Convert PIL to OpenCV format
|
||||||
|
cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||||
|
return self.extract_text(cv_image, min_confidence)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_text_similarity(text1_list: List[str], text2_list: List[str]) -> float:
|
||||||
|
"""
|
||||||
|
Compute fuzzy text similarity between two lists of text strings.
|
||||||
|
|
||||||
|
Uses a combination of exact matches and fuzzy matching to handle
|
||||||
|
OCR variations like case differences, spacing, and minor errors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text1_list: List of text strings from first image
|
||||||
|
text2_list: List of text strings from second image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Similarity score between 0 and 1
|
||||||
|
"""
|
||||||
|
if not text1_list or not text2_list:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Combine all text into single strings for overall comparison
|
||||||
|
text1_combined = " ".join(sorted(text1_list))
|
||||||
|
text2_combined = " ".join(sorted(text2_list))
|
||||||
|
|
||||||
|
# Method 1: Sequence matching on combined text
|
||||||
|
seq_similarity = SequenceMatcher(None, text1_combined, text2_combined).ratio()
|
||||||
|
|
||||||
|
# Method 2: Token overlap (Jaccard-like)
|
||||||
|
# Split into tokens
|
||||||
|
tokens1 = set(text1_combined.split())
|
||||||
|
tokens2 = set(text2_combined.split())
|
||||||
|
|
||||||
|
if tokens1 and tokens2:
|
||||||
|
intersection = len(tokens1 & tokens2)
|
||||||
|
union = len(tokens1 | tokens2)
|
||||||
|
token_similarity = intersection / union if union > 0 else 0
|
||||||
|
else:
|
||||||
|
token_similarity = 0
|
||||||
|
|
||||||
|
# Method 3: Best pairwise match for each text in list1
|
||||||
|
pairwise_scores = []
|
||||||
|
for t1 in text1_list:
|
||||||
|
best_match = 0
|
||||||
|
for t2 in text2_list:
|
||||||
|
score = SequenceMatcher(None, t1, t2).ratio()
|
||||||
|
best_match = max(best_match, score)
|
||||||
|
pairwise_scores.append(best_match)
|
||||||
|
|
||||||
|
pairwise_similarity = sum(pairwise_scores) / len(pairwise_scores) if pairwise_scores else 0
|
||||||
|
|
||||||
|
# Combine methods (weighted average)
|
||||||
|
combined = (seq_similarity * 0.3 + token_similarity * 0.3 + pairwise_similarity * 0.4)
|
||||||
|
|
||||||
|
return combined
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def texts_match(
|
||||||
|
ref_texts: List[str],
|
||||||
|
det_texts: List[str],
|
||||||
|
threshold: float = 0.5
|
||||||
|
) -> Tuple[bool, float]:
|
||||||
|
"""
|
||||||
|
Determine if texts match above a threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ref_texts: Text from reference logo
|
||||||
|
det_texts: Text from detected region
|
||||||
|
threshold: Minimum similarity to consider a match
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_match, similarity_score)
|
||||||
|
"""
|
||||||
|
if not ref_texts:
|
||||||
|
# Reference has no text - can't match on text
|
||||||
|
return (False, 0.0)
|
||||||
|
|
||||||
|
if not det_texts:
|
||||||
|
# Reference has text but detection doesn't - no text match
|
||||||
|
return (False, 0.0)
|
||||||
|
|
||||||
|
similarity = DetectLogosDETR.compute_text_similarity(ref_texts, det_texts)
|
||||||
|
return (similarity >= threshold, similarity)
|
||||||
|
|
||||||
|
def find_best_match_hybrid(
|
||||||
|
self,
|
||||||
|
detected_embedding: torch.Tensor,
|
||||||
|
detected_image: np.ndarray,
|
||||||
|
reference_data: Dict[str, Dict[str, Any]],
|
||||||
|
clip_threshold: float = 0.70,
|
||||||
|
clip_threshold_with_text: float = 0.60,
|
||||||
|
clip_threshold_text_mismatch: float = 0.80,
|
||||||
|
text_similarity_threshold: float = 0.5,
|
||||||
|
margin: float = 0.05,
|
||||||
|
use_mean_similarity: bool = False,
|
||||||
|
) -> Optional[Tuple[str, float, Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Find best match using hybrid text + CLIP approach.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
- If reference has text AND detection has matching text:
|
||||||
|
→ Use lower CLIP threshold (text provides additional confidence)
|
||||||
|
- If reference has text but detection doesn't match:
|
||||||
|
→ Use higher CLIP threshold (need more visual confidence)
|
||||||
|
- If reference has no text:
|
||||||
|
→ Use standard CLIP threshold
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detected_embedding: CLIP embedding from detected logo region
|
||||||
|
detected_image: OpenCV image of the detected region (for text extraction)
|
||||||
|
reference_data: Dict mapping logo name to:
|
||||||
|
{
|
||||||
|
'embeddings': List[torch.Tensor], # CLIP embeddings
|
||||||
|
'texts': List[str], # Extracted text from reference
|
||||||
|
}
|
||||||
|
clip_threshold: Standard CLIP threshold for no-text references
|
||||||
|
clip_threshold_with_text: Lower threshold when text matches
|
||||||
|
clip_threshold_text_mismatch: Higher threshold when text expected but missing
|
||||||
|
text_similarity_threshold: Threshold for text matching
|
||||||
|
margin: Required margin between best and second-best
|
||||||
|
use_mean_similarity: Use mean vs max for multi-ref aggregation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (label, clip_similarity, match_info) or None
|
||||||
|
match_info contains: text_matched, text_similarity, threshold_used
|
||||||
|
"""
|
||||||
|
if not reference_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract text from detected region
|
||||||
|
detected_texts = self.extract_text(detected_image)
|
||||||
|
|
||||||
|
# Calculate scores for all logos
|
||||||
|
logo_scores = []
|
||||||
|
|
||||||
|
for label, ref_info in reference_data.items():
|
||||||
|
ref_embeddings = ref_info.get('embeddings', [])
|
||||||
|
ref_texts = ref_info.get('texts', [])
|
||||||
|
|
||||||
|
if not ref_embeddings:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Calculate CLIP similarity
|
||||||
|
similarities = []
|
||||||
|
for ref_emb in ref_embeddings:
|
||||||
|
sim = self.compare_embeddings(detected_embedding, ref_emb)
|
||||||
|
similarities.append(sim)
|
||||||
|
|
||||||
|
if use_mean_similarity:
|
||||||
|
clip_score = sum(similarities) / len(similarities)
|
||||||
|
else:
|
||||||
|
clip_score = max(similarities)
|
||||||
|
|
||||||
|
# Determine text match status and appropriate threshold
|
||||||
|
has_ref_text = len(ref_texts) > 0
|
||||||
|
text_matched, text_sim = self.texts_match(
|
||||||
|
ref_texts, detected_texts, text_similarity_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_ref_text:
|
||||||
|
if text_matched:
|
||||||
|
# Text matches - use lower threshold, boost confidence
|
||||||
|
threshold_used = clip_threshold_with_text
|
||||||
|
match_type = "text_match"
|
||||||
|
else:
|
||||||
|
# Reference has text but detection doesn't match
|
||||||
|
# Require higher CLIP threshold
|
||||||
|
threshold_used = clip_threshold_text_mismatch
|
||||||
|
match_type = "text_mismatch"
|
||||||
|
else:
|
||||||
|
# No text in reference - standard matching
|
||||||
|
threshold_used = clip_threshold
|
||||||
|
match_type = "no_text"
|
||||||
|
text_sim = 0.0
|
||||||
|
|
||||||
|
# Check if CLIP score meets the appropriate threshold
|
||||||
|
if clip_score >= threshold_used:
|
||||||
|
logo_scores.append({
|
||||||
|
'label': label,
|
||||||
|
'clip_score': clip_score,
|
||||||
|
'text_matched': text_matched,
|
||||||
|
'text_similarity': text_sim,
|
||||||
|
'threshold_used': threshold_used,
|
||||||
|
'match_type': match_type,
|
||||||
|
'has_ref_text': has_ref_text,
|
||||||
|
})
|
||||||
|
|
||||||
|
if not logo_scores:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Sort by CLIP score descending
|
||||||
|
logo_scores.sort(key=lambda x: x['clip_score'], reverse=True)
|
||||||
|
|
||||||
|
best = logo_scores[0]
|
||||||
|
|
||||||
|
# Check margin against second-best
|
||||||
|
if margin > 0 and len(logo_scores) > 1:
|
||||||
|
second_best_score = logo_scores[1]['clip_score']
|
||||||
|
if best['clip_score'] - second_best_score < margin:
|
||||||
|
return None
|
||||||
|
|
||||||
|
match_info = {
|
||||||
|
'text_matched': best['text_matched'],
|
||||||
|
'text_similarity': best['text_similarity'],
|
||||||
|
'threshold_used': best['threshold_used'],
|
||||||
|
'match_type': best['match_type'],
|
||||||
|
'has_ref_text': best['has_ref_text'],
|
||||||
|
'detected_texts': detected_texts,
|
||||||
|
}
|
||||||
|
|
||||||
|
return (best['label'], best['clip_score'], match_info)
|
||||||
|
|
||||||
|
def prepare_reference_data_hybrid(
|
||||||
|
self,
|
||||||
|
reference_images: Dict[str, List[np.ndarray]],
|
||||||
|
text_min_confidence: float = 0.3,
|
||||||
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Prepare reference data for hybrid matching by computing embeddings and extracting text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reference_images: Dict mapping logo name to list of reference images (OpenCV BGR)
|
||||||
|
text_min_confidence: Minimum confidence for text extraction
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping logo name to {'embeddings': [...], 'texts': [...]}
|
||||||
|
"""
|
||||||
|
reference_data = {}
|
||||||
|
|
||||||
|
for logo_name, images in reference_images.items():
|
||||||
|
embeddings = []
|
||||||
|
all_texts = set()
|
||||||
|
|
||||||
|
for img in images:
|
||||||
|
# Compute CLIP embedding
|
||||||
|
emb = self.get_embedding(img)
|
||||||
|
embeddings.append(emb)
|
||||||
|
|
||||||
|
# Extract text
|
||||||
|
texts = self.extract_text(img, text_min_confidence)
|
||||||
|
all_texts.update(texts)
|
||||||
|
|
||||||
|
reference_data[logo_name] = {
|
||||||
|
'embeddings': embeddings,
|
||||||
|
'texts': list(all_texts),
|
||||||
|
}
|
||||||
|
|
||||||
|
if all_texts:
|
||||||
|
self.logger.debug(f"Reference '{logo_name}' has text: {all_texts}")
|
||||||
|
|
||||||
|
return reference_data
|
||||||
168
run_hybrid_test.sh
Executable file
168
run_hybrid_test.sh
Executable file
@ -0,0 +1,168 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Test the hybrid text+CLIP matching approach for logo detection.
|
||||||
|
#
|
||||||
|
# This approach uses text recognition to improve logo matching:
|
||||||
|
# - If reference logo has text and detection matches it: use lower CLIP threshold
|
||||||
|
# - If reference logo has text but detection doesn't match: use higher CLIP threshold
|
||||||
|
# - If reference logo has no text: use standard CLIP threshold
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./run_hybrid_test.sh
|
||||||
|
#
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
OUTPUT_FILE="${SCRIPT_DIR}/test_results/hybrid_matching_results.txt"
|
||||||
|
|
||||||
|
# Model - baseline CLIP
|
||||||
|
MODEL="openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Fixed parameters
|
||||||
|
NUM_LOGOS=20
|
||||||
|
REFS_PER_LOGO=10
|
||||||
|
POSITIVE_SAMPLES=20
|
||||||
|
NEGATIVE_SAMPLES=100
|
||||||
|
SEED=42
|
||||||
|
|
||||||
|
# Create output directory if needed
|
||||||
|
mkdir -p "${SCRIPT_DIR}/test_results"
|
||||||
|
|
||||||
|
# Clear output file and write header
|
||||||
|
cat > "$OUTPUT_FILE" << EOF
|
||||||
|
Hybrid Text+CLIP Matching Test Results
|
||||||
|
======================================
|
||||||
|
Date: $(date)
|
||||||
|
|
||||||
|
Model: ${MODEL}
|
||||||
|
|
||||||
|
Fixed Parameters:
|
||||||
|
Number of logo brands: ${NUM_LOGOS}
|
||||||
|
Refs per logo: ${REFS_PER_LOGO}
|
||||||
|
Positive samples/logo: ${POSITIVE_SAMPLES}
|
||||||
|
Negative samples/logo: ${NEGATIVE_SAMPLES}
|
||||||
|
Seed: ${SEED}
|
||||||
|
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo "Hybrid Text+CLIP Matching Test"
|
||||||
|
echo "==============================="
|
||||||
|
echo "Model: ${MODEL}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 1: Compare hybrid vs multi-ref baseline
|
||||||
|
echo "=== Test 1: Multi-ref baseline (for comparison) ==="
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "=== BASELINE: Multi-ref (max) at threshold 0.70 ===" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--min-matching-refs 1 \
|
||||||
|
--use-max-similarity \
|
||||||
|
--threshold 0.70 \
|
||||||
|
--margin 0.05 \
|
||||||
|
--seed $SEED \
|
||||||
|
--embedding-model "$MODEL" \
|
||||||
|
--output-file "$OUTPUT_FILE" \
|
||||||
|
--no-cache
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 2: Hybrid with default thresholds
|
||||||
|
echo "=== Test 2: Hybrid with default thresholds ==="
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "=== HYBRID: default thresholds (0.70/0.60/0.80) ===" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method hybrid \
|
||||||
|
--threshold 0.70 \
|
||||||
|
--hybrid-text-threshold 0.60 \
|
||||||
|
--hybrid-no-text-threshold 0.80 \
|
||||||
|
--text-similarity-threshold 0.5 \
|
||||||
|
--margin 0.05 \
|
||||||
|
--seed $SEED \
|
||||||
|
--embedding-model "$MODEL" \
|
||||||
|
--output-file "$OUTPUT_FILE" \
|
||||||
|
--no-cache
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 3: Hybrid with more aggressive text bonus
|
||||||
|
echo "=== Test 3: Hybrid with lower text-match threshold ==="
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "=== HYBRID: aggressive text bonus (0.70/0.55/0.80) ===" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method hybrid \
|
||||||
|
--threshold 0.70 \
|
||||||
|
--hybrid-text-threshold 0.55 \
|
||||||
|
--hybrid-no-text-threshold 0.80 \
|
||||||
|
--text-similarity-threshold 0.5 \
|
||||||
|
--margin 0.05 \
|
||||||
|
--seed $SEED \
|
||||||
|
--embedding-model "$MODEL" \
|
||||||
|
--output-file "$OUTPUT_FILE" \
|
||||||
|
--no-cache
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 4: Hybrid with stricter text mismatch penalty
|
||||||
|
echo "=== Test 4: Hybrid with stricter text mismatch penalty ==="
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "=== HYBRID: strict mismatch (0.70/0.60/0.85) ===" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method hybrid \
|
||||||
|
--threshold 0.70 \
|
||||||
|
--hybrid-text-threshold 0.60 \
|
||||||
|
--hybrid-no-text-threshold 0.85 \
|
||||||
|
--text-similarity-threshold 0.5 \
|
||||||
|
--margin 0.05 \
|
||||||
|
--seed $SEED \
|
||||||
|
--embedding-model "$MODEL" \
|
||||||
|
--output-file "$OUTPUT_FILE" \
|
||||||
|
--no-cache
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 5: Hybrid with lower text similarity threshold (more lenient OCR matching)
|
||||||
|
echo "=== Test 5: Hybrid with lenient text matching ==="
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "=== HYBRID: lenient text matching (text_sim=0.4) ===" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method hybrid \
|
||||||
|
--threshold 0.70 \
|
||||||
|
--hybrid-text-threshold 0.60 \
|
||||||
|
--hybrid-no-text-threshold 0.80 \
|
||||||
|
--text-similarity-threshold 0.4 \
|
||||||
|
--margin 0.05 \
|
||||||
|
--seed $SEED \
|
||||||
|
--embedding-model "$MODEL" \
|
||||||
|
--output-file "$OUTPUT_FILE" \
|
||||||
|
--no-cache
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "======================================="
|
||||||
|
echo "Tests complete!"
|
||||||
|
echo "Results saved to: $OUTPUT_FILE"
|
||||||
|
echo "======================================="
|
||||||
149
run_preprocess_test.sh
Executable file
149
run_preprocess_test.sh
Executable file
@ -0,0 +1,149 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Test different image preprocessing modes to determine if they improve
|
||||||
|
# CLIP embedding accuracy for logo matching.
|
||||||
|
#
|
||||||
|
# Preprocessing modes tested:
|
||||||
|
# - default: CLIP's default (resize shortest edge + center crop)
|
||||||
|
# - letterbox: Pad to square with black bars, preserving aspect ratio
|
||||||
|
# - stretch: Stretch to square (distorts aspect ratio)
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./run_preprocess_test.sh
|
||||||
|
#
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
OUTPUT_FILE="${SCRIPT_DIR}/test_results/preprocessing_comparison.txt"
|
||||||
|
|
||||||
|
# Model - baseline CLIP (testing preprocessing effect on standard model)
|
||||||
|
MODEL="openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Fixed parameters (same as refs_per_logo test for comparability)
|
||||||
|
NUM_LOGOS=20
|
||||||
|
REFS_PER_LOGO=10
|
||||||
|
POSITIVE_SAMPLES=20
|
||||||
|
NEGATIVE_SAMPLES=100
|
||||||
|
MIN_MATCHING_REFS=1
|
||||||
|
THRESHOLD=0.70
|
||||||
|
MARGIN=0.05
|
||||||
|
SEED=42
|
||||||
|
|
||||||
|
# Preprocessing modes to test
|
||||||
|
MODES="default letterbox stretch"
|
||||||
|
|
||||||
|
# Create output directory if needed
|
||||||
|
mkdir -p "${SCRIPT_DIR}/test_results"
|
||||||
|
|
||||||
|
# Clear output file and write header
|
||||||
|
cat > "$OUTPUT_FILE" << EOF
|
||||||
|
Image Preprocessing Comparison Test
|
||||||
|
====================================
|
||||||
|
Date: $(date)
|
||||||
|
|
||||||
|
Model: ${MODEL}
|
||||||
|
Method: multi-ref (max)
|
||||||
|
|
||||||
|
Fixed Parameters:
|
||||||
|
Number of logo brands: ${NUM_LOGOS}
|
||||||
|
Refs per logo: ${REFS_PER_LOGO}
|
||||||
|
Similarity threshold: ${THRESHOLD}
|
||||||
|
Margin: ${MARGIN}
|
||||||
|
Min matching refs: ${MIN_MATCHING_REFS}
|
||||||
|
Positive samples/logo: ${POSITIVE_SAMPLES}
|
||||||
|
Negative samples/logo: ${NEGATIVE_SAMPLES}
|
||||||
|
Seed: ${SEED}
|
||||||
|
|
||||||
|
Testing preprocessing modes: ${MODES}
|
||||||
|
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo "Image Preprocessing Comparison Test"
|
||||||
|
echo "===================================="
|
||||||
|
echo "Model: ${MODEL}"
|
||||||
|
echo "Testing preprocessing modes: ${MODES}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Results table header
|
||||||
|
echo "Results Summary:" >> "$OUTPUT_FILE"
|
||||||
|
echo "----------------" >> "$OUTPUT_FILE"
|
||||||
|
printf "%-12s %8s %8s %8s %8s %8s %8s\n" "Mode" "TP" "FP" "FN" "Prec" "Recall" "F1" >> "$OUTPUT_FILE"
|
||||||
|
echo "------------------------------------------------------------------------" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
# Track best result
|
||||||
|
BEST_F1=0
|
||||||
|
BEST_MODE="default"
|
||||||
|
|
||||||
|
for MODE in ${MODES}; do
|
||||||
|
echo "=== Testing preprocess_mode=${MODE} ==="
|
||||||
|
|
||||||
|
# Clear cache to ensure fresh embeddings with new preprocessing
|
||||||
|
rm -f "${SCRIPT_DIR}/.embedding_cache.pkl"
|
||||||
|
|
||||||
|
# Run test and capture output
|
||||||
|
OUTPUT=$(uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
|
--use-max-similarity \
|
||||||
|
--threshold $THRESHOLD \
|
||||||
|
--margin $MARGIN \
|
||||||
|
--seed $SEED \
|
||||||
|
--embedding-model "$MODEL" \
|
||||||
|
--preprocess-mode "$MODE" \
|
||||||
|
--no-cache \
|
||||||
|
2>&1)
|
||||||
|
|
||||||
|
# Extract metrics
|
||||||
|
TP=$(echo "${OUTPUT}" | grep "True Positives" | grep -oE "[0-9]+" | head -1)
|
||||||
|
FP=$(echo "${OUTPUT}" | grep "False Positives" | grep -oE "[0-9]+" | head -1)
|
||||||
|
FN=$(echo "${OUTPUT}" | grep "False Negatives" | grep -oE "[0-9]+" | head -1)
|
||||||
|
PREC=$(echo "${OUTPUT}" | grep "Precision:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||||
|
RECALL=$(echo "${OUTPUT}" | grep "Recall:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||||
|
F1=$(echo "${OUTPUT}" | grep "F1 Score:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||||
|
|
||||||
|
# Print to console
|
||||||
|
echo " TP: ${TP}, FP: ${FP}, FN: ${FN}"
|
||||||
|
echo " Precision: ${PREC}, Recall: ${RECALL}, F1: ${F1}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Add to results table
|
||||||
|
printf "%-12s %8s %8s %8s %8s %8s %8s\n" "${MODE}" "${TP}" "${FP}" "${FN}" "${PREC}" "${RECALL}" "${F1}" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
# Track best F1
|
||||||
|
F1_NUM=$(echo "${F1}" | tr -d '%')
|
||||||
|
if [ -n "$F1_NUM" ]; then
|
||||||
|
BETTER=$(echo "${F1_NUM} > ${BEST_F1}" | bc -l 2>/dev/null || echo "0")
|
||||||
|
if [ "$BETTER" = "1" ]; then
|
||||||
|
BEST_F1="${F1_NUM}"
|
||||||
|
BEST_MODE="${MODE}"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Also append full output for this test
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "======================================================================" >> "$OUTPUT_FILE"
|
||||||
|
echo "DETAILED RESULTS: preprocess_mode=${MODE}" >> "$OUTPUT_FILE"
|
||||||
|
echo "======================================================================" >> "$OUTPUT_FILE"
|
||||||
|
echo "${OUTPUT}" | grep -A 50 "Configuration:" | head -30 >> "$OUTPUT_FILE"
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
done
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
echo "------------------------------------------------------------------------" >> "$OUTPUT_FILE"
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "BEST PREPROCESSING MODE: ${BEST_MODE} (F1 = ${BEST_F1}%)" >> "$OUTPUT_FILE"
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "Notes:" >> "$OUTPUT_FILE"
|
||||||
|
echo " - default: CLIP's standard preprocessing (resize shortest edge + center crop)" >> "$OUTPUT_FILE"
|
||||||
|
echo " - letterbox: Pads image to square with black bars, preserving aspect ratio" >> "$OUTPUT_FILE"
|
||||||
|
echo " - stretch: Resizes image to square, distorting aspect ratio" >> "$OUTPUT_FILE"
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo "======================================="
|
||||||
|
echo "BEST: preprocess_mode=${BEST_MODE} (F1 = ${BEST_F1}%)"
|
||||||
|
echo "======================================="
|
||||||
|
echo ""
|
||||||
|
echo "Results saved to: $OUTPUT_FILE"
|
||||||
@ -18,7 +18,7 @@ import random
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
@ -243,11 +243,12 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--matching-method",
|
"--matching-method",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["simple", "margin", "multi-ref"],
|
choices=["simple", "margin", "multi-ref", "hybrid"],
|
||||||
default="margin",
|
default="margin",
|
||||||
help="Matching method: 'simple' returns all matches above threshold, "
|
help="Matching method: 'simple' returns all matches above threshold, "
|
||||||
"'margin' requires confidence margin over 2nd best, "
|
"'margin' requires confidence margin over 2nd best, "
|
||||||
"'multi-ref' aggregates scores across reference images (default: margin)",
|
"'multi-ref' aggregates scores across reference images, "
|
||||||
|
"'hybrid' combines text recognition with CLIP (default: margin)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--min-matching-refs",
|
"--min-matching-refs",
|
||||||
@ -260,6 +261,25 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="For 'multi-ref' method: use max similarity instead of mean across references",
|
help="For 'multi-ref' method: use max similarity instead of mean across references",
|
||||||
)
|
)
|
||||||
|
# Hybrid method arguments
|
||||||
|
parser.add_argument(
|
||||||
|
"--hybrid-text-threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.60,
|
||||||
|
help="For 'hybrid' method: CLIP threshold when text matches (default: 0.60)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hybrid-no-text-threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.80,
|
||||||
|
help="For 'hybrid' method: CLIP threshold when text expected but not found (default: 0.80)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-similarity-threshold",
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help="For 'hybrid' method: minimum text similarity to consider a match (default: 0.5)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-v", "--verbose",
|
"-v", "--verbose",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -286,6 +306,14 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
help="Append results summary to this file (no progress output, just results)",
|
help="Append results summary to this file (no progress output, just results)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--preprocess-mode",
|
||||||
|
type=str,
|
||||||
|
choices=["default", "letterbox", "stretch"],
|
||||||
|
default="default",
|
||||||
|
help="Image preprocessing mode for CLIP: 'default' (resize+center crop), "
|
||||||
|
"'letterbox' (pad to square with black bars), 'stretch' (distort to square)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger = setup_logging(args.verbose)
|
logger = setup_logging(args.verbose)
|
||||||
@ -315,12 +343,23 @@ def main():
|
|||||||
|
|
||||||
# Initialize detector
|
# Initialize detector
|
||||||
logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}")
|
logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}")
|
||||||
|
if args.preprocess_mode != "default":
|
||||||
|
logger.info(f"Using preprocessing mode: {args.preprocess_mode}")
|
||||||
detector = DetectLogosDETR(
|
detector = DetectLogosDETR(
|
||||||
logger=logger,
|
logger=logger,
|
||||||
detr_threshold=args.detr_threshold,
|
detr_threshold=args.detr_threshold,
|
||||||
embedding_model=args.embedding_model,
|
embedding_model=args.embedding_model,
|
||||||
|
preprocess_mode=args.preprocess_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize text detector for hybrid method
|
||||||
|
text_detector = None
|
||||||
|
if args.matching_method == "hybrid":
|
||||||
|
logger.info("Initializing text detector for hybrid matching...")
|
||||||
|
from text_recognition import DetectText
|
||||||
|
text_detector = DetectText(logger=logger, threshold=0.3)
|
||||||
|
detector.set_text_detector(text_detector)
|
||||||
|
|
||||||
# Load ground truth (both mappings)
|
# Load ground truth (both mappings)
|
||||||
logger.info("Loading ground truth from database...")
|
logger.info("Loading ground truth from database...")
|
||||||
image_to_logos, logo_to_images = get_ground_truth(db_path)
|
image_to_logos, logo_to_images = get_ground_truth(db_path)
|
||||||
@ -338,10 +377,15 @@ def main():
|
|||||||
multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {}
|
multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {}
|
||||||
# List for margin-based matching: (logo_name, embedding) tuples
|
# List for margin-based matching: (logo_name, embedding) tuples
|
||||||
reference_embeddings: List[Tuple[str, torch.Tensor]] = []
|
reference_embeddings: List[Tuple[str, torch.Tensor]] = []
|
||||||
|
# Dict for hybrid matching: logo_name -> {'embeddings': [...], 'texts': [...]}
|
||||||
|
hybrid_reference_data: Dict[str, Dict[str, Any]] = {}
|
||||||
total_refs = 0
|
total_refs = 0
|
||||||
|
logos_with_text = 0
|
||||||
|
|
||||||
for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"):
|
for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"):
|
||||||
multi_ref_embeddings[logo_name] = []
|
multi_ref_embeddings[logo_name] = []
|
||||||
|
if args.matching_method == "hybrid":
|
||||||
|
hybrid_reference_data[logo_name] = {'embeddings': [], 'texts': set()}
|
||||||
|
|
||||||
for ref_filename in ref_filenames:
|
for ref_filename in ref_filenames:
|
||||||
ref_path = reference_dir / ref_filename
|
ref_path = reference_dir / ref_filename
|
||||||
@ -354,11 +398,15 @@ def main():
|
|||||||
cache_key = f"ref:{ref_filename}"
|
cache_key = f"ref:{ref_filename}"
|
||||||
embedding = cache.get(cache_key) if cache else None
|
embedding = cache.get(cache_key) if cache else None
|
||||||
|
|
||||||
if embedding is None:
|
# Load image if needed (for embedding or text extraction)
|
||||||
|
img = None
|
||||||
|
if embedding is None or args.matching_method == "hybrid":
|
||||||
img = load_image(ref_path)
|
img = load_image(ref_path)
|
||||||
if img is None:
|
if img is None:
|
||||||
logger.warning(f"Failed to load reference logo: {ref_path}")
|
logger.warning(f"Failed to load reference logo: {ref_path}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if embedding is None:
|
||||||
embedding = detector.get_embedding(img)
|
embedding = detector.get_embedding(img)
|
||||||
if cache:
|
if cache:
|
||||||
cache.put(cache_key, embedding)
|
cache.put(cache_key, embedding)
|
||||||
@ -367,7 +415,21 @@ def main():
|
|||||||
reference_embeddings.append((logo_name, embedding))
|
reference_embeddings.append((logo_name, embedding))
|
||||||
total_refs += 1
|
total_refs += 1
|
||||||
|
|
||||||
|
# Extract text for hybrid method
|
||||||
|
if args.matching_method == "hybrid" and img is not None:
|
||||||
|
hybrid_reference_data[logo_name]['embeddings'].append(embedding)
|
||||||
|
texts = detector.extract_text(img, min_confidence=0.3)
|
||||||
|
hybrid_reference_data[logo_name]['texts'].update(texts)
|
||||||
|
|
||||||
|
# Convert text set to list for hybrid data
|
||||||
|
if args.matching_method == "hybrid":
|
||||||
|
hybrid_reference_data[logo_name]['texts'] = list(hybrid_reference_data[logo_name]['texts'])
|
||||||
|
if hybrid_reference_data[logo_name]['texts']:
|
||||||
|
logos_with_text += 1
|
||||||
|
|
||||||
logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos")
|
logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos")
|
||||||
|
if args.matching_method == "hybrid":
|
||||||
|
logger.info(f"Extracted text from {logos_with_text}/{len(sampled_logos)} reference logos")
|
||||||
|
|
||||||
# Build test set: for each logo, sample positive and negative images
|
# Build test set: for each logo, sample positive and negative images
|
||||||
logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...")
|
logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...")
|
||||||
@ -442,17 +504,26 @@ def main():
|
|||||||
cache_key = f"det:{test_filename}"
|
cache_key = f"det:{test_filename}"
|
||||||
cached_detections = cache.get(cache_key) if cache else None
|
cached_detections = cache.get(cache_key) if cache else None
|
||||||
|
|
||||||
|
# For hybrid matching, we always need the original image for text extraction
|
||||||
|
test_img = None
|
||||||
|
if args.matching_method == "hybrid":
|
||||||
|
test_img = load_image(test_path)
|
||||||
|
if test_img is None:
|
||||||
|
logger.warning(f"Failed to load test image: {test_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
if cached_detections is not None:
|
if cached_detections is not None:
|
||||||
# Cached detections contain serialized box data and embeddings
|
# Cached detections contain serialized box data and embeddings
|
||||||
detections = cached_detections
|
detections = cached_detections
|
||||||
else:
|
else:
|
||||||
# Load and detect
|
# Load and detect
|
||||||
img = load_image(test_path)
|
if test_img is None:
|
||||||
if img is None:
|
test_img = load_image(test_path)
|
||||||
logger.warning(f"Failed to load test image: {test_path}")
|
if test_img is None:
|
||||||
continue
|
logger.warning(f"Failed to load test image: {test_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
detections = detector.detect(img)
|
detections = detector.detect(test_img)
|
||||||
|
|
||||||
# Cache the detections
|
# Cache the detections
|
||||||
if cache:
|
if cache:
|
||||||
@ -549,7 +620,7 @@ def main():
|
|||||||
"correct": is_correct,
|
"correct": is_correct,
|
||||||
})
|
})
|
||||||
|
|
||||||
else: # multi-ref
|
elif args.matching_method == "multi-ref":
|
||||||
# Multi-ref matching: aggregates scores across reference images
|
# Multi-ref matching: aggregates scores across reference images
|
||||||
match_result = detector.find_best_match_multi_ref(
|
match_result = detector.find_best_match_multi_ref(
|
||||||
detection["embedding"],
|
detection["embedding"],
|
||||||
@ -580,6 +651,50 @@ def main():
|
|||||||
"correct": is_correct,
|
"correct": is_correct,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
else: # hybrid
|
||||||
|
# Hybrid matching: combines text recognition with CLIP
|
||||||
|
# Extract crop from original image for text extraction
|
||||||
|
box = detection["box"]
|
||||||
|
crop = test_img[
|
||||||
|
int(box["ymin"]):int(box["ymax"]),
|
||||||
|
int(box["xmin"]):int(box["xmax"])
|
||||||
|
]
|
||||||
|
|
||||||
|
match_result = detector.find_best_match_hybrid(
|
||||||
|
detected_embedding=detection["embedding"],
|
||||||
|
detected_image=crop,
|
||||||
|
reference_data=hybrid_reference_data,
|
||||||
|
clip_threshold=args.threshold,
|
||||||
|
clip_threshold_with_text=args.hybrid_text_threshold,
|
||||||
|
clip_threshold_text_mismatch=args.hybrid_no_text_threshold,
|
||||||
|
text_similarity_threshold=args.text_similarity_threshold,
|
||||||
|
margin=args.margin,
|
||||||
|
use_mean_similarity=not args.use_max_similarity,
|
||||||
|
)
|
||||||
|
if match_result:
|
||||||
|
label, similarity, match_info = match_result
|
||||||
|
matched_logos.add(label)
|
||||||
|
|
||||||
|
is_correct = label in expected_logos
|
||||||
|
if is_correct:
|
||||||
|
true_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["true_positive_sims"].append(similarity)
|
||||||
|
else:
|
||||||
|
false_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["false_positive_sims"].append(similarity)
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"test_image": test_filename,
|
||||||
|
"matched_logo": label,
|
||||||
|
"similarity": similarity,
|
||||||
|
"correct": is_correct,
|
||||||
|
"text_matched": match_info.get("text_matched", False),
|
||||||
|
"text_similarity": match_info.get("text_similarity", 0),
|
||||||
|
"match_type": match_info.get("match_type", "unknown"),
|
||||||
|
})
|
||||||
|
|
||||||
# Count missed detections (false negatives)
|
# Count missed detections (false negatives)
|
||||||
missed = expected_logos - matched_logos
|
missed = expected_logos - matched_logos
|
||||||
false_negatives += len(missed)
|
false_negatives += len(missed)
|
||||||
@ -625,12 +740,18 @@ def main():
|
|||||||
print(f" Test images processed: {len(test_images)}")
|
print(f" Test images processed: {len(test_images)}")
|
||||||
print(f" CLIP similarity threshold: {args.threshold}")
|
print(f" CLIP similarity threshold: {args.threshold}")
|
||||||
print(f" DETR confidence threshold: {args.detr_threshold}")
|
print(f" DETR confidence threshold: {args.detr_threshold}")
|
||||||
|
print(f" Preprocess mode: {args.preprocess_mode}")
|
||||||
print(f" Matching method: {args.matching_method}")
|
print(f" Matching method: {args.matching_method}")
|
||||||
if args.matching_method in ("margin", "multi-ref"):
|
if args.matching_method in ("margin", "multi-ref", "hybrid"):
|
||||||
print(f" Matching margin: {args.margin}")
|
print(f" Matching margin: {args.margin}")
|
||||||
if args.matching_method == "multi-ref":
|
if args.matching_method == "multi-ref":
|
||||||
print(f" Min matching refs: {args.min_matching_refs}")
|
print(f" Min matching refs: {args.min_matching_refs}")
|
||||||
print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}")
|
print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}")
|
||||||
|
if args.matching_method == "hybrid":
|
||||||
|
print(f" CLIP threshold (text match): {args.hybrid_text_threshold}")
|
||||||
|
print(f" CLIP threshold (no text): {args.hybrid_no_text_threshold}")
|
||||||
|
print(f" Text similarity threshold: {args.text_similarity_threshold}")
|
||||||
|
print(f" Refs with text: {logos_with_text}/{len(sampled_logos)}")
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
print(f" Random seed: {args.seed}")
|
print(f" Random seed: {args.seed}")
|
||||||
|
|
||||||
@ -818,9 +939,14 @@ def write_results_to_file(
|
|||||||
method_desc = "Simple (all matches above threshold)"
|
method_desc = "Simple (all matches above threshold)"
|
||||||
elif args.matching_method == "margin":
|
elif args.matching_method == "margin":
|
||||||
method_desc = f"Margin-based (margin={args.margin})"
|
method_desc = f"Margin-based (margin={args.margin})"
|
||||||
else: # multi-ref
|
elif args.matching_method == "multi-ref":
|
||||||
agg = "max" if args.use_max_similarity else "mean"
|
agg = "max" if args.use_max_similarity else "mean"
|
||||||
method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})"
|
method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})"
|
||||||
|
else: # hybrid
|
||||||
|
method_desc = (
|
||||||
|
f"Hybrid (text+CLIP, text_thresh={args.hybrid_text_threshold}, "
|
||||||
|
f"no_text_thresh={args.hybrid_no_text_threshold}, margin={args.margin})"
|
||||||
|
)
|
||||||
|
|
||||||
lines = [
|
lines = [
|
||||||
"=" * 70,
|
"=" * 70,
|
||||||
@ -832,6 +958,7 @@ def write_results_to_file(
|
|||||||
"",
|
"",
|
||||||
"Configuration:",
|
"Configuration:",
|
||||||
f" Embedding model: {args.embedding_model}",
|
f" Embedding model: {args.embedding_model}",
|
||||||
|
f" Preprocess mode: {args.preprocess_mode}",
|
||||||
f" Reference logos: {num_logos}",
|
f" Reference logos: {num_logos}",
|
||||||
f" Refs per logo: {args.refs_per_logo}",
|
f" Refs per logo: {args.refs_per_logo}",
|
||||||
f" Total reference embeddings:{total_refs}",
|
f" Total reference embeddings:{total_refs}",
|
||||||
|
|||||||
Reference in New Issue
Block a user