stevenkhan commited on
Commit
5d6e14c
·
verified ·
1 Parent(s): 098b42d

Upload clashcr/core/evaluator.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. clashcr/core/evaluator.py +163 -0
clashcr/core/evaluator.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Offline evaluator: replay recordings and report metrics.
2
+
3
+ Metrics:
4
+ - precision, recall
5
+ - false positives per minute
6
+ - missed events
7
+ - timing error (seconds)
8
+ - confusion matrix
9
+ """
10
+ from __future__ import annotations
11
+
12
+ import csv
13
+ import json
14
+ import logging
15
+ from dataclasses import dataclass, field
16
+ from pathlib import Path
17
+ from typing import Dict, List, Optional, Tuple
18
+
19
+ import numpy as np
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class EvalResult:
26
+ precision: float
27
+ recall: float
28
+ f1: float
29
+ false_positives_per_minute: float
30
+ missed_events: int
31
+ mean_timing_error_seconds: float
32
+ median_timing_error_seconds: float
33
+ confusion_matrix: Dict[str, Dict[str, int]] = field(default_factory=dict)
34
+ total_predictions: int = 0
35
+ total_labels: int = 0
36
+ correct: int = 0
37
+ false_positives: int = 0
38
+ false_negatives: int = 0
39
+
40
+
41
+ class OfflineEvaluator:
42
+ """Evaluate predictions against ground-truth labels from a recording."""
43
+
44
+ def __init__(self, timing_tolerance_seconds: float = 3.0):
45
+ self.timing_tolerance = timing_tolerance_seconds
46
+
47
+ def load_labels(self, labels_path: str) -> List[dict]:
48
+ rows = []
49
+ with open(labels_path, "r", encoding="utf-8") as f:
50
+ reader = csv.DictReader(f)
51
+ for row in reader:
52
+ rows.append({
53
+ "timestamp": float(row["timestamp"]),
54
+ "frame_idx": int(row.get("frame_idx", 0)),
55
+ "side": row["side"],
56
+ "card_key": row["card_key"],
57
+ "confidence": float(row.get("confidence", 1.0)),
58
+ "note": row.get("manual_note", ""),
59
+ })
60
+ return rows
61
+
62
+ def load_predictions(self, predictions_path: str) -> List[dict]:
63
+ rows = []
64
+ with open(predictions_path, "r", encoding="utf-8") as f:
65
+ reader = csv.DictReader(f)
66
+ for row in reader:
67
+ rows.append({
68
+ "timestamp": float(row["timestamp"]),
69
+ "frame_idx": int(row.get("frame_idx", 0)),
70
+ "side": row["side"],
71
+ "card_key": row["card_key"],
72
+ "confidence": float(row.get("confidence", 0.0)),
73
+ "evidence": row.get("evidence", ""),
74
+ "resolver_reason": row.get("resolver_reason", ""),
75
+ })
76
+ return rows
77
+
78
+ def evaluate(self, labels: List[dict], predictions: List[dict],
79
+ recording_duration_seconds: Optional[float] = None) -> EvalResult:
80
+ # Filter to opponent-side only
81
+ labels = [r for r in labels if r["side"] == "opponent"]
82
+ predictions = [r for r in predictions if r["side"] == "opponent"]
83
+
84
+ total_labels = len(labels)
85
+ total_predictions = len(predictions)
86
+
87
+ matched_labels = set()
88
+ matched_preds = set()
89
+ timing_errors = []
90
+ confusion: Dict[str, Dict[str, int]] = {}
91
+
92
+ for li, lab in enumerate(labels):
93
+ best_pi = -1
94
+ best_err = float("inf")
95
+ for pi, pred in enumerate(predictions):
96
+ if pi in matched_preds:
97
+ continue
98
+ err = abs(pred["timestamp"] - lab["timestamp"])
99
+ if err <= self.timing_tolerance and err < best_err:
100
+ best_err = err
101
+ best_pi = pi
102
+
103
+ if best_pi >= 0:
104
+ matched_labels.add(li)
105
+ matched_preds.add(best_pi)
106
+ timing_errors.append(best_err)
107
+ pred_card = predictions[best_pi]["card_key"]
108
+ true_card = lab["card_key"]
109
+ confusion.setdefault(true_card, {}).setdefault(pred_card, 0)
110
+ confusion[true_card][pred_card] += 1
111
+
112
+ correct = sum(
113
+ 1 for li in matched_labels
114
+ if labels[li]["card_key"] == predictions[
115
+ next(pi for pi in matched_preds if abs(predictions[pi]["timestamp"] - labels[li]["timestamp"]) <= self.timing_tolerance)
116
+ ]["card_key"]
117
+ )
118
+ # Simpler correct count from confusion diagonal
119
+ correct = sum(confusion.get(c, {}).get(c, 0) for c in confusion)
120
+
121
+ false_positives = total_predictions - len(matched_preds)
122
+ false_negatives = total_labels - len(matched_labels)
123
+
124
+ precision = correct / total_predictions if total_predictions > 0 else 0.0
125
+ recall = correct / total_labels if total_labels > 0 else 0.0
126
+ f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
127
+
128
+ duration = recording_duration_seconds or (labels[-1]["timestamp"] - labels[0]["timestamp"]) if labels else 1.0
129
+ fp_per_min = (false_positives / duration) * 60.0 if duration > 0 else 0.0
130
+
131
+ mean_timing = float(np.mean(timing_errors)) if timing_errors else 0.0
132
+ median_timing = float(np.median(timing_errors)) if timing_errors else 0.0
133
+
134
+ return EvalResult(
135
+ precision=precision,
136
+ recall=recall,
137
+ f1=f1,
138
+ false_positives_per_minute=fp_per_min,
139
+ missed_events=false_negatives,
140
+ mean_timing_error_seconds=mean_timing,
141
+ median_timing_error_seconds=median_timing,
142
+ confusion_matrix=confusion,
143
+ total_predictions=total_predictions,
144
+ total_labels=total_labels,
145
+ correct=correct,
146
+ false_positives=false_positives,
147
+ false_negatives=false_negatives,
148
+ )
149
+
150
+ def evaluate_recording_dir(self, recording_dir: str) -> EvalResult:
151
+ rec = Path(recording_dir)
152
+ labels = self.load_labels(str(rec / "labels.csv"))
153
+ predictions = self.load_predictions(str(rec / "predictions.csv"))
154
+ meta_path = rec / "metadata.jsonl"
155
+ duration = None
156
+ if meta_path.exists():
157
+ with open(meta_path, "r") as f:
158
+ lines = f.readlines()
159
+ if lines:
160
+ first = json.loads(lines[0])
161
+ last = json.loads(lines[-1])
162
+ duration = last["timestamp"] - first["timestamp"]
163
+ return self.evaluate(labels, predictions, duration)