diff --git a/training/trainer.py b/training/trainer.py index e624283..07a8da6 100644 --- a/training/trainer.py +++ b/training/trainer.py @@ -169,16 +169,11 @@ class Trainer: "val_neg_sim": val_metrics["mean_neg_sim"], }) - # Checkpointing based on separation (primary) or loss (secondary) - improved = False + # Checkpointing based on separation (gap between pos and neg similarity) + # This is the key metric for contrastive learning quality if val_metrics["separation"] > self.best_val_separation + self.config.min_delta: self.best_val_separation = val_metrics["separation"] - improved = True - elif val_metrics["loss"] < self.best_val_loss - self.config.min_delta: - self.best_val_loss = val_metrics["loss"] - improved = True - - if improved: + self.best_val_loss = val_metrics["loss"] # Track for reference self.patience_counter = 0 self._save_checkpoint("best.pt") self.logger.info("New best model saved!")