#!/usr/bin/env python3 """ Prepare test data from LogoDet-3K dataset. This script: 1. Scans LogoDet-3K for images and XML annotation files 2. Extracts cropped logos using bounding box data and saves to reference_logos/ 3. Copies full images to test_images/ with unique filenames 4. Creates a SQLite database for storing mappings and verification """ import sqlite3 import shutil import xml.etree.ElementTree as ET from pathlib import Path from PIL import Image from tqdm import tqdm def parse_xml_annotation(xml_path: Path) -> dict: """Parse Pascal VOC format XML annotation file.""" tree = ET.parse(xml_path) root = tree.getroot() annotation = { "filename": root.find("filename").text, "size": { "width": int(root.find("size/width").text), "height": int(root.find("size/height").text), }, "objects": [] } for obj in root.findall("object"): bbox = obj.find("bndbox") annotation["objects"].append({ "name": obj.find("name").text, "xmin": int(bbox.find("xmin").text), "ymin": int(bbox.find("ymin").text), "xmax": int(bbox.find("xmax").text), "ymax": int(bbox.find("ymax").text), }) return annotation def sanitize_filename(name: str) -> str: """Convert logo name to a safe filename.""" # Replace problematic characters safe = name.replace("/", "_").replace("\\", "_").replace(" ", "_") safe = safe.replace(":", "_").replace("*", "_").replace("?", "_") safe = safe.replace('"', "_").replace("<", "_").replace(">", "_") safe = safe.replace("|", "_") return safe def init_database(db_path: Path) -> sqlite3.Connection: """Initialize SQLite database with schema.""" # Remove existing database if present if db_path.exists(): db_path.unlink() conn = sqlite3.connect(db_path) cursor = conn.cursor() # Create tables cursor.executescript(""" -- Test images table CREATE TABLE test_images ( id INTEGER PRIMARY KEY AUTOINCREMENT, filename TEXT UNIQUE NOT NULL ); -- Logo names table (unique brand/logo identifiers) CREATE TABLE logo_names ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT UNIQUE NOT NULL ); -- Reference logos table with foreign keys CREATE TABLE reference_logos ( id INTEGER PRIMARY KEY AUTOINCREMENT, filename TEXT UNIQUE NOT NULL, test_image_id INTEGER NOT NULL, logo_name_id INTEGER NOT NULL, FOREIGN KEY (test_image_id) REFERENCES test_images(id), FOREIGN KEY (logo_name_id) REFERENCES logo_names(id) ); -- Statistics table for metadata CREATE TABLE statistics ( key TEXT PRIMARY KEY, value INTEGER NOT NULL ); -- Indexes for faster lookups CREATE INDEX idx_reference_logos_test_image ON reference_logos(test_image_id); CREATE INDEX idx_reference_logos_logo_name ON reference_logos(logo_name_id); """) conn.commit() return conn def get_or_create_logo_name(cursor: sqlite3.Cursor, name: str) -> int: """Get existing logo_name id or create new one.""" cursor.execute("SELECT id FROM logo_names WHERE name = ?", (name,)) row = cursor.fetchone() if row: return row[0] cursor.execute("INSERT INTO logo_names (name) VALUES (?)", (name,)) return cursor.lastrowid def main(): # Use script directory as base path for portability base_dir = Path(__file__).parent.resolve() # Paths relative to script location dataset_dir = base_dir / "LogoDet-3K" reference_dir = base_dir / "reference_logos" test_images_dir = base_dir / "test_images" db_path = base_dir / "test_data_mapping.db" # Ensure output directories exist reference_dir.mkdir(exist_ok=True) test_images_dir.mkdir(exist_ok=True) # Initialize database print(f"Initializing database at {db_path}...") conn = init_database(db_path) cursor = conn.cursor() # Find all XML files print("Scanning for XML annotation files...") xml_files = list(dataset_dir.rglob("*.xml")) print(f"Found {len(xml_files)} annotation files") # Track unique filenames to avoid conflicts (keyed by subdirectory tuple) used_test_filenames = {} used_ref_filenames = {} # Counters for progress stats = { "images_processed": 0, "logos_extracted": 0, "skipped_missing_image": 0, "skipped_invalid_bbox": 0, } # Process each XML file print("\nProcessing annotations...") for xml_path in tqdm(xml_files, desc="Processing", unit="file"): try: annotation = parse_xml_annotation(xml_path) except Exception as e: tqdm.write(f"Error parsing {xml_path}: {e}") continue # Find corresponding image file image_filename = annotation["filename"] image_path = xml_path.parent / image_filename if not image_path.exists(): # Try common extensions for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]: alt_path = xml_path.parent / (xml_path.stem + ext) if alt_path.exists(): image_path = alt_path break if not image_path.exists(): stats["skipped_missing_image"] += 1 continue # Generate unique test image filename # Use category/brand/original_name format to avoid conflicts rel_path = xml_path.relative_to(dataset_dir) category = rel_path.parts[0] if len(rel_path.parts) > 0 else "unknown" brand = rel_path.parts[1] if len(rel_path.parts) > 1 else "unknown" safe_category = sanitize_filename(category) safe_brand = sanitize_filename(brand) base_name = image_path.stem ext = image_path.suffix # Create subdirectory structure: category/brand/ test_subdir = test_images_dir / safe_category / safe_brand test_subdir.mkdir(parents=True, exist_ok=True) test_basename = f"{base_name}{ext}" # Handle duplicates within subdirectory counter = 1 while test_basename in used_test_filenames.get((safe_category, safe_brand), set()): test_basename = f"{base_name}_{counter}{ext}" counter += 1 used_test_filenames.setdefault((safe_category, safe_brand), set()).add(test_basename) # Store relative path from test_images_dir for database test_filename = f"{safe_category}/{safe_brand}/{test_basename}" # Copy full image to test_images test_image_path = test_subdir / test_basename shutil.copy2(image_path, test_image_path) stats["images_processed"] += 1 # Insert test image into database cursor.execute( "INSERT INTO test_images (filename) VALUES (?)", (test_filename,) ) test_image_id = cursor.lastrowid # Load image for cropping try: img = Image.open(image_path) except Exception as e: tqdm.write(f"Error loading {image_path}: {e}") continue img_width, img_height = img.size # Process each object/logo in the image for obj_idx, obj in enumerate(annotation["objects"]): logo_name = obj["name"] xmin, ymin = obj["xmin"], obj["ymin"] xmax, ymax = obj["xmax"], obj["ymax"] # Validate bounding box if xmin >= xmax or ymin >= ymax: stats["skipped_invalid_bbox"] += 1 continue # Clamp to image bounds xmin = max(0, min(xmin, img_width - 1)) ymin = max(0, min(ymin, img_height - 1)) xmax = max(1, min(xmax, img_width)) ymax = max(1, min(ymax, img_height)) if xmin >= xmax or ymin >= ymax: stats["skipped_invalid_bbox"] += 1 continue # Crop logo region try: logo_crop = img.crop((xmin, ymin, xmax, ymax)) except Exception as e: tqdm.write(f"Error cropping {image_path}: {e}") stats["skipped_invalid_bbox"] += 1 continue # Generate reference logo filename with subdirectory structure: category/logo_name/ safe_logo_name = sanitize_filename(logo_name) ref_subdir = reference_dir / safe_category / safe_logo_name ref_subdir.mkdir(parents=True, exist_ok=True) ref_basename = f"{base_name}_{obj_idx}.png" # Handle duplicates within subdirectory counter = 1 while ref_basename in used_ref_filenames.get((safe_category, safe_logo_name), set()): ref_basename = f"{base_name}_{obj_idx}_{counter}.png" counter += 1 used_ref_filenames.setdefault((safe_category, safe_logo_name), set()).add(ref_basename) # Store relative path from reference_dir for database ref_filename = f"{safe_category}/{safe_logo_name}/{ref_basename}" # Save cropped logo ref_path = ref_subdir / ref_basename try: logo_crop.save(ref_path, "PNG") except Exception as e: tqdm.write(f"Error saving {ref_path}: {e}") continue stats["logos_extracted"] += 1 # Get or create logo_name entry logo_name_id = get_or_create_logo_name(cursor, logo_name) # Insert reference logo into database cursor.execute( "INSERT INTO reference_logos (filename, test_image_id, logo_name_id) VALUES (?, ?, ?)", (ref_filename, test_image_id, logo_name_id) ) # Get unique logo names count cursor.execute("SELECT COUNT(*) FROM logo_names") unique_logo_names = cursor.fetchone()[0] # Save statistics to database statistics_data = [ ("total_test_images", stats["images_processed"]), ("total_reference_logos", stats["logos_extracted"]), ("unique_logo_names", unique_logo_names), ("skipped_missing_image", stats["skipped_missing_image"]), ("skipped_invalid_bbox", stats["skipped_invalid_bbox"]), ] cursor.executemany( "INSERT INTO statistics (key, value) VALUES (?, ?)", statistics_data ) # Commit and close database conn.commit() conn.close() # Print summary print("\n" + "=" * 60) print("SUMMARY") print("=" * 60) print(f"Test images created: {stats['images_processed']:,}") print(f"Reference logos created: {stats['logos_extracted']:,}") print(f"Unique logo names: {unique_logo_names:,}") print(f"Skipped (missing image): {stats['skipped_missing_image']:,}") print(f"Skipped (invalid bbox): {stats['skipped_invalid_bbox']:,}") print(f"\nDatabase saved to: {db_path}") print(f"Reference logos: {reference_dir}") print(f"Test images: {test_images_dir}") if __name__ == "__main__": main()