- test_color_variety.py: named-color test for local llama.cpp VLM - test_color_variety_gemini.py: named-color test for Gemini 3 Flash API - test_hex_color_specificity.py: hex color specificity test for Gemini - test_hex_color_specificity_llama.py: hex color specificity test for local VLM - jersey_prompt_hex_color.txt: prompt requesting hex color codes - COLOR_TEST_REPORT.md: analysis report comparing 3 models across 5 tests - color_test_results.md: raw test output from all runs
151 lines
5.1 KiB
Python
151 lines
5.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Test script to discover the variety of colors a VLM returns for jersey detection.
|
|
|
|
Submits all test images to the VLM and tallies every unique jersey_color and
|
|
number_color value, producing a summary of the model's color vocabulary.
|
|
|
|
Usage:
|
|
python test_color_variety.py
|
|
"""
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
from collections import Counter
|
|
from pathlib import Path
|
|
|
|
import cv2
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
|
from scan_utils.llama_cpp_client import LlamaCppClient
|
|
|
|
SERVER_URL = "http://agx:8080"
|
|
IMAGES_DIR = os.path.join(os.path.dirname(__file__), "basketball_jersery_color_test_files")
|
|
PROMPT_FILE = os.path.join(os.path.dirname(__file__), "jersey_prompt.txt")
|
|
|
|
|
|
def clean_response(text: str) -> str:
|
|
"""Remove think tags and markdown code blocks from model output."""
|
|
cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
|
cleaned = re.sub(r'\u25c1think\u25b7.*?\u25c1/think\u25b7', '', cleaned, flags=re.DOTALL)
|
|
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
|
|
cleaned = re.sub(r'\u25c1/?think\u25b7', '', cleaned, flags=re.IGNORECASE)
|
|
|
|
json_block = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, flags=re.DOTALL | re.IGNORECASE)
|
|
if json_block:
|
|
cleaned = json_block.group(1)
|
|
else:
|
|
cleaned = re.sub(r'```(?:json)?', '', cleaned, flags=re.IGNORECASE)
|
|
|
|
return cleaned.strip()
|
|
|
|
|
|
def main():
|
|
# Load prompt
|
|
with open(PROMPT_FILE, 'r') as f:
|
|
prompt = f.read()
|
|
|
|
# Gather image files (extensions OpenCV can handle)
|
|
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
|
image_files = sorted([
|
|
p for p in Path(IMAGES_DIR).iterdir()
|
|
if p.suffix.lower() in valid_extensions
|
|
])
|
|
|
|
skipped = sorted([
|
|
p.name for p in Path(IMAGES_DIR).iterdir()
|
|
if p.is_file() and p.suffix.lower() not in valid_extensions
|
|
])
|
|
|
|
print(f"Images to process: {len(image_files)}")
|
|
if skipped:
|
|
print(f"Skipping {len(skipped)} unsupported files: {', '.join(skipped)}")
|
|
print(f"Server: {SERVER_URL}")
|
|
print(f"Prompt: {PROMPT_FILE} ({len(prompt)} chars)")
|
|
print("=" * 70)
|
|
|
|
client = LlamaCppClient(base_url=SERVER_URL)
|
|
|
|
jersey_color_counter = Counter()
|
|
number_color_counter = Counter()
|
|
total_jerseys = 0
|
|
errors = 0
|
|
start_all = time.time()
|
|
|
|
for i, image_path in enumerate(image_files, 1):
|
|
print(f"[{i}/{len(image_files)}] {image_path.name} ... ", end="", flush=True)
|
|
|
|
image = cv2.imread(str(image_path))
|
|
if image is None:
|
|
print("SKIP (failed to load)")
|
|
errors += 1
|
|
continue
|
|
|
|
message = client.create_multimodal_message(role="user", content=prompt, images=[image])
|
|
|
|
try:
|
|
t0 = time.time()
|
|
response = client.chat_completion(messages=[message], temperature=0.1, max_tokens=1000)
|
|
elapsed = time.time() - t0
|
|
|
|
response_text = response['choices'][0]['message']['content']
|
|
cleaned = clean_response(response_text)
|
|
result = json.loads(cleaned)
|
|
jerseys = result.get('jerseys', [])
|
|
|
|
colors_found = []
|
|
for j in jerseys:
|
|
jc = j.get('jersey_color', '').strip().lower()
|
|
nc = j.get('number_color', '').strip().lower()
|
|
if jc:
|
|
jersey_color_counter[jc] += 1
|
|
if nc:
|
|
number_color_counter[nc] += 1
|
|
colors_found.append(f"{jc}/{nc}")
|
|
total_jerseys += 1
|
|
|
|
print(f"{len(jerseys)} jersey(s) in {elapsed:.1f}s {', '.join(colors_found) if colors_found else '(none)'}")
|
|
|
|
except (json.JSONDecodeError, KeyError, IndexError) as e:
|
|
print(f"PARSE ERROR: {e}")
|
|
errors += 1
|
|
except Exception as e:
|
|
print(f"ERROR: {e}")
|
|
errors += 1
|
|
|
|
total_time = time.time() - start_all
|
|
|
|
# --- Summary ---
|
|
print()
|
|
print("=" * 70)
|
|
print("COLOR VARIETY SUMMARY")
|
|
print("=" * 70)
|
|
print(f"Images processed: {len(image_files)}")
|
|
print(f"Total jerseys detected: {total_jerseys}")
|
|
print(f"Errors: {errors}")
|
|
print(f"Total time: {total_time:.1f}s ({total_time / len(image_files):.1f}s avg)")
|
|
|
|
print(f"\n--- Jersey Colors ({len(jersey_color_counter)} unique) ---")
|
|
for color, count in jersey_color_counter.most_common():
|
|
bar = "#" * min(count, 50)
|
|
print(f" {color:25s} {count:4d} {bar}")
|
|
|
|
print(f"\n--- Number Colors ({len(number_color_counter)} unique) ---")
|
|
for color, count in number_color_counter.most_common():
|
|
bar = "#" * min(count, 50)
|
|
print(f" {color:25s} {count:4d} {bar}")
|
|
|
|
# Combined unique palette
|
|
all_colors = sorted(set(jersey_color_counter.keys()) | set(number_color_counter.keys()))
|
|
print(f"\n--- Combined Color Palette ({len(all_colors)} unique values) ---")
|
|
for color in all_colors:
|
|
jc = jersey_color_counter.get(color, 0)
|
|
nc = number_color_counter.get(color, 0)
|
|
print(f" {color:25s} jersey:{jc:3d} number:{nc:3d}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main() |