Add CLIP fine-tuning pipeline for logo recognition
Implement contrastive learning with LoRA to fine-tune CLIP's vision encoder on LogoDet-3K dataset for improved logo embedding similarity. New training module (training/): - config.py: TrainingConfig dataclass with all hyperparameters - dataset.py: LogoContrastiveDataset with logo-level splits - model.py: LogoFineTunedCLIP wrapper with LoRA support - losses.py: InfoNCE, TripletLoss, SupConLoss implementations - trainer.py: Training loop with mixed precision and checkpointing - evaluation.py: EmbeddingEvaluator for validation metrics New scripts: - train_clip_logo.py: Main training entry point - export_model.py: Export to HuggingFace-compatible format Configurations: - configs/jetson_orin.yaml: Optimized for Jetson Orin AGX - configs/cloud_rtx4090.yaml: Optimized for 24GB cloud GPUs - configs/cloud_a100.yaml: Optimized for 80GB cloud GPUs Documentation: - CLIP_FINETUNING.md: Training guide and usage instructions - CLOUD_TRAINING.md: Cloud GPU recommendations and cost estimates Modified: - logo_detection_detr.py: Add fine-tuned model loading support - pyproject.toml: Add peft, pyyaml, torchvision dependencies
This commit is contained in:
335
training/model.py
Normal file
335
training/model.py
Normal file
@ -0,0 +1,335 @@
|
||||
"""
|
||||
Fine-tunable CLIP model wrapper with LoRA support.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
# Check if peft is available for LoRA
|
||||
try:
|
||||
from peft import LoraConfig, get_peft_model, PeftModel
|
||||
PEFT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PEFT_AVAILABLE = False
|
||||
LoraConfig = None
|
||||
get_peft_model = None
|
||||
PeftModel = None
|
||||
|
||||
|
||||
class LogoFineTunedCLIP(nn.Module):
|
||||
"""
|
||||
CLIP vision encoder fine-tuned for logo similarity.
|
||||
|
||||
Preserves embedding interface for compatibility with DetectLogosDETR:
|
||||
- Same embedding dimensionality (768 for ViT-L/14)
|
||||
- L2 normalized outputs
|
||||
- Works with existing get_image_features() pattern
|
||||
|
||||
Supports:
|
||||
- LoRA for memory-efficient fine-tuning
|
||||
- Layer freezing for transfer learning
|
||||
- Gradient checkpointing for memory optimization
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_model: nn.Module,
|
||||
lora_r: int = 16,
|
||||
lora_alpha: int = 32,
|
||||
lora_dropout: float = 0.1,
|
||||
freeze_layers: int = 12,
|
||||
use_gradient_checkpointing: bool = True,
|
||||
add_projection_head: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the fine-tunable CLIP wrapper.
|
||||
|
||||
Args:
|
||||
vision_model: CLIP vision model (CLIPVisionModel)
|
||||
lora_r: Rank of LoRA low-rank matrices (0 to disable)
|
||||
lora_alpha: LoRA scaling factor
|
||||
lora_dropout: Dropout for LoRA layers
|
||||
freeze_layers: Number of transformer layers to freeze (from bottom)
|
||||
use_gradient_checkpointing: Enable gradient checkpointing
|
||||
add_projection_head: Add trainable projection head
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = vision_model
|
||||
self.embedding_dim = vision_model.config.hidden_size
|
||||
self.freeze_layers = freeze_layers
|
||||
self.lora_r = lora_r
|
||||
self.lora_alpha = lora_alpha
|
||||
|
||||
# Enable gradient checkpointing for memory efficiency
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(self.vision_model, "gradient_checkpointing_enable"):
|
||||
self.vision_model.gradient_checkpointing_enable()
|
||||
|
||||
# Freeze lower layers
|
||||
self._freeze_layers(freeze_layers)
|
||||
|
||||
# Apply LoRA to attention layers in upper blocks
|
||||
self.peft_applied = False
|
||||
if PEFT_AVAILABLE and lora_r > 0:
|
||||
self._apply_lora(lora_r, lora_alpha, lora_dropout)
|
||||
self.peft_applied = True
|
||||
elif lora_r > 0 and not PEFT_AVAILABLE:
|
||||
print(
|
||||
"Warning: peft not installed. LoRA disabled. "
|
||||
"Install with: pip install peft"
|
||||
)
|
||||
|
||||
# Optional projection head for fine-tuning
|
||||
self.add_projection_head = add_projection_head
|
||||
if add_projection_head:
|
||||
self.projection = nn.Sequential(
|
||||
nn.Linear(self.embedding_dim, self.embedding_dim),
|
||||
nn.LayerNorm(self.embedding_dim),
|
||||
)
|
||||
else:
|
||||
self.projection = nn.Identity()
|
||||
|
||||
def _freeze_layers(self, num_layers: int) -> None:
|
||||
"""Freeze the first N transformer layers and embeddings."""
|
||||
if num_layers <= 0:
|
||||
return
|
||||
|
||||
# Freeze embeddings
|
||||
if hasattr(self.vision_model, "embeddings"):
|
||||
for param in self.vision_model.embeddings.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze specified number of encoder layers
|
||||
if hasattr(self.vision_model, "encoder"):
|
||||
for i, layer in enumerate(self.vision_model.encoder.layers):
|
||||
if i < num_layers:
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _apply_lora(
|
||||
self,
|
||||
r: int,
|
||||
alpha: int,
|
||||
dropout: float,
|
||||
) -> None:
|
||||
"""Apply LoRA adapters to attention layers."""
|
||||
if not PEFT_AVAILABLE:
|
||||
return
|
||||
|
||||
# Configure LoRA for vision transformer
|
||||
lora_config = LoraConfig(
|
||||
r=r,
|
||||
lora_alpha=alpha,
|
||||
lora_dropout=dropout,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
bias="none",
|
||||
modules_to_save=[], # Don't save any full modules
|
||||
)
|
||||
|
||||
self.vision_model = get_peft_model(self.vision_model, lora_config)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Extract normalized embeddings for logo images.
|
||||
|
||||
Args:
|
||||
pixel_values: [batch, 3, 224, 224] preprocessed images
|
||||
|
||||
Returns:
|
||||
embeddings: [batch, embedding_dim] L2-normalized
|
||||
"""
|
||||
# Get vision features
|
||||
outputs = self.vision_model(pixel_values=pixel_values)
|
||||
|
||||
# Use pooler output (CLS token projection) if available
|
||||
if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None:
|
||||
features = outputs.pooler_output
|
||||
else:
|
||||
# Fall back to CLS token from last hidden state
|
||||
features = outputs.last_hidden_state[:, 0, :]
|
||||
|
||||
# Apply projection head
|
||||
features = self.projection(features)
|
||||
|
||||
# L2 normalize for cosine similarity
|
||||
features = F.normalize(features, dim=-1)
|
||||
|
||||
return features
|
||||
|
||||
def get_image_features(self, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Compatibility method matching CLIP's interface.
|
||||
|
||||
Used by DetectLogosDETR._get_embedding_pil().
|
||||
"""
|
||||
return self.forward(kwargs["pixel_values"])
|
||||
|
||||
def get_trainable_parameters(self) -> List[torch.nn.Parameter]:
|
||||
"""Return list of trainable parameters."""
|
||||
return [p for p in self.parameters() if p.requires_grad]
|
||||
|
||||
def get_parameter_count(self) -> Dict[str, int]:
|
||||
"""Return count of trainable and total parameters."""
|
||||
total = sum(p.numel() for p in self.parameters())
|
||||
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
||||
return {
|
||||
"total": total,
|
||||
"trainable": trainable,
|
||||
"frozen": total - trainable,
|
||||
"trainable_percent": 100 * trainable / total if total > 0 else 0,
|
||||
}
|
||||
|
||||
def save_pretrained(self, output_dir: str) -> None:
|
||||
"""
|
||||
Save model in HuggingFace-compatible format.
|
||||
|
||||
Args:
|
||||
output_dir: Directory to save model files
|
||||
"""
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save model weights
|
||||
if self.peft_applied and PEFT_AVAILABLE:
|
||||
# Save LoRA weights separately
|
||||
self.vision_model.save_pretrained(output_path / "vision_lora")
|
||||
# Save projection head
|
||||
torch.save(
|
||||
self.projection.state_dict(),
|
||||
output_path / "projection_head.bin",
|
||||
)
|
||||
else:
|
||||
# Save full model state
|
||||
torch.save(self.state_dict(), output_path / "pytorch_model.bin")
|
||||
|
||||
# Save config
|
||||
config = {
|
||||
"model_type": "clip_logo_finetuned",
|
||||
"embedding_dim": self.embedding_dim,
|
||||
"lora_r": self.lora_r,
|
||||
"lora_alpha": self.lora_alpha,
|
||||
"freeze_layers": self.freeze_layers,
|
||||
"add_projection_head": self.add_projection_head,
|
||||
"peft_applied": self.peft_applied,
|
||||
}
|
||||
|
||||
with open(output_path / "config.json", "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
model_path: str,
|
||||
base_model: str = "openai/clip-vit-large-patch14",
|
||||
device: Optional[torch.device] = None,
|
||||
) -> "LogoFineTunedCLIP":
|
||||
"""
|
||||
Load a fine-tuned model from saved weights.
|
||||
|
||||
Args:
|
||||
model_path: Path to saved model directory
|
||||
base_model: Base CLIP model name (for architecture)
|
||||
device: Device to load model on
|
||||
|
||||
Returns:
|
||||
Loaded LogoFineTunedCLIP model
|
||||
"""
|
||||
model_path = Path(model_path)
|
||||
|
||||
# Load config
|
||||
with open(model_path / "config.json", "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Load base CLIP model
|
||||
clip_model = CLIPModel.from_pretrained(base_model)
|
||||
|
||||
# Create model instance
|
||||
model = cls(
|
||||
vision_model=clip_model.vision_model,
|
||||
lora_r=config.get("lora_r", 0),
|
||||
lora_alpha=config.get("lora_alpha", 1),
|
||||
freeze_layers=config.get("freeze_layers", 12),
|
||||
add_projection_head=config.get("add_projection_head", True),
|
||||
use_gradient_checkpointing=False, # Not needed for inference
|
||||
)
|
||||
|
||||
# Load weights
|
||||
if config.get("peft_applied", False) and PEFT_AVAILABLE:
|
||||
# Load LoRA weights
|
||||
lora_path = model_path / "vision_lora"
|
||||
if lora_path.exists():
|
||||
model.vision_model = PeftModel.from_pretrained(
|
||||
model.vision_model, lora_path
|
||||
)
|
||||
# Load projection head
|
||||
proj_path = model_path / "projection_head.bin"
|
||||
if proj_path.exists():
|
||||
model.projection.load_state_dict(torch.load(proj_path))
|
||||
else:
|
||||
# Load full model state
|
||||
weights_path = model_path / "pytorch_model.bin"
|
||||
if weights_path.exists():
|
||||
model.load_state_dict(torch.load(weights_path))
|
||||
|
||||
if device is not None:
|
||||
model = model.to(device)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def create_model(
|
||||
base_model: str = "openai/clip-vit-large-patch14",
|
||||
lora_r: int = 16,
|
||||
lora_alpha: int = 32,
|
||||
lora_dropout: float = 0.1,
|
||||
freeze_layers: int = 12,
|
||||
use_gradient_checkpointing: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[LogoFineTunedCLIP, CLIPProcessor]:
|
||||
"""
|
||||
Create a fine-tunable CLIP model and processor.
|
||||
|
||||
Args:
|
||||
base_model: HuggingFace model name or path
|
||||
lora_r: LoRA rank (0 to disable)
|
||||
lora_alpha: LoRA scaling factor
|
||||
lora_dropout: LoRA dropout
|
||||
freeze_layers: Number of layers to freeze
|
||||
use_gradient_checkpointing: Enable gradient checkpointing
|
||||
device: Device to load model on
|
||||
|
||||
Returns:
|
||||
Tuple of (model, processor)
|
||||
"""
|
||||
# Load base CLIP model
|
||||
clip_model = CLIPModel.from_pretrained(base_model)
|
||||
processor = CLIPProcessor.from_pretrained(base_model)
|
||||
|
||||
# Create fine-tunable wrapper
|
||||
model = LogoFineTunedCLIP(
|
||||
vision_model=clip_model.vision_model,
|
||||
lora_r=lora_r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
freeze_layers=freeze_layers,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
)
|
||||
|
||||
if device is not None:
|
||||
model = model.to(device)
|
||||
|
||||
# Print parameter info
|
||||
param_info = model.get_parameter_count()
|
||||
print(f"Model created:")
|
||||
print(f" Total parameters: {param_info['total']:,}")
|
||||
print(f" Trainable: {param_info['trainable']:,} ({param_info['trainable_percent']:.2f}%)")
|
||||
print(f" Frozen: {param_info['frozen']:,}")
|
||||
|
||||
return model, processor
|
||||
Reference in New Issue
Block a user