Add hybrid text+CLIP matching and image preprocessing

Hybrid matching combines text recognition with CLIP similarity:
- If reference logo has text and detection matches: lower CLIP threshold
- If reference has text but detection doesn't match: higher threshold
- If reference has no text: standard threshold

Image preprocessing adds letterbox/stretch modes for CLIP input to
preserve aspect ratio instead of center cropping.

New files:
- run_hybrid_test.sh: Test hybrid matching configurations
- run_preprocess_test.sh: Compare preprocessing modes

Changes to logo_detection_detr.py:
- Add preprocess_mode parameter (default/letterbox/stretch)
- Add set_text_detector() for hybrid matching
- Add extract_text() using EasyOCR
- Add compute_text_similarity() with fuzzy matching
- Add find_best_match_hybrid() with tiered thresholds

Changes to test_logo_detection.py:
- Add --matching-method hybrid option
- Add --preprocess-mode option
- Add hybrid threshold arguments
This commit is contained in:
Rick McEwen
2026-01-07 15:09:09 -05:00
parent 78f46f04bf
commit 49f982611a
4 changed files with 817 additions and 13 deletions

View File

@ -23,6 +23,7 @@ import cv2
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Dict, Optional, Any from typing import List, Tuple, Dict, Optional, Any
from difflib import SequenceMatcher
class DetectLogosDETR: class DetectLogosDETR:
@ -49,6 +50,7 @@ class DetectLogosDETR:
detr_threshold: float = 0.5, detr_threshold: float = 0.5,
min_box_size: int = 20, min_box_size: int = 20,
nms_iou_threshold: float = 0.5, nms_iou_threshold: float = 0.5,
preprocess_mode: str = "default",
): ):
""" """
Initialize DETR and embedding models. Initialize DETR and embedding models.
@ -64,12 +66,17 @@ class DetectLogosDETR:
detr_threshold: Confidence threshold for DETR detections (0-1) detr_threshold: Confidence threshold for DETR detections (0-1)
min_box_size: Minimum width/height in pixels for detected boxes (filters noise) min_box_size: Minimum width/height in pixels for detected boxes (filters noise)
nms_iou_threshold: IoU threshold for Non-Maximum Suppression nms_iou_threshold: IoU threshold for Non-Maximum Suppression
preprocess_mode: Image preprocessing mode for CLIP:
- "default": Use CLIP's default (resize shortest edge + center crop)
- "letterbox": Pad to square with black bars, preserving aspect ratio
- "stretch": Stretch to square (distorts aspect ratio)
""" """
self.logger = logger self.logger = logger
self.detr_threshold = detr_threshold self.detr_threshold = detr_threshold
self.min_box_size = min_box_size self.min_box_size = min_box_size
self.nms_iou_threshold = nms_iou_threshold self.nms_iou_threshold = nms_iou_threshold
self.embedding_model_name = embedding_model self.embedding_model_name = embedding_model
self.preprocess_mode = preprocess_mode
# Set device # Set device
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu" self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
@ -116,6 +123,8 @@ class DetectLogosDETR:
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device) self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path) self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
if self.preprocess_mode != "default":
self.logger.info(f"Image preprocessing mode: {self.preprocess_mode}")
self.logger.info("DetectLogosDETR initialization complete") self.logger.info("DetectLogosDETR initialization complete")
def _detect_model_type(self, model_name: str) -> str: def _detect_model_type(self, model_name: str) -> str:
@ -402,6 +411,46 @@ class DetectLogosDETR:
return self._get_embedding_pil(pil_image) return self._get_embedding_pil(pil_image)
def _preprocess_image(self, pil_image: Image.Image, target_size: int = 224) -> Image.Image:
"""
Preprocess image based on the configured preprocessing mode.
Args:
pil_image: PIL Image (RGB format)
target_size: Target size for the square output (default 224 for CLIP)
Returns:
Preprocessed PIL Image
"""
if self.preprocess_mode == "default":
# Let the processor handle it (resize shortest edge + center crop)
return pil_image
width, height = pil_image.size
if self.preprocess_mode == "letterbox":
# Pad to square with black bars, preserving aspect ratio
max_dim = max(width, height)
# Create a black square canvas
new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
# Paste the original image centered
paste_x = (max_dim - width) // 2
paste_y = (max_dim - height) // 2
new_image.paste(pil_image, (paste_x, paste_y))
# Resize to target size
return new_image.resize((target_size, target_size), Image.LANCZOS)
elif self.preprocess_mode == "stretch":
# Stretch to square (distorts aspect ratio)
return pil_image.resize((target_size, target_size), Image.LANCZOS)
else:
# Unknown mode, return original
return pil_image
def _get_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor: def _get_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor:
""" """
Internal method to get embedding from PIL image. Internal method to get embedding from PIL image.
@ -414,6 +463,10 @@ class DetectLogosDETR:
Returns: Returns:
Normalized feature embedding (torch.Tensor) Normalized feature embedding (torch.Tensor)
""" """
# Apply preprocessing if configured
if self.preprocess_mode != "default":
pil_image = self._preprocess_image(pil_image)
# Process image through the embedding model # Process image through the embedding model
inputs = self.embedding_processor(images=pil_image, return_tensors="pt").to(self.device) inputs = self.embedding_processor(images=pil_image, return_tensors="pt").to(self.device)
@ -713,3 +766,310 @@ class DetectLogosDETR:
) )
return matched_detections return matched_detections
# =========================================================================
# Hybrid Text + CLIP Matching
# =========================================================================
def set_text_detector(self, text_detector) -> None:
"""
Set an optional text detector for hybrid matching.
Args:
text_detector: Instance of DetectText class from text_recognition.py
"""
self.text_detector = text_detector
self.logger.info("Text detector enabled for hybrid matching")
def extract_text(self, image: np.ndarray, min_confidence: float = 0.3) -> List[str]:
"""
Extract text from an image using the text detector.
Args:
image: OpenCV image (BGR format)
min_confidence: Minimum OCR confidence to accept text
Returns:
List of detected text strings (lowercased, stripped)
"""
if not hasattr(self, 'text_detector') or self.text_detector is None:
return []
try:
results, _ = self.text_detector.detect(image)
# Filter by confidence and normalize text
texts = []
for text, confidence in results:
if confidence >= min_confidence:
# Normalize: lowercase, strip whitespace, remove special chars
normalized = text.lower().strip()
if len(normalized) >= 2: # Ignore single characters
texts.append(normalized)
return texts
except Exception as e:
self.logger.warning(f"Text extraction failed: {e}")
return []
def extract_text_pil(self, pil_image: Image.Image, min_confidence: float = 0.3) -> List[str]:
"""
Extract text from a PIL image.
Args:
pil_image: PIL Image (RGB format)
min_confidence: Minimum OCR confidence
Returns:
List of detected text strings
"""
# Convert PIL to OpenCV format
cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return self.extract_text(cv_image, min_confidence)
@staticmethod
def compute_text_similarity(text1_list: List[str], text2_list: List[str]) -> float:
"""
Compute fuzzy text similarity between two lists of text strings.
Uses a combination of exact matches and fuzzy matching to handle
OCR variations like case differences, spacing, and minor errors.
Args:
text1_list: List of text strings from first image
text2_list: List of text strings from second image
Returns:
Similarity score between 0 and 1
"""
if not text1_list or not text2_list:
return 0.0
# Combine all text into single strings for overall comparison
text1_combined = " ".join(sorted(text1_list))
text2_combined = " ".join(sorted(text2_list))
# Method 1: Sequence matching on combined text
seq_similarity = SequenceMatcher(None, text1_combined, text2_combined).ratio()
# Method 2: Token overlap (Jaccard-like)
# Split into tokens
tokens1 = set(text1_combined.split())
tokens2 = set(text2_combined.split())
if tokens1 and tokens2:
intersection = len(tokens1 & tokens2)
union = len(tokens1 | tokens2)
token_similarity = intersection / union if union > 0 else 0
else:
token_similarity = 0
# Method 3: Best pairwise match for each text in list1
pairwise_scores = []
for t1 in text1_list:
best_match = 0
for t2 in text2_list:
score = SequenceMatcher(None, t1, t2).ratio()
best_match = max(best_match, score)
pairwise_scores.append(best_match)
pairwise_similarity = sum(pairwise_scores) / len(pairwise_scores) if pairwise_scores else 0
# Combine methods (weighted average)
combined = (seq_similarity * 0.3 + token_similarity * 0.3 + pairwise_similarity * 0.4)
return combined
@staticmethod
def texts_match(
ref_texts: List[str],
det_texts: List[str],
threshold: float = 0.5
) -> Tuple[bool, float]:
"""
Determine if texts match above a threshold.
Args:
ref_texts: Text from reference logo
det_texts: Text from detected region
threshold: Minimum similarity to consider a match
Returns:
Tuple of (is_match, similarity_score)
"""
if not ref_texts:
# Reference has no text - can't match on text
return (False, 0.0)
if not det_texts:
# Reference has text but detection doesn't - no text match
return (False, 0.0)
similarity = DetectLogosDETR.compute_text_similarity(ref_texts, det_texts)
return (similarity >= threshold, similarity)
def find_best_match_hybrid(
self,
detected_embedding: torch.Tensor,
detected_image: np.ndarray,
reference_data: Dict[str, Dict[str, Any]],
clip_threshold: float = 0.70,
clip_threshold_with_text: float = 0.60,
clip_threshold_text_mismatch: float = 0.80,
text_similarity_threshold: float = 0.5,
margin: float = 0.05,
use_mean_similarity: bool = False,
) -> Optional[Tuple[str, float, Dict[str, Any]]]:
"""
Find best match using hybrid text + CLIP approach.
Strategy:
- If reference has text AND detection has matching text:
→ Use lower CLIP threshold (text provides additional confidence)
- If reference has text but detection doesn't match:
→ Use higher CLIP threshold (need more visual confidence)
- If reference has no text:
→ Use standard CLIP threshold
Args:
detected_embedding: CLIP embedding from detected logo region
detected_image: OpenCV image of the detected region (for text extraction)
reference_data: Dict mapping logo name to:
{
'embeddings': List[torch.Tensor], # CLIP embeddings
'texts': List[str], # Extracted text from reference
}
clip_threshold: Standard CLIP threshold for no-text references
clip_threshold_with_text: Lower threshold when text matches
clip_threshold_text_mismatch: Higher threshold when text expected but missing
text_similarity_threshold: Threshold for text matching
margin: Required margin between best and second-best
use_mean_similarity: Use mean vs max for multi-ref aggregation
Returns:
Tuple of (label, clip_similarity, match_info) or None
match_info contains: text_matched, text_similarity, threshold_used
"""
if not reference_data:
return None
# Extract text from detected region
detected_texts = self.extract_text(detected_image)
# Calculate scores for all logos
logo_scores = []
for label, ref_info in reference_data.items():
ref_embeddings = ref_info.get('embeddings', [])
ref_texts = ref_info.get('texts', [])
if not ref_embeddings:
continue
# Calculate CLIP similarity
similarities = []
for ref_emb in ref_embeddings:
sim = self.compare_embeddings(detected_embedding, ref_emb)
similarities.append(sim)
if use_mean_similarity:
clip_score = sum(similarities) / len(similarities)
else:
clip_score = max(similarities)
# Determine text match status and appropriate threshold
has_ref_text = len(ref_texts) > 0
text_matched, text_sim = self.texts_match(
ref_texts, detected_texts, text_similarity_threshold
)
if has_ref_text:
if text_matched:
# Text matches - use lower threshold, boost confidence
threshold_used = clip_threshold_with_text
match_type = "text_match"
else:
# Reference has text but detection doesn't match
# Require higher CLIP threshold
threshold_used = clip_threshold_text_mismatch
match_type = "text_mismatch"
else:
# No text in reference - standard matching
threshold_used = clip_threshold
match_type = "no_text"
text_sim = 0.0
# Check if CLIP score meets the appropriate threshold
if clip_score >= threshold_used:
logo_scores.append({
'label': label,
'clip_score': clip_score,
'text_matched': text_matched,
'text_similarity': text_sim,
'threshold_used': threshold_used,
'match_type': match_type,
'has_ref_text': has_ref_text,
})
if not logo_scores:
return None
# Sort by CLIP score descending
logo_scores.sort(key=lambda x: x['clip_score'], reverse=True)
best = logo_scores[0]
# Check margin against second-best
if margin > 0 and len(logo_scores) > 1:
second_best_score = logo_scores[1]['clip_score']
if best['clip_score'] - second_best_score < margin:
return None
match_info = {
'text_matched': best['text_matched'],
'text_similarity': best['text_similarity'],
'threshold_used': best['threshold_used'],
'match_type': best['match_type'],
'has_ref_text': best['has_ref_text'],
'detected_texts': detected_texts,
}
return (best['label'], best['clip_score'], match_info)
def prepare_reference_data_hybrid(
self,
reference_images: Dict[str, List[np.ndarray]],
text_min_confidence: float = 0.3,
) -> Dict[str, Dict[str, Any]]:
"""
Prepare reference data for hybrid matching by computing embeddings and extracting text.
Args:
reference_images: Dict mapping logo name to list of reference images (OpenCV BGR)
text_min_confidence: Minimum confidence for text extraction
Returns:
Dict mapping logo name to {'embeddings': [...], 'texts': [...]}
"""
reference_data = {}
for logo_name, images in reference_images.items():
embeddings = []
all_texts = set()
for img in images:
# Compute CLIP embedding
emb = self.get_embedding(img)
embeddings.append(emb)
# Extract text
texts = self.extract_text(img, text_min_confidence)
all_texts.update(texts)
reference_data[logo_name] = {
'embeddings': embeddings,
'texts': list(all_texts),
}
if all_texts:
self.logger.debug(f"Reference '{logo_name}' has text: {all_texts}")
return reference_data

