Add accuracy test framework, prompts, results, and analysis reports
Includes accuracy test scripts for Qwen (local) and Gemini (cloud API), three prompt variants (original, capstone, constrained), test results from all runs, and two analysis reports with an HTML presentation version.
This commit is contained in:
576
test_accuracy_gemini.py
Normal file
576
test_accuracy_gemini.py
Normal file
@ -0,0 +1,576 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to measure Gemini VLM accuracy for jersey color detection.
|
||||
|
||||
Uses annotated test images where ground truth colors are encoded in filenames.
|
||||
Compares Gemini results against ground truth, measuring exact and similar color
|
||||
matches. White is ignored in both ground truth and VLM results.
|
||||
|
||||
Filename format: "014 - orange_dark blue or purple.jpg"
|
||||
- Underscore separates distinct jersey colors
|
||||
- "or" separates acceptable alternatives for a single jersey
|
||||
|
||||
Usage:
|
||||
python test_accuracy_gemini.py [prompt_file]
|
||||
"""
|
||||
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import requests
|
||||
|
||||
GEMINI_MODEL = "gemini-3-flash-preview"
|
||||
API_URL = f"https://generativelanguage.googleapis.com/v1beta/models/{GEMINI_MODEL}:generateContent"
|
||||
|
||||
IMAGES_DIR = os.path.join(os.path.dirname(__file__), "basketball_jersery_color_test_files_annotated")
|
||||
DEFAULT_PROMPT_FILE = os.path.join(os.path.dirname(__file__), "jersey_prompt.txt")
|
||||
API_KEY_FILE = os.path.join(os.path.dirname(__file__), "gemini_api_key.txt")
|
||||
MAX_IMAGE_WIDTH = 768
|
||||
JPEG_QUALITY = 85
|
||||
CONCURRENT_WORKERS = 8
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Color similarity – colors in the same family count as "similar" matches
|
||||
# ---------------------------------------------------------------------------
|
||||
COLOR_FAMILIES = {
|
||||
'blue': ['blue', 'dark blue', 'navy blue', 'navy', 'royal blue'],
|
||||
'light_blue': ['light blue', 'sky blue', 'baby blue', 'carolina blue', 'powder blue'],
|
||||
'red': ['red', 'scarlet', 'crimson'],
|
||||
'dark_red': ['maroon', 'burgundy', 'dark red', 'wine'],
|
||||
'green': ['green', 'dark green', 'forest green', 'kelly green'],
|
||||
'yellow': ['yellow', 'gold', 'golden'],
|
||||
'orange': ['orange', 'burnt orange'],
|
||||
'brown': ['brown', 'dark brown'],
|
||||
'purple': ['purple', 'violet'],
|
||||
'gray': ['gray', 'grey', 'silver', 'charcoal'],
|
||||
'black': ['black'],
|
||||
'teal': ['teal', 'turquoise', 'cyan', 'aqua'],
|
||||
'pink': ['pink', 'magenta', 'hot pink', 'rose'],
|
||||
}
|
||||
|
||||
_COLOR_TO_FAMILY = {}
|
||||
for _family, _members in COLOR_FAMILIES.items():
|
||||
for _color in _members:
|
||||
_COLOR_TO_FAMILY[_color] = _family
|
||||
|
||||
|
||||
def colors_are_similar(color1: str, color2: str) -> bool:
|
||||
"""Return True if two colors belong to the same color family."""
|
||||
if color1 == color2:
|
||||
return True
|
||||
f1 = _COLOR_TO_FAMILY.get(color1)
|
||||
f2 = _COLOR_TO_FAMILY.get(color2)
|
||||
return bool(f1 and f2 and f1 == f2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ground-truth parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
def parse_ground_truth(filename: str) -> list[list[str]]:
|
||||
"""Parse ground truth colors from an annotated filename.
|
||||
|
||||
Returns a list of color groups. Each group is a list of acceptable
|
||||
alternatives (from "or" in the filename). White entries are removed.
|
||||
|
||||
Example: "014 - orange_dark blue or purple.jpg"
|
||||
-> [["orange"], ["dark blue", "purple"]]
|
||||
"""
|
||||
name = Path(filename).stem
|
||||
# Strip number prefix ("014 - ", "029 -", etc.)
|
||||
name = re.sub(r'^\d+\s*-\s*', '', name)
|
||||
# Treat hyphens between colors as underscores (e.g. "yellow-black")
|
||||
name = name.replace('-', '_')
|
||||
|
||||
color_groups = []
|
||||
for part in name.split('_'):
|
||||
part = part.strip()
|
||||
if not part:
|
||||
continue
|
||||
alternatives = [a.strip().lower() for a in part.split(' or ')]
|
||||
alternatives = [a for a in alternatives if a and a != 'white']
|
||||
if alternatives:
|
||||
color_groups.append(alternatives)
|
||||
return color_groups
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response cleaning & salvage
|
||||
# ---------------------------------------------------------------------------
|
||||
def clean_response(text: str) -> str:
|
||||
"""Remove think tags and markdown code blocks from model output."""
|
||||
cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
cleaned = re.sub(r'</?think>', '', cleaned, flags=re.IGNORECASE)
|
||||
|
||||
json_block = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', cleaned, flags=re.DOTALL | re.IGNORECASE)
|
||||
if json_block:
|
||||
cleaned = json_block.group(1)
|
||||
else:
|
||||
cleaned = re.sub(r'```(?:json)?', '', cleaned, flags=re.IGNORECASE)
|
||||
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def salvage_jerseys(text: str) -> list[dict]:
|
||||
"""Extract complete jersey objects from truncated JSON using regex."""
|
||||
pattern = re.compile(
|
||||
r'\{\s*'
|
||||
r'"jersey_number"\s*:\s*"[^"]*"\s*,\s*'
|
||||
r'"jersey_color"\s*:\s*"([^"]*)"\s*,\s*'
|
||||
r'"number_color"\s*:\s*"([^"]*)"\s*'
|
||||
r'\}',
|
||||
re.DOTALL,
|
||||
)
|
||||
jerseys = []
|
||||
for m in pattern.finditer(text):
|
||||
jerseys.append({
|
||||
'jersey_color': m.group(1),
|
||||
'number_color': m.group(2),
|
||||
})
|
||||
return jerseys
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Gemini API helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def load_api_key() -> str:
|
||||
with open(API_KEY_FILE, 'r') as f:
|
||||
return f.read().strip()
|
||||
|
||||
|
||||
def encode_image(image_path: str) -> tuple[str, str]:
|
||||
"""Read an image file, resize if wider than MAX_IMAGE_WIDTH, and return (base64_data, mime_type)."""
|
||||
ext = Path(image_path).suffix.lower()
|
||||
mime_map = {
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.png': 'image/png',
|
||||
'.webp': 'image/webp',
|
||||
'.bmp': 'image/bmp',
|
||||
'.tiff': 'image/tiff',
|
||||
}
|
||||
mime_type = mime_map.get(ext, 'image/jpeg')
|
||||
|
||||
image = cv2.imread(image_path)
|
||||
if image is not None:
|
||||
h, w = image.shape[:2]
|
||||
if w > MAX_IMAGE_WIDTH:
|
||||
scale = MAX_IMAGE_WIDTH / w
|
||||
image = cv2.resize(image, (MAX_IMAGE_WIDTH, int(h * scale)), interpolation=cv2.INTER_AREA)
|
||||
if ext == '.png':
|
||||
_, buf = cv2.imencode('.png', image)
|
||||
else:
|
||||
_, buf = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, JPEG_QUALITY])
|
||||
data = base64.b64encode(buf).decode('utf-8')
|
||||
else:
|
||||
with open(image_path, 'rb') as f:
|
||||
data = base64.b64encode(f.read()).decode('utf-8')
|
||||
|
||||
return data, mime_type
|
||||
|
||||
|
||||
MAX_RETRIES = 3
|
||||
RETRY_BACKOFF = [2, 5, 10]
|
||||
|
||||
|
||||
def call_gemini(session: requests.Session, api_key: str, image_data: str,
|
||||
mime_type: str, prompt: str) -> dict:
|
||||
"""Send pre-encoded image + prompt to the Gemini API and return the raw response."""
|
||||
payload = {
|
||||
"contents": [{
|
||||
"parts": [
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": mime_type,
|
||||
"data": image_data,
|
||||
}
|
||||
},
|
||||
{
|
||||
"text": prompt,
|
||||
}
|
||||
]
|
||||
}],
|
||||
"generationConfig": {
|
||||
"temperature": 0.1,
|
||||
"maxOutputTokens": 8192,
|
||||
"responseMimeType": "application/json",
|
||||
}
|
||||
}
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
response = session.post(
|
||||
API_URL,
|
||||
headers={
|
||||
"x-goog-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code >= 500 and attempt < MAX_RETRIES - 1:
|
||||
time.sleep(RETRY_BACKOFF[attempt])
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _api_worker(session: requests.Session, api_key: str, image_data: str,
|
||||
mime_type: str, prompt: str) -> dict:
|
||||
"""Wrapper that captures timing and exceptions for concurrent execution."""
|
||||
t0 = time.time()
|
||||
try:
|
||||
resp = call_gemini(session, api_key, image_data, mime_type, prompt)
|
||||
return {'resp': resp, 'elapsed': time.time() - t0, 'error': None}
|
||||
except Exception as e:
|
||||
return {'resp': None, 'elapsed': time.time() - t0, 'error': e}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scoring
|
||||
# ---------------------------------------------------------------------------
|
||||
def score_image(gt_groups: list[list[str]], vlm_colors: set[str]) -> dict:
|
||||
"""Compare VLM detected colors against ground truth color groups.
|
||||
|
||||
Recall = how many GT color groups were found in VLM output
|
||||
Precision = how many VLM colors match something in the GT
|
||||
"""
|
||||
recall_exact = 0
|
||||
recall_similar = 0
|
||||
recall_missed = []
|
||||
confusions = []
|
||||
|
||||
for group in gt_groups:
|
||||
# Try exact match first
|
||||
if any(alt in vlm_colors for alt in group):
|
||||
recall_exact += 1
|
||||
continue
|
||||
# Try similar match
|
||||
matched_vlm = None
|
||||
for alt in group:
|
||||
for vc in vlm_colors:
|
||||
if colors_are_similar(alt, vc):
|
||||
matched_vlm = vc
|
||||
break
|
||||
if matched_vlm:
|
||||
break
|
||||
if matched_vlm:
|
||||
recall_similar += 1
|
||||
confusions.append((group, matched_vlm))
|
||||
else:
|
||||
recall_missed.append(group)
|
||||
|
||||
# Precision: check each VLM color against GT
|
||||
all_gt_alts = [alt for group in gt_groups for alt in group]
|
||||
precision_exact = 0
|
||||
precision_similar = 0
|
||||
precision_extra = []
|
||||
for vc in vlm_colors:
|
||||
if vc in all_gt_alts:
|
||||
precision_exact += 1
|
||||
elif any(colors_are_similar(vc, gt) for gt in all_gt_alts):
|
||||
precision_similar += 1
|
||||
else:
|
||||
precision_extra.append(vc)
|
||||
|
||||
return {
|
||||
'gt_count': len(gt_groups),
|
||||
'vlm_count': len(vlm_colors),
|
||||
'recall_exact': recall_exact,
|
||||
'recall_similar': recall_similar,
|
||||
'recall_missed': recall_missed,
|
||||
'precision_exact': precision_exact,
|
||||
'precision_similar': precision_similar,
|
||||
'precision_extra': precision_extra,
|
||||
'confusions': confusions,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
def pct(n: int, d: int) -> str:
|
||||
return f"{100 * n / d:.1f}%" if d else "N/A"
|
||||
|
||||
|
||||
def extract_vlm_colors(jerseys: list[dict]) -> set[str]:
|
||||
"""Return unique jersey colors from VLM output, ignoring white."""
|
||||
vlm_colors = set()
|
||||
for j in jerseys:
|
||||
jc = j.get('jersey_color', '').strip().lower()
|
||||
if jc and jc != 'white':
|
||||
vlm_colors.add(jc)
|
||||
return vlm_colors
|
||||
|
||||
|
||||
def parse_response(result: dict) -> tuple[list[dict], set[str]]:
|
||||
"""Parse a Gemini response into jerseys list and vlm_colors set.
|
||||
|
||||
On JSON parse failure, attempts to salvage jersey objects from truncated
|
||||
output. Returns (jerseys, vlm_colors).
|
||||
"""
|
||||
text = result['resp']['candidates'][0]['content']['parts'][0]['text']
|
||||
cleaned = clean_response(text)
|
||||
try:
|
||||
data = json.loads(cleaned)
|
||||
jerseys = data.get('jerseys', [])
|
||||
except json.JSONDecodeError:
|
||||
jerseys = salvage_jerseys(text)
|
||||
return jerseys, extract_vlm_colors(jerseys)
|
||||
|
||||
|
||||
def score_and_format(gt_groups, vlm_colors, scores):
|
||||
"""Build a status line and tag from scoring results."""
|
||||
status_parts = []
|
||||
if scores['recall_exact']:
|
||||
status_parts.append(f"exact:{scores['recall_exact']}")
|
||||
if scores['recall_similar']:
|
||||
status_parts.append(f"similar:{scores['recall_similar']}")
|
||||
if scores['recall_missed']:
|
||||
missed_strs = ["|".join(g) for g in scores['recall_missed']]
|
||||
status_parts.append(f"MISS:{','.join(missed_strs)}")
|
||||
if scores['precision_extra']:
|
||||
status_parts.append(f"extra:{','.join(scores['precision_extra'])}")
|
||||
|
||||
all_found = (scores['recall_exact'] + scores['recall_similar']) == scores['gt_count']
|
||||
no_extra = not scores['precision_extra']
|
||||
if all_found and no_extra:
|
||||
tag = "PASS"
|
||||
elif scores['recall_exact'] + scores['recall_similar'] > 0:
|
||||
tag = "PARTIAL"
|
||||
else:
|
||||
tag = "FAIL"
|
||||
return tag, status_parts
|
||||
|
||||
|
||||
def print_summary(model_name, total_gt, total_vlm, total_recall_exact,
|
||||
total_recall_similar, total_recall_missed,
|
||||
total_precision_exact, total_precision_similar,
|
||||
total_precision_extra, confusion_counter, missed_counter,
|
||||
extra_counter, per_image_results, image_count, errors,
|
||||
total_time):
|
||||
"""Print the full accuracy summary report."""
|
||||
print()
|
||||
print("=" * 80)
|
||||
print(f"ACCURACY SUMMARY ({model_name})")
|
||||
print("=" * 80)
|
||||
print(f"Images processed: {image_count}")
|
||||
print(f"Errors: {errors}")
|
||||
print(f"Total time: {total_time:.1f}s ({total_time / max(image_count, 1):.1f}s avg)")
|
||||
print()
|
||||
print(f"Ground truth colors: {total_gt} (excluding white)")
|
||||
print(f"VLM unique colors: {total_vlm} (excluding white)")
|
||||
print()
|
||||
|
||||
print("--- Recall (did VLM find each ground truth color?) ---")
|
||||
print(f" Exact match: {total_recall_exact:4d} / {total_gt} ({pct(total_recall_exact, total_gt)})")
|
||||
print(f" Similar match: {total_recall_similar:4d} / {total_gt} ({pct(total_recall_similar, total_gt)})")
|
||||
recall_total = total_recall_exact + total_recall_similar
|
||||
print(f" Total found: {recall_total:4d} / {total_gt} ({pct(recall_total, total_gt)})")
|
||||
print(f" Missed: {total_recall_missed:4d} / {total_gt} ({pct(total_recall_missed, total_gt)})")
|
||||
print()
|
||||
|
||||
print("--- Precision (are VLM colors correct?) ---")
|
||||
print(f" Exact match: {total_precision_exact:4d} / {total_vlm} ({pct(total_precision_exact, total_vlm)})")
|
||||
print(f" Similar match: {total_precision_similar:4d} / {total_vlm} ({pct(total_precision_similar, total_vlm)})")
|
||||
prec_total = total_precision_exact + total_precision_similar
|
||||
print(f" Total correct: {prec_total:4d} / {total_vlm} ({pct(prec_total, total_vlm)})")
|
||||
print(f" Extra/wrong: {total_precision_extra:4d} / {total_vlm} ({pct(total_precision_extra, total_vlm)})")
|
||||
|
||||
if confusion_counter:
|
||||
print()
|
||||
print("--- Similar-Match Confusions (expected -> got) ---")
|
||||
for (expected, got), count in confusion_counter.most_common():
|
||||
print(f" {expected:30s} -> {got:20s} x{count}")
|
||||
|
||||
if missed_counter:
|
||||
print()
|
||||
print("--- Most Missed Ground Truth Colors ---")
|
||||
for color, count in missed_counter.most_common(20):
|
||||
bar = "#" * min(count, 40)
|
||||
print(f" {color:30s} {count:3d} {bar}")
|
||||
|
||||
if extra_counter:
|
||||
print()
|
||||
print("--- Most Common Extra/Wrong VLM Colors ---")
|
||||
for color, count in extra_counter.most_common(20):
|
||||
bar = "#" * min(count, 40)
|
||||
print(f" {color:30s} {count:3d} {bar}")
|
||||
|
||||
if per_image_results:
|
||||
tags = Counter(r['tag'] for r in per_image_results)
|
||||
print()
|
||||
print("--- Per-Image Verdict ---")
|
||||
for tag in ['PASS', 'PARTIAL', 'FAIL']:
|
||||
print(f" {tag:10s} {tags.get(tag, 0):4d}")
|
||||
|
||||
failed = [r for r in per_image_results if r['tag'] == 'FAIL']
|
||||
if failed:
|
||||
print()
|
||||
print(f"--- Failed Images ({len(failed)}) ---")
|
||||
for r in failed:
|
||||
scores = r['scores']
|
||||
missed_strs = ["|".join(g) for g in scores['recall_missed']]
|
||||
print(f" {r['file']}")
|
||||
print(f" missed: {', '.join(missed_strs)}")
|
||||
if scores['precision_extra']:
|
||||
print(f" extra: {', '.join(scores['precision_extra'])}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
def main():
|
||||
prompt_file = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_PROMPT_FILE
|
||||
api_key = load_api_key()
|
||||
|
||||
with open(prompt_file, 'r') as f:
|
||||
prompt = f.read()
|
||||
|
||||
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}
|
||||
image_files = sorted([
|
||||
p for p in Path(IMAGES_DIR).iterdir()
|
||||
if p.suffix.lower() in valid_extensions
|
||||
])
|
||||
|
||||
print(f"Model: {GEMINI_MODEL}")
|
||||
print(f"Images to process: {len(image_files)}")
|
||||
print(f"Concurrency: {CONCURRENT_WORKERS} workers")
|
||||
print(f"Prompt: {prompt_file} ({len(prompt)} chars)")
|
||||
print("=" * 80)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 1: Pre-encode all images
|
||||
# ------------------------------------------------------------------
|
||||
print("Pre-encoding images ... ", end="", flush=True)
|
||||
t_enc = time.time()
|
||||
encoded_images = []
|
||||
for image_path in image_files:
|
||||
encoded_images.append(encode_image(str(image_path)))
|
||||
print(f"{len(encoded_images)} images in {time.time() - t_enc:.1f}s")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 2: Submit all API calls concurrently
|
||||
# ------------------------------------------------------------------
|
||||
session = requests.Session()
|
||||
start_all = time.time()
|
||||
|
||||
print(f"Sending API requests ... ", flush=True)
|
||||
api_results = [None] * len(image_files)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=CONCURRENT_WORKERS) as executor:
|
||||
future_to_idx = {}
|
||||
for i, (image_data, mime_type) in enumerate(encoded_images):
|
||||
future = executor.submit(
|
||||
_api_worker, session, api_key, image_data, mime_type, prompt,
|
||||
)
|
||||
future_to_idx[future] = i
|
||||
|
||||
completed = 0
|
||||
for future in concurrent.futures.as_completed(future_to_idx):
|
||||
idx = future_to_idx[future]
|
||||
api_results[idx] = future.result()
|
||||
completed += 1
|
||||
print(f"\r {completed}/{len(image_files)} API calls completed", end="", flush=True)
|
||||
|
||||
api_time = time.time() - start_all
|
||||
print(f" ({api_time:.1f}s total)")
|
||||
print("=" * 80)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Phase 3: Score results in order
|
||||
# ------------------------------------------------------------------
|
||||
total_gt = 0
|
||||
total_vlm = 0
|
||||
total_recall_exact = 0
|
||||
total_recall_similar = 0
|
||||
total_recall_missed = 0
|
||||
total_precision_exact = 0
|
||||
total_precision_similar = 0
|
||||
total_precision_extra = 0
|
||||
errors = 0
|
||||
|
||||
confusion_counter = Counter()
|
||||
missed_counter = Counter()
|
||||
extra_counter = Counter()
|
||||
per_image_results = []
|
||||
|
||||
for i, (image_path, result) in enumerate(zip(image_files, api_results), 1):
|
||||
gt_groups = parse_ground_truth(image_path.name)
|
||||
gt_display = ", ".join("|".join(g) for g in gt_groups) if gt_groups else "(none)"
|
||||
print(f"\n[{i}/{len(image_files)}] {image_path.name}")
|
||||
print(f" GT: [{gt_display}]")
|
||||
|
||||
if result['error'] is not None:
|
||||
e = result['error']
|
||||
if isinstance(e, requests.exceptions.HTTPError):
|
||||
print(f" HTTP ERROR: {e}")
|
||||
else:
|
||||
print(f" ERROR: {e}")
|
||||
errors += 1
|
||||
continue
|
||||
|
||||
elapsed = result['elapsed']
|
||||
|
||||
try:
|
||||
jerseys, vlm_colors = parse_response(result)
|
||||
|
||||
vlm_display = ", ".join(sorted(vlm_colors)) if vlm_colors else "(none)"
|
||||
print(f" VLM: [{vlm_display}] ({len(jerseys)} jersey(s), {elapsed:.1f}s)")
|
||||
|
||||
if not gt_groups:
|
||||
print(" -- no ground truth colors (white-only), skipping scoring")
|
||||
continue
|
||||
|
||||
scores = score_image(gt_groups, vlm_colors)
|
||||
total_gt += scores['gt_count']
|
||||
total_vlm += scores['vlm_count']
|
||||
total_recall_exact += scores['recall_exact']
|
||||
total_recall_similar += scores['recall_similar']
|
||||
total_recall_missed += len(scores['recall_missed'])
|
||||
total_precision_exact += scores['precision_exact']
|
||||
total_precision_similar += scores['precision_similar']
|
||||
total_precision_extra += len(scores['precision_extra'])
|
||||
|
||||
for group, got in scores['confusions']:
|
||||
confusion_counter[("|".join(group), got)] += 1
|
||||
for group in scores['recall_missed']:
|
||||
missed_counter["|".join(group)] += 1
|
||||
for ec in scores['precision_extra']:
|
||||
extra_counter[ec] += 1
|
||||
|
||||
tag, status_parts = score_and_format(gt_groups, vlm_colors, scores)
|
||||
print(f" {tag} {', '.join(status_parts)}")
|
||||
|
||||
per_image_results.append({
|
||||
'file': image_path.name,
|
||||
'tag': tag,
|
||||
'scores': scores,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f" PARSE ERROR: {e}")
|
||||
errors += 1
|
||||
|
||||
total_time = time.time() - start_all
|
||||
|
||||
print_summary(
|
||||
GEMINI_MODEL, total_gt, total_vlm, total_recall_exact,
|
||||
total_recall_similar, total_recall_missed, total_precision_exact,
|
||||
total_precision_similar, total_precision_extra, confusion_counter,
|
||||
missed_counter, extra_counter, per_image_results, len(image_files),
|
||||
errors, total_time,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user