SofiTesfay2010 commited on
Commit
5d94b1c
·
verified ·
1 Parent(s): b2bebdd

v0.3: core with save/load calibration + auto-tune

Browse files
Files changed (1) hide show
  1. aria_llm/core.py +147 -50
aria_llm/core.py CHANGED
@@ -1,21 +1,29 @@
1
  """
2
- ARIA Core Module v0.2
3
  ======================
4
 
5
- Key v0.2 changes:
6
- - Correction budget: at most max_corrections_per_step correctors fire per step.
7
- The highest-severity signal wins. This prevents corrector interference.
8
- - All correctors receive the global correction_scale from config.
9
- - Extended calibration phase: no corrections during first calibration_steps.
10
- - Reliability estimation only counts TRIGGERED signals (not every detection).
 
11
 
12
  Usage:
13
  from aria_llm import ARIA, ARIAConfig
14
- config = ARIAConfig(calibration_steps=20, sensitivity_k=2.5,
15
- max_corrections_per_step=1, correction_scale=0.1, verbose=True)
 
16
  aria = ARIA.attach(model, tokenizer, config=config)
17
  output = model.generate(input_ids, max_new_tokens=500)
18
- print(aria.report_text())
 
 
 
 
 
 
19
  aria.detach()
20
  """
21
 
@@ -25,6 +33,8 @@ from typing import Optional, Dict, List, Tuple, Any
25
  from collections import deque
26
  import time
27
  import json
 
 
28
 
29
  from aria_llm.config import ARIAConfig
