325 lines
11 KiB
Python
Executable File
325 lines
11 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Prepare test data from LogoDet-3K dataset.
|
|
|
|
This script:
|
|
1. Scans LogoDet-3K for images and XML annotation files
|
|
2. Extracts cropped logos using bounding box data and saves to reference_logos/
|
|
3. Copies full images to test_images/ with unique filenames
|
|
4. Creates a SQLite database for storing mappings and verification
|
|
"""
|
|
|
|
import sqlite3
|
|
import shutil
|
|
import xml.etree.ElementTree as ET
|
|
from pathlib import Path
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
|
|
|
|
def parse_xml_annotation(xml_path: Path) -> dict:
|
|
"""Parse Pascal VOC format XML annotation file."""
|
|
tree = ET.parse(xml_path)
|
|
root = tree.getroot()
|
|
|
|
annotation = {
|
|
"filename": root.find("filename").text,
|
|
"size": {
|
|
"width": int(root.find("size/width").text),
|
|
"height": int(root.find("size/height").text),
|
|
},
|
|
"objects": []
|
|
}
|
|
|
|
for obj in root.findall("object"):
|
|
bbox = obj.find("bndbox")
|
|
annotation["objects"].append({
|
|
"name": obj.find("name").text,
|
|
"xmin": int(bbox.find("xmin").text),
|
|
"ymin": int(bbox.find("ymin").text),
|
|
"xmax": int(bbox.find("xmax").text),
|
|
"ymax": int(bbox.find("ymax").text),
|
|
})
|
|
|
|
return annotation
|
|
|
|
|
|
def sanitize_filename(name: str) -> str:
|
|
"""Convert logo name to a safe filename."""
|
|
# Replace problematic characters
|
|
safe = name.replace("/", "_").replace("\\", "_").replace(" ", "_")
|
|
safe = safe.replace(":", "_").replace("*", "_").replace("?", "_")
|
|
safe = safe.replace('"', "_").replace("<", "_").replace(">", "_")
|
|
safe = safe.replace("|", "_")
|
|
return safe
|
|
|
|
|
|
def init_database(db_path: Path) -> sqlite3.Connection:
|
|
"""Initialize SQLite database with schema."""
|
|
# Remove existing database if present
|
|
if db_path.exists():
|
|
db_path.unlink()
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Create tables
|
|
cursor.executescript("""
|
|
-- Test images table
|
|
CREATE TABLE test_images (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
filename TEXT UNIQUE NOT NULL
|
|
);
|
|
|
|
-- Logo names table (unique brand/logo identifiers)
|
|
CREATE TABLE logo_names (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
name TEXT UNIQUE NOT NULL
|
|
);
|
|
|
|
-- Reference logos table with foreign keys
|
|
CREATE TABLE reference_logos (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
filename TEXT UNIQUE NOT NULL,
|
|
test_image_id INTEGER NOT NULL,
|
|
logo_name_id INTEGER NOT NULL,
|
|
FOREIGN KEY (test_image_id) REFERENCES test_images(id),
|
|
FOREIGN KEY (logo_name_id) REFERENCES logo_names(id)
|
|
);
|
|
|
|
-- Statistics table for metadata
|
|
CREATE TABLE statistics (
|
|
key TEXT PRIMARY KEY,
|
|
value INTEGER NOT NULL
|
|
);
|
|
|
|
-- Indexes for faster lookups
|
|
CREATE INDEX idx_reference_logos_test_image ON reference_logos(test_image_id);
|
|
CREATE INDEX idx_reference_logos_logo_name ON reference_logos(logo_name_id);
|
|
""")
|
|
|
|
conn.commit()
|
|
return conn
|
|
|
|
|
|
def get_or_create_logo_name(cursor: sqlite3.Cursor, name: str) -> int:
|
|
"""Get existing logo_name id or create new one."""
|
|
cursor.execute("SELECT id FROM logo_names WHERE name = ?", (name,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
return row[0]
|
|
cursor.execute("INSERT INTO logo_names (name) VALUES (?)", (name,))
|
|
return cursor.lastrowid
|
|
|
|
|
|
def main():
|
|
# Use script directory as base path for portability
|
|
base_dir = Path(__file__).parent.resolve()
|
|
|
|
# Paths relative to script location
|
|
dataset_dir = base_dir / "LogoDet-3K"
|
|
reference_dir = base_dir / "reference_logos"
|
|
test_images_dir = base_dir / "test_images"
|
|
db_path = base_dir / "test_data_mapping.db"
|
|
|
|
# Ensure output directories exist
|
|
reference_dir.mkdir(exist_ok=True)
|
|
test_images_dir.mkdir(exist_ok=True)
|
|
|
|
# Initialize database
|
|
print(f"Initializing database at {db_path}...")
|
|
conn = init_database(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Find all XML files
|
|
print("Scanning for XML annotation files...")
|
|
xml_files = list(dataset_dir.rglob("*.xml"))
|
|
print(f"Found {len(xml_files)} annotation files")
|
|
|
|
# Track unique filenames to avoid conflicts (keyed by subdirectory tuple)
|
|
used_test_filenames = {}
|
|
used_ref_filenames = {}
|
|
|
|
# Counters for progress
|
|
stats = {
|
|
"images_processed": 0,
|
|
"logos_extracted": 0,
|
|
"skipped_missing_image": 0,
|
|
"skipped_invalid_bbox": 0,
|
|
}
|
|
|
|
# Process each XML file
|
|
print("\nProcessing annotations...")
|
|
for xml_path in tqdm(xml_files, desc="Processing", unit="file"):
|
|
try:
|
|
annotation = parse_xml_annotation(xml_path)
|
|
except Exception as e:
|
|
tqdm.write(f"Error parsing {xml_path}: {e}")
|
|
continue
|
|
|
|
# Find corresponding image file
|
|
image_filename = annotation["filename"]
|
|
image_path = xml_path.parent / image_filename
|
|
|
|
if not image_path.exists():
|
|
# Try common extensions
|
|
for ext in [".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG"]:
|
|
alt_path = xml_path.parent / (xml_path.stem + ext)
|
|
if alt_path.exists():
|
|
image_path = alt_path
|
|
break
|
|
|
|
if not image_path.exists():
|
|
stats["skipped_missing_image"] += 1
|
|
continue
|
|
|
|
# Generate unique test image filename
|
|
# Use category/brand/original_name format to avoid conflicts
|
|
rel_path = xml_path.relative_to(dataset_dir)
|
|
category = rel_path.parts[0] if len(rel_path.parts) > 0 else "unknown"
|
|
brand = rel_path.parts[1] if len(rel_path.parts) > 1 else "unknown"
|
|
|
|
safe_category = sanitize_filename(category)
|
|
safe_brand = sanitize_filename(brand)
|
|
base_name = image_path.stem
|
|
ext = image_path.suffix
|
|
|
|
# Create subdirectory structure: category/brand/
|
|
test_subdir = test_images_dir / safe_category / safe_brand
|
|
test_subdir.mkdir(parents=True, exist_ok=True)
|
|
|
|
test_basename = f"{base_name}{ext}"
|
|
|
|
# Handle duplicates within subdirectory
|
|
counter = 1
|
|
while test_basename in used_test_filenames.get((safe_category, safe_brand), set()):
|
|
test_basename = f"{base_name}_{counter}{ext}"
|
|
counter += 1
|
|
used_test_filenames.setdefault((safe_category, safe_brand), set()).add(test_basename)
|
|
|
|
# Store relative path from test_images_dir for database
|
|
test_filename = f"{safe_category}/{safe_brand}/{test_basename}"
|
|
|
|
# Copy full image to test_images
|
|
test_image_path = test_subdir / test_basename
|
|
shutil.copy2(image_path, test_image_path)
|
|
stats["images_processed"] += 1
|
|
|
|
# Insert test image into database
|
|
cursor.execute(
|
|
"INSERT INTO test_images (filename) VALUES (?)",
|
|
(test_filename,)
|
|
)
|
|
test_image_id = cursor.lastrowid
|
|
|
|
# Load image for cropping
|
|
try:
|
|
img = Image.open(image_path)
|
|
except Exception as e:
|
|
tqdm.write(f"Error loading {image_path}: {e}")
|
|
continue
|
|
|
|
img_width, img_height = img.size
|
|
|
|
# Process each object/logo in the image
|
|
for obj_idx, obj in enumerate(annotation["objects"]):
|
|
logo_name = obj["name"]
|
|
xmin, ymin = obj["xmin"], obj["ymin"]
|
|
xmax, ymax = obj["xmax"], obj["ymax"]
|
|
|
|
# Validate bounding box
|
|
if xmin >= xmax or ymin >= ymax:
|
|
stats["skipped_invalid_bbox"] += 1
|
|
continue
|
|
|
|
# Clamp to image bounds
|
|
xmin = max(0, min(xmin, img_width - 1))
|
|
ymin = max(0, min(ymin, img_height - 1))
|
|
xmax = max(1, min(xmax, img_width))
|
|
ymax = max(1, min(ymax, img_height))
|
|
|
|
if xmin >= xmax or ymin >= ymax:
|
|
stats["skipped_invalid_bbox"] += 1
|
|
continue
|
|
|
|
# Crop logo region
|
|
try:
|
|
logo_crop = img.crop((xmin, ymin, xmax, ymax))
|
|
except Exception as e:
|
|
tqdm.write(f"Error cropping {image_path}: {e}")
|
|
stats["skipped_invalid_bbox"] += 1
|
|
continue
|
|
|
|
# Generate reference logo filename with subdirectory structure: category/logo_name/
|
|
safe_logo_name = sanitize_filename(logo_name)
|
|
ref_subdir = reference_dir / safe_category / safe_logo_name
|
|
ref_subdir.mkdir(parents=True, exist_ok=True)
|
|
|
|
ref_basename = f"{base_name}_{obj_idx}.png"
|
|
|
|
# Handle duplicates within subdirectory
|
|
counter = 1
|
|
while ref_basename in used_ref_filenames.get((safe_category, safe_logo_name), set()):
|
|
ref_basename = f"{base_name}_{obj_idx}_{counter}.png"
|
|
counter += 1
|
|
used_ref_filenames.setdefault((safe_category, safe_logo_name), set()).add(ref_basename)
|
|
|
|
# Store relative path from reference_dir for database
|
|
ref_filename = f"{safe_category}/{safe_logo_name}/{ref_basename}"
|
|
|
|
# Save cropped logo
|
|
ref_path = ref_subdir / ref_basename
|
|
try:
|
|
logo_crop.save(ref_path, "PNG")
|
|
except Exception as e:
|
|
tqdm.write(f"Error saving {ref_path}: {e}")
|
|
continue
|
|
|
|
stats["logos_extracted"] += 1
|
|
|
|
# Get or create logo_name entry
|
|
logo_name_id = get_or_create_logo_name(cursor, logo_name)
|
|
|
|
# Insert reference logo into database
|
|
cursor.execute(
|
|
"INSERT INTO reference_logos (filename, test_image_id, logo_name_id) VALUES (?, ?, ?)",
|
|
(ref_filename, test_image_id, logo_name_id)
|
|
)
|
|
|
|
# Get unique logo names count
|
|
cursor.execute("SELECT COUNT(*) FROM logo_names")
|
|
unique_logo_names = cursor.fetchone()[0]
|
|
|
|
# Save statistics to database
|
|
statistics_data = [
|
|
("total_test_images", stats["images_processed"]),
|
|
("total_reference_logos", stats["logos_extracted"]),
|
|
("unique_logo_names", unique_logo_names),
|
|
("skipped_missing_image", stats["skipped_missing_image"]),
|
|
("skipped_invalid_bbox", stats["skipped_invalid_bbox"]),
|
|
]
|
|
cursor.executemany(
|
|
"INSERT INTO statistics (key, value) VALUES (?, ?)",
|
|
statistics_data
|
|
)
|
|
|
|
# Commit and close database
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
# Print summary
|
|
print("\n" + "=" * 60)
|
|
print("SUMMARY")
|
|
print("=" * 60)
|
|
print(f"Test images created: {stats['images_processed']:,}")
|
|
print(f"Reference logos created: {stats['logos_extracted']:,}")
|
|
print(f"Unique logo names: {unique_logo_names:,}")
|
|
print(f"Skipped (missing image): {stats['skipped_missing_image']:,}")
|
|
print(f"Skipped (invalid bbox): {stats['skipped_invalid_bbox']:,}")
|
|
print(f"\nDatabase saved to: {db_path}")
|
|
print(f"Reference logos: {reference_dir}")
|
|
print(f"Test images: {test_images_dir}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |