Files
jersey_test/test_accuracy_gemini.py
Rick McEwen 5405d7f7dc Add accuracy test framework, prompts, results, and analysis reports
Includes accuracy test scripts for Qwen (local) and Gemini (cloud API),
three prompt variants (original, capstone, constrained), test results
from all runs, and two analysis reports with an HTML presentation version.
2026-03-03 18:44:49 -07:00

577 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Test script to measure Gemini VLM accuracy for jersey color detection.
Uses annotated test images where ground truth colors are encoded in filenames.
Compares Gemini results against ground truth, measuring exact and similar color
matches. White is ignored in both ground truth and VLM results.
Filename format: "014 - orange_dark blue or purple.jpg"
- Underscore separates distinct jersey colors
- "or" separates acceptable alternatives for a single jersey
Usage:
python test_accuracy_gemini.py [prompt_file]
"""
import base64
import concurrent.futures
import json
import os
import re
import sys
import time
from collections import Counter
from pathlib import Path
import cv2
import requests
GEMINI_MODEL = "gemini-3-flash-preview"
API_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent"
IMAGES_DIR = os.path.join(os.path.dirname(__file__), "basketball_jersery_color_test_files_annotated")
DEFAULT_PROMPT_FILE = os.path.join(os.path.dirname(__file__), "jersey_prompt.txt")
API_KEY_FILE = os.path.join(os.path.dirname(__file__), "gemini_api_key.txt")
MAX_IMAGE_WIDTH = 768
JPEG_QUALITY = 85
CONCURRENT_WORKERS = 8
# ---------------------------------------------------------------------------
# Color similarity colors in the same family count as "similar" matches
# ---------------------------------------------------------------------------
COLOR_FAMILIES = {
'blue': ['blue', 'dark blue', 'navy blue', 'navy', 'royal blue'],
'light_blue': ['light blue', 'sky blue', 'baby blue', 'carolina blue', 'powder blue'],
'red': ['red', 'scarlet', 'crimson'],
'dark_red': ['maroon', 'burgundy', 'dark red', 'wine'],
'green': ['green', 'dark green', 'forest green', 'kelly green'],
'yellow': ['yellow', 'gold', 'golden'],
'orange': ['orange', 'burnt orange'],
'brown': ['brown', 'dark brown'],
'purple': ['purple', 'violet'],
'gray': ['gray', 'grey', 'silver', 'charcoal'],
'black': ['black'],
'teal': ['teal', 'turquoise', 'cyan', 'aqua'],
'pink': ['pink', 'magenta', 'hot pink', 'rose'],
}
_COLOR_TO_FAMILY = {}
for _family, _members in COLOR_FAMILIES.items():
for _color in _members:
_COLOR_TO_FAMILY[_color] = _family
def colors_are_similar(color1: str, color2: str) -> bool:
"""Return True if two colors belong to the same color family."""
if color1 == color2:
return True
f1 = _COLOR_TO_FAMILY.get(color1)
f2 = _COLOR_TO_FAMILY.get(color2)
return bool(f1 and f2 and f1 == f2)
# ---------------------------------------------------------------------------
# Ground-truth parsing
# ---------------------------------------------------------------------------
def parse_ground_truth(filename: str) -> list[list[str]]:
"""Parse ground truth colors from an annotated filename.
Returns a list of color groups. Each group is a list of acceptable
alternatives (from "or" in the filename). White entries are removed.
Example: "014 - orange_dark blue or purple.jpg"
-> [["orange"], ["dark blue", "purple"]]
"""
name = Path(filename).stem
# Strip number prefix ("014 - ", "029 -", etc.)
name = re.sub(r'^\d+\s*-\s*', '', name)
# Treat hyphens between colors as underscores (e.g. "yellow-black")
name = name.replace('-', '_')
color_groups = []
for part in name.split('_'):
part = part.strip()
if not part:
continue
alternatives = [a.strip().lower() for a in part.split(' or ')]
alternatives = [a for a in alternatives if a and a != 'white']
if alternatives:
color_groups.append(alternatives)
return color_groups
# ---------------------------------------------------------------------------
# Response cleaning & salvage
# ---------------------------------------------------------------------------
def clean_response(text: str) -> str:
"""Remove think tags and markdown code blocks from model output."""
cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE)
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
json_block = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, flags=re.DOTALL | re.IGNORECASE)
if json_block:
cleaned = json_block.group(1)
else:
cleaned = re.sub(r'```(?:json)?', '', cleaned, flags=re.IGNORECASE)
return cleaned.strip()
def salvage_jerseys(text: str) -> list[dict]:
"""Extract complete jersey objects from truncated JSON using regex."""
pattern = re.compile(
r'\{\s*'
r'"jersey_number"\s*:\s*"[^"]*"\s*,\s*'
r'"jersey_color"\s*:\s*"([^"]*)"\s*,\s*'
r'"number_color"\s*:\s*"([^"]*)"\s*'
r'\}',
re.DOTALL,
)
jerseys = []
for m in pattern.finditer(text):
jerseys.append({
'jersey_color': m.group(1),
'number_color': m.group(2),
})
return jerseys
# ---------------------------------------------------------------------------
# Gemini API helpers
# ---------------------------------------------------------------------------
def load_api_key() -> str:
with open(API_KEY_FILE, 'r') as f:
return f.read().strip()
def encode_image(image_path: str) -> tuple[str, str]:
"""Read an image file, resize if wider than MAX_IMAGE_WIDTH, and return (base64_data, mime_type)."""
ext = Path(image_path).suffix.lower()
mime_map = {
'.jpg': 'image/jpeg',
'.jpeg': 'image/jpeg',
'.png': 'image/png',
'.webp': 'image/webp',
'.bmp': 'image/bmp',
'.tiff': 'image/tiff',
}
mime_type = mime_map.get(ext, 'image/jpeg')
image = cv2.imread(image_path)
if image is not None:
h, w = image.shape[:2]
if w > MAX_IMAGE_WIDTH:
scale = MAX_IMAGE_WIDTH / w
image = cv2.resize(image, (MAX_IMAGE_WIDTH, int(h * scale)), interpolation=cv2.INTER_AREA)
if ext == '.png':
_, buf = cv2.imencode('.png', image)
else:
_, buf = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, JPEG_QUALITY])
data = base64.b64encode(buf).decode('utf-8')
else:
with open(image_path, 'rb') as f:
data = base64.b64encode(f.read()).decode('utf-8')
return data, mime_type
MAX_RETRIES = 3
RETRY_BACKOFF = [2, 5, 10]
def call_gemini(session: requests.Session, api_key: str, image_data: str,
mime_type: str, prompt: str) -> dict:
"""Send pre-encoded image + prompt to the Gemini API and return the raw response."""
payload = {
"contents": [{
"parts": [
{
"inline_data": {
"mime_type": mime_type,
"data": image_data,
}
},
{
"text": prompt,
}
]
}],
"generationConfig": {
"temperature": 0.1,
"maxOutputTokens": 8192,
"responseMimeType": "application/json",
}
}
for attempt in range(MAX_RETRIES):
response = session.post(
API_URL,
headers={
"x-goog-api-key": api_key,
"Content-Type": "application/json",
},
json=payload,
)
if response.status_code >= 500 and attempt < MAX_RETRIES - 1:
time.sleep(RETRY_BACKOFF[attempt])
continue
response.raise_for_status()
return response.json()
response.raise_for_status()
return response.json()
def _api_worker(session: requests.Session, api_key: str, image_data: str,
mime_type: str, prompt: str) -> dict:
"""Wrapper that captures timing and exceptions for concurrent execution."""
t0 = time.time()
try:
resp = call_gemini(session, api_key, image_data, mime_type, prompt)
return {'resp': resp, 'elapsed': time.time() - t0, 'error': None}
except Exception as e:
return {'resp': None, 'elapsed': time.time() - t0, 'error': e}
# ---------------------------------------------------------------------------
# Scoring
# ---------------------------------------------------------------------------
def score_image(gt_groups: list[list[str]], vlm_colors: set[str]) -> dict:
"""Compare VLM detected colors against ground truth color groups.
Recall = how many GT color groups were found in VLM output
Precision = how many VLM colors match something in the GT
"""
recall_exact = 0
recall_similar = 0
recall_missed = []
confusions = []
for group in gt_groups:
# Try exact match first
if any(alt in vlm_colors for alt in group):
recall_exact += 1
continue
# Try similar match
matched_vlm = None
for alt in group:
for vc in vlm_colors:
if colors_are_similar(alt, vc):
matched_vlm = vc
break
if matched_vlm:
break
if matched_vlm:
recall_similar += 1
confusions.append((group, matched_vlm))
else:
recall_missed.append(group)
# Precision: check each VLM color against GT
all_gt_alts = [alt for group in gt_groups for alt in group]
precision_exact = 0
precision_similar = 0
precision_extra = []
for vc in vlm_colors:
if vc in all_gt_alts:
precision_exact += 1
elif any(colors_are_similar(vc, gt) for gt in all_gt_alts):
precision_similar += 1
else:
precision_extra.append(vc)
return {
'gt_count': len(gt_groups),
'vlm_count': len(vlm_colors),
'recall_exact': recall_exact,
'recall_similar': recall_similar,
'recall_missed': recall_missed,
'precision_exact': precision_exact,
'precision_similar': precision_similar,
'precision_extra': precision_extra,
'confusions': confusions,
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def pct(n: int, d: int) -> str:
return f"{100 * n / d:.1f}%" if d else "N/A"
def extract_vlm_colors(jerseys: list[dict]) -> set[str]:
"""Return unique jersey colors from VLM output, ignoring white."""
vlm_colors = set()
for j in jerseys:
jc = j.get('jersey_color', '').strip().lower()
if jc and jc != 'white':
vlm_colors.add(jc)
return vlm_colors
def parse_response(result: dict) -> tuple[list[dict], set[str]]:
"""Parse a Gemini response into jerseys list and vlm_colors set.
On JSON parse failure, attempts to salvage jersey objects from truncated
output. Returns (jerseys, vlm_colors).
"""
text = result['resp']['candidates'][0]['content']['parts'][0]['text']
cleaned = clean_response(text)
try:
data = json.loads(cleaned)
jerseys = data.get('jerseys', [])
except json.JSONDecodeError:
jerseys = salvage_jerseys(text)
return jerseys, extract_vlm_colors(jerseys)
def score_and_format(gt_groups, vlm_colors, scores):
"""Build a status line and tag from scoring results."""
status_parts = []
if scores['recall_exact']:
status_parts.append(f"exact:{scores['recall_exact']}")
if scores['recall_similar']:
status_parts.append(f"similar:{scores['recall_similar']}")
if scores['recall_missed']:
missed_strs = ["|".join(g) for g in scores['recall_missed']]
status_parts.append(f"MISS:{','.join(missed_strs)}")
if scores['precision_extra']:
status_parts.append(f"extra:{','.join(scores['precision_extra'])}")
all_found = (scores['recall_exact'] + scores['recall_similar']) == scores['gt_count']
no_extra = not scores['precision_extra']
if all_found and no_extra:
tag = "PASS"
elif scores['recall_exact'] + scores['recall_similar'] > 0:
tag = "PARTIAL"
else:
tag = "FAIL"
return tag, status_parts
def print_summary(model_name, total_gt, total_vlm, total_recall_exact,
total_recall_similar, total_recall_missed,
total_precision_exact, total_precision_similar,
total_precision_extra, confusion_counter, missed_counter,
extra_counter, per_image_results, image_count, errors,
total_time):
"""Print the full accuracy summary report."""
print()
print("=" * 80)
print(f"ACCURACY SUMMARY ({model_name})")
print("=" * 80)
print(f"Images processed: {image_count}")
print(f"Errors: {errors}")
print(f"Total time: {total_time:.1f}s ({total_time / max(image_count, 1):.1f}s avg)")
print()
print(f"Ground truth colors: {total_gt} (excluding white)")
print(f"VLM unique colors: {total_vlm} (excluding white)")
print()
print("--- Recall (did VLM find each ground truth color?) ---")
print(f" Exact match: {total_recall_exact:4d} / {total_gt} ({pct(total_recall_exact, total_gt)})")
print(f" Similar match: {total_recall_similar:4d} / {total_gt} ({pct(total_recall_similar, total_gt)})")
recall_total = total_recall_exact + total_recall_similar
print(f" Total found: {recall_total:4d} / {total_gt} ({pct(recall_total, total_gt)})")
print(f" Missed: {total_recall_missed:4d} / {total_gt} ({pct(total_recall_missed, total_gt)})")
print()
print("--- Precision (are VLM colors correct?) ---")
print(f" Exact match: {total_precision_exact:4d} / {total_vlm} ({pct(total_precision_exact, total_vlm)})")
print(f" Similar match: {total_precision_similar:4d} / {total_vlm} ({pct(total_precision_similar, total_vlm)})")
prec_total = total_precision_exact + total_precision_similar
print(f" Total correct: {prec_total:4d} / {total_vlm} ({pct(prec_total, total_vlm)})")
print(f" Extra/wrong: {total_precision_extra:4d} / {total_vlm} ({pct(total_precision_extra, total_vlm)})")
if confusion_counter:
print()
print("--- Similar-Match Confusions (expected -> got) ---")
for (expected, got), count in confusion_counter.most_common():
print(f" {expected:30s} -> {got:20s} x{count}")
if missed_counter:
print()
print("--- Most Missed Ground Truth Colors ---")
for color, count in missed_counter.most_common(20):
bar = "#" * min(count, 40)
print(f" {color:30s} {count:3d} {bar}")
if extra_counter:
print()
print("--- Most Common Extra/Wrong VLM Colors ---")
for color, count in extra_counter.most_common(20):
bar = "#" * min(count, 40)
print(f" {color:30s} {count:3d} {bar}")
if per_image_results:
tags = Counter(r['tag'] for r in per_image_results)
print()
print("--- Per-Image Verdict ---")
for tag in ['PASS', 'PARTIAL', 'FAIL']:
print(f" {tag:10s} {tags.get(tag, 0):4d}")
failed = [r for r in per_image_results if r['tag'] == 'FAIL']
if failed:
print()
print(f"--- Failed Images ({len(failed)}) ---")
for r in failed:
scores = r['scores']
missed_strs = ["|".join(g) for g in scores['recall_missed']]
print(f" {r['file']}")
print(f" missed: {', '.join(missed_strs)}")
if scores['precision_extra']:
print(f" extra: {', '.join(scores['precision_extra'])}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
prompt_file = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_PROMPT_FILE
api_key = load_api_key()
with open(prompt_file, 'r') as f:
prompt = f.read()
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
image_files = sorted([
p for p in Path(IMAGES_DIR).iterdir()
if p.suffix.lower() in valid_extensions
])
print(f"Model: {GEMINI_MODEL}")
print(f"Images to process: {len(image_files)}")
print(f"Concurrency: {CONCURRENT_WORKERS} workers")
print(f"Prompt: {prompt_file} ({len(prompt)} chars)")
print("=" * 80)
# ------------------------------------------------------------------
# Phase 1: Pre-encode all images
# ------------------------------------------------------------------
print("Pre-encoding images ... ", end="", flush=True)
t_enc = time.time()
encoded_images = []
for image_path in image_files:
encoded_images.append(encode_image(str(image_path)))
print(f"{len(encoded_images)} images in {time.time() - t_enc:.1f}s")
# ------------------------------------------------------------------
# Phase 2: Submit all API calls concurrently
# ------------------------------------------------------------------
session = requests.Session()
start_all = time.time()
print(f"Sending API requests ... ", flush=True)
api_results = [None] * len(image_files)
with concurrent.futures.ThreadPoolExecutor(max_workers=CONCURRENT_WORKERS) as executor:
future_to_idx = {}
for i, (image_data, mime_type) in enumerate(encoded_images):
future = executor.submit(
_api_worker, session, api_key, image_data, mime_type, prompt,
)
future_to_idx[future] = i
completed = 0
for future in concurrent.futures.as_completed(future_to_idx):
idx = future_to_idx[future]
api_results[idx] = future.result()
completed += 1
print(f"\r {completed}/{len(image_files)} API calls completed", end="", flush=True)
api_time = time.time() - start_all
print(f" ({api_time:.1f}s total)")
print("=" * 80)
# ------------------------------------------------------------------
# Phase 3: Score results in order
# ------------------------------------------------------------------
total_gt = 0
total_vlm = 0
total_recall_exact = 0
total_recall_similar = 0
total_recall_missed = 0
total_precision_exact = 0
total_precision_similar = 0
total_precision_extra = 0
errors = 0
confusion_counter = Counter()
missed_counter = Counter()
extra_counter = Counter()
per_image_results = []
for i, (image_path, result) in enumerate(zip(image_files, api_results), 1):
gt_groups = parse_ground_truth(image_path.name)
gt_display = ", ".join("|".join(g) for g in gt_groups) if gt_groups else "(none)"
print(f"\n[{i}/{len(image_files)}] {image_path.name}")
print(f" GT: [{gt_display}]")
if result['error'] is not None:
e = result['error']
if isinstance(e, requests.exceptions.HTTPError):
print(f" HTTP ERROR: {e}")
else:
print(f" ERROR: {e}")
errors += 1
continue
elapsed = result['elapsed']
try:
jerseys, vlm_colors = parse_response(result)
vlm_display = ", ".join(sorted(vlm_colors)) if vlm_colors else "(none)"
print(f" VLM: [{vlm_display}] ({len(jerseys)} jersey(s), {elapsed:.1f}s)")
if not gt_groups:
print(" -- no ground truth colors (white-only), skipping scoring")
continue
scores = score_image(gt_groups, vlm_colors)
total_gt += scores['gt_count']
total_vlm += scores['vlm_count']
total_recall_exact += scores['recall_exact']
total_recall_similar += scores['recall_similar']
total_recall_missed += len(scores['recall_missed'])
total_precision_exact += scores['precision_exact']
total_precision_similar += scores['precision_similar']
total_precision_extra += len(scores['precision_extra'])
for group, got in scores['confusions']:
confusion_counter[("|".join(group), got)] += 1
for group in scores['recall_missed']:
missed_counter["|".join(group)] += 1
for ec in scores['precision_extra']:
extra_counter[ec] += 1
tag, status_parts = score_and_format(gt_groups, vlm_colors, scores)
print(f" {tag} {', '.join(status_parts)}")
per_image_results.append({
'file': image_path.name,
'tag': tag,
'scores': scores,
})
except Exception as e:
print(f" PARSE ERROR: {e}")
errors += 1
total_time = time.time() - start_all
print_summary(
GEMINI_MODEL, total_gt, total_vlm, total_recall_exact,
total_recall_similar, total_recall_missed, total_precision_exact,
total_precision_similar, total_precision_extra, confusion_counter,
missed_counter, extra_counter, per_image_results, len(image_files),
errors, total_time,
)
if __name__ == '__main__':
main()