Add hybrid text+CLIP matching and image preprocessing
Hybrid matching combines text recognition with CLIP similarity: - If reference logo has text and detection matches: lower CLIP threshold - If reference has text but detection doesn't match: higher threshold - If reference has no text: standard threshold Image preprocessing adds letterbox/stretch modes for CLIP input to preserve aspect ratio instead of center cropping. New files: - run_hybrid_test.sh: Test hybrid matching configurations - run_preprocess_test.sh: Compare preprocessing modes Changes to logo_detection_detr.py: - Add preprocess_mode parameter (default/letterbox/stretch) - Add set_text_detector() for hybrid matching - Add extract_text() using EasyOCR - Add compute_text_similarity() with fuzzy matching - Add find_best_match_hybrid() with tiered thresholds Changes to test_logo_detection.py: - Add --matching-method hybrid option - Add --preprocess-mode option - Add hybrid threshold arguments
This commit is contained in:
@ -18,7 +18,7 @@ import random
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
@ -243,11 +243,12 @@ def main():
|
||||
parser.add_argument(
|
||||
"--matching-method",
|
||||
type=str,
|
||||
choices=["simple", "margin", "multi-ref"],
|
||||
choices=["simple", "margin", "multi-ref", "hybrid"],
|
||||
default="margin",
|
||||
help="Matching method: 'simple' returns all matches above threshold, "
|
||||
"'margin' requires confidence margin over 2nd best, "
|
||||
"'multi-ref' aggregates scores across reference images (default: margin)",
|
||||
"'multi-ref' aggregates scores across reference images, "
|
||||
"'hybrid' combines text recognition with CLIP (default: margin)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min-matching-refs",
|
||||
@ -260,6 +261,25 @@ def main():
|
||||
action="store_true",
|
||||
help="For 'multi-ref' method: use max similarity instead of mean across references",
|
||||
)
|
||||
# Hybrid method arguments
|
||||
parser.add_argument(
|
||||
"--hybrid-text-threshold",
|
||||
type=float,
|
||||
default=0.60,
|
||||
help="For 'hybrid' method: CLIP threshold when text matches (default: 0.60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hybrid-no-text-threshold",
|
||||
type=float,
|
||||
default=0.80,
|
||||
help="For 'hybrid' method: CLIP threshold when text expected but not found (default: 0.80)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text-similarity-threshold",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="For 'hybrid' method: minimum text similarity to consider a match (default: 0.5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose",
|
||||
action="store_true",
|
||||
@ -286,6 +306,14 @@ def main():
|
||||
default=None,
|
||||
help="Append results summary to this file (no progress output, just results)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocess-mode",
|
||||
type=str,
|
||||
choices=["default", "letterbox", "stretch"],
|
||||
default="default",
|
||||
help="Image preprocessing mode for CLIP: 'default' (resize+center crop), "
|
||||
"'letterbox' (pad to square with black bars), 'stretch' (distort to square)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
logger = setup_logging(args.verbose)
|
||||
@ -315,12 +343,23 @@ def main():
|
||||
|
||||
# Initialize detector
|
||||
logger.info(f"Initializing logo detector with embedding model: {args.embedding_model}")
|
||||
if args.preprocess_mode != "default":
|
||||
logger.info(f"Using preprocessing mode: {args.preprocess_mode}")
|
||||
detector = DetectLogosDETR(
|
||||
logger=logger,
|
||||
detr_threshold=args.detr_threshold,
|
||||
embedding_model=args.embedding_model,
|
||||
preprocess_mode=args.preprocess_mode,
|
||||
)
|
||||
|
||||
# Initialize text detector for hybrid method
|
||||
text_detector = None
|
||||
if args.matching_method == "hybrid":
|
||||
logger.info("Initializing text detector for hybrid matching...")
|
||||
from text_recognition import DetectText
|
||||
text_detector = DetectText(logger=logger, threshold=0.3)
|
||||
detector.set_text_detector(text_detector)
|
||||
|
||||
# Load ground truth (both mappings)
|
||||
logger.info("Loading ground truth from database...")
|
||||
image_to_logos, logo_to_images = get_ground_truth(db_path)
|
||||
@ -338,10 +377,15 @@ def main():
|
||||
multi_ref_embeddings: Dict[str, List[torch.Tensor]] = {}
|
||||
# List for margin-based matching: (logo_name, embedding) tuples
|
||||
reference_embeddings: List[Tuple[str, torch.Tensor]] = []
|
||||
# Dict for hybrid matching: logo_name -> {'embeddings': [...], 'texts': [...]}
|
||||
hybrid_reference_data: Dict[str, Dict[str, Any]] = {}
|
||||
total_refs = 0
|
||||
logos_with_text = 0
|
||||
|
||||
for logo_name, ref_filenames in tqdm(sampled_logos.items(), desc="Reference logos"):
|
||||
multi_ref_embeddings[logo_name] = []
|
||||
if args.matching_method == "hybrid":
|
||||
hybrid_reference_data[logo_name] = {'embeddings': [], 'texts': set()}
|
||||
|
||||
for ref_filename in ref_filenames:
|
||||
ref_path = reference_dir / ref_filename
|
||||
@ -354,11 +398,15 @@ def main():
|
||||
cache_key = f"ref:{ref_filename}"
|
||||
embedding = cache.get(cache_key) if cache else None
|
||||
|
||||
if embedding is None:
|
||||
# Load image if needed (for embedding or text extraction)
|
||||
img = None
|
||||
if embedding is None or args.matching_method == "hybrid":
|
||||
img = load_image(ref_path)
|
||||
if img is None:
|
||||
logger.warning(f"Failed to load reference logo: {ref_path}")
|
||||
continue
|
||||
|
||||
if embedding is None:
|
||||
embedding = detector.get_embedding(img)
|
||||
if cache:
|
||||
cache.put(cache_key, embedding)
|
||||
@ -367,7 +415,21 @@ def main():
|
||||
reference_embeddings.append((logo_name, embedding))
|
||||
total_refs += 1
|
||||
|
||||
# Extract text for hybrid method
|
||||
if args.matching_method == "hybrid" and img is not None:
|
||||
hybrid_reference_data[logo_name]['embeddings'].append(embedding)
|
||||
texts = detector.extract_text(img, min_confidence=0.3)
|
||||
hybrid_reference_data[logo_name]['texts'].update(texts)
|
||||
|
||||
# Convert text set to list for hybrid data
|
||||
if args.matching_method == "hybrid":
|
||||
hybrid_reference_data[logo_name]['texts'] = list(hybrid_reference_data[logo_name]['texts'])
|
||||
if hybrid_reference_data[logo_name]['texts']:
|
||||
logos_with_text += 1
|
||||
|
||||
logger.info(f"Computed {total_refs} embeddings for {len(sampled_logos)} logos")
|
||||
if args.matching_method == "hybrid":
|
||||
logger.info(f"Extracted text from {logos_with_text}/{len(sampled_logos)} reference logos")
|
||||
|
||||
# Build test set: for each logo, sample positive and negative images
|
||||
logger.info(f"Sampling test images: {args.positive_samples} positive, {args.negative_samples} negative per logo...")
|
||||
@ -442,17 +504,26 @@ def main():
|
||||
cache_key = f"det:{test_filename}"
|
||||
cached_detections = cache.get(cache_key) if cache else None
|
||||
|
||||
# For hybrid matching, we always need the original image for text extraction
|
||||
test_img = None
|
||||
if args.matching_method == "hybrid":
|
||||
test_img = load_image(test_path)
|
||||
if test_img is None:
|
||||
logger.warning(f"Failed to load test image: {test_path}")
|
||||
continue
|
||||
|
||||
if cached_detections is not None:
|
||||
# Cached detections contain serialized box data and embeddings
|
||||
detections = cached_detections
|
||||
else:
|
||||
# Load and detect
|
||||
img = load_image(test_path)
|
||||
if img is None:
|
||||
logger.warning(f"Failed to load test image: {test_path}")
|
||||
continue
|
||||
if test_img is None:
|
||||
test_img = load_image(test_path)
|
||||
if test_img is None:
|
||||
logger.warning(f"Failed to load test image: {test_path}")
|
||||
continue
|
||||
|
||||
detections = detector.detect(img)
|
||||
detections = detector.detect(test_img)
|
||||
|
||||
# Cache the detections
|
||||
if cache:
|
||||
@ -549,7 +620,7 @@ def main():
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
else: # multi-ref
|
||||
elif args.matching_method == "multi-ref":
|
||||
# Multi-ref matching: aggregates scores across reference images
|
||||
match_result = detector.find_best_match_multi_ref(
|
||||
detection["embedding"],
|
||||
@ -580,6 +651,50 @@ def main():
|
||||
"correct": is_correct,
|
||||
})
|
||||
|
||||
else: # hybrid
|
||||
# Hybrid matching: combines text recognition with CLIP
|
||||
# Extract crop from original image for text extraction
|
||||
box = detection["box"]
|
||||
crop = test_img[
|
||||
int(box["ymin"]):int(box["ymax"]),
|
||||
int(box["xmin"]):int(box["xmax"])
|
||||
]
|
||||
|
||||
match_result = detector.find_best_match_hybrid(
|
||||
detected_embedding=detection["embedding"],
|
||||
detected_image=crop,
|
||||
reference_data=hybrid_reference_data,
|
||||
clip_threshold=args.threshold,
|
||||
clip_threshold_with_text=args.hybrid_text_threshold,
|
||||
clip_threshold_text_mismatch=args.hybrid_no_text_threshold,
|
||||
text_similarity_threshold=args.text_similarity_threshold,
|
||||
margin=args.margin,
|
||||
use_mean_similarity=not args.use_max_similarity,
|
||||
)
|
||||
if match_result:
|
||||
label, similarity, match_info = match_result
|
||||
matched_logos.add(label)
|
||||
|
||||
is_correct = label in expected_logos
|
||||
if is_correct:
|
||||
true_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["true_positive_sims"].append(similarity)
|
||||
else:
|
||||
false_positives += 1
|
||||
if args.similarity_details:
|
||||
similarity_details["false_positive_sims"].append(similarity)
|
||||
|
||||
results.append({
|
||||
"test_image": test_filename,
|
||||
"matched_logo": label,
|
||||
"similarity": similarity,
|
||||
"correct": is_correct,
|
||||
"text_matched": match_info.get("text_matched", False),
|
||||
"text_similarity": match_info.get("text_similarity", 0),
|
||||
"match_type": match_info.get("match_type", "unknown"),
|
||||
})
|
||||
|
||||
# Count missed detections (false negatives)
|
||||
missed = expected_logos - matched_logos
|
||||
false_negatives += len(missed)
|
||||
@ -625,12 +740,18 @@ def main():
|
||||
print(f" Test images processed: {len(test_images)}")
|
||||
print(f" CLIP similarity threshold: {args.threshold}")
|
||||
print(f" DETR confidence threshold: {args.detr_threshold}")
|
||||
print(f" Preprocess mode: {args.preprocess_mode}")
|
||||
print(f" Matching method: {args.matching_method}")
|
||||
if args.matching_method in ("margin", "multi-ref"):
|
||||
if args.matching_method in ("margin", "multi-ref", "hybrid"):
|
||||
print(f" Matching margin: {args.margin}")
|
||||
if args.matching_method == "multi-ref":
|
||||
print(f" Min matching refs: {args.min_matching_refs}")
|
||||
print(f" Similarity aggregation: {'max' if args.use_max_similarity else 'mean'}")
|
||||
if args.matching_method == "hybrid":
|
||||
print(f" CLIP threshold (text match): {args.hybrid_text_threshold}")
|
||||
print(f" CLIP threshold (no text): {args.hybrid_no_text_threshold}")
|
||||
print(f" Text similarity threshold: {args.text_similarity_threshold}")
|
||||
print(f" Refs with text: {logos_with_text}/{len(sampled_logos)}")
|
||||
if args.seed is not None:
|
||||
print(f" Random seed: {args.seed}")
|
||||
|
||||
@ -818,9 +939,14 @@ def write_results_to_file(
|
||||
method_desc = "Simple (all matches above threshold)"
|
||||
elif args.matching_method == "margin":
|
||||
method_desc = f"Margin-based (margin={args.margin})"
|
||||
else: # multi-ref
|
||||
elif args.matching_method == "multi-ref":
|
||||
agg = "max" if args.use_max_similarity else "mean"
|
||||
method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})"
|
||||
else: # hybrid
|
||||
method_desc = (
|
||||
f"Hybrid (text+CLIP, text_thresh={args.hybrid_text_threshold}, "
|
||||
f"no_text_thresh={args.hybrid_no_text_threshold}, margin={args.margin})"
|
||||
)
|
||||
|
||||
lines = [
|
||||
"=" * 70,
|
||||
@ -832,6 +958,7 @@ def write_results_to_file(
|
||||
"",
|
||||
"Configuration:",
|
||||
f" Embedding model: {args.embedding_model}",
|
||||
f" Preprocess mode: {args.preprocess_mode}",
|
||||
f" Reference logos: {num_logos}",
|
||||
f" Refs per logo: {args.refs_per_logo}",
|
||||
f" Total reference embeddings:{total_refs}",
|
||||
|
||||
Reference in New Issue
Block a user