Add hybrid text+CLIP matching and image preprocessing
Hybrid matching combines text recognition with CLIP similarity: - If reference logo has text and detection matches: lower CLIP threshold - If reference has text but detection doesn't match: higher threshold - If reference has no text: standard threshold Image preprocessing adds letterbox/stretch modes for CLIP input to preserve aspect ratio instead of center cropping. New files: - run_hybrid_test.sh: Test hybrid matching configurations - run_preprocess_test.sh: Compare preprocessing modes Changes to logo_detection_detr.py: - Add preprocess_mode parameter (default/letterbox/stretch) - Add set_text_detector() for hybrid matching - Add extract_text() using EasyOCR - Add compute_text_similarity() with fuzzy matching - Add find_best_match_hybrid() with tiered thresholds Changes to test_logo_detection.py: - Add --matching-method hybrid option - Add --preprocess-mode option - Add hybrid threshold arguments
This commit is contained in:
@ -23,6 +23,7 @@ import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Optional, Any
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
class DetectLogosDETR:
|
||||
@ -49,6 +50,7 @@ class DetectLogosDETR:
|
||||
detr_threshold: float = 0.5,
|
||||
min_box_size: int = 20,
|
||||
nms_iou_threshold: float = 0.5,
|
||||
preprocess_mode: str = "default",
|
||||
):
|
||||
"""
|
||||
Initialize DETR and embedding models.
|
||||
@ -64,12 +66,17 @@ class DetectLogosDETR:
|
||||
detr_threshold: Confidence threshold for DETR detections (0-1)
|
||||
min_box_size: Minimum width/height in pixels for detected boxes (filters noise)
|
||||
nms_iou_threshold: IoU threshold for Non-Maximum Suppression
|
||||
preprocess_mode: Image preprocessing mode for CLIP:
|
||||
- "default": Use CLIP's default (resize shortest edge + center crop)
|
||||
- "letterbox": Pad to square with black bars, preserving aspect ratio
|
||||
- "stretch": Stretch to square (distorts aspect ratio)
|
||||
"""
|
||||
self.logger = logger
|
||||
self.detr_threshold = detr_threshold
|
||||
self.min_box_size = min_box_size
|
||||
self.nms_iou_threshold = nms_iou_threshold
|
||||
self.embedding_model_name = embedding_model
|
||||
self.preprocess_mode = preprocess_mode
|
||||
|
||||
# Set device
|
||||
self.device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
@ -116,6 +123,8 @@ class DetectLogosDETR:
|
||||
self.embedding_model = AutoModel.from_pretrained(embedding_model_path).to(self.device)
|
||||
self.embedding_processor = AutoImageProcessor.from_pretrained(embedding_model_path)
|
||||
|
||||
if self.preprocess_mode != "default":
|
||||
self.logger.info(f"Image preprocessing mode: {self.preprocess_mode}")
|
||||
self.logger.info("DetectLogosDETR initialization complete")
|
||||
|
||||
def _detect_model_type(self, model_name: str) -> str:
|
||||
@ -402,6 +411,46 @@ class DetectLogosDETR:
|
||||
|
||||
return self._get_embedding_pil(pil_image)
|
||||
|
||||
def _preprocess_image(self, pil_image: Image.Image, target_size: int = 224) -> Image.Image:
|
||||
"""
|
||||
Preprocess image based on the configured preprocessing mode.
|
||||
|
||||
Args:
|
||||
pil_image: PIL Image (RGB format)
|
||||
target_size: Target size for the square output (default 224 for CLIP)
|
||||
|
||||
Returns:
|
||||
Preprocessed PIL Image
|
||||
"""
|
||||
if self.preprocess_mode == "default":
|
||||
# Let the processor handle it (resize shortest edge + center crop)
|
||||
return pil_image
|
||||
|
||||
width, height = pil_image.size
|
||||
|
||||
if self.preprocess_mode == "letterbox":
|
||||
# Pad to square with black bars, preserving aspect ratio
|
||||
max_dim = max(width, height)
|
||||
|
||||
# Create a black square canvas
|
||||
new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
|
||||
|
||||
# Paste the original image centered
|
||||
paste_x = (max_dim - width) // 2
|
||||
paste_y = (max_dim - height) // 2
|
||||
new_image.paste(pil_image, (paste_x, paste_y))
|
||||
|
||||
# Resize to target size
|
||||
return new_image.resize((target_size, target_size), Image.LANCZOS)
|
||||
|
||||
elif self.preprocess_mode == "stretch":
|
||||
# Stretch to square (distorts aspect ratio)
|
||||
return pil_image.resize((target_size, target_size), Image.LANCZOS)
|
||||
|
||||
else:
|
||||
# Unknown mode, return original
|
||||
return pil_image
|
||||
|
||||
def _get_embedding_pil(self, pil_image: Image.Image) -> torch.Tensor:
|
||||
"""
|
||||
Internal method to get embedding from PIL image.
|
||||
@ -414,6 +463,10 @@ class DetectLogosDETR:
|
||||
Returns:
|
||||
Normalized feature embedding (torch.Tensor)
|
||||
"""
|
||||
# Apply preprocessing if configured
|
||||
if self.preprocess_mode != "default":
|
||||
pil_image = self._preprocess_image(pil_image)
|
||||
|
||||
# Process image through the embedding model
|
||||
inputs = self.embedding_processor(images=pil_image, return_tensors="pt").to(self.device)
|
||||
|
||||
@ -712,4 +765,311 @@ class DetectLogosDETR:
|
||||
f"(threshold: {similarity_threshold})"
|
||||
)
|
||||
|
||||
return matched_detections
|
||||
return matched_detections
|
||||
|
||||
# =========================================================================
|
||||
# Hybrid Text + CLIP Matching
|
||||
# =========================================================================
|
||||
|
||||
def set_text_detector(self, text_detector) -> None:
|
||||
"""
|
||||
Set an optional text detector for hybrid matching.
|
||||
|
||||
Args:
|
||||
text_detector: Instance of DetectText class from text_recognition.py
|
||||
"""
|
||||
self.text_detector = text_detector
|
||||
self.logger.info("Text detector enabled for hybrid matching")
|
||||
|
||||
def extract_text(self, image: np.ndarray, min_confidence: float = 0.3) -> List[str]:
|
||||
"""
|
||||
Extract text from an image using the text detector.
|
||||
|
||||
Args:
|
||||
image: OpenCV image (BGR format)
|
||||
min_confidence: Minimum OCR confidence to accept text
|
||||
|
||||
Returns:
|
||||
List of detected text strings (lowercased, stripped)
|
||||
"""
|
||||
if not hasattr(self, 'text_detector') or self.text_detector is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
results, _ = self.text_detector.detect(image)
|
||||
# Filter by confidence and normalize text
|
||||
texts = []
|
||||
for text, confidence in results:
|
||||
if confidence >= min_confidence:
|
||||
# Normalize: lowercase, strip whitespace, remove special chars
|
||||
normalized = text.lower().strip()
|
||||
if len(normalized) >= 2: # Ignore single characters
|
||||
texts.append(normalized)
|
||||
return texts
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Text extraction failed: {e}")
|
||||
return []
|
||||
|
||||
def extract_text_pil(self, pil_image: Image.Image, min_confidence: float = 0.3) -> List[str]:
|
||||
"""
|
||||
Extract text from a PIL image.
|
||||
|
||||
Args:
|
||||
pil_image: PIL Image (RGB format)
|
||||
min_confidence: Minimum OCR confidence
|
||||
|
||||
Returns:
|
||||
List of detected text strings
|
||||
"""
|
||||
# Convert PIL to OpenCV format
|
||||
cv_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
||||
return self.extract_text(cv_image, min_confidence)
|
||||
|
||||
@staticmethod
|
||||
def compute_text_similarity(text1_list: List[str], text2_list: List[str]) -> float:
|
||||
"""
|
||||
Compute fuzzy text similarity between two lists of text strings.
|
||||
|
||||
Uses a combination of exact matches and fuzzy matching to handle
|
||||
OCR variations like case differences, spacing, and minor errors.
|
||||
|
||||
Args:
|
||||
text1_list: List of text strings from first image
|
||||
text2_list: List of text strings from second image
|
||||
|
||||
Returns:
|
||||
Similarity score between 0 and 1
|
||||
"""
|
||||
if not text1_list or not text2_list:
|
||||
return 0.0
|
||||
|
||||
# Combine all text into single strings for overall comparison
|
||||
text1_combined = " ".join(sorted(text1_list))
|
||||
text2_combined = " ".join(sorted(text2_list))
|
||||
|
||||
# Method 1: Sequence matching on combined text
|
||||
seq_similarity = SequenceMatcher(None, text1_combined, text2_combined).ratio()
|
||||
|
||||
# Method 2: Token overlap (Jaccard-like)
|
||||
# Split into tokens
|
||||
tokens1 = set(text1_combined.split())
|
||||
tokens2 = set(text2_combined.split())
|
||||
|
||||
if tokens1 and tokens2:
|
||||
intersection = len(tokens1 & tokens2)
|
||||
union = len(tokens1 | tokens2)
|
||||
token_similarity = intersection / union if union > 0 else 0
|
||||
else:
|
||||
token_similarity = 0
|
||||
|
||||
# Method 3: Best pairwise match for each text in list1
|
||||
pairwise_scores = []
|
||||
for t1 in text1_list:
|
||||
best_match = 0
|
||||
for t2 in text2_list:
|
||||
score = SequenceMatcher(None, t1, t2).ratio()
|
||||
best_match = max(best_match, score)
|
||||
pairwise_scores.append(best_match)
|
||||
|
||||
pairwise_similarity = sum(pairwise_scores) / len(pairwise_scores) if pairwise_scores else 0
|
||||
|
||||
# Combine methods (weighted average)
|
||||
combined = (seq_similarity * 0.3 + token_similarity * 0.3 + pairwise_similarity * 0.4)
|
||||
|
||||
return combined
|
||||
|
||||
@staticmethod
|
||||
def texts_match(
|
||||
ref_texts: List[str],
|
||||
det_texts: List[str],
|
||||
threshold: float = 0.5
|
||||
) -> Tuple[bool, float]:
|
||||
"""
|
||||
Determine if texts match above a threshold.
|
||||
|
||||
Args:
|
||||
ref_texts: Text from reference logo
|
||||
det_texts: Text from detected region
|
||||
threshold: Minimum similarity to consider a match
|
||||
|
||||
Returns:
|
||||
Tuple of (is_match, similarity_score)
|
||||
"""
|
||||
if not ref_texts:
|
||||
# Reference has no text - can't match on text
|
||||
return (False, 0.0)
|
||||
|
||||
if not det_texts:
|
||||
# Reference has text but detection doesn't - no text match
|
||||
return (False, 0.0)
|
||||
|
||||
similarity = DetectLogosDETR.compute_text_similarity(ref_texts, det_texts)
|
||||
return (similarity >= threshold, similarity)
|
||||
|
||||
def find_best_match_hybrid(
|
||||
self,
|
||||
detected_embedding: torch.Tensor,
|
||||
detected_image: np.ndarray,
|
||||
reference_data: Dict[str, Dict[str, Any]],
|
||||
clip_threshold: float = 0.70,
|
||||
clip_threshold_with_text: float = 0.60,
|
||||
clip_threshold_text_mismatch: float = 0.80,
|
||||
text_similarity_threshold: float = 0.5,
|
||||
margin: float = 0.05,
|
||||
use_mean_similarity: bool = False,
|
||||
) -> Optional[Tuple[str, float, Dict[str, Any]]]:
|
||||
"""
|
||||
Find best match using hybrid text + CLIP approach.
|
||||
|
||||
Strategy:
|
||||
- If reference has text AND detection has matching text:
|
||||
→ Use lower CLIP threshold (text provides additional confidence)
|
||||
- If reference has text but detection doesn't match:
|
||||
→ Use higher CLIP threshold (need more visual confidence)
|
||||
- If reference has no text:
|
||||
→ Use standard CLIP threshold
|
||||
|
||||
Args:
|
||||
detected_embedding: CLIP embedding from detected logo region
|
||||
detected_image: OpenCV image of the detected region (for text extraction)
|
||||
reference_data: Dict mapping logo name to:
|
||||
{
|
||||
'embeddings': List[torch.Tensor], # CLIP embeddings
|
||||
'texts': List[str], # Extracted text from reference
|
||||
}
|
||||
clip_threshold: Standard CLIP threshold for no-text references
|
||||
clip_threshold_with_text: Lower threshold when text matches
|
||||
clip_threshold_text_mismatch: Higher threshold when text expected but missing
|
||||
text_similarity_threshold: Threshold for text matching
|
||||
margin: Required margin between best and second-best
|
||||
use_mean_similarity: Use mean vs max for multi-ref aggregation
|
||||
|
||||
Returns:
|
||||
Tuple of (label, clip_similarity, match_info) or None
|
||||
match_info contains: text_matched, text_similarity, threshold_used
|
||||
"""
|
||||
if not reference_data:
|
||||
return None
|
||||
|
||||
# Extract text from detected region
|
||||
detected_texts = self.extract_text(detected_image)
|
||||
|
||||
# Calculate scores for all logos
|
||||
logo_scores = []
|
||||
|
||||
for label, ref_info in reference_data.items():
|
||||
ref_embeddings = ref_info.get('embeddings', [])
|
||||
ref_texts = ref_info.get('texts', [])
|
||||
|
||||
if not ref_embeddings:
|
||||
continue
|
||||
|
||||
# Calculate CLIP similarity
|
||||
similarities = []
|
||||
for ref_emb in ref_embeddings:
|
||||
sim = self.compare_embeddings(detected_embedding, ref_emb)
|
||||
similarities.append(sim)
|
||||
|
||||
if use_mean_similarity:
|
||||
clip_score = sum(similarities) / len(similarities)
|
||||
else:
|
||||
clip_score = max(similarities)
|
||||
|
||||
# Determine text match status and appropriate threshold
|
||||
has_ref_text = len(ref_texts) > 0
|
||||
text_matched, text_sim = self.texts_match(
|
||||
ref_texts, detected_texts, text_similarity_threshold
|
||||
)
|
||||
|
||||
if has_ref_text:
|
||||
if text_matched:
|
||||
# Text matches - use lower threshold, boost confidence
|
||||
threshold_used = clip_threshold_with_text
|
||||
match_type = "text_match"
|
||||
else:
|
||||
# Reference has text but detection doesn't match
|
||||
# Require higher CLIP threshold
|
||||
threshold_used = clip_threshold_text_mismatch
|
||||
match_type = "text_mismatch"
|
||||
else:
|
||||
# No text in reference - standard matching
|
||||
threshold_used = clip_threshold
|
||||
match_type = "no_text"
|
||||
text_sim = 0.0
|
||||
|
||||
# Check if CLIP score meets the appropriate threshold
|
||||
if clip_score >= threshold_used:
|
||||
logo_scores.append({
|
||||
'label': label,
|
||||
'clip_score': clip_score,
|
||||
'text_matched': text_matched,
|
||||
'text_similarity': text_sim,
|
||||
'threshold_used': threshold_used,
|
||||
'match_type': match_type,
|
||||
'has_ref_text': has_ref_text,
|
||||
})
|
||||
|
||||
if not logo_scores:
|
||||
return None
|
||||
|
||||
# Sort by CLIP score descending
|
||||
logo_scores.sort(key=lambda x: x['clip_score'], reverse=True)
|
||||
|
||||
best = logo_scores[0]
|
||||
|
||||
# Check margin against second-best
|
||||
if margin > 0 and len(logo_scores) > 1:
|
||||
second_best_score = logo_scores[1]['clip_score']
|
||||
if best['clip_score'] - second_best_score < margin:
|
||||
return None
|
||||
|
||||
match_info = {
|
||||
'text_matched': best['text_matched'],
|
||||
'text_similarity': best['text_similarity'],
|
||||
'threshold_used': best['threshold_used'],
|
||||
'match_type': best['match_type'],
|
||||
'has_ref_text': best['has_ref_text'],
|
||||
'detected_texts': detected_texts,
|
||||
}
|
||||
|
||||
return (best['label'], best['clip_score'], match_info)
|
||||
|
||||
def prepare_reference_data_hybrid(
|
||||
self,
|
||||
reference_images: Dict[str, List[np.ndarray]],
|
||||
text_min_confidence: float = 0.3,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Prepare reference data for hybrid matching by computing embeddings and extracting text.
|
||||
|
||||
Args:
|
||||
reference_images: Dict mapping logo name to list of reference images (OpenCV BGR)
|
||||
text_min_confidence: Minimum confidence for text extraction
|
||||
|
||||
Returns:
|
||||
Dict mapping logo name to {'embeddings': [...], 'texts': [...]}
|
||||
"""
|
||||
reference_data = {}
|
||||
|
||||
for logo_name, images in reference_images.items():
|
||||
embeddings = []
|
||||
all_texts = set()
|
||||
|
||||
for img in images:
|
||||
# Compute CLIP embedding
|
||||
emb = self.get_embedding(img)
|
||||
embeddings.append(emb)
|
||||
|
||||
# Extract text
|
||||
texts = self.extract_text(img, text_min_confidence)
|
||||
all_texts.update(texts)
|
||||
|
||||
reference_data[logo_name] = {
|
||||
'embeddings': embeddings,
|
||||
'texts': list(all_texts),
|
||||
}
|
||||
|
||||
if all_texts:
|
||||
self.logger.debug(f"Reference '{logo_name}' has text: {all_texts}")
|
||||
|
||||
return reference_data
|
||||
Reference in New Issue
Block a user