aria-llm / aria_llm /core.py
SofiTesfay2010's picture
v0.3: core with save/load calibration + auto-tune
5d94b1c verified
"""
ARIA Core Module v0.3
======================
v0.3 changes:
- save_calibration() / load_calibration(): Persist calibration profiles as JSON.
Skip the calibration phase on subsequent runs with the same model.
- auto_tune_correction_scale(): After calibration, automatically set correction_scale
based on the observed signal variances. High-variance models get gentler corrections.
- Calibration profile includes model fingerprint (name + hidden_dim + num_layers)
for safety checking.
Usage:
from aria_llm import ARIA, ARIAConfig
# First run: calibrate and save
config = ARIAConfig(auto_tune_correction_scale=True, verbose=True)
aria = ARIA.attach(model, tokenizer, config=config)
output = model.generate(input_ids, max_new_tokens=500)
aria.save_calibration("profiles/my_model.json")
aria.detach()
# Subsequent runs: load profile (instant, no calibration needed)
aria = ARIA.attach(model, tokenizer, config=ARIAConfig(
calibration_profile_path="profiles/my_model.json"))
output = model.generate(...)
aria.detach()
"""
import torch
import torch.nn.functional as F
from typing import Optional, Dict, List, Tuple, Any
from collections import deque
import time
import json
import os
import hashlib
from aria_llm.config import ARIAConfig
from aria_llm.detectors import (
CompoundErrorDetector, SemanticDriftDetector,
LogicLoopDetector, MedianTrapDetector, DetectionSignal,
)
from aria_llm.correctors import (
SteeringCorrector, GoalAnchor, TrajectoryDiverger, TasteAmplifier,
)
class ARIAState:
"""Runtime state for a single generation pass."""
def __init__(self):
self.step = 0
self.signals: List[Dict] = []
self.corrections: List[Dict] = []
self.start_time = time.time()
self.effective_r: List[float] = []
self.cumulative_r: List[float] = []
self.baseline_r: List[float] = []
def record_signal(self, signal: DetectionSignal):
self.signals.append({
"step": self.step, "name": signal.name, "severity": signal.severity,
"triggered": signal.triggered, "raw_value": signal.raw_value,
"metadata": signal.metadata,
})
def record_correction(self, name: str, strength: float):
self.corrections.append({"step": self.step, "corrector": name, "strength": strength})
def record_reliability(self, r: float):
self.effective_r.append(r)
if self.cumulative_r:
self.cumulative_r.append(self.cumulative_r[-1] * r)
else:
self.cumulative_r.append(r)
class ARIA:
"""Adaptive Reliability & Integrity Attachment v0.3.
Hooks into a HuggingFace Transformers model to provide real-time
detection and correction of four failure modes.
v0.3: Calibration profiles + auto-tune correction_scale.
"""
def __init__(self, model, tokenizer, config: Optional[ARIAConfig] = None):
self.model = model
self.tokenizer = tokenizer
self.config = config or ARIAConfig()
cs, sk, cscale = self.config.calibration_steps, self.config.sensitivity_k, self.config.correction_scale
self.compound_detector = CompoundErrorDetector(calibration_steps=cs, sensitivity_k=sk,
window=self.config.instability_window, lam=self.config.instability_lambda,
fallback_threshold=self.config.compound_error_threshold)
self.drift_detector = SemanticDriftDetector(calibration_steps=cs, sensitivity_k=sk,
window=self.config.drift_window, fallback_threshold=self.config.drift_threshold)
self.loop_detector = LogicLoopDetector(calibration_steps=cs, sensitivity_k=sk,
window=self.config.loop_window, similarity_threshold=self.config.loop_similarity_threshold,
entropy_var_threshold=self.config.loop_entropy_variance_threshold,
max_breaks=self.config.max_loop_breaks)
self.median_detector = MedianTrapDetector(calibration_steps=cs, sensitivity_k=sk,
temperature_boost=self.config.taste_temperature_boost, novelty_bonus=self.config.novelty_bonus)
self.steering_corrector = SteeringCorrector(alpha=self.config.steering_alpha, correction_scale=cscale)
self.goal_anchor = GoalAnchor(drift_threshold=self.config.drift_threshold,
correction_strength=0.2, reanchor_interval=self.config.goal_reanchor_interval, correction_scale=cscale)
self.trajectory_diverger = TrajectoryDiverger(divergence_strength=0.5,
max_breaks=self.config.max_loop_breaks, correction_scale=cscale)
self.taste_amplifier = TasteAmplifier(temperature_boost=self.config.taste_temperature_boost,
novelty_bonus=self.config.novelty_bonus, correction_scale=cscale)
self._hooks: List = []
self._attached = False
self.state = ARIAState()
self._step_corrections_this_step = 0
self._current_step_id = -1
self._last_compound_signal: Optional[DetectionSignal] = None
self._last_loop_signal: Optional[DetectionSignal] = None
self._last_median_signal: Optional[DetectionSignal] = None
self._last_drift_signal: Optional[DetectionSignal] = None
self._model_info = self._detect_architecture()
self._calibration_loaded = False
self._auto_tuned = False
if self.config.calibration_profile_path:
self.load_calibration(self.config.calibration_profile_path)
@classmethod
def attach(cls, model, tokenizer, config: Optional[ARIAConfig] = None) -> 'ARIA':
aria = cls(model, tokenizer, config)
aria._install_hooks()
return aria
def detach(self):
for hook in self._hooks:
hook.remove()
self._hooks.clear()
self._attached = False
if self.config.verbose:
print(f"[ARIA] Detached. Applied {len(self.state.corrections)} corrections "
f"over {self.state.step} steps.")
def reset(self):
self.compound_detector.reset()
self.drift_detector.reset()
self.loop_detector.reset()
self.median_detector.reset()
self.steering_corrector.reset()
self.goal_anchor.reset()
self.trajectory_diverger.reset()
self.taste_amplifier.reset()
self.state = ARIAState()
self._step_corrections_this_step = 0
self._current_step_id = -1
self._last_compound_signal = None
self._last_loop_signal = None
self._last_median_signal = None
self._last_drift_signal = None
self._auto_tuned = False
def _model_fingerprint(self) -> Dict:
model_config = getattr(self.model, "config", None)
name = getattr(model_config, "_name_or_path", "unknown") if model_config else "unknown"
return {
"model_name": name,
"num_layers": self._model_info["num_layers"],
"hidden_dim": self._model_info["hidden_dim"],
"fingerprint_hash": hashlib.md5(
f"{name}_{self._model_info['num_layers']}_{self._model_info['hidden_dim']}".encode()
).hexdigest()[:12],
}
def save_calibration(self, path: str):
os.makedirs(os.path.dirname(path) if os.path.dirname(path) else ".", exist_ok=True)
profile = {
"aria_version": "0.3.0",
"saved_at": time.strftime("%Y-%m-%dT%H:%M:%S%z"),
"model": self._model_fingerprint(),
"config": {
"calibration_steps": self.config.calibration_steps,
"sensitivity_k": self.config.sensitivity_k,
"correction_scale": self.config.correction_scale,
"max_corrections_per_step": self.config.max_corrections_per_step,
"auto_tuned": self._auto_tuned,
},
"detectors": {},
}
profile["detectors"].update(self.compound_detector.export_calibration())
profile["detectors"].update(self.drift_detector.export_calibration())
profile["detectors"].update(self.loop_detector.export_calibration())
profile["detectors"].update(self.median_detector.export_calibration())
with open(path, "w") as f:
json.dump(profile, f, indent=2, default=str)
if self.config.verbose:
print(f"[ARIA] Calibration profile saved to {path}")
return profile
def load_calibration(self, path: str):
if not os.path.exists(path):
raise FileNotFoundError(f"Calibration profile not found: {path}")
with open(path, "r") as f:
profile = json.load(f)
saved_fp = profile.get("model", {})
current_fp = self._model_fingerprint()
if (saved_fp.get("num_layers") != current_fp["num_layers"] or
saved_fp.get("hidden_dim") != current_fp["hidden_dim"]):
raise ValueError(
f"Calibration profile mismatch! Saved for layers={saved_fp.get('num_layers')}, "
f"dim={saved_fp.get('hidden_dim')}. Current: layers={current_fp['num_layers']}, "
f"dim={current_fp['hidden_dim']}")
detectors = profile.get("detectors", {})
if "compound_error" in detectors:
self.compound_detector.load_calibration(detectors)
if "semantic_drift" in detectors:
self.drift_detector.load_calibration(detectors)
if "logic_loop" in detectors:
self.loop_detector.load_calibration(detectors)
if "median_trap" in detectors:
self.median_detector.load_calibration(detectors)
saved_config = profile.get("config", {})
if saved_config.get("auto_tuned") and "correction_scale" in saved_config:
self.config.correction_scale = saved_config["correction_scale"]
self._update_corrector_scales(self.config.correction_scale)
self._auto_tuned = True
self._calibration_loaded = True
if self.config.verbose:
print(f"[ARIA] Calibration profile loaded from {path}")
def auto_tune_correction_scale(self) -> float:
cvs = []
for cal in [self.compound_detector.calibration, self.drift_detector.calibration,
self.median_detector.top1_calibration, self.median_detector.inv_entropy_calibration]:
if cal.mean is not None and cal.std is not None and abs(cal.mean) > 1e-8:
cvs.append(cal.std / abs(cal.mean))
if not cvs:
return self.config.correction_scale
avg_cv = sum(cvs) / len(cvs)
new_scale = max(self.config.auto_tune_min_scale,
min(self.config.auto_tune_max_scale, 0.15 / (1.0 + avg_cv)))
old_scale = self.config.correction_scale
self.config.correction_scale = new_scale
self._update_corrector_scales(new_scale)
self._auto_tuned = True
if self.config.verbose:
print(f"[ARIA] Auto-tune correction_scale: {old_scale:.4f} -> {new_scale:.4f} (avg_cv={avg_cv:.3f})")
return new_scale
def _update_corrector_scales(self, scale: float):
self.steering_corrector.correction_scale = scale
self.goal_anchor.correction_scale = scale
self.trajectory_diverger.correction_scale = scale
self.taste_amplifier.correction_scale = scale
def _can_correct(self) -> bool:
return self._step_corrections_this_step < self.config.max_corrections_per_step
def _use_correction_budget(self):
self._step_corrections_this_step += 1
def _detect_architecture(self) -> Dict:
info = {"arch": "unknown", "num_layers": 0, "hidden_dim": 0, "layers_attr": None}
for attr in ["model.layers", "transformer.h", "gpt_neox.layers",
"model.decoder.layers", "encoder.layer"]:
parts = attr.split(".")
obj = self.model
try:
for part in parts:
obj = getattr(obj, part)
info["layers_attr"] = attr
info["num_layers"] = len(obj)
break
except AttributeError:
continue
model_config = getattr(self.model, "config", None)
if model_config:
for dim_attr in ["hidden_size", "d_model", "n_embd"]:
if hasattr(model_config, dim_attr):
info["hidden_dim"] = getattr(model_config, dim_attr)
break
if self.config.verbose:
print(f"[ARIA] Detected architecture: {info}")
return info
def _get_target_layers(self) -> list:
if self.config.steering_layers is not None:
return self.config.steering_layers
n = self._model_info["num_layers"]
return [n // 2] if n > 0 else []
def _get_layers_module(self):
if self._model_info["layers_attr"] is None:
return None
parts = self._model_info["layers_attr"].split(".")
obj = self.model
for part in parts:
obj = getattr(obj, part)
return obj
def _install_hooks(self):
layers = self._get_layers_module()
if layers is None:
self._install_output_hook()
self._attached = True
return
target_indices = self._get_target_layers()
for idx in target_indices:
if idx < len(layers):
self._hooks.append(layers[idx].register_forward_hook(self._make_layer_hook(idx)))
self._install_output_hook()
self._attached = True
if self.config.verbose:
print(f"[ARIA] Installed {len(self._hooks)} hooks on layers {target_indices}")
def _install_output_hook(self):
for attr in ["lm_head", "output", "cls"]:
if hasattr(self.model, attr):
self._hooks.append(getattr(self.model, attr).register_forward_hook(self._logit_hook))
return
def _make_layer_hook(self, layer_idx: int):
def hook_fn(module, input, output):
if not self.config.enabled:
return output
if isinstance(output, tuple):
hidden_states = output[0]
else:
hidden_states = output
if hidden_states is None or not isinstance(hidden_states, torch.Tensor):
return output
if hidden_states.dim() == 3:
last_hidden = hidden_states[:, -1, :]
elif hidden_states.dim() == 2:
last_hidden = hidden_states[-1:]
else:
return output
h = last_hidden[0] if last_hidden.dim() > 1 else last_hidden
self.state.step += 1
step_id = self.state.step
if step_id != self._current_step_id:
self._current_step_id = step_id
self._step_corrections_this_step = 0
# Auto-tune after calibration completes (once)
if (self.config.auto_tune_correction_scale and
not self._auto_tuned and not self._calibration_loaded and
step_id == self.config.calibration_steps + 1):
self.auto_tune_correction_scale()
drift_signal = self.drift_detector.detect(h)
self._last_drift_signal = drift_signal
self.state.record_signal(drift_signal)
candidates = []
if drift_signal.triggered and self._can_correct():
candidates.append(("goal_anchor", drift_signal.severity, "drift"))
if (self._last_compound_signal is not None and
self._last_compound_signal.triggered and self._can_correct()):
candidates.append(("steering", self._last_compound_signal.severity, "compound"))
else:
self.steering_corrector.update_good_state(h)
if (self._last_loop_signal is not None and
self._last_loop_signal.triggered and self._can_correct()):
candidates.append(("trajectory_diverger", self._last_loop_signal.severity, "loop"))
corrected_hidden = hidden_states
if candidates:
candidates.sort(key=lambda x: x[1], reverse=True)
for corrector_name, severity, signal_type in candidates:
if not self._can_correct():
break
if hidden_states.dim() != 3:
continue
if corrector_name == "goal_anchor":
h_c = self.goal_anchor.correct(corrected_hidden[:, -1, :], severity)
elif corrector_name == "steering":
h_c = self.steering_corrector.correct(corrected_hidden[:, -1, :], severity)
elif corrector_name == "trajectory_diverger":
h_c = self.trajectory_diverger.correct(corrected_hidden[:, -1, :], severity)
else:
continue
corrected_hidden = corrected_hidden.clone()
corrected_hidden[:, -1, :] = h_c
self.state.record_correction(corrector_name, severity)
self._use_correction_budget()
triggered_severities = []
if drift_signal.triggered:
triggered_severities.append(drift_signal.severity)
if self._last_compound_signal is not None and self._last_compound_signal.triggered:
triggered_severities.append(self._last_compound_signal.severity)
if self._last_loop_signal is not None and self._last_loop_signal.triggered:
triggered_severities.append(self._last_loop_signal.severity)
if self._last_median_signal is not None and self._last_median_signal.triggered:
triggered_severities.append(self._last_median_signal.severity)
if triggered_severities:
max_severity = max(triggered_severities)
baseline_step_r = 1.0 - (max_severity * 0.15)
correction_applied = self._step_corrections_this_step > 0
if correction_applied:
effective_severity = max_severity * max(0.0, 1.0 - self.config.correction_scale * 3.0)
else:
effective_severity = max_severity
aria_step_r = 1.0 - (effective_severity * 0.15)
else:
baseline_step_r = 1.0
aria_step_r = 1.0
self.state.record_reliability(aria_step_r)
self.state.baseline_r.append(baseline_step_r)
if isinstance(output, tuple):
return (corrected_hidden,) + output[1:]
return corrected_hidden
return hook_fn
def _logit_hook(self, module, input, output):
if not self.config.enabled:
return output
logits = output[0] if isinstance(output, tuple) else output
if not isinstance(logits, torch.Tensor):
return output
if logits.dim() == 3:
last_logits = logits[:, -1, :]
elif logits.dim() == 2:
last_logits = logits
else:
return output
l = last_logits[0]
compound_signal = self.compound_detector.detect(l)
self._last_compound_signal = compound_signal
self.state.record_signal(compound_signal)
loop_signal = self.loop_detector.detect(l, l)
self._last_loop_signal = loop_signal
self.state.record_signal(loop_signal)
median_signal = self.median_detector.detect(l)
self._last_median_signal = median_signal
self.state.record_signal(median_signal)
if median_signal.triggered and self._can_correct():
corrected_logits = self.taste_amplifier.correct_logits(l, median_signal.severity)
if logits.dim() == 3:
logits = logits.clone()
logits[0, -1, :] = corrected_logits
elif logits.dim() == 2:
logits = logits.clone()
logits[0] = corrected_logits
self.state.record_correction("taste_amplifier", median_signal.severity)
self._use_correction_budget()
if isinstance(output, tuple):
return (logits,) + output[1:]
return logits
return output
def report(self) -> Dict:
n = self.state.step
avg_r = sum(self.state.effective_r) / len(self.state.effective_r) if self.state.effective_r else 1.0
bl_r = sum(self.state.baseline_r) / len(self.state.baseline_r) if self.state.baseline_r else 0.95
n_steps = max(n, 1)
correction_counts = {}
for c in self.state.corrections:
correction_counts[c["corrector"]] = correction_counts.get(c["corrector"], 0) + 1
signal_counts = {}
trigger_counts = {}
for s in self.state.signals:
signal_counts[s["name"]] = signal_counts.get(s["name"], 0) + 1
if s["triggered"]:
trigger_counts[s["name"]] = trigger_counts.get(s["name"], 0) + 1
return {
"summary": {
"version": "0.3.0", "total_steps": n_steps,
"calibration_steps": self.config.calibration_steps,
"sensitivity_k": self.config.sensitivity_k,
"correction_scale": self.config.correction_scale,
"max_corrections_per_step": self.config.max_corrections_per_step,
"auto_tuned": self._auto_tuned,
"calibration_loaded": self._calibration_loaded,
"baseline_R": round(bl_r, 4), "aria_R": round(avg_r, 4),
"R_improvement": round(avg_r - bl_r, 4),
"baseline_P_success": f"{bl_r ** n_steps:.6e}",
"aria_P_success": f"{avg_r ** n_steps:.6e}",
"improvement_factor": round((avg_r ** n_steps) / max(bl_r ** n_steps, 1e-300), 2),
"total_corrections": len(self.state.corrections),
"elapsed_seconds": round(time.time() - self.state.start_time, 2),
},
"corrections_by_type": correction_counts,
"signals_detected": signal_counts,
"signals_triggered": trigger_counts,
"calibration_info": {
"compound_error": {"mean": self.compound_detector.calibration.mean,
"std": self.compound_detector.calibration.std,
"threshold": self.compound_detector.calibration.threshold},
"semantic_drift": {"mean": self.drift_detector.calibration.mean,
"std": self.drift_detector.calibration.std,
"threshold": self.drift_detector.calibration.threshold},
"median_trap_top1": {"mean": self.median_detector.top1_calibration.mean,
"std": self.median_detector.top1_calibration.std,
"threshold": self.median_detector.top1_calibration.threshold},
},
"model_info": self._model_info,
}
def report_text(self) -> str:
r = self.report()
s = r["summary"]
lines = [
"=" * 60, " ARIA v0.3 RELIABILITY REPORT", "=" * 60, "",
f" Steps monitored: {s['total_steps']}",
f" Correction scale: {s['correction_scale']}" + (" (auto-tuned)" if s['auto_tuned'] else ""),
f" Calibration loaded: {s['calibration_loaded']}",
f" Time elapsed: {s['elapsed_seconds']}s", "",
" RELIABILITY (R per step):",
f" Baseline (no ARIA): {s['baseline_R']}",
f" With ARIA: {s['aria_R']}",
f" Improvement: {'+' if s['R_improvement'] >= 0 else ''}{s['R_improvement']}", "",
" SUCCESS PROBABILITY (P_s = R^n):",
f" Baseline: {s['baseline_P_success']}",
f" With ARIA: {s['aria_P_success']}",
f" Improvement factor: {s['improvement_factor']}x", "",
f" Total corrections: {s['total_corrections']}",
]
if r["corrections_by_type"]:
lines += ["", " CORRECTIONS APPLIED:"]
for name, count in r["corrections_by_type"].items():
lines.append(f" {name}: {count}")
if r["signals_triggered"]:
lines += ["", " FAILURE MODES DETECTED:"]
for name, count in r["signals_triggered"].items():
total = r["signals_detected"].get(name, count)
lines.append(f" {name}: {count}/{total} ({count/max(total,1)*100:.1f}% of checks)")
lines += ["", "=" * 60]
return "\n".join(lines)
def __repr__(self):
status = "attached" if self._attached else "detached"
loaded = " profile-loaded" if self._calibration_loaded else ""
tuned = " auto-tuned" if self._auto_tuned else ""
return f"ARIA(status={status}, v=0.3, layers={len(self._hooks)} hooks, corrections={len(self.state.corrections)}{loaded}{tuned})"