Add similarity distribution analysis for debugging embedding quality
- Add --similarity-details flag to test_logo_detection.py - Track true positive, false positive, and missed detection similarities - Compute distribution statistics (min, max, mean, stddev, percentiles) - Analyze overlap between TP and FP distributions - Suggest optimal threshold based on data - Show per-detection breakdown with top-5 matches - Create analyze_similarity_distribution.sh wrapper script - Supports baseline, finetuned, or both models - Saves output to similarity_analysis/ directory
This commit is contained in:
141
analyze_similarity_distribution.sh
Executable file
141
analyze_similarity_distribution.sh
Executable file
@ -0,0 +1,141 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Analyze similarity distribution for baseline and fine-tuned models.
|
||||||
|
#
|
||||||
|
# This script runs the test with --similarity-details to output detailed
|
||||||
|
# statistics about how the models score matches vs non-matches.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./analyze_similarity_distribution.sh
|
||||||
|
# ./analyze_similarity_distribution.sh --model finetuned
|
||||||
|
# ./analyze_similarity_distribution.sh --model baseline
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Default parameters
|
||||||
|
NUM_LOGOS="${NUM_LOGOS:-50}"
|
||||||
|
SEED="${SEED:-42}"
|
||||||
|
THRESHOLD="${THRESHOLD:-0.75}"
|
||||||
|
REFS_PER_LOGO="${REFS_PER_LOGO:-3}"
|
||||||
|
MARGIN="${MARGIN:-0.05}"
|
||||||
|
MODEL="${MODEL:-both}"
|
||||||
|
|
||||||
|
# Model paths
|
||||||
|
BASELINE_MODEL="openai/clip-vit-large-patch14"
|
||||||
|
FINETUNED_MODEL="models/logo_detection/clip_finetuned"
|
||||||
|
|
||||||
|
# Output directory
|
||||||
|
OUTPUT_DIR="similarity_analysis"
|
||||||
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
-n|--num-logos)
|
||||||
|
NUM_LOGOS="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-s|--seed)
|
||||||
|
SEED="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-t|--threshold)
|
||||||
|
THRESHOLD="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--model)
|
||||||
|
MODEL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--finetuned-path)
|
||||||
|
FINETUNED_MODEL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
echo "Usage: $0 [OPTIONS]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " -n, --num-logos NUM Number of logos to test (default: 50)"
|
||||||
|
echo " -s, --seed SEED Random seed (default: 42)"
|
||||||
|
echo " -t, --threshold VAL Similarity threshold (default: 0.75)"
|
||||||
|
echo " --model MODEL Which model: 'baseline', 'finetuned', or 'both' (default: both)"
|
||||||
|
echo " --finetuned-path PATH Path to fine-tuned model"
|
||||||
|
echo " -h, --help Show this help message"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
mkdir -p "${OUTPUT_DIR}"
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo "SIMILARITY DISTRIBUTION ANALYSIS"
|
||||||
|
echo "============================================================"
|
||||||
|
echo ""
|
||||||
|
echo "Parameters:"
|
||||||
|
echo " Number of logos: ${NUM_LOGOS}"
|
||||||
|
echo " Random seed: ${SEED}"
|
||||||
|
echo " Threshold: ${THRESHOLD}"
|
||||||
|
echo " Refs per logo: ${REFS_PER_LOGO}"
|
||||||
|
echo " Margin: ${MARGIN}"
|
||||||
|
echo " Model: ${MODEL}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Common test arguments
|
||||||
|
TEST_ARGS=(
|
||||||
|
-n "${NUM_LOGOS}"
|
||||||
|
-s "${SEED}"
|
||||||
|
-t "${THRESHOLD}"
|
||||||
|
--refs-per-logo "${REFS_PER_LOGO}"
|
||||||
|
--margin "${MARGIN}"
|
||||||
|
--matching-method multi-ref
|
||||||
|
--similarity-details
|
||||||
|
--clear-cache
|
||||||
|
)
|
||||||
|
|
||||||
|
run_analysis() {
|
||||||
|
local model_name="$1"
|
||||||
|
local model_path="$2"
|
||||||
|
local output_file="${OUTPUT_DIR}/${model_name}_similarity_${TIMESTAMP}.txt"
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Analyzing: ${model_name}"
|
||||||
|
echo "Model: ${model_path}"
|
||||||
|
echo "Output: ${output_file}"
|
||||||
|
echo "============================================================"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
uv run python test_logo_detection.py \
|
||||||
|
"${TEST_ARGS[@]}" \
|
||||||
|
-e "${model_path}" \
|
||||||
|
2>&1 | tee "${output_file}"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Results saved to: ${output_file}"
|
||||||
|
echo ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run analysis based on model selection
|
||||||
|
if [[ "${MODEL}" == "baseline" ]] || [[ "${MODEL}" == "both" ]]; then
|
||||||
|
run_analysis "baseline" "${BASELINE_MODEL}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "${MODEL}" == "finetuned" ]] || [[ "${MODEL}" == "both" ]]; then
|
||||||
|
if [ ! -d "${FINETUNED_MODEL}" ]; then
|
||||||
|
echo "Warning: Fine-tuned model not found at ${FINETUNED_MODEL}"
|
||||||
|
echo "Skipping fine-tuned model analysis."
|
||||||
|
else
|
||||||
|
run_analysis "finetuned" "${FINETUNED_MODEL}"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Analysis complete!"
|
||||||
|
echo "Results saved to: ${OUTPUT_DIR}/"
|
||||||
|
echo "============================================================"
|
||||||
@ -265,6 +265,11 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable verbose logging",
|
help="Enable verbose logging",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--similarity-details",
|
||||||
|
action="store_true",
|
||||||
|
help="Output detailed similarity scores for each detection (for analyzing score distributions)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-cache",
|
"--no-cache",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -411,6 +416,16 @@ def main():
|
|||||||
# Detailed results for analysis
|
# Detailed results for analysis
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
# Similarity distribution tracking (for --similarity-details)
|
||||||
|
similarity_details = {
|
||||||
|
"true_positive_sims": [], # Similarities for correct matches
|
||||||
|
"false_positive_sims": [], # Similarities for wrong matches
|
||||||
|
"missed_best_sims": [], # Best similarity for logos that should have matched but didn't
|
||||||
|
"all_positive_sims": [], # All similarities between detected regions and correct logos
|
||||||
|
"all_negative_sims": [], # All similarities between detected regions and wrong logos
|
||||||
|
"detection_details": [], # Per-detection breakdown
|
||||||
|
}
|
||||||
|
|
||||||
# Process test images
|
# Process test images
|
||||||
for test_filename in tqdm(test_images, desc="Testing"):
|
for test_filename in tqdm(test_images, desc="Testing"):
|
||||||
test_path = test_images_dir / test_filename
|
test_path = test_images_dir / test_filename
|
||||||
@ -445,7 +460,38 @@ def main():
|
|||||||
|
|
||||||
# Match detections against references using selected method
|
# Match detections against references using selected method
|
||||||
matched_logos: Set[str] = set()
|
matched_logos: Set[str] = set()
|
||||||
for detection in detections:
|
for det_idx, detection in enumerate(detections):
|
||||||
|
# Compute similarities to all reference logos for detailed analysis
|
||||||
|
if args.similarity_details:
|
||||||
|
all_sims = {}
|
||||||
|
for logo_name, ref_emb_list in multi_ref_embeddings.items():
|
||||||
|
sims = []
|
||||||
|
for ref_emb in ref_emb_list:
|
||||||
|
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||||
|
sims.append(sim)
|
||||||
|
# Use mean or max based on setting
|
||||||
|
if args.use_max_similarity:
|
||||||
|
all_sims[logo_name] = max(sims) if sims else 0
|
||||||
|
else:
|
||||||
|
all_sims[logo_name] = sum(sims) / len(sims) if sims else 0
|
||||||
|
|
||||||
|
# Track positive vs negative similarities
|
||||||
|
for sim in sims:
|
||||||
|
if logo_name in expected_logos:
|
||||||
|
similarity_details["all_positive_sims"].append(sim)
|
||||||
|
else:
|
||||||
|
similarity_details["all_negative_sims"].append(sim)
|
||||||
|
|
||||||
|
# Store detection details
|
||||||
|
sorted_sims = sorted(all_sims.items(), key=lambda x: -x[1])
|
||||||
|
similarity_details["detection_details"].append({
|
||||||
|
"image": test_filename,
|
||||||
|
"detection_idx": det_idx,
|
||||||
|
"expected_logos": list(expected_logos),
|
||||||
|
"top_5_matches": sorted_sims[:5],
|
||||||
|
"detr_score": detection.get("score", 0),
|
||||||
|
})
|
||||||
|
|
||||||
if args.matching_method == "simple":
|
if args.matching_method == "simple":
|
||||||
# Simple matching: return ALL logos above threshold
|
# Simple matching: return ALL logos above threshold
|
||||||
all_matches = detector.find_all_matches(
|
all_matches = detector.find_all_matches(
|
||||||
@ -457,16 +503,21 @@ def main():
|
|||||||
matched_logos.add(label)
|
matched_logos.add(label)
|
||||||
|
|
||||||
# Check if this is a correct match
|
# Check if this is a correct match
|
||||||
if label in expected_logos:
|
is_correct = label in expected_logos
|
||||||
|
if is_correct:
|
||||||
true_positives += 1
|
true_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["true_positive_sims"].append(similarity)
|
||||||
else:
|
else:
|
||||||
false_positives += 1
|
false_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["false_positive_sims"].append(similarity)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test_image": test_filename,
|
"test_image": test_filename,
|
||||||
"matched_logo": label,
|
"matched_logo": label,
|
||||||
"similarity": similarity,
|
"similarity": similarity,
|
||||||
"correct": label in expected_logos,
|
"correct": is_correct,
|
||||||
})
|
})
|
||||||
|
|
||||||
elif args.matching_method == "margin":
|
elif args.matching_method == "margin":
|
||||||
@ -481,16 +532,21 @@ def main():
|
|||||||
label, similarity = match_result
|
label, similarity = match_result
|
||||||
matched_logos.add(label)
|
matched_logos.add(label)
|
||||||
|
|
||||||
if label in expected_logos:
|
is_correct = label in expected_logos
|
||||||
|
if is_correct:
|
||||||
true_positives += 1
|
true_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["true_positive_sims"].append(similarity)
|
||||||
else:
|
else:
|
||||||
false_positives += 1
|
false_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["false_positive_sims"].append(similarity)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test_image": test_filename,
|
"test_image": test_filename,
|
||||||
"matched_logo": label,
|
"matched_logo": label,
|
||||||
"similarity": similarity,
|
"similarity": similarity,
|
||||||
"correct": label in expected_logos,
|
"correct": is_correct,
|
||||||
})
|
})
|
||||||
|
|
||||||
else: # multi-ref
|
else: # multi-ref
|
||||||
@ -507,16 +563,21 @@ def main():
|
|||||||
label, similarity, num_matching = match_result
|
label, similarity, num_matching = match_result
|
||||||
matched_logos.add(label)
|
matched_logos.add(label)
|
||||||
|
|
||||||
if label in expected_logos:
|
is_correct = label in expected_logos
|
||||||
|
if is_correct:
|
||||||
true_positives += 1
|
true_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["true_positive_sims"].append(similarity)
|
||||||
else:
|
else:
|
||||||
false_positives += 1
|
false_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["false_positive_sims"].append(similarity)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test_image": test_filename,
|
"test_image": test_filename,
|
||||||
"matched_logo": label,
|
"matched_logo": label,
|
||||||
"similarity": similarity,
|
"similarity": similarity,
|
||||||
"correct": label in expected_logos,
|
"correct": is_correct,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Count missed detections (false negatives)
|
# Count missed detections (false negatives)
|
||||||
@ -524,6 +585,15 @@ def main():
|
|||||||
false_negatives += len(missed)
|
false_negatives += len(missed)
|
||||||
|
|
||||||
for missed_logo in missed:
|
for missed_logo in missed:
|
||||||
|
# Track best similarity for missed logos (if we have detections)
|
||||||
|
if args.similarity_details and detections:
|
||||||
|
best_sim_for_missed = 0
|
||||||
|
for detection in detections:
|
||||||
|
for ref_emb in multi_ref_embeddings.get(missed_logo, []):
|
||||||
|
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||||
|
best_sim_for_missed = max(best_sim_for_missed, sim)
|
||||||
|
similarity_details["missed_best_sims"].append(best_sim_for_missed)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test_image": test_filename,
|
"test_image": test_filename,
|
||||||
"matched_logo": None,
|
"matched_logo": None,
|
||||||
@ -593,6 +663,10 @@ def main():
|
|||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Print similarity distribution details if requested
|
||||||
|
if args.similarity_details:
|
||||||
|
print_similarity_details(similarity_details, args.threshold)
|
||||||
|
|
||||||
# Write results to file if requested
|
# Write results to file if requested
|
||||||
if args.output_file:
|
if args.output_file:
|
||||||
write_results_to_file(
|
write_results_to_file(
|
||||||
@ -612,6 +686,116 @@ def main():
|
|||||||
print(f"\nResults appended to: {args.output_file}")
|
print(f"\nResults appended to: {args.output_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_similarity_details(details: dict, threshold: float):
|
||||||
|
"""Print detailed similarity distribution analysis."""
|
||||||
|
import statistics
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SIMILARITY DISTRIBUTION ANALYSIS")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Helper to compute stats
|
||||||
|
def compute_stats(values, name):
|
||||||
|
if not values:
|
||||||
|
print(f"\n{name}: No data")
|
||||||
|
return
|
||||||
|
print(f"\n{name} (n={len(values)}):")
|
||||||
|
print(f" Min: {min(values):.4f}")
|
||||||
|
print(f" Max: {max(values):.4f}")
|
||||||
|
print(f" Mean: {statistics.mean(values):.4f}")
|
||||||
|
if len(values) > 1:
|
||||||
|
print(f" StdDev: {statistics.stdev(values):.4f}")
|
||||||
|
print(f" Median: {statistics.median(values):.4f}")
|
||||||
|
|
||||||
|
# Percentiles
|
||||||
|
sorted_vals = sorted(values)
|
||||||
|
n = len(sorted_vals)
|
||||||
|
p10 = sorted_vals[int(n * 0.10)] if n > 10 else sorted_vals[0]
|
||||||
|
p25 = sorted_vals[int(n * 0.25)] if n > 4 else sorted_vals[0]
|
||||||
|
p75 = sorted_vals[int(n * 0.75)] if n > 4 else sorted_vals[-1]
|
||||||
|
p90 = sorted_vals[int(n * 0.90)] if n > 10 else sorted_vals[-1]
|
||||||
|
print(f" P10: {p10:.4f}")
|
||||||
|
print(f" P25: {p25:.4f}")
|
||||||
|
print(f" P75: {p75:.4f}")
|
||||||
|
print(f" P90: {p90:.4f}")
|
||||||
|
|
||||||
|
# Count above/below threshold
|
||||||
|
above = sum(1 for v in values if v >= threshold)
|
||||||
|
below = sum(1 for v in values if v < threshold)
|
||||||
|
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
|
||||||
|
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
|
||||||
|
|
||||||
|
# Print distribution stats
|
||||||
|
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities (correct matches)")
|
||||||
|
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities (wrong matches)")
|
||||||
|
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities (false negatives)")
|
||||||
|
compute_stats(details["all_positive_sims"], "ALL similarities to CORRECT logos (per-ref)")
|
||||||
|
compute_stats(details["all_negative_sims"], "ALL similarities to WRONG logos (per-ref)")
|
||||||
|
|
||||||
|
# Overlap analysis
|
||||||
|
tp_sims = details["true_positive_sims"]
|
||||||
|
fp_sims = details["false_positive_sims"]
|
||||||
|
if tp_sims and fp_sims:
|
||||||
|
print("\n" + "-" * 40)
|
||||||
|
print("OVERLAP ANALYSIS:")
|
||||||
|
tp_min, tp_max = min(tp_sims), max(tp_sims)
|
||||||
|
fp_min, fp_max = min(fp_sims), max(fp_sims)
|
||||||
|
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
|
||||||
|
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
|
||||||
|
|
||||||
|
# Check overlap
|
||||||
|
overlap_min = max(tp_min, fp_min)
|
||||||
|
overlap_max = min(tp_max, fp_max)
|
||||||
|
if overlap_min < overlap_max:
|
||||||
|
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
|
||||||
|
tp_in_overlap = sum(1 for v in tp_sims if overlap_min <= v <= overlap_max)
|
||||||
|
fp_in_overlap = sum(1 for v in fp_sims if overlap_min <= v <= overlap_max)
|
||||||
|
print(f" TPs in overlap: {tp_in_overlap} ({100*tp_in_overlap/len(tp_sims):.1f}%)")
|
||||||
|
print(f" FPs in overlap: {fp_in_overlap} ({100*fp_in_overlap/len(fp_sims):.1f}%)")
|
||||||
|
else:
|
||||||
|
print(" NO OVERLAP - distributions are separable!")
|
||||||
|
|
||||||
|
# Suggest optimal threshold
|
||||||
|
all_points = [(s, "tp") for s in tp_sims] + [(s, "fp") for s in fp_sims]
|
||||||
|
all_points.sort()
|
||||||
|
best_thresh = threshold
|
||||||
|
best_f1 = 0
|
||||||
|
total_tp = len(tp_sims)
|
||||||
|
total_fp = len(fp_sims)
|
||||||
|
|
||||||
|
for thresh in [p[0] for p in all_points]:
|
||||||
|
# At this threshold:
|
||||||
|
tp_above = sum(1 for s in tp_sims if s >= thresh)
|
||||||
|
fp_above = sum(1 for s in fp_sims if s >= thresh)
|
||||||
|
prec = tp_above / (tp_above + fp_above) if (tp_above + fp_above) > 0 else 0
|
||||||
|
rec = tp_above / total_tp if total_tp > 0 else 0
|
||||||
|
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
|
||||||
|
if f1 > best_f1:
|
||||||
|
best_f1 = f1
|
||||||
|
best_thresh = thresh
|
||||||
|
|
||||||
|
print(f"\n SUGGESTED OPTIMAL THRESHOLD: {best_thresh:.4f}")
|
||||||
|
print(f" (would give F1 = {best_f1:.4f} on this data)")
|
||||||
|
|
||||||
|
# Print sample detection details
|
||||||
|
det_details = details["detection_details"]
|
||||||
|
if det_details:
|
||||||
|
print("\n" + "-" * 40)
|
||||||
|
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
|
||||||
|
for i, det in enumerate(det_details[:20]):
|
||||||
|
expected = det["expected_logos"]
|
||||||
|
top5 = det["top_5_matches"]
|
||||||
|
print(f"\n [{i+1}] Image: {det['image']}")
|
||||||
|
print(f" Expected: {expected if expected else '(none)'}")
|
||||||
|
print(f" DETR score: {det['detr_score']:.3f}")
|
||||||
|
print(f" Top 5 matches:")
|
||||||
|
for logo, sim in top5:
|
||||||
|
marker = " <-- CORRECT" if logo in expected else ""
|
||||||
|
print(f" {sim:.4f} {logo}{marker}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
|
||||||
|
|
||||||
def write_results_to_file(
|
def write_results_to_file(
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
args,
|
args,
|
||||||
|
|||||||
Reference in New Issue
Block a user