Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
This commit is contained in:
169
export_model.py
Normal file
169
export_model.py
Normal file
@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Export a trained CLIP model to HuggingFace-compatible format.
|
||||
|
||||
This script converts a training checkpoint to a format that can be
|
||||
loaded by DetectLogosDETR for inference.
|
||||
|
||||
Usage:
|
||||
uv run python export_model.py \
|
||||
--checkpoint checkpoints/best.pt \
|
||||
--output models/logo_detection/clip_finetuned
|
||||
|
||||
# With custom base model
|
||||
uv run python export_model.py \
|
||||
--checkpoint checkpoints/best.pt \
|
||||
--output models/logo_detection/clip_finetuned \
|
||||
--base-model openai/clip-vit-large-patch14
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from training.config import TrainingConfig
|
||||
from training.model import create_model, LogoFineTunedCLIP
|
||||
|
||||
|
||||
def setup_logging() -> logging.Logger:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Export trained CLIP model for inference",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to training checkpoint (.pt file)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory for exported model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-model",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Base CLIP model (reads from checkpoint config if not specified)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--merge-lora",
|
||||
action="store_true",
|
||||
help="Merge LoRA weights into base model (reduces inference overhead)",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logger = setup_logging()
|
||||
|
||||
logger.info("CLIP Model Export")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Check checkpoint exists
|
||||
checkpoint_path = Path(args.checkpoint)
|
||||
if not checkpoint_path.exists():
|
||||
logger.error(f"Checkpoint not found: {checkpoint_path}")
|
||||
sys.exit(1)
|
||||
|
||||
# Load checkpoint
|
||||
logger.info(f"Loading checkpoint: {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
# Get config from checkpoint
|
||||
if "config" in checkpoint:
|
||||
config_dict = checkpoint["config"]
|
||||
base_model = args.base_model or config_dict.get(
|
||||
"base_model", "openai/clip-vit-large-patch14"
|
||||
)
|
||||
lora_r = config_dict.get("lora_r", 16)
|
||||
lora_alpha = config_dict.get("lora_alpha", 32)
|
||||
freeze_layers = config_dict.get("freeze_layers", 12)
|
||||
else:
|
||||
base_model = args.base_model or "openai/clip-vit-large-patch14"
|
||||
lora_r = 16
|
||||
lora_alpha = 32
|
||||
freeze_layers = 12
|
||||
|
||||
logger.info(f"Base model: {base_model}")
|
||||
logger.info(f"LoRA rank: {lora_r}")
|
||||
logger.info(f"Freeze layers: {freeze_layers}")
|
||||
|
||||
# Create model with same architecture
|
||||
logger.info("Creating model architecture...")
|
||||
model, processor = create_model(
|
||||
base_model=base_model,
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
freeze_layers=freeze_layers,
|
||||
use_gradient_checkpointing=False, # Not needed for export
|
||||
)
|
||||
|
||||
# Load weights
|
||||
logger.info("Loading trained weights...")
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
|
||||
# Merge LoRA if requested
|
||||
if args.merge_lora and model.peft_applied:
|
||||
try:
|
||||
logger.info("Merging LoRA weights into base model...")
|
||||
model.vision_model = model.vision_model.merge_and_unload()
|
||||
model.peft_applied = False
|
||||
model.lora_r = 0
|
||||
logger.info("LoRA weights merged successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not merge LoRA weights: {e}")
|
||||
logger.warning("Exporting with separate LoRA weights")
|
||||
|
||||
# Create output directory
|
||||
output_path = Path(args.output)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
logger.info(f"Exporting to: {output_path}")
|
||||
model.save_pretrained(str(output_path))
|
||||
|
||||
# Save processor config for reference
|
||||
processor.save_pretrained(str(output_path / "processor"))
|
||||
|
||||
# Save additional metadata
|
||||
metadata = {
|
||||
"base_model": base_model,
|
||||
"source_checkpoint": str(checkpoint_path),
|
||||
"training_epochs": checkpoint.get("epoch", -1) + 1,
|
||||
"best_val_loss": checkpoint.get("best_val_loss", None),
|
||||
"best_val_separation": checkpoint.get("best_val_separation", None),
|
||||
"lora_merged": args.merge_lora and not model.peft_applied,
|
||||
}
|
||||
|
||||
with open(output_path / "export_metadata.json", "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
logger.info("\nExport complete!")
|
||||
logger.info(f"Model saved to: {output_path}")
|
||||
logger.info("\nTo use with DetectLogosDETR:")
|
||||
logger.info(f" detector = DetectLogosDETR(embedding_model='{output_path}')")
|
||||
logger.info("\nOr with test_logo_detection.py:")
|
||||
logger.info(f" uv run python test_logo_detection.py -e {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user