#!/usr/bin/env python3 """ Fine-tune CLIP vision encoder for logo recognition. This script trains a CLIP model using contrastive learning on the LogoDet-3K dataset to improve logo embedding quality for similarity-based matching. Usage: # Train with YAML config uv run python train_clip_logo.py --config configs/jetson_orin.yaml # Train with command-line overrides uv run python train_clip_logo.py --config configs/jetson_orin.yaml \ --learning-rate 5e-6 --max-epochs 30 # Resume from checkpoint uv run python train_clip_logo.py --config configs/jetson_orin.yaml \ --resume checkpoints/epoch_10.pt """ import argparse import logging import random import sys from pathlib import Path import numpy as np import torch from training.config import TrainingConfig from training.dataset import create_dataloaders from training.model import create_model from training.trainer import Trainer def setup_logging(log_level: str = "INFO") -> logging.Logger: """Configure logging.""" logging.basicConfig( level=getattr(logging, log_level.upper()), format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) return logging.getLogger(__name__) def set_seed(seed: int) -> None: """Set random seeds for reproducibility.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def parse_args() -> argparse.Namespace: """Parse command-line arguments.""" parser = argparse.ArgumentParser( description="Fine-tune CLIP for logo recognition", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Config file parser.add_argument( "--config", type=str, help="Path to YAML configuration file", ) # Dataset paths parser.add_argument( "--dataset-dir", type=str, help="Path to LogoDet-3K dataset", ) parser.add_argument( "--reference-dir", type=str, help="Path to reference logos directory", ) parser.add_argument( "--db-path", type=str, help="Path to SQLite database", ) # Model parser.add_argument( "--base-model", type=str, help="Base CLIP model name or path", ) parser.add_argument( "--lora-r", type=int, help="LoRA rank (0 to disable)", ) parser.add_argument( "--freeze-layers", type=int, help="Number of transformer layers to freeze", ) # Training parser.add_argument( "--batch-size", type=int, help="Batch size", ) parser.add_argument( "--learning-rate", type=float, help="Learning rate", ) parser.add_argument( "--max-epochs", type=int, help="Maximum number of epochs", ) parser.add_argument( "--gradient-accumulation-steps", type=int, help="Gradient accumulation steps", ) # Loss parser.add_argument( "--temperature", type=float, help="Temperature for InfoNCE loss", ) parser.add_argument( "--loss-type", choices=["infonce", "supcon", "triplet", "combined"], help="Loss function type", ) # Checkpointing parser.add_argument( "--checkpoint-dir", type=str, help="Directory for checkpoints", ) parser.add_argument( "--output-dir", type=str, help="Directory for final model output", ) parser.add_argument( "--resume", type=str, help="Path to checkpoint to resume from", ) # Other parser.add_argument( "--seed", type=int, help="Random seed", ) parser.add_argument( "--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging level", ) parser.add_argument( "--no-mixed-precision", action="store_true", help="Disable mixed precision training", ) return parser.parse_args() def main(): """Main training entry point.""" args = parse_args() # Setup logging logger = setup_logging(args.log_level) logger.info("CLIP Logo Fine-Tuning") logger.info("=" * 60) # Load or create configuration if args.config: logger.info(f"Loading config from: {args.config}") config = TrainingConfig.from_yaml(args.config) else: logger.info("Using default configuration") config = TrainingConfig() # Apply command-line overrides override_fields = [ "dataset_dir", "reference_dir", "db_path", "base_model", "lora_r", "freeze_layers", "batch_size", "learning_rate", "max_epochs", "gradient_accumulation_steps", "temperature", "loss_type", "checkpoint_dir", "output_dir", "seed", ] for field in override_fields: arg_name = field.replace("_", "-") arg_value = getattr(args, field.replace("-", "_"), None) if arg_value is not None: setattr(config, field, arg_value) logger.info(f"Override: {field} = {arg_value}") if args.no_mixed_precision: config.mixed_precision = False logger.info("Override: mixed_precision = False") # Validate configuration warnings = config.validate() for warning in warnings: logger.warning(f"Config warning: {warning}") # Set random seed set_seed(config.seed) logger.info(f"Random seed: {config.seed}") # Check paths exist db_path = Path(config.db_path) ref_dir = Path(config.reference_dir) if not db_path.exists(): logger.error(f"Database not found: {db_path}") logger.error("Run prepare_test_data.py first to create the database.") sys.exit(1) if not ref_dir.exists(): logger.error(f"Reference directory not found: {ref_dir}") logger.error("Run prepare_test_data.py first to extract reference logos.") sys.exit(1) # Create model logger.info(f"Creating model from: {config.base_model}") model, processor = create_model( base_model=config.base_model, lora_r=config.lora_r, lora_alpha=config.lora_alpha, lora_dropout=config.lora_dropout, freeze_layers=config.freeze_layers, use_gradient_checkpointing=config.use_gradient_checkpointing, ) # Create dataloaders logger.info("Creating dataloaders...") train_loader, val_loader, test_loader = create_dataloaders( db_path=str(config.db_path), reference_dir=str(config.reference_dir), batch_size=config.batch_size, logos_per_batch=config.logos_per_batch, samples_per_logo=config.samples_per_logo, num_workers=config.num_workers, train_split=config.train_split, val_split=config.val_split, test_split=config.test_split, seed=config.seed, augmentation_strength=config.augmentation_strength, ) # Create trainer trainer = Trainer( model=model, train_loader=train_loader, val_loader=val_loader, config=config, logger=logger, ) # Resume from checkpoint if specified if args.resume: resume_path = Path(args.resume) if resume_path.exists(): logger.info(f"Resuming from: {resume_path}") # Set checkpoint dir to resume path's parent if resume_path.is_file(): config.checkpoint_dir = str(resume_path.parent) trainer.load_checkpoint(resume_path.name) else: logger.warning(f"Resume checkpoint not found: {resume_path}") # Train logger.info("\nStarting training...") final_metrics = trainer.train() logger.info("\nTraining complete!") logger.info(f" Best val loss: {final_metrics['best_val_loss']:.4f}") logger.info(f" Best separation: {final_metrics['best_val_separation']:.4f}") logger.info(f" Total epochs: {final_metrics['total_epochs']}") logger.info(f" Total time: {final_metrics['total_time_minutes']:.1f} minutes") # Export model output_path = trainer.export_model() logger.info(f"\nModel exported to: {output_path}") # Print next steps logger.info("\n" + "=" * 60) logger.info("Next steps:") logger.info(f"1. Test the fine-tuned model:") logger.info(f" uv run python test_logo_detection.py -n 50 \\") logger.info(f" -e {output_path} --matching-method multi-ref") logger.info(f"") logger.info(f"2. Compare with baseline:") logger.info(f" uv run python test_logo_detection.py -n 50 \\") logger.info(f" -e openai/clip-vit-large-patch14 --matching-method multi-ref") if __name__ == "__main__": main()