Test scripts and utilities for evaluating vision-language models on jersey number detection using llama.cpp server.
149 lines
6.0 KiB
Python
149 lines
6.0 KiB
Python
import json
|
|
import cv2
|
|
import numpy as np
|
|
from typing import Dict, Any, Optional
|
|
import logging
|
|
|
|
# Read the default jersey detection prompt
|
|
try:
|
|
with open('jersey_prompt.txt', 'r') as f:
|
|
DEFAULT_JERSEY_PROMPT = f.read()
|
|
except FileNotFoundError:
|
|
# Fallback prompt if file is not found
|
|
DEFAULT_JERSEY_PROMPT = """You are an expert at detecting sports jerseys in images. Carefully examine the provided image and identify all visible sports jerseys.
|
|
|
|
CRITICAL INSTRUCTIONS:
|
|
1. ONLY detect jerseys that are CLEARLY VISIBLE in the image
|
|
2. ONLY include jersey numbers that you can ACTUALLY READ in the image
|
|
3. If you CANNOT see any jerseys, you MUST return {"jerseys": []}
|
|
4. DO NOT make up, imagine, or guess jersey numbers that aren't visible
|
|
5. DO NOT include jerseys if you cannot clearly see the number
|
|
|
|
RESPONSE FORMAT:
|
|
Respond ONLY with a valid JSON object. No explanations, no markdown, no extra text.
|
|
|
|
Use DOUBLE QUOTES (") for all JSON keys and string values.
|
|
|
|
The JSON must have a single key "jerseys" with an array of dictionaries.
|
|
|
|
Each dictionary must have exactly these three keys:
|
|
- "jersey_number": The number on the jersey (as a string, only if clearly visible)
|
|
- "jersey_color": The primary color of the jersey
|
|
- "number_color": The color of the number on the jersey
|
|
|
|
Example response for an image WITH visible jerseys:
|
|
{
|
|
"jerseys": [
|
|
{
|
|
"jersey_number": "101",
|
|
"jersey_color": "red",
|
|
"number_color": "white"
|
|
}
|
|
]
|
|
}
|
|
|
|
Example response for an image WITHOUT jerseys or with unclear numbers:
|
|
{"jerseys": []}
|
|
|
|
REMEMBER: Only include jerseys with numbers you can ACTUALLY SEE in the image. When in doubt, return empty array.
|
|
|
|
Now analyze the image and return the JSON object."""
|
|
|
|
|
|
class DetectJerseys:
|
|
"""A class for detecting sports jerseys using a vision language model."""
|
|
|
|
def __init__(self, llama_cpp_base_url: str = "http://192.168.1.34:8080", logger: Optional[logging.Logger] = None, prompt: Optional[str] = None):
|
|
"""
|
|
Initialize the jersey detection class.
|
|
|
|
Args:
|
|
llama_cpp_base_url: Base URL for the llama.cpp server
|
|
logger: Logger instance for logging messages
|
|
prompt: Custom prompt to use for jersey detection (optional)
|
|
"""
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
self.prompt = prompt or DEFAULT_JERSEY_PROMPT
|
|
|
|
# Import here to avoid circular dependencies
|
|
try:
|
|
from scan_utils.llama_cpp_client import LlamaCppClient
|
|
self.client = LlamaCppClient(base_url=llama_cpp_base_url)
|
|
self.logger.info(f"Jersey detection initialized with llama.cpp server at {llama_cpp_base_url}")
|
|
except ImportError as e:
|
|
self.logger.error(f"Failed to import LlamaCppClient: {e}")
|
|
raise
|
|
|
|
def detect(self, image: np.ndarray, temperature: float = 0.1) -> Dict[str, Any]:
|
|
"""
|
|
Detect jerseys in an image using the vision language model.
|
|
|
|
Args:
|
|
image: OpenCV image (numpy array) to analyze
|
|
temperature: Temperature value for the model (default: 0.1)
|
|
|
|
Returns:
|
|
Dictionary containing detected jerseys or empty dict if invalid
|
|
"""
|
|
try:
|
|
# Create multimodal message with image and prompt
|
|
message = self.client.create_multimodal_message(
|
|
role="user",
|
|
content=self.prompt,
|
|
images=[image]
|
|
)
|
|
|
|
# Send chat completion request
|
|
response = self.client.chat_completion(
|
|
messages=[message],
|
|
temperature=temperature,
|
|
max_tokens=1000
|
|
)
|
|
|
|
# Extract the response text
|
|
if 'choices' in response and len(response['choices']) > 0:
|
|
response_text = response['choices'][0]['message']['content']
|
|
|
|
# Log the raw response for debugging
|
|
self.logger.debug(f"Raw VLM response: {response_text}")
|
|
|
|
# Parse JSON response
|
|
try:
|
|
result = json.loads(response_text)
|
|
|
|
# Process jerseys to ensure they have all required fields
|
|
jerseys = result.get('jerseys', [])
|
|
|
|
# Hallucination detection: filter out example numbers from the prompt
|
|
# Using numbers > 100 as examples to avoid filtering valid jersey numbers
|
|
HALLUCINATION_NUMBERS = {'101', '102', '103', '142', '199'}
|
|
|
|
processed_jerseys = []
|
|
for jersey in jerseys:
|
|
jersey_number = jersey.get('jersey_number', '')
|
|
|
|
# Check for hallucination (model returning example numbers)
|
|
if jersey_number in HALLUCINATION_NUMBERS:
|
|
self.logger.warning(f"Possible hallucination detected - jersey number {jersey_number} matches example pattern. Filtering out.")
|
|
continue
|
|
|
|
# Ensure all required fields are present
|
|
processed_jersey = {
|
|
'jersey_number': jersey_number,
|
|
'jersey_color': jersey.get('jersey_color', ''),
|
|
'number_color': jersey.get('number_color', 'unknown') # Default to 'unknown' if missing
|
|
}
|
|
processed_jerseys.append(processed_jersey)
|
|
|
|
return {"jerseys": processed_jerseys}
|
|
except json.JSONDecodeError as e:
|
|
self.logger.error(f"Failed to parse JSON response: {e}")
|
|
self.logger.debug(f"Response text was: {response_text}")
|
|
return {"jerseys": []}
|
|
else:
|
|
self.logger.warning("Empty response from VLM")
|
|
return {"jerseys": []}
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error during jersey detection: {e}")
|
|
return {"jerseys": []} |