Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
170 lines
5.1 KiB
Python
170 lines
5.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Export a trained CLIP model to HuggingFace-compatible format.
|
|
|
|
This script converts a training checkpoint to a format that can be
|
|
loaded by DetectLogosDETR for inference.
|
|
|
|
Usage:
|
|
uv run python export_model.py \
|
|
--checkpoint checkpoints/best.pt \
|
|
--output models/logo_detection/clip_finetuned
|
|
|
|
# With custom base model
|
|
uv run python export_model.py \
|
|
--checkpoint checkpoints/best.pt \
|
|
--output models/logo_detection/clip_finetuned \
|
|
--base-model openai/clip-vit-large-patch14
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
from training.config import TrainingConfig
|
|
from training.model import create_model, LogoFineTunedCLIP
|
|
|
|
|
|
def setup_logging() -> logging.Logger:
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
return logging.getLogger(__name__)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Export trained CLIP model for inference",
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--checkpoint",
|
|
type=str,
|
|
required=True,
|
|
help="Path to training checkpoint (.pt file)",
|
|
)
|
|
parser.add_argument(
|
|
"--output",
|
|
type=str,
|
|
required=True,
|
|
help="Output directory for exported model",
|
|
)
|
|
parser.add_argument(
|
|
"--base-model",
|
|
type=str,
|
|
default=None,
|
|
help="Base CLIP model (reads from checkpoint config if not specified)",
|
|
)
|
|
parser.add_argument(
|
|
"--merge-lora",
|
|
action="store_true",
|
|
help="Merge LoRA weights into base model (reduces inference overhead)",
|
|
)
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
logger = setup_logging()
|
|
|
|
logger.info("CLIP Model Export")
|
|
logger.info("=" * 60)
|
|
|
|
# Check checkpoint exists
|
|
checkpoint_path = Path(args.checkpoint)
|
|
if not checkpoint_path.exists():
|
|
logger.error(f"Checkpoint not found: {checkpoint_path}")
|
|
sys.exit(1)
|
|
|
|
# Load checkpoint
|
|
logger.info(f"Loading checkpoint: {checkpoint_path}")
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
|
|
# Get config from checkpoint
|
|
if "config" in checkpoint:
|
|
config_dict = checkpoint["config"]
|
|
base_model = args.base_model or config_dict.get(
|
|
"base_model", "openai/clip-vit-large-patch14"
|
|
)
|
|
lora_r = config_dict.get("lora_r", 16)
|
|
lora_alpha = config_dict.get("lora_alpha", 32)
|
|
freeze_layers = config_dict.get("freeze_layers", 12)
|
|
else:
|
|
base_model = args.base_model or "openai/clip-vit-large-patch14"
|
|
lora_r = 16
|
|
lora_alpha = 32
|
|
freeze_layers = 12
|
|
|
|
logger.info(f"Base model: {base_model}")
|
|
logger.info(f"LoRA rank: {lora_r}")
|
|
logger.info(f"Freeze layers: {freeze_layers}")
|
|
|
|
# Create model with same architecture
|
|
logger.info("Creating model architecture...")
|
|
model, processor = create_model(
|
|
base_model=base_model,
|
|
lora_r=lora_r,
|
|
lora_alpha=lora_alpha,
|
|
freeze_layers=freeze_layers,
|
|
use_gradient_checkpointing=False, # Not needed for export
|
|
)
|
|
|
|
# Load weights
|
|
logger.info("Loading trained weights...")
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
|
|
# Merge LoRA if requested
|
|
if args.merge_lora and model.peft_applied:
|
|
try:
|
|
logger.info("Merging LoRA weights into base model...")
|
|
model.vision_model = model.vision_model.merge_and_unload()
|
|
model.peft_applied = False
|
|
model.lora_r = 0
|
|
logger.info("LoRA weights merged successfully")
|
|
except Exception as e:
|
|
logger.warning(f"Could not merge LoRA weights: {e}")
|
|
logger.warning("Exporting with separate LoRA weights")
|
|
|
|
# Create output directory
|
|
output_path = Path(args.output)
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Save model
|
|
logger.info(f"Exporting to: {output_path}")
|
|
model.save_pretrained(str(output_path))
|
|
|
|
# Save processor config for reference
|
|
processor.save_pretrained(str(output_path / "processor"))
|
|
|
|
# Save additional metadata
|
|
metadata = {
|
|
"base_model": base_model,
|
|
"source_checkpoint": str(checkpoint_path),
|
|
"training_epochs": checkpoint.get("epoch", -1) + 1,
|
|
"best_val_loss": checkpoint.get("best_val_loss", None),
|
|
"best_val_separation": checkpoint.get("best_val_separation", None),
|
|
"lora_merged": args.merge_lora and not model.peft_applied,
|
|
}
|
|
|
|
with open(output_path / "export_metadata.json", "w") as f:
|
|
json.dump(metadata, f, indent=2)
|
|
|
|
logger.info("\nExport complete!")
|
|
logger.info(f"Model saved to: {output_path}")
|
|
logger.info("\nTo use with DetectLogosDETR:")
|
|
logger.info(f" detector = DetectLogosDETR(embedding_model='{output_path}')")
|
|
logger.info("\nOr with test_logo_detection.py:")
|
|
logger.info(f" uv run python test_logo_detection.py -e {output_path}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|