From 1bf9985defe4cfdeb8c44642297182dc26785fe1 Mon Sep 17 00:00:00 2001 From: Rick McEwen Date: Mon, 5 Jan 2026 11:50:10 -0500 Subject: [PATCH] Fix double LoRA application when loading fine-tuned model The from_pretrained method was applying LoRA twice: 1. In the constructor via lora_r parameter 2. When loading with PeftModel.from_pretrained() Now creates model with lora_r=0 and loads LoRA weights separately. Note: Warning about "missing adapter keys" for layers 0-11 is expected since those layers are frozen and don't have LoRA adapters. --- training/model.py | 46 +++++++++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/training/model.py b/training/model.py index 90a30b9..ae846be 100644 --- a/training/model.py +++ b/training/model.py @@ -250,33 +250,49 @@ class LogoFineTunedCLIP(nn.Module): # Load base CLIP model clip_model = CLIPModel.from_pretrained(base_model) - # Create model instance - model = cls( - vision_model=clip_model.vision_model, - lora_r=config.get("lora_r", 0), - lora_alpha=config.get("lora_alpha", 1), - freeze_layers=config.get("freeze_layers", 12), - add_projection_head=config.get("add_projection_head", True), - use_gradient_checkpointing=False, # Not needed for inference - ) - - # Load weights + # Check if we need to load LoRA weights if config.get("peft_applied", False) and PEFT_AVAILABLE: - # Load LoRA weights + # Create model WITHOUT LoRA (lora_r=0) - we'll load LoRA weights separately + model = cls( + vision_model=clip_model.vision_model, + lora_r=0, # Don't apply LoRA in constructor + lora_alpha=config.get("lora_alpha", 1), + freeze_layers=config.get("freeze_layers", 12), + add_projection_head=config.get("add_projection_head", True), + use_gradient_checkpointing=False, + ) + + # Load LoRA weights from checkpoint lora_path = model_path / "vision_lora" if lora_path.exists(): model.vision_model = PeftModel.from_pretrained( model.vision_model, lora_path ) + model.peft_applied = True + model.lora_r = config.get("lora_r", 16) + # Load projection head proj_path = model_path / "projection_head.bin" if proj_path.exists(): - model.projection.load_state_dict(torch.load(proj_path)) + model.projection.load_state_dict( + torch.load(proj_path, map_location="cpu") + ) else: - # Load full model state + # No LoRA - create model and load full state + model = cls( + vision_model=clip_model.vision_model, + lora_r=0, + lora_alpha=config.get("lora_alpha", 1), + freeze_layers=config.get("freeze_layers", 12), + add_projection_head=config.get("add_projection_head", True), + use_gradient_checkpointing=False, + ) + weights_path = model_path / "pytorch_model.bin" if weights_path.exists(): - model.load_state_dict(torch.load(weights_path)) + model.load_state_dict( + torch.load(weights_path, map_location="cpu") + ) if device is not None: model = model.to(device)