Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
This commit is contained in:
309
train_clip_logo.py
Normal file
309
train_clip_logo.py
Normal file
@ -0,0 +1,309 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fine-tune CLIP vision encoder for logo recognition.
|
||||
|
||||
This script trains a CLIP model using contrastive learning on the LogoDet-3K
|
||||
dataset to improve logo embedding quality for similarity-based matching.
|
||||
|
||||
Usage:
|
||||
# Train with YAML config
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
|
||||
|
||||
# Train with command-line overrides
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||
--learning-rate 5e-6 --max-epochs 30
|
||||
|
||||
# Resume from checkpoint
|
||||
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
||||
--resume checkpoints/epoch_10.pt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from training.config import TrainingConfig
|
||||
from training.dataset import create_dataloaders
|
||||
from training.model import create_model
|
||||
from training.trainer import Trainer
|
||||
|
||||
|
||||
def setup_logging(log_level: str = "INFO") -> logging.Logger:
|
||||
"""Configure logging."""
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, log_level.upper()),
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
"""Set random seeds for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Fine-tune CLIP for logo recognition",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Config file
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to YAML configuration file",
|
||||
)
|
||||
|
||||
# Dataset paths
|
||||
parser.add_argument(
|
||||
"--dataset-dir",
|
||||
type=str,
|
||||
help="Path to LogoDet-3K dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference-dir",
|
||||
type=str,
|
||||
help="Path to reference logos directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--db-path",
|
||||
type=str,
|
||||
help="Path to SQLite database",
|
||||
)
|
||||
|
||||
# Model
|
||||
parser.add_argument(
|
||||
"--base-model",
|
||||
type=str,
|
||||
help="Base CLIP model name or path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-r",
|
||||
type=int,
|
||||
help="LoRA rank (0 to disable)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--freeze-layers",
|
||||
type=int,
|
||||
help="Number of transformer layers to freeze",
|
||||
)
|
||||
|
||||
# Training
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
help="Batch size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate",
|
||||
type=float,
|
||||
help="Learning rate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-epochs",
|
||||
type=int,
|
||||
help="Maximum number of epochs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient-accumulation-steps",
|
||||
type=int,
|
||||
help="Gradient accumulation steps",
|
||||
)
|
||||
|
||||
# Loss
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
help="Temperature for InfoNCE loss",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loss-type",
|
||||
choices=["infonce", "supcon", "triplet", "combined"],
|
||||
help="Loss function type",
|
||||
)
|
||||
|
||||
# Checkpointing
|
||||
parser.add_argument(
|
||||
"--checkpoint-dir",
|
||||
type=str,
|
||||
help="Directory for checkpoints",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
help="Directory for final model output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
type=str,
|
||||
help="Path to checkpoint to resume from",
|
||||
)
|
||||
|
||||
# Other
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
help="Random seed",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
default="INFO",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
help="Logging level",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-mixed-precision",
|
||||
action="store_true",
|
||||
help="Disable mixed precision training",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main training entry point."""
|
||||
args = parse_args()
|
||||
|
||||
# Setup logging
|
||||
logger = setup_logging(args.log_level)
|
||||
logger.info("CLIP Logo Fine-Tuning")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Load or create configuration
|
||||
if args.config:
|
||||
logger.info(f"Loading config from: {args.config}")
|
||||
config = TrainingConfig.from_yaml(args.config)
|
||||
else:
|
||||
logger.info("Using default configuration")
|
||||
config = TrainingConfig()
|
||||
|
||||
# Apply command-line overrides
|
||||
override_fields = [
|
||||
"dataset_dir", "reference_dir", "db_path", "base_model",
|
||||
"lora_r", "freeze_layers", "batch_size", "learning_rate",
|
||||
"max_epochs", "gradient_accumulation_steps", "temperature",
|
||||
"loss_type", "checkpoint_dir", "output_dir", "seed",
|
||||
]
|
||||
for field in override_fields:
|
||||
arg_name = field.replace("_", "-")
|
||||
arg_value = getattr(args, field.replace("-", "_"), None)
|
||||
if arg_value is not None:
|
||||
setattr(config, field, arg_value)
|
||||
logger.info(f"Override: {field} = {arg_value}")
|
||||
|
||||
if args.no_mixed_precision:
|
||||
config.mixed_precision = False
|
||||
logger.info("Override: mixed_precision = False")
|
||||
|
||||
# Validate configuration
|
||||
warnings = config.validate()
|
||||
for warning in warnings:
|
||||
logger.warning(f"Config warning: {warning}")
|
||||
|
||||
# Set random seed
|
||||
set_seed(config.seed)
|
||||
logger.info(f"Random seed: {config.seed}")
|
||||
|
||||
# Check paths exist
|
||||
db_path = Path(config.db_path)
|
||||
ref_dir = Path(config.reference_dir)
|
||||
|
||||
if not db_path.exists():
|
||||
logger.error(f"Database not found: {db_path}")
|
||||
logger.error("Run prepare_test_data.py first to create the database.")
|
||||
sys.exit(1)
|
||||
|
||||
if not ref_dir.exists():
|
||||
logger.error(f"Reference directory not found: {ref_dir}")
|
||||
logger.error("Run prepare_test_data.py first to extract reference logos.")
|
||||
sys.exit(1)
|
||||
|
||||
# Create model
|
||||
logger.info(f"Creating model from: {config.base_model}")
|
||||
model, processor = create_model(
|
||||
base_model=config.base_model,
|
||||
lora_r=config.lora_r,
|
||||
lora_alpha=config.lora_alpha,
|
||||
lora_dropout=config.lora_dropout,
|
||||
freeze_layers=config.freeze_layers,
|
||||
use_gradient_checkpointing=config.use_gradient_checkpointing,
|
||||
)
|
||||
|
||||
# Create dataloaders
|
||||
logger.info("Creating dataloaders...")
|
||||
train_loader, val_loader, test_loader = create_dataloaders(
|
||||
db_path=str(config.db_path),
|
||||
reference_dir=str(config.reference_dir),
|
||||
batch_size=config.batch_size,
|
||||
logos_per_batch=config.logos_per_batch,
|
||||
samples_per_logo=config.samples_per_logo,
|
||||
num_workers=config.num_workers,
|
||||
train_split=config.train_split,
|
||||
val_split=config.val_split,
|
||||
test_split=config.test_split,
|
||||
seed=config.seed,
|
||||
augmentation_strength=config.augmentation_strength,
|
||||
)
|
||||
|
||||
# Create trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
train_loader=train_loader,
|
||||
val_loader=val_loader,
|
||||
config=config,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Resume from checkpoint if specified
|
||||
if args.resume:
|
||||
resume_path = Path(args.resume)
|
||||
if resume_path.exists():
|
||||
logger.info(f"Resuming from: {resume_path}")
|
||||
# Set checkpoint dir to resume path's parent
|
||||
if resume_path.is_file():
|
||||
config.checkpoint_dir = str(resume_path.parent)
|
||||
trainer.load_checkpoint(resume_path.name)
|
||||
else:
|
||||
logger.warning(f"Resume checkpoint not found: {resume_path}")
|
||||
|
||||
# Train
|
||||
logger.info("\nStarting training...")
|
||||
final_metrics = trainer.train()
|
||||
|
||||
logger.info("\nTraining complete!")
|
||||
logger.info(f" Best val loss: {final_metrics['best_val_loss']:.4f}")
|
||||
logger.info(f" Best separation: {final_metrics['best_val_separation']:.4f}")
|
||||
logger.info(f" Total epochs: {final_metrics['total_epochs']}")
|
||||
logger.info(f" Total time: {final_metrics['total_time_minutes']:.1f} minutes")
|
||||
|
||||
# Export model
|
||||
output_path = trainer.export_model()
|
||||
logger.info(f"\nModel exported to: {output_path}")
|
||||
|
||||
# Print next steps
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("Next steps:")
|
||||
logger.info(f"1. Test the fine-tuned model:")
|
||||
logger.info(f" uv run python test_logo_detection.py -n 50 \\")
|
||||
logger.info(f" -e {output_path} --matching-method multi-ref")
|
||||
logger.info(f"")
|
||||
logger.info(f"2. Compare with baseline:")
|
||||
logger.info(f" uv run python test_logo_detection.py -n 50 \\")
|
||||
logger.info(f" -e openai/clip-vit-large-patch14 --matching-method multi-ref")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user