""" Fine-tunable CLIP model wrapper with LoRA support. """ import json from pathlib import Path from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers import CLIPModel, CLIPProcessor # Check if peft is available for LoRA try: from peft import LoraConfig, get_peft_model, PeftModel PEFT_AVAILABLE = True except ImportError: PEFT_AVAILABLE = False LoraConfig = None get_peft_model = None PeftModel = None class LogoFineTunedCLIP(nn.Module): """ CLIP vision encoder fine-tuned for logo similarity. Preserves embedding interface for compatibility with DetectLogosDETR: - Same embedding dimensionality (768 for ViT-L/14) - L2 normalized outputs - Works with existing get_image_features() pattern Supports: - LoRA for memory-efficient fine-tuning - Layer freezing for transfer learning - Gradient checkpointing for memory optimization """ def __init__( self, vision_model: nn.Module, lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.1, freeze_layers: int = 12, use_gradient_checkpointing: bool = True, add_projection_head: bool = True, ): """ Initialize the fine-tunable CLIP wrapper. Args: vision_model: CLIP vision model (CLIPVisionModel) lora_r: Rank of LoRA low-rank matrices (0 to disable) lora_alpha: LoRA scaling factor lora_dropout: Dropout for LoRA layers freeze_layers: Number of transformer layers to freeze (from bottom) use_gradient_checkpointing: Enable gradient checkpointing add_projection_head: Add trainable projection head """ super().__init__() self.vision_model = vision_model self.embedding_dim = vision_model.config.hidden_size self.freeze_layers = freeze_layers self.lora_r = lora_r self.lora_alpha = lora_alpha # Enable gradient checkpointing for memory efficiency if use_gradient_checkpointing: if hasattr(self.vision_model, "gradient_checkpointing_enable"): self.vision_model.gradient_checkpointing_enable() # Freeze lower layers self._freeze_layers(freeze_layers) # Apply LoRA to attention layers in upper blocks self.peft_applied = False if PEFT_AVAILABLE and lora_r > 0: self._apply_lora(lora_r, lora_alpha, lora_dropout) self.peft_applied = True elif lora_r > 0 and not PEFT_AVAILABLE: print( "Warning: peft not installed. LoRA disabled. " "Install with: pip install peft" ) # Optional projection head for fine-tuning self.add_projection_head = add_projection_head if add_projection_head: self.projection = nn.Sequential( nn.Linear(self.embedding_dim, self.embedding_dim), nn.LayerNorm(self.embedding_dim), ) else: self.projection = nn.Identity() def _freeze_layers(self, num_layers: int) -> None: """Freeze the first N transformer layers and embeddings.""" if num_layers <= 0: return # Freeze embeddings if hasattr(self.vision_model, "embeddings"): for param in self.vision_model.embeddings.parameters(): param.requires_grad = False # Freeze specified number of encoder layers if hasattr(self.vision_model, "encoder"): for i, layer in enumerate(self.vision_model.encoder.layers): if i < num_layers: for param in layer.parameters(): param.requires_grad = False def _apply_lora( self, r: int, alpha: int, dropout: float, ) -> None: """Apply LoRA adapters to attention layers.""" if not PEFT_AVAILABLE: return # Configure LoRA for vision transformer lora_config = LoraConfig( r=r, lora_alpha=alpha, lora_dropout=dropout, target_modules=["q_proj", "v_proj"], bias="none", modules_to_save=[], # Don't save any full modules ) self.vision_model = get_peft_model(self.vision_model, lora_config) def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: """ Extract normalized embeddings for logo images. Args: pixel_values: [batch, 3, 224, 224] preprocessed images Returns: embeddings: [batch, embedding_dim] L2-normalized """ # Get vision features outputs = self.vision_model(pixel_values=pixel_values) # Use pooler output (CLS token projection) if available if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: features = outputs.pooler_output else: # Fall back to CLS token from last hidden state features = outputs.last_hidden_state[:, 0, :] # Apply projection head features = self.projection(features) # L2 normalize for cosine similarity features = F.normalize(features, dim=-1) return features def get_image_features(self, **kwargs) -> torch.Tensor: """ Compatibility method matching CLIP's interface. Used by DetectLogosDETR._get_embedding_pil(). """ return self.forward(kwargs["pixel_values"]) def get_trainable_parameters(self) -> List[torch.nn.Parameter]: """Return list of trainable parameters.""" return [p for p in self.parameters() if p.requires_grad] def get_parameter_count(self) -> Dict[str, int]: """Return count of trainable and total parameters.""" total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) return { "total": total, "trainable": trainable, "frozen": total - trainable, "trainable_percent": 100 * trainable / total if total > 0 else 0, } def save_pretrained(self, output_dir: str) -> None: """ Save model in HuggingFace-compatible format. Args: output_dir: Directory to save model files """ output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) # Save model weights if self.peft_applied and PEFT_AVAILABLE: # Save LoRA weights separately self.vision_model.save_pretrained(output_path / "vision_lora") # Save projection head torch.save( self.projection.state_dict(), output_path / "projection_head.bin", ) else: # Save full model state torch.save(self.state_dict(), output_path / "pytorch_model.bin") # Save config config = { "model_type": "clip_logo_finetuned", "embedding_dim": self.embedding_dim, "lora_r": self.lora_r, "lora_alpha": self.lora_alpha, "freeze_layers": self.freeze_layers, "add_projection_head": self.add_projection_head, "peft_applied": self.peft_applied, } with open(output_path / "config.json", "w") as f: json.dump(config, f, indent=2) @classmethod def from_pretrained( cls, model_path: str, base_model: str = "openai/clip-vit-large-patch14", device: Optional[torch.device] = None, ) -> "LogoFineTunedCLIP": """ Load a fine-tuned model from saved weights. Args: model_path: Path to saved model directory base_model: Base CLIP model name (for architecture) device: Device to load model on Returns: Loaded LogoFineTunedCLIP model """ model_path = Path(model_path) # Load config with open(model_path / "config.json", "r") as f: config = json.load(f) # Load base CLIP model clip_model = CLIPModel.from_pretrained(base_model) # Create model instance model = cls( vision_model=clip_model.vision_model, lora_r=config.get("lora_r", 0), lora_alpha=config.get("lora_alpha", 1), freeze_layers=config.get("freeze_layers", 12), add_projection_head=config.get("add_projection_head", True), use_gradient_checkpointing=False, # Not needed for inference ) # Load weights if config.get("peft_applied", False) and PEFT_AVAILABLE: # Load LoRA weights lora_path = model_path / "vision_lora" if lora_path.exists(): model.vision_model = PeftModel.from_pretrained( model.vision_model, lora_path ) # Load projection head proj_path = model_path / "projection_head.bin" if proj_path.exists(): model.projection.load_state_dict(torch.load(proj_path)) else: # Load full model state weights_path = model_path / "pytorch_model.bin" if weights_path.exists(): model.load_state_dict(torch.load(weights_path)) if device is not None: model = model.to(device) return model def create_model( base_model: str = "openai/clip-vit-large-patch14", lora_r: int = 16, lora_alpha: int = 32, lora_dropout: float = 0.1, freeze_layers: int = 12, use_gradient_checkpointing: bool = True, device: Optional[torch.device] = None, ) -> Tuple[LogoFineTunedCLIP, CLIPProcessor]: """ Create a fine-tunable CLIP model and processor. Args: base_model: HuggingFace model name or path lora_r: LoRA rank (0 to disable) lora_alpha: LoRA scaling factor lora_dropout: LoRA dropout freeze_layers: Number of layers to freeze use_gradient_checkpointing: Enable gradient checkpointing device: Device to load model on Returns: Tuple of (model, processor) """ # Load base CLIP model clip_model = CLIPModel.from_pretrained(base_model) processor = CLIPProcessor.from_pretrained(base_model) # Create fine-tunable wrapper model = LogoFineTunedCLIP( vision_model=clip_model.vision_model, lora_r=lora_r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, freeze_layers=freeze_layers, use_gradient_checkpointing=use_gradient_checkpointing, ) if device is not None: model = model.to(device) # Print parameter info param_info = model.get_parameter_count() print(f"Model created:") print(f" Total parameters: {param_info['total']:,}") print(f" Trainable: {param_info['trainable']:,} ({param_info['trainable_percent']:.2f}%)") print(f" Frozen: {param_info['frozen']:,}") return model, processor