#!/usr/bin/env python3 """ Export a trained CLIP model to HuggingFace-compatible format. This script converts a training checkpoint to a format that can be loaded by DetectLogosDETR for inference. Usage: uv run python export_model.py \ --checkpoint checkpoints/best.pt \ --output models/logo_detection/clip_finetuned # With custom base model uv run python export_model.py \ --checkpoint checkpoints/best.pt \ --output models/logo_detection/clip_finetuned \ --base-model openai/clip-vit-large-patch14 """ import argparse import json import logging import sys from pathlib import Path import torch from training.config import TrainingConfig from training.model import create_model, LogoFineTunedCLIP def setup_logging() -> logging.Logger: logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) return logging.getLogger(__name__) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Export trained CLIP model for inference", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "--checkpoint", type=str, required=True, help="Path to training checkpoint (.pt file)", ) parser.add_argument( "--output", type=str, required=True, help="Output directory for exported model", ) parser.add_argument( "--base-model", type=str, default=None, help="Base CLIP model (reads from checkpoint config if not specified)", ) parser.add_argument( "--merge-lora", action="store_true", help="Merge LoRA weights into base model (reduces inference overhead)", ) return parser.parse_args() def main(): args = parse_args() logger = setup_logging() logger.info("CLIP Model Export") logger.info("=" * 60) # Check checkpoint exists checkpoint_path = Path(args.checkpoint) if not checkpoint_path.exists(): logger.error(f"Checkpoint not found: {checkpoint_path}") sys.exit(1) # Load checkpoint logger.info(f"Loading checkpoint: {checkpoint_path}") checkpoint = torch.load(checkpoint_path, map_location="cpu") # Get config from checkpoint if "config" in checkpoint: config_dict = checkpoint["config"] base_model = args.base_model or config_dict.get( "base_model", "openai/clip-vit-large-patch14" ) lora_r = config_dict.get("lora_r", 16) lora_alpha = config_dict.get("lora_alpha", 32) freeze_layers = config_dict.get("freeze_layers", 12) else: base_model = args.base_model or "openai/clip-vit-large-patch14" lora_r = 16 lora_alpha = 32 freeze_layers = 12 logger.info(f"Base model: {base_model}") logger.info(f"LoRA rank: {lora_r}") logger.info(f"Freeze layers: {freeze_layers}") # Create model with same architecture logger.info("Creating model architecture...") model, processor = create_model( base_model=base_model, lora_r=lora_r, lora_alpha=lora_alpha, freeze_layers=freeze_layers, use_gradient_checkpointing=False, # Not needed for export ) # Load weights logger.info("Loading trained weights...") model.load_state_dict(checkpoint["model_state_dict"]) # Merge LoRA if requested if args.merge_lora and model.peft_applied: try: logger.info("Merging LoRA weights into base model...") model.vision_model = model.vision_model.merge_and_unload() model.peft_applied = False model.lora_r = 0 logger.info("LoRA weights merged successfully") except Exception as e: logger.warning(f"Could not merge LoRA weights: {e}") logger.warning("Exporting with separate LoRA weights") # Create output directory output_path = Path(args.output) output_path.mkdir(parents=True, exist_ok=True) # Save model logger.info(f"Exporting to: {output_path}") model.save_pretrained(str(output_path)) # Save processor config for reference processor.save_pretrained(str(output_path / "processor")) # Save additional metadata metadata = { "base_model": base_model, "source_checkpoint": str(checkpoint_path), "training_epochs": checkpoint.get("epoch", -1) + 1, "best_val_loss": checkpoint.get("best_val_loss", None), "best_val_separation": checkpoint.get("best_val_separation", None), "lora_merged": args.merge_lora and not model.peft_applied, } with open(output_path / "export_metadata.json", "w") as f: json.dump(metadata, f, indent=2) logger.info("\nExport complete!") logger.info(f"Model saved to: {output_path}") logger.info("\nTo use with DetectLogosDETR:") logger.info(f" detector = DetectLogosDETR(embedding_model='{output_path}')") logger.info("\nOr with test_logo_detection.py:") logger.info(f" uv run python test_logo_detection.py -e {output_path}") if __name__ == "__main__": main()