30
  from aria_llm.detectors import (
@@ -66,16 +76,12 @@ class ARIAState:
66
 
67
 
68
  class ARIA:
69
- """Adaptive Reliability & Integrity Attachment v0.2.
70
 
71
  Hooks into a HuggingFace Transformers model to provide real-time
72
- detection and correction of four failure modes:
73
- 1. Compound Error Accumulation
74
- 2. Semantic Drift
75
- 3. Logic Looping
76
- 4. Median Trap (Lack of "Taste")
77
 
78
- v0.2: Calibration-first, budget-limited, statistically-grounded.
79
  """
80
 
81
  def __init__(self, model, tokenizer, config: Optional[ARIAConfig] = None):
@@ -114,6 +120,11 @@ class ARIA:
114
  self._last_median_signal: Optional[DetectionSignal] = None
115
  self._last_drift_signal: Optional[DetectionSignal] = None
116
  self._model_info = self._detect_architecture()
 
 
 
 
 
117
 
118
  @classmethod
119
  def attach(cls, model, tokenizer, config: Optional[ARIAConfig] = None) -> 'ARIA':
@@ -146,7 +157,101 @@ class ARIA:
146
  self._last_loop_signal = None
147
  self._last_median_signal = None
148
  self._last_drift_signal = None
 
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  def _can_correct(self) -> bool:
151
  return self._step_corrections_this_step < self.config.max_corrections_per_step
152
 
@@ -155,7 +260,7 @@ class ARIA:
155
 
156
  def _detect_architecture(self) -> Dict:
157
  info = {"arch": "unknown", "num_layers": 0, "hidden_dim": 0, "layers_attr": None}
158
- for attr in ["model.layers", "transformer.h", "gpt_neox.layers",
159
  "model.decoder.layers", "encoder.layer"]:
160
  parts = attr.split(".")
161
  obj = self.model
@@ -195,8 +300,6 @@ class ARIA:
195
  def _install_hooks(self):
196
  layers = self._get_layers_module()
197
  if layers is None:
198
- if self.config.verbose:
199
- print("[ARIA] Warning: Could not detect model layers. Logits-only mode.")
200
  self._install_output_hook()
201
  self._attached = True
202
  return
@@ -239,6 +342,12 @@ class ARIA:
239
  self._current_step_id = step_id
240
  self._step_corrections_this_step = 0
241
 
 
 
 
 
 
 
242
  drift_signal = self.drift_detector.detect(h)
243
  self._last_drift_signal = drift_signal
244
  self.state.record_signal(drift_signal)
@@ -246,12 +355,12 @@ class ARIA:
246
  candidates = []
247
  if drift_signal.triggered and self._can_correct():
248
  candidates.append(("goal_anchor", drift_signal.severity, "drift"))
249
- if (self._last_compound_signal is not None and
250
  self._last_compound_signal.triggered and self._can_correct()):
251
  candidates.append(("steering", self._last_compound_signal.severity, "compound"))
252
  else:
253
  self.steering_corrector.update_good_state(h)
254
- if (self._last_loop_signal is not None and
255
  self._last_loop_signal.triggered and self._can_correct()):
256
  candidates.append(("trajectory_diverger", self._last_loop_signal.severity, "loop"))
257
 
@@ -350,14 +459,9 @@ class ARIA:
350
 
351
  def report(self) -> Dict:
352
  n = self.state.step
353
- avg_r_with_aria = sum(self.state.effective_r) / len(self.state.effective_r) if self.state.effective_r else 1.0
354
- baseline_r_list = getattr(self.state, 'baseline_r', [])
355
- baseline_r = sum(baseline_r_list) / len(baseline_r_list) if baseline_r_list else 0.95
356
-
357
- import math
358
  n_steps = max(n, 1)
359
- p_s_baseline = baseline_r ** n_steps if baseline_r > 0 else 0
360
- p_s_aria = avg_r_with_aria ** n_steps if avg_r_with_aria > 0 else 0
361
 
362
  correction_counts = {}
363
  for c in self.state.corrections:
@@ -371,27 +475,24 @@ class ARIA:
371
 
372
  return {
373
  "summary": {
374
- "version": "0.2.0", "total_steps": n_steps,
375
  "calibration_steps": self.config.calibration_steps,
376
  "sensitivity_k": self.config.sensitivity_k,
377
  "correction_scale": self.config.correction_scale,
378
  "max_corrections_per_step": self.config.max_corrections_per_step,
379
- "baseline_R": round(baseline_r, 4), "aria_R": round(avg_r_with_aria, 4),
380
- "R_improvement": round(avg_r_with_aria - baseline_r, 4),
381
- "baseline_P_success": f"{p_s_baseline:.6e}",
382
- "aria_P_success": f"{p_s_aria:.6e}",
383
- "improvement_factor": round(p_s_aria / max(p_s_baseline, 1e-300), 2),
 
 
384
  "total_corrections": len(self.state.corrections),
385
- "total_signals_checked": len(self.state.signals),
386
  "elapsed_seconds": round(time.time() - self.state.start_time, 2),
387
  },
388
  "corrections_by_type": correction_counts,
389
  "signals_detected": signal_counts,
390
  "signals_triggered": trigger_counts,
391
- "reliability_curve": {
392
- "per_step_R": [round(r, 4) for r in self.state.effective_r[-50:]],
393
- "cumulative_R": [round(r, 6) for r in self.state.cumulative_r[-50:]],
394
- },
395
  "calibration_info": {
396
  "compound_error": {"mean": self.compound_detector.calibration.mean,
397
  "std": self.compound_detector.calibration.std,
@@ -410,12 +511,10 @@ class ARIA:
410
  r = self.report()
411
  s = r["summary"]
412
  lines = [
413
- "=" * 60, " ARIA v0.2 RELIABILITY REPORT", "=" * 60, "",
414
  f" Steps monitored: {s['total_steps']}",
415
- f" Calibration steps: {s['calibration_steps']}",
416
- f" Sensitivity (k): {s['sensitivity_k']}",
417
- f" Correction scale: {s['correction_scale']}",
418
- f" Max corrections/step: {s['max_corrections_per_step']}",
419
  f" Time elapsed: {s['elapsed_seconds']}s", "",
420
  " RELIABILITY (R per step):",
421
  f" Baseline (no ARIA): {s['baseline_R']}",
@@ -436,13 +535,11 @@ class ARIA:
436
  for name, count in r["signals_triggered"].items():
437
  total = r["signals_detected"].get(name, count)
438
  lines.append(f" {name}: {count}/{total} ({count/max(total,1)*100:.1f}% of checks)")
439
- lines += ["", " CALIBRATION BASELINES:"]
440
- for det_name, cal in r["calibration_info"].items():
441
- if cal["mean"] is not None:
442
- lines.append(f" {det_name}: mean={cal['mean']:.4f}, std={cal['std']:.4f}, threshold={cal['threshold']:.4f}")
443
  lines += ["", "=" * 60]
444
  return "\n".join(lines)
445
 
446
  def __repr__(self):
447
  status = "attached" if self._attached else "detached"
448
- return f"ARIA(status={status}, v=0.2, layers={len(self._hooks)} hooks, corrections={len(self.state.corrections)})"
 
 
 
1
  """
2
+ ARIA Core Module v0.3
3
  ======================
4
 
5
+ v0.3 changes:
6
+ - save_calibration() / load_calibration(): Persist calibration profiles as JSON.
7
+ Skip the calibration phase on subsequent runs with the same model.
8
+ - auto_tune_correction_scale(): After calibration, automatically set correction_scale
9
+ based on the observed signal variances. High-variance models get gentler corrections.
10
+ - Calibration profile includes model fingerprint (name + hidden_dim + num_layers)
11
+ for safety checking.
12
 
13
  Usage:
14
  from aria_llm import ARIA, ARIAConfig
15
+
16
+ # First run: calibrate and save
17
+ config = ARIAConfig(auto_tune_correction_scale=True, verbose=True)
18
  aria = ARIA.attach(model, tokenizer, config=config)
19
  output = model.generate(input_ids, max_new_tokens=500)
20
+ aria.save_calibration("profiles/my_model.json")
21
+ aria.detach()
22
+
23
+ # Subsequent runs: load profile (instant, no calibration needed)
24
+ aria = ARIA.attach(model, tokenizer, config=ARIAConfig(
25
+ calibration_profile_path="profiles/my_model.json"))
26
+ output = model.generate(...)
27
  aria.detach()
28
  """
29
 
 
33
  from collections import deque
34
  import time
35
  import json
36
+ import os
37
+ import hashlib
38
 
39
  from aria_llm.config import ARIAConfig
40
  from aria_llm.detectors import (
 
76
 
77
 
78
  class ARIA:
79
+ """Adaptive Reliability & Integrity Attachment v0.3.
80
 
81
  Hooks into a HuggingFace Transformers model to provide real-time
82
+ detection and correction of four failure modes.
 
 
 
 
83
 
84
+ v0.3: Calibration profiles + auto-tune correction_scale.
85
  """
86
 
87
  def __init__(self, model, tokenizer, config: Optional[ARIAConfig] = None):
 
120
  self._last_median_signal: Optional[DetectionSignal] = None
121
  self._last_drift_signal: Optional[DetectionSignal] = None
122
  self._model_info = self._detect_architecture()
123
+ self._calibration_loaded = False
124
+ self._auto_tuned = False
125
+
126
+ if self.config.calibration_profile_path:
127
+ self.load_calibration(self.config.calibration_profile_path)
128
 
129
  @classmethod
130
  def attach(cls, model, tokenizer, config: Optional[ARIAConfig] = None) -> 'ARIA':
 
157
  self._last_loop_signal = None
158
  self._last_median_signal = None
159
  self._last_drift_signal = None
160
+ self._auto_tuned = False
161
 
162
+ def _model_fingerprint(self) -> Dict:
163
+ model_config = getattr(self.model, "config", None)
164
+ name = getattr(model_config, "_name_or_path", "unknown") if model_config else "unknown"
165
+ return {
166
+ "model_name": name,
167
+ "num_layers": self._model_info["num_layers"],
168
+ "hidden_dim": self._model_info["hidden_dim"],
169
+ "fingerprint_hash": hashlib.md5(
170
+ f"{name}_{self._model_info['num_layers']}_{self._model_info['hidden_dim']}".encode()
171
+ ).hexdigest()[:12],
172
+ }
173
+
174
+ def save_calibration(self, path: str):
175
+ os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
176
+ profile = {
177
+ "aria_version": "0.3.0",
178
+ "saved_at": time.strftime("%Y-%m-%dT%H:%M:%S%z"),
179
+ "model": self._model_fingerprint(),
180
+ "config": {
181
+ "calibration_steps": self.config.calibration_steps,
182
+ "sensitivity_k": self.config.sensitivity_k,
183
+ "correction_scale": self.config.correction_scale,
184
+ "max_corrections_per_step": self.config.max_corrections_per_step,
185
+ "auto_tuned": self._auto_tuned,
186
+ },
187
+ "detectors": {},
188
+ }
189
+ profile["detectors"].update(self.compound_detector.export_calibration())
190
+ profile["detectors"].update(self.drift_detector.export_calibration())
191
+ profile["detectors"].update(self.loop_detector.export_calibration())
192
+ profile["detectors"].update(self.median_detector.export_calibration())
193
+ with open(path, "w") as f:
194
+ json.dump(profile, f, indent=2, default=str)
195
+ if self.config.verbose:
196
+ print(f"[ARIA] Calibration profile saved to {path}")
197
+ return profile
198
+
199
+ def load_calibration(self, path: str):
200
+ if not os.path.exists(path):
201
+ raise FileNotFoundError(f"Calibration profile not found: {path}")
202
+ with open(path, "r") as f:
203
+ profile = json.load(f)
204
+ saved_fp = profile.get("model", {})
205
+ current_fp = self._model_fingerprint()
206
+ if (saved_fp.get("num_layers") != current_fp["num_layers"] or
207
+ saved_fp.get("hidden_dim") != current_fp["hidden_dim"]):
208
+ raise ValueError(
209
+ f"Calibration profile mismatch! Saved for layers={saved_fp.get('num_layers')}, "
210
+ f"dim={saved_fp.get('hidden_dim')}. Current: layers={current_fp['num_layers']}, "
211
+ f"dim={current_fp['hidden_dim']}")
212
+ detectors = profile.get("detectors", {})
213
+ if "compound_error" in detectors:
214
+ self.compound_detector.load_calibration(detectors)
215
+ if "semantic_drift" in detectors:
216
+ self.drift_detector.load_calibration(detectors)
217
+ if "logic_loop" in detectors:
218
+ self.loop_detector.load_calibration(detectors)
219
+ if "median_trap" in detectors:
220
+ self.median_detector.load_calibration(detectors)
221
+ saved_config = profile.get("config", {})
222
+ if saved_config.get("auto_tuned") and "correction_scale" in saved_config:
223
+ self.config.correction_scale = saved_config["correction_scale"]
224
+ self._update_corrector_scales(self.config.correction_scale)
225
+ self._auto_tuned = True
226
+ self._calibration_loaded = True
227
+ if self.config.verbose:
228
+ print(f"[ARIA] Calibration profile loaded from {path}")
229
+
230
+ def auto_tune_correction_scale(self) -> float:
231
+ cvs = []
232
+ for cal in [self.compound_detector.calibration, self.drift_detector.calibration,
233
+ self.median_detector.top1_calibration, self.median_detector.inv_entropy_calibration]:
234
+ if cal.mean is not None and cal.std is not None and abs(cal.mean) > 1e-8:
235
+ cvs.append(cal.std / abs(cal.mean))
236
+ if not cvs:
237
+ return self.config.correction_scale
238
+ avg_cv = sum(cvs) / len(cvs)
239
+ new_scale = max(self.config.auto_tune_min_scale,
240
+ min(self.config.auto_tune_max_scale, 0.15 / (1.0 + avg_cv)))
241
+ old_scale = self.config.correction_scale
242
+ self.config.correction_scale = new_scale
243
+ self._update_corrector_scales(new_scale)
244
+ self._auto_tuned = True
245
+ if self.config.verbose:
246
+ print(f"[ARIA] Auto-tune correction_scale: {old_scale:.4f} -> {new_scale:.4f} (avg_cv={avg_cv:.3f})")
247
+ return new_scale
248
+
249
+ def _update_corrector_scales(self, scale: float):
250
+ self.steering_corrector.correction_scale = scale
251
+ self.goal_anchor.correction_scale = scale
252
+ self.trajectory_diverger.correction_scale = scale
253
+ self.taste_amplifier.correction_scale = scale
254
+
255
  def _can_correct(self) -> bool:
256
  return self._step_corrections_this_step < self.config.max_corrections_per_step
257
 
 
260
 
261
  def _detect_architecture(self) -> Dict:
262
  info = {"arch": "unknown", "num_layers": 0, "hidden_dim": 0, "layers_attr": None}
263
+ for attr in ["model.layers", "transformer.h", "gpt_neox.layers",
264
  "model.decoder.layers", "encoder.layer"]:
265
  parts = attr.split(".")
266
  obj = self.model
 
300
  def _install_hooks(self):
301
  layers = self._get_layers_module()
302
  if layers is None:
 
 
303
  self._install_output_hook()
304
  self._attached = True
305
  return
 
342
  self._current_step_id = step_id
343
  self._step_corrections_this_step = 0
344
 
345
+ # Auto-tune after calibration completes (once)
346
+ if (self.config.auto_tune_correction_scale and
347
+ not self._auto_tuned and not self._calibration_loaded and
348
+ step_id == self.config.calibration_steps + 1):
349
+ self.auto_tune_correction_scale()
350
+
351
  drift_signal = self.drift_detector.detect(h)
352
  self._last_drift_signal = drift_signal
353
  self.state.record_signal(drift_signal)
 
355
  candidates = []
356
  if drift_signal.triggered and self._can_correct():
357
  candidates.append(("goal_anchor", drift_signal.severity, "drift"))
358
+ if (self._last_compound_signal is not None and
359
  self._last_compound_signal.triggered and self._can_correct()):
360
  candidates.append(("steering", self._last_compound_signal.severity, "compound"))
361
  else:
362
  self.steering_corrector.update_good_state(h)
363
+ if (self._last_loop_signal is not None and
364
  self._last_loop_signal.triggered and self._can_correct()):
365
  candidates.append(("trajectory_diverger", self._last_loop_signal.severity, "loop"))
366
 
 
459
 
460
  def report(self) -> Dict:
461
  n = self.state.step
462
+ avg_r = sum(self.state.effective_r) / len(self.state.effective_r) if self.state.effective_r else 1.0
463
+ bl_r = sum(self.state.baseline_r) / len(self.state.baseline_r) if self.state.baseline_r else 0.95
 
 
 
464
  n_steps = max(n, 1)
 
 
465
 
466
  correction_counts = {}
467
  for c in self.state.corrections:
 
475
 
476
  return {
477
  "summary": {
478
+ "version": "0.3.0", "total_steps": n_steps,
479
  "calibration_steps": self.config.calibration_steps,
480
  "sensitivity_k": self.config.sensitivity_k,
481
  "correction_scale": self.config.correction_scale,
482
  "max_corrections_per_step": self.config.max_corrections_per_step,
483
+ "auto_tuned": self._auto_tuned,
484
+ "calibration_loaded": self._calibration_loaded,
485
+ "baseline_R": round(bl_r, 4), "aria_R": round(avg_r, 4),
486
+ "R_improvement": round(avg_r - bl_r, 4),
487
+ "baseline_P_success": f"{bl_r ** n_steps:.6e}",
488
+ "aria_P_success": f"{avg_r ** n_steps:.6e}",
489
+ "improvement_factor": round((avg_r ** n_steps) / max(bl_r ** n_steps, 1e-300), 2),
490
  "total_corrections": len(self.state.corrections),
 
491
  "elapsed_seconds": round(time.time() - self.state.start_time, 2),
492
  },
493
  "corrections_by_type": correction_counts,
494
  "signals_detected": signal_counts,
495
  "signals_triggered": trigger_counts,
 
 
 
 
496
  "calibration_info": {
497
  "compound_error": {"mean": self.compound_detector.calibration.mean,
498
  "std": self.compound_detector.calibration.std,
 
511
  r = self.report()
512
  s = r["summary"]
513
  lines = [
514
+ "=" * 60, " ARIA v0.3 RELIABILITY REPORT", "=" * 60, "",
515
  f" Steps monitored: {s['total_steps']}",
516
+ f" Correction scale: {s['correction_scale']}" + (" (auto-tuned)" if s['auto_tuned'] else ""),
517
+ f" Calibration loaded: {s['calibration_loaded']}",
 
 
518
  f" Time elapsed: {s['elapsed_seconds']}s", "",
519
  " RELIABILITY (R per step):",
520
  f" Baseline (no ARIA): {s['baseline_R']}",
 
535
  for name, count in r["signals_triggered"].items():
536
  total = r["signals_detected"].get(name, count)
537
  lines.append(f" {name}: {count}/{total} ({count/max(total,1)*100:.1f}% of checks)")
 
 
 
 
538
  lines += ["", "=" * 60]
539
  return "\n".join(lines)
540
 
541
  def __repr__(self):
542
  status = "attached" if self._attached else "detached"
543
+ loaded = " profile-loaded" if self._calibration_loaded else ""
544
+ tuned = " auto-tuned" if self._auto_tuned else ""
545
+ return f"ARIA(status={status}, v=0.3, layers={len(self._hooks)} hooks, corrections={len(self.state.corrections)}{loaded}{tuned})"