Add accuracy test framework, prompts, results, and analysis reports
Includes accuracy test scripts for Qwen (local) and Gemini (cloud API), three prompt variants (original, capstone, constrained), test results from all runs, and two analysis reports with an HTML presentation version.
This commit is contained in:
402
test_accuracy.py
Normal file
402
test_accuracy.py
Normal file
@ -0,0 +1,402 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to measure VLM accuracy for jersey color detection.
|
||||
|
||||
Uses annotated test images where ground truth colors are encoded in filenames.
|
||||
Compares VLM results against ground truth, measuring exact and similar color matches.
|
||||
White is ignored in both ground truth and VLM results.
|
||||
|
||||
Filename format: "014 - orange_dark blue or purple.jpg"
|
||||
- Underscore separates distinct jersey colors
|
||||
- "or" separates acceptable alternatives for a single jersey
|
||||
|
||||
Usage:
|
||||
python test_accuracy.py [prompt_file]
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from scan_utils.llama_cpp_client import LlamaCppClient
|
||||
|
||||
SERVER_URL = "http://agx:8080"
|
||||
IMAGES_DIR = os.path.join(os.path.dirname(__file__), "basketball_jersery_color_test_files_annotated")
|
||||
DEFAULT_PROMPT_FILE = os.path.join(os.path.dirname(__file__), "jersey_prompt.txt")
|
||||
MAX_IMAGE_WIDTH = 768
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Color similarity – colors in the same family count as "similar" matches
|
||||
# ---------------------------------------------------------------------------
|
||||
COLOR_FAMILIES = {
|
||||
'blue': ['blue', 'dark blue', 'navy blue', 'navy', 'royal blue'],
|
||||
'light_blue': ['light blue', 'sky blue', 'baby blue', 'carolina blue', 'powder blue'],
|
||||
'red': ['red', 'scarlet', 'crimson'],
|
||||
'dark_red': ['maroon', 'burgundy', 'dark red', 'wine'],
|
||||
'green': ['green', 'dark green', 'forest green', 'kelly green'],
|
||||
'yellow': ['yellow', 'gold', 'golden'],
|
||||
'orange': ['orange', 'burnt orange'],
|
||||
'brown': ['brown', 'dark brown'],
|
||||
'purple': ['purple', 'violet'],
|
||||
'gray': ['gray', 'grey', 'silver', 'charcoal'],
|
||||
'black': ['black'],
|
||||
'teal': ['teal', 'turquoise', 'cyan', 'aqua'],
|
||||
'pink': ['pink', 'magenta', 'hot pink', 'rose'],
|
||||
}
|
||||
|
||||
_COLOR_TO_FAMILY = {}
|
||||
for _family, _members in COLOR_FAMILIES.items():
|
||||
for _color in _members:
|
||||
_COLOR_TO_FAMILY[_color] = _family
|
||||
|
||||
|
||||
def colors_are_similar(color1: str, color2: str) -> bool:
|
||||
"""Return True if two colors belong to the same color family."""
|
||||
if color1 == color2:
|
||||
return True
|
||||
f1 = _COLOR_TO_FAMILY.get(color1)
|
||||
f2 = _COLOR_TO_FAMILY.get(color2)
|
||||
return bool(f1 and f2 and f1 == f2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ground-truth parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
def parse_ground_truth(filename: str) -> list[list[str]]:
|
||||
"""Parse ground truth colors from an annotated filename.
|
||||
|
||||
Returns a list of color groups. Each group is a list of acceptable
|
||||
alternatives (from "or" in the filename). White entries are removed.
|
||||
|
||||
Example: "014 - orange_dark blue or purple.jpg"
|
||||
-> [["orange"], ["dark blue", "purple"]]
|
||||
"""
|
||||
name = Path(filename).stem
|
||||
# Strip number prefix ("014 - ", "029 -", etc.)
|
||||
name = re.sub(r'^\d+\s*-\s*', '', name)
|
||||
# Treat hyphens between colors as underscores (e.g. "yellow-black")
|
||||
name = name.replace('-', '_')
|
||||
|
||||
color_groups = []
|
||||
for part in name.split('_'):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
alternatives = [a.strip().lower() for a in part.split(' or ')]
|
||||
alternatives = [a for a in alternatives if a and a != 'white']
|
||||
if alternatives:
|
||||
color_groups.append(alternatives)
|
||||
return color_groups
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response cleaning
|
||||
# ---------------------------------------------------------------------------
|
||||
def clean_response(text: str) -> str:
|
||||
"""Remove think tags and markdown code blocks from model output."""
|
||||
cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
cleaned = re.sub(r'\u25c1think\u25b7.*?\u25c1/think\u25b7', '', cleaned, flags=re.DOTALL)
|
||||
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
|
||||
cleaned = re.sub(r'\u25c1/?think\u25b7', '', cleaned, flags=re.IGNORECASE)
|
||||
|
||||
json_block = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, flags=re.DOTALL | re.IGNORECASE)
|
||||
if json_block:
|
||||
cleaned = json_block.group(1)
|
||||
else:
|
||||
cleaned = re.sub(r'```(?:json)?', '', cleaned, flags=re.IGNORECASE)
|
||||
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
def score_image(gt_groups: list[list[str]], vlm_colors: set[str]) -> dict:
|
||||
"""Compare VLM detected colors against ground truth color groups.
|
||||
|
||||
Recall = how many GT color groups were found in VLM output
|
||||
Precision = how many VLM colors match something in the GT
|
||||
"""
|
||||
recall_exact = 0
|
||||
recall_similar = 0
|
||||
recall_missed = []
|
||||
confusions = []
|
||||
|
||||
for group in gt_groups:
|
||||
# Try exact match first
|
||||
if any(alt in vlm_colors for alt in group):
|
||||
recall_exact += 1
|
||||
continue
|
||||
# Try similar match
|
||||
matched_vlm = None
|
||||
for alt in group:
|
||||
for vc in vlm_colors:
|
||||
if colors_are_similar(alt, vc):
|
||||
matched_vlm = vc
|
||||
break
|
||||
if matched_vlm:
|
||||
break
|
||||
if matched_vlm:
|
||||
recall_similar += 1
|
||||
confusions.append((group, matched_vlm))
|
||||
else:
|
||||
recall_missed.append(group)
|
||||
|
||||
# Precision: check each VLM color against GT
|
||||
all_gt_alts = [alt for group in gt_groups for alt in group]
|
||||
precision_exact = 0
|
||||
precision_similar = 0
|
||||
precision_extra = []
|
||||
for vc in vlm_colors:
|
||||
if vc in all_gt_alts:
|
||||
precision_exact += 1
|
||||
elif any(colors_are_similar(vc, gt) for gt in all_gt_alts):
|
||||
precision_similar += 1
|
||||
else:
|
||||
precision_extra.append(vc)
|
||||
|
||||
return {
|
||||
'gt_count': len(gt_groups),
|
||||
'vlm_count': len(vlm_colors),
|
||||
'recall_exact': recall_exact,
|
||||
'recall_similar': recall_similar,
|
||||
'recall_missed': recall_missed,
|
||||
'precision_exact': precision_exact,
|
||||
'precision_similar': precision_similar,
|
||||
'precision_extra': precision_extra,
|
||||
'confusions': confusions,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def pct(n: int, d: int) -> str:
|
||||
return f"{100 * n / d:.1f}%" if d else "N/A"
|
||||
|
||||
|
||||
def print_summary(total_gt, total_vlm, total_recall_exact, total_recall_similar,
|
||||
total_recall_missed, total_precision_exact, total_precision_similar,
|
||||
total_precision_extra, confusion_counter, missed_counter,
|
||||
extra_counter, per_image_results, image_count, errors, total_time):
|
||||
"""Print the full accuracy summary report."""
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("ACCURACY SUMMARY")
|
||||
print("=" * 80)
|
||||
print(f"Images processed: {image_count}")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Total time: {total_time:.1f}s ({total_time / max(image_count, 1):.1f}s avg)")
|
||||
print()
|
||||
print(f"Ground truth colors: {total_gt} (excluding white)")
|
||||
print(f"VLM unique colors: {total_vlm} (excluding white)")
|
||||
print()
|
||||
|
||||
print("--- Recall (did VLM find each ground truth color?) ---")
|
||||
print(f" Exact match: {total_recall_exact:4d} / {total_gt} ({pct(total_recall_exact, total_gt)})")
|
||||
print(f" Similar match: {total_recall_similar:4d} / {total_gt} ({pct(total_recall_similar, total_gt)})")
|
||||
recall_total = total_recall_exact + total_recall_similar
|
||||
print(f" Total found: {recall_total:4d} / {total_gt} ({pct(recall_total, total_gt)})")
|
||||
print(f" Missed: {total_recall_missed:4d} / {total_gt} ({pct(total_recall_missed, total_gt)})")
|
||||
print()
|
||||
|
||||
print("--- Precision (are VLM colors correct?) ---")
|
||||
print(f" Exact match: {total_precision_exact:4d} / {total_vlm} ({pct(total_precision_exact, total_vlm)})")
|
||||
print(f" Similar match: {total_precision_similar:4d} / {total_vlm} ({pct(total_precision_similar, total_vlm)})")
|
||||
prec_total = total_precision_exact + total_precision_similar
|
||||
print(f" Total correct: {prec_total:4d} / {total_vlm} ({pct(prec_total, total_vlm)})")
|
||||
print(f" Extra/wrong: {total_precision_extra:4d} / {total_vlm} ({pct(total_precision_extra, total_vlm)})")
|
||||
|
||||
if confusion_counter:
|
||||
print()
|
||||
print("--- Similar-Match Confusions (expected -> got) ---")
|
||||
for (expected, got), count in confusion_counter.most_common():
|
||||
print(f" {expected:30s} -> {got:20s} x{count}")
|
||||
|
||||
if missed_counter:
|
||||
print()
|
||||
print("--- Most Missed Ground Truth Colors ---")
|
||||
for color, count in missed_counter.most_common(20):
|
||||
bar = "#" * min(count, 40)
|
||||
print(f" {color:30s} {count:3d} {bar}")
|
||||
|
||||
if extra_counter:
|
||||
print()
|
||||
print("--- Most Common Extra/Wrong VLM Colors ---")
|
||||
for color, count in extra_counter.most_common(20):
|
||||
bar = "#" * min(count, 40)
|
||||
print(f" {color:30s} {count:3d} {bar}")
|
||||
|
||||
if per_image_results:
|
||||
tags = Counter(r['tag'] for r in per_image_results)
|
||||
print()
|
||||
print("--- Per-Image Verdict ---")
|
||||
for tag in ['PASS', 'PARTIAL', 'FAIL']:
|
||||
print(f" {tag:10s} {tags.get(tag, 0):4d}")
|
||||
|
||||
failed = [r for r in per_image_results if r['tag'] == 'FAIL']
|
||||
if failed:
|
||||
print()
|
||||
print(f"--- Failed Images ({len(failed)}) ---")
|
||||
for r in failed:
|
||||
scores = r['scores']
|
||||
missed_strs = ["|".join(g) for g in scores['recall_missed']]
|
||||
print(f" {r['file']}")
|
||||
print(f" missed: {', '.join(missed_strs)}")
|
||||
if scores['precision_extra']:
|
||||
print(f" extra: {', '.join(scores['precision_extra'])}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
def main():
|
||||
prompt_file = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_PROMPT_FILE
|
||||
|
||||
with open(prompt_file, 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
||||
image_files = sorted([
|
||||
p for p in Path(IMAGES_DIR).iterdir()
|
||||
if p.suffix.lower() in valid_extensions
|
||||
])
|
||||
|
||||
print(f"Images to process: {len(image_files)}")
|
||||
print(f"Server: {SERVER_URL}")
|
||||
print(f"Prompt: {prompt_file} ({len(prompt)} chars)")
|
||||
print("=" * 80)
|
||||
|
||||
client = LlamaCppClient(base_url=SERVER_URL)
|
||||
|
||||
# Accumulators
|
||||
total_gt = 0
|
||||
total_vlm = 0
|
||||
total_recall_exact = 0
|
||||
total_recall_similar = 0
|
||||
total_recall_missed = 0
|
||||
total_precision_exact = 0
|
||||
total_precision_similar = 0
|
||||
total_precision_extra = 0
|
||||
errors = 0
|
||||
start_all = time.time()
|
||||
|
||||
confusion_counter = Counter()
|
||||
missed_counter = Counter()
|
||||
extra_counter = Counter()
|
||||
per_image_results = []
|
||||
|
||||
for i, image_path in enumerate(image_files, 1):
|
||||
gt_groups = parse_ground_truth(image_path.name)
|
||||
gt_display = ", ".join("|".join(g) for g in gt_groups) if gt_groups else "(none)"
|
||||
print(f"\n[{i}/{len(image_files)}] {image_path.name}")
|
||||
print(f" GT: [{gt_display}]")
|
||||
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
print(" SKIP (failed to load)")
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
h, w = image.shape[:2]
|
||||
if w > MAX_IMAGE_WIDTH:
|
||||
scale = MAX_IMAGE_WIDTH / w
|
||||
image = cv2.resize(image, (MAX_IMAGE_WIDTH, int(h * scale)), interpolation=cv2.INTER_AREA)
|
||||
|
||||
message = client.create_multimodal_message(role="user", content=prompt, images=[image])
|
||||
|
||||
try:
|
||||
t0 = time.time()
|
||||
response = client.chat_completion(messages=[message], temperature=0.1, max_tokens=1000)
|
||||
elapsed = time.time() - t0
|
||||
|
||||
response_text = response['choices'][0]['message']['content']
|
||||
cleaned = clean_response(response_text)
|
||||
result = json.loads(cleaned)
|
||||
jerseys = result.get('jerseys', [])
|
||||
|
||||
# Unique VLM jersey colors, ignoring white
|
||||
vlm_colors = set()
|
||||
for j in jerseys:
|
||||
jc = j.get('jersey_color', '').strip().lower()
|
||||
if jc and jc != 'white':
|
||||
vlm_colors.add(jc)
|
||||
|
||||
vlm_display = ", ".join(sorted(vlm_colors)) if vlm_colors else "(none)"
|
||||
print(f" VLM: [{vlm_display}] ({len(jerseys)} jersey(s), {elapsed:.1f}s)")
|
||||
|
||||
if not gt_groups:
|
||||
print(" -- no ground truth colors (white-only), skipping scoring")
|
||||
continue
|
||||
|
||||
scores = score_image(gt_groups, vlm_colors)
|
||||
total_gt += scores['gt_count']
|
||||
total_vlm += scores['vlm_count']
|
||||
total_recall_exact += scores['recall_exact']
|
||||
total_recall_similar += scores['recall_similar']
|
||||
total_recall_missed += len(scores['recall_missed'])
|
||||
total_precision_exact += scores['precision_exact']
|
||||
total_precision_similar += scores['precision_similar']
|
||||
total_precision_extra += len(scores['precision_extra'])
|
||||
|
||||
for group, got in scores['confusions']:
|
||||
confusion_counter[("|".join(group), got)] += 1
|
||||
for group in scores['recall_missed']:
|
||||
missed_counter["|".join(group)] += 1
|
||||
for ec in scores['precision_extra']:
|
||||
extra_counter[ec] += 1
|
||||
|
||||
# Status line
|
||||
status_parts = []
|
||||
if scores['recall_exact']:
|
||||
status_parts.append(f"exact:{scores['recall_exact']}")
|
||||
if scores['recall_similar']:
|
||||
status_parts.append(f"similar:{scores['recall_similar']}")
|
||||
if scores['recall_missed']:
|
||||
missed_strs = ["|".join(g) for g in scores['recall_missed']]
|
||||
status_parts.append(f"MISS:{','.join(missed_strs)}")
|
||||
if scores['precision_extra']:
|
||||
status_parts.append(f"extra:{','.join(scores['precision_extra'])}")
|
||||
|
||||
all_found = (scores['recall_exact'] + scores['recall_similar']) == scores['gt_count']
|
||||
no_extra = not scores['precision_extra']
|
||||
if all_found and no_extra:
|
||||
tag = "PASS"
|
||||
elif scores['recall_exact'] + scores['recall_similar'] > 0:
|
||||
tag = "PARTIAL"
|
||||
else:
|
||||
tag = "FAIL"
|
||||
print(f" {tag} {', '.join(status_parts)}")
|
||||
|
||||
per_image_results.append({
|
||||
'file': image_path.name,
|
||||
'tag': tag,
|
||||
'scores': scores,
|
||||
})
|
||||
|
||||
except (json.JSONDecodeError, KeyError, IndexError) as e:
|
||||
print(f" PARSE ERROR: {e}")
|
||||
errors += 1
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
errors += 1
|
||||
|
||||
total_time = time.time() - start_all
|
||||
|
||||
print_summary(
|
||||
total_gt, total_vlm, total_recall_exact, total_recall_similar,
|
||||
total_recall_missed, total_precision_exact, total_precision_similar,
|
||||
total_precision_extra, confusion_counter, missed_counter,
|
||||
extra_counter, per_image_results, len(image_files), errors, total_time,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user