168
run_hybrid_test.sh Executable file
View File

@ -0,0 +1,168 @@
#!/bin/bash
#
# Test the hybrid text+CLIP matching approach for logo detection.
#
# This approach uses text recognition to improve logo matching:
# - If reference logo has text and detection matches it: use lower CLIP threshold
# - If reference logo has text but detection doesn't match: use higher CLIP threshold
# - If reference logo has no text: use standard CLIP threshold
#
# Usage:
# ./run_hybrid_test.sh
#
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OUTPUT_FILE="${SCRIPT_DIR}/test_results/hybrid_matching_results.txt"
# Model - baseline CLIP
MODEL="openai/clip-vit-large-patch14"
# Fixed parameters
NUM_LOGOS=20
REFS_PER_LOGO=10
POSITIVE_SAMPLES=20
NEGATIVE_SAMPLES=100
SEED=42
# Create output directory if needed
mkdir -p "${SCRIPT_DIR}/test_results"
# Clear output file and write header
cat > "$OUTPUT_FILE" << EOF
Hybrid Text+CLIP Matching Test Results
======================================
Date: $(date)
Model: ${MODEL}
Fixed Parameters:
Number of logo brands: ${NUM_LOGOS}
Refs per logo: ${REFS_PER_LOGO}
Positive samples/logo: ${POSITIVE_SAMPLES}
Negative samples/logo: ${NEGATIVE_SAMPLES}
Seed: ${SEED}
EOF
echo "Hybrid Text+CLIP Matching Test"
echo "==============================="
echo "Model: ${MODEL}"
echo ""
# Test 1: Compare hybrid vs multi-ref baseline
echo "=== Test 1: Multi-ref baseline (for comparison) ==="
echo "" >> "$OUTPUT_FILE"
echo "=== BASELINE: Multi-ref (max) at threshold 0.70 ===" >> "$OUTPUT_FILE"
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
--num-logos $NUM_LOGOS \
--refs-per-logo $REFS_PER_LOGO \
--positive-samples $POSITIVE_SAMPLES \
--negative-samples $NEGATIVE_SAMPLES \
--matching-method multi-ref \
--min-matching-refs 1 \
--use-max-similarity \
--threshold 0.70 \
--margin 0.05 \
--seed $SEED \
--embedding-model "$MODEL" \
--output-file "$OUTPUT_FILE" \
--no-cache
echo ""
# Test 2: Hybrid with default thresholds
echo "=== Test 2: Hybrid with default thresholds ==="
echo "" >> "$OUTPUT_FILE"
echo "=== HYBRID: default thresholds (0.70/0.60/0.80) ===" >> "$OUTPUT_FILE"
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
--num-logos $NUM_LOGOS \
--refs-per-logo $REFS_PER_LOGO \
--positive-samples $POSITIVE_SAMPLES \
--negative-samples $NEGATIVE_SAMPLES \
--matching-method hybrid \
--threshold 0.70 \
--hybrid-text-threshold 0.60 \
--hybrid-no-text-threshold 0.80 \
--text-similarity-threshold 0.5 \
--margin 0.05 \
--seed $SEED \
--embedding-model "$MODEL" \
--output-file "$OUTPUT_FILE" \
--no-cache
echo ""
# Test 3: Hybrid with more aggressive text bonus
echo "=== Test 3: Hybrid with lower text-match threshold ==="
echo "" >> "$OUTPUT_FILE"
echo "=== HYBRID: aggressive text bonus (0.70/0.55/0.80) ===" >> "$OUTPUT_FILE"
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
--num-logos $NUM_LOGOS \
--refs-per-logo $REFS_PER_LOGO \
--positive-samples $POSITIVE_SAMPLES \
--negative-samples $NEGATIVE_SAMPLES \
--matching-method hybrid \
--threshold 0.70 \
--hybrid-text-threshold 0.55 \
--hybrid-no-text-threshold 0.80 \
--text-similarity-threshold 0.5 \
--margin 0.05 \
--seed $SEED \
--embedding-model "$MODEL" \
--output-file "$OUTPUT_FILE" \
--no-cache
echo ""
# Test 4: Hybrid with stricter text mismatch penalty
echo "=== Test 4: Hybrid with stricter text mismatch penalty ==="
echo "" >> "$OUTPUT_FILE"
echo "=== HYBRID: strict mismatch (0.70/0.60/0.85) ===" >> "$OUTPUT_FILE"
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
--num-logos $NUM_LOGOS \
--refs-per-logo $REFS_PER_LOGO \
--positive-samples $POSITIVE_SAMPLES \
--negative-samples $NEGATIVE_SAMPLES \
--matching-method hybrid \
--threshold 0.70 \
--hybrid-text-threshold 0.60 \
--hybrid-no-text-threshold 0.85 \
--text-similarity-threshold 0.5 \
--margin 0.05 \
--seed $SEED \
--embedding-model "$MODEL" \
--output-file "$OUTPUT_FILE" \
--no-cache
echo ""
# Test 5: Hybrid with lower text similarity threshold (more lenient OCR matching)
echo "=== Test 5: Hybrid with lenient text matching ==="
echo "" >> "$OUTPUT_FILE"
echo "=== HYBRID: lenient text matching (text_sim=0.4) ===" >> "$OUTPUT_FILE"
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
--num-logos $NUM_LOGOS \
--refs-per-logo $REFS_PER_LOGO \
--positive-samples $POSITIVE_SAMPLES \
--negative-samples $NEGATIVE_SAMPLES \
--matching-method hybrid \
--threshold 0.70 \
--hybrid-text-threshold 0.60 \
--hybrid-no-text-threshold 0.80 \
--text-similarity-threshold 0.4 \
--margin 0.05 \
--seed $SEED \
--embedding-model "$MODEL" \
--output-file "$OUTPUT_FILE" \
--no-cache
echo ""
echo "======================================="
echo "Tests complete!"
echo "Results saved to: $OUTPUT_FILE"
echo "======================================="

149
run_preprocess_test.sh Executable file
View File

@ -0,0 +1,149 @@
#!/bin/bash
#
# Test different image preprocessing modes to determine if they improve
# CLIP embedding accuracy for logo matching.
#
# Preprocessing modes tested:
# - default: CLIP's default (resize shortest edge + center crop)
# - letterbox: Pad to square with black bars, preserving aspect ratio
# - stretch: Stretch to square (distorts aspect ratio)
#
# Usage:
# ./run_preprocess_test.sh
#
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OUTPUT_FILE="${SCRIPT_DIR}/test_results/preprocessing_comparison.txt"
# Model - baseline CLIP (testing preprocessing effect on standard model)
MODEL="openai/clip-vit-large-patch14"
# Fixed parameters (same as refs_per_logo test for comparability)
NUM_LOGOS=20
REFS_PER_LOGO=10
POSITIVE_SAMPLES=20
NEGATIVE_SAMPLES=100
MIN_MATCHING_REFS=1
THRESHOLD=0.70
MARGIN=0.05
SEED=42
# Preprocessing modes to test
MODES="default letterbox stretch"
# Create output directory if needed
mkdir -p "${SCRIPT_DIR}/test_results"
# Clear output file and write header
cat > "$OUTPUT_FILE" << EOF
Image Preprocessing Comparison Test
====================================
Date: $(date)
Model: ${MODEL}
Method: multi-ref (max)
Fixed Parameters:
Number of logo brands: ${NUM_LOGOS}
Refs per logo: ${REFS_PER_LOGO}
Similarity threshold: ${THRESHOLD}
Margin: ${MARGIN}
Min matching refs: ${MIN_MATCHING_REFS}
Positive samples/logo: ${POSITIVE_SAMPLES}
Negative samples/logo: ${NEGATIVE_SAMPLES}
Seed: ${SEED}
Testing preprocessing modes: ${MODES}
EOF
echo "Image Preprocessing Comparison Test"
echo "===================================="
echo "Model: ${MODEL}"
echo "Testing preprocessing modes: ${MODES}"
echo ""
# Results table header
echo "Results Summary:" >> "$OUTPUT_FILE"
echo "----------------" >> "$OUTPUT_FILE"
printf "%-12s %8s %8s %8s %8s %8s %8s\n" "Mode" "TP" "FP" "FN" "Prec" "Recall" "F1" >> "$OUTPUT_FILE"
echo "------------------------------------------------------------------------" >> "$OUTPUT_FILE"
# Track best result
BEST_F1=0
BEST_MODE="default"
for MODE in ${MODES}; do
echo "=== Testing preprocess_mode=${MODE} ==="
# Clear cache to ensure fresh embeddings with new preprocessing
rm -f "${SCRIPT_DIR}/.embedding_cache.pkl"
# Run test and capture output
OUTPUT=$(uv run python "$SCRIPT_DIR/test_logo_detection.py" \
--num-logos $NUM_LOGOS \
--refs-per-logo $REFS_PER_LOGO \
--positive-samples $POSITIVE_SAMPLES \
--negative-samples $NEGATIVE_SAMPLES \
--matching-method multi-ref \
--min-matching-refs $MIN_MATCHING_REFS \
--use-max-similarity \
--threshold $THRESHOLD \
--margin $MARGIN \
--seed $SEED \
--embedding-model "$MODEL" \
--preprocess-mode "$MODE" \
--no-cache \
2>&1)
# Extract metrics
TP=$(echo "${OUTPUT}" | grep "True Positives" | grep -oE "[0-9]+" | head -1)
FP=$(echo "${OUTPUT}" | grep "False Positives" | grep -oE "[0-9]+" | head -1)
FN=$(echo "${OUTPUT}" | grep "False Negatives" | grep -oE "[0-9]+" | head -1)
PREC=$(echo "${OUTPUT}" | grep "Precision:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
RECALL=$(echo "${OUTPUT}" | grep "Recall:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
F1=$(echo "${OUTPUT}" | grep "F1 Score:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
# Print to console
echo " TP: ${TP}, FP: ${FP}, FN: ${FN}"
echo " Precision: ${PREC}, Recall: ${RECALL}, F1: ${F1}"
echo ""
# Add to results table
printf "%-12s %8s %8s %8s %8s %8s %8s\n" "${MODE}" "${TP}" "${FP}" "${FN}" "${PREC}" "${RECALL}" "${F1}" >> "$OUTPUT_FILE"
# Track best F1
F1_NUM=$(echo "${F1}" | tr -d '%')
if [ -n "$F1_NUM" ]; then
BETTER=$(echo "${F1_NUM} > ${BEST_F1}" | bc -l 2>/dev/null || echo "0")
if [ "$BETTER" = "1" ]; then
BEST_F1="${F1_NUM}"
BEST_MODE="${MODE}"
fi
fi
# Also append full output for this test
echo "" >> "$OUTPUT_FILE"
echo "======================================================================" >> "$OUTPUT_FILE"
echo "DETAILED RESULTS: preprocess_mode=${MODE}" >> "$OUTPUT_FILE"
echo "======================================================================" >> "$OUTPUT_FILE"
echo "${OUTPUT}" | grep -A 50 "Configuration:" | head -30 >> "$OUTPUT_FILE"
echo "" >> "$OUTPUT_FILE"
done
# Summary
echo "------------------------------------------------------------------------" >> "$OUTPUT_FILE"
echo "" >> "$OUTPUT_FILE"
echo "BEST PREPROCESSING MODE: ${BEST_MODE} (F1 = ${BEST_F1}%)" >> "$OUTPUT_FILE"
echo "" >> "$OUTPUT_FILE"
echo "Notes:" >> "$OUTPUT_FILE"
echo " - default: CLIP's standard preprocessing (resize shortest edge + center crop)" >> "$OUTPUT_FILE"
echo " - letterbox: Pads image to square with black bars, preserving aspect ratio" >> "$OUTPUT_FILE"
echo " - stretch: Resizes image to square, distorting aspect ratio" >> "$OUTPUT_FILE"
echo "" >> "$OUTPUT_FILE"
echo "======================================="
echo "BEST: preprocess_mode=${BEST_MODE} (F1 = ${BEST_F1}%)"
echo "======================================="
echo ""
echo "Results saved to: $OUTPUT_FILE"

View File

@ -18,7 +18,7 @@ import random
import sqlite3 import sqlite3
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple from typing import Any, Dict, List, Optional, Set, Tuple
import cv2 import cv2
import torch import torch
@ -243,11 +243,12 @@ def main():
parser.add_argument( parser.add_argument(
"--matching-method", "--matching-method",
type=str, type=str,
choices=["simple", "margin", "multi-ref"], choices=["simple", "margin", "multi-ref", "hybrid"],
default="margin", default="margin",
help="Matching method: 'simple' returns all matches above threshold, " help="Matching method: 'simple' returns all matches above threshold, "
"'margin' requires confidence margin over 2nd best, " "'margin' requires confidence margin over 2nd best, "
"'multi-ref' aggregates scores across reference images (default: margin)", "'multi-ref' aggregates scores across reference images, "
"'hybrid' combines text recognition with CLIP (default: margin)",
) )
parser.add_argument( parser.add_argument(
"--min-matching-refs", "--min-matching-refs",
@ -260,6 +261,25 @@ def main():
action="store_true", action="store_true",
help="For 'multi-ref' method: use max similarity instead of mean across references", help="For 'multi-ref' method: use max similarity instead of mean across references",
) )
# Hybrid method arguments
parser.add_argument(
"--hybrid-text-threshold",
type=float,
default=0.60,
help="For 'hybrid' method: CLIP threshold when text matches (default: 0.60)",
)
parser.add_argument(
"--hybrid-no-text-threshold",
type=float,
default=0.80,
help="For 'hybrid' method: CLIP threshold when text expected but not found (default: 0.80)",
)
parser.add_argument(
"--text-similarity-threshold",
type=float,
default=0.5,
help="For 'hybrid' method: minimum text similarity to consider a match (default: 0.5)",
)
parser.add_argument( parser.add_argument(
"-v", "--verbose", "-v", "--verbose",
action="store_true", action="store_true",
@ -286,6 +306,14 @@ def main():
default=None, default=None,
help="Append results summary to this file (no progress output, just results)", help="Append results summary to this file (no progress output, just results)",
) )
parser.add_argument(
"--preprocess-mode",
type=str,
choices=["default", "letterbox", "stretch"],
default="default",
help="Image preprocessing mode for CLIP: 'default' (resize+center crop), "
"'letterbox' (pad to square with black bars), 'stretch' (distort to square)",
)
args = parser.parse_args() args = parser.parse_args()
logger = setup_logging(args.verbose) logger = setup_logging(args.verbose)
@ -315,12 +343,23 @@ def main():
# Initialize detector # Initialize detector
logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}") logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}")
if args.preprocess_mode != "default":
logger.info(f"Using preprocessing mode: {args.preprocess_mode}")
detector = DetectLogosDETR( detector = DetectLogosDETR(
logger=logger, logger=logger,
detr_threshold=args.detr_threshold, detr_threshold=args.detr_threshold,
embedding_model=args.embedding_model, embedding_model=args.embedding_model,
preprocess_mode=args.preprocess_mode,
) )
# Initialize text detector for hybrid method
text_detector = None
if args.matching_method == "hybrid":
logger.info("Initializing text detector for hybrid matching...")
from text_recognition import DetectText
text_detector = DetectText(logger=logger, threshold=0.3)
detector.set_text_detector(text_detector)
# Load ground truth (both mappings) # Load ground truth (both mappings)
logger.info("Loading ground truth from database...") logger.info("Loading ground truth from database...")
image_to_logos, logo_to_images = get_ground_truth(db_path) image_to_logos, logo_to_images = get_ground_truth(db_path)
@ -338,10 +377,15 @@ def main():
multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {} multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {}
# List for margin-based matching: (logo_name, embedding) tuples # List for margin-based matching: (logo_name, embedding) tuples
reference_embeddings: List[Tuple[str, torch.Tensor]] = [] reference_embeddings: List[Tuple[str, torch.Tensor]] = []
# Dict for hybrid matching: logo_name -> {'embeddings': [...], 'texts': [...]}
hybrid_reference_data: Dict[str, Dict[str, Any]] = {}
total_refs = 0 total_refs = 0
logos_with_text = 0
for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"): for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"):
multi_ref_embeddings[logo_name] = [] multi_ref_embeddings[logo_name] = []
if args.matching_method == "hybrid":
hybrid_reference_data[logo_name] = {'embeddings': [], 'texts': set()}
for ref_filename in ref_filenames: for ref_filename in ref_filenames:
ref_path = reference_dir / ref_filename ref_path = reference_dir / ref_filename
@ -354,11 +398,15 @@ def main():
cache_key = f"ref:{ref_filename}" cache_key = f"ref:{ref_filename}"
embedding = cache.get(cache_key) if cache else None embedding = cache.get(cache_key) if cache else None
if embedding is None: # Load image if needed (for embedding or text extraction)
img = None
if embedding is None or args.matching_method == "hybrid":
img = load_image(ref_path) img = load_image(ref_path)
if img is None: if img is None:
logger.warning(f"Failed to load reference logo: {ref_path}") logger.warning(f"Failed to load reference logo: {ref_path}")
continue continue
if embedding is None:
embedding = detector.get_embedding(img) embedding = detector.get_embedding(img)
if cache: if cache:
cache.put(cache_key, embedding) cache.put(cache_key, embedding)
@ -367,7 +415,21 @@ def main():
reference_embeddings.append((logo_name, embedding)) reference_embeddings.append((logo_name, embedding))
total_refs += 1 total_refs += 1
# Extract text for hybrid method
if args.matching_method == "hybrid" and img is not None:
hybrid_reference_data[logo_name]['embeddings'].append(embedding)
texts = detector.extract_text(img, min_confidence=0.3)
hybrid_reference_data[logo_name]['texts'].update(texts)
# Convert text set to list for hybrid data
if args.matching_method == "hybrid":
hybrid_reference_data[logo_name]['texts'] = list(hybrid_reference_data[logo_name]['texts'])
if hybrid_reference_data[logo_name]['texts']:
logos_with_text += 1
logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos") logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos")
if args.matching_method == "hybrid":
logger.info(f"Extracted text from {logos_with_text}/{len(sampled_logos)} reference logos")
# Build test set: for each logo, sample positive and negative images # Build test set: for each logo, sample positive and negative images
logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...") logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...")
@ -442,17 +504,26 @@ def main():
cache_key = f"det:{test_filename}" cache_key = f"det:{test_filename}"
cached_detections = cache.get(cache_key) if cache else None cached_detections = cache.get(cache_key) if cache else None
# For hybrid matching, we always need the original image for text extraction
test_img = None
if args.matching_method == "hybrid":
test_img = load_image(test_path)
if test_img is None:
logger.warning(f"Failed to load test image: {test_path}")
continue
if cached_detections is not None: if cached_detections is not None:
# Cached detections contain serialized box data and embeddings # Cached detections contain serialized box data and embeddings
detections = cached_detections detections = cached_detections
else: else:
# Load and detect # Load and detect
img = load_image(test_path) if test_img is None:
if img is None: test_img = load_image(test_path)
logger.warning(f"Failed to load test image: {test_path}") if test_img is None:
continue logger.warning(f"Failed to load test image: {test_path}")
continue
detections = detector.detect(img) detections = detector.detect(test_img)
# Cache the detections # Cache the detections
if cache: if cache:
@ -549,7 +620,7 @@ def main():
"correct": is_correct, "correct": is_correct,
}) })
else: # multi-ref elif args.matching_method == "multi-ref":
# Multi-ref matching: aggregates scores across reference images # Multi-ref matching: aggregates scores across reference images
match_result = detector.find_best_match_multi_ref( match_result = detector.find_best_match_multi_ref(
detection["embedding"], detection["embedding"],
@ -580,6 +651,50 @@ def main():
"correct": is_correct, "correct": is_correct,
}) })
else: # hybrid
# Hybrid matching: combines text recognition with CLIP
# Extract crop from original image for text extraction
box = detection["box"]
crop = test_img[
int(box["ymin"]):int(box["ymax"]),
int(box["xmin"]):int(box["xmax"])
]
match_result = detector.find_best_match_hybrid(
detected_embedding=detection["embedding"],
detected_image=crop,
reference_data=hybrid_reference_data,
clip_threshold=args.threshold,
clip_threshold_with_text=args.hybrid_text_threshold,
clip_threshold_text_mismatch=args.hybrid_no_text_threshold,
text_similarity_threshold=args.text_similarity_threshold,
margin=args.margin,
use_mean_similarity=not args.use_max_similarity,
)
if match_result:
label, similarity, match_info = match_result
matched_logos.add(label)
is_correct = label in expected_logos
if is_correct:
true_positives += 1
if args.similarity_details:
similarity_details["true_positive_sims"].append(similarity)
else:
false_positives += 1
if args.similarity_details:
similarity_details["false_positive_sims"].append(similarity)
results.append({
"test_image": test_filename,
"matched_logo": label,
"similarity": similarity,
"correct": is_correct,
"text_matched": match_info.get("text_matched", False),
"text_similarity": match_info.get("text_similarity", 0),
"match_type": match_info.get("match_type", "unknown"),
})
# Count missed detections (false negatives) # Count missed detections (false negatives)
missed = expected_logos - matched_logos missed = expected_logos - matched_logos
false_negatives += len(missed) false_negatives += len(missed)
@ -625,12 +740,18 @@ def main():
print(f" Test images processed: {len(test_images)}") print(f" Test images processed: {len(test_images)}")
print(f" CLIP similarity threshold: {args.threshold}") print(f" CLIP similarity threshold: {args.threshold}")
print(f" DETR confidence threshold: {args.detr_threshold}") print(f" DETR confidence threshold: {args.detr_threshold}")
print(f" Preprocess mode: {args.preprocess_mode}")
print(f" Matching method: {args.matching_method}") print(f" Matching method: {args.matching_method}")
if args.matching_method in ("margin", "multi-ref"): if args.matching_method in ("margin", "multi-ref", "hybrid"):
print(f" Matching margin: {args.margin}") print(f" Matching margin: {args.margin}")
if args.matching_method == "multi-ref": if args.matching_method == "multi-ref":
print(f" Min matching refs: {args.min_matching_refs}") print(f" Min matching refs: {args.min_matching_refs}")
print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}") print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}")
if args.matching_method == "hybrid":
print(f" CLIP threshold (text match): {args.hybrid_text_threshold}")
print(f" CLIP threshold (no text): {args.hybrid_no_text_threshold}")
print(f" Text similarity threshold: {args.text_similarity_threshold}")
print(f" Refs with text: {logos_with_text}/{len(sampled_logos)}")
if args.seed is not None: if args.seed is not None:
print(f" Random seed: {args.seed}") print(f" Random seed: {args.seed}")
@ -818,9 +939,14 @@ def write_results_to_file(
method_desc = "Simple (all matches above threshold)" method_desc = "Simple (all matches above threshold)"
elif args.matching_method == "margin": elif args.matching_method == "margin":
method_desc = f"Margin-based (margin={args.margin})" method_desc = f"Margin-based (margin={args.margin})"
else: # multi-ref elif args.matching_method == "multi-ref":
agg = "max" if args.use_max_similarity else "mean" agg = "max" if args.use_max_similarity else "mean"
method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})" method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})"
else: # hybrid
method_desc = (
f"Hybrid (text+CLIP, text_thresh={args.hybrid_text_threshold}, "
f"no_text_thresh={args.hybrid_no_text_threshold}, margin={args.margin})"
)
lines = [ lines = [
"=" * 70, "=" * 70,
@ -832,6 +958,7 @@ def write_results_to_file(
"", "",
"Configuration:", "Configuration:",
f" Embedding model: {args.embedding_model}", f" Embedding model: {args.embedding_model}",
f" Preprocess mode: {args.preprocess_mode}",
f" Reference logos: {num_logos}", f" Reference logos: {num_logos}",
f" Refs per logo: {args.refs_per_logo}", f" Refs per logo: {args.refs_per_logo}",
f" Total reference embeddings:{total_refs}", f" Total reference embeddings:{total_refs}",