Add hybrid text+CLIP matching and image preprocessing

Hybrid matching combines text recognition with CLIP similarity:
- If reference logo has text and detection matches: lower CLIP threshold
- If reference has text but detection doesn't match: higher threshold
- If reference has no text: standard threshold

Image preprocessing adds letterbox/stretch modes for CLIP input to
preserve aspect ratio instead of center cropping.

New files:
- run_hybrid_test.sh: Test hybrid matching configurations
- run_preprocess_test.sh: Compare preprocessing modes

Changes to logo_detection_detr.py:
- Add preprocess_mode parameter (default/letterbox/stretch)
- Add set_text_detector() for hybrid matching
- Add extract_text() using EasyOCR
- Add compute_text_similarity() with fuzzy matching
- Add find_best_match_hybrid() with tiered thresholds

Changes to test_logo_detection.py:
- Add --matching-method hybrid option
- Add --preprocess-mode option
- Add hybrid threshold arguments
This commit is contained in:
Rick McEwen
2026-01-07 15:09:09 -05:00
parent 78f46f04bf
commit 49f982611a
4 changed files with 817 additions and 13 deletions

View File

@ -18,7 +18,7 @@ import random
import sqlite3
import sys
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple
import cv2
import torch
@ -243,11 +243,12 @@ def main():
parser.add_argument(
"--matching-method",
type=str,
choices=["simple", "margin", "multi-ref"],
choices=["simple", "margin", "multi-ref", "hybrid"],
default="margin",
help="Matching method: 'simple' returns all matches above threshold, "
"'margin' requires confidence margin over 2nd best, "
"'multi-ref' aggregates scores across reference images (default: margin)",
"'multi-ref' aggregates scores across reference images, "
"'hybrid' combines text recognition with CLIP (default: margin)",
)
parser.add_argument(
"--min-matching-refs",
@ -260,6 +261,25 @@ def main():
action="store_true",
help="For 'multi-ref' method: use max similarity instead of mean across references",
)
# Hybrid method arguments
parser.add_argument(
"--hybrid-text-threshold",
type=float,
default=0.60,
help="For 'hybrid' method: CLIP threshold when text matches (default: 0.60)",
)
parser.add_argument(
"--hybrid-no-text-threshold",
type=float,
default=0.80,
help="For 'hybrid' method: CLIP threshold when text expected but not found (default: 0.80)",
)
parser.add_argument(
"--text-similarity-threshold",
type=float,
default=0.5,
help="For 'hybrid' method: minimum text similarity to consider a match (default: 0.5)",
)
parser.add_argument(
"-v", "--verbose",
action="store_true",
@ -286,6 +306,14 @@ def main():
default=None,
help="Append results summary to this file (no progress output, just results)",
)
parser.add_argument(
"--preprocess-mode",
type=str,
choices=["default", "letterbox", "stretch"],
default="default",
help="Image preprocessing mode for CLIP: 'default' (resize+center crop), "
"'letterbox' (pad to square with black bars), 'stretch' (distort to square)",
)
args = parser.parse_args()
logger = setup_logging(args.verbose)
@ -315,12 +343,23 @@ def main():
# Initialize detector
logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}")
if args.preprocess_mode != "default":
logger.info(f"Using preprocessing mode: {args.preprocess_mode}")
detector = DetectLogosDETR(
logger=logger,
detr_threshold=args.detr_threshold,
embedding_model=args.embedding_model,
preprocess_mode=args.preprocess_mode,
)
# Initialize text detector for hybrid method
text_detector = None
if args.matching_method == "hybrid":
logger.info("Initializing text detector for hybrid matching...")
from text_recognition import DetectText
text_detector = DetectText(logger=logger, threshold=0.3)
detector.set_text_detector(text_detector)
# Load ground truth (both mappings)
logger.info("Loading ground truth from database...")
image_to_logos, logo_to_images = get_ground_truth(db_path)
@ -338,10 +377,15 @@ def main():
multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {}
# List for margin-based matching: (logo_name, embedding) tuples
reference_embeddings: List[Tuple[str, torch.Tensor]] = []
# Dict for hybrid matching: logo_name -> {'embeddings': [...], 'texts': [...]}
hybrid_reference_data: Dict[str, Dict[str, Any]] = {}
total_refs = 0
logos_with_text = 0
for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"):
multi_ref_embeddings[logo_name] = []
if args.matching_method == "hybrid":
hybrid_reference_data[logo_name] = {'embeddings': [], 'texts': set()}
for ref_filename in ref_filenames:
ref_path = reference_dir / ref_filename
@ -354,11 +398,15 @@ def main():
cache_key = f"ref:{ref_filename}"
embedding = cache.get(cache_key) if cache else None
if embedding is None:
# Load image if needed (for embedding or text extraction)
img = None
if embedding is None or args.matching_method == "hybrid":
img = load_image(ref_path)
if img is None:
logger.warning(f"Failed to load reference logo: {ref_path}")
continue
if embedding is None:
embedding = detector.get_embedding(img)
if cache:
cache.put(cache_key, embedding)
@ -367,7 +415,21 @@ def main():
reference_embeddings.append((logo_name, embedding))
total_refs += 1
# Extract text for hybrid method
if args.matching_method == "hybrid" and img is not None:
hybrid_reference_data[logo_name]['embeddings'].append(embedding)
texts = detector.extract_text(img, min_confidence=0.3)
hybrid_reference_data[logo_name]['texts'].update(texts)
# Convert text set to list for hybrid data
if args.matching_method == "hybrid":
hybrid_reference_data[logo_name]['texts'] = list(hybrid_reference_data[logo_name]['texts'])
if hybrid_reference_data[logo_name]['texts']:
logos_with_text += 1
logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos")
if args.matching_method == "hybrid":
logger.info(f"Extracted text from {logos_with_text}/{len(sampled_logos)} reference logos")
# Build test set: for each logo, sample positive and negative images
logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...")
@ -442,17 +504,26 @@ def main():
cache_key = f"det:{test_filename}"
cached_detections = cache.get(cache_key) if cache else None
# For hybrid matching, we always need the original image for text extraction
test_img = None
if args.matching_method == "hybrid":
test_img = load_image(test_path)
if test_img is None:
logger.warning(f"Failed to load test image: {test_path}")
continue
if cached_detections is not None:
# Cached detections contain serialized box data and embeddings
detections = cached_detections
else:
# Load and detect
img = load_image(test_path)
if img is None:
logger.warning(f"Failed to load test image: {test_path}")
continue
if test_img is None:
test_img = load_image(test_path)
if test_img is None:
logger.warning(f"Failed to load test image: {test_path}")
continue
detections = detector.detect(img)
detections = detector.detect(test_img)
# Cache the detections
if cache:
@ -549,7 +620,7 @@ def main():
"correct": is_correct,
})
else: # multi-ref
elif args.matching_method == "multi-ref":
# Multi-ref matching: aggregates scores across reference images
match_result = detector.find_best_match_multi_ref(
detection["embedding"],
@ -580,6 +651,50 @@ def main():
"correct": is_correct,
})
else: # hybrid
# Hybrid matching: combines text recognition with CLIP
# Extract crop from original image for text extraction
box = detection["box"]
crop = test_img[
int(box["ymin"]):int(box["ymax"]),
int(box["xmin"]):int(box["xmax"])
]
match_result = detector.find_best_match_hybrid(
detected_embedding=detection["embedding"],
detected_image=crop,
reference_data=hybrid_reference_data,
clip_threshold=args.threshold,
clip_threshold_with_text=args.hybrid_text_threshold,
clip_threshold_text_mismatch=args.hybrid_no_text_threshold,
text_similarity_threshold=args.text_similarity_threshold,
margin=args.margin,
use_mean_similarity=not args.use_max_similarity,
)
if match_result:
label, similarity, match_info = match_result
matched_logos.add(label)
is_correct = label in expected_logos
if is_correct:
true_positives += 1
if args.similarity_details:
similarity_details["true_positive_sims"].append(similarity)
else:
false_positives += 1
if args.similarity_details:
similarity_details["false_positive_sims"].append(similarity)
results.append({
"test_image": test_filename,
"matched_logo": label,
"similarity": similarity,
"correct": is_correct,
"text_matched": match_info.get("text_matched", False),
"text_similarity": match_info.get("text_similarity", 0),
"match_type": match_info.get("match_type", "unknown"),
})
# Count missed detections (false negatives)
missed = expected_logos - matched_logos
false_negatives += len(missed)
@ -625,12 +740,18 @@ def main():
print(f" Test images processed: {len(test_images)}")
print(f" CLIP similarity threshold: {args.threshold}")
print(f" DETR confidence threshold: {args.detr_threshold}")
print(f" Preprocess mode: {args.preprocess_mode}")
print(f" Matching method: {args.matching_method}")
if args.matching_method in ("margin", "multi-ref"):
if args.matching_method in ("margin", "multi-ref", "hybrid"):
print(f" Matching margin: {args.margin}")
if args.matching_method == "multi-ref":
print(f" Min matching refs: {args.min_matching_refs}")
print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}")
if args.matching_method == "hybrid":
print(f" CLIP threshold (text match): {args.hybrid_text_threshold}")
print(f" CLIP threshold (no text): {args.hybrid_no_text_threshold}")
print(f" Text similarity threshold: {args.text_similarity_threshold}")
print(f" Refs with text: {logos_with_text}/{len(sampled_logos)}")
if args.seed is not None:
print(f" Random seed: {args.seed}")
@ -818,9 +939,14 @@ def write_results_to_file(
method_desc = "Simple (all matches above threshold)"
elif args.matching_method == "margin":
method_desc = f"Margin-based (margin={args.margin})"
else: # multi-ref
elif args.matching_method == "multi-ref":
agg = "max" if args.use_max_similarity else "mean"
method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})"
else: # hybrid
method_desc = (
f"Hybrid (text+CLIP, text_thresh={args.hybrid_text_threshold}, "
f"no_text_thresh={args.hybrid_no_text_threshold}, margin={args.margin})"
)
lines = [
"=" * 70,
@ -832,6 +958,7 @@ def write_results_to_file(
"",
"Configuration:",
f" Embedding model: {args.embedding_model}",
f" Preprocess mode: {args.preprocess_mode}",
f" Reference logos: {num_logos}",
f" Refs per logo: {args.refs_per_logo}",
f" Total reference embeddings:{total_refs}",