Compare commits
2 Commits
197e007591
...
41c75356d9
| Author | SHA1 | Date | |
|---|---|---|---|
| 41c75356d9 | |||
| 41bc0c701f |
@ -394,6 +394,47 @@ class DetectLogosDETR:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def find_all_matches(
|
||||||
|
self,
|
||||||
|
detected_embedding: torch.Tensor,
|
||||||
|
reference_embeddings: List[Tuple[str, torch.Tensor]],
|
||||||
|
similarity_threshold: float = 0.7,
|
||||||
|
) -> List[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
Find all matching reference logos above the similarity threshold.
|
||||||
|
|
||||||
|
Unlike find_best_match, this returns ALL logos that have at least one
|
||||||
|
reference above threshold. Each unique logo is returned once with its
|
||||||
|
highest similarity score.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detected_embedding: CLIP embedding from detected logo region
|
||||||
|
reference_embeddings: List of (label, embedding) tuples for reference logos
|
||||||
|
similarity_threshold: Minimum similarity to consider a match (0-1)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of (label, similarity) tuples for all matches above threshold,
|
||||||
|
sorted by similarity descending. Each logo appears at most once.
|
||||||
|
"""
|
||||||
|
if not reference_embeddings:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Track best similarity for each logo
|
||||||
|
logo_best_sim: Dict[str, float] = {}
|
||||||
|
|
||||||
|
for label, ref_embedding in reference_embeddings:
|
||||||
|
similarity = self.compare_embeddings(detected_embedding, ref_embedding)
|
||||||
|
|
||||||
|
if similarity >= similarity_threshold:
|
||||||
|
if label not in logo_best_sim or similarity > logo_best_sim[label]:
|
||||||
|
logo_best_sim[label] = similarity
|
||||||
|
|
||||||
|
# Convert to list and sort by similarity descending
|
||||||
|
matches = [(label, sim) for label, sim in logo_best_sim.items()]
|
||||||
|
matches.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
def find_best_match_multi_ref(
|
def find_best_match_multi_ref(
|
||||||
self,
|
self,
|
||||||
detected_embedding: torch.Tensor,
|
detected_embedding: torch.Tensor,
|
||||||
|
|||||||
@ -78,6 +78,41 @@ match = detector.find_best_match(
|
|||||||
**Returns:**
|
**Returns:**
|
||||||
- Tuple of (label, similarity) for best match, or None if no match above threshold
|
- Tuple of (label, similarity) for best match, or None if no match above threshold
|
||||||
|
|
||||||
|
#### `find_all_matches()` - Find all matching reference logos
|
||||||
|
|
||||||
|
Returns ALL logos that have at least one reference above the similarity threshold. Each unique logo appears once with its highest similarity score.
|
||||||
|
|
||||||
|
```python
|
||||||
|
matches = detector.find_all_matches(
|
||||||
|
detected_embedding,
|
||||||
|
reference_embeddings,
|
||||||
|
similarity_threshold=0.7
|
||||||
|
)
|
||||||
|
# Returns: [(label1, similarity1), (label2, similarity2), ...]
|
||||||
|
```
|
||||||
|
|
||||||
|
**Parameters:**
|
||||||
|
- `detected_embedding`: CLIP embedding from detected logo region
|
||||||
|
- `reference_embeddings`: List of (label, embedding) tuples for reference logos
|
||||||
|
- `similarity_threshold`: Minimum similarity to consider a match (0-1, default: 0.7)
|
||||||
|
|
||||||
|
**Returns:**
|
||||||
|
- List of (label, similarity) tuples for all matches above threshold, sorted by similarity descending
|
||||||
|
- Each logo appears at most once (with its highest matching reference)
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```python
|
||||||
|
# Get all logos that match a detection
|
||||||
|
all_matches = detector.find_all_matches(
|
||||||
|
detection["embedding"],
|
||||||
|
reference_embeddings,
|
||||||
|
similarity_threshold=0.7
|
||||||
|
)
|
||||||
|
|
||||||
|
for label, similarity in all_matches:
|
||||||
|
print(f"Matched: {label} (similarity: {similarity:.3f})")
|
||||||
|
```
|
||||||
|
|
||||||
#### `detect_and_match()` - One-step detection and matching
|
#### `detect_and_match()` - One-step detection and matching
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
|||||||
@ -39,8 +39,8 @@ The system uses a two-stage pipeline:
|
|||||||
|
|
||||||
| Parameter | Default | Description |
|
| Parameter | Default | Description |
|
||||||
|-----------|---------|-------------|
|
|-----------|---------|-------------|
|
||||||
| `--matching-method` | margin | Matching method: `margin` or `multi-ref` |
|
| `--matching-method` | margin | Matching method: `simple`, `margin`, or `multi-ref` |
|
||||||
| `--margin` | 0.05 | Required margin between best and second-best match (applies to both methods) |
|
| `--margin` | 0.05 | Required margin between best and second-best match (applies to `margin` and `multi-ref`) |
|
||||||
|
|
||||||
#### Multi-Ref Method Parameters (when `--matching-method multi-ref`)
|
#### Multi-Ref Method Parameters (when `--matching-method multi-ref`)
|
||||||
|
|
||||||
@ -193,11 +193,11 @@ This ensures cosine similarity is computed correctly and scores fall in the rang
|
|||||||
|
|
||||||
| Method | Test Script Option | Key Feature |
|
| Method | Test Script Option | Key Feature |
|
||||||
|--------|-------------------|-------------|
|
|--------|-------------------|-------------|
|
||||||
| `find_best_match` | N/A (library only) | Returns highest similarity above threshold |
|
| `find_all_matches` | `--matching-method simple` | Returns ALL logos above threshold (baseline, most permissive) |
|
||||||
| `find_best_match_with_margin` | `--matching-method margin` | Requires margin over second-best match |
|
| `find_best_match_with_margin` | `--matching-method margin` | Requires margin over second-best match |
|
||||||
| `find_best_match_multi_ref` | `--matching-method multi-ref` | Aggregates scores across reference images |
|
| `find_best_match_multi_ref` | `--matching-method multi-ref` | Aggregates scores across reference images |
|
||||||
|
|
||||||
The test script supports both `margin` and `multi-ref` matching methods via the `--matching-method` parameter.
|
The test script supports `simple`, `margin`, and `multi-ref` matching methods via the `--matching-method` parameter.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@ -242,13 +242,14 @@ Input Image
|
|||||||
▼
|
▼
|
||||||
┌─────────────────────────────────────┐
|
┌─────────────────────────────────────┐
|
||||||
│ Matching (selectable method) │
|
│ Matching (selectable method) │
|
||||||
│ ┌───────────────┬────────────────┐ │
|
│ ┌─────────┬─────────┬────────────┐ │
|
||||||
│ │ margin │ multi-ref │ │
|
│ │ simple │ margin │ multi-ref │ │
|
||||||
│ ├───────────────┼────────────────┤ │
|
│ ├─────────┼─────────┼────────────┤ │
|
||||||
│ │ Require margin│ Aggregate │ │
|
│ │ All │ Require │ Aggregate │ │
|
||||||
│ │ over 2nd best │ across refs │ │
|
│ │ matches │ margin │ across │ │
|
||||||
│ │ match │ (mean or max) │ │
|
│ │ above │ over │ refs │ │
|
||||||
│ └───────────────┴────────────────┘ │
|
│ │ thresh │ 2nd best│ (mean/max) │ │
|
||||||
|
│ └─────────┴─────────┴────────────┘ │
|
||||||
└─────────────────────────────────────┘
|
└─────────────────────────────────────┘
|
||||||
│
|
│
|
||||||
▼
|
▼
|
||||||
@ -259,6 +260,15 @@ Matched Logo Labels
|
|||||||
|
|
||||||
## Tuning Recommendations
|
## Tuning Recommendations
|
||||||
|
|
||||||
|
### For Simple Matching (`--matching-method simple`)
|
||||||
|
|
||||||
|
| Goal | Adjustments |
|
||||||
|
|------|-------------|
|
||||||
|
| **Reduce false positives** | Increase `--threshold` (only tuning option for simple method) |
|
||||||
|
| **Reduce false negatives** | Decrease `--threshold` |
|
||||||
|
|
||||||
|
Note: Simple matching is primarily used as a baseline. For production use, consider `margin` or `multi-ref`.
|
||||||
|
|
||||||
### For Margin-Based Matching (`--matching-method margin`)
|
### For Margin-Based Matching (`--matching-method margin`)
|
||||||
|
|
||||||
| Goal | Adjustments |
|
| Goal | Adjustments |
|
||||||
@ -287,6 +297,9 @@ Matched Logo Labels
|
|||||||
## Example Usage
|
## Example Usage
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Simple matching (baseline - all matches above threshold)
|
||||||
|
python test_logo_detection.py -n 20 --matching-method simple --threshold 0.70
|
||||||
|
|
||||||
# Default margin-based matching
|
# Default margin-based matching
|
||||||
python test_logo_detection.py -n 20 --threshold 0.75 --margin 0.05
|
python test_logo_detection.py -n 20 --threshold 0.75 --margin 0.05
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
#
|
#
|
||||||
# Run logo detection tests with all three matching methods and save results.
|
# Run logo detection tests with all four matching methods and save results.
|
||||||
#
|
#
|
||||||
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
@ -16,10 +16,19 @@ MIN_MATCHING_REFS=3
|
|||||||
# Use a fixed seed for reproducibility across methods
|
# Use a fixed seed for reproducibility across methods
|
||||||
SEED=42
|
SEED=42
|
||||||
|
|
||||||
|
# Clear output file and write header
|
||||||
echo "Logo Detection Comparison Tests" > "$OUTPUT_FILE"
|
echo "Logo Detection Comparison Tests" > "$OUTPUT_FILE"
|
||||||
echo "================================" >> "$OUTPUT_FILE"
|
echo "================================" >> "$OUTPUT_FILE"
|
||||||
echo "Date: $(date)" >> "$OUTPUT_FILE"
|
echo "Date: $(date)" >> "$OUTPUT_FILE"
|
||||||
echo "" >> "$OUTPUT_FILE"
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
echo "Common Parameters:" >> "$OUTPUT_FILE"
|
||||||
|
echo " Reference logos: $NUM_LOGOS" >> "$OUTPUT_FILE"
|
||||||
|
echo " Refs per logo: $REFS_PER_LOGO" >> "$OUTPUT_FILE"
|
||||||
|
echo " Positive samples: $POSITIVE_SAMPLES" >> "$OUTPUT_FILE"
|
||||||
|
echo " Negative samples: $NEGATIVE_SAMPLES" >> "$OUTPUT_FILE"
|
||||||
|
echo " Min matching refs: $MIN_MATCHING_REFS" >> "$OUTPUT_FILE"
|
||||||
|
echo " Seed: $SEED" >> "$OUTPUT_FILE"
|
||||||
|
echo "" >> "$OUTPUT_FILE"
|
||||||
|
|
||||||
echo "Running tests with:"
|
echo "Running tests with:"
|
||||||
echo " Reference logos: $NUM_LOGOS"
|
echo " Reference logos: $NUM_LOGOS"
|
||||||
@ -30,8 +39,21 @@ echo " Min matching refs: $MIN_MATCHING_REFS"
|
|||||||
echo " Seed: $SEED"
|
echo " Seed: $SEED"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Test 1: Margin-based matching
|
# Test 1: Simple matching (baseline - all matches above threshold)
|
||||||
echo "=== Test 1: Margin-based matching ===" | tee -a "$OUTPUT_FILE"
|
echo "=== Test 1: Simple matching (baseline) ==="
|
||||||
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
|
--num-logos $NUM_LOGOS \
|
||||||
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
|
--positive-samples $POSITIVE_SAMPLES \
|
||||||
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
|
--matching-method simple \
|
||||||
|
--seed $SEED \
|
||||||
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Test 2: Margin-based matching
|
||||||
|
echo "=== Test 2: Margin-based matching ==="
|
||||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
--num-logos $NUM_LOGOS \
|
--num-logos $NUM_LOGOS \
|
||||||
--refs-per-logo $REFS_PER_LOGO \
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
@ -39,13 +61,12 @@ uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
|||||||
--negative-samples $NEGATIVE_SAMPLES \
|
--negative-samples $NEGATIVE_SAMPLES \
|
||||||
--matching-method margin \
|
--matching-method margin \
|
||||||
--seed $SEED \
|
--seed $SEED \
|
||||||
2>&1 | tee -a "$OUTPUT_FILE"
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
echo "" >> "$OUTPUT_FILE"
|
echo ""
|
||||||
echo "" >> "$OUTPUT_FILE"
|
|
||||||
|
|
||||||
# Test 2: Multi-ref with mean similarity
|
# Test 3: Multi-ref with mean similarity
|
||||||
echo "=== Test 2: Multi-ref matching (mean similarity) ===" | tee -a "$OUTPUT_FILE"
|
echo "=== Test 3: Multi-ref matching (mean similarity) ==="
|
||||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
--num-logos $NUM_LOGOS \
|
--num-logos $NUM_LOGOS \
|
||||||
--refs-per-logo $REFS_PER_LOGO \
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
@ -54,13 +75,12 @@ uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
|||||||
--matching-method multi-ref \
|
--matching-method multi-ref \
|
||||||
--min-matching-refs $MIN_MATCHING_REFS \
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
--seed $SEED \
|
--seed $SEED \
|
||||||
2>&1 | tee -a "$OUTPUT_FILE"
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
echo "" >> "$OUTPUT_FILE"
|
echo ""
|
||||||
echo "" >> "$OUTPUT_FILE"
|
|
||||||
|
|
||||||
# Test 3: Multi-ref with max similarity
|
# Test 4: Multi-ref with max similarity
|
||||||
echo "=== Test 3: Multi-ref matching (max similarity) ===" | tee -a "$OUTPUT_FILE"
|
echo "=== Test 4: Multi-ref matching (max similarity) ==="
|
||||||
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
||||||
--num-logos $NUM_LOGOS \
|
--num-logos $NUM_LOGOS \
|
||||||
--refs-per-logo $REFS_PER_LOGO \
|
--refs-per-logo $REFS_PER_LOGO \
|
||||||
@ -70,7 +90,7 @@ uv run python "$SCRIPT_DIR/test_logo_detection.py" \
|
|||||||
--min-matching-refs $MIN_MATCHING_REFS \
|
--min-matching-refs $MIN_MATCHING_REFS \
|
||||||
--use-max-similarity \
|
--use-max-similarity \
|
||||||
--seed $SEED \
|
--seed $SEED \
|
||||||
2>&1 | tee -a "$OUTPUT_FILE"
|
--output-file "$OUTPUT_FILE"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "Results saved to: $OUTPUT_FILE"
|
echo "Results saved to: $OUTPUT_FILE"
|
||||||
@ -236,9 +236,10 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--matching-method",
|
"--matching-method",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["margin", "multi-ref"],
|
choices=["simple", "margin", "multi-ref"],
|
||||||
default="margin",
|
default="margin",
|
||||||
help="Matching method: 'margin' requires confidence margin over 2nd best, "
|
help="Matching method: 'simple' returns all matches above threshold, "
|
||||||
|
"'margin' requires confidence margin over 2nd best, "
|
||||||
"'multi-ref' aggregates scores across reference images (default: margin)",
|
"'multi-ref' aggregates scores across reference images (default: margin)",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -267,6 +268,12 @@ def main():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Clear embedding cache before running",
|
help="Clear embedding cache before running",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-file",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Append results summary to this file (no progress output, just results)",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
logger = setup_logging(args.verbose)
|
logger = setup_logging(args.verbose)
|
||||||
@ -431,10 +438,30 @@ def main():
|
|||||||
# Match detections against references using selected method
|
# Match detections against references using selected method
|
||||||
matched_logos: Set[str] = set()
|
matched_logos: Set[str] = set()
|
||||||
for detection in detections:
|
for detection in detections:
|
||||||
match = None
|
if args.matching_method == "simple":
|
||||||
similarity = None
|
# Simple matching: return ALL logos above threshold
|
||||||
|
all_matches = detector.find_all_matches(
|
||||||
|
detection["embedding"],
|
||||||
|
reference_embeddings,
|
||||||
|
similarity_threshold=args.threshold,
|
||||||
|
)
|
||||||
|
for label, similarity in all_matches:
|
||||||
|
matched_logos.add(label)
|
||||||
|
|
||||||
if args.matching_method == "margin":
|
# Check if this is a correct match
|
||||||
|
if label in expected_logos:
|
||||||
|
true_positives += 1
|
||||||
|
else:
|
||||||
|
false_positives += 1
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"test_image": test_filename,
|
||||||
|
"matched_logo": label,
|
||||||
|
"similarity": similarity,
|
||||||
|
"correct": label in expected_logos,
|
||||||
|
})
|
||||||
|
|
||||||
|
elif args.matching_method == "margin":
|
||||||
# Margin-based matching: requires margin over second-best
|
# Margin-based matching: requires margin over second-best
|
||||||
match_result = detector.find_best_match_with_margin(
|
match_result = detector.find_best_match_with_margin(
|
||||||
detection["embedding"],
|
detection["embedding"],
|
||||||
@ -444,7 +471,20 @@ def main():
|
|||||||
)
|
)
|
||||||
if match_result:
|
if match_result:
|
||||||
label, similarity = match_result
|
label, similarity = match_result
|
||||||
match = label
|
matched_logos.add(label)
|
||||||
|
|
||||||
|
if label in expected_logos:
|
||||||
|
true_positives += 1
|
||||||
|
else:
|
||||||
|
false_positives += 1
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"test_image": test_filename,
|
||||||
|
"matched_logo": label,
|
||||||
|
"similarity": similarity,
|
||||||
|
"correct": label in expected_logos,
|
||||||
|
})
|
||||||
|
|
||||||
else: # multi-ref
|
else: # multi-ref
|
||||||
# Multi-ref matching: aggregates scores across reference images
|
# Multi-ref matching: aggregates scores across reference images
|
||||||
match_result = detector.find_best_match_multi_ref(
|
match_result = detector.find_best_match_multi_ref(
|
||||||
@ -457,22 +497,18 @@ def main():
|
|||||||
)
|
)
|
||||||
if match_result:
|
if match_result:
|
||||||
label, similarity, num_matching = match_result
|
label, similarity, num_matching = match_result
|
||||||
match = label
|
matched_logos.add(label)
|
||||||
|
|
||||||
if match:
|
if label in expected_logos:
|
||||||
matched_logos.add(match)
|
|
||||||
|
|
||||||
# Check if this is a correct match
|
|
||||||
if match in expected_logos:
|
|
||||||
true_positives += 1
|
true_positives += 1
|
||||||
else:
|
else:
|
||||||
false_positives += 1
|
false_positives += 1
|
||||||
|
|
||||||
results.append({
|
results.append({
|
||||||
"test_image": test_filename,
|
"test_image": test_filename,
|
||||||
"matched_logo": match,
|
"matched_logo": label,
|
||||||
"similarity": similarity,
|
"similarity": similarity,
|
||||||
"correct": match in expected_logos,
|
"correct": label in expected_logos,
|
||||||
})
|
})
|
||||||
|
|
||||||
# Count missed detections (false negatives)
|
# Count missed detections (false negatives)
|
||||||
@ -512,6 +548,7 @@ def main():
|
|||||||
print(f" CLIP similarity threshold: {args.threshold}")
|
print(f" CLIP similarity threshold: {args.threshold}")
|
||||||
print(f" DETR confidence threshold: {args.detr_threshold}")
|
print(f" DETR confidence threshold: {args.detr_threshold}")
|
||||||
print(f" Matching method: {args.matching_method}")
|
print(f" Matching method: {args.matching_method}")
|
||||||
|
if args.matching_method in ("margin", "multi-ref"):
|
||||||
print(f" Matching margin: {args.margin}")
|
print(f" Matching margin: {args.margin}")
|
||||||
if args.matching_method == "multi-ref":
|
if args.matching_method == "multi-ref":
|
||||||
print(f" Min matching refs: {args.min_matching_refs}")
|
print(f" Min matching refs: {args.min_matching_refs}")
|
||||||
@ -548,6 +585,92 @@ def main():
|
|||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Write results to file if requested
|
||||||
|
if args.output_file:
|
||||||
|
write_results_to_file(
|
||||||
|
output_path=Path(args.output_file),
|
||||||
|
args=args,
|
||||||
|
num_logos=len(sampled_logos),
|
||||||
|
total_refs=total_refs,
|
||||||
|
num_test_images=len(test_images),
|
||||||
|
true_positives=true_positives,
|
||||||
|
false_positives=false_positives,
|
||||||
|
false_negatives=false_negatives,
|
||||||
|
total_expected=total_expected,
|
||||||
|
precision=precision,
|
||||||
|
recall=recall,
|
||||||
|
f1=f1,
|
||||||
|
)
|
||||||
|
print(f"\nResults appended to: {args.output_file}")
|
||||||
|
|
||||||
|
|
||||||
|
def write_results_to_file(
|
||||||
|
output_path: Path,
|
||||||
|
args,
|
||||||
|
num_logos: int,
|
||||||
|
total_refs: int,
|
||||||
|
num_test_images: int,
|
||||||
|
true_positives: int,
|
||||||
|
false_positives: int,
|
||||||
|
false_negatives: int,
|
||||||
|
total_expected: int,
|
||||||
|
precision: float,
|
||||||
|
recall: float,
|
||||||
|
f1: float,
|
||||||
|
):
|
||||||
|
"""Write results summary to file with detailed header."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Build method description for header
|
||||||
|
if args.matching_method == "simple":
|
||||||
|
method_desc = "Simple (all matches above threshold)"
|
||||||
|
elif args.matching_method == "margin":
|
||||||
|
method_desc = f"Margin-based (margin={args.margin})"
|
||||||
|
else: # multi-ref
|
||||||
|
agg = "max" if args.use_max_similarity else "mean"
|
||||||
|
method_desc = f"Multi-ref ({agg}, min_refs={args.min_matching_refs}, margin={args.margin})"
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"=" * 70,
|
||||||
|
f"TEST: {args.matching_method.upper()} MATCHING",
|
||||||
|
f"Method: {method_desc}",
|
||||||
|
"=" * 70,
|
||||||
|
f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
||||||
|
"",
|
||||||
|
"Configuration:",
|
||||||
|
f" Reference logos: {num_logos}",
|
||||||
|
f" Refs per logo: {args.refs_per_logo}",
|
||||||
|
f" Total reference embeddings:{total_refs}",
|
||||||
|
f" Positive samples/logo: {args.positive_samples}",
|
||||||
|
f" Negative samples/logo: {args.negative_samples}",
|
||||||
|
f" Test images processed: {num_test_images}",
|
||||||
|
f" CLIP threshold: {args.threshold}",
|
||||||
|
f" DETR threshold: {args.detr_threshold}",
|
||||||
|
]
|
||||||
|
|
||||||
|
if args.seed is not None:
|
||||||
|
lines.append(f" Random seed: {args.seed}")
|
||||||
|
|
||||||
|
lines.extend([
|
||||||
|
"",
|
||||||
|
"Results:",
|
||||||
|
f" True Positives: {true_positives:>6}",
|
||||||
|
f" False Positives: {false_positives:>6}",
|
||||||
|
f" False Negatives: {false_negatives:>6}",
|
||||||
|
f" Total Expected: {total_expected:>6}",
|
||||||
|
"",
|
||||||
|
"Scores:",
|
||||||
|
f" Precision: {precision:.4f} ({precision*100:.1f}%)",
|
||||||
|
f" Recall: {recall:.4f} ({recall*100:.1f}%)",
|
||||||
|
f" F1 Score: {f1:.4f} ({f1*100:.1f}%)",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
])
|
||||||
|
|
||||||
|
# Append to file
|
||||||
|
with open(output_path, "a") as f:
|
||||||
|
f.write("\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
Reference in New Issue
Block a user