mhc-mup-diagnostics / mhc_diagnostics.py
galimova's picture
Add diagnostic implementation code
3175bf5 verified
"""
Dynamic mHC Β΅P Diagnostic Suite
================================
Implements all 6 diagnostics from the updated plan for the conditional
fixed-depth dynamic mHC Β΅P theorem.
Matches DeepSeek V4 inference code exactly:
- hc_split_sinkhorn: softmax init + K-1 row/col normalization iterations
- hc_pre: flatten β†’ RMS-normalize β†’ F.linear β†’ split β†’ sigmoid/Sinkhorn
- hc_post: post * x + comb @ residual
All diagnostics operate on synthetic data at multiple widths to test
width-transfer conditions.
Usage:
python mhc_diagnostics.py
Results saved to diagnostic_results.json
"""
import torch
import torch.nn.functional as F
import numpy as np
import json
from dataclasses import dataclass
from typing import Dict, List
import math
# ───────────────────────────────────────────────────────
# Β§0 Pure-PyTorch mHC matching V4 inference/model.py
# ───────────────────────────────────────────────────────
def sinkhorn_knopp(logits: torch.Tensor, K: int, eps: float = 1e-6) -> torch.Tensor:
"""
Sinkhorn-Knopp normalization matching V4 kernel.py exactly.
V4 procedure:
1. comb = softmax(logits, dim=-1) + eps
2. comb = comb / (comb.sum(dim=-2) + eps)
3. repeat K-1 times:
comb = comb / (comb.sum(dim=-1, keepdim=True) + eps)
comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
"""
comb = F.softmax(logits, dim=-1) + eps
comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
for _ in range(K - 1):
comb = comb / (comb.sum(dim=-1, keepdim=True) + eps)
comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
return comb
def sinkhorn_knopp_differentiable(logits: torch.Tensor, K: int, eps: float = 1e-6):
"""Same as sinkhorn_knopp but ensures autograd graph is built."""
comb = F.softmax(logits, dim=-1) + eps
comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
for _ in range(K - 1):
comb = comb / (comb.sum(dim=-1, keepdim=True) + eps)
comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
return comb
def hc_split_sinkhorn_py(mixes, hc_scale, hc_base, hc_mult=4, sinkhorn_iters=20, eps=1e-6):
"""Pure PyTorch version of hc_split_sinkhorn from V4 kernel.py."""
n = hc_mult
pre_raw = mixes[..., :n]
post_raw = mixes[..., n:2*n]
comb_raw = mixes[..., 2*n:]
pre_base = hc_base[:n]
post_base = hc_base[n:2*n]
comb_base = hc_base[2*n:]
pre = torch.sigmoid(pre_raw * hc_scale[0] + pre_base) + eps
post = 2.0 * torch.sigmoid(post_raw * hc_scale[1] + post_base)
comb_logits = comb_raw * hc_scale[2] + comb_base
comb_logits = comb_logits.unflatten(-1, (n, n))
comb = sinkhorn_knopp(comb_logits, sinkhorn_iters, eps)
return pre, post, comb
def hc_pre_py(x, hc_fn, hc_scale, hc_base, hc_mult=4, sinkhorn_iters=20, eps=1e-6):
"""Matches Block.hc_pre from V4 model.py."""
shape = x.shape
x_flat = x.flatten(2).float()
rsqrt = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + eps)
mixes = F.linear(x_flat, hc_fn) * rsqrt
pre, post, comb = hc_split_sinkhorn_py(mixes, hc_scale, hc_base, hc_mult, sinkhorn_iters, eps)
y = torch.sum(pre.unsqueeze(-1) * x.float(), dim=2)
return y, post, comb
# ───────────────────────────────────────────────────────
# Β§1 Configuration
# ───────────────────────────────────────────────────────
@dataclass
class MHCConfig:
dim: int = 256
n_s: int = 4
K: int = 20
eps: float = 1e-6
batch_size: int = 8
seq_len: int = 32
init_std: float = 0.02
@property
def hc_dim(self):
return self.n_s * self.dim
@property
def mix_hc(self):
return (2 + self.n_s) * self.n_s
def init_mhc_params(cfg, device='cpu'):
n_s, d, mix_hc, hc_dim = cfg.n_s, cfg.dim, cfg.mix_hc, cfg.hc_dim
hc_fn = torch.randn(mix_hc, hc_dim, device=device) * cfg.init_std
hc_base = torch.zeros(mix_hc, device=device)
hc_scale = torch.ones(3, device=device) * 0.1
return hc_fn, hc_scale, hc_base
def make_residual_state(cfg, device='cpu'):
return torch.randn(cfg.batch_size, cfg.seq_len, cfg.n_s, cfg.dim, device=device)
# ───────────────────────────────────────────────────────
# Β§2 Diagnostic 1: Finite-K Sinkhorn error
# ───────────────────────────────────────────────────────
def diagnostic_1_sinkhorn_error(comb):
row_sums = comb.sum(dim=-1)
row_violation = (row_sums - 1.0).abs().max(dim=-1)[0]
col_sums = comb.sum(dim=-2)
col_violation = (col_sums - 1.0).abs().max(dim=-1)[0]
eps_K = torch.maximum(row_violation, col_violation)
return {
'eps_K_mean': eps_K.mean().item(),
'eps_K_max': eps_K.max().item(),
}
# ───────────────────────────────────────────────────────
# Β§3 Diagnostic 2: Spectral norm
# ───────────────────────────────────────────────────────
def diagnostic_2_spectral_norm(comb):
flat = comb.reshape(-1, comb.shape[-2], comb.shape[-1])
sv = torch.linalg.svdvals(flat)
sigma_max = sv[:, 0]
sigma_min = sv[:, -1]
return {
'spectral_norm_mean': sigma_max.mean().item(),
'spectral_norm_max': sigma_max.max().item(),
'condition_number_mean': (sigma_max / (sigma_min + 1e-10)).mean().item(),
}
# ───────────────────────────────────────────────────────
# Β§4 Diagnostic 3: Quotient Jacobian
# ───────────────────────────────────────────────────────
def gauge_projector(n_s, device='cpu'):
"""Orthogonal projector onto G^βŠ₯ where G = {rΒ·1α΅€ + 1Β·cα΅€}."""
basis_vecs = []
for i in range(n_s):
v = torch.zeros(n_s, n_s, device=device)
v[i, :] = 1.0
basis_vecs.append(v.flatten())
for j in range(n_s):
v = torch.zeros(n_s, n_s, device=device)
v[:, j] = 1.0
basis_vecs.append(v.flatten())
G = torch.stack(basis_vecs, dim=0).T
Q, R = torch.linalg.qr(G.double())
diag_R = R.diag().abs()
rank = (diag_R > 1e-10).sum().item()
Q = Q[:, :rank].float()
P_G = Q @ Q.T
P_perp = torch.eye(n_s * n_s, device=device) - P_G
return P_perp
def diagnostic_3_quotient_jacobian(n_s=4, K_values=[1,2,5,10,20,50], n_samples=100, device='cpu'):
P_perp = gauge_projector(n_s, device)
d_perp = (n_s - 1) ** 2
eigvals, eigvecs = torch.linalg.eigh(P_perp)
basis = eigvecs[:, -d_perp:]
results = {}
for K in K_values:
sigma_maxs, sigma_mins, kappas = [], [], []
for _ in range(n_samples):
Z = torch.randn(n_s, n_s, device=device, requires_grad=True)
C = sinkhorn_knopp_differentiable(Z, K)
c_flat = C.flatten()
J = torch.zeros(n_s*n_s, n_s*n_s, device=device)
for idx in range(n_s*n_s):
if Z.grad is not None:
Z.grad.zero_()
c_flat[idx].backward(retain_graph=True)
J[idx] = Z.grad.flatten()
J_quot = basis.T @ J @ basis
sv = torch.linalg.svdvals(J_quot)
sigma_maxs.append(sv[0].item())
sigma_mins.append(sv[-1].item())
kappas.append((sv[0] / (sv[-1] + 1e-10)).item())
results[f'K={K}'] = {
'sigma_max_mean': np.mean(sigma_maxs),
'sigma_min_mean': np.mean(sigma_mins),
'kappa_mean': np.mean(kappas),
}
return results
# ───────────────────────────────────────────────────────
# Β§5 Diagnostic 4: Dynamic sensitivity
# ───────────────────────────────────────────────────────
def diagnostic_4_dynamic_sensitivity(cfg, n_samples=50, device='cpu'):
n_s, d, K, eps = cfg.n_s, cfg.dim, cfg.K, cfg.eps
hc_dim = n_s * d
mix_hc = cfg.mix_hc
dc_x_norms, dp_x_norms, dq_x_norms = [], [], []
for _ in range(n_samples):
hc_fn = torch.randn(mix_hc, hc_dim, device=device) * cfg.init_std
hc_scale = torch.ones(3, device=device) * 0.1
hc_base = torch.zeros(mix_hc, device=device)
x = torch.randn(1, 1, n_s, d, device=device, requires_grad=True)
x_norm = x.norm().item()
x_flat = x.flatten(2).float()
rsqrt = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + eps)
mixes = F.linear(x_flat, hc_fn) * rsqrt
# Comb Jacobian
comb_raw = mixes[..., 2*n_s:]
comb_logits = (comb_raw * hc_scale[2] + hc_base[2*n_s:]).unflatten(-1, (n_s, n_s))
comb = sinkhorn_knopp_differentiable(comb_logits, K, eps)
c_flat = comb.flatten()
J_rows = []
for idx in range(n_s*n_s):
if x.grad is not None:
x.grad.zero_()
c_flat[idx].backward(retain_graph=True)
J_rows.append(x.grad.flatten().clone())
J_C = torch.stack(J_rows, dim=0)
dc_x_norms.append(torch.linalg.norm(J_C, ord=2).item() * x_norm)
# Pre Jacobian
pre = torch.sigmoid(mixes[..., :n_s] * hc_scale[0] + hc_base[:n_s]) + eps
p_flat = pre.flatten()
jac_p = []
for idx in range(n_s):
if x.grad is not None:
x.grad.zero_()
p_flat[idx].backward(retain_graph=True)
jac_p.append(x.grad.flatten().clone())
J_p = torch.stack(jac_p, dim=0)
dp_x_norms.append(torch.linalg.norm(J_p, ord=2).item() * x_norm)
# Post Jacobian
post = 2.0 * torch.sigmoid(mixes[..., n_s:2*n_s] * hc_scale[1] + hc_base[n_s:2*n_s])
q_flat = post.flatten()
jac_q = []
for idx in range(n_s):
if x.grad is not None:
x.grad.zero_()
q_flat[idx].backward(retain_graph=True)
jac_q.append(x.grad.flatten().clone())
J_q = torch.stack(jac_q, dim=0)
dq_x_norms.append(torch.linalg.norm(J_q, ord=2).item() * x_norm)
return {
'dc_x_mean': np.mean(dc_x_norms),
'dc_x_std': np.std(dc_x_norms),
'dp_x_mean': np.mean(dp_x_norms),
'dp_x_std': np.std(dp_x_norms),
'dq_x_mean': np.mean(dq_x_norms),
'dq_x_std': np.std(dq_x_norms),
}
# ───────────────────────────────────────────────────────
# Β§6 Diagnostic 5: Logit update scale
# ───────────────────────────────────────────────────────
def diagnostic_5_logit_update(cfg, n_samples=50, lr=1e-3, device='cpu'):
n_s, d, K, eps = cfg.n_s, cfg.dim, cfg.K, cfg.eps
hc_dim = n_s * d
P_perp = gauge_projector(n_s, device)
proj_norms, total_norms = [], []
for _ in range(n_samples):
W_C = torch.nn.Parameter(torch.randn(n_s*n_s, hc_dim, device=device) * cfg.init_std)
s_C = torch.nn.Parameter(torch.tensor([0.1], device=device))
b_C = torch.nn.Parameter(torch.zeros(n_s*n_s, device=device))
x = torch.randn(1, 1, n_s, d, device=device)
x_flat = x.flatten(2).float()
rsqrt = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + eps)
x_hat = (x_flat * rsqrt).squeeze(0).squeeze(0)
Z = s_C * (W_C @ x_hat) + b_C
C = sinkhorn_knopp_differentiable(Z.unflatten(0, (n_s, n_s)), K, eps)
loss = ((C - torch.eye(n_s, device=device))**2).sum()
loss.backward()
with torch.no_grad():
W_new = W_C - lr * W_C.grad
s_new = s_C - lr * s_C.grad
b_new = b_C - lr * b_C.grad
Z_new = s_new * (W_new @ x_hat) + b_new
delta_Z = Z_new - Z.detach()
delta_Z_perp = P_perp @ delta_Z
proj_norms.append(delta_Z_perp.norm().item())
total_norms.append(delta_Z.norm().item())
return {
'proj_perp_norm_mean': np.mean(proj_norms),
'total_norm_mean': np.mean(total_norms),
'perp_ratio_mean': np.mean([p/(t+1e-10) for p,t in zip(proj_norms, total_norms)]),
}
# ───────────────────────────────────────────────────────
# Β§7 Diagnostic 6: Gate statistics
# ───────────────────────────────────────────────────────
def diagnostic_6_gate_stats(pre, post):
pre_flat = pre.flatten(0, -2)
post_flat = post.flatten(0, -2)
return {
'pre_mean': pre_flat.mean().item(),
'pre_l1_mean': pre_flat.sum(-1).mean().item(),
'post_mean': post_flat.mean().item(),
'post_linf_mean': post_flat.abs().max(-1)[0].mean().item(),
}
# ───────────────────────────────────────────────────────
# Β§8 Main runner
# ───────────────────────────────────────────────────────
def run_full_diagnostics(widths=[64,128,256,512,1024], n_s=4, K=20, device='cpu'):
results = {}
d3 = None
for d in widths:
print(f"\nWidth d = {d}")
cfg = MHCConfig(dim=d, n_s=n_s, K=K, batch_size=4, seq_len=16)
hc_fn, hc_scale, hc_base = init_mhc_params(cfg, device)
x = make_residual_state(cfg, device)
y, post, comb = hc_pre_py(x, hc_fn, hc_scale, hc_base, cfg.n_s, cfg.K, cfg.eps)
d1 = diagnostic_1_sinkhorn_error(comb)
d2 = diagnostic_2_spectral_norm(comb)
if d3 is None:
d3 = diagnostic_3_quotient_jacobian(n_s, [1,2,5,10,20,50], 100, device)
d4 = diagnostic_4_dynamic_sensitivity(cfg, 30, device)
d5 = diagnostic_5_logit_update(cfg, 30, device=device)
x_flat = x.flatten(2).float()
rsqrt_val = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + cfg.eps)
mixes_val = F.linear(x_flat, hc_fn) * rsqrt_val
pre_val, post_val, _ = hc_split_sinkhorn_py(mixes_val, hc_scale, hc_base, cfg.n_s, cfg.K, cfg.eps)
d6 = diagnostic_6_gate_stats(pre_val, post_val)
results[d] = {
'sinkhorn_error': d1, 'spectral_norm': d2,
'quotient_jacobian': d3, 'dynamic_sensitivity': d4,
'logit_update': d5, 'gate_stats': d6,
}
print(f" Ξ΅_K={d1['eps_K_mean']:.1e}, ||C||_2={d2['spectral_norm_mean']:.6f}")
print(f" ||DC||Β·||x||={d4['dc_x_mean']:.4f}, ||Dp||Β·||x||={d4['dp_x_mean']:.4f}")
return results
if __name__ == '__main__':
print("Dynamic mHC Β΅P Diagnostics")
print("=" * 60)
results = run_full_diagnostics()
# Scaling analysis
widths = sorted(results.keys())
dc_means = [results[d]['dynamic_sensitivity']['dc_x_mean'] for d in widths]
log_d = np.log(widths)
A = np.vstack([log_d, np.ones_like(log_d)]).T
alpha = np.linalg.lstsq(A, np.log(dc_means), rcond=None)[0][0]
print(f"\n||DC(x)||Β·||x|| scales as d^{alpha:.3f}")
print(f"\nSummary:")
print(f"{'Width':>8} {'||DC||Β·||x||':>14} {'||Dp||Β·||x||':>14} {'||C||_2':>10} {'Ξ΅_K':>10}")
for d in widths:
r = results[d]
print(f"{d:>8} {r['dynamic_sensitivity']['dc_x_mean']:>14.4f} "
f"{r['dynamic_sensitivity']['dp_x_mean']:>14.4f} "
f"{r['spectral_norm']['spectral_norm_mean']:>10.6f} "
f"{r['sinkhorn_error']['eps_K_mean']:>10.1e}")