Fix double LoRA application when loading fine-tuned model

The from_pretrained method was applying LoRA twice:
1. In the constructor via lora_r parameter
2. When loading with PeftModel.from_pretrained()

Now creates model with lora_r=0 and loads LoRA weights separately.

Note: Warning about "missing adapter keys" for layers 0-11 is expected
since those layers are frozen and don't have LoRA adapters.
This commit is contained in:
Rick McEwen
2026-01-05 11:50:10 -05:00
parent e5482a2d9e
commit 1bf9985def

View File

@ -250,33 +250,49 @@ class LogoFineTunedCLIP(nn.Module):
# Load base CLIP model # Load base CLIP model
clip_model = CLIPModel.from_pretrained(base_model) clip_model = CLIPModel.from_pretrained(base_model)
# Create model instance # Check if we need to load LoRA weights
if config.get("peft_applied", False) and PEFT_AVAILABLE:
# Create model WITHOUT LoRA (lora_r=0) - we'll load LoRA weights separately
model = cls( model = cls(
vision_model=clip_model.vision_model, vision_model=clip_model.vision_model,
lora_r=config.get("lora_r", 0), lora_r=0, # Don't apply LoRA in constructor
lora_alpha=config.get("lora_alpha", 1), lora_alpha=config.get("lora_alpha", 1),
freeze_layers=config.get("freeze_layers", 12), freeze_layers=config.get("freeze_layers", 12),
add_projection_head=config.get("add_projection_head", True), add_projection_head=config.get("add_projection_head", True),
use_gradient_checkpointing=False, # Not needed for inference use_gradient_checkpointing=False,
) )
# Load weights # Load LoRA weights from checkpoint
if config.get("peft_applied", False) and PEFT_AVAILABLE:
# Load LoRA weights
lora_path = model_path / "vision_lora" lora_path = model_path / "vision_lora"
if lora_path.exists(): if lora_path.exists():
model.vision_model = PeftModel.from_pretrained( model.vision_model = PeftModel.from_pretrained(
model.vision_model, lora_path model.vision_model, lora_path
) )
model.peft_applied = True
model.lora_r = config.get("lora_r", 16)
# Load projection head # Load projection head
proj_path = model_path / "projection_head.bin" proj_path = model_path / "projection_head.bin"
if proj_path.exists(): if proj_path.exists():
model.projection.load_state_dict(torch.load(proj_path)) model.projection.load_state_dict(
torch.load(proj_path, map_location="cpu")
)
else: else:
# Load full model state # No LoRA - create model and load full state
model = cls(
vision_model=clip_model.vision_model,
lora_r=0,
lora_alpha=config.get("lora_alpha", 1),
freeze_layers=config.get("freeze_layers", 12),
add_projection_head=config.get("add_projection_head", True),
use_gradient_checkpointing=False,
)
weights_path = model_path / "pytorch_model.bin" weights_path = model_path / "pytorch_model.bin"
if weights_path.exists(): if weights_path.exists():
model.load_state_dict(torch.load(weights_path)) model.load_state_dict(
torch.load(weights_path, map_location="cpu")
)
if device is not None: if device is not None:
model = model.to(device) model = model.to(device)