| """ |
| ARIA Core Module v0.3 |
| ====================== |
| |
| v0.3 changes: |
| - save_calibration() / load_calibration(): Persist calibration profiles as JSON. |
| Skip the calibration phase on subsequent runs with the same model. |
| - auto_tune_correction_scale(): After calibration, automatically set correction_scale |
| based on the observed signal variances. High-variance models get gentler corrections. |
| - Calibration profile includes model fingerprint (name + hidden_dim + num_layers) |
| for safety checking. |
| |
| Usage: |
| from aria_llm import ARIA, ARIAConfig |
| |
| # First run: calibrate and save |
| config = ARIAConfig(auto_tune_correction_scale=True, verbose=True) |
| aria = ARIA.attach(model, tokenizer, config=config) |
| output = model.generate(input_ids, max_new_tokens=500) |
| aria.save_calibration("profiles/my_model.json") |
| aria.detach() |
| |
| # Subsequent runs: load profile (instant, no calibration needed) |
| aria = ARIA.attach(model, tokenizer, config=ARIAConfig( |
| calibration_profile_path="profiles/my_model.json")) |
| output = model.generate(...) |
| aria.detach() |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from typing import Optional, Dict, List, Tuple, Any |
| from collections import deque |
| import time |
| import json |
| import os |
| import hashlib |
|
|
| from aria_llm.config import ARIAConfig |
| from aria_llm.detectors import ( |
| CompoundErrorDetector, SemanticDriftDetector, |
| LogicLoopDetector, MedianTrapDetector, DetectionSignal, |
| ) |
| from aria_llm.correctors import ( |
| SteeringCorrector, GoalAnchor, TrajectoryDiverger, TasteAmplifier, |
| ) |
|
|
|
|
| class ARIAState: |
| """Runtime state for a single generation pass.""" |
| def __init__(self): |
| self.step = 0 |
| self.signals: List[Dict] = [] |
| self.corrections: List[Dict] = [] |
| self.start_time = time.time() |
| self.effective_r: List[float] = [] |
| self.cumulative_r: List[float] = [] |
| self.baseline_r: List[float] = [] |
| |
| def record_signal(self, signal: DetectionSignal): |
| self.signals.append({ |
| "step": self.step, "name": signal.name, "severity": signal.severity, |
| "triggered": signal.triggered, "raw_value": signal.raw_value, |
| "metadata": signal.metadata, |
| }) |
| |
| def record_correction(self, name: str, strength: float): |
| self.corrections.append({"step": self.step, "corrector": name, "strength": strength}) |
| |
| def record_reliability(self, r: float): |
| self.effective_r.append(r) |
| if self.cumulative_r: |
| self.cumulative_r.append(self.cumulative_r[-1] * r) |
| else: |
| self.cumulative_r.append(r) |
|
|
|
|
| class ARIA: |
| """Adaptive Reliability & Integrity Attachment v0.3. |
| |
| Hooks into a HuggingFace Transformers model to provide real-time |
| detection and correction of four failure modes. |
| |
| v0.3: Calibration profiles + auto-tune correction_scale. |
| """ |
| |
| def __init__(self, model, tokenizer, config: Optional[ARIAConfig] = None): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.config = config or ARIAConfig() |
| cs, sk, cscale = self.config.calibration_steps, self.config.sensitivity_k, self.config.correction_scale |
| |
| self.compound_detector = CompoundErrorDetector(calibration_steps=cs, sensitivity_k=sk, |
| window=self.config.instability_window, lam=self.config.instability_lambda, |
| fallback_threshold=self.config.compound_error_threshold) |
| self.drift_detector = SemanticDriftDetector(calibration_steps=cs, sensitivity_k=sk, |
| window=self.config.drift_window, fallback_threshold=self.config.drift_threshold) |
| self.loop_detector = LogicLoopDetector(calibration_steps=cs, sensitivity_k=sk, |
| window=self.config.loop_window, similarity_threshold=self.config.loop_similarity_threshold, |
| entropy_var_threshold=self.config.loop_entropy_variance_threshold, |
| max_breaks=self.config.max_loop_breaks) |
| self.median_detector = MedianTrapDetector(calibration_steps=cs, sensitivity_k=sk, |
| temperature_boost=self.config.taste_temperature_boost, novelty_bonus=self.config.novelty_bonus) |
| |
| self.steering_corrector = SteeringCorrector(alpha=self.config.steering_alpha, correction_scale=cscale) |
| self.goal_anchor = GoalAnchor(drift_threshold=self.config.drift_threshold, |
| correction_strength=0.2, reanchor_interval=self.config.goal_reanchor_interval, correction_scale=cscale) |
| self.trajectory_diverger = TrajectoryDiverger(divergence_strength=0.5, |
| max_breaks=self.config.max_loop_breaks, correction_scale=cscale) |
| self.taste_amplifier = TasteAmplifier(temperature_boost=self.config.taste_temperature_boost, |
| novelty_bonus=self.config.novelty_bonus, correction_scale=cscale) |
| |
| self._hooks: List = [] |
| self._attached = False |
| self.state = ARIAState() |
| self._step_corrections_this_step = 0 |
| self._current_step_id = -1 |
| self._last_compound_signal: Optional[DetectionSignal] = None |
| self._last_loop_signal: Optional[DetectionSignal] = None |
| self._last_median_signal: Optional[DetectionSignal] = None |
| self._last_drift_signal: Optional[DetectionSignal] = None |
| self._model_info = self._detect_architecture() |
| self._calibration_loaded = False |
| self._auto_tuned = False |
| |
| if self.config.calibration_profile_path: |
| self.load_calibration(self.config.calibration_profile_path) |
| |
| @classmethod |
| def attach(cls, model, tokenizer, config: Optional[ARIAConfig] = None) -> 'ARIA': |
| aria = cls(model, tokenizer, config) |
| aria._install_hooks() |
| return aria |
| |
| def detach(self): |
| for hook in self._hooks: |
| hook.remove() |
| self._hooks.clear() |
| self._attached = False |
| if self.config.verbose: |
| print(f"[ARIA] Detached. Applied {len(self.state.corrections)} corrections " |
| f"over {self.state.step} steps.") |
| |
| def reset(self): |
| self.compound_detector.reset() |
| self.drift_detector.reset() |
| self.loop_detector.reset() |
| self.median_detector.reset() |
| self.steering_corrector.reset() |
| self.goal_anchor.reset() |
| self.trajectory_diverger.reset() |
| self.taste_amplifier.reset() |
| self.state = ARIAState() |
| self._step_corrections_this_step = 0 |
| self._current_step_id = -1 |
| self._last_compound_signal = None |
| self._last_loop_signal = None |
| self._last_median_signal = None |
| self._last_drift_signal = None |
| self._auto_tuned = False |
| |
| def _model_fingerprint(self) -> Dict: |
| model_config = getattr(self.model, "config", None) |
| name = getattr(model_config, "_name_or_path", "unknown") if model_config else "unknown" |
| return { |
| "model_name": name, |
| "num_layers": self._model_info["num_layers"], |
| "hidden_dim": self._model_info["hidden_dim"], |
| "fingerprint_hash": hashlib.md5( |
| f"{name}_{self._model_info['num_layers']}_{self._model_info['hidden_dim']}".encode() |
| ).hexdigest()[:12], |
| } |
| |
| def save_calibration(self, path: str): |
| os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) |
| profile = { |
| "aria_version": "0.3.0", |
| "saved_at": time.strftime("%Y-%m-%dT%H:%M:%S%z"), |
| "model": self._model_fingerprint(), |
| "config": { |
| "calibration_steps": self.config.calibration_steps, |
| "sensitivity_k": self.config.sensitivity_k, |
| "correction_scale": self.config.correction_scale, |
| "max_corrections_per_step": self.config.max_corrections_per_step, |
| "auto_tuned": self._auto_tuned, |
| }, |
| "detectors": {}, |
| } |
| profile["detectors"].update(self.compound_detector.export_calibration()) |
| profile["detectors"].update(self.drift_detector.export_calibration()) |
| profile["detectors"].update(self.loop_detector.export_calibration()) |
| profile["detectors"].update(self.median_detector.export_calibration()) |
| with open(path, "w") as f: |
| json.dump(profile, f, indent=2, default=str) |
| if self.config.verbose: |
| print(f"[ARIA] Calibration profile saved to {path}") |
| return profile |
| |
| def load_calibration(self, path: str): |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"Calibration profile not found: {path}") |
| with open(path, "r") as f: |
| profile = json.load(f) |
| saved_fp = profile.get("model", {}) |
| current_fp = self._model_fingerprint() |
| if (saved_fp.get("num_layers") != current_fp["num_layers"] or |
| saved_fp.get("hidden_dim") != current_fp["hidden_dim"]): |
| raise ValueError( |
| f"Calibration profile mismatch! Saved for layers={saved_fp.get('num_layers')}, " |
| f"dim={saved_fp.get('hidden_dim')}. Current: layers={current_fp['num_layers']}, " |
| f"dim={current_fp['hidden_dim']}") |
| detectors = profile.get("detectors", {}) |
| if "compound_error" in detectors: |
| self.compound_detector.load_calibration(detectors) |
| if "semantic_drift" in detectors: |
| self.drift_detector.load_calibration(detectors) |
| if "logic_loop" in detectors: |
| self.loop_detector.load_calibration(detectors) |
| if "median_trap" in detectors: |
| self.median_detector.load_calibration(detectors) |
| saved_config = profile.get("config", {}) |
| if saved_config.get("auto_tuned") and "correction_scale" in saved_config: |
| self.config.correction_scale = saved_config["correction_scale"] |
| self._update_corrector_scales(self.config.correction_scale) |
| self._auto_tuned = True |
| self._calibration_loaded = True |
| if self.config.verbose: |
| print(f"[ARIA] Calibration profile loaded from {path}") |
| |
| def auto_tune_correction_scale(self) -> float: |
| cvs = [] |
| for cal in [self.compound_detector.calibration, self.drift_detector.calibration, |
| self.median_detector.top1_calibration, self.median_detector.inv_entropy_calibration]: |
| if cal.mean is not None and cal.std is not None and abs(cal.mean) > 1e-8: |
| cvs.append(cal.std / abs(cal.mean)) |
| if not cvs: |
| return self.config.correction_scale |
| avg_cv = sum(cvs) / len(cvs) |
| new_scale = max(self.config.auto_tune_min_scale, |
| min(self.config.auto_tune_max_scale, 0.15 / (1.0 + avg_cv))) |
| old_scale = self.config.correction_scale |
| self.config.correction_scale = new_scale |
| self._update_corrector_scales(new_scale) |
| self._auto_tuned = True |
| if self.config.verbose: |
| print(f"[ARIA] Auto-tune correction_scale: {old_scale:.4f} -> {new_scale:.4f} (avg_cv={avg_cv:.3f})") |
| return new_scale |
| |
| def _update_corrector_scales(self, scale: float): |
| self.steering_corrector.correction_scale = scale |
| self.goal_anchor.correction_scale = scale |
| self.trajectory_diverger.correction_scale = scale |
| self.taste_amplifier.correction_scale = scale |
| |
| def _can_correct(self) -> bool: |
| return self._step_corrections_this_step < self.config.max_corrections_per_step |
| |
| def _use_correction_budget(self): |
| self._step_corrections_this_step += 1 |
| |
| def _detect_architecture(self) -> Dict: |
| info = {"arch": "unknown", "num_layers": 0, "hidden_dim": 0, "layers_attr": None} |
| for attr in ["model.layers", "transformer.h", "gpt_neox.layers", |
| "model.decoder.layers", "encoder.layer"]: |
| parts = attr.split(".") |
| obj = self.model |
| try: |
| for part in parts: |
| obj = getattr(obj, part) |
| info["layers_attr"] = attr |
| info["num_layers"] = len(obj) |
| break |
| except AttributeError: |
| continue |
| model_config = getattr(self.model, "config", None) |
| if model_config: |
| for dim_attr in ["hidden_size", "d_model", "n_embd"]: |
| if hasattr(model_config, dim_attr): |
| info["hidden_dim"] = getattr(model_config, dim_attr) |
| break |
| if self.config.verbose: |
| print(f"[ARIA] Detected architecture: {info}") |
| return info |
| |
| def _get_target_layers(self) -> list: |
| if self.config.steering_layers is not None: |
| return self.config.steering_layers |
| n = self._model_info["num_layers"] |
| return [n // 2] if n > 0 else [] |
| |
| def _get_layers_module(self): |
| if self._model_info["layers_attr"] is None: |
| return None |
| parts = self._model_info["layers_attr"].split(".") |
| obj = self.model |
| for part in parts: |
| obj = getattr(obj, part) |
| return obj |
| |
| def _install_hooks(self): |
| layers = self._get_layers_module() |
| if layers is None: |
| self._install_output_hook() |
| self._attached = True |
| return |
| target_indices = self._get_target_layers() |
| for idx in target_indices: |
| if idx < len(layers): |
| self._hooks.append(layers[idx].register_forward_hook(self._make_layer_hook(idx))) |
| self._install_output_hook() |
| self._attached = True |
| if self.config.verbose: |
| print(f"[ARIA] Installed {len(self._hooks)} hooks on layers {target_indices}") |
| |
| def _install_output_hook(self): |
| for attr in ["lm_head", "output", "cls"]: |
| if hasattr(self.model, attr): |
| self._hooks.append(getattr(self.model, attr).register_forward_hook(self._logit_hook)) |
| return |
| |
| def _make_layer_hook(self, layer_idx: int): |
| def hook_fn(module, input, output): |
| if not self.config.enabled: |
| return output |
| if isinstance(output, tuple): |
| hidden_states = output[0] |
| else: |
| hidden_states = output |
| if hidden_states is None or not isinstance(hidden_states, torch.Tensor): |
| return output |
| if hidden_states.dim() == 3: |
| last_hidden = hidden_states[:, -1, :] |
| elif hidden_states.dim() == 2: |
| last_hidden = hidden_states[-1:] |
| else: |
| return output |
| h = last_hidden[0] if last_hidden.dim() > 1 else last_hidden |
| |
| self.state.step += 1 |
| step_id = self.state.step |
| if step_id != self._current_step_id: |
| self._current_step_id = step_id |
| self._step_corrections_this_step = 0 |
| |
| |
| if (self.config.auto_tune_correction_scale and |
| not self._auto_tuned and not self._calibration_loaded and |
| step_id == self.config.calibration_steps + 1): |
| self.auto_tune_correction_scale() |
| |
| drift_signal = self.drift_detector.detect(h) |
| self._last_drift_signal = drift_signal |
| self.state.record_signal(drift_signal) |
| |
| candidates = [] |
| if drift_signal.triggered and self._can_correct(): |
| candidates.append(("goal_anchor", drift_signal.severity, "drift")) |
| if (self._last_compound_signal is not None and |
| self._last_compound_signal.triggered and self._can_correct()): |
| candidates.append(("steering", self._last_compound_signal.severity, "compound")) |
| else: |
| self.steering_corrector.update_good_state(h) |
| if (self._last_loop_signal is not None and |
| self._last_loop_signal.triggered and self._can_correct()): |
| candidates.append(("trajectory_diverger", self._last_loop_signal.severity, "loop")) |
| |
| corrected_hidden = hidden_states |
| if candidates: |
| candidates.sort(key=lambda x: x[1], reverse=True) |
| for corrector_name, severity, signal_type in candidates: |
| if not self._can_correct(): |
| break |
| if hidden_states.dim() != 3: |
| continue |
| if corrector_name == "goal_anchor": |
| h_c = self.goal_anchor.correct(corrected_hidden[:, -1, :], severity) |
| elif corrector_name == "steering": |
| h_c = self.steering_corrector.correct(corrected_hidden[:, -1, :], severity) |
| elif corrector_name == "trajectory_diverger": |
| h_c = self.trajectory_diverger.correct(corrected_hidden[:, -1, :], severity) |
| else: |
| continue |
| corrected_hidden = corrected_hidden.clone() |
| corrected_hidden[:, -1, :] = h_c |
| self.state.record_correction(corrector_name, severity) |
| self._use_correction_budget() |
| |
| triggered_severities = [] |
| if drift_signal.triggered: |
| triggered_severities.append(drift_signal.severity) |
| if self._last_compound_signal is not None and self._last_compound_signal.triggered: |
| triggered_severities.append(self._last_compound_signal.severity) |
| if self._last_loop_signal is not None and self._last_loop_signal.triggered: |
| triggered_severities.append(self._last_loop_signal.severity) |
| if self._last_median_signal is not None and self._last_median_signal.triggered: |
| triggered_severities.append(self._last_median_signal.severity) |
| |
| if triggered_severities: |
| max_severity = max(triggered_severities) |
| baseline_step_r = 1.0 - (max_severity * 0.15) |
| correction_applied = self._step_corrections_this_step > 0 |
| if correction_applied: |
| effective_severity = max_severity * max(0.0, 1.0 - self.config.correction_scale * 3.0) |
| else: |
| effective_severity = max_severity |
| aria_step_r = 1.0 - (effective_severity * 0.15) |
| else: |
| baseline_step_r = 1.0 |
| aria_step_r = 1.0 |
| |
| self.state.record_reliability(aria_step_r) |
| self.state.baseline_r.append(baseline_step_r) |
| |
| if isinstance(output, tuple): |
| return (corrected_hidden,) + output[1:] |
| return corrected_hidden |
| return hook_fn |
| |
| def _logit_hook(self, module, input, output): |
| if not self.config.enabled: |
| return output |
| logits = output[0] if isinstance(output, tuple) else output |
| if not isinstance(logits, torch.Tensor): |
| return output |
| if logits.dim() == 3: |
| last_logits = logits[:, -1, :] |
| elif logits.dim() == 2: |
| last_logits = logits |
| else: |
| return output |
| l = last_logits[0] |
| |
| compound_signal = self.compound_detector.detect(l) |
| self._last_compound_signal = compound_signal |
| self.state.record_signal(compound_signal) |
| |
| loop_signal = self.loop_detector.detect(l, l) |
| self._last_loop_signal = loop_signal |
| self.state.record_signal(loop_signal) |
| |
| median_signal = self.median_detector.detect(l) |
| self._last_median_signal = median_signal |
| self.state.record_signal(median_signal) |
| |
| if median_signal.triggered and self._can_correct(): |
| corrected_logits = self.taste_amplifier.correct_logits(l, median_signal.severity) |
| if logits.dim() == 3: |
| logits = logits.clone() |
| logits[0, -1, :] = corrected_logits |
| elif logits.dim() == 2: |
| logits = logits.clone() |
| logits[0] = corrected_logits |
| self.state.record_correction("taste_amplifier", median_signal.severity) |
| self._use_correction_budget() |
| if isinstance(output, tuple): |
| return (logits,) + output[1:] |
| return logits |
| return output |
| |
| def report(self) -> Dict: |
| n = self.state.step |
| avg_r = sum(self.state.effective_r) / len(self.state.effective_r) if self.state.effective_r else 1.0 |
| bl_r = sum(self.state.baseline_r) / len(self.state.baseline_r) if self.state.baseline_r else 0.95 |
| n_steps = max(n, 1) |
| |
| correction_counts = {} |
| for c in self.state.corrections: |
| correction_counts[c["corrector"]] = correction_counts.get(c["corrector"], 0) + 1 |
| signal_counts = {} |
| trigger_counts = {} |
| for s in self.state.signals: |
| signal_counts[s["name"]] = signal_counts.get(s["name"], 0) + 1 |
| if s["triggered"]: |
| trigger_counts[s["name"]] = trigger_counts.get(s["name"], 0) + 1 |
| |
| return { |
| "summary": { |
| "version": "0.3.0", "total_steps": n_steps, |
| "calibration_steps": self.config.calibration_steps, |
| "sensitivity_k": self.config.sensitivity_k, |
| "correction_scale": self.config.correction_scale, |
| "max_corrections_per_step": self.config.max_corrections_per_step, |
| "auto_tuned": self._auto_tuned, |
| "calibration_loaded": self._calibration_loaded, |
| "baseline_R": round(bl_r, 4), "aria_R": round(avg_r, 4), |
| "R_improvement": round(avg_r - bl_r, 4), |
| "baseline_P_success": f"{bl_r ** n_steps:.6e}", |
| "aria_P_success": f"{avg_r ** n_steps:.6e}", |
| "improvement_factor": round((avg_r ** n_steps) / max(bl_r ** n_steps, 1e-300), 2), |
| "total_corrections": len(self.state.corrections), |
| "elapsed_seconds": round(time.time() - self.state.start_time, 2), |
| }, |
| "corrections_by_type": correction_counts, |
| "signals_detected": signal_counts, |
| "signals_triggered": trigger_counts, |
| "calibration_info": { |
| "compound_error": {"mean": self.compound_detector.calibration.mean, |
| "std": self.compound_detector.calibration.std, |
| "threshold": self.compound_detector.calibration.threshold}, |
| "semantic_drift": {"mean": self.drift_detector.calibration.mean, |
| "std": self.drift_detector.calibration.std, |
| "threshold": self.drift_detector.calibration.threshold}, |
| "median_trap_top1": {"mean": self.median_detector.top1_calibration.mean, |
| "std": self.median_detector.top1_calibration.std, |
| "threshold": self.median_detector.top1_calibration.threshold}, |
| }, |
| "model_info": self._model_info, |
| } |
| |
| def report_text(self) -> str: |
| r = self.report() |
| s = r["summary"] |
| lines = [ |
| "=" * 60, " ARIA v0.3 RELIABILITY REPORT", "=" * 60, "", |
| f" Steps monitored: {s['total_steps']}", |
| f" Correction scale: {s['correction_scale']}" + (" (auto-tuned)" if s['auto_tuned'] else ""), |
| f" Calibration loaded: {s['calibration_loaded']}", |
| f" Time elapsed: {s['elapsed_seconds']}s", "", |
| " RELIABILITY (R per step):", |
| f" Baseline (no ARIA): {s['baseline_R']}", |
| f" With ARIA: {s['aria_R']}", |
| f" Improvement: {'+' if s['R_improvement'] >= 0 else ''}{s['R_improvement']}", "", |
| " SUCCESS PROBABILITY (P_s = R^n):", |
| f" Baseline: {s['baseline_P_success']}", |
| f" With ARIA: {s['aria_P_success']}", |
| f" Improvement factor: {s['improvement_factor']}x", "", |
| f" Total corrections: {s['total_corrections']}", |
| ] |
| if r["corrections_by_type"]: |
| lines += ["", " CORRECTIONS APPLIED:"] |
| for name, count in r["corrections_by_type"].items(): |
| lines.append(f" {name}: {count}") |
| if r["signals_triggered"]: |
| lines += ["", " FAILURE MODES DETECTED:"] |
| for name, count in r["signals_triggered"].items(): |
| total = r["signals_detected"].get(name, count) |
| lines.append(f" {name}: {count}/{total} ({count/max(total,1)*100:.1f}% of checks)") |
| lines += ["", "=" * 60] |
| return "\n".join(lines) |
| |
| def __repr__(self): |
| status = "attached" if self._attached else "detached" |
| loaded = " profile-loaded" if self._calibration_loaded else "" |
| tuned = " auto-tuned" if self._auto_tuned else "" |
| return f"ARIA(status={status}, v=0.3, layers={len(self._hooks)} hooks, corrections={len(self.state.corrections)}{loaded}{tuned})" |
|
|