""" ARIA Core Module v0.3 ====================== v0.3 changes: - save_calibration() / load_calibration(): Persist calibration profiles as JSON. Skip the calibration phase on subsequent runs with the same model. - auto_tune_correction_scale(): After calibration, automatically set correction_scale based on the observed signal variances. High-variance models get gentler corrections. - Calibration profile includes model fingerprint (name + hidden_dim + num_layers) for safety checking. Usage: from aria_llm import ARIA, ARIAConfig # First run: calibrate and save config = ARIAConfig(auto_tune_correction_scale=True, verbose=True) aria = ARIA.attach(model, tokenizer, config=config) output = model.generate(input_ids, max_new_tokens=500) aria.save_calibration("profiles/my_model.json") aria.detach() # Subsequent runs: load profile (instant, no calibration needed) aria = ARIA.attach(model, tokenizer, config=ARIAConfig( calibration_profile_path="profiles/my_model.json")) output = model.generate(...) aria.detach() """ import torch import torch.nn.functional as F from typing import Optional, Dict, List, Tuple, Any from collections import deque import time import json import os import hashlib from aria_llm.config import ARIAConfig from aria_llm.detectors import ( CompoundErrorDetector, SemanticDriftDetector, LogicLoopDetector, MedianTrapDetector, DetectionSignal, ) from aria_llm.correctors import ( SteeringCorrector, GoalAnchor, TrajectoryDiverger, TasteAmplifier, ) class ARIAState: """Runtime state for a single generation pass.""" def __init__(self): self.step = 0 self.signals: List[Dict] = [] self.corrections: List[Dict] = [] self.start_time = time.time() self.effective_r: List[float] = [] self.cumulative_r: List[float] = [] self.baseline_r: List[float] = [] def record_signal(self, signal: DetectionSignal): self.signals.append({ "step": self.step, "name": signal.name, "severity": signal.severity, "triggered": signal.triggered, "raw_value": signal.raw_value, "metadata": signal.metadata, }) def record_correction(self, name: str, strength: float): self.corrections.append({"step": self.step, "corrector": name, "strength": strength}) def record_reliability(self, r: float): self.effective_r.append(r) if self.cumulative_r: self.cumulative_r.append(self.cumulative_r[-1] * r) else: self.cumulative_r.append(r) class ARIA: """Adaptive Reliability & Integrity Attachment v0.3. Hooks into a HuggingFace Transformers model to provide real-time detection and correction of four failure modes. v0.3: Calibration profiles + auto-tune correction_scale. """ def __init__(self, model, tokenizer, config: Optional[ARIAConfig] = None): self.model = model self.tokenizer = tokenizer self.config = config or ARIAConfig() cs, sk, cscale = self.config.calibration_steps, self.config.sensitivity_k, self.config.correction_scale self.compound_detector = CompoundErrorDetector(calibration_steps=cs, sensitivity_k=sk, window=self.config.instability_window, lam=self.config.instability_lambda, fallback_threshold=self.config.compound_error_threshold) self.drift_detector = SemanticDriftDetector(calibration_steps=cs, sensitivity_k=sk, window=self.config.drift_window, fallback_threshold=self.config.drift_threshold) self.loop_detector = LogicLoopDetector(calibration_steps=cs, sensitivity_k=sk, window=self.config.loop_window, similarity_threshold=self.config.loop_similarity_threshold, entropy_var_threshold=self.config.loop_entropy_variance_threshold, max_breaks=self.config.max_loop_breaks) self.median_detector = MedianTrapDetector(calibration_steps=cs, sensitivity_k=sk, temperature_boost=self.config.taste_temperature_boost, novelty_bonus=self.config.novelty_bonus) self.steering_corrector = SteeringCorrector(alpha=self.config.steering_alpha, correction_scale=cscale) self.goal_anchor = GoalAnchor(drift_threshold=self.config.drift_threshold, correction_strength=0.2, reanchor_interval=self.config.goal_reanchor_interval, correction_scale=cscale) self.trajectory_diverger = TrajectoryDiverger(divergence_strength=0.5, max_breaks=self.config.max_loop_breaks, correction_scale=cscale) self.taste_amplifier = TasteAmplifier(temperature_boost=self.config.taste_temperature_boost, novelty_bonus=self.config.novelty_bonus, correction_scale=cscale) self._hooks: List = [] self._attached = False self.state = ARIAState() self._step_corrections_this_step = 0 self._current_step_id = -1 self._last_compound_signal: Optional[DetectionSignal] = None self._last_loop_signal: Optional[DetectionSignal] = None self._last_median_signal: Optional[DetectionSignal] = None self._last_drift_signal: Optional[DetectionSignal] = None self._model_info = self._detect_architecture() self._calibration_loaded = False self._auto_tuned = False if self.config.calibration_profile_path: self.load_calibration(self.config.calibration_profile_path) @classmethod def attach(cls, model, tokenizer, config: Optional[ARIAConfig] = None) -> 'ARIA': aria = cls(model, tokenizer, config) aria._install_hooks() return aria def detach(self): for hook in self._hooks: hook.remove() self._hooks.clear() self._attached = False if self.config.verbose: print(f"[ARIA] Detached. Applied {len(self.state.corrections)} corrections " f"over {self.state.step} steps.") def reset(self): self.compound_detector.reset() self.drift_detector.reset() self.loop_detector.reset() self.median_detector.reset() self.steering_corrector.reset() self.goal_anchor.reset() self.trajectory_diverger.reset() self.taste_amplifier.reset() self.state = ARIAState() self._step_corrections_this_step = 0 self._current_step_id = -1 self._last_compound_signal = None self._last_loop_signal = None self._last_median_signal = None self._last_drift_signal = None self._auto_tuned = False def _model_fingerprint(self) -> Dict: model_config = getattr(self.model, "config", None) name = getattr(model_config, "_name_or_path", "unknown") if model_config else "unknown" return { "model_name": name, "num_layers": self._model_info["num_layers"], "hidden_dim": self._model_info["hidden_dim"], "fingerprint_hash": hashlib.md5( f"{name}_{self._model_info['num_layers']}_{self._model_info['hidden_dim']}".encode() ).hexdigest()[:12], } def save_calibration(self, path: str): os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True) profile = { "aria_version": "0.3.0", "saved_at": time.strftime("%Y-%m-%dT%H:%M:%S%z"), "model": self._model_fingerprint(), "config": { "calibration_steps": self.config.calibration_steps, "sensitivity_k": self.config.sensitivity_k, "correction_scale": self.config.correction_scale, "max_corrections_per_step": self.config.max_corrections_per_step, "auto_tuned": self._auto_tuned, }, "detectors": {}, } profile["detectors"].update(self.compound_detector.export_calibration()) profile["detectors"].update(self.drift_detector.export_calibration()) profile["detectors"].update(self.loop_detector.export_calibration()) profile["detectors"].update(self.median_detector.export_calibration()) with open(path, "w") as f: json.dump(profile, f, indent=2, default=str) if self.config.verbose: print(f"[ARIA] Calibration profile saved to {path}") return profile def load_calibration(self, path: str): if not os.path.exists(path): raise FileNotFoundError(f"Calibration profile not found: {path}") with open(path, "r") as f: profile = json.load(f) saved_fp = profile.get("model", {}) current_fp = self._model_fingerprint() if (saved_fp.get("num_layers") != current_fp["num_layers"] or saved_fp.get("hidden_dim") != current_fp["hidden_dim"]): raise ValueError( f"Calibration profile mismatch! Saved for layers={saved_fp.get('num_layers')}, " f"dim={saved_fp.get('hidden_dim')}. Current: layers={current_fp['num_layers']}, " f"dim={current_fp['hidden_dim']}") detectors = profile.get("detectors", {}) if "compound_error" in detectors: self.compound_detector.load_calibration(detectors) if "semantic_drift" in detectors: self.drift_detector.load_calibration(detectors) if "logic_loop" in detectors: self.loop_detector.load_calibration(detectors) if "median_trap" in detectors: self.median_detector.load_calibration(detectors) saved_config = profile.get("config", {}) if saved_config.get("auto_tuned") and "correction_scale" in saved_config: self.config.correction_scale = saved_config["correction_scale"] self._update_corrector_scales(self.config.correction_scale) self._auto_tuned = True self._calibration_loaded = True if self.config.verbose: print(f"[ARIA] Calibration profile loaded from {path}") def auto_tune_correction_scale(self) -> float: cvs = [] for cal in [self.compound_detector.calibration, self.drift_detector.calibration, self.median_detector.top1_calibration, self.median_detector.inv_entropy_calibration]: if cal.mean is not None and cal.std is not None and abs(cal.mean) > 1e-8: cvs.append(cal.std / abs(cal.mean)) if not cvs: return self.config.correction_scale avg_cv = sum(cvs) / len(cvs) new_scale = max(self.config.auto_tune_min_scale, min(self.config.auto_tune_max_scale, 0.15 / (1.0 + avg_cv))) old_scale = self.config.correction_scale self.config.correction_scale = new_scale self._update_corrector_scales(new_scale) self._auto_tuned = True if self.config.verbose: print(f"[ARIA] Auto-tune correction_scale: {old_scale:.4f} -> {new_scale:.4f} (avg_cv={avg_cv:.3f})") return new_scale def _update_corrector_scales(self, scale: float): self.steering_corrector.correction_scale = scale self.goal_anchor.correction_scale = scale self.trajectory_diverger.correction_scale = scale self.taste_amplifier.correction_scale = scale def _can_correct(self) -> bool: return self._step_corrections_this_step < self.config.max_corrections_per_step def _use_correction_budget(self): self._step_corrections_this_step += 1 def _detect_architecture(self) -> Dict: info = {"arch": "unknown", "num_layers": 0, "hidden_dim": 0, "layers_attr": None} for attr in ["model.layers", "transformer.h", "gpt_neox.layers", "model.decoder.layers", "encoder.layer"]: parts = attr.split(".") obj = self.model try: for part in parts: obj = getattr(obj, part) info["layers_attr"] = attr info["num_layers"] = len(obj) break except AttributeError: continue model_config = getattr(self.model, "config", None) if model_config: for dim_attr in ["hidden_size", "d_model", "n_embd"]: if hasattr(model_config, dim_attr): info["hidden_dim"] = getattr(model_config, dim_attr) break if self.config.verbose: print(f"[ARIA] Detected architecture: {info}") return info def _get_target_layers(self) -> list: if self.config.steering_layers is not None: return self.config.steering_layers n = self._model_info["num_layers"] return [n // 2] if n > 0 else [] def _get_layers_module(self): if self._model_info["layers_attr"] is None: return None parts = self._model_info["layers_attr"].split(".") obj = self.model for part in parts: obj = getattr(obj, part) return obj def _install_hooks(self): layers = self._get_layers_module() if layers is None: self._install_output_hook() self._attached = True return target_indices = self._get_target_layers() for idx in target_indices: if idx < len(layers): self._hooks.append(layers[idx].register_forward_hook(self._make_layer_hook(idx))) self._install_output_hook() self._attached = True if self.config.verbose: print(f"[ARIA] Installed {len(self._hooks)} hooks on layers {target_indices}") def _install_output_hook(self): for attr in ["lm_head", "output", "cls"]: if hasattr(self.model, attr): self._hooks.append(getattr(self.model, attr).register_forward_hook(self._logit_hook)) return def _make_layer_hook(self, layer_idx: int): def hook_fn(module, input, output): if not self.config.enabled: return output if isinstance(output, tuple): hidden_states = output[0] else: hidden_states = output if hidden_states is None or not isinstance(hidden_states, torch.Tensor): return output if hidden_states.dim() == 3: last_hidden = hidden_states[:, -1, :] elif hidden_states.dim() == 2: last_hidden = hidden_states[-1:] else: return output h = last_hidden[0] if last_hidden.dim() > 1 else last_hidden self.state.step += 1 step_id = self.state.step if step_id != self._current_step_id: self._current_step_id = step_id self._step_corrections_this_step = 0 # Auto-tune after calibration completes (once) if (self.config.auto_tune_correction_scale and not self._auto_tuned and not self._calibration_loaded and step_id == self.config.calibration_steps + 1): self.auto_tune_correction_scale() drift_signal = self.drift_detector.detect(h) self._last_drift_signal = drift_signal self.state.record_signal(drift_signal) candidates = [] if drift_signal.triggered and self._can_correct(): candidates.append(("goal_anchor", drift_signal.severity, "drift")) if (self._last_compound_signal is not None and self._last_compound_signal.triggered and self._can_correct()): candidates.append(("steering", self._last_compound_signal.severity, "compound")) else: self.steering_corrector.update_good_state(h) if (self._last_loop_signal is not None and self._last_loop_signal.triggered and self._can_correct()): candidates.append(("trajectory_diverger", self._last_loop_signal.severity, "loop")) corrected_hidden = hidden_states if candidates: candidates.sort(key=lambda x: x[1], reverse=True) for corrector_name, severity, signal_type in candidates: if not self._can_correct(): break if hidden_states.dim() != 3: continue if corrector_name == "goal_anchor": h_c = self.goal_anchor.correct(corrected_hidden[:, -1, :], severity) elif corrector_name == "steering": h_c = self.steering_corrector.correct(corrected_hidden[:, -1, :], severity) elif corrector_name == "trajectory_diverger": h_c = self.trajectory_diverger.correct(corrected_hidden[:, -1, :], severity) else: continue corrected_hidden = corrected_hidden.clone() corrected_hidden[:, -1, :] = h_c self.state.record_correction(corrector_name, severity) self._use_correction_budget() triggered_severities = [] if drift_signal.triggered: triggered_severities.append(drift_signal.severity) if self._last_compound_signal is not None and self._last_compound_signal.triggered: triggered_severities.append(self._last_compound_signal.severity) if self._last_loop_signal is not None and self._last_loop_signal.triggered: triggered_severities.append(self._last_loop_signal.severity) if self._last_median_signal is not None and self._last_median_signal.triggered: triggered_severities.append(self._last_median_signal.severity) if triggered_severities: max_severity = max(triggered_severities) baseline_step_r = 1.0 - (max_severity * 0.15) correction_applied = self._step_corrections_this_step > 0 if correction_applied: effective_severity = max_severity * max(0.0, 1.0 - self.config.correction_scale * 3.0) else: effective_severity = max_severity aria_step_r = 1.0 - (effective_severity * 0.15) else: baseline_step_r = 1.0 aria_step_r = 1.0 self.state.record_reliability(aria_step_r) self.state.baseline_r.append(baseline_step_r) if isinstance(output, tuple): return (corrected_hidden,) + output[1:] return corrected_hidden return hook_fn def _logit_hook(self, module, input, output): if not self.config.enabled: return output logits = output[0] if isinstance(output, tuple) else output if not isinstance(logits, torch.Tensor): return output if logits.dim() == 3: last_logits = logits[:, -1, :] elif logits.dim() == 2: last_logits = logits else: return output l = last_logits[0] compound_signal = self.compound_detector.detect(l) self._last_compound_signal = compound_signal self.state.record_signal(compound_signal) loop_signal = self.loop_detector.detect(l, l) self._last_loop_signal = loop_signal self.state.record_signal(loop_signal) median_signal = self.median_detector.detect(l) self._last_median_signal = median_signal self.state.record_signal(median_signal) if median_signal.triggered and self._can_correct(): corrected_logits = self.taste_amplifier.correct_logits(l, median_signal.severity) if logits.dim() == 3: logits = logits.clone() logits[0, -1, :] = corrected_logits elif logits.dim() == 2: logits = logits.clone() logits[0] = corrected_logits self.state.record_correction("taste_amplifier", median_signal.severity) self._use_correction_budget() if isinstance(output, tuple): return (logits,) + output[1:] return logits return output def report(self) -> Dict: n = self.state.step avg_r = sum(self.state.effective_r) / len(self.state.effective_r) if self.state.effective_r else 1.0 bl_r = sum(self.state.baseline_r) / len(self.state.baseline_r) if self.state.baseline_r else 0.95 n_steps = max(n, 1) correction_counts = {} for c in self.state.corrections: correction_counts[c["corrector"]] = correction_counts.get(c["corrector"], 0) + 1 signal_counts = {} trigger_counts = {} for s in self.state.signals: signal_counts[s["name"]] = signal_counts.get(s["name"], 0) + 1 if s["triggered"]: trigger_counts[s["name"]] = trigger_counts.get(s["name"], 0) + 1 return { "summary": { "version": "0.3.0", "total_steps": n_steps, "calibration_steps": self.config.calibration_steps, "sensitivity_k": self.config.sensitivity_k, "correction_scale": self.config.correction_scale, "max_corrections_per_step": self.config.max_corrections_per_step, "auto_tuned": self._auto_tuned, "calibration_loaded": self._calibration_loaded, "baseline_R": round(bl_r, 4), "aria_R": round(avg_r, 4), "R_improvement": round(avg_r - bl_r, 4), "baseline_P_success": f"{bl_r ** n_steps:.6e}", "aria_P_success": f"{avg_r ** n_steps:.6e}", "improvement_factor": round((avg_r ** n_steps) / max(bl_r ** n_steps, 1e-300), 2), "total_corrections": len(self.state.corrections), "elapsed_seconds": round(time.time() - self.state.start_time, 2), }, "corrections_by_type": correction_counts, "signals_detected": signal_counts, "signals_triggered": trigger_counts, "calibration_info": { "compound_error": {"mean": self.compound_detector.calibration.mean, "std": self.compound_detector.calibration.std, "threshold": self.compound_detector.calibration.threshold}, "semantic_drift": {"mean": self.drift_detector.calibration.mean, "std": self.drift_detector.calibration.std, "threshold": self.drift_detector.calibration.threshold}, "median_trap_top1": {"mean": self.median_detector.top1_calibration.mean, "std": self.median_detector.top1_calibration.std, "threshold": self.median_detector.top1_calibration.threshold}, }, "model_info": self._model_info, } def report_text(self) -> str: r = self.report() s = r["summary"] lines = [ "=" * 60, " ARIA v0.3 RELIABILITY REPORT", "=" * 60, "", f" Steps monitored: {s['total_steps']}", f" Correction scale: {s['correction_scale']}" + (" (auto-tuned)" if s['auto_tuned'] else ""), f" Calibration loaded: {s['calibration_loaded']}", f" Time elapsed: {s['elapsed_seconds']}s", "", " RELIABILITY (R per step):", f" Baseline (no ARIA): {s['baseline_R']}", f" With ARIA: {s['aria_R']}", f" Improvement: {'+' if s['R_improvement'] >= 0 else ''}{s['R_improvement']}", "", " SUCCESS PROBABILITY (P_s = R^n):", f" Baseline: {s['baseline_P_success']}", f" With ARIA: {s['aria_P_success']}", f" Improvement factor: {s['improvement_factor']}x", "", f" Total corrections: {s['total_corrections']}", ] if r["corrections_by_type"]: lines += ["", " CORRECTIONS APPLIED:"] for name, count in r["corrections_by_type"].items(): lines.append(f" {name}: {count}") if r["signals_triggered"]: lines += ["", " FAILURE MODES DETECTED:"] for name, count in r["signals_triggered"].items(): total = r["signals_detected"].get(name, count) lines.append(f" {name}: {count}/{total} ({count/max(total,1)*100:.1f}% of checks)") lines += ["", "=" * 60] return "\n".join(lines) def __repr__(self): status = "attached" if self._attached else "detached" loaded = " profile-loaded" if self._calibration_loaded else "" tuned = " auto-tuned" if self._auto_tuned else "" return f"ARIA(status={status}, v=0.3, layers={len(self._hooks)} hooks, corrections={len(self.state.corrections)}{loaded}{tuned})"