Add accuracy test framework, prompts, results, and analysis reports

Includes accuracy test scripts for Qwen (local) and Gemini (cloud API),
three prompt variants (original, capstone, constrained), test results
from all runs, and two analysis reports with an HTML presentation version.
This commit is contained in:
2026-03-03 18:44:49 -07:00
parent 435033ea07
commit 5405d7f7dc
13 changed files with 8561 additions and 0 deletions

402
test_accuracy.py Normal file
View File

@ -0,0 +1,402 @@
#!/usr/bin/env python3
"""
Test script to measure VLM accuracy for jersey color detection.
Uses annotated test images where ground truth colors are encoded in filenames.
Compares VLM results against ground truth, measuring exact and similar color matches.
White is ignored in both ground truth and VLM results.
Filename format: "014 - orange_dark blue or purple.jpg"
- Underscore separates distinct jersey colors
- "or" separates acceptable alternatives for a single jersey
Usage:
python test_accuracy.py [prompt_file]
"""
import json
import os
import re
import sys
import time
from collections import Counter
from pathlib import Path
import cv2
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from scan_utils.llama_cpp_client import LlamaCppClient
SERVER_URL = "http://agx:8080"
IMAGES_DIR = os.path.join(os.path.dirname(__file__), "basketball_jersery_color_test_files_annotated")
DEFAULT_PROMPT_FILE = os.path.join(os.path.dirname(__file__), "jersey_prompt.txt")
MAX_IMAGE_WIDTH = 768
# ---------------------------------------------------------------------------
# Color similarity colors in the same family count as "similar" matches
# ---------------------------------------------------------------------------
COLOR_FAMILIES = {
'blue': ['blue', 'dark blue', 'navy blue', 'navy', 'royal blue'],
'light_blue': ['light blue', 'sky blue', 'baby blue', 'carolina blue', 'powder blue'],
'red': ['red', 'scarlet', 'crimson'],
'dark_red': ['maroon', 'burgundy', 'dark red', 'wine'],
'green': ['green', 'dark green', 'forest green', 'kelly green'],
'yellow': ['yellow', 'gold', 'golden'],
'orange': ['orange', 'burnt orange'],
'brown': ['brown', 'dark brown'],
'purple': ['purple', 'violet'],
'gray': ['gray', 'grey', 'silver', 'charcoal'],
'black': ['black'],
'teal': ['teal', 'turquoise', 'cyan', 'aqua'],
'pink': ['pink', 'magenta', 'hot pink', 'rose'],
}
_COLOR_TO_FAMILY = {}
for _family, _members in COLOR_FAMILIES.items():
for _color in _members:
_COLOR_TO_FAMILY[_color] = _family
def colors_are_similar(color1: str, color2: str) -> bool:
"""Return True if two colors belong to the same color family."""
if color1 == color2:
return True
f1 = _COLOR_TO_FAMILY.get(color1)
f2 = _COLOR_TO_FAMILY.get(color2)
return bool(f1 and f2 and f1 == f2)
# ---------------------------------------------------------------------------
# Ground-truth parsing
# ---------------------------------------------------------------------------
def parse_ground_truth(filename: str) -> list[list[str]]:
"""Parse ground truth colors from an annotated filename.
Returns a list of color groups. Each group is a list of acceptable
alternatives (from "or" in the filename). White entries are removed.
Example: "014 - orange_dark blue or purple.jpg"
-> [["orange"], ["dark blue", "purple"]]
"""
name = Path(filename).stem
# Strip number prefix ("014 - ", "029 -", etc.)
name = re.sub(r'^\d+\s*-\s*', '', name)
# Treat hyphens between colors as underscores (e.g. "yellow-black")
name = name.replace('-', '_')
color_groups = []
for part in name.split('_'):
part = part.strip()
if not part:
continue
alternatives = [a.strip().lower() for a in part.split(' or ')]
alternatives = [a for a in alternatives if a and a != 'white']
if alternatives:
color_groups.append(alternatives)
return color_groups
# ---------------------------------------------------------------------------
# Response cleaning
# ---------------------------------------------------------------------------
def clean_response(text: str) -> str:
"""Remove think tags and markdown code blocks from model output."""
cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE)
cleaned = re.sub(r'\u25c1think\u25b7.*?\u25c1/think\u25b7', '', cleaned, flags=re.DOTALL)
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'\u25c1/?think\u25b7', '', cleaned, flags=re.IGNORECASE)
json_block = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, flags=re.DOTALL | re.IGNORECASE)
if json_block:
cleaned = json_block.group(1)
else:
cleaned = re.sub(r'```(?:json)?', '', cleaned, flags=re.IGNORECASE)
return cleaned.strip()
# ---------------------------------------------------------------------------
# Scoring
# ---------------------------------------------------------------------------
def score_image(gt_groups: list[list[str]], vlm_colors: set[str]) -> dict:
"""Compare VLM detected colors against ground truth color groups.
Recall = how many GT color groups were found in VLM output
Precision = how many VLM colors match something in the GT
"""
recall_exact = 0
recall_similar = 0
recall_missed = []
confusions = []
for group in gt_groups:
# Try exact match first
if any(alt in vlm_colors for alt in group):
recall_exact += 1
continue
# Try similar match
matched_vlm = None
for alt in group:
for vc in vlm_colors:
if colors_are_similar(alt, vc):
matched_vlm = vc
break
if matched_vlm:
break
if matched_vlm:
recall_similar += 1
confusions.append((group, matched_vlm))
else:
recall_missed.append(group)
# Precision: check each VLM color against GT
all_gt_alts = [alt for group in gt_groups for alt in group]
precision_exact = 0
precision_similar = 0
precision_extra = []
for vc in vlm_colors:
if vc in all_gt_alts:
precision_exact += 1
elif any(colors_are_similar(vc, gt) for gt in all_gt_alts):
precision_similar += 1
else:
precision_extra.append(vc)
return {
'gt_count': len(gt_groups),
'vlm_count': len(vlm_colors),
'recall_exact': recall_exact,
'recall_similar': recall_similar,
'recall_missed': recall_missed,
'precision_exact': precision_exact,
'precision_similar': precision_similar,
'precision_extra': precision_extra,
'confusions': confusions,
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def pct(n: int, d: int) -> str:
return f"{100 * n / d:.1f}%" if d else "N/A"
def print_summary(total_gt, total_vlm, total_recall_exact, total_recall_similar,
total_recall_missed, total_precision_exact, total_precision_similar,
total_precision_extra, confusion_counter, missed_counter,
extra_counter, per_image_results, image_count, errors, total_time):
"""Print the full accuracy summary report."""
print()
print("=" * 80)
print("ACCURACY SUMMARY")
print("=" * 80)
print(f"Images processed: {image_count}")
print(f"Errors: {errors}")
print(f"Total time: {total_time:.1f}s ({total_time / max(image_count, 1):.1f}s avg)")
print()
print(f"Ground truth colors: {total_gt} (excluding white)")
print(f"VLM unique colors: {total_vlm} (excluding white)")
print()
print("--- Recall (did VLM find each ground truth color?) ---")
print(f" Exact match: {total_recall_exact:4d} / {total_gt} ({pct(total_recall_exact, total_gt)})")
print(f" Similar match: {total_recall_similar:4d} / {total_gt} ({pct(total_recall_similar, total_gt)})")
recall_total = total_recall_exact + total_recall_similar
print(f" Total found: {recall_total:4d} / {total_gt} ({pct(recall_total, total_gt)})")
print(f" Missed: {total_recall_missed:4d} / {total_gt} ({pct(total_recall_missed, total_gt)})")
print()
print("--- Precision (are VLM colors correct?) ---")
print(f" Exact match: {total_precision_exact:4d} / {total_vlm} ({pct(total_precision_exact, total_vlm)})")
print(f" Similar match: {total_precision_similar:4d} / {total_vlm} ({pct(total_precision_similar, total_vlm)})")
prec_total = total_precision_exact + total_precision_similar
print(f" Total correct: {prec_total:4d} / {total_vlm} ({pct(prec_total, total_vlm)})")
print(f" Extra/wrong: {total_precision_extra:4d} / {total_vlm} ({pct(total_precision_extra, total_vlm)})")
if confusion_counter:
print()
print("--- Similar-Match Confusions (expected -> got) ---")
for (expected, got), count in confusion_counter.most_common():
print(f" {expected:30s} -> {got:20s} x{count}")
if missed_counter:
print()
print("--- Most Missed Ground Truth Colors ---")
for color, count in missed_counter.most_common(20):
bar = "#" * min(count, 40)
print(f" {color:30s} {count:3d} {bar}")
if extra_counter:
print()
print("--- Most Common Extra/Wrong VLM Colors ---")
for color, count in extra_counter.most_common(20):
bar = "#" * min(count, 40)
print(f" {color:30s} {count:3d} {bar}")
if per_image_results:
tags = Counter(r['tag'] for r in per_image_results)
print()
print("--- Per-Image Verdict ---")
for tag in ['PASS', 'PARTIAL', 'FAIL']:
print(f" {tag:10s} {tags.get(tag, 0):4d}")
failed = [r for r in per_image_results if r['tag'] == 'FAIL']
if failed:
print()
print(f"--- Failed Images ({len(failed)}) ---")
for r in failed:
scores = r['scores']
missed_strs = ["|".join(g) for g in scores['recall_missed']]
print(f" {r['file']}")
print(f" missed: {', '.join(missed_strs)}")
if scores['precision_extra']:
print(f" extra: {', '.join(scores['precision_extra'])}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
prompt_file = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_PROMPT_FILE
with open(prompt_file, 'r') as f:
prompt = f.read()
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
image_files = sorted([
p for p in Path(IMAGES_DIR).iterdir()
if p.suffix.lower() in valid_extensions
])
print(f"Images to process: {len(image_files)}")
print(f"Server: {SERVER_URL}")
print(f"Prompt: {prompt_file} ({len(prompt)} chars)")
print("=" * 80)
client = LlamaCppClient(base_url=SERVER_URL)
# Accumulators
total_gt = 0
total_vlm = 0
total_recall_exact = 0
total_recall_similar = 0
total_recall_missed = 0
total_precision_exact = 0
total_precision_similar = 0
total_precision_extra = 0
errors = 0
start_all = time.time()
confusion_counter = Counter()
missed_counter = Counter()
extra_counter = Counter()
per_image_results = []
for i, image_path in enumerate(image_files, 1):
gt_groups = parse_ground_truth(image_path.name)
gt_display = ", ".join("|".join(g) for g in gt_groups) if gt_groups else "(none)"
print(f"\n[{i}/{len(image_files)}] {image_path.name}")
print(f" GT: [{gt_display}]")
image = cv2.imread(str(image_path))
if image is None:
print(" SKIP (failed to load)")
errors += 1
continue
h, w = image.shape[:2]
if w > MAX_IMAGE_WIDTH:
scale = MAX_IMAGE_WIDTH / w
image = cv2.resize(image, (MAX_IMAGE_WIDTH, int(h * scale)), interpolation=cv2.INTER_AREA)
message = client.create_multimodal_message(role="user", content=prompt, images=[image])
try:
t0 = time.time()
response = client.chat_completion(messages=[message], temperature=0.1, max_tokens=1000)
elapsed = time.time() - t0
response_text = response['choices'][0]['message']['content']
cleaned = clean_response(response_text)
result = json.loads(cleaned)
jerseys = result.get('jerseys', [])
# Unique VLM jersey colors, ignoring white
vlm_colors = set()
for j in jerseys:
jc = j.get('jersey_color', '').strip().lower()
if jc and jc != 'white':
vlm_colors.add(jc)
vlm_display = ", ".join(sorted(vlm_colors)) if vlm_colors else "(none)"
print(f" VLM: [{vlm_display}] ({len(jerseys)} jersey(s), {elapsed:.1f}s)")
if not gt_groups:
print(" -- no ground truth colors (white-only), skipping scoring")
continue
scores = score_image(gt_groups, vlm_colors)
total_gt += scores['gt_count']
total_vlm += scores['vlm_count']
total_recall_exact += scores['recall_exact']
total_recall_similar += scores['recall_similar']
total_recall_missed += len(scores['recall_missed'])
total_precision_exact += scores['precision_exact']
total_precision_similar += scores['precision_similar']
total_precision_extra += len(scores['precision_extra'])
for group, got in scores['confusions']:
confusion_counter[("|".join(group), got)] += 1
for group in scores['recall_missed']:
missed_counter["|".join(group)] += 1
for ec in scores['precision_extra']:
extra_counter[ec] += 1
# Status line
status_parts = []
if scores['recall_exact']:
status_parts.append(f"exact:{scores['recall_exact']}")
if scores['recall_similar']:
status_parts.append(f"similar:{scores['recall_similar']}")
if scores['recall_missed']:
missed_strs = ["|".join(g) for g in scores['recall_missed']]
status_parts.append(f"MISS:{','.join(missed_strs)}")
if scores['precision_extra']:
status_parts.append(f"extra:{','.join(scores['precision_extra'])}")
all_found = (scores['recall_exact'] + scores['recall_similar']) == scores['gt_count']
no_extra = not scores['precision_extra']
if all_found and no_extra:
tag = "PASS"
elif scores['recall_exact'] + scores['recall_similar'] > 0:
tag = "PARTIAL"
else:
tag = "FAIL"
print(f" {tag} {', '.join(status_parts)}")
per_image_results.append({
'file': image_path.name,
'tag': tag,
'scores': scores,
})
except (json.JSONDecodeError, KeyError, IndexError) as e:
print(f" PARSE ERROR: {e}")
errors += 1
except Exception as e:
print(f" ERROR: {e}")
errors += 1
total_time = time.time() - start_all
print_summary(
total_gt, total_vlm, total_recall_exact, total_recall_similar,
total_recall_missed, total_precision_exact, total_precision_similar,
total_precision_extra, confusion_counter, missed_counter,
extra_counter, per_image_results, len(image_files), errors, total_time,
)
if __name__ == '__main__':
main()