Files
logo_test/prepare_test_data.py
Rick McEwen ddccf653d2 Initial commit: Logo detection test framework
Add DETR+CLIP based logo detection library and test framework:
- DetectLogosDETR class for logo detection and matching
- Test script with margin-based and multi-ref matching methods
- Data preparation script for test database
- Documentation for API usage and test methodology
2025-12-31 10:42:36 -05:00

322 lines
11 KiB
Python
Executable File

#!/usr/bin/env python3
"""
Prepare test data from LogoDet-3K dataset.
This script:
1. Scans LogoDet-3K for images and XML annotation files
2. Extracts cropped logos using bounding box data and saves to reference_logos/
3. Copies full images to test_images/ with unique filenames
4. Creates a SQLite database for storing mappings and verification
"""
import sqlite3
import shutil
import xml.etree.ElementTree as ET
from pathlib import Path
from PIL import Image
from tqdm import tqdm
def parse_xml_annotation(xml_path: Path) -> dict:
"""Parse Pascal VOC format XML annotation file."""
tree = ET.parse(xml_path)
root = tree.getroot()
annotation = {
"filename": root.find("filename").text,
"size": {
"width": int(root.find("size/width").text),
"height": int(root.find("size/height").text),
},
"objects": []
}
for obj in root.findall("object"):
bbox = obj.find("bndbox")
annotation["objects"].append({
"name": obj.find("name").text,
"xmin": int(bbox.find("xmin").text),
"ymin": int(bbox.find("ymin").text),
"xmax": int(bbox.find("xmax").text),
"ymax": int(bbox.find("ymax").text),
})
return annotation
def sanitize_filename(name: str) -> str:
"""Convert logo name to a safe filename."""
# Replace problematic characters
safe = name.replace("/", "_").replace("\\", "_").replace(" ", "_")
safe = safe.replace(":", "_").replace("*", "_").replace("?", "_")
safe = safe.replace('"', "_").replace("<", "_").replace(">", "_")
safe = safe.replace("|", "_")
return safe
def init_database(db_path: Path) -> sqlite3.Connection:
"""Initialize SQLite database with schema."""
# Remove existing database if present
if db_path.exists():
db_path.unlink()
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Create tables
cursor.executescript("""
-- Test images table
CREATE TABLE test_images (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT UNIQUE NOT NULL
);
-- Logo names table (unique brand/logo identifiers)
CREATE TABLE logo_names (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL
);
-- Reference logos table with foreign keys
CREATE TABLE reference_logos (
id INTEGER PRIMARY KEY AUTOINCREMENT,
filename TEXT UNIQUE NOT NULL,
test_image_id INTEGER NOT NULL,
logo_name_id INTEGER NOT NULL,
FOREIGN KEY (test_image_id) REFERENCES test_images(id),
FOREIGN KEY (logo_name_id) REFERENCES logo_names(id)
);
-- Statistics table for metadata
CREATE TABLE statistics (
key TEXT PRIMARY KEY,
value INTEGER NOT NULL
);
-- Indexes for faster lookups
CREATE INDEX idx_reference_logos_test_image ON reference_logos(test_image_id);
CREATE INDEX idx_reference_logos_logo_name ON reference_logos(logo_name_id);
""")
conn.commit()
return conn
def get_or_create_logo_name(cursor: sqlite3.Cursor, name: str) -> int:
"""Get existing logo_name id or create new one."""
cursor.execute("SELECT id FROM logo_names WHERE name = ?", (name,))
row = cursor.fetchone()
if row:
return row[0]
cursor.execute("INSERT INTO logo_names (name) VALUES (?)", (name,))
return cursor.lastrowid
def main():
# Paths
dataset_dir = Path("/data/dev.python/logo_test/LogoDet-3K")
reference_dir = Path("/data/dev.python/logo_test/reference_logos")
test_images_dir = Path("/data/dev.python/logo_test/test_images")
db_path = Path("/data/dev.python/logo_test/test_data_mapping.db")
# Ensure output directories exist
reference_dir.mkdir(exist_ok=True)
test_images_dir.mkdir(exist_ok=True)
# Initialize database
print(f"Initializing database at {db_path}...")
conn = init_database(db_path)
cursor = conn.cursor()
# Find all XML files
print("Scanning for XML annotation files...")
xml_files = list(dataset_dir.rglob("*.xml"))
print(f"Found {len(xml_files)} annotation files")
# Track unique filenames to avoid conflicts (keyed by subdirectory tuple)
used_test_filenames = {}
used_ref_filenames = {}
# Counters for progress
stats = {
"images_processed": 0,
"logos_extracted": 0,
"skipped_missing_image": 0,
"skipped_invalid_bbox": 0,
}
# Process each XML file
print("\nProcessing annotations...")
for xml_path in tqdm(xml_files, desc="Processing", unit="file"):
try:
annotation = parse_xml_annotation(xml_path)
except Exception as e:
tqdm.write(f"Error parsing {xml_path}: {e}")
continue
# Find corresponding image file
image_filename = annotation["filename"]
image_path = xml_path.parent / image_filename
if not image_path.exists():
# Try common extensions
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
alt_path = xml_path.parent / (xml_path.stem + ext)
if alt_path.exists():
image_path = alt_path
break
if not image_path.exists():
stats["skipped_missing_image"] += 1
continue
# Generate unique test image filename
# Use category/brand/original_name format to avoid conflicts
rel_path = xml_path.relative_to(dataset_dir)
category = rel_path.parts[0] if len(rel_path.parts) > 0 else "unknown"
brand = rel_path.parts[1] if len(rel_path.parts) > 1 else "unknown"
safe_category = sanitize_filename(category)
safe_brand = sanitize_filename(brand)
base_name = image_path.stem
ext = image_path.suffix
# Create subdirectory structure: category/brand/
test_subdir = test_images_dir / safe_category / safe_brand
test_subdir.mkdir(parents=True, exist_ok=True)
test_basename = f"{base_name}{ext}"
# Handle duplicates within subdirectory
counter = 1
while test_basename in used_test_filenames.get((safe_category, safe_brand), set()):
test_basename = f"{base_name}_{counter}{ext}"
counter += 1
used_test_filenames.setdefault((safe_category, safe_brand), set()).add(test_basename)
# Store relative path from test_images_dir for database
test_filename = f"{safe_category}/{safe_brand}/{test_basename}"
# Copy full image to test_images
test_image_path = test_subdir / test_basename
shutil.copy2(image_path, test_image_path)
stats["images_processed"] += 1
# Insert test image into database
cursor.execute(
"INSERT INTO test_images (filename) VALUES (?)",
(test_filename,)
)
test_image_id = cursor.lastrowid
# Load image for cropping
try:
img = Image.open(image_path)
except Exception as e:
tqdm.write(f"Error loading {image_path}: {e}")
continue
img_width, img_height = img.size
# Process each object/logo in the image
for obj_idx, obj in enumerate(annotation["objects"]):
logo_name = obj["name"]
xmin, ymin = obj["xmin"], obj["ymin"]
xmax, ymax = obj["xmax"], obj["ymax"]
# Validate bounding box
if xmin >= xmax or ymin >= ymax:
stats["skipped_invalid_bbox"] += 1
continue
# Clamp to image bounds
xmin = max(0, min(xmin, img_width - 1))
ymin = max(0, min(ymin, img_height - 1))
xmax = max(1, min(xmax, img_width))
ymax = max(1, min(ymax, img_height))
if xmin >= xmax or ymin >= ymax:
stats["skipped_invalid_bbox"] += 1
continue
# Crop logo region
try:
logo_crop = img.crop((xmin, ymin, xmax, ymax))
except Exception as e:
tqdm.write(f"Error cropping {image_path}: {e}")
stats["skipped_invalid_bbox"] += 1
continue
# Generate reference logo filename with subdirectory structure: category/logo_name/
safe_logo_name = sanitize_filename(logo_name)
ref_subdir = reference_dir / safe_category / safe_logo_name
ref_subdir.mkdir(parents=True, exist_ok=True)
ref_basename = f"{base_name}_{obj_idx}.png"
# Handle duplicates within subdirectory
counter = 1
while ref_basename in used_ref_filenames.get((safe_category, safe_logo_name), set()):
ref_basename = f"{base_name}_{obj_idx}_{counter}.png"
counter += 1
used_ref_filenames.setdefault((safe_category, safe_logo_name), set()).add(ref_basename)
# Store relative path from reference_dir for database
ref_filename = f"{safe_category}/{safe_logo_name}/{ref_basename}"
# Save cropped logo
ref_path = ref_subdir / ref_basename
try:
logo_crop.save(ref_path, "PNG")
except Exception as e:
tqdm.write(f"Error saving {ref_path}: {e}")
continue
stats["logos_extracted"] += 1
# Get or create logo_name entry
logo_name_id = get_or_create_logo_name(cursor, logo_name)
# Insert reference logo into database
cursor.execute(
"INSERT INTO reference_logos (filename, test_image_id, logo_name_id) VALUES (?, ?, ?)",
(ref_filename, test_image_id, logo_name_id)
)
# Get unique logo names count
cursor.execute("SELECT COUNT(*) FROM logo_names")
unique_logo_names = cursor.fetchone()[0]
# Save statistics to database
statistics_data = [
("total_test_images", stats["images_processed"]),
("total_reference_logos", stats["logos_extracted"]),
("unique_logo_names", unique_logo_names),
("skipped_missing_image", stats["skipped_missing_image"]),
("skipped_invalid_bbox", stats["skipped_invalid_bbox"]),
]
cursor.executemany(
"INSERT INTO statistics (key, value) VALUES (?, ?)",
statistics_data
)
# Commit and close database
conn.commit()
conn.close()
# Print summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Test images created: {stats['images_processed']:,}")
print(f"Reference logos created: {stats['logos_extracted']:,}")
print(f"Unique logo names: {unique_logo_names:,}")
print(f"Skipped (missing image): {stats['skipped_missing_image']:,}")
print(f"Skipped (invalid bbox): {stats['skipped_invalid_bbox']:,}")
print(f"\nDatabase saved to: {db_path}")
print(f"Reference logos: {reference_dir}")
print(f"Test images: {test_images_dir}")
if __name__ == "__main__":
main()