Add color variety and hex specificity test scripts with report
- test_color_variety.py: named-color test for local llama.cpp VLM - test_color_variety_gemini.py: named-color test for Gemini 3 Flash API - test_hex_color_specificity.py: hex color specificity test for Gemini - test_hex_color_specificity_llama.py: hex color specificity test for local VLM - jersey_prompt_hex_color.txt: prompt requesting hex color codes - COLOR_TEST_REPORT.md: analysis report comparing 3 models across 5 tests - color_test_results.md: raw test output from all runs
This commit is contained in:
151
test_color_variety.py
Normal file
151
test_color_variety.py
Normal file
@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to discover the variety of colors a VLM returns for jersey detection.
|
||||
|
||||
Submits all test images to the VLM and tallies every unique jersey_color and
|
||||
number_color value, producing a summary of the model's color vocabulary.
|
||||
|
||||
Usage:
|
||||
python test_color_variety.py
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from scan_utils.llama_cpp_client import LlamaCppClient
|
||||
|
||||
SERVER_URL = "http://agx:8080"
|
||||
IMAGES_DIR = os.path.join(os.path.dirname(__file__), "basketball_jersery_color_test_files")
|
||||
PROMPT_FILE = os.path.join(os.path.dirname(__file__), "jersey_prompt.txt")
|
||||
|
||||
|
||||
def clean_response(text: str) -> str:
|
||||
"""Remove think tags and markdown code blocks from model output."""
|
||||
cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
cleaned = re.sub(r'\u25c1think\u25b7.*?\u25c1/think\u25b7', '', cleaned, flags=re.DOTALL)
|
||||
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
|
||||
cleaned = re.sub(r'\u25c1/?think\u25b7', '', cleaned, flags=re.IGNORECASE)
|
||||
|
||||
json_block = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, flags=re.DOTALL | re.IGNORECASE)
|
||||
if json_block:
|
||||
cleaned = json_block.group(1)
|
||||
else:
|
||||
cleaned = re.sub(r'```(?:json)?', '', cleaned, flags=re.IGNORECASE)
|
||||
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def main():
|
||||
# Load prompt
|
||||
with open(PROMPT_FILE, 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
# Gather image files (extensions OpenCV can handle)
|
||||
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
||||
image_files = sorted([
|
||||
p for p in Path(IMAGES_DIR).iterdir()
|
||||
if p.suffix.lower() in valid_extensions
|
||||
])
|
||||
|
||||
skipped = sorted([
|
||||
p.name for p in Path(IMAGES_DIR).iterdir()
|
||||
if p.is_file() and p.suffix.lower() not in valid_extensions
|
||||
])
|
||||
|
||||
print(f"Images to process: {len(image_files)}")
|
||||
if skipped:
|
||||
print(f"Skipping {len(skipped)} unsupported files: {', '.join(skipped)}")
|
||||
print(f"Server: {SERVER_URL}")
|
||||
print(f"Prompt: {PROMPT_FILE} ({len(prompt)} chars)")
|
||||
print("=" * 70)
|
||||
|
||||
client = LlamaCppClient(base_url=SERVER_URL)
|
||||
|
||||
jersey_color_counter = Counter()
|
||||
number_color_counter = Counter()
|
||||
total_jerseys = 0
|
||||
errors = 0
|
||||
start_all = time.time()
|
||||
|
||||
for i, image_path in enumerate(image_files, 1):
|
||||
print(f"[{i}/{len(image_files)}] {image_path.name} ... ", end="", flush=True)
|
||||
|
||||
image = cv2.imread(str(image_path))
|
||||
if image is None:
|
||||
print("SKIP (failed to load)")
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
message = client.create_multimodal_message(role="user", content=prompt, images=[image])
|
||||
|
||||
try:
|
||||
t0 = time.time()
|
||||
response = client.chat_completion(messages=[message], temperature=0.1, max_tokens=1000)
|
||||
elapsed = time.time() - t0
|
||||
|
||||
response_text = response['choices'][0]['message']['content']
|
||||
cleaned = clean_response(response_text)
|
||||
result = json.loads(cleaned)
|
||||
jerseys = result.get('jerseys', [])
|
||||
|
||||
colors_found = []
|
||||
for j in jerseys:
|
||||
jc = j.get('jersey_color', '').strip().lower()
|
||||
nc = j.get('number_color', '').strip().lower()
|
||||
if jc:
|
||||
jersey_color_counter[jc] += 1
|
||||
if nc:
|
||||
number_color_counter[nc] += 1
|
||||
colors_found.append(f"{jc}/{nc}")
|
||||
total_jerseys += 1
|
||||
|
||||
print(f"{len(jerseys)} jersey(s) in {elapsed:.1f}s {', '.join(colors_found) if colors_found else '(none)'}")
|
||||
|
||||
except (json.JSONDecodeError, KeyError, IndexError) as e:
|
||||
print(f"PARSE ERROR: {e}")
|
||||
errors += 1
|
||||
except Exception as e:
|
||||
print(f"ERROR: {e}")
|
||||
errors += 1
|
||||
|
||||
total_time = time.time() - start_all
|
||||
|
||||
# --- Summary ---
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("COLOR VARIETY SUMMARY")
|
||||
print("=" * 70)
|
||||
print(f"Images processed: {len(image_files)}")
|
||||
print(f"Total jerseys detected: {total_jerseys}")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Total time: {total_time:.1f}s ({total_time / len(image_files):.1f}s avg)")
|
||||
|
||||
print(f"\n--- Jersey Colors ({len(jersey_color_counter)} unique) ---")
|
||||
for color, count in jersey_color_counter.most_common():
|
||||
bar = "#" * min(count, 50)
|
||||
print(f" {color:25s} {count:4d} {bar}")
|
||||
|
||||
print(f"\n--- Number Colors ({len(number_color_counter)} unique) ---")
|
||||
for color, count in number_color_counter.most_common():
|
||||
bar = "#" * min(count, 50)
|
||||
print(f" {color:25s} {count:4d} {bar}")
|
||||
|
||||
# Combined unique palette
|
||||
all_colors = sorted(set(jersey_color_counter.keys()) | set(number_color_counter.keys()))
|
||||
print(f"\n--- Combined Color Palette ({len(all_colors)} unique values) ---")
|
||||
for color in all_colors:
|
||||
jc = jersey_color_counter.get(color, 0)
|
||||
nc = number_color_counter.get(color, 0)
|
||||
print(f" {color:25s} jersey:{jc:3d} number:{nc:3d}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user