Files
jersey_test/test_color_variety.py
Rick McEwen 435033ea07 Add color variety and hex specificity test scripts with report
- test_color_variety.py: named-color test for local llama.cpp VLM
- test_color_variety_gemini.py: named-color test for Gemini 3 Flash API
- test_hex_color_specificity.py: hex color specificity test for Gemini
- test_hex_color_specificity_llama.py: hex color specificity test for local VLM
- jersey_prompt_hex_color.txt: prompt requesting hex color codes
- COLOR_TEST_REPORT.md: analysis report comparing 3 models across 5 tests
- color_test_results.md: raw test output from all runs
2026-02-24 11:30:41 -07:00

151 lines
5.1 KiB
Python

#!/usr/bin/env python3
"""
Test script to discover the variety of colors a VLM returns for jersey detection.
Submits all test images to the VLM and tallies every unique jersey_color and
number_color value, producing a summary of the model's color vocabulary.
Usage:
python test_color_variety.py
"""
import json
import os
import re
import sys
import time
from collections import Counter
from pathlib import Path
import cv2
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from scan_utils.llama_cpp_client import LlamaCppClient
SERVER_URL = "http://agx:8080"
IMAGES_DIR = os.path.join(os.path.dirname(__file__), "basketball_jersery_color_test_files")
PROMPT_FILE = os.path.join(os.path.dirname(__file__), "jersey_prompt.txt")
def clean_response(text: str) -> str:
"""Remove think tags and markdown code blocks from model output."""
cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE)
cleaned = re.sub(r'\u25c1think\u25b7.*?\u25c1/think\u25b7', '', cleaned, flags=re.DOTALL)
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
cleaned = re.sub(r'\u25c1/?think\u25b7', '', cleaned, flags=re.IGNORECASE)
json_block = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, flags=re.DOTALL | re.IGNORECASE)
if json_block:
cleaned = json_block.group(1)
else:
cleaned = re.sub(r'```(?:json)?', '', cleaned, flags=re.IGNORECASE)
return cleaned.strip()
def main():
# Load prompt
with open(PROMPT_FILE, 'r') as f:
prompt = f.read()
# Gather image files (extensions OpenCV can handle)
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
image_files = sorted([
p for p in Path(IMAGES_DIR).iterdir()
if p.suffix.lower() in valid_extensions
])
skipped = sorted([
p.name for p in Path(IMAGES_DIR).iterdir()
if p.is_file() and p.suffix.lower() not in valid_extensions
])
print(f"Images to process: {len(image_files)}")
if skipped:
print(f"Skipping {len(skipped)} unsupported files: {', '.join(skipped)}")
print(f"Server: {SERVER_URL}")
print(f"Prompt: {PROMPT_FILE} ({len(prompt)} chars)")
print("=" * 70)
client = LlamaCppClient(base_url=SERVER_URL)
jersey_color_counter = Counter()
number_color_counter = Counter()
total_jerseys = 0
errors = 0
start_all = time.time()
for i, image_path in enumerate(image_files, 1):
print(f"[{i}/{len(image_files)}] {image_path.name} ... ", end="", flush=True)
image = cv2.imread(str(image_path))
if image is None:
print("SKIP (failed to load)")
errors += 1
continue
message = client.create_multimodal_message(role="user", content=prompt, images=[image])
try:
t0 = time.time()
response = client.chat_completion(messages=[message], temperature=0.1, max_tokens=1000)
elapsed = time.time() - t0
response_text = response['choices'][0]['message']['content']
cleaned = clean_response(response_text)
result = json.loads(cleaned)
jerseys = result.get('jerseys', [])
colors_found = []
for j in jerseys:
jc = j.get('jersey_color', '').strip().lower()
nc = j.get('number_color', '').strip().lower()
if jc:
jersey_color_counter[jc] += 1
if nc:
number_color_counter[nc] += 1
colors_found.append(f"{jc}/{nc}")
total_jerseys += 1
print(f"{len(jerseys)} jersey(s) in {elapsed:.1f}s {', '.join(colors_found) if colors_found else '(none)'}")
except (json.JSONDecodeError, KeyError, IndexError) as e:
print(f"PARSE ERROR: {e}")
errors += 1
except Exception as e:
print(f"ERROR: {e}")
errors += 1
total_time = time.time() - start_all
# --- Summary ---
print()
print("=" * 70)
print("COLOR VARIETY SUMMARY")
print("=" * 70)
print(f"Images processed: {len(image_files)}")
print(f"Total jerseys detected: {total_jerseys}")
print(f"Errors: {errors}")
print(f"Total time: {total_time:.1f}s ({total_time / len(image_files):.1f}s avg)")
print(f"\n--- Jersey Colors ({len(jersey_color_counter)} unique) ---")
for color, count in jersey_color_counter.most_common():
bar = "#" * min(count, 50)
print(f" {color:25s} {count:4d} {bar}")
print(f"\n--- Number Colors ({len(number_color_counter)} unique) ---")
for color, count in number_color_counter.most_common():
bar = "#" * min(count, 50)
print(f" {color:25s} {count:4d} {bar}")
# Combined unique palette
all_colors = sorted(set(jersey_color_counter.keys()) | set(number_color_counter.keys()))
print(f"\n--- Combined Color Palette ({len(all_colors)} unique values) ---")
for color in all_colors:
jc = jersey_color_counter.get(color, 0)
nc = number_color_counter.get(color, 0)
print(f" {color:25s} jersey:{jc:3d} number:{nc:3d}")
if __name__ == '__main__':
main()