taf-agent / cli /diagnose_model.py
karlexmarin's picture
fix: 6 formula corrections in TAF Agent (audit-driven)
fed77dc
"""
diagnose_model.py — "Predicting How Transformers Attend" Diagnostic Tool
=========================================================================
Single-command characterization of any causal LM via power-law attention decay.
Measures:
γ (gamma) — attention decay exponent A(d) ∝ d^{-γ}
T_attn = 1/γ — attention temperature
Phase — A (deconfined / RoPE), B (confined / AbsPE), C (ALiBi), Hagedorn
Z, U, S, F — thermodynamic potentials (partition function, energy, entropy, free energy)
C_V, χ — heat capacity, susceptibility
D_90 — context depth capturing 90% of Z (KV compression estimate)
ΔH_90 — holographic quality loss at D_90
KL_grammar — attention grammar anomaly (deviation from power-law prior)
θ_eff — effective RoPE base (Padé diagnostic)
γ_pred — theoretical prediction C/ln(θ) where C=ln(10000)=9.2103
Usage:
python diagnose_model.py --model EleutherAI/pythia-70m
python diagnose_model.py --model meta-llama/Meta-Llama-3-8B --local /path/to/weights --load_in_4bit
python diagnose_model.py --model Qwen/Qwen2.5-7B --theta 1000000 --N 1000
python diagnose_model.py --model EleutherAI/pythia-70m --fast # quick mode, 3 distances
Output:
Prints diagnostic table to stdout.
Saves JSON to ./diagnose_results/{model_short}.json
"""
import sys
sys.stdout.reconfigure(encoding='utf-8')
import argparse
import json
import math
import random
import time
from pathlib import Path
import numpy as np
# ── Constants ──────────────────────────────────────────────────────────────────
C_THEORY = math.log(10000) # 9.2103 — γ × ln(θ) = C for standard RoPE
DISTANCES_FULL = [10, 20, 30, 50, 100, 200, 500, 1000, 2000]
DISTANCES_FAST = [10, 50, 200, 1000]
N_PROMPTS = 30 # per distance (fast mode default)
N_PROMPTS_FULL = 80
SEEDS = [42, 123, 7]
THETA_KNOWN = {
"EleutherAI/pythia-14m": 10_000,
"EleutherAI/pythia-31m": 10_000,
"EleutherAI/pythia-70m": 10_000,
"EleutherAI/pythia-160m": 10_000,
"EleutherAI/pythia-410m": 10_000,
"EleutherAI/pythia-1b": 10_000,
"EleutherAI/pythia-1.4b": 10_000,
"EleutherAI/pythia-2.8b": 10_000,
"mistralai/Mistral-7B-v0.1": 10_000,
"tiiuae/falcon-7b": 10_000,
"microsoft/phi-2": 10_000,
"meta-llama/Llama-2-7b-hf": 10_000,
"google/gemma-2-9b-it": 10_000,
"EleutherAI/gpt-j-6B": 10_000,
"meta-llama/Meta-Llama-3-8B": 500_000,
"Qwen/Qwen2.5-7B": 1_000_000,
"mistralai/Mistral-Nemo-Instruct-2407": 1_000_000,
"codellama/CodeLlama-13b-Instruct-hf": 1_000_000,
}
OUTPUT_DIR = Path("./diagnose_results")
# ── Thermodynamic functions ────────────────────────────────────────────────────
# Euler-Mascheroni constant — needed for accurate H_N approximation at γ=1.
EULER_GAMMA = 0.5772156649015329
def partition_Z(gamma: float, N: int) -> float:
"""Z(γ, N) = sum_{d=1}^N d^{-γ}.
γ=1: H_N ~ log N + γ_E + 1/(2N) − ... [Euler-Mascheroni asymptotic]
γ≠1: integral approximation + d=1 boundary.
"""
if abs(gamma - 1.0) < 1e-5:
return math.log(N) + EULER_GAMMA # was math.log(N+0.5), missing γ_E
return (N ** (1 - gamma) - 1) / (1 - gamma) + 1
def mean_log_d(gamma: float, N: int) -> float:
Z = partition_Z(gamma, N)
if Z <= 0:
return 0.0
if abs(gamma - 1.0) < 1e-5:
integral = math.log(N) ** 2 / 2
else:
g1 = 1.0 - gamma
integral = N ** g1 * (math.log(N) / g1 - 1 / g1 ** 2) + 1 / g1 ** 2
return integral / Z
def entropy_S(gamma: float, N: int) -> float:
return math.log(partition_Z(gamma, N)) + gamma * mean_log_d(gamma, N)
def free_energy_F(gamma: float, N: int) -> float:
"""Helmholtz free energy: F = -T·log(Z) = -log(Z)/γ (T_attn = 1/γ).
Was: -log(Z) [β·F = log-partition convention; ambiguous when reported as F].
Now: -log(Z)/γ [physical F, consistent with U = -∂(log Z)/∂γ and S = (U − F)/T].
"""
Z = max(partition_Z(gamma, N), 1e-30)
return -math.log(Z) / max(gamma, 1e-9)
def heat_capacity_Cv(gamma: float, N: int, delta: float = 1e-4) -> float:
if gamma <= delta or gamma >= 20:
return float("nan")
dU = (mean_log_d(gamma + delta, N) - mean_log_d(gamma - delta, N)) / (2 * delta)
return -gamma ** 2 * dU
def D_f_closed(gamma: float, f: float, N: int) -> int:
"""KV compression window — DISCRETE truth (exact for the sum).
Smallest D such that ∑_{d=1}^D d^{-γ} / ∑_{d=1}^N d^{-γ} ≥ f.
The paper's "exact continuous formula"
D_f = [(1−f) + f·N^(1−γ)]^{1/(1−γ)} (and the γ=1 limit N^f)
is a CONTINUUM INTEGRAL APPROXIMATION that diverges from the discrete
sum by 5–50% in Phase B (γ>1), where the agent serves users.
Since N is bounded by context window (≤ ~10⁶), direct summation is
O(N) and fast (<10 ms). We use it for accuracy.
"""
if N <= 0:
return 1
if not (0.0 < gamma):
return N # ill-defined; safe upper bound
# Direct discrete cumulative
weights = [d ** (-gamma) for d in range(1, N + 1)]
total = sum(weights)
if total <= 0 or not math.isfinite(total):
# Fall back to continuum closed form (rare numerical edge case)
return _D_f_closed_continuum(gamma, f, N)
target = f * total
cum = 0.0
for d, w in enumerate(weights, start=1):
cum += w
if cum >= target:
return d
return N
def _D_f_closed_continuum(gamma: float, f: float, N: int) -> int:
"""Continuum closed form (paper Theorem 7.1) — asymptotic, kept as fallback."""
if abs(gamma - 1.0) < 1e-9:
return max(1, min(N, int(round(N ** f))))
one_minus_g = 1.0 - gamma
base = (1 - f) + f * (N ** one_minus_g)
if base <= 0:
return 1
try:
d_f = base ** (1.0 / one_minus_g)
except (OverflowError, ValueError):
return N
if not math.isfinite(d_f):
return N
return max(1, min(N, int(round(d_f))))
def delta_H(theta: float, Df: int, N: int) -> float:
sqrt2 = math.sqrt(2)
return math.log((theta + Df / sqrt2) / (theta + N / sqrt2))
def theta_eff_pade(theta: float, T: float) -> float:
return theta + T / math.sqrt(2)
def phase_label(gamma: float) -> str:
if gamma < 0.95:
return "A — deconfined (RoPE/long)"
if gamma > 1.05:
return "B — confined (AbsPE/short)"
return "Hagedorn (crossover γ≈1)"
def kl_divergence(p: np.ndarray, q: np.ndarray) -> float:
p = p / p.sum()
q = q / q.sum()
eps = 1e-12
mask = p > eps
return float(np.sum(p[mask] * np.log(p[mask] / (q[mask] + eps))))
# ── Attention measurement ──────────────────────────────────────────────────────
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
try:
import torch
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
except ImportError:
pass
def measure_attn_distance(model, tokenizer, distance: int, n_prompts: int,
seed: int, device: str, vocab_high: int) -> float:
import torch
set_seed(seed)
rng = random.Random(seed)
seq_len = distance + 50
target_pos = seq_len - distance - 1
last_pos = seq_len - 1
vocab_low = 1000
attn_values = []
model.eval()
with torch.no_grad():
for _ in range(n_prompts):
tokens = [rng.randint(vocab_low, vocab_high) for _ in range(seq_len)]
input_ids = torch.tensor([tokens], dtype=torch.long).to(device)
try:
out = model(input_ids, output_attentions=True, return_dict=True)
except Exception:
continue
if out.attentions is None:
raise RuntimeError(
"output_attentions returned None. "
"Try loading with attn_implementation='eager'."
)
vals = []
for layer_attn in out.attentions:
w = layer_attn[0, :, last_pos, target_pos].float().cpu().numpy()
finite = w[np.isfinite(w)]
if len(finite):
vals.append(float(np.mean(finite)))
if vals:
attn_values.append(float(np.mean(vals)))
return float(np.mean(attn_values)) if attn_values else float("nan")
def fit_power_law(distances: list, means: list) -> dict:
d = np.array(distances, dtype=float)
m = np.array(means, dtype=float)
mask = np.isfinite(m) & (m > 0)
if mask.sum() < 2:
return {"gamma": float("nan"), "log_A": 0.0, "R2": 0.0}
log_d = np.log(d[mask])
log_m = np.log(m[mask])
X = np.stack([np.ones(mask.sum()), -log_d], axis=1)
coeffs, *_ = np.linalg.lstsq(X, log_m, rcond=None)
log_A, gamma = float(coeffs[0]), float(coeffs[1])
pred = log_A - gamma * log_d
ss_res = float(np.sum((log_m - pred) ** 2))
ss_tot = float(np.sum((log_m - np.mean(log_m)) ** 2))
R2 = 1.0 - ss_res / ss_tot if ss_tot > 0 else 0.0
return {"gamma": gamma, "log_A": log_A, "R2": round(R2, 6)}
# ── Attention Grammar anomaly ──────────────────────────────────────────────────
def grammar_kl(attn_by_d: dict, gamma: float, log_A: float) -> float:
dists = sorted(attn_by_d.keys())
p_obs = np.array([attn_by_d[d] for d in dists], dtype=float)
p_obs = np.maximum(p_obs, 1e-30)
p_obs /= p_obs.sum()
A = math.exp(log_A)
p_prior = np.array([A * d ** (-gamma) for d in dists], dtype=float)
p_prior = np.maximum(p_prior, 1e-30)
p_prior /= p_prior.sum()
return kl_divergence(p_obs, p_prior)
# ── Main diagnostic ───────────────────────────────────────────────────────────
def run_diagnostic(args) -> dict:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = args.model
theta_nom = args.theta or THETA_KNOWN.get(model_name, 10_000)
print(f"\n{'='*65}")
print(f"TRANSFORMER THERMODYNAMICS DIAGNOSTIC")
print(f"{'='*65}")
print(f" Model : {model_name}")
print(f" theta_nom : {theta_nom:,}")
print(f" N : {args.N}")
print(f" Mode : {'fast' if args.fast else 'full'}")
print()
# ── Load model ──────────────────────────────────────────────────────
local_path = args.local or model_name
print(f"Loading model from: {local_path} ...")
t0 = time.time()
load_kwargs = dict(
trust_remote_code=True,
attn_implementation="eager",
)
device = "cuda" if (not args.cpu and torch.cuda.is_available()) else "cpu"
if args.load_in_4bit and device == "cuda":
try:
from transformers import BitsAndBytesConfig
load_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
load_kwargs["device_map"] = "auto"
except ImportError:
print(" [warn] bitsandbytes not available; loading in float32")
elif device == "cpu":
load_kwargs["dtype"] = torch.float32
tokenizer = AutoTokenizer.from_pretrained(local_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(local_path, **load_kwargs)
if device == "cpu":
model = model.to("cpu")
model.eval()
print(f" Loaded in {time.time()-t0:.1f}s device={device}")
vocab_high = min(tokenizer.vocab_size - 1, 49_000)
distances = DISTANCES_FAST if args.fast else DISTANCES_FULL
n_prompts = N_PROMPTS if args.fast else N_PROMPTS_FULL
N = args.N
# ── Measure attention by distance ────────────────────────────────────
print(f"\nMeasuring attention decay at {len(distances)} distances × {n_prompts} prompts ...")
attn_by_d = {}
for dist in distances:
if dist > N:
continue
t1 = time.time()
mean_val = measure_attn_distance(
model, tokenizer, dist, n_prompts, SEEDS[0], device, vocab_high
)
attn_by_d[dist] = mean_val
print(f" d={dist:5d} attn={mean_val:.6f} ({time.time()-t1:.1f}s)")
# ── Fit power law ────────────────────────────────────────────────────
valid_d = [d for d, v in attn_by_d.items() if math.isfinite(v) and v > 0]
valid_v = [attn_by_d[d] for d in valid_d]
fit = fit_power_law(valid_d, valid_v)
gamma = fit["gamma"]
log_A = fit["log_A"]
R2 = fit["R2"]
if not math.isfinite(gamma):
print("\n[ERROR] Power-law fit failed. Too few valid distances.")
return {}
# ── Thermodynamics ───────────────────────────────────────────────────
Z = partition_Z(gamma, N)
U = mean_log_d(gamma, N)
S = entropy_S(gamma, N)
F = free_energy_F(gamma, N)
Cv = heat_capacity_Cv(gamma, N)
chi = 1.0 / abs(gamma - 1.0) if abs(gamma - 1.0) > 1e-4 else 1e6
xi = 1.0 / abs(math.log(gamma)) if abs(math.log(gamma)) > 1e-10 else 1e6
T_attn = 1.0 / gamma
D90 = D_f_closed(gamma, 0.90, N)
dH90 = delta_H(theta_nom, D90, N)
theta_eff = theta_eff_pade(theta_nom, float(N))
# Theoretical γ prediction — γ_Padé(θ, T_eval) (paper §3.3, supersedes
# the earlier shorthand γ ≈ C/lnθ which assumed T = 10000).
if theta_nom > 0:
T_for_pred = max(distances) if distances else N # use largest measured T
z_sqrt2 = T_for_pred * math.sqrt(2)
gamma_pred = (2 * theta_nom - z_sqrt2) / (2 * theta_nom + z_sqrt2)
else:
gamma_pred = None
# Attention grammar KL
kl_ag = grammar_kl(attn_by_d, gamma, log_A)
# Phase
phase = phase_label(gamma)
# ── Report ───────────────────────────────────────────────────────────
print(f"\n{'='*65}")
print(f"RESULTS")
print(f"{'='*65}")
print(f" γ (gamma) = {gamma:.4f} [R²={R2:.4f}]")
if gamma_pred is not None:
delta_g = gamma - gamma_pred
print(f" γ_Padé(θ,T) = {gamma_pred:.4f} Δγ = {delta_g:+.4f}")
print(f" Phase : {phase}")
print(f" T_attn = 1/γ = {T_attn:.4f}")
print()
print(f" Thermodynamics (N={N}):")
print(f" Z (partition) = {Z:.4f}")
print(f" U = E[log d] = {U:.4f}")
print(f" S (entropy) = {S:.4f}")
print(f" F (free ener) = {F:.4f}")
cv_str = f"{Cv:.4f}" if math.isfinite(Cv) else "N/A"
print(f" C_V (heat cap)= {cv_str}")
chi_str = f"{chi:.2f}" if chi < 1e5 else "∞ (near Hagedorn)"
print(f" χ (suscept.) = {chi_str}")
xi_str = f"{xi:.2f}" if xi < 1e5 else "∞"
print(f" ξ (corr. len) = {xi_str}")
print()
print(f" KV Compression (f=0.90):")
print(f" D_90 = {D90} tokens ({D90/N*100:.1f}% of N={N})")
print(f" dH_90 = {dH90:.4f} nats")
print()
print(f" RoPE Diagnostic:")
print(f" theta_nom = {theta_nom:,}")
print(f" theta_eff_Pade = {theta_eff:.1f}")
print()
print(f" Attention Grammar:")
print(f" KL(obs||prior) = {kl_ag:.4f} ", end="")
if kl_ag > 0.05:
print("[HIGH — non-power-law circuits present]")
elif kl_ag > 0.01:
print("[MODERATE — some circuit deviation]")
else:
print("[LOW — pure positional attention]")
print(f"\n γ interpretation:")
if gamma < 0.7:
print(f" Very long-range attention (large θ, LLaMA-3/Qwen2.5 class)")
elif gamma < 0.95:
print(f" Long-range attention (standard RoPE, Phase A)")
elif gamma < 1.05:
print(f" Hagedorn crossover — attention at phase boundary")
elif gamma < 1.3:
print(f" Short-range attention (AbsPE or short context training)")
else:
print(f" Highly local attention (possible SWA or very short context)")
# ── Save ─────────────────────────────────────────────────────────────
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
short = model_name.replace("/", "--")
result = {
"model": model_name,
"theta_nom": theta_nom,
"N": N,
"fast_mode": args.fast,
"fit_power_law": fit,
"gamma": gamma,
"gamma_pred": gamma_pred,
"delta_gamma": (gamma - gamma_pred) if gamma_pred else None,
"phase": phase,
"T_attn": T_attn,
"Z": Z, "U": U, "S": S, "F": F, "Cv": Cv,
"chi": chi, "xi": xi,
"D90": D90,
"D90_frac": D90 / N,
"delta_H_90": dH90,
"theta_eff_pade": theta_eff,
"kl_grammar": kl_ag,
"attn_by_distance": {str(d): v for d, v in attn_by_d.items()},
}
out_path = OUTPUT_DIR / f"{short}.json"
out_path.write_text(json.dumps(result, indent=2, default=float), encoding="utf-8")
print(f"\n Saved: {out_path}")
print(f"{'='*65}\n")
return result
def main():
parser = argparse.ArgumentParser(
description="Predicting How Transformers Attend — diagnostic for any causal LM"
)
parser.add_argument("--model", required=True,
help="HuggingFace model ID (e.g. EleutherAI/pythia-70m)")
parser.add_argument("--local", default=None,
help="Local path to model weights (if not downloading)")
parser.add_argument("--theta", type=int, default=None,
help="RoPE θ (auto-detected for known models)")
parser.add_argument("--N", type=int, default=2000,
help="Context length N for thermodynamic calculations (default 2000)")
parser.add_argument("--fast", action="store_true",
help="Fast mode: fewer distances and prompts (~5 min on CPU)")
parser.add_argument("--load_in_4bit", action="store_true",
help="Load model in 4-bit quantization (requires bitsandbytes)")
parser.add_argument("--cpu", action="store_true",
help="Force CPU even if CUDA available")
args = parser.parse_args()
try:
run_diagnostic(args)
except KeyboardInterrupt:
print("\n[interrupted]")
except Exception as e:
print(f"\n[ERROR] {e}")
raise
if __name__ == "__main__":
main()