Add margin check to multi-ref matching to reduce false positives
The multi-ref matching method was missing a margin check against other logos, causing excessive false positives. This fix adds: - margin parameter to find_best_match_multi_ref() that requires the best logo's score to exceed the second-best by a minimum margin - Test script now passes --margin to both matching methods - Updated documentation to reflect margin applies to both methods Also adds run_comparison_tests.sh to run all three matching methods and compare results.
This commit is contained in:
@ -401,6 +401,7 @@ class DetectLogosDETR:
|
||||
similarity_threshold: float = 0.85,
|
||||
min_matching_refs: int = 1,
|
||||
use_mean_similarity: bool = True,
|
||||
margin: float = 0.0,
|
||||
) -> Optional[Tuple[str, float, int]]:
|
||||
"""
|
||||
Find the best matching reference logo using multiple reference embeddings per logo.
|
||||
@ -414,6 +415,7 @@ class DetectLogosDETR:
|
||||
similarity_threshold: Minimum similarity to consider a match (0-1)
|
||||
min_matching_refs: Minimum number of references that must match above threshold
|
||||
use_mean_similarity: If True, use mean similarity across all refs; if False, use max
|
||||
margin: Required margin between best and second-best logo scores (0-1)
|
||||
|
||||
Returns:
|
||||
Tuple of (label, similarity, num_matching_refs) for best match,
|
||||
@ -422,9 +424,8 @@ class DetectLogosDETR:
|
||||
if not reference_embeddings:
|
||||
return None
|
||||
|
||||
best_score = -1.0
|
||||
best_label = None
|
||||
best_num_matches = 0
|
||||
# Calculate scores for all logos that meet the min_matching_refs requirement
|
||||
logo_scores = []
|
||||
|
||||
for label, ref_embedding_list in reference_embeddings.items():
|
||||
if not ref_embedding_list:
|
||||
@ -445,17 +446,30 @@ class DetectLogosDETR:
|
||||
else:
|
||||
score = max(similarities)
|
||||
|
||||
# Check if this logo meets the minimum matching refs requirement
|
||||
if num_matches >= min_matching_refs and score > best_score:
|
||||
best_score = score
|
||||
best_label = label
|
||||
best_num_matches = num_matches
|
||||
# Only consider logos that meet the minimum matching refs requirement
|
||||
if num_matches >= min_matching_refs:
|
||||
logo_scores.append((label, score, num_matches))
|
||||
|
||||
if best_label is not None and best_score >= similarity_threshold:
|
||||
return (best_label, best_score, best_num_matches)
|
||||
else:
|
||||
if not logo_scores:
|
||||
return None
|
||||
|
||||
# Sort by score descending
|
||||
logo_scores.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
best_label, best_score, best_num_matches = logo_scores[0]
|
||||
|
||||
# Check if best score meets threshold
|
||||
if best_score < similarity_threshold:
|
||||
return None
|
||||
|
||||
# Check margin against second-best logo (if exists)
|
||||
if margin > 0 and len(logo_scores) > 1:
|
||||
second_best_score = logo_scores[1][1]
|
||||
if best_score - second_best_score < margin:
|
||||
return None # Not confident enough
|
||||
|
||||
return (best_label, best_score, best_num_matches)
|
||||
|
||||
def find_best_match_with_margin(
|
||||
self,
|
||||
detected_embedding: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user