diff --git a/logo_detection_detr.py b/logo_detection_detr.py index 7cbbd8d..06e42b6 100644 --- a/logo_detection_detr.py +++ b/logo_detection_detr.py @@ -23,6 +23,7 @@ import cv2 import numpy as np from pathlib import Path from typing import List, Tuple, Dict, Optional, Any +from difflib import SequenceMatcher class DetectLogosDETR: @@ -49,6 +50,7 @@ class DetectLogosDETR: detr_threshold: float = 0.5, min_box_size: int = 20, nms_iou_threshold: float = 0.5, + preprocess_mode: str = "default", ): """ Initialize DETR and embedding models. @@ -64,12 +66,17 @@ class DetectLogosDETR: detr_threshold: Confidence threshold for DETR detections (0-1) min_box_size: Minimum width/height in pixels for detected boxes (filters noise) nms_iou_threshold: IoU threshold for Non-Maximum Suppression + preprocess_mode: Image preprocessing mode for CLIP: + - "default": Use CLIP's default (resize shortest edge + center crop) + - "letterbox": Pad to square with black bars, preserving aspect ratio + - "stretch": Stretch to square (distorts aspect ratio) """ self.logger = logger self.detr_threshold = detr_threshold self.min_box_size = min_box_size self.nms_iou_threshold = nms_iou_threshold self.embedding_model_name = embedding_model + self.preprocess_mode = preprocess_mode # Set device self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu" @@ -116,6 +123,8 @@ class DetectLogosDETR: self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device) self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path) + if self.preprocess_mode != "default": + self.logger.info(f"Image preprocessing mode: {self.preprocess_mode}") self.logger.info("DetectLogosDETR initialization complete") def _detect_model_type(self, model_name: str) -> str: @@ -402,6 +411,46 @@ class DetectLogosDETR: return self._get_embedding_pil(pil_image) + def _preprocess_image(self, pil_image: Image.Image, target_size: int = 224) -> Image.Image: + """ + Preprocess image based on the configured preprocessing mode. + + Args: + pil_image: PIL Image (RGB format) + target_size: Target size for the square output (default 224 for CLIP) + + Returns: + Preprocessed PIL Image + """ + if self.preprocess_mode == "default": + # Let the processor handle it (resize shortest edge + center crop) + return pil_image + + width, height = pil_image.size + + if self.preprocess_mode == "letterbox": + # Pad to square with black bars, preserving aspect ratio + max_dim = max(width, height) + + # Create a black square canvas + new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) + + # Paste the original image centered + paste_x = (max_dim - width) // 2 + paste_y = (max_dim - height) // 2 + new_image.paste(pil_image, (paste_x, paste_y)) + + # Resize to target size + return new_image.resize((target_size, target_size), Image.LANCZOS) + + elif self.preprocess_mode == "stretch": + # Stretch to square (distorts aspect ratio) + return pil_image.resize((target_size, target_size), Image.LANCZOS) + + else: + # Unknown mode, return original + return pil_image + def _get_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor: """ Internal method to get embedding from PIL image. @@ -414,6 +463,10 @@ class DetectLogosDETR: Returns: Normalized feature embedding (torch.Tensor) """ + # Apply preprocessing if configured + if self.preprocess_mode != "default": + pil_image = self._preprocess_image(pil_image) + # Process image through the embedding model inputs = self.embedding_processor(images=pil_image, return_tensors="pt").to(self.device) @@ -712,4 +765,311 @@ class DetectLogosDETR: f"(threshold: {similarity_threshold})" ) - return matched_detections \ No newline at end of file + return matched_detections + + # ========================================================================= + # Hybrid Text + CLIP Matching + # ========================================================================= + + def set_text_detector(self, text_detector) -> None: + """ + Set an optional text detector for hybrid matching. + + Args: + text_detector: Instance of DetectText class from text_recognition.py + """ + self.text_detector = text_detector + self.logger.info("Text detector enabled for hybrid matching") + + def extract_text(self, image: np.ndarray, min_confidence: float = 0.3) -> List[str]: + """ + Extract text from an image using the text detector. + + Args: + image: OpenCV image (BGR format) + min_confidence: Minimum OCR confidence to accept text + + Returns: + List of detected text strings (lowercased, stripped) + """ + if not hasattr(self, 'text_detector') or self.text_detector is None: + return [] + + try: + results, _ = self.text_detector.detect(image) + # Filter by confidence and normalize text + texts = [] + for text, confidence in results: + if confidence >= min_confidence: + # Normalize: lowercase, strip whitespace, remove special chars + normalized = text.lower().strip() + if len(normalized) >= 2: # Ignore single characters + texts.append(normalized) + return texts + except Exception as e: + self.logger.warning(f"Text extraction failed: {e}") + return [] + + def extract_text_pil(self, pil_image: Image.Image, min_confidence: float = 0.3) -> List[str]: + """ + Extract text from a PIL image. + + Args: + pil_image: PIL Image (RGB format) + min_confidence: Minimum OCR confidence + + Returns: + List of detected text strings + """ + # Convert PIL to OpenCV format + cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) + return self.extract_text(cv_image, min_confidence) + + @staticmethod + def compute_text_similarity(text1_list: List[str], text2_list: List[str]) -> float: + """ + Compute fuzzy text similarity between two lists of text strings. + + Uses a combination of exact matches and fuzzy matching to handle + OCR variations like case differences, spacing, and minor errors. + + Args: + text1_list: List of text strings from first image + text2_list: List of text strings from second image + + Returns: + Similarity score between 0 and 1 + """ + if not text1_list or not text2_list: + return 0.0 + + # Combine all text into single strings for overall comparison + text1_combined = " ".join(sorted(text1_list)) + text2_combined = " ".join(sorted(text2_list)) + + # Method 1: Sequence matching on combined text + seq_similarity = SequenceMatcher(None, text1_combined, text2_combined).ratio() + + # Method 2: Token overlap (Jaccard-like) + # Split into tokens + tokens1 = set(text1_combined.split()) + tokens2 = set(text2_combined.split()) + + if tokens1 and tokens2: + intersection = len(tokens1 & tokens2) + union = len(tokens1 | tokens2) + token_similarity = intersection / union if union > 0 else 0 + else: + token_similarity = 0 + + # Method 3: Best pairwise match for each text in list1 + pairwise_scores = [] + for t1 in text1_list: + best_match = 0 + for t2 in text2_list: + score = SequenceMatcher(None, t1, t2).ratio() + best_match = max(best_match, score) + pairwise_scores.append(best_match) + + pairwise_similarity = sum(pairwise_scores) / len(pairwise_scores) if pairwise_scores else 0 + + # Combine methods (weighted average) + combined = (seq_similarity * 0.3 + token_similarity * 0.3 + pairwise_similarity * 0.4) + + return combined + + @staticmethod + def texts_match( + ref_texts: List[str], + det_texts: List[str], + threshold: float = 0.5 + ) -> Tuple[bool, float]: + """ + Determine if texts match above a threshold. + + Args: + ref_texts: Text from reference logo + det_texts: Text from detected region + threshold: Minimum similarity to consider a match + + Returns: + Tuple of (is_match, similarity_score) + """ + if not ref_texts: + # Reference has no text - can't match on text + return (False, 0.0) + + if not det_texts: + # Reference has text but detection doesn't - no text match + return (False, 0.0) + + similarity = DetectLogosDETR.compute_text_similarity(ref_texts, det_texts) + return (similarity >= threshold, similarity) + + def find_best_match_hybrid( + self, + detected_embedding: torch.Tensor, + detected_image: np.ndarray, + reference_data: Dict[str, Dict[str, Any]], + clip_threshold: float = 0.70, + clip_threshold_with_text: float = 0.60, + clip_threshold_text_mismatch: float = 0.80, + text_similarity_threshold: float = 0.5, + margin: float = 0.05, + use_mean_similarity: bool = False, + ) -> Optional[Tuple[str, float, Dict[str, Any]]]: + """ + Find best match using hybrid text + CLIP approach. + + Strategy: + - If reference has text AND detection has matching text: + → Use lower CLIP threshold (text provides additional confidence) + - If reference has text but detection doesn't match: + → Use higher CLIP threshold (need more visual confidence) + - If reference has no text: + → Use standard CLIP threshold + + Args: + detected_embedding: CLIP embedding from detected logo region + detected_image: OpenCV image of the detected region (for text extraction) + reference_data: Dict mapping logo name to: + { + 'embeddings': List[torch.Tensor], # CLIP embeddings + 'texts': List[str], # Extracted text from reference + } + clip_threshold: Standard CLIP threshold for no-text references + clip_threshold_with_text: Lower threshold when text matches + clip_threshold_text_mismatch: Higher threshold when text expected but missing + text_similarity_threshold: Threshold for text matching + margin: Required margin between best and second-best + use_mean_similarity: Use mean vs max for multi-ref aggregation + + Returns: + Tuple of (label, clip_similarity, match_info) or None + match_info contains: text_matched, text_similarity, threshold_used + """ + if not reference_data: + return None + + # Extract text from detected region + detected_texts = self.extract_text(detected_image) + + # Calculate scores for all logos + logo_scores = [] + + for label, ref_info in reference_data.items(): + ref_embeddings = ref_info.get('embeddings', []) + ref_texts = ref_info.get('texts', []) + + if not ref_embeddings: + continue + + # Calculate CLIP similarity + similarities = [] + for ref_emb in ref_embeddings: + sim = self.compare_embeddings(detected_embedding, ref_emb) + similarities.append(sim) + + if use_mean_similarity: + clip_score = sum(similarities) / len(similarities) + else: + clip_score = max(similarities) + + # Determine text match status and appropriate threshold + has_ref_text = len(ref_texts) > 0 + text_matched, text_sim = self.texts_match( + ref_texts, detected_texts, text_similarity_threshold + ) + + if has_ref_text: + if text_matched: + # Text matches - use lower threshold, boost confidence + threshold_used = clip_threshold_with_text + match_type = "text_match" + else: + # Reference has text but detection doesn't match + # Require higher CLIP threshold + threshold_used = clip_threshold_text_mismatch + match_type = "text_mismatch" + else: + # No text in reference - standard matching + threshold_used = clip_threshold + match_type = "no_text" + text_sim = 0.0 + + # Check if CLIP score meets the appropriate threshold + if clip_score >= threshold_used: + logo_scores.append({ + 'label': label, + 'clip_score': clip_score, + 'text_matched': text_matched, + 'text_similarity': text_sim, + 'threshold_used': threshold_used, + 'match_type': match_type, + 'has_ref_text': has_ref_text, + }) + + if not logo_scores: + return None + + # Sort by CLIP score descending + logo_scores.sort(key=lambda x: x['clip_score'], reverse=True) + + best = logo_scores[0] + + # Check margin against second-best + if margin > 0 and len(logo_scores) > 1: + second_best_score = logo_scores[1]['clip_score'] + if best['clip_score'] - second_best_score < margin: + return None + + match_info = { + 'text_matched': best['text_matched'], + 'text_similarity': best['text_similarity'], + 'threshold_used': best['threshold_used'], + 'match_type': best['match_type'], + 'has_ref_text': best['has_ref_text'], + 'detected_texts': detected_texts, + } + + return (best['label'], best['clip_score'], match_info) + + def prepare_reference_data_hybrid( + self, + reference_images: Dict[str, List[np.ndarray]], + text_min_confidence: float = 0.3, + ) -> Dict[str, Dict[str, Any]]: + """ + Prepare reference data for hybrid matching by computing embeddings and extracting text. + + Args: + reference_images: Dict mapping logo name to list of reference images (OpenCV BGR) + text_min_confidence: Minimum confidence for text extraction + + Returns: + Dict mapping logo name to {'embeddings': [...], 'texts': [...]} + """ + reference_data = {} + + for logo_name, images in reference_images.items(): + embeddings = [] + all_texts = set() + + for img in images: + # Compute CLIP embedding + emb = self.get_embedding(img) + embeddings.append(emb) + + # Extract text + texts = self.extract_text(img, text_min_confidence) + all_texts.update(texts) + + reference_data[logo_name] = { + 'embeddings': embeddings, + 'texts': list(all_texts), + } + + if all_texts: + self.logger.debug(f"Reference '{logo_name}' has text: {all_texts}") + + return reference_data \ No newline at end of file diff --git a/run_hybrid_test.sh b/run_hybrid_test.sh new file mode 100755 index 0000000..99843d9 --- /dev/null +++ b/run_hybrid_test.sh @@ -0,0 +1,168 @@ +#!/bin/bash +# +# Test the hybrid text+CLIP matching approach for logo detection. +# +# This approach uses text recognition to improve logo matching: +# - If reference logo has text and detection matches it: use lower CLIP threshold +# - If reference logo has text but detection doesn't match: use higher CLIP threshold +# - If reference logo has no text: use standard CLIP threshold +# +# Usage: +# ./run_hybrid_test.sh +# + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +OUTPUT_FILE="${SCRIPT_DIR}/test_results/hybrid_matching_results.txt" + +# Model - baseline CLIP +MODEL="openai/clip-vit-large-patch14" + +# Fixed parameters +NUM_LOGOS=20 +REFS_PER_LOGO=10 +POSITIVE_SAMPLES=20 +NEGATIVE_SAMPLES=100 +SEED=42 + +# Create output directory if needed +mkdir -p "${SCRIPT_DIR}/test_results" + +# Clear output file and write header +cat > "$OUTPUT_FILE" << EOF +Hybrid Text+CLIP Matching Test Results +====================================== +Date: $(date) + +Model: ${MODEL} + +Fixed Parameters: + Number of logo brands: ${NUM_LOGOS} + Refs per logo: ${REFS_PER_LOGO} + Positive samples/logo: ${POSITIVE_SAMPLES} + Negative samples/logo: ${NEGATIVE_SAMPLES} + Seed: ${SEED} + +EOF + +echo "Hybrid Text+CLIP Matching Test" +echo "===============================" +echo "Model: ${MODEL}" +echo "" + +# Test 1: Compare hybrid vs multi-ref baseline +echo "=== Test 1: Multi-ref baseline (for comparison) ===" +echo "" >> "$OUTPUT_FILE" +echo "=== BASELINE: Multi-ref (max) at threshold 0.70 ===" >> "$OUTPUT_FILE" + +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method multi-ref \ + --min-matching-refs 1 \ + --use-max-similarity \ + --threshold 0.70 \ + --margin 0.05 \ + --seed $SEED \ + --embedding-model "$MODEL" \ + --output-file "$OUTPUT_FILE" \ + --no-cache + +echo "" + +# Test 2: Hybrid with default thresholds +echo "=== Test 2: Hybrid with default thresholds ===" +echo "" >> "$OUTPUT_FILE" +echo "=== HYBRID: default thresholds (0.70/0.60/0.80) ===" >> "$OUTPUT_FILE" + +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method hybrid \ + --threshold 0.70 \ + --hybrid-text-threshold 0.60 \ + --hybrid-no-text-threshold 0.80 \ + --text-similarity-threshold 0.5 \ + --margin 0.05 \ + --seed $SEED \ + --embedding-model "$MODEL" \ + --output-file "$OUTPUT_FILE" \ + --no-cache + +echo "" + +# Test 3: Hybrid with more aggressive text bonus +echo "=== Test 3: Hybrid with lower text-match threshold ===" +echo "" >> "$OUTPUT_FILE" +echo "=== HYBRID: aggressive text bonus (0.70/0.55/0.80) ===" >> "$OUTPUT_FILE" + +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method hybrid \ + --threshold 0.70 \ + --hybrid-text-threshold 0.55 \ + --hybrid-no-text-threshold 0.80 \ + --text-similarity-threshold 0.5 \ + --margin 0.05 \ + --seed $SEED \ + --embedding-model "$MODEL" \ + --output-file "$OUTPUT_FILE" \ + --no-cache + +echo "" + +# Test 4: Hybrid with stricter text mismatch penalty +echo "=== Test 4: Hybrid with stricter text mismatch penalty ===" +echo "" >> "$OUTPUT_FILE" +echo "=== HYBRID: strict mismatch (0.70/0.60/0.85) ===" >> "$OUTPUT_FILE" + +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method hybrid \ + --threshold 0.70 \ + --hybrid-text-threshold 0.60 \ + --hybrid-no-text-threshold 0.85 \ + --text-similarity-threshold 0.5 \ + --margin 0.05 \ + --seed $SEED \ + --embedding-model "$MODEL" \ + --output-file "$OUTPUT_FILE" \ + --no-cache + +echo "" + +# Test 5: Hybrid with lower text similarity threshold (more lenient OCR matching) +echo "=== Test 5: Hybrid with lenient text matching ===" +echo "" >> "$OUTPUT_FILE" +echo "=== HYBRID: lenient text matching (text_sim=0.4) ===" >> "$OUTPUT_FILE" + +uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method hybrid \ + --threshold 0.70 \ + --hybrid-text-threshold 0.60 \ + --hybrid-no-text-threshold 0.80 \ + --text-similarity-threshold 0.4 \ + --margin 0.05 \ + --seed $SEED \ + --embedding-model "$MODEL" \ + --output-file "$OUTPUT_FILE" \ + --no-cache + +echo "" +echo "=======================================" +echo "Tests complete!" +echo "Results saved to: $OUTPUT_FILE" +echo "=======================================" diff --git a/run_preprocess_test.sh b/run_preprocess_test.sh new file mode 100755 index 0000000..ddcbf72 --- /dev/null +++ b/run_preprocess_test.sh @@ -0,0 +1,149 @@ +#!/bin/bash +# +# Test different image preprocessing modes to determine if they improve +# CLIP embedding accuracy for logo matching. +# +# Preprocessing modes tested: +# - default: CLIP's default (resize shortest edge + center crop) +# - letterbox: Pad to square with black bars, preserving aspect ratio +# - stretch: Stretch to square (distorts aspect ratio) +# +# Usage: +# ./run_preprocess_test.sh +# + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +OUTPUT_FILE="${SCRIPT_DIR}/test_results/preprocessing_comparison.txt" + +# Model - baseline CLIP (testing preprocessing effect on standard model) +MODEL="openai/clip-vit-large-patch14" + +# Fixed parameters (same as refs_per_logo test for comparability) +NUM_LOGOS=20 +REFS_PER_LOGO=10 +POSITIVE_SAMPLES=20 +NEGATIVE_SAMPLES=100 +MIN_MATCHING_REFS=1 +THRESHOLD=0.70 +MARGIN=0.05 +SEED=42 + +# Preprocessing modes to test +MODES="default letterbox stretch" + +# Create output directory if needed +mkdir -p "${SCRIPT_DIR}/test_results" + +# Clear output file and write header +cat > "$OUTPUT_FILE" << EOF +Image Preprocessing Comparison Test +==================================== +Date: $(date) + +Model: ${MODEL} +Method: multi-ref (max) + +Fixed Parameters: + Number of logo brands: ${NUM_LOGOS} + Refs per logo: ${REFS_PER_LOGO} + Similarity threshold: ${THRESHOLD} + Margin: ${MARGIN} + Min matching refs: ${MIN_MATCHING_REFS} + Positive samples/logo: ${POSITIVE_SAMPLES} + Negative samples/logo: ${NEGATIVE_SAMPLES} + Seed: ${SEED} + +Testing preprocessing modes: ${MODES} + +EOF + +echo "Image Preprocessing Comparison Test" +echo "====================================" +echo "Model: ${MODEL}" +echo "Testing preprocessing modes: ${MODES}" +echo "" + +# Results table header +echo "Results Summary:" >> "$OUTPUT_FILE" +echo "----------------" >> "$OUTPUT_FILE" +printf "%-12s %8s %8s %8s %8s %8s %8s\n" "Mode" "TP" "FP" "FN" "Prec" "Recall" "F1" >> "$OUTPUT_FILE" +echo "------------------------------------------------------------------------" >> "$OUTPUT_FILE" + +# Track best result +BEST_F1=0 +BEST_MODE="default" + +for MODE in ${MODES}; do + echo "=== Testing preprocess_mode=${MODE} ===" + + # Clear cache to ensure fresh embeddings with new preprocessing + rm -f "${SCRIPT_DIR}/.embedding_cache.pkl" + + # Run test and capture output + OUTPUT=$(uv run python "$SCRIPT_DIR/test_logo_detection.py" \ + --num-logos $NUM_LOGOS \ + --refs-per-logo $REFS_PER_LOGO \ + --positive-samples $POSITIVE_SAMPLES \ + --negative-samples $NEGATIVE_SAMPLES \ + --matching-method multi-ref \ + --min-matching-refs $MIN_MATCHING_REFS \ + --use-max-similarity \ + --threshold $THRESHOLD \ + --margin $MARGIN \ + --seed $SEED \ + --embedding-model "$MODEL" \ + --preprocess-mode "$MODE" \ + --no-cache \ + 2>&1) + + # Extract metrics + TP=$(echo "${OUTPUT}" | grep "True Positives" | grep -oE "[0-9]+" | head -1) + FP=$(echo "${OUTPUT}" | grep "False Positives" | grep -oE "[0-9]+" | head -1) + FN=$(echo "${OUTPUT}" | grep "False Negatives" | grep -oE "[0-9]+" | head -1) + PREC=$(echo "${OUTPUT}" | grep "Precision:" | grep -oE "[0-9]+\.[0-9]+%" | head -1) + RECALL=$(echo "${OUTPUT}" | grep "Recall:" | grep -oE "[0-9]+\.[0-9]+%" | head -1) + F1=$(echo "${OUTPUT}" | grep "F1 Score:" | grep -oE "[0-9]+\.[0-9]+%" | head -1) + + # Print to console + echo " TP: ${TP}, FP: ${FP}, FN: ${FN}" + echo " Precision: ${PREC}, Recall: ${RECALL}, F1: ${F1}" + echo "" + + # Add to results table + printf "%-12s %8s %8s %8s %8s %8s %8s\n" "${MODE}" "${TP}" "${FP}" "${FN}" "${PREC}" "${RECALL}" "${F1}" >> "$OUTPUT_FILE" + + # Track best F1 + F1_NUM=$(echo "${F1}" | tr -d '%') + if [ -n "$F1_NUM" ]; then + BETTER=$(echo "${F1_NUM} > ${BEST_F1}" | bc -l 2>/dev/null || echo "0") + if [ "$BETTER" = "1" ]; then + BEST_F1="${F1_NUM}" + BEST_MODE="${MODE}" + fi + fi + + # Also append full output for this test + echo "" >> "$OUTPUT_FILE" + echo "======================================================================" >> "$OUTPUT_FILE" + echo "DETAILED RESULTS: preprocess_mode=${MODE}" >> "$OUTPUT_FILE" + echo "======================================================================" >> "$OUTPUT_FILE" + echo "${OUTPUT}" | grep -A 50 "Configuration:" | head -30 >> "$OUTPUT_FILE" + echo "" >> "$OUTPUT_FILE" +done + +# Summary +echo "------------------------------------------------------------------------" >> "$OUTPUT_FILE" +echo "" >> "$OUTPUT_FILE" +echo "BEST PREPROCESSING MODE: ${BEST_MODE} (F1 = ${BEST_F1}%)" >> "$OUTPUT_FILE" +echo "" >> "$OUTPUT_FILE" +echo "Notes:" >> "$OUTPUT_FILE" +echo " - default: CLIP's standard preprocessing (resize shortest edge + center crop)" >> "$OUTPUT_FILE" +echo " - letterbox: Pads image to square with black bars, preserving aspect ratio" >> "$OUTPUT_FILE" +echo " - stretch: Resizes image to square, distorting aspect ratio" >> "$OUTPUT_FILE" +echo "" >> "$OUTPUT_FILE" + +echo "=======================================" +echo "BEST: preprocess_mode=${BEST_MODE} (F1 = ${BEST_F1}%)" +echo "=======================================" +echo "" +echo "Results saved to: $OUTPUT_FILE" diff --git a/test_logo_detection.py b/test_logo_detection.py index 7326394..ec3580e 100755 --- a/test_logo_detection.py +++ b/test_logo_detection.py @@ -18,7 +18,7 @@ import random import sqlite3 import sys from pathlib import Path -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import cv2 import torch @@ -243,11 +243,12 @@ def main(): parser.add_argument( "--matching-method", type=str, - choices=["simple", "margin", "multi-ref"], + choices=["simple", "margin", "multi-ref", "hybrid"], default="margin", help="Matching method: 'simple' returns all matches above threshold, " "'margin' requires confidence margin over 2nd best, " - "'multi-ref' aggregates scores across reference images (default: margin)", + "'multi-ref' aggregates scores across reference images, " + "'hybrid' combines text recognition with CLIP (default: margin)", ) parser.add_argument( "--min-matching-refs", @@ -260,6 +261,25 @@ def main(): action="store_true", help="For 'multi-ref' method: use max similarity instead of mean across references", ) + # Hybrid method arguments + parser.add_argument( + "--hybrid-text-threshold", + type=float, + default=0.60, + help="For 'hybrid' method: CLIP threshold when text matches (default: 0.60)", + ) + parser.add_argument( + "--hybrid-no-text-threshold", + type=float, + default=0.80, + help="For 'hybrid' method: CLIP threshold when text expected but not found (default: 0.80)", + ) + parser.add_argument( + "--text-similarity-threshold", + type=float, + default=0.5, + help="For 'hybrid' method: minimum text similarity to consider a match (default: 0.5)", + ) parser.add_argument( "-v", "--verbose", action="store_true", @@ -286,6 +306,14 @@ def main(): default=None, help="Append results summary to this file (no progress output, just results)", ) + parser.add_argument( + "--preprocess-mode", + type=str, + choices=["default", "letterbox", "stretch"], + default="default", + help="Image preprocessing mode for CLIP: 'default' (resize+center crop), " + "'letterbox' (pad to square with black bars), 'stretch' (distort to square)", + ) args = parser.parse_args() logger = setup_logging(args.verbose) @@ -315,12 +343,23 @@ def main(): # Initialize detector logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}") + if args.preprocess_mode != "default": + logger.info(f"Using preprocessing mode: {args.preprocess_mode}") detector = DetectLogosDETR( logger=logger, detr_threshold=args.detr_threshold, embedding_model=args.embedding_model, + preprocess_mode=args.preprocess_mode, ) + # Initialize text detector for hybrid method + text_detector = None + if args.matching_method == "hybrid": + logger.info("Initializing text detector for hybrid matching...") + from text_recognition import DetectText + text_detector = DetectText(logger=logger, threshold=0.3) + detector.set_text_detector(text_detector) + # Load ground truth (both mappings) logger.info("Loading ground truth from database...") image_to_logos, logo_to_images = get_ground_truth(db_path) @@ -338,10 +377,15 @@ def main(): multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {} # List for margin-based matching: (logo_name, embedding) tuples reference_embeddings: List[Tuple[str, torch.Tensor]] = [] + # Dict for hybrid matching: logo_name -> {'embeddings': [...], 'texts': [...]} + hybrid_reference_data: Dict[str, Dict[str, Any]] = {} total_refs = 0 + logos_with_text = 0 for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"): multi_ref_embeddings[logo_name] = [] + if args.matching_method == "hybrid": + hybrid_reference_data[logo_name] = {'embeddings': [], 'texts': set()} for ref_filename in ref_filenames: ref_path = reference_dir / ref_filename @@ -354,11 +398,15 @@ def main(): cache_key = f"ref:{ref_filename}" embedding = cache.get(cache_key) if cache else None - if embedding is None: + # Load image if needed (for embedding or text extraction) + img = None + if embedding is None or args.matching_method == "hybrid": img = load_image(ref_path) if img is None: logger.warning(f"Failed to load reference logo: {ref_path}") continue + + if embedding is None: embedding = detector.get_embedding(img) if cache: cache.put(cache_key, embedding) @@ -367,7 +415,21 @@ def main(): reference_embeddings.append((logo_name, embedding)) total_refs += 1 + # Extract text for hybrid method + if args.matching_method == "hybrid" and img is not None: + hybrid_reference_data[logo_name]['embeddings'].append(embedding) + texts = detector.extract_text(img, min_confidence=0.3) + hybrid_reference_data[logo_name]['texts'].update(texts) + + # Convert text set to list for hybrid data + if args.matching_method == "hybrid": + hybrid_reference_data[logo_name]['texts'] = list(hybrid_reference_data[logo_name]['texts']) + if hybrid_reference_data[logo_name]['texts']: + logos_with_text += 1 + logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos") + if args.matching_method == "hybrid": + logger.info(f"Extracted text from {logos_with_text}/{len(sampled_logos)} reference logos") # Build test set: for each logo, sample positive and negative images logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...") @@ -442,17 +504,26 @@ def main(): cache_key = f"det:{test_filename}" cached_detections = cache.get(cache_key) if cache else None + # For hybrid matching, we always need the original image for text extraction + test_img = None + if args.matching_method == "hybrid": + test_img = load_image(test_path) + if test_img is None: + logger.warning(f"Failed to load test image: {test_path}") + continue + if cached_detections is not None: # Cached detections contain serialized box data and embeddings detections = cached_detections else: # Load and detect - img = load_image(test_path) - if img is None: - logger.warning(f"Failed to load test image: {test_path}") - continue + if test_img is None: + test_img = load_image(test_path) + if test_img is None: + logger.warning(f"Failed to load test image: {test_path}") + continue - detections = detector.detect(img) + detections = detector.detect(test_img) # Cache the detections if cache: @@ -549,7 +620,7 @@ def main(): "correct": is_correct, }) - else: # multi-ref + elif args.matching_method == "multi-ref": # Multi-ref matching: aggregates scores across reference images match_result = detector.find_best_match_multi_ref( detection["embedding"], @@ -580,6 +651,50 @@ def main(): "correct": is_correct, }) + else: # hybrid + # Hybrid matching: combines text recognition with CLIP + # Extract crop from original image for text extraction + box = detection["box"] + crop = test_img[ + int(box["ymin"]):int(box["ymax"]), + int(box["xmin"]):int(box["xmax"]) + ] + + match_result = detector.find_best_match_hybrid( + detected_embedding=detection["embedding"], + detected_image=crop, + reference_data=hybrid_reference_data, + clip_threshold=args.threshold, + clip_threshold_with_text=args.hybrid_text_threshold, + clip_threshold_text_mismatch=args.hybrid_no_text_threshold, + text_similarity_threshold=args.text_similarity_threshold, + margin=args.margin, + use_mean_similarity=not args.use_max_similarity, + ) + if match_result: + label, similarity, match_info = match_result + matched_logos.add(label) + + is_correct = label in expected_logos + if is_correct: + true_positives += 1 + if args.similarity_details: + similarity_details["true_positive_sims"].append(similarity) + else: + false_positives += 1 + if args.similarity_details: + similarity_details["false_positive_sims"].append(similarity) + + results.append({ + "test_image": test_filename, + "matched_logo": label, + "similarity": similarity, + "correct": is_correct, + "text_matched": match_info.get("text_matched", False), + "text_similarity": match_info.get("text_similarity", 0), + "match_type": match_info.get("match_type", "unknown"), + }) + # Count missed detections (false negatives) missed = expected_logos - matched_logos false_negatives += len(missed) @@ -625,12 +740,18 @@ def main(): print(f" Test images processed: {len(test_images)}") print(f" CLIP similarity threshold: {args.threshold}") print(f" DETR confidence threshold: {args.detr_threshold}") + print(f" Preprocess mode: {args.preprocess_mode}") print(f" Matching method: {args.matching_method}") - if args.matching_method in ("margin", "multi-ref"): + if args.matching_method in ("margin", "multi-ref", "hybrid"): print(f" Matching margin: {args.margin}") if args.matching_method == "multi-ref": print(f" Min matching refs: {args.min_matching_refs}") print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}") + if args.matching_method == "hybrid": + print(f" CLIP threshold (text match): {args.hybrid_text_threshold}") + print(f" CLIP threshold (no text): {args.hybrid_no_text_threshold}") + print(f" Text similarity threshold: {args.text_similarity_threshold}") + print(f" Refs with text: {logos_with_text}/{len(sampled_logos)}") if args.seed is not None: print(f" Random seed: {args.seed}") @@ -818,9 +939,14 @@ def write_results_to_file( method_desc = "Simple (all matches above threshold)" elif args.matching_method == "margin": method_desc = f"Margin-based (margin={args.margin})" - else: # multi-ref + elif args.matching_method == "multi-ref": agg = "max" if args.use_max_similarity else "mean" method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})" + else: # hybrid + method_desc = ( + f"Hybrid (text+CLIP, text_thresh={args.hybrid_text_threshold}, " + f"no_text_thresh={args.hybrid_no_text_threshold}, margin={args.margin})" + ) lines = [ "=" * 70, @@ -832,6 +958,7 @@ def write_results_to_file( "", "Configuration:", f" Embedding model: {args.embedding_model}", + f" Preprocess mode: {args.preprocess_mode}", f" Reference logos: {num_logos}", f" Refs per logo: {args.refs_per_logo}", f" Total reference embeddings:{total_refs}",