SofiTesfay2010 commited on
Commit
f712b7c
·
verified ·
1 Parent(s): ce46efa

v0.3: detectors with export/load calibration

Browse files
Files changed (1) hide show
  1. aria_llm/detectors.py +82 -16
aria_llm/detectors.py CHANGED
@@ -1,13 +1,10 @@
1
  """
2
- ARIA Detectors v0.2
3
  ====================
4
 
5
- v0.2 changes:
6
- - All detectors now have a calibration phase that collects N steps of normal
7
- model behavior and computes mean + std statistics.
8
- - Triggering is based on mean + k*std (configurable sensitivity).
9
- - No detector fires during calibration.
10
- - MedianTrapDetector completely rewritten to use calibrated baselines.
11
 
12
  Grounded in:
13
  - Dynamic Instability Signal (arxiv:2602.02863): JSD + entropy
@@ -39,7 +36,10 @@ class DetectionSignal:
39
 
40
 
41
  class _CalibrationBuffer:
42
- """Shared calibration logic: collect samples, compute mean + std, derive threshold."""
 
 
 
43
 
44
  def __init__(self, calibration_steps: int, sensitivity_k: float):
45
  self.calibration_steps = calibration_steps
@@ -79,6 +79,31 @@ class _CalibrationBuffer:
79
  severity = min(1.0, excess / (self.sensitivity_k * scale))
80
  return severity
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  def reset(self):
83
  self.samples.clear()
84
  self.mean = None
@@ -88,8 +113,7 @@ class _CalibrationBuffer:
88
 
89
 
90
  class CompoundErrorDetector:
91
- """Detects compound error accumulation via Dynamic Instability Signal (arxiv:2602.02863).
92
- v0.2: Uses calibration buffer. Only triggers when instability exceeds mean + k*std."""
93
 
94
  def __init__(self, calibration_steps: int = 20, sensitivity_k: float = 2.5,
95
  window: int = 10, lam: float = 1.0, fallback_threshold: float = 0.7):
@@ -104,6 +128,12 @@ class CompoundErrorDetector:
104
  self.prev_probs = None
105
  self.instability_history.clear()
106
  self.calibration.reset()
 
 
 
 
 
 
107
 
108
  def _jsd(self, p: torch.Tensor, q: torch.Tensor) -> float:
109
  p = p.float().clamp(min=1e-8)
@@ -166,8 +196,7 @@ class CompoundErrorDetector:
166
 
167
 
168
  class SemanticDriftDetector:
169
- """Detects semantic drift by tracking cosine distance from goal anchor.
170
- v0.2: Uses calibration buffer for cosine distance distribution."""
171
 
172
  def __init__(self, calibration_steps: int = 20, sensitivity_k: float = 2.5,
173
  window: int = 20, fallback_threshold: float = 0.3):
@@ -181,6 +210,12 @@ class SemanticDriftDetector:
181
  self.goal_anchor = None
182
  self.distance_history.clear()
183
  self.calibration.reset()
 
 
 
 
 
 
184
 
185
  def set_goal_anchor(self, hidden_state: torch.Tensor):
186
  self.goal_anchor = hidden_state.float().detach().clone()
@@ -231,8 +266,7 @@ class SemanticDriftDetector:
231
 
232
 
233
  class LogicLoopDetector:
234
- """Detects logic looping via entropy variance collapse + trajectory fingerprinting.
235
- v0.2: Calibrates entropy variance baseline."""
236
 
237
  def __init__(self, calibration_steps: int = 20, sensitivity_k: float = 2.5,
238
  window: int = 15, similarity_threshold: float = 0.92,
@@ -247,6 +281,7 @@ class LogicLoopDetector:
247
  self.step = 0
248
  self.var_samples = []
249
  self.calibration_steps = calibration_steps
 
250
  self.var_mean: Optional[float] = None
251
  self.var_std: Optional[float] = None
252
  self.var_threshold: Optional[float] = None
@@ -262,6 +297,24 @@ class LogicLoopDetector:
262
  self.var_std = None
263
  self.var_threshold = None
264
  self.sim_calibration.reset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
 
266
  def _compute_fingerprint(self, states: List[torch.Tensor]) -> torch.Tensor:
267
  if not states:
@@ -331,8 +384,7 @@ class LogicLoopDetector:
331
 
332
 
333
  class MedianTrapDetector:
334
- """Detects when the model is producing statistically average outputs.
335
- v0.2: Completely rewritten to use calibrated baselines instead of absolute formula."""
336
 
337
  def __init__(self, calibration_steps: int = 20, sensitivity_k: float = 2.5,
338
  temperature_boost: float = 1.15, novelty_bonus: float = 0.05):
@@ -348,6 +400,20 @@ class MedianTrapDetector:
348
  self.step = 0
349
  self.top1_calibration.reset()
350
  self.inv_entropy_calibration.reset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  def detect(self, logits: torch.Tensor) -> DetectionSignal:
353
  self.step += 1
 
1
  """
2
+ ARIA Detectors v0.3
3
  ====================
4
 
5
+ v0.3 changes:
6
+ - _CalibrationBuffer gains export_state() / load_state() for profile persistence
7
+ - All detectors gain export_calibration() / load_calibration() methods
 
 
 
8
 
9
  Grounded in:
10
  - Dynamic Instability Signal (arxiv:2602.02863): JSD + entropy
 
36
 
37
 
38
  class _CalibrationBuffer:
39
+ """Shared calibration logic: collect samples, compute mean + std, derive threshold.
40
+
41
+ v0.3: Added export_state() / load_state() for calibration profile persistence.
42
+ """
43
 
44
  def __init__(self, calibration_steps: int, sensitivity_k: float):
45
  self.calibration_steps = calibration_steps
 
79
  severity = min(1.0, excess / (self.sensitivity_k * scale))
80
  return severity
81
 
82
+ def export_state(self) -> Dict:
83
+ """Export calibration state for persistence."""
84
+ return {
85
+ "mean": self.mean,
86
+ "std": self.std,
87
+ "threshold": self.threshold,
88
+ "sensitivity_k": self.sensitivity_k,
89
+ "calibration_steps": self.calibration_steps,
90
+ "n_samples": len(self.samples),
91
+ "samples_summary": {
92
+ "min": min(self.samples) if self.samples else None,
93
+ "max": max(self.samples) if self.samples else None,
94
+ "median": sorted(self.samples)[len(self.samples)//2] if self.samples else None,
95
+ }
96
+ }
97
+
98
+ def load_state(self, state: Dict):
99
+ """Load calibration state from a saved profile. Skips calibration phase."""
100
+ self.mean = state["mean"]
101
+ self.std = state["std"]
102
+ self.threshold = state["threshold"]
103
+ self.sensitivity_k = state.get("sensitivity_k", self.sensitivity_k)
104
+ self.step = self.calibration_steps
105
+ self.samples = []
106
+
107
  def reset(self):
108
  self.samples.clear()
109
  self.mean = None
 
113
 
114
 
115
  class CompoundErrorDetector:
116
+ """Detects compound error accumulation via Dynamic Instability Signal (arxiv:2602.02863)."""
 
117
 
118
  def __init__(self, calibration_steps: int = 20, sensitivity_k: float = 2.5,
119
  window: int = 10, lam: float = 1.0, fallback_threshold: float = 0.7):
 
128
  self.prev_probs = None
129
  self.instability_history.clear()
130
  self.calibration.reset()
131
+
132
+ def export_calibration(self) -> Dict:
133
+ return {"compound_error": self.calibration.export_state()}
134
+
135
+ def load_calibration(self, state: Dict):
136
+ self.calibration.load_state(state["compound_error"])
137
 
138
  def _jsd(self, p: torch.Tensor, q: torch.Tensor) -> float:
139
  p = p.float().clamp(min=1e-8)
 
196
 
197
 
198
  class SemanticDriftDetector:
199
+ """Detects semantic drift by tracking cosine distance from goal anchor."""
 
200
 
201
  def __init__(self, calibration_steps: int = 20, sensitivity_k: float = 2.5,
202
  window: int = 20, fallback_threshold: float = 0.3):
 
210
  self.goal_anchor = None
211
  self.distance_history.clear()
212
  self.calibration.reset()
213
+
214
+ def export_calibration(self) -> Dict:
215
+ return {"semantic_drift": self.calibration.export_state()}
216
+
217
+ def load_calibration(self, state: Dict):
218
+ self.calibration.load_state(state["semantic_drift"])
219
 
220
  def set_goal_anchor(self, hidden_state: torch.Tensor):
221
  self.goal_anchor = hidden_state.float().detach().clone()
 
266
 
267
 
268
  class LogicLoopDetector:
269
+ """Detects logic looping via entropy variance collapse + trajectory fingerprinting."""
 
270
 
271
  def __init__(self, calibration_steps: int = 20, sensitivity_k: float = 2.5,
272
  window: int = 15, similarity_threshold: float = 0.92,
 
281
  self.step = 0
282
  self.var_samples = []
283
  self.calibration_steps = calibration_steps
284
+ self.sensitivity_k = sensitivity_k
285
  self.var_mean: Optional[float] = None
286
  self.var_std: Optional[float] = None
287
  self.var_threshold: Optional[float] = None
 
297
  self.var_std = None
298
  self.var_threshold = None
299
  self.sim_calibration.reset()
300
+
301
+ def export_calibration(self) -> Dict:
302
+ return {
303
+ "logic_loop": {
304
+ "sim_calibration": self.sim_calibration.export_state(),
305
+ "var_mean": self.var_mean,
306
+ "var_std": self.var_std,
307
+ "var_threshold": self.var_threshold,
308
+ }
309
+ }
310
+
311
+ def load_calibration(self, state: Dict):
312
+ loop_state = state["logic_loop"]
313
+ self.sim_calibration.load_state(loop_state["sim_calibration"])
314
+ self.var_mean = loop_state["var_mean"]
315
+ self.var_std = loop_state["var_std"]
316
+ self.var_threshold = loop_state["var_threshold"]
317
+ self.step = self.calibration_steps
318
 
319
  def _compute_fingerprint(self, states: List[torch.Tensor]) -> torch.Tensor:
320
  if not states:
 
384
 
385
 
386
  class MedianTrapDetector:
387
+ """Detects when the model is producing statistically average outputs."""
 
388
 
389
  def __init__(self, calibration_steps: int = 20, sensitivity_k: float = 2.5,
390
  temperature_boost: float = 1.15, novelty_bonus: float = 0.05):
 
400
  self.step = 0
401
  self.top1_calibration.reset()
402
  self.inv_entropy_calibration.reset()
403
+
404
+ def export_calibration(self) -> Dict:
405
+ return {
406
+ "median_trap": {
407
+ "top1": self.top1_calibration.export_state(),
408
+ "inv_entropy": self.inv_entropy_calibration.export_state(),
409
+ }
410
+ }
411
+
412
+ def load_calibration(self, state: Dict):
413
+ mt = state["median_trap"]
414
+ self.top1_calibration.load_state(mt["top1"])
415
+ self.inv_entropy_calibration.load_state(mt["inv_entropy"])
416
+ self.step = self.top1_calibration.calibration_steps
417
 
418
  def detect(self, logits: torch.Tensor) -> DetectionSignal:
419
  self.step += 1