#!/usr/bin/env python3 """ Test script to measure VLM accuracy for jersey color detection. Uses annotated test images where ground truth colors are encoded in filenames. Compares VLM results against ground truth, measuring exact and similar color matches. White is ignored in both ground truth and VLM results. Filename format: "014 - orange_dark blue or purple.jpg" - Underscore separates distinct jersey colors - "or" separates acceptable alternatives for a single jersey Usage: python test_accuracy.py [prompt_file] """ import json import os import re import sys import time from collections import Counter from pathlib import Path import cv2 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from scan_utils.llama_cpp_client import LlamaCppClient SERVER_URL = "http://agx:8080" IMAGES_DIR = os.path.join(os.path.dirname(__file__), "basketball_jersery_color_test_files_annotated") DEFAULT_PROMPT_FILE = os.path.join(os.path.dirname(__file__), "jersey_prompt.txt") MAX_IMAGE_WIDTH = 768 # --------------------------------------------------------------------------- # Color similarity – colors in the same family count as "similar" matches # --------------------------------------------------------------------------- COLOR_FAMILIES = { 'blue': ['blue', 'dark blue', 'navy blue', 'navy', 'royal blue'], 'light_blue': ['light blue', 'sky blue', 'baby blue', 'carolina blue', 'powder blue'], 'red': ['red', 'scarlet', 'crimson'], 'dark_red': ['maroon', 'burgundy', 'dark red', 'wine'], 'green': ['green', 'dark green', 'forest green', 'kelly green'], 'yellow': ['yellow', 'gold', 'golden'], 'orange': ['orange', 'burnt orange'], 'brown': ['brown', 'dark brown'], 'purple': ['purple', 'violet'], 'gray': ['gray', 'grey', 'silver', 'charcoal'], 'black': ['black'], 'teal': ['teal', 'turquoise', 'cyan', 'aqua'], 'pink': ['pink', 'magenta', 'hot pink', 'rose'], } _COLOR_TO_FAMILY = {} for _family, _members in COLOR_FAMILIES.items(): for _color in _members: _COLOR_TO_FAMILY[_color] = _family def colors_are_similar(color1: str, color2: str) -> bool: """Return True if two colors belong to the same color family.""" if color1 == color2: return True f1 = _COLOR_TO_FAMILY.get(color1) f2 = _COLOR_TO_FAMILY.get(color2) return bool(f1 and f2 and f1 == f2) # --------------------------------------------------------------------------- # Ground-truth parsing # --------------------------------------------------------------------------- def parse_ground_truth(filename: str) -> list[list[str]]: """Parse ground truth colors from an annotated filename. Returns a list of color groups. Each group is a list of acceptable alternatives (from "or" in the filename). White entries are removed. Example: "014 - orange_dark blue or purple.jpg" -> [["orange"], ["dark blue", "purple"]] """ name = Path(filename).stem # Strip number prefix ("014 - ", "029 -", etc.) name = re.sub(r'^\d+\s*-\s*', '', name) # Treat hyphens between colors as underscores (e.g. "yellow-black") name = name.replace('-', '_') color_groups = [] for part in name.split('_'): part = part.strip() if not part: continue alternatives = [a.strip().lower() for a in part.split(' or ')] alternatives = [a for a in alternatives if a and a != 'white'] if alternatives: color_groups.append(alternatives) return color_groups # --------------------------------------------------------------------------- # Response cleaning # --------------------------------------------------------------------------- def clean_response(text: str) -> str: """Remove think tags and markdown code blocks from model output.""" cleaned = re.sub(r'.*?', '', text, flags=re.DOTALL | re.IGNORECASE) cleaned = re.sub(r'\u25c1think\u25b7.*?\u25c1/think\u25b7', '', cleaned, flags=re.DOTALL) cleaned = re.sub(r'', '', cleaned, flags=re.IGNORECASE) cleaned = re.sub(r'\u25c1/?think\u25b7', '', cleaned, flags=re.IGNORECASE) json_block = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, flags=re.DOTALL | re.IGNORECASE) if json_block: cleaned = json_block.group(1) else: cleaned = re.sub(r'```(?:json)?', '', cleaned, flags=re.IGNORECASE) return cleaned.strip() # --------------------------------------------------------------------------- # Scoring # --------------------------------------------------------------------------- def score_image(gt_groups: list[list[str]], vlm_colors: set[str]) -> dict: """Compare VLM detected colors against ground truth color groups. Recall = how many GT color groups were found in VLM output Precision = how many VLM colors match something in the GT """ recall_exact = 0 recall_similar = 0 recall_missed = [] confusions = [] for group in gt_groups: # Try exact match first if any(alt in vlm_colors for alt in group): recall_exact += 1 continue # Try similar match matched_vlm = None for alt in group: for vc in vlm_colors: if colors_are_similar(alt, vc): matched_vlm = vc break if matched_vlm: break if matched_vlm: recall_similar += 1 confusions.append((group, matched_vlm)) else: recall_missed.append(group) # Precision: check each VLM color against GT all_gt_alts = [alt for group in gt_groups for alt in group] precision_exact = 0 precision_similar = 0 precision_extra = [] for vc in vlm_colors: if vc in all_gt_alts: precision_exact += 1 elif any(colors_are_similar(vc, gt) for gt in all_gt_alts): precision_similar += 1 else: precision_extra.append(vc) return { 'gt_count': len(gt_groups), 'vlm_count': len(vlm_colors), 'recall_exact': recall_exact, 'recall_similar': recall_similar, 'recall_missed': recall_missed, 'precision_exact': precision_exact, 'precision_similar': precision_similar, 'precision_extra': precision_extra, 'confusions': confusions, } # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def pct(n: int, d: int) -> str: return f"{100 * n / d:.1f}%" if d else "N/A" def print_summary(total_gt, total_vlm, total_recall_exact, total_recall_similar, total_recall_missed, total_precision_exact, total_precision_similar, total_precision_extra, confusion_counter, missed_counter, extra_counter, per_image_results, image_count, errors, total_time): """Print the full accuracy summary report.""" print() print("=" * 80) print("ACCURACY SUMMARY") print("=" * 80) print(f"Images processed: {image_count}") print(f"Errors: {errors}") print(f"Total time: {total_time:.1f}s ({total_time / max(image_count, 1):.1f}s avg)") print() print(f"Ground truth colors: {total_gt} (excluding white)") print(f"VLM unique colors: {total_vlm} (excluding white)") print() print("--- Recall (did VLM find each ground truth color?) ---") print(f" Exact match: {total_recall_exact:4d} / {total_gt} ({pct(total_recall_exact, total_gt)})") print(f" Similar match: {total_recall_similar:4d} / {total_gt} ({pct(total_recall_similar, total_gt)})") recall_total = total_recall_exact + total_recall_similar print(f" Total found: {recall_total:4d} / {total_gt} ({pct(recall_total, total_gt)})") print(f" Missed: {total_recall_missed:4d} / {total_gt} ({pct(total_recall_missed, total_gt)})") print() print("--- Precision (are VLM colors correct?) ---") print(f" Exact match: {total_precision_exact:4d} / {total_vlm} ({pct(total_precision_exact, total_vlm)})") print(f" Similar match: {total_precision_similar:4d} / {total_vlm} ({pct(total_precision_similar, total_vlm)})") prec_total = total_precision_exact + total_precision_similar print(f" Total correct: {prec_total:4d} / {total_vlm} ({pct(prec_total, total_vlm)})") print(f" Extra/wrong: {total_precision_extra:4d} / {total_vlm} ({pct(total_precision_extra, total_vlm)})") if confusion_counter: print() print("--- Similar-Match Confusions (expected -> got) ---") for (expected, got), count in confusion_counter.most_common(): print(f" {expected:30s} -> {got:20s} x{count}") if missed_counter: print() print("--- Most Missed Ground Truth Colors ---") for color, count in missed_counter.most_common(20): bar = "#" * min(count, 40) print(f" {color:30s} {count:3d} {bar}") if extra_counter: print() print("--- Most Common Extra/Wrong VLM Colors ---") for color, count in extra_counter.most_common(20): bar = "#" * min(count, 40) print(f" {color:30s} {count:3d} {bar}") if per_image_results: tags = Counter(r['tag'] for r in per_image_results) print() print("--- Per-Image Verdict ---") for tag in ['PASS', 'PARTIAL', 'FAIL']: print(f" {tag:10s} {tags.get(tag, 0):4d}") failed = [r for r in per_image_results if r['tag'] == 'FAIL'] if failed: print() print(f"--- Failed Images ({len(failed)}) ---") for r in failed: scores = r['scores'] missed_strs = ["|".join(g) for g in scores['recall_missed']] print(f" {r['file']}") print(f" missed: {', '.join(missed_strs)}") if scores['precision_extra']: print(f" extra: {', '.join(scores['precision_extra'])}") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): prompt_file = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_PROMPT_FILE with open(prompt_file, 'r') as f: prompt = f.read() valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} image_files = sorted([ p for p in Path(IMAGES_DIR).iterdir() if p.suffix.lower() in valid_extensions ]) print(f"Images to process: {len(image_files)}") print(f"Server: {SERVER_URL}") print(f"Prompt: {prompt_file} ({len(prompt)} chars)") print("=" * 80) client = LlamaCppClient(base_url=SERVER_URL) # Accumulators total_gt = 0 total_vlm = 0 total_recall_exact = 0 total_recall_similar = 0 total_recall_missed = 0 total_precision_exact = 0 total_precision_similar = 0 total_precision_extra = 0 errors = 0 start_all = time.time() confusion_counter = Counter() missed_counter = Counter() extra_counter = Counter() per_image_results = [] for i, image_path in enumerate(image_files, 1): gt_groups = parse_ground_truth(image_path.name) gt_display = ", ".join("|".join(g) for g in gt_groups) if gt_groups else "(none)" print(f"\n[{i}/{len(image_files)}] {image_path.name}") print(f" GT: [{gt_display}]") image = cv2.imread(str(image_path)) if image is None: print(" SKIP (failed to load)") errors += 1 continue h, w = image.shape[:2] if w > MAX_IMAGE_WIDTH: scale = MAX_IMAGE_WIDTH / w image = cv2.resize(image, (MAX_IMAGE_WIDTH, int(h * scale)), interpolation=cv2.INTER_AREA) message = client.create_multimodal_message(role="user", content=prompt, images=[image]) try: t0 = time.time() response = client.chat_completion(messages=[message], temperature=0.1, max_tokens=1000) elapsed = time.time() - t0 response_text = response['choices'][0]['message']['content'] cleaned = clean_response(response_text) result = json.loads(cleaned) jerseys = result.get('jerseys', []) # Unique VLM jersey colors, ignoring white vlm_colors = set() for j in jerseys: jc = j.get('jersey_color', '').strip().lower() if jc and jc != 'white': vlm_colors.add(jc) vlm_display = ", ".join(sorted(vlm_colors)) if vlm_colors else "(none)" print(f" VLM: [{vlm_display}] ({len(jerseys)} jersey(s), {elapsed:.1f}s)") if not gt_groups: print(" -- no ground truth colors (white-only), skipping scoring") continue scores = score_image(gt_groups, vlm_colors) total_gt += scores['gt_count'] total_vlm += scores['vlm_count'] total_recall_exact += scores['recall_exact'] total_recall_similar += scores['recall_similar'] total_recall_missed += len(scores['recall_missed']) total_precision_exact += scores['precision_exact'] total_precision_similar += scores['precision_similar'] total_precision_extra += len(scores['precision_extra']) for group, got in scores['confusions']: confusion_counter[("|".join(group), got)] += 1 for group in scores['recall_missed']: missed_counter["|".join(group)] += 1 for ec in scores['precision_extra']: extra_counter[ec] += 1 # Status line status_parts = [] if scores['recall_exact']: status_parts.append(f"exact:{scores['recall_exact']}") if scores['recall_similar']: status_parts.append(f"similar:{scores['recall_similar']}") if scores['recall_missed']: missed_strs = ["|".join(g) for g in scores['recall_missed']] status_parts.append(f"MISS:{','.join(missed_strs)}") if scores['precision_extra']: status_parts.append(f"extra:{','.join(scores['precision_extra'])}") all_found = (scores['recall_exact'] + scores['recall_similar']) == scores['gt_count'] no_extra = not scores['precision_extra'] if all_found and no_extra: tag = "PASS" elif scores['recall_exact'] + scores['recall_similar'] > 0: tag = "PARTIAL" else: tag = "FAIL" print(f" {tag} {', '.join(status_parts)}") per_image_results.append({ 'file': image_path.name, 'tag': tag, 'scores': scores, }) except (json.JSONDecodeError, KeyError, IndexError) as e: print(f" PARSE ERROR: {e}") errors += 1 except Exception as e: print(f" ERROR: {e}") errors += 1 total_time = time.time() - start_all print_summary( total_gt, total_vlm, total_recall_exact, total_recall_similar, total_recall_missed, total_precision_exact, total_precision_similar, total_precision_extra, confusion_counter, missed_counter, extra_counter, per_image_results, len(image_files), errors, total_time, ) if __name__ == '__main__': main()