Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
310 lines
8.7 KiB
Python
310 lines
8.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Fine-tune CLIP vision encoder for logo recognition.
|
|
|
|
This script trains a CLIP model using contrastive learning on the LogoDet-3K
|
|
dataset to improve logo embedding quality for similarity-based matching.
|
|
|
|
Usage:
|
|
# Train with YAML config
|
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml
|
|
|
|
# Train with command-line overrides
|
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
|
--learning-rate 5e-6 --max-epochs 30
|
|
|
|
# Resume from checkpoint
|
|
uv run python train_clip_logo.py --config configs/jetson_orin.yaml \
|
|
--resume checkpoints/epoch_10.pt
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import random
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from training.config import TrainingConfig
|
|
from training.dataset import create_dataloaders
|
|
from training.model import create_model
|
|
from training.trainer import Trainer
|
|
|
|
|
|
def setup_logging(log_level: str = "INFO") -> logging.Logger:
|
|
"""Configure logging."""
|
|
logging.basicConfig(
|
|
level=getattr(logging, log_level.upper()),
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
return logging.getLogger(__name__)
|
|
|
|
|
|
def set_seed(seed: int) -> None:
|
|
"""Set random seeds for reproducibility."""
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
"""Parse command-line arguments."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Fine-tune CLIP for logo recognition",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
|
|
# Config file
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
help="Path to YAML configuration file",
|
|
)
|
|
|
|
# Dataset paths
|
|
parser.add_argument(
|
|
"--dataset-dir",
|
|
type=str,
|
|
help="Path to LogoDet-3K dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--reference-dir",
|
|
type=str,
|
|
help="Path to reference logos directory",
|
|
)
|
|
parser.add_argument(
|
|
"--db-path",
|
|
type=str,
|
|
help="Path to SQLite database",
|
|
)
|
|
|
|
# Model
|
|
parser.add_argument(
|
|
"--base-model",
|
|
type=str,
|
|
help="Base CLIP model name or path",
|
|
)
|
|
parser.add_argument(
|
|
"--lora-r",
|
|
type=int,
|
|
help="LoRA rank (0 to disable)",
|
|
)
|
|
parser.add_argument(
|
|
"--freeze-layers",
|
|
type=int,
|
|
help="Number of transformer layers to freeze",
|
|
)
|
|
|
|
# Training
|
|
parser.add_argument(
|
|
"--batch-size",
|
|
type=int,
|
|
help="Batch size",
|
|
)
|
|
parser.add_argument(
|
|
"--learning-rate",
|
|
type=float,
|
|
help="Learning rate",
|
|
)
|
|
parser.add_argument(
|
|
"--max-epochs",
|
|
type=int,
|
|
help="Maximum number of epochs",
|
|
)
|
|
parser.add_argument(
|
|
"--gradient-accumulation-steps",
|
|
type=int,
|
|
help="Gradient accumulation steps",
|
|
)
|
|
|
|
# Loss
|
|
parser.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
help="Temperature for InfoNCE loss",
|
|
)
|
|
parser.add_argument(
|
|
"--loss-type",
|
|
choices=["infonce", "supcon", "triplet", "combined"],
|
|
help="Loss function type",
|
|
)
|
|
|
|
# Checkpointing
|
|
parser.add_argument(
|
|
"--checkpoint-dir",
|
|
type=str,
|
|
help="Directory for checkpoints",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=str,
|
|
help="Directory for final model output",
|
|
)
|
|
parser.add_argument(
|
|
"--resume",
|
|
type=str,
|
|
help="Path to checkpoint to resume from",
|
|
)
|
|
|
|
# Other
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
help="Random seed",
|
|
)
|
|
parser.add_argument(
|
|
"--log-level",
|
|
type=str,
|
|
default="INFO",
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
help="Logging level",
|
|
)
|
|
parser.add_argument(
|
|
"--no-mixed-precision",
|
|
action="store_true",
|
|
help="Disable mixed precision training",
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
"""Main training entry point."""
|
|
args = parse_args()
|
|
|
|
# Setup logging
|
|
logger = setup_logging(args.log_level)
|
|
logger.info("CLIP Logo Fine-Tuning")
|
|
logger.info("=" * 60)
|
|
|
|
# Load or create configuration
|
|
if args.config:
|
|
logger.info(f"Loading config from: {args.config}")
|
|
config = TrainingConfig.from_yaml(args.config)
|
|
else:
|
|
logger.info("Using default configuration")
|
|
config = TrainingConfig()
|
|
|
|
# Apply command-line overrides
|
|
override_fields = [
|
|
"dataset_dir", "reference_dir", "db_path", "base_model",
|
|
"lora_r", "freeze_layers", "batch_size", "learning_rate",
|
|
"max_epochs", "gradient_accumulation_steps", "temperature",
|
|
"loss_type", "checkpoint_dir", "output_dir", "seed",
|
|
]
|
|
for field in override_fields:
|
|
arg_name = field.replace("_", "-")
|
|
arg_value = getattr(args, field.replace("-", "_"), None)
|
|
if arg_value is not None:
|
|
setattr(config, field, arg_value)
|
|
logger.info(f"Override: {field} = {arg_value}")
|
|
|
|
if args.no_mixed_precision:
|
|
config.mixed_precision = False
|
|
logger.info("Override: mixed_precision = False")
|
|
|
|
# Validate configuration
|
|
warnings = config.validate()
|
|
for warning in warnings:
|
|
logger.warning(f"Config warning: {warning}")
|
|
|
|
# Set random seed
|
|
set_seed(config.seed)
|
|
logger.info(f"Random seed: {config.seed}")
|
|
|
|
# Check paths exist
|
|
db_path = Path(config.db_path)
|
|
ref_dir = Path(config.reference_dir)
|
|
|
|
if not db_path.exists():
|
|
logger.error(f"Database not found: {db_path}")
|
|
logger.error("Run prepare_test_data.py first to create the database.")
|
|
sys.exit(1)
|
|
|
|
if not ref_dir.exists():
|
|
logger.error(f"Reference directory not found: {ref_dir}")
|
|
logger.error("Run prepare_test_data.py first to extract reference logos.")
|
|
sys.exit(1)
|
|
|
|
# Create model
|
|
logger.info(f"Creating model from: {config.base_model}")
|
|
model, processor = create_model(
|
|
base_model=config.base_model,
|
|
lora_r=config.lora_r,
|
|
lora_alpha=config.lora_alpha,
|
|
lora_dropout=config.lora_dropout,
|
|
freeze_layers=config.freeze_layers,
|
|
use_gradient_checkpointing=config.use_gradient_checkpointing,
|
|
)
|
|
|
|
# Create dataloaders
|
|
logger.info("Creating dataloaders...")
|
|
train_loader, val_loader, test_loader = create_dataloaders(
|
|
db_path=str(config.db_path),
|
|
reference_dir=str(config.reference_dir),
|
|
batch_size=config.batch_size,
|
|
logos_per_batch=config.logos_per_batch,
|
|
samples_per_logo=config.samples_per_logo,
|
|
num_workers=config.num_workers,
|
|
train_split=config.train_split,
|
|
val_split=config.val_split,
|
|
test_split=config.test_split,
|
|
seed=config.seed,
|
|
augmentation_strength=config.augmentation_strength,
|
|
)
|
|
|
|
# Create trainer
|
|
trainer = Trainer(
|
|
model=model,
|
|
train_loader=train_loader,
|
|
val_loader=val_loader,
|
|
config=config,
|
|
logger=logger,
|
|
)
|
|
|
|
# Resume from checkpoint if specified
|
|
if args.resume:
|
|
resume_path = Path(args.resume)
|
|
if resume_path.exists():
|
|
logger.info(f"Resuming from: {resume_path}")
|
|
# Set checkpoint dir to resume path's parent
|
|
if resume_path.is_file():
|
|
config.checkpoint_dir = str(resume_path.parent)
|
|
trainer.load_checkpoint(resume_path.name)
|
|
else:
|
|
logger.warning(f"Resume checkpoint not found: {resume_path}")
|
|
|
|
# Train
|
|
logger.info("\nStarting training...")
|
|
final_metrics = trainer.train()
|
|
|
|
logger.info("\nTraining complete!")
|
|
logger.info(f" Best val loss: {final_metrics['best_val_loss']:.4f}")
|
|
logger.info(f" Best separation: {final_metrics['best_val_separation']:.4f}")
|
|
logger.info(f" Total epochs: {final_metrics['total_epochs']}")
|
|
logger.info(f" Total time: {final_metrics['total_time_minutes']:.1f} minutes")
|
|
|
|
# Export model
|
|
output_path = trainer.export_model()
|
|
logger.info(f"\nModel exported to: {output_path}")
|
|
|
|
# Print next steps
|
|
logger.info("\n" + "=" * 60)
|
|
logger.info("Next steps:")
|
|
logger.info(f"1. Test the fine-tuned model:")
|
|
logger.info(f" uv run python test_logo_detection.py -n 50 \\")
|
|
logger.info(f" -e {output_path} --matching-method multi-ref")
|
|
logger.info(f"")
|
|
logger.info(f"2. Compare with baseline:")
|
|
logger.info(f" uv run python test_logo_detection.py -n 50 \\")
|
|
logger.info(f" -e openai/clip-vit-large-patch14 --matching-method multi-ref")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|