#!/usr/bin/env python3 """ Test script to discover the variety of colors a VLM returns for jersey detection. Submits all test images to the VLM and tallies every unique jersey_color and number_color value, producing a summary of the model's color vocabulary. Usage: python test_color_variety.py """ import json import os import re import sys import time from collections import Counter from pathlib import Path import cv2 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from scan_utils.llama_cpp_client import LlamaCppClient SERVER_URL = "http://agx:8080" IMAGES_DIR = os.path.join(os.path.dirname(__file__), "basketball_jersery_color_test_files") PROMPT_FILE = os.path.join(os.path.dirname(__file__), "jersey_prompt.txt") def clean_response(text: str) -> str: """Remove think tags and markdown code blocks from model output.""" cleaned = re.sub(r'.*?', '', text, flags=re.DOTALL | re.IGNORECASE) cleaned = re.sub(r'\u25c1think\u25b7.*?\u25c1/think\u25b7', '', cleaned, flags=re.DOTALL) cleaned = re.sub(r'', '', cleaned, flags=re.IGNORECASE) cleaned = re.sub(r'\u25c1/?think\u25b7', '', cleaned, flags=re.IGNORECASE) json_block = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, flags=re.DOTALL | re.IGNORECASE) if json_block: cleaned = json_block.group(1) else: cleaned = re.sub(r'```(?:json)?', '', cleaned, flags=re.IGNORECASE) return cleaned.strip() def main(): # Load prompt with open(PROMPT_FILE, 'r') as f: prompt = f.read() # Gather image files (extensions OpenCV can handle) valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'} image_files = sorted([ p for p in Path(IMAGES_DIR).iterdir() if p.suffix.lower() in valid_extensions ]) skipped = sorted([ p.name for p in Path(IMAGES_DIR).iterdir() if p.is_file() and p.suffix.lower() not in valid_extensions ]) print(f"Images to process: {len(image_files)}") if skipped: print(f"Skipping {len(skipped)} unsupported files: {', '.join(skipped)}") print(f"Server: {SERVER_URL}") print(f"Prompt: {PROMPT_FILE} ({len(prompt)} chars)") print("=" * 70) client = LlamaCppClient(base_url=SERVER_URL) jersey_color_counter = Counter() number_color_counter = Counter() total_jerseys = 0 errors = 0 start_all = time.time() for i, image_path in enumerate(image_files, 1): print(f"[{i}/{len(image_files)}] {image_path.name} ... ", end="", flush=True) image = cv2.imread(str(image_path)) if image is None: print("SKIP (failed to load)") errors += 1 continue message = client.create_multimodal_message(role="user", content=prompt, images=[image]) try: t0 = time.time() response = client.chat_completion(messages=[message], temperature=0.1, max_tokens=1000) elapsed = time.time() - t0 response_text = response['choices'][0]['message']['content'] cleaned = clean_response(response_text) result = json.loads(cleaned) jerseys = result.get('jerseys', []) colors_found = [] for j in jerseys: jc = j.get('jersey_color', '').strip().lower() nc = j.get('number_color', '').strip().lower() if jc: jersey_color_counter[jc] += 1 if nc: number_color_counter[nc] += 1 colors_found.append(f"{jc}/{nc}") total_jerseys += 1 print(f"{len(jerseys)} jersey(s) in {elapsed:.1f}s {', '.join(colors_found) if colors_found else '(none)'}") except (json.JSONDecodeError, KeyError, IndexError) as e: print(f"PARSE ERROR: {e}") errors += 1 except Exception as e: print(f"ERROR: {e}") errors += 1 total_time = time.time() - start_all # --- Summary --- print() print("=" * 70) print("COLOR VARIETY SUMMARY") print("=" * 70) print(f"Images processed: {len(image_files)}") print(f"Total jerseys detected: {total_jerseys}") print(f"Errors: {errors}") print(f"Total time: {total_time:.1f}s ({total_time / len(image_files):.1f}s avg)") print(f"\n--- Jersey Colors ({len(jersey_color_counter)} unique) ---") for color, count in jersey_color_counter.most_common(): bar = "#" * min(count, 50) print(f" {color:25s} {count:4d} {bar}") print(f"\n--- Number Colors ({len(number_color_counter)} unique) ---") for color, count in number_color_counter.most_common(): bar = "#" * min(count, 50) print(f" {color:25s} {count:4d} {bar}") # Combined unique palette all_colors = sorted(set(jersey_color_counter.keys()) | set(number_color_counter.keys())) print(f"\n--- Combined Color Palette ({len(all_colors)} unique values) ---") for color in all_colors: jc = jersey_color_counter.get(color, 0) nc = number_color_counter.get(color, 0) print(f" {color:25s} jersey:{jc:3d} number:{nc:3d}") if __name__ == '__main__': main()