Initial commit: Logo detection test framework
Add DETR+CLIP based logo detection library and test framework: - DetectLogosDETR class for logo detection and matching - Test script with margin-based and multi-ref matching methods - Data preparation script for test database - Documentation for API usage and test methodology
This commit is contained in:
322
prepare_test_data.py
Executable file
322
prepare_test_data.py
Executable file
@ -0,0 +1,322 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Prepare test data from LogoDet-3K dataset.
|
||||
|
||||
This script:
|
||||
1. Scans LogoDet-3K for images and XML annotation files
|
||||
2. Extracts cropped logos using bounding box data and saves to reference_logos/
|
||||
3. Copies full images to test_images/ with unique filenames
|
||||
4. Creates a SQLite database for storing mappings and verification
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import shutil
|
||||
import xml.etree.ElementTree as ET
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def parse_xml_annotation(xml_path: Path) -> dict:
|
||||
"""Parse Pascal VOC format XML annotation file."""
|
||||
tree = ET.parse(xml_path)
|
||||
root = tree.getroot()
|
||||
|
||||
annotation = {
|
||||
"filename": root.find("filename").text,
|
||||
"size": {
|
||||
"width": int(root.find("size/width").text),
|
||||
"height": int(root.find("size/height").text),
|
||||
},
|
||||
"objects": []
|
||||
}
|
||||
|
||||
for obj in root.findall("object"):
|
||||
bbox = obj.find("bndbox")
|
||||
annotation["objects"].append({
|
||||
"name": obj.find("name").text,
|
||||
"xmin": int(bbox.find("xmin").text),
|
||||
"ymin": int(bbox.find("ymin").text),
|
||||
"xmax": int(bbox.find("xmax").text),
|
||||
"ymax": int(bbox.find("ymax").text),
|
||||
})
|
||||
|
||||
return annotation
|
||||
|
||||
|
||||
def sanitize_filename(name: str) -> str:
|
||||
"""Convert logo name to a safe filename."""
|
||||
# Replace problematic characters
|
||||
safe = name.replace("/", "_").replace("\\", "_").replace(" ", "_")
|
||||
safe = safe.replace(":", "_").replace("*", "_").replace("?", "_")
|
||||
safe = safe.replace('"', "_").replace("<", "_").replace(">", "_")
|
||||
safe = safe.replace("|", "_")
|
||||
return safe
|
||||
|
||||
|
||||
def init_database(db_path: Path) -> sqlite3.Connection:
|
||||
"""Initialize SQLite database with schema."""
|
||||
# Remove existing database if present
|
||||
if db_path.exists():
|
||||
db_path.unlink()
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create tables
|
||||
cursor.executescript("""
|
||||
-- Test images table
|
||||
CREATE TABLE test_images (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
filename TEXT UNIQUE NOT NULL
|
||||
);
|
||||
|
||||
-- Logo names table (unique brand/logo identifiers)
|
||||
CREATE TABLE logo_names (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL
|
||||
);
|
||||
|
||||
-- Reference logos table with foreign keys
|
||||
CREATE TABLE reference_logos (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
filename TEXT UNIQUE NOT NULL,
|
||||
test_image_id INTEGER NOT NULL,
|
||||
logo_name_id INTEGER NOT NULL,
|
||||
FOREIGN KEY (test_image_id) REFERENCES test_images(id),
|
||||
FOREIGN KEY (logo_name_id) REFERENCES logo_names(id)
|
||||
);
|
||||
|
||||
-- Statistics table for metadata
|
||||
CREATE TABLE statistics (
|
||||
key TEXT PRIMARY KEY,
|
||||
value INTEGER NOT NULL
|
||||
);
|
||||
|
||||
-- Indexes for faster lookups
|
||||
CREATE INDEX idx_reference_logos_test_image ON reference_logos(test_image_id);
|
||||
CREATE INDEX idx_reference_logos_logo_name ON reference_logos(logo_name_id);
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
return conn
|
||||
|
||||
|
||||
def get_or_create_logo_name(cursor: sqlite3.Cursor, name: str) -> int:
|
||||
"""Get existing logo_name id or create new one."""
|
||||
cursor.execute("SELECT id FROM logo_names WHERE name = ?", (name,))
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
cursor.execute("INSERT INTO logo_names (name) VALUES (?)", (name,))
|
||||
return cursor.lastrowid
|
||||
|
||||
|
||||
def main():
|
||||
# Paths
|
||||
dataset_dir = Path("/data/dev.python/logo_test/LogoDet-3K")
|
||||
reference_dir = Path("/data/dev.python/logo_test/reference_logos")
|
||||
test_images_dir = Path("/data/dev.python/logo_test/test_images")
|
||||
db_path = Path("/data/dev.python/logo_test/test_data_mapping.db")
|
||||
|
||||
# Ensure output directories exist
|
||||
reference_dir.mkdir(exist_ok=True)
|
||||
test_images_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Initialize database
|
||||
print(f"Initializing database at {db_path}...")
|
||||
conn = init_database(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Find all XML files
|
||||
print("Scanning for XML annotation files...")
|
||||
xml_files = list(dataset_dir.rglob("*.xml"))
|
||||
print(f"Found {len(xml_files)} annotation files")
|
||||
|
||||
# Track unique filenames to avoid conflicts (keyed by subdirectory tuple)
|
||||
used_test_filenames = {}
|
||||
used_ref_filenames = {}
|
||||
|
||||
# Counters for progress
|
||||
stats = {
|
||||
"images_processed": 0,
|
||||
"logos_extracted": 0,
|
||||
"skipped_missing_image": 0,
|
||||
"skipped_invalid_bbox": 0,
|
||||
}
|
||||
|
||||
# Process each XML file
|
||||
print("\nProcessing annotations...")
|
||||
for xml_path in tqdm(xml_files, desc="Processing", unit="file"):
|
||||
try:
|
||||
annotation = parse_xml_annotation(xml_path)
|
||||
except Exception as e:
|
||||
tqdm.write(f"Error parsing {xml_path}: {e}")
|
||||
continue
|
||||
|
||||
# Find corresponding image file
|
||||
image_filename = annotation["filename"]
|
||||
image_path = xml_path.parent / image_filename
|
||||
|
||||
if not image_path.exists():
|
||||
# Try common extensions
|
||||
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
|
||||
alt_path = xml_path.parent / (xml_path.stem + ext)
|
||||
if alt_path.exists():
|
||||
image_path = alt_path
|
||||
break
|
||||
|
||||
if not image_path.exists():
|
||||
stats["skipped_missing_image"] += 1
|
||||
continue
|
||||
|
||||
# Generate unique test image filename
|
||||
# Use category/brand/original_name format to avoid conflicts
|
||||
rel_path = xml_path.relative_to(dataset_dir)
|
||||
category = rel_path.parts[0] if len(rel_path.parts) > 0 else "unknown"
|
||||
brand = rel_path.parts[1] if len(rel_path.parts) > 1 else "unknown"
|
||||
|
||||
safe_category = sanitize_filename(category)
|
||||
safe_brand = sanitize_filename(brand)
|
||||
base_name = image_path.stem
|
||||
ext = image_path.suffix
|
||||
|
||||
# Create subdirectory structure: category/brand/
|
||||
test_subdir = test_images_dir / safe_category / safe_brand
|
||||
test_subdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
test_basename = f"{base_name}{ext}"
|
||||
|
||||
# Handle duplicates within subdirectory
|
||||
counter = 1
|
||||
while test_basename in used_test_filenames.get((safe_category, safe_brand), set()):
|
||||
test_basename = f"{base_name}_{counter}{ext}"
|
||||
counter += 1
|
||||
used_test_filenames.setdefault((safe_category, safe_brand), set()).add(test_basename)
|
||||
|
||||
# Store relative path from test_images_dir for database
|
||||
test_filename = f"{safe_category}/{safe_brand}/{test_basename}"
|
||||
|
||||
# Copy full image to test_images
|
||||
test_image_path = test_subdir / test_basename
|
||||
shutil.copy2(image_path, test_image_path)
|
||||
stats["images_processed"] += 1
|
||||
|
||||
# Insert test image into database
|
||||
cursor.execute(
|
||||
"INSERT INTO test_images (filename) VALUES (?)",
|
||||
(test_filename,)
|
||||
)
|
||||
test_image_id = cursor.lastrowid
|
||||
|
||||
# Load image for cropping
|
||||
try:
|
||||
img = Image.open(image_path)
|
||||
except Exception as e:
|
||||
tqdm.write(f"Error loading {image_path}: {e}")
|
||||
continue
|
||||
|
||||
img_width, img_height = img.size
|
||||
|
||||
# Process each object/logo in the image
|
||||
for obj_idx, obj in enumerate(annotation["objects"]):
|
||||
logo_name = obj["name"]
|
||||
xmin, ymin = obj["xmin"], obj["ymin"]
|
||||
xmax, ymax = obj["xmax"], obj["ymax"]
|
||||
|
||||
# Validate bounding box
|
||||
if xmin >= xmax or ymin >= ymax:
|
||||
stats["skipped_invalid_bbox"] += 1
|
||||
continue
|
||||
|
||||
# Clamp to image bounds
|
||||
xmin = max(0, min(xmin, img_width - 1))
|
||||
ymin = max(0, min(ymin, img_height - 1))
|
||||
xmax = max(1, min(xmax, img_width))
|
||||
ymax = max(1, min(ymax, img_height))
|
||||
|
||||
if xmin >= xmax or ymin >= ymax:
|
||||
stats["skipped_invalid_bbox"] += 1
|
||||
continue
|
||||
|
||||
# Crop logo region
|
||||
try:
|
||||
logo_crop = img.crop((xmin, ymin, xmax, ymax))
|
||||
except Exception as e:
|
||||
tqdm.write(f"Error cropping {image_path}: {e}")
|
||||
stats["skipped_invalid_bbox"] += 1
|
||||
continue
|
||||
|
||||
# Generate reference logo filename with subdirectory structure: category/logo_name/
|
||||
safe_logo_name = sanitize_filename(logo_name)
|
||||
ref_subdir = reference_dir / safe_category / safe_logo_name
|
||||
ref_subdir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ref_basename = f"{base_name}_{obj_idx}.png"
|
||||
|
||||
# Handle duplicates within subdirectory
|
||||
counter = 1
|
||||
while ref_basename in used_ref_filenames.get((safe_category, safe_logo_name), set()):
|
||||
ref_basename = f"{base_name}_{obj_idx}_{counter}.png"
|
||||
counter += 1
|
||||
used_ref_filenames.setdefault((safe_category, safe_logo_name), set()).add(ref_basename)
|
||||
|
||||
# Store relative path from reference_dir for database
|
||||
ref_filename = f"{safe_category}/{safe_logo_name}/{ref_basename}"
|
||||
|
||||
# Save cropped logo
|
||||
ref_path = ref_subdir / ref_basename
|
||||
try:
|
||||
logo_crop.save(ref_path, "PNG")
|
||||
except Exception as e:
|
||||
tqdm.write(f"Error saving {ref_path}: {e}")
|
||||
continue
|
||||
|
||||
stats["logos_extracted"] += 1
|
||||
|
||||
# Get or create logo_name entry
|
||||
logo_name_id = get_or_create_logo_name(cursor, logo_name)
|
||||
|
||||
# Insert reference logo into database
|
||||
cursor.execute(
|
||||
"INSERT INTO reference_logos (filename, test_image_id, logo_name_id) VALUES (?, ?, ?)",
|
||||
(ref_filename, test_image_id, logo_name_id)
|
||||
)
|
||||
|
||||
# Get unique logo names count
|
||||
cursor.execute("SELECT COUNT(*) FROM logo_names")
|
||||
unique_logo_names = cursor.fetchone()[0]
|
||||
|
||||
# Save statistics to database
|
||||
statistics_data = [
|
||||
("total_test_images", stats["images_processed"]),
|
||||
("total_reference_logos", stats["logos_extracted"]),
|
||||
("unique_logo_names", unique_logo_names),
|
||||
("skipped_missing_image", stats["skipped_missing_image"]),
|
||||
("skipped_invalid_bbox", stats["skipped_invalid_bbox"]),
|
||||
]
|
||||
cursor.executemany(
|
||||
"INSERT INTO statistics (key, value) VALUES (?, ?)",
|
||||
statistics_data
|
||||
)
|
||||
|
||||
# Commit and close database
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Test images created: {stats['images_processed']:,}")
|
||||
print(f"Reference logos created: {stats['logos_extracted']:,}")
|
||||
print(f"Unique logo names: {unique_logo_names:,}")
|
||||
print(f"Skipped (missing image): {stats['skipped_missing_image']:,}")
|
||||
print(f"Skipped (invalid bbox): {stats['skipped_invalid_bbox']:,}")
|
||||
print(f"\nDatabase saved to: {db_path}")
|
||||
print(f"Reference logos: {reference_dir}")
|
||||
print(f"Test images: {test_images_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user