Compare commits
8 Commits
44e8b6ae7d
...
55abb1217c
| Author | SHA1 | Date | |
|---|---|---|---|
| 55abb1217c | |||
| 14a1bda3fa | |||
| 32bfefc022 | |||
| f74d4b6981 | |||
| 6685af72d9 | |||
| 1bf9985def | |||
| e5482a2d9e | |||
| 99e5781c91 |
@ -114,9 +114,12 @@ min_delta: 0.001
|
|||||||
|
|
||||||
### Test Fine-Tuned Model
|
### Test Fine-Tuned Model
|
||||||
|
|
||||||
|
**Important**: The fine-tuned model requires a higher threshold (0.82) than baseline (0.75).
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uv run python test_logo_detection.py -n 50 \
|
uv run python test_logo_detection.py -n 50 \
|
||||||
-e models/logo_detection/clip_finetuned \
|
-e models/logo_detection/clip_finetuned \
|
||||||
|
-t 0.82 \
|
||||||
--matching-method multi-ref \
|
--matching-method multi-ref \
|
||||||
--seed 42
|
--seed 42
|
||||||
```
|
```
|
||||||
@ -124,26 +127,58 @@ uv run python test_logo_detection.py -n 50 \
|
|||||||
### Compare with Baseline
|
### Compare with Baseline
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Baseline CLIP
|
# Baseline CLIP (threshold 0.75)
|
||||||
uv run python test_logo_detection.py -n 50 \
|
uv run python test_logo_detection.py -n 50 \
|
||||||
-e openai/clip-vit-large-patch14 \
|
-e openai/clip-vit-large-patch14 \
|
||||||
|
-t 0.75 \
|
||||||
--matching-method multi-ref \
|
--matching-method multi-ref \
|
||||||
--seed 42
|
--seed 42
|
||||||
|
|
||||||
# Fine-tuned model
|
# Fine-tuned model (threshold 0.82)
|
||||||
uv run python test_logo_detection.py -n 50 \
|
uv run python test_logo_detection.py -n 50 \
|
||||||
-e models/logo_detection/clip_finetuned \
|
-e models/logo_detection/clip_finetuned \
|
||||||
|
-t 0.82 \
|
||||||
--matching-method multi-ref \
|
--matching-method multi-ref \
|
||||||
--seed 42
|
--seed 42
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Threshold Selection
|
||||||
|
|
||||||
|
The fine-tuned model requires a **higher similarity threshold** than baseline CLIP. This is because contrastive learning successfully pushed non-matching logo similarities much lower, changing the score distribution.
|
||||||
|
|
||||||
|
#### Similarity Distribution Analysis
|
||||||
|
|
||||||
|
| Metric | Baseline | Fine-tuned |
|
||||||
|
|--------|----------|------------|
|
||||||
|
| Wrong logos mean similarity | 0.66 | **0.44** |
|
||||||
|
| Wrong logos above 0.75 | 23.2% | **0.6%** |
|
||||||
|
| Correct logos mean similarity | 0.75 | 0.64 |
|
||||||
|
| Optimal threshold | 0.756 | **0.819** |
|
||||||
|
| F1 at optimal threshold | 67.1% | **71.9%** |
|
||||||
|
|
||||||
|
**Key insight**: The fine-tuned model dramatically reduced similarities to wrong logos (from 0.66 to 0.44 mean). This means at threshold 0.75, it correctly rejects far more non-matches, but needs a higher threshold to avoid false positives from scores that bunch up just above 0.75.
|
||||||
|
|
||||||
|
#### Analyze Similarity Distribution
|
||||||
|
|
||||||
|
To find the optimal threshold for your model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run detailed similarity analysis
|
||||||
|
./analyze_similarity_distribution.sh --model finetuned
|
||||||
|
|
||||||
|
# Or analyze both models
|
||||||
|
./analyze_similarity_distribution.sh --model both
|
||||||
|
```
|
||||||
|
|
||||||
|
This outputs distribution statistics and suggests an optimal threshold based on the data.
|
||||||
|
|
||||||
### Expected Metrics
|
### Expected Metrics
|
||||||
|
|
||||||
| Metric | Baseline CLIP | Target (Fine-tuned) |
|
| Metric | Baseline (t=0.75) | Fine-tuned (t=0.82) |
|
||||||
|--------|---------------|---------------------|
|
|--------|-------------------|---------------------|
|
||||||
| Precision | ~49% | >70% |
|
| Precision | ~49% | >65% |
|
||||||
| Recall | ~77% | >75% |
|
| Recall | ~77% | >70% |
|
||||||
| F1 Score | ~60% | >72% |
|
| F1 Score | ~60% | >70% |
|
||||||
|
|
||||||
Training metrics to monitor:
|
Training metrics to monitor:
|
||||||
- Mean positive similarity: target > 0.85
|
- Mean positive similarity: target > 0.85
|
||||||
|
|||||||
141
analyze_similarity_distribution.sh
Executable file
141
analyze_similarity_distribution.sh
Executable file
@ -0,0 +1,141 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Analyze similarity distribution for baseline and fine-tuned models.
|
||||||
|
#
|
||||||
|
# This script runs the test with --similarity-details to output detailed
|
||||||
|
# statistics about how the models score matches vs non-matches.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./analyze_similarity_distribution.sh
|
||||||
|
# ./analyze_similarity_distribution.sh --model finetuned
|
||||||
|
# ./analyze_similarity_distribution.sh --model baseline
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Default parameters
|
||||||
|
NUM_LOGOS="${NUM_LOGOS:-50}"
|
||||||
|
SEED="${SEED:-42}"
|
||||||
|
THRESHOLD="${THRESHOLD:-0.75}"
|
||||||
|
REFS_PER_LOGO="${REFS_PER_LOGO:-3}"
|
||||||
|
MARGIN="${MARGIN:-0.05}"
|
||||||
|
MODEL="${MODEL:-both}"
|
||||||
|
|
||||||
|
# Model paths
|
||||||
|
BASELINE_MODEL="openai/clip-vit-large-patch14"
|
||||||
|
FINETUNED_MODEL="models/logo_detection/clip_finetuned"
|
||||||
|
|
||||||
|
# Output directory
|
||||||
|
OUTPUT_DIR="similarity_analysis"
|
||||||
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
-n|--num-logos)
|
||||||
|
NUM_LOGOS="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-s|--seed)
|
||||||
|
SEED="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-t|--threshold)
|
||||||
|
THRESHOLD="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--model)
|
||||||
|
MODEL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--finetuned-path)
|
||||||
|
FINETUNED_MODEL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
echo "Usage: $0 [OPTIONS]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " -n, --num-logos NUM Number of logos to test (default: 50)"
|
||||||
|
echo " -s, --seed SEED Random seed (default: 42)"
|
||||||
|
echo " -t, --threshold VAL Similarity threshold (default: 0.75)"
|
||||||
|
echo " --model MODEL Which model: 'baseline', 'finetuned', or 'both' (default: both)"
|
||||||
|
echo " --finetuned-path PATH Path to fine-tuned model"
|
||||||
|
echo " -h, --help Show this help message"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
mkdir -p "${OUTPUT_DIR}"
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo "SIMILARITY DISTRIBUTION ANALYSIS"
|
||||||
|
echo "============================================================"
|
||||||
|
echo ""
|
||||||
|
echo "Parameters:"
|
||||||
|
echo " Number of logos: ${NUM_LOGOS}"
|
||||||
|
echo " Random seed: ${SEED}"
|
||||||
|
echo " Threshold: ${THRESHOLD}"
|
||||||
|
echo " Refs per logo: ${REFS_PER_LOGO}"
|
||||||
|
echo " Margin: ${MARGIN}"
|
||||||
|
echo " Model: ${MODEL}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Common test arguments
|
||||||
|
TEST_ARGS=(
|
||||||
|
-n "${NUM_LOGOS}"
|
||||||
|
-s "${SEED}"
|
||||||
|
-t "${THRESHOLD}"
|
||||||
|
--refs-per-logo "${REFS_PER_LOGO}"
|
||||||
|
--margin "${MARGIN}"
|
||||||
|
--matching-method multi-ref
|
||||||
|
--similarity-details
|
||||||
|
--clear-cache
|
||||||
|
)
|
||||||
|
|
||||||
|
run_analysis() {
|
||||||
|
local model_name="$1"
|
||||||
|
local model_path="$2"
|
||||||
|
local output_file="${OUTPUT_DIR}/${model_name}_similarity_${TIMESTAMP}.txt"
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Analyzing: ${model_name}"
|
||||||
|
echo "Model: ${model_path}"
|
||||||
|
echo "Output: ${output_file}"
|
||||||
|
echo "============================================================"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
uv run python test_logo_detection.py \
|
||||||
|
"${TEST_ARGS[@]}" \
|
||||||
|
-e "${model_path}" \
|
||||||
|
2>&1 | tee "${output_file}"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Results saved to: ${output_file}"
|
||||||
|
echo ""
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run analysis based on model selection
|
||||||
|
if [[ "${MODEL}" == "baseline" ]] || [[ "${MODEL}" == "both" ]]; then
|
||||||
|
run_analysis "baseline" "${BASELINE_MODEL}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "${MODEL}" == "finetuned" ]] || [[ "${MODEL}" == "both" ]]; then
|
||||||
|
if [ ! -d "${FINETUNED_MODEL}" ]; then
|
||||||
|
echo "Warning: Fine-tuned model not found at ${FINETUNED_MODEL}"
|
||||||
|
echo "Skipping fine-tuned model analysis."
|
||||||
|
else
|
||||||
|
run_analysis "finetuned" "${FINETUNED_MODEL}"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Analysis complete!"
|
||||||
|
echo "Results saved to: ${OUTPUT_DIR}/"
|
||||||
|
echo "============================================================"
|
||||||
191
compare_finetuned_vs_baseline.sh
Executable file
191
compare_finetuned_vs_baseline.sh
Executable file
@ -0,0 +1,191 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Compare fine-tuned CLIP model against baseline CLIP for logo recognition.
|
||||||
|
#
|
||||||
|
# This script runs the same test suite on both models and outputs results
|
||||||
|
# for easy comparison.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./compare_finetuned_vs_baseline.sh
|
||||||
|
# ./compare_finetuned_vs_baseline.sh --num-logos 100
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Default parameters
|
||||||
|
NUM_LOGOS="${NUM_LOGOS:-50}"
|
||||||
|
SEED="${SEED:-42}"
|
||||||
|
THRESHOLD="${THRESHOLD:-0.7}"
|
||||||
|
DETR_THRESHOLD="${DETR_THRESHOLD:-0.5}"
|
||||||
|
REFS_PER_LOGO="${REFS_PER_LOGO:-3}"
|
||||||
|
MARGIN="${MARGIN:-0.05}"
|
||||||
|
POSITIVE_SAMPLES="${POSITIVE_SAMPLES:-5}"
|
||||||
|
NEGATIVE_SAMPLES="${NEGATIVE_SAMPLES:-20}"
|
||||||
|
|
||||||
|
# Model paths
|
||||||
|
BASELINE_MODEL="openai/clip-vit-large-patch14"
|
||||||
|
FINETUNED_MODEL="models/logo_detection/clip_finetuned"
|
||||||
|
|
||||||
|
# Output files
|
||||||
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
|
OUTPUT_DIR="comparison_results"
|
||||||
|
BASELINE_OUTPUT="${OUTPUT_DIR}/baseline_${TIMESTAMP}.txt"
|
||||||
|
FINETUNED_OUTPUT="${OUTPUT_DIR}/finetuned_${TIMESTAMP}.txt"
|
||||||
|
SUMMARY_OUTPUT="${OUTPUT_DIR}/comparison_summary_${TIMESTAMP}.txt"
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
-n|--num-logos)
|
||||||
|
NUM_LOGOS="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-s|--seed)
|
||||||
|
SEED="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-t|--threshold)
|
||||||
|
THRESHOLD="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--refs-per-logo)
|
||||||
|
REFS_PER_LOGO="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--margin)
|
||||||
|
MARGIN="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--finetuned-model)
|
||||||
|
FINETUNED_MODEL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
echo "Usage: $0 [OPTIONS]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " -n, --num-logos NUM Number of logos to test (default: 50)"
|
||||||
|
echo " -s, --seed SEED Random seed for reproducibility (default: 42)"
|
||||||
|
echo " -t, --threshold VAL Similarity threshold (default: 0.7)"
|
||||||
|
echo " --refs-per-logo NUM Reference images per logo (default: 3)"
|
||||||
|
echo " --margin VAL Margin for matching (default: 0.05)"
|
||||||
|
echo " --finetuned-model PATH Path to fine-tuned model"
|
||||||
|
echo " -h, --help Show this help message"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
mkdir -p "${OUTPUT_DIR}"
|
||||||
|
|
||||||
|
# Check if fine-tuned model exists
|
||||||
|
if [ ! -d "${FINETUNED_MODEL}" ]; then
|
||||||
|
echo "Error: Fine-tuned model not found at ${FINETUNED_MODEL}"
|
||||||
|
echo "Please train the model first using: uv run python train_clip_logo.py --config configs/jetson_orin.yaml"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo "CLIP Logo Recognition: Fine-tuned vs Baseline Comparison"
|
||||||
|
echo "============================================================"
|
||||||
|
echo ""
|
||||||
|
echo "Parameters:"
|
||||||
|
echo " Number of logos: ${NUM_LOGOS}"
|
||||||
|
echo " Random seed: ${SEED}"
|
||||||
|
echo " Threshold: ${THRESHOLD}"
|
||||||
|
echo " DETR threshold: ${DETR_THRESHOLD}"
|
||||||
|
echo " Refs per logo: ${REFS_PER_LOGO}"
|
||||||
|
echo " Margin: ${MARGIN}"
|
||||||
|
echo " Positive samples: ${POSITIVE_SAMPLES}"
|
||||||
|
echo " Negative samples: ${NEGATIVE_SAMPLES}"
|
||||||
|
echo ""
|
||||||
|
echo "Models:"
|
||||||
|
echo " Baseline: ${BASELINE_MODEL}"
|
||||||
|
echo " Fine-tuned: ${FINETUNED_MODEL}"
|
||||||
|
echo ""
|
||||||
|
echo "Output:"
|
||||||
|
echo " Baseline results: ${BASELINE_OUTPUT}"
|
||||||
|
echo " Fine-tuned results: ${FINETUNED_OUTPUT}"
|
||||||
|
echo " Summary: ${SUMMARY_OUTPUT}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Common test arguments
|
||||||
|
TEST_ARGS=(
|
||||||
|
-n "${NUM_LOGOS}"
|
||||||
|
-s "${SEED}"
|
||||||
|
-t "${THRESHOLD}"
|
||||||
|
-d "${DETR_THRESHOLD}"
|
||||||
|
--refs-per-logo "${REFS_PER_LOGO}"
|
||||||
|
--margin "${MARGIN}"
|
||||||
|
--positive-samples "${POSITIVE_SAMPLES}"
|
||||||
|
--negative-samples "${NEGATIVE_SAMPLES}"
|
||||||
|
--matching-method multi-ref
|
||||||
|
--clear-cache
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run baseline test
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Testing BASELINE model: ${BASELINE_MODEL}"
|
||||||
|
echo "============================================================"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
uv run python test_logo_detection.py \
|
||||||
|
"${TEST_ARGS[@]}" \
|
||||||
|
-e "${BASELINE_MODEL}" \
|
||||||
|
2>&1 | tee "${BASELINE_OUTPUT}"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Baseline results saved to: ${BASELINE_OUTPUT}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Run fine-tuned test
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Testing FINE-TUNED model: ${FINETUNED_MODEL}"
|
||||||
|
echo "============================================================"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
uv run python test_logo_detection.py \
|
||||||
|
"${TEST_ARGS[@]}" \
|
||||||
|
-e "${FINETUNED_MODEL}" \
|
||||||
|
2>&1 | tee "${FINETUNED_OUTPUT}"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Fine-tuned results saved to: ${FINETUNED_OUTPUT}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Extract and compare key metrics
|
||||||
|
echo "============================================================"
|
||||||
|
echo "COMPARISON SUMMARY"
|
||||||
|
echo "============================================================" | tee "${SUMMARY_OUTPUT}"
|
||||||
|
echo "" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo "Test Parameters:" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo " Logos: ${NUM_LOGOS}, Seed: ${SEED}, Threshold: ${THRESHOLD}" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo " Method: multi-ref, Refs/logo: ${REFS_PER_LOGO}, Margin: ${MARGIN}" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo "" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
|
||||||
|
echo "BASELINE (${BASELINE_MODEL}):" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
grep -E "(Precision|Recall|F1 Score|True Positives|False Positives|False Negatives)" "${BASELINE_OUTPUT}" | head -6 | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo "" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
|
||||||
|
echo "FINE-TUNED (${FINETUNED_MODEL}):" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
grep -E "(Precision|Recall|F1 Score|True Positives|False Positives|False Negatives)" "${FINETUNED_OUTPUT}" | head -6 | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo "" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
|
||||||
|
# Extract F1 scores for quick comparison
|
||||||
|
BASELINE_F1=$(grep "F1 Score" "${BASELINE_OUTPUT}" | head -1 | grep -oE "[0-9]+\.[0-9]+%" | head -1 || echo "N/A")
|
||||||
|
FINETUNED_F1=$(grep "F1 Score" "${FINETUNED_OUTPUT}" | head -1 | grep -oE "[0-9]+\.[0-9]+%" | head -1 || echo "N/A")
|
||||||
|
|
||||||
|
echo "------------------------------------------------------------" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo "F1 SCORE COMPARISON:" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo " Baseline: ${BASELINE_F1}" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo " Fine-tuned: ${FINETUNED_F1}" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo "------------------------------------------------------------" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo "" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo "Full results saved to: ${OUTPUT_DIR}/" | tee -a "${SUMMARY_OUTPUT}"
|
||||||
|
echo ""
|
||||||
|
echo "Done!"
|
||||||
70
configs/cloud_rtx4090_image_split.yaml
Normal file
70
configs/cloud_rtx4090_image_split.yaml
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
# Training configuration for RTX 4090 (24GB VRAM) with IMAGE-LEVEL splits
|
||||||
|
#
|
||||||
|
# Combines RTX 4090 hardware optimizations with image-level splitting and
|
||||||
|
# gentler contrastive learning for better generalization.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# python train_clip_logo.py --config configs/cloud_rtx4090_image_split.yaml
|
||||||
|
#
|
||||||
|
# Estimated training time: 5-7 hours (more epochs than logo-level)
|
||||||
|
# Estimated cost on RunPod: ~$4
|
||||||
|
|
||||||
|
# Base model
|
||||||
|
base_model: "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Dataset paths
|
||||||
|
dataset_dir: "LogoDet-3K"
|
||||||
|
reference_dir: "reference_logos"
|
||||||
|
db_path: "test_data_mapping.db"
|
||||||
|
|
||||||
|
# Data split configuration - IMAGE LEVEL
|
||||||
|
# Each logo brand will have images in all splits, allowing the model
|
||||||
|
# to see some examples of each brand during training.
|
||||||
|
split_level: "image"
|
||||||
|
train_split: 0.7
|
||||||
|
val_split: 0.15
|
||||||
|
test_split: 0.15
|
||||||
|
|
||||||
|
# Larger batches for faster training on 24GB VRAM
|
||||||
|
batch_size: 32
|
||||||
|
logos_per_batch: 32
|
||||||
|
samples_per_logo: 4
|
||||||
|
gradient_accumulation_steps: 4 # Effective batch = 128
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
# Model architecture
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.1
|
||||||
|
freeze_layers: 12
|
||||||
|
use_gradient_checkpointing: true
|
||||||
|
|
||||||
|
# Training - GENTLER settings for better generalization
|
||||||
|
learning_rate: 5.0e-6 # Reduced from 1e-5
|
||||||
|
weight_decay: 0.01
|
||||||
|
warmup_steps: 500
|
||||||
|
max_epochs: 30 # More epochs with slower learning
|
||||||
|
mixed_precision: true
|
||||||
|
|
||||||
|
# Loss - HIGHER temperature for softer contrastive learning
|
||||||
|
temperature: 0.15 # Increased from 0.07
|
||||||
|
loss_type: "infonce"
|
||||||
|
triplet_margin: 0.2 # Reduced from 0.3
|
||||||
|
|
||||||
|
# Early stopping - more patience with gentler learning
|
||||||
|
patience: 7
|
||||||
|
min_delta: 0.001
|
||||||
|
|
||||||
|
# Output - separate directory for image-split model
|
||||||
|
checkpoint_dir: "checkpoints_image_split"
|
||||||
|
output_dir: "models/logo_detection/clip_finetuned_image_split"
|
||||||
|
save_every_n_epochs: 2 # Save frequently for cloud
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_every_n_steps: 10
|
||||||
|
eval_every_n_epochs: 1
|
||||||
|
|
||||||
|
seed: 42
|
||||||
|
use_hard_negatives: false
|
||||||
|
use_augmentation: true
|
||||||
|
augmentation_strength: "medium"
|
||||||
78
configs/image_level_splits.yaml
Normal file
78
configs/image_level_splits.yaml
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# Training configuration with IMAGE-LEVEL splits
|
||||||
|
#
|
||||||
|
# Unlike logo-level splits where test logos are completely unseen brands,
|
||||||
|
# image-level splits allow the model to see some images from each brand
|
||||||
|
# during training. This is less rigorous but more representative of
|
||||||
|
# real-world use where you have reference images for logos you want to detect.
|
||||||
|
#
|
||||||
|
# Also uses gentler contrastive learning settings to prevent over-separation.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# uv run python train_clip_logo.py --config configs/image_level_splits.yaml
|
||||||
|
|
||||||
|
# Base model
|
||||||
|
base_model: "openai/clip-vit-large-patch14"
|
||||||
|
|
||||||
|
# Dataset paths (relative to project root)
|
||||||
|
dataset_dir: "LogoDet-3K"
|
||||||
|
reference_dir: "reference_logos"
|
||||||
|
db_path: "test_data_mapping.db"
|
||||||
|
|
||||||
|
# Data split configuration
|
||||||
|
# split_level: "image" means images are split, not logo brands
|
||||||
|
# This allows test set to contain images from brands seen during training
|
||||||
|
split_level: "image"
|
||||||
|
train_split: 0.7
|
||||||
|
val_split: 0.15
|
||||||
|
test_split: 0.15
|
||||||
|
|
||||||
|
# Batch construction
|
||||||
|
batch_size: 16
|
||||||
|
logos_per_batch: 32
|
||||||
|
samples_per_logo: 4
|
||||||
|
gradient_accumulation_steps: 8
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
# Model architecture - same as before
|
||||||
|
lora_r: 16
|
||||||
|
lora_alpha: 32
|
||||||
|
lora_dropout: 0.1
|
||||||
|
freeze_layers: 12
|
||||||
|
use_gradient_checkpointing: true
|
||||||
|
|
||||||
|
# Training hyperparameters - GENTLER settings
|
||||||
|
learning_rate: 5.0e-6 # Reduced from 1e-5
|
||||||
|
weight_decay: 0.01
|
||||||
|
warmup_steps: 500
|
||||||
|
max_epochs: 30 # More epochs with slower learning
|
||||||
|
mixed_precision: true
|
||||||
|
|
||||||
|
# Loss function - HIGHER temperature for softer contrastive learning
|
||||||
|
temperature: 0.15 # Increased from 0.07
|
||||||
|
loss_type: "infonce"
|
||||||
|
triplet_margin: 0.2 # Reduced from 0.3
|
||||||
|
|
||||||
|
# Early stopping
|
||||||
|
patience: 7 # More patience with gentler learning
|
||||||
|
min_delta: 0.001
|
||||||
|
|
||||||
|
# Checkpoints and output
|
||||||
|
checkpoint_dir: "checkpoints_image_split"
|
||||||
|
output_dir: "models/logo_detection/clip_finetuned_image_split"
|
||||||
|
save_every_n_epochs: 5
|
||||||
|
|
||||||
|
# Logging
|
||||||
|
log_every_n_steps: 10
|
||||||
|
eval_every_n_epochs: 1
|
||||||
|
|
||||||
|
# Reproducibility
|
||||||
|
seed: 42
|
||||||
|
|
||||||
|
# Hard negative mining
|
||||||
|
use_hard_negatives: false
|
||||||
|
hard_negative_start_epoch: 10
|
||||||
|
hard_negatives_per_logo: 10
|
||||||
|
|
||||||
|
# Data augmentation
|
||||||
|
use_augmentation: true
|
||||||
|
augmentation_strength: "medium"
|
||||||
168
find_optimal_threshold.sh
Executable file
168
find_optimal_threshold.sh
Executable file
@ -0,0 +1,168 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#
|
||||||
|
# Find optimal similarity threshold for logo detection.
|
||||||
|
#
|
||||||
|
# Tests a range of thresholds and outputs precision/recall/F1 for each.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./find_optimal_threshold.sh
|
||||||
|
# ./find_optimal_threshold.sh --model finetuned
|
||||||
|
# ./find_optimal_threshold.sh --model baseline
|
||||||
|
# ./find_optimal_threshold.sh --thresholds "0.70 0.75 0.80 0.85"
|
||||||
|
#
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Default parameters
|
||||||
|
NUM_LOGOS="${NUM_LOGOS:-50}"
|
||||||
|
SEED="${SEED:-42}"
|
||||||
|
REFS_PER_LOGO="${REFS_PER_LOGO:-3}"
|
||||||
|
MARGIN="${MARGIN:-0.05}"
|
||||||
|
MODEL="${MODEL:-finetuned}"
|
||||||
|
USE_MAX_SIM="${USE_MAX_SIM:-false}"
|
||||||
|
|
||||||
|
# Default thresholds to test
|
||||||
|
THRESHOLDS="${THRESHOLDS:-0.70 0.72 0.74 0.76 0.78 0.80 0.82 0.84 0.86}"
|
||||||
|
|
||||||
|
# Model paths
|
||||||
|
BASELINE_MODEL="openai/clip-vit-large-patch14"
|
||||||
|
FINETUNED_MODEL="models/logo_detection/clip_finetuned"
|
||||||
|
|
||||||
|
# Output
|
||||||
|
OUTPUT_DIR="threshold_analysis"
|
||||||
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
|
|
||||||
|
# Parse command line arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
-n|--num-logos)
|
||||||
|
NUM_LOGOS="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
-s|--seed)
|
||||||
|
SEED="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--model)
|
||||||
|
MODEL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--thresholds)
|
||||||
|
THRESHOLDS="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--finetuned-path)
|
||||||
|
FINETUNED_MODEL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--use-max-similarity)
|
||||||
|
USE_MAX_SIM="true"
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
echo "Usage: $0 [OPTIONS]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " -n, --num-logos NUM Number of logos to test (default: 50)"
|
||||||
|
echo " -s, --seed SEED Random seed (default: 42)"
|
||||||
|
echo " --model MODEL Which model: 'baseline' or 'finetuned' (default: finetuned)"
|
||||||
|
echo " --thresholds \"T1 T2 ...\" Space-separated thresholds to test"
|
||||||
|
echo " --finetuned-path PATH Path to fine-tuned model"
|
||||||
|
echo " --use-max-similarity Use max instead of mean for multi-ref aggregation"
|
||||||
|
echo " -h, --help Show this help message"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Select model path
|
||||||
|
if [[ "${MODEL}" == "baseline" ]]; then
|
||||||
|
MODEL_PATH="${BASELINE_MODEL}"
|
||||||
|
else
|
||||||
|
MODEL_PATH="${FINETUNED_MODEL}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if fine-tuned model exists
|
||||||
|
if [[ "${MODEL}" == "finetuned" ]] && [ ! -d "${FINETUNED_MODEL}" ]; then
|
||||||
|
echo "Error: Fine-tuned model not found at ${FINETUNED_MODEL}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Create output directory
|
||||||
|
mkdir -p "${OUTPUT_DIR}"
|
||||||
|
OUTPUT_FILE="${OUTPUT_DIR}/${MODEL}_thresholds_${TIMESTAMP}.txt"
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo "THRESHOLD OPTIMIZATION"
|
||||||
|
echo "============================================================"
|
||||||
|
echo ""
|
||||||
|
echo "Model: ${MODEL} (${MODEL_PATH})"
|
||||||
|
echo "Thresholds: ${THRESHOLDS}"
|
||||||
|
echo "Logos: ${NUM_LOGOS}"
|
||||||
|
echo "Seed: ${SEED}"
|
||||||
|
echo "Max sim: ${USE_MAX_SIM}"
|
||||||
|
echo "Output: ${OUTPUT_FILE}"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Header for results
|
||||||
|
echo "============================================================" | tee "${OUTPUT_FILE}"
|
||||||
|
echo "THRESHOLD OPTIMIZATION RESULTS" | tee -a "${OUTPUT_FILE}"
|
||||||
|
echo "Model: ${MODEL} (${MODEL_PATH})" | tee -a "${OUTPUT_FILE}"
|
||||||
|
echo "============================================================" | tee -a "${OUTPUT_FILE}"
|
||||||
|
echo "" | tee -a "${OUTPUT_FILE}"
|
||||||
|
printf "%-10s %8s %8s %8s %8s %8s %8s\n" "Threshold" "TP" "FP" "FN" "Prec" "Recall" "F1" | tee -a "${OUTPUT_FILE}"
|
||||||
|
echo "--------------------------------------------------------------------" | tee -a "${OUTPUT_FILE}"
|
||||||
|
|
||||||
|
# Track best F1
|
||||||
|
BEST_F1=0
|
||||||
|
BEST_THRESHOLD=""
|
||||||
|
|
||||||
|
# Build extra args
|
||||||
|
EXTRA_ARGS=""
|
||||||
|
if [[ "${USE_MAX_SIM}" == "true" ]]; then
|
||||||
|
EXTRA_ARGS="--use-max-similarity"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Test each threshold
|
||||||
|
for THRESHOLD in ${THRESHOLDS}; do
|
||||||
|
# Run test and capture output
|
||||||
|
OUTPUT=$(uv run python test_logo_detection.py \
|
||||||
|
-n "${NUM_LOGOS}" \
|
||||||
|
-s "${SEED}" \
|
||||||
|
-t "${THRESHOLD}" \
|
||||||
|
--refs-per-logo "${REFS_PER_LOGO}" \
|
||||||
|
--margin "${MARGIN}" \
|
||||||
|
--matching-method multi-ref \
|
||||||
|
-e "${MODEL_PATH}" \
|
||||||
|
${EXTRA_ARGS} \
|
||||||
|
2>/dev/null)
|
||||||
|
|
||||||
|
# Extract metrics
|
||||||
|
TP=$(echo "${OUTPUT}" | grep "True Positives" | grep -oE "[0-9]+" | head -1)
|
||||||
|
FP=$(echo "${OUTPUT}" | grep "False Positives" | grep -oE "[0-9]+" | head -1)
|
||||||
|
FN=$(echo "${OUTPUT}" | grep "False Negatives" | grep -oE "[0-9]+" | head -1)
|
||||||
|
PREC=$(echo "${OUTPUT}" | grep "Precision:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||||
|
RECALL=$(echo "${OUTPUT}" | grep "Recall:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||||
|
F1=$(echo "${OUTPUT}" | grep "F1 Score:" | grep -oE "[0-9]+\.[0-9]+%" | head -1)
|
||||||
|
|
||||||
|
# Print row
|
||||||
|
printf "%-10s %8s %8s %8s %8s %8s %8s\n" "${THRESHOLD}" "${TP}" "${FP}" "${FN}" "${PREC}" "${RECALL}" "${F1}" | tee -a "${OUTPUT_FILE}"
|
||||||
|
|
||||||
|
# Track best F1
|
||||||
|
F1_NUM=$(echo "${F1}" | tr -d '%')
|
||||||
|
BEST_NUM=$(echo "${BEST_F1}" | tr -d '%')
|
||||||
|
if (( $(echo "${F1_NUM} > ${BEST_NUM}" | bc -l) )); then
|
||||||
|
BEST_F1="${F1}"
|
||||||
|
BEST_THRESHOLD="${THRESHOLD}"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "--------------------------------------------------------------------" | tee -a "${OUTPUT_FILE}"
|
||||||
|
echo "" | tee -a "${OUTPUT_FILE}"
|
||||||
|
echo "BEST THRESHOLD: ${BEST_THRESHOLD} (F1 = ${BEST_F1})" | tee -a "${OUTPUT_FILE}"
|
||||||
|
echo "" | tee -a "${OUTPUT_FILE}"
|
||||||
|
echo "Results saved to: ${OUTPUT_FILE}"
|
||||||
@ -265,6 +265,11 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable verbose logging",
|
help="Enable verbose logging",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--similarity-details",
|
||||||
|
action="store_true",
|
||||||
|
help="Output detailed similarity scores for each detection (for analyzing score distributions)",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-cache",
|
"--no-cache",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -411,6 +416,16 @@ def main():
|
|||||||
# Detailed results for analysis
|
# Detailed results for analysis
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
|
# Similarity distribution tracking (for --similarity-details)
|
||||||
|
similarity_details = {
|
||||||
|
"true_positive_sims": [], # Similarities for correct matches
|
||||||
|
"false_positive_sims": [], # Similarities for wrong matches
|
||||||
|
"missed_best_sims": [], # Best similarity for logos that should have matched but didn't
|
||||||
|
"all_positive_sims": [], # All similarities between detected regions and correct logos
|
||||||
|
"all_negative_sims": [], # All similarities between detected regions and wrong logos
|
||||||
|
"detection_details": [], # Per-detection breakdown
|
||||||
|
}
|
||||||
|
|
||||||
# Process test images
|
# Process test images
|
||||||
for test_filename in tqdm(test_images, desc="Testing"):
|
for test_filename in tqdm(test_images, desc="Testing"):
|
||||||
test_path = test_images_dir / test_filename
|
test_path = test_images_dir / test_filename
|
||||||
@ -445,7 +460,38 @@ def main():
|
|||||||
|
|
||||||
# Match detections against references using selected method
|
# Match detections against references using selected method
|
||||||
matched_logos: Set[str] = set()
|
matched_logos: Set[str] = set()
|
||||||
for detection in detections:
|
for det_idx, detection in enumerate(detections):
|
||||||
|
# Compute similarities to all reference logos for detailed analysis
|
||||||
|
if args.similarity_details:
|
||||||
|
all_sims = {}
|
||||||
|
for logo_name, ref_emb_list in multi_ref_embeddings.items():
|
||||||
|
sims = []
|
||||||
|
for ref_emb in ref_emb_list:
|
||||||
|
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||||
|
sims.append(sim)
|
||||||
|
# Use mean or max based on setting
|
||||||
|
if args.use_max_similarity:
|
||||||
|
all_sims[logo_name] = max(sims) if sims else 0
|
||||||
|
else:
|
||||||
|
all_sims[logo_name] = sum(sims) / len(sims) if sims else 0
|
||||||
|
|
||||||
|
# Track positive vs negative similarities
|
||||||
|
for sim in sims:
|
||||||
|
if logo_name in expected_logos:
|
||||||
|
similarity_details["all_positive_sims"].append(sim)
|
||||||
|
else:
|
||||||
|
similarity_details["all_negative_sims"].append(sim)
|
||||||
|
|
||||||
|
# Store detection details
|
||||||
|
sorted_sims = sorted(all_sims.items(), key=lambda x: -x[1])
|
||||||
|
similarity_details["detection_details"].append({
|
||||||
|
"image": test_filename,
|
||||||
|
"detection_idx": det_idx,
|
||||||
|
"expected_logos": list(expected_logos),
|
||||||
|
"top_5_matches": sorted_sims[:5],
|
||||||
|
"detr_score": detection.get("score", 0),
|
||||||
|
})
|
||||||
|
|
||||||
if args.matching_method == "simple":
|
if args.matching_method == "simple":
|
||||||
# Simple matching: return ALL logos above threshold
|
# Simple matching: return ALL logos above threshold
|
||||||
all_matches = detector.find_all_matches(
|
all_matches = detector.find_all_matches(
|
||||||
@ -457,16 +503,21 @@ def main():
|
|||||||
matched_logos.add(label)
|
matched_logos.add(label)
|
||||||
|
|
||||||
# Check if this is a correct match
|
# Check if this is a correct match
|
||||||
if label in expected_logos:
|
is_correct = label in expected_logos
|
||||||
|
if is_correct:
|
||||||
true_positives += 1
|
true_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["true_positive_sims"].append(similarity)
|
||||||
else:
|
else:
|
||||||
false_positives += 1
|
false_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["false_positive_sims"].append(similarity)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test_image": test_filename,
|
"test_image": test_filename,
|
||||||
"matched_logo": label,
|
"matched_logo": label,
|
||||||
"similarity": similarity,
|
"similarity": similarity,
|
||||||
"correct": label in expected_logos,
|
"correct": is_correct,
|
||||||
})
|
})
|
||||||
|
|
||||||
elif args.matching_method == "margin":
|
elif args.matching_method == "margin":
|
||||||
@ -481,16 +532,21 @@ def main():
|
|||||||
label, similarity = match_result
|
label, similarity = match_result
|
||||||
matched_logos.add(label)
|
matched_logos.add(label)
|
||||||
|
|
||||||
if label in expected_logos:
|
is_correct = label in expected_logos
|
||||||
|
if is_correct:
|
||||||
true_positives += 1
|
true_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["true_positive_sims"].append(similarity)
|
||||||
else:
|
else:
|
||||||
false_positives += 1
|
false_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["false_positive_sims"].append(similarity)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test_image": test_filename,
|
"test_image": test_filename,
|
||||||
"matched_logo": label,
|
"matched_logo": label,
|
||||||
"similarity": similarity,
|
"similarity": similarity,
|
||||||
"correct": label in expected_logos,
|
"correct": is_correct,
|
||||||
})
|
})
|
||||||
|
|
||||||
else: # multi-ref
|
else: # multi-ref
|
||||||
@ -507,16 +563,21 @@ def main():
|
|||||||
label, similarity, num_matching = match_result
|
label, similarity, num_matching = match_result
|
||||||
matched_logos.add(label)
|
matched_logos.add(label)
|
||||||
|
|
||||||
if label in expected_logos:
|
is_correct = label in expected_logos
|
||||||
|
if is_correct:
|
||||||
true_positives += 1
|
true_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["true_positive_sims"].append(similarity)
|
||||||
else:
|
else:
|
||||||
false_positives += 1
|
false_positives += 1
|
||||||
|
if args.similarity_details:
|
||||||
|
similarity_details["false_positive_sims"].append(similarity)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test_image": test_filename,
|
"test_image": test_filename,
|
||||||
"matched_logo": label,
|
"matched_logo": label,
|
||||||
"similarity": similarity,
|
"similarity": similarity,
|
||||||
"correct": label in expected_logos,
|
"correct": is_correct,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Count missed detections (false negatives)
|
# Count missed detections (false negatives)
|
||||||
@ -524,6 +585,15 @@ def main():
|
|||||||
false_negatives += len(missed)
|
false_negatives += len(missed)
|
||||||
|
|
||||||
for missed_logo in missed:
|
for missed_logo in missed:
|
||||||
|
# Track best similarity for missed logos (if we have detections)
|
||||||
|
if args.similarity_details and detections:
|
||||||
|
best_sim_for_missed = 0
|
||||||
|
for detection in detections:
|
||||||
|
for ref_emb in multi_ref_embeddings.get(missed_logo, []):
|
||||||
|
sim = detector.compare_embeddings(detection["embedding"], ref_emb)
|
||||||
|
best_sim_for_missed = max(best_sim_for_missed, sim)
|
||||||
|
similarity_details["missed_best_sims"].append(best_sim_for_missed)
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test_image": test_filename,
|
"test_image": test_filename,
|
||||||
"matched_logo": None,
|
"matched_logo": None,
|
||||||
@ -593,6 +663,10 @@ def main():
|
|||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Print similarity distribution details if requested
|
||||||
|
if args.similarity_details:
|
||||||
|
print_similarity_details(similarity_details, args.threshold)
|
||||||
|
|
||||||
# Write results to file if requested
|
# Write results to file if requested
|
||||||
if args.output_file:
|
if args.output_file:
|
||||||
write_results_to_file(
|
write_results_to_file(
|
||||||
@ -612,6 +686,116 @@ def main():
|
|||||||
print(f"\nResults appended to: {args.output_file}")
|
print(f"\nResults appended to: {args.output_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_similarity_details(details: dict, threshold: float):
|
||||||
|
"""Print detailed similarity distribution analysis."""
|
||||||
|
import statistics
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SIMILARITY DISTRIBUTION ANALYSIS")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Helper to compute stats
|
||||||
|
def compute_stats(values, name):
|
||||||
|
if not values:
|
||||||
|
print(f"\n{name}: No data")
|
||||||
|
return
|
||||||
|
print(f"\n{name} (n={len(values)}):")
|
||||||
|
print(f" Min: {min(values):.4f}")
|
||||||
|
print(f" Max: {max(values):.4f}")
|
||||||
|
print(f" Mean: {statistics.mean(values):.4f}")
|
||||||
|
if len(values) > 1:
|
||||||
|
print(f" StdDev: {statistics.stdev(values):.4f}")
|
||||||
|
print(f" Median: {statistics.median(values):.4f}")
|
||||||
|
|
||||||
|
# Percentiles
|
||||||
|
sorted_vals = sorted(values)
|
||||||
|
n = len(sorted_vals)
|
||||||
|
p10 = sorted_vals[int(n * 0.10)] if n > 10 else sorted_vals[0]
|
||||||
|
p25 = sorted_vals[int(n * 0.25)] if n > 4 else sorted_vals[0]
|
||||||
|
p75 = sorted_vals[int(n * 0.75)] if n > 4 else sorted_vals[-1]
|
||||||
|
p90 = sorted_vals[int(n * 0.90)] if n > 10 else sorted_vals[-1]
|
||||||
|
print(f" P10: {p10:.4f}")
|
||||||
|
print(f" P25: {p25:.4f}")
|
||||||
|
print(f" P75: {p75:.4f}")
|
||||||
|
print(f" P90: {p90:.4f}")
|
||||||
|
|
||||||
|
# Count above/below threshold
|
||||||
|
above = sum(1 for v in values if v >= threshold)
|
||||||
|
below = sum(1 for v in values if v < threshold)
|
||||||
|
print(f" Above threshold ({threshold}): {above} ({100*above/len(values):.1f}%)")
|
||||||
|
print(f" Below threshold ({threshold}): {below} ({100*below/len(values):.1f}%)")
|
||||||
|
|
||||||
|
# Print distribution stats
|
||||||
|
compute_stats(details["true_positive_sims"], "TRUE POSITIVE similarities (correct matches)")
|
||||||
|
compute_stats(details["false_positive_sims"], "FALSE POSITIVE similarities (wrong matches)")
|
||||||
|
compute_stats(details["missed_best_sims"], "MISSED LOGO best similarities (false negatives)")
|
||||||
|
compute_stats(details["all_positive_sims"], "ALL similarities to CORRECT logos (per-ref)")
|
||||||
|
compute_stats(details["all_negative_sims"], "ALL similarities to WRONG logos (per-ref)")
|
||||||
|
|
||||||
|
# Overlap analysis
|
||||||
|
tp_sims = details["true_positive_sims"]
|
||||||
|
fp_sims = details["false_positive_sims"]
|
||||||
|
if tp_sims and fp_sims:
|
||||||
|
print("\n" + "-" * 40)
|
||||||
|
print("OVERLAP ANALYSIS:")
|
||||||
|
tp_min, tp_max = min(tp_sims), max(tp_sims)
|
||||||
|
fp_min, fp_max = min(fp_sims), max(fp_sims)
|
||||||
|
print(f" True Positives range: [{tp_min:.4f}, {tp_max:.4f}]")
|
||||||
|
print(f" False Positives range: [{fp_min:.4f}, {fp_max:.4f}]")
|
||||||
|
|
||||||
|
# Check overlap
|
||||||
|
overlap_min = max(tp_min, fp_min)
|
||||||
|
overlap_max = min(tp_max, fp_max)
|
||||||
|
if overlap_min < overlap_max:
|
||||||
|
print(f" OVERLAP REGION: [{overlap_min:.4f}, {overlap_max:.4f}]")
|
||||||
|
tp_in_overlap = sum(1 for v in tp_sims if overlap_min <= v <= overlap_max)
|
||||||
|
fp_in_overlap = sum(1 for v in fp_sims if overlap_min <= v <= overlap_max)
|
||||||
|
print(f" TPs in overlap: {tp_in_overlap} ({100*tp_in_overlap/len(tp_sims):.1f}%)")
|
||||||
|
print(f" FPs in overlap: {fp_in_overlap} ({100*fp_in_overlap/len(fp_sims):.1f}%)")
|
||||||
|
else:
|
||||||
|
print(" NO OVERLAP - distributions are separable!")
|
||||||
|
|
||||||
|
# Suggest optimal threshold
|
||||||
|
all_points = [(s, "tp") for s in tp_sims] + [(s, "fp") for s in fp_sims]
|
||||||
|
all_points.sort()
|
||||||
|
best_thresh = threshold
|
||||||
|
best_f1 = 0
|
||||||
|
total_tp = len(tp_sims)
|
||||||
|
total_fp = len(fp_sims)
|
||||||
|
|
||||||
|
for thresh in [p[0] for p in all_points]:
|
||||||
|
# At this threshold:
|
||||||
|
tp_above = sum(1 for s in tp_sims if s >= thresh)
|
||||||
|
fp_above = sum(1 for s in fp_sims if s >= thresh)
|
||||||
|
prec = tp_above / (tp_above + fp_above) if (tp_above + fp_above) > 0 else 0
|
||||||
|
rec = tp_above / total_tp if total_tp > 0 else 0
|
||||||
|
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0
|
||||||
|
if f1 > best_f1:
|
||||||
|
best_f1 = f1
|
||||||
|
best_thresh = thresh
|
||||||
|
|
||||||
|
print(f"\n SUGGESTED OPTIMAL THRESHOLD: {best_thresh:.4f}")
|
||||||
|
print(f" (would give F1 = {best_f1:.4f} on this data)")
|
||||||
|
|
||||||
|
# Print sample detection details
|
||||||
|
det_details = details["detection_details"]
|
||||||
|
if det_details:
|
||||||
|
print("\n" + "-" * 40)
|
||||||
|
print(f"SAMPLE DETECTION DETAILS (first 20 of {len(det_details)}):")
|
||||||
|
for i, det in enumerate(det_details[:20]):
|
||||||
|
expected = det["expected_logos"]
|
||||||
|
top5 = det["top_5_matches"]
|
||||||
|
print(f"\n [{i+1}] Image: {det['image']}")
|
||||||
|
print(f" Expected: {expected if expected else '(none)'}")
|
||||||
|
print(f" DETR score: {det['detr_score']:.3f}")
|
||||||
|
print(f" Top 5 matches:")
|
||||||
|
for logo, sim in top5:
|
||||||
|
marker = " <-- CORRECT" if logo in expected else ""
|
||||||
|
print(f" {sim:.4f} {logo}{marker}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
|
||||||
|
|
||||||
def write_results_to_file(
|
def write_results_to_file(
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
args,
|
args,
|
||||||
|
|||||||
@ -256,6 +256,7 @@ def main():
|
|||||||
test_split=config.test_split,
|
test_split=config.test_split,
|
||||||
seed=config.seed,
|
seed=config.seed,
|
||||||
augmentation_strength=config.augmentation_strength,
|
augmentation_strength=config.augmentation_strength,
|
||||||
|
split_level=getattr(config, 'split_level', 'logo'),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create trainer
|
# Create trainer
|
||||||
|
|||||||
@ -20,7 +20,8 @@ class TrainingConfig:
|
|||||||
reference_dir: str = "reference_logos"
|
reference_dir: str = "reference_logos"
|
||||||
db_path: str = "test_data_mapping.db"
|
db_path: str = "test_data_mapping.db"
|
||||||
|
|
||||||
# Data split ratios
|
# Data split configuration
|
||||||
|
split_level: str = "logo" # "logo" for brand-level, "image" for image-level
|
||||||
train_split: float = 0.7
|
train_split: float = 0.7
|
||||||
val_split: float = 0.15
|
val_split: float = 0.15
|
||||||
test_split: float = 0.15
|
test_split: float = 0.15
|
||||||
|
|||||||
@ -84,7 +84,7 @@ class LogoDataset:
|
|||||||
"""
|
"""
|
||||||
Manages logo data from the SQLite database.
|
Manages logo data from the SQLite database.
|
||||||
|
|
||||||
Handles loading logo-to-image mappings and splitting by logo brand.
|
Handles loading logo-to-image mappings and splitting by logo brand or image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -95,19 +95,57 @@ class LogoDataset:
|
|||||||
val_split: float = 0.15,
|
val_split: float = 0.15,
|
||||||
test_split: float = 0.15,
|
test_split: float = 0.15,
|
||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
|
split_level: str = "logo",
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Initialize the logo dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database
|
||||||
|
reference_dir: Directory containing reference logo images
|
||||||
|
train_split: Fraction for training
|
||||||
|
val_split: Fraction for validation
|
||||||
|
test_split: Fraction for testing
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
split_level: "logo" for brand-level splits (test on unseen brands),
|
||||||
|
"image" for image-level splits (test on unseen images
|
||||||
|
from seen brands)
|
||||||
|
"""
|
||||||
self.db_path = Path(db_path)
|
self.db_path = Path(db_path)
|
||||||
self.reference_dir = Path(reference_dir)
|
self.reference_dir = Path(reference_dir)
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
self.split_level = split_level
|
||||||
|
|
||||||
# Load logo-to-images mapping from database
|
# Load logo-to-images mapping from database
|
||||||
self.logo_to_images = self._load_logo_mappings()
|
self.logo_to_images = self._load_logo_mappings()
|
||||||
self.all_logos = list(self.logo_to_images.keys())
|
self.all_logos = list(self.logo_to_images.keys())
|
||||||
|
|
||||||
# Create logo-level splits
|
if split_level == "logo":
|
||||||
|
# Logo-level splits: test logos are completely unseen brands
|
||||||
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
self.train_logos, self.val_logos, self.test_logos = self._split_logos(
|
||||||
train_split, val_split, test_split
|
train_split, val_split, test_split
|
||||||
)
|
)
|
||||||
|
# For logo-level splits, each split has its own logos
|
||||||
|
self.train_logo_to_images = {
|
||||||
|
l: self.logo_to_images[l] for l in self.train_logos
|
||||||
|
}
|
||||||
|
self.val_logo_to_images = {
|
||||||
|
l: self.logo_to_images[l] for l in self.val_logos
|
||||||
|
}
|
||||||
|
self.test_logo_to_images = {
|
||||||
|
l: self.logo_to_images[l] for l in self.test_logos
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Image-level splits: all logos present in all splits, different images
|
||||||
|
(
|
||||||
|
self.train_logo_to_images,
|
||||||
|
self.val_logo_to_images,
|
||||||
|
self.test_logo_to_images,
|
||||||
|
) = self._split_images(train_split, val_split, test_split)
|
||||||
|
# All logos are in all splits
|
||||||
|
self.train_logos = list(self.train_logo_to_images.keys())
|
||||||
|
self.val_logos = list(self.val_logo_to_images.keys())
|
||||||
|
self.test_logos = list(self.test_logo_to_images.keys())
|
||||||
|
|
||||||
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
|
def _load_logo_mappings(self) -> Dict[str, List[Path]]:
|
||||||
"""Load logo name to image paths mapping from database."""
|
"""Load logo name to image paths mapping from database."""
|
||||||
@ -151,21 +189,74 @@ class LogoDataset:
|
|||||||
|
|
||||||
return train_logos, val_logos, test_logos
|
return train_logos, val_logos, test_logos
|
||||||
|
|
||||||
def get_split_info(self) -> Dict[str, int]:
|
def _split_images(
|
||||||
|
self,
|
||||||
|
train_split: float,
|
||||||
|
val_split: float,
|
||||||
|
test_split: float,
|
||||||
|
) -> Tuple[Dict[str, List[Path]], Dict[str, List[Path]], Dict[str, List[Path]]]:
|
||||||
|
"""
|
||||||
|
Split images within each logo brand for train/val/test.
|
||||||
|
|
||||||
|
Each logo brand will have images in all splits, allowing the model
|
||||||
|
to see some examples of each brand during training.
|
||||||
|
"""
|
||||||
|
random.seed(self.seed)
|
||||||
|
|
||||||
|
train_logo_to_images: Dict[str, List[Path]] = {}
|
||||||
|
val_logo_to_images: Dict[str, List[Path]] = {}
|
||||||
|
test_logo_to_images: Dict[str, List[Path]] = {}
|
||||||
|
|
||||||
|
for logo, images in self.logo_to_images.items():
|
||||||
|
# Shuffle images for this logo
|
||||||
|
shuffled_images = images.copy()
|
||||||
|
random.shuffle(shuffled_images)
|
||||||
|
|
||||||
|
n = len(shuffled_images)
|
||||||
|
if n == 1:
|
||||||
|
# Only one image: put in train only
|
||||||
|
train_logo_to_images[logo] = shuffled_images
|
||||||
|
continue
|
||||||
|
elif n == 2:
|
||||||
|
# Two images: one train, one val
|
||||||
|
train_logo_to_images[logo] = [shuffled_images[0]]
|
||||||
|
val_logo_to_images[logo] = [shuffled_images[1]]
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Normal split for 3+ images
|
||||||
|
train_end = max(1, int(n * train_split))
|
||||||
|
val_end = train_end + max(1, int(n * val_split))
|
||||||
|
|
||||||
|
train_images = shuffled_images[:train_end]
|
||||||
|
val_images = shuffled_images[train_end:val_end]
|
||||||
|
test_images = shuffled_images[val_end:]
|
||||||
|
|
||||||
|
# Ensure at least one image in train
|
||||||
|
if train_images:
|
||||||
|
train_logo_to_images[logo] = train_images
|
||||||
|
if val_images:
|
||||||
|
val_logo_to_images[logo] = val_images
|
||||||
|
if test_images:
|
||||||
|
test_logo_to_images[logo] = test_images
|
||||||
|
|
||||||
|
return train_logo_to_images, val_logo_to_images, test_logo_to_images
|
||||||
|
|
||||||
|
def get_split_info(self) -> Dict[str, any]:
|
||||||
"""Return information about the splits."""
|
"""Return information about the splits."""
|
||||||
return {
|
return {
|
||||||
|
"split_level": self.split_level,
|
||||||
"total_logos": len(self.all_logos),
|
"total_logos": len(self.all_logos),
|
||||||
"train_logos": len(self.train_logos),
|
"train_logos": len(self.train_logos),
|
||||||
"val_logos": len(self.val_logos),
|
"val_logos": len(self.val_logos),
|
||||||
"test_logos": len(self.test_logos),
|
"test_logos": len(self.test_logos),
|
||||||
"train_images": sum(
|
"train_images": sum(
|
||||||
len(self.logo_to_images[l]) for l in self.train_logos
|
len(imgs) for imgs in self.train_logo_to_images.values()
|
||||||
),
|
),
|
||||||
"val_images": sum(
|
"val_images": sum(
|
||||||
len(self.logo_to_images[l]) for l in self.val_logos
|
len(imgs) for imgs in self.val_logo_to_images.values()
|
||||||
),
|
),
|
||||||
"test_images": sum(
|
"test_images": sum(
|
||||||
len(self.logo_to_images[l]) for l in self.test_logos
|
len(imgs) for imgs in self.test_logo_to_images.values()
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -205,29 +296,33 @@ class LogoContrastiveDataset(Dataset):
|
|||||||
self.transform = transform
|
self.transform = transform
|
||||||
self.batches_per_epoch = batches_per_epoch
|
self.batches_per_epoch = batches_per_epoch
|
||||||
|
|
||||||
# Get logos for this split
|
# Get logos and their images for this split
|
||||||
|
# This respects both logo-level and image-level splits
|
||||||
if split == "train":
|
if split == "train":
|
||||||
self.logos = logo_data.train_logos
|
self.logos = logo_data.train_logos
|
||||||
|
self.logo_to_images = logo_data.train_logo_to_images
|
||||||
elif split == "val":
|
elif split == "val":
|
||||||
self.logos = logo_data.val_logos
|
self.logos = logo_data.val_logos
|
||||||
|
self.logo_to_images = logo_data.val_logo_to_images
|
||||||
else:
|
else:
|
||||||
self.logos = logo_data.test_logos
|
self.logos = logo_data.test_logos
|
||||||
|
self.logo_to_images = logo_data.test_logo_to_images
|
||||||
|
|
||||||
# Filter logos with enough samples
|
# Filter logos with enough samples for this split
|
||||||
self.valid_logos = [
|
self.valid_logos = [
|
||||||
logo for logo in self.logos
|
logo for logo in self.logos
|
||||||
if len(logo_data.logo_to_images[logo]) >= samples_per_logo
|
if logo in self.logo_to_images and len(self.logo_to_images[logo]) >= samples_per_logo
|
||||||
]
|
]
|
||||||
|
|
||||||
# For logos with fewer samples, we'll use with replacement
|
# For logos with fewer samples, we'll use with replacement
|
||||||
self.logos_needing_replacement = [
|
self.logos_needing_replacement = [
|
||||||
logo for logo in self.logos
|
logo for logo in self.logos
|
||||||
if len(logo_data.logo_to_images[logo]) < samples_per_logo
|
if logo in self.logo_to_images and len(self.logo_to_images[logo]) < samples_per_logo
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create label mapping
|
# Create label mapping (use all logos from the full dataset for consistent labels)
|
||||||
self.logo_to_label = {
|
self.logo_to_label = {
|
||||||
logo: idx for idx, logo in enumerate(self.logos)
|
logo: idx for idx, logo in enumerate(logo_data.all_logos)
|
||||||
}
|
}
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
@ -244,12 +339,13 @@ class LogoContrastiveDataset(Dataset):
|
|||||||
images = []
|
images = []
|
||||||
labels = []
|
labels = []
|
||||||
|
|
||||||
# Sample K logos for this batch
|
# Sample K logos for this batch (only from logos that have images in this split)
|
||||||
k = min(self.logos_per_batch, len(self.logos))
|
available_logos = [l for l in self.logos if l in self.logo_to_images]
|
||||||
batch_logos = random.sample(self.logos, k)
|
k = min(self.logos_per_batch, len(available_logos))
|
||||||
|
batch_logos = random.sample(available_logos, k)
|
||||||
|
|
||||||
for logo in batch_logos:
|
for logo in batch_logos:
|
||||||
logo_images = self.logo_data.logo_to_images[logo]
|
logo_images = self.logo_to_images[logo]
|
||||||
|
|
||||||
# Sample M images for this logo
|
# Sample M images for this logo
|
||||||
if len(logo_images) >= self.samples_per_logo:
|
if len(logo_images) >= self.samples_per_logo:
|
||||||
@ -353,6 +449,7 @@ def create_dataloaders(
|
|||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
augmentation_strength: str = "medium",
|
augmentation_strength: str = "medium",
|
||||||
batches_per_epoch: int = 1000,
|
batches_per_epoch: int = 1000,
|
||||||
|
split_level: str = "logo",
|
||||||
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
|
) -> Tuple[DataLoader, DataLoader, Optional[DataLoader]]:
|
||||||
"""
|
"""
|
||||||
Create train, validation, and optionally test dataloaders.
|
Create train, validation, and optionally test dataloaders.
|
||||||
@ -370,6 +467,7 @@ def create_dataloaders(
|
|||||||
seed: Random seed
|
seed: Random seed
|
||||||
augmentation_strength: "light", "medium", or "strong"
|
augmentation_strength: "light", "medium", or "strong"
|
||||||
batches_per_epoch: Number of batches per training epoch
|
batches_per_epoch: Number of batches per training epoch
|
||||||
|
split_level: "logo" for brand-level splits, "image" for image-level splits
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (train_loader, val_loader, test_loader)
|
Tuple of (train_loader, val_loader, test_loader)
|
||||||
@ -382,11 +480,13 @@ def create_dataloaders(
|
|||||||
val_split=val_split,
|
val_split=val_split,
|
||||||
test_split=test_split,
|
test_split=test_split,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
split_level=split_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Print split info
|
# Print split info
|
||||||
split_info = logo_data.get_split_info()
|
split_info = logo_data.get_split_info()
|
||||||
print(f"Dataset loaded:")
|
print(f"Dataset loaded:")
|
||||||
|
print(f" Split level: {split_info['split_level']}")
|
||||||
print(f" Total logos: {split_info['total_logos']}")
|
print(f" Total logos: {split_info['total_logos']}")
|
||||||
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
|
print(f" Train: {split_info['train_logos']} logos, {split_info['train_images']} images")
|
||||||
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
|
print(f" Val: {split_info['val_logos']} logos, {split_info['val_images']} images")
|
||||||
|
|||||||
@ -250,33 +250,49 @@ class LogoFineTunedCLIP(nn.Module):
|
|||||||
# Load base CLIP model
|
# Load base CLIP model
|
||||||
clip_model = CLIPModel.from_pretrained(base_model)
|
clip_model = CLIPModel.from_pretrained(base_model)
|
||||||
|
|
||||||
# Create model instance
|
# Check if we need to load LoRA weights
|
||||||
|
if config.get("peft_applied", False) and PEFT_AVAILABLE:
|
||||||
|
# Create model WITHOUT LoRA (lora_r=0) - we'll load LoRA weights separately
|
||||||
model = cls(
|
model = cls(
|
||||||
vision_model=clip_model.vision_model,
|
vision_model=clip_model.vision_model,
|
||||||
lora_r=config.get("lora_r", 0),
|
lora_r=0, # Don't apply LoRA in constructor
|
||||||
lora_alpha=config.get("lora_alpha", 1),
|
lora_alpha=config.get("lora_alpha", 1),
|
||||||
freeze_layers=config.get("freeze_layers", 12),
|
freeze_layers=config.get("freeze_layers", 12),
|
||||||
add_projection_head=config.get("add_projection_head", True),
|
add_projection_head=config.get("add_projection_head", True),
|
||||||
use_gradient_checkpointing=False, # Not needed for inference
|
use_gradient_checkpointing=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load weights
|
# Load LoRA weights from checkpoint
|
||||||
if config.get("peft_applied", False) and PEFT_AVAILABLE:
|
|
||||||
# Load LoRA weights
|
|
||||||
lora_path = model_path / "vision_lora"
|
lora_path = model_path / "vision_lora"
|
||||||
if lora_path.exists():
|
if lora_path.exists():
|
||||||
model.vision_model = PeftModel.from_pretrained(
|
model.vision_model = PeftModel.from_pretrained(
|
||||||
model.vision_model, lora_path
|
model.vision_model, lora_path
|
||||||
)
|
)
|
||||||
|
model.peft_applied = True
|
||||||
|
model.lora_r = config.get("lora_r", 16)
|
||||||
|
|
||||||
# Load projection head
|
# Load projection head
|
||||||
proj_path = model_path / "projection_head.bin"
|
proj_path = model_path / "projection_head.bin"
|
||||||
if proj_path.exists():
|
if proj_path.exists():
|
||||||
model.projection.load_state_dict(torch.load(proj_path))
|
model.projection.load_state_dict(
|
||||||
|
torch.load(proj_path, map_location="cpu")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Load full model state
|
# No LoRA - create model and load full state
|
||||||
|
model = cls(
|
||||||
|
vision_model=clip_model.vision_model,
|
||||||
|
lora_r=0,
|
||||||
|
lora_alpha=config.get("lora_alpha", 1),
|
||||||
|
freeze_layers=config.get("freeze_layers", 12),
|
||||||
|
add_projection_head=config.get("add_projection_head", True),
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
)
|
||||||
|
|
||||||
weights_path = model_path / "pytorch_model.bin"
|
weights_path = model_path / "pytorch_model.bin"
|
||||||
if weights_path.exists():
|
if weights_path.exists():
|
||||||
model.load_state_dict(torch.load(weights_path))
|
model.load_state_dict(
|
||||||
|
torch.load(weights_path, map_location="cpu")
|
||||||
|
)
|
||||||
|
|
||||||
if device is not None:
|
if device is not None:
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|||||||
@ -169,16 +169,11 @@ class Trainer:
|
|||||||
"val_neg_sim": val_metrics["mean_neg_sim"],
|
"val_neg_sim": val_metrics["mean_neg_sim"],
|
||||||
})
|
})
|
||||||
|
|
||||||
# Checkpointing based on separation (primary) or loss (secondary)
|
# Checkpointing based on separation (gap between pos and neg similarity)
|
||||||
improved = False
|
# This is the key metric for contrastive learning quality
|
||||||
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
|
if val_metrics["separation"] > self.best_val_separation + self.config.min_delta:
|
||||||
self.best_val_separation = val_metrics["separation"]
|
self.best_val_separation = val_metrics["separation"]
|
||||||
improved = True
|
self.best_val_loss = val_metrics["loss"] # Track for reference
|
||||||
elif val_metrics["loss"] < self.best_val_loss - self.config.min_delta:
|
|
||||||
self.best_val_loss = val_metrics["loss"]
|
|
||||||
improved = True
|
|
||||||
|
|
||||||
if improved:
|
|
||||||
self.patience_counter = 0
|
self.patience_counter = 0
|
||||||
self._save_checkpoint("best.pt")
|
self._save_checkpoint("best.pt")
|
||||||
self.logger.info("New best model saved!")
|
self.logger.info("New best model saved!")
|
||||||
|
|||||||
Reference in New Issue
Block a user