| """ |
| Dynamic mHC Β΅P Diagnostic Suite |
| ================================ |
| Implements all 6 diagnostics from the updated plan for the conditional |
| fixed-depth dynamic mHC Β΅P theorem. |
| |
| Matches DeepSeek V4 inference code exactly: |
| - hc_split_sinkhorn: softmax init + K-1 row/col normalization iterations |
| - hc_pre: flatten β RMS-normalize β F.linear β split β sigmoid/Sinkhorn |
| - hc_post: post * x + comb @ residual |
| |
| All diagnostics operate on synthetic data at multiple widths to test |
| width-transfer conditions. |
| |
| Usage: |
| python mhc_diagnostics.py |
| |
| Results saved to diagnostic_results.json |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import json |
| from dataclasses import dataclass |
| from typing import Dict, List |
| import math |
|
|
|
|
| |
| |
| |
|
|
| def sinkhorn_knopp(logits: torch.Tensor, K: int, eps: float = 1e-6) -> torch.Tensor: |
| """ |
| Sinkhorn-Knopp normalization matching V4 kernel.py exactly. |
| |
| V4 procedure: |
| 1. comb = softmax(logits, dim=-1) + eps |
| 2. comb = comb / (comb.sum(dim=-2) + eps) |
| 3. repeat K-1 times: |
| comb = comb / (comb.sum(dim=-1, keepdim=True) + eps) |
| comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) |
| """ |
| comb = F.softmax(logits, dim=-1) + eps |
| comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) |
| for _ in range(K - 1): |
| comb = comb / (comb.sum(dim=-1, keepdim=True) + eps) |
| comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) |
| return comb |
|
|
|
|
| def sinkhorn_knopp_differentiable(logits: torch.Tensor, K: int, eps: float = 1e-6): |
| """Same as sinkhorn_knopp but ensures autograd graph is built.""" |
| comb = F.softmax(logits, dim=-1) + eps |
| comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) |
| for _ in range(K - 1): |
| comb = comb / (comb.sum(dim=-1, keepdim=True) + eps) |
| comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) |
| return comb |
|
|
|
|
| def hc_split_sinkhorn_py(mixes, hc_scale, hc_base, hc_mult=4, sinkhorn_iters=20, eps=1e-6): |
| """Pure PyTorch version of hc_split_sinkhorn from V4 kernel.py.""" |
| n = hc_mult |
| pre_raw = mixes[..., :n] |
| post_raw = mixes[..., n:2*n] |
| comb_raw = mixes[..., 2*n:] |
| |
| pre_base = hc_base[:n] |
| post_base = hc_base[n:2*n] |
| comb_base = hc_base[2*n:] |
| |
| pre = torch.sigmoid(pre_raw * hc_scale[0] + pre_base) + eps |
| post = 2.0 * torch.sigmoid(post_raw * hc_scale[1] + post_base) |
| |
| comb_logits = comb_raw * hc_scale[2] + comb_base |
| comb_logits = comb_logits.unflatten(-1, (n, n)) |
| comb = sinkhorn_knopp(comb_logits, sinkhorn_iters, eps) |
| |
| return pre, post, comb |
|
|
|
|
| def hc_pre_py(x, hc_fn, hc_scale, hc_base, hc_mult=4, sinkhorn_iters=20, eps=1e-6): |
| """Matches Block.hc_pre from V4 model.py.""" |
| shape = x.shape |
| x_flat = x.flatten(2).float() |
| rsqrt = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + eps) |
| mixes = F.linear(x_flat, hc_fn) * rsqrt |
| pre, post, comb = hc_split_sinkhorn_py(mixes, hc_scale, hc_base, hc_mult, sinkhorn_iters, eps) |
| y = torch.sum(pre.unsqueeze(-1) * x.float(), dim=2) |
| return y, post, comb |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class MHCConfig: |
| dim: int = 256 |
| n_s: int = 4 |
| K: int = 20 |
| eps: float = 1e-6 |
| batch_size: int = 8 |
| seq_len: int = 32 |
| init_std: float = 0.02 |
| |
| @property |
| def hc_dim(self): |
| return self.n_s * self.dim |
| |
| @property |
| def mix_hc(self): |
| return (2 + self.n_s) * self.n_s |
|
|
|
|
| def init_mhc_params(cfg, device='cpu'): |
| n_s, d, mix_hc, hc_dim = cfg.n_s, cfg.dim, cfg.mix_hc, cfg.hc_dim |
| hc_fn = torch.randn(mix_hc, hc_dim, device=device) * cfg.init_std |
| hc_base = torch.zeros(mix_hc, device=device) |
| hc_scale = torch.ones(3, device=device) * 0.1 |
| return hc_fn, hc_scale, hc_base |
|
|
|
|
| def make_residual_state(cfg, device='cpu'): |
| return torch.randn(cfg.batch_size, cfg.seq_len, cfg.n_s, cfg.dim, device=device) |
|
|
|
|
| |
| |
| |
|
|
| def diagnostic_1_sinkhorn_error(comb): |
| row_sums = comb.sum(dim=-1) |
| row_violation = (row_sums - 1.0).abs().max(dim=-1)[0] |
| col_sums = comb.sum(dim=-2) |
| col_violation = (col_sums - 1.0).abs().max(dim=-1)[0] |
| eps_K = torch.maximum(row_violation, col_violation) |
| return { |
| 'eps_K_mean': eps_K.mean().item(), |
| 'eps_K_max': eps_K.max().item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def diagnostic_2_spectral_norm(comb): |
| flat = comb.reshape(-1, comb.shape[-2], comb.shape[-1]) |
| sv = torch.linalg.svdvals(flat) |
| sigma_max = sv[:, 0] |
| sigma_min = sv[:, -1] |
| return { |
| 'spectral_norm_mean': sigma_max.mean().item(), |
| 'spectral_norm_max': sigma_max.max().item(), |
| 'condition_number_mean': (sigma_max / (sigma_min + 1e-10)).mean().item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def gauge_projector(n_s, device='cpu'): |
| """Orthogonal projector onto G^β₯ where G = {rΒ·1α΅ + 1Β·cα΅}.""" |
| basis_vecs = [] |
| for i in range(n_s): |
| v = torch.zeros(n_s, n_s, device=device) |
| v[i, :] = 1.0 |
| basis_vecs.append(v.flatten()) |
| for j in range(n_s): |
| v = torch.zeros(n_s, n_s, device=device) |
| v[:, j] = 1.0 |
| basis_vecs.append(v.flatten()) |
| |
| G = torch.stack(basis_vecs, dim=0).T |
| Q, R = torch.linalg.qr(G.double()) |
| diag_R = R.diag().abs() |
| rank = (diag_R > 1e-10).sum().item() |
| Q = Q[:, :rank].float() |
| P_G = Q @ Q.T |
| P_perp = torch.eye(n_s * n_s, device=device) - P_G |
| return P_perp |
|
|
|
|
| def diagnostic_3_quotient_jacobian(n_s=4, K_values=[1,2,5,10,20,50], n_samples=100, device='cpu'): |
| P_perp = gauge_projector(n_s, device) |
| d_perp = (n_s - 1) ** 2 |
| eigvals, eigvecs = torch.linalg.eigh(P_perp) |
| basis = eigvecs[:, -d_perp:] |
| |
| results = {} |
| for K in K_values: |
| sigma_maxs, sigma_mins, kappas = [], [], [] |
| for _ in range(n_samples): |
| Z = torch.randn(n_s, n_s, device=device, requires_grad=True) |
| C = sinkhorn_knopp_differentiable(Z, K) |
| c_flat = C.flatten() |
| |
| J = torch.zeros(n_s*n_s, n_s*n_s, device=device) |
| for idx in range(n_s*n_s): |
| if Z.grad is not None: |
| Z.grad.zero_() |
| c_flat[idx].backward(retain_graph=True) |
| J[idx] = Z.grad.flatten() |
| |
| J_quot = basis.T @ J @ basis |
| sv = torch.linalg.svdvals(J_quot) |
| sigma_maxs.append(sv[0].item()) |
| sigma_mins.append(sv[-1].item()) |
| kappas.append((sv[0] / (sv[-1] + 1e-10)).item()) |
| |
| results[f'K={K}'] = { |
| 'sigma_max_mean': np.mean(sigma_maxs), |
| 'sigma_min_mean': np.mean(sigma_mins), |
| 'kappa_mean': np.mean(kappas), |
| } |
| return results |
|
|
|
|
| |
| |
| |
|
|
| def diagnostic_4_dynamic_sensitivity(cfg, n_samples=50, device='cpu'): |
| n_s, d, K, eps = cfg.n_s, cfg.dim, cfg.K, cfg.eps |
| hc_dim = n_s * d |
| mix_hc = cfg.mix_hc |
| |
| dc_x_norms, dp_x_norms, dq_x_norms = [], [], [] |
| |
| for _ in range(n_samples): |
| hc_fn = torch.randn(mix_hc, hc_dim, device=device) * cfg.init_std |
| hc_scale = torch.ones(3, device=device) * 0.1 |
| hc_base = torch.zeros(mix_hc, device=device) |
| |
| x = torch.randn(1, 1, n_s, d, device=device, requires_grad=True) |
| x_norm = x.norm().item() |
| |
| x_flat = x.flatten(2).float() |
| rsqrt = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + eps) |
| mixes = F.linear(x_flat, hc_fn) * rsqrt |
| |
| |
| comb_raw = mixes[..., 2*n_s:] |
| comb_logits = (comb_raw * hc_scale[2] + hc_base[2*n_s:]).unflatten(-1, (n_s, n_s)) |
| comb = sinkhorn_knopp_differentiable(comb_logits, K, eps) |
| c_flat = comb.flatten() |
| |
| J_rows = [] |
| for idx in range(n_s*n_s): |
| if x.grad is not None: |
| x.grad.zero_() |
| c_flat[idx].backward(retain_graph=True) |
| J_rows.append(x.grad.flatten().clone()) |
| J_C = torch.stack(J_rows, dim=0) |
| dc_x_norms.append(torch.linalg.norm(J_C, ord=2).item() * x_norm) |
| |
| |
| pre = torch.sigmoid(mixes[..., :n_s] * hc_scale[0] + hc_base[:n_s]) + eps |
| p_flat = pre.flatten() |
| jac_p = [] |
| for idx in range(n_s): |
| if x.grad is not None: |
| x.grad.zero_() |
| p_flat[idx].backward(retain_graph=True) |
| jac_p.append(x.grad.flatten().clone()) |
| J_p = torch.stack(jac_p, dim=0) |
| dp_x_norms.append(torch.linalg.norm(J_p, ord=2).item() * x_norm) |
| |
| |
| post = 2.0 * torch.sigmoid(mixes[..., n_s:2*n_s] * hc_scale[1] + hc_base[n_s:2*n_s]) |
| q_flat = post.flatten() |
| jac_q = [] |
| for idx in range(n_s): |
| if x.grad is not None: |
| x.grad.zero_() |
| q_flat[idx].backward(retain_graph=True) |
| jac_q.append(x.grad.flatten().clone()) |
| J_q = torch.stack(jac_q, dim=0) |
| dq_x_norms.append(torch.linalg.norm(J_q, ord=2).item() * x_norm) |
| |
| return { |
| 'dc_x_mean': np.mean(dc_x_norms), |
| 'dc_x_std': np.std(dc_x_norms), |
| 'dp_x_mean': np.mean(dp_x_norms), |
| 'dp_x_std': np.std(dp_x_norms), |
| 'dq_x_mean': np.mean(dq_x_norms), |
| 'dq_x_std': np.std(dq_x_norms), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def diagnostic_5_logit_update(cfg, n_samples=50, lr=1e-3, device='cpu'): |
| n_s, d, K, eps = cfg.n_s, cfg.dim, cfg.K, cfg.eps |
| hc_dim = n_s * d |
| P_perp = gauge_projector(n_s, device) |
| |
| proj_norms, total_norms = [], [] |
| |
| for _ in range(n_samples): |
| W_C = torch.nn.Parameter(torch.randn(n_s*n_s, hc_dim, device=device) * cfg.init_std) |
| s_C = torch.nn.Parameter(torch.tensor([0.1], device=device)) |
| b_C = torch.nn.Parameter(torch.zeros(n_s*n_s, device=device)) |
| |
| x = torch.randn(1, 1, n_s, d, device=device) |
| x_flat = x.flatten(2).float() |
| rsqrt = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + eps) |
| x_hat = (x_flat * rsqrt).squeeze(0).squeeze(0) |
| |
| Z = s_C * (W_C @ x_hat) + b_C |
| C = sinkhorn_knopp_differentiable(Z.unflatten(0, (n_s, n_s)), K, eps) |
| |
| loss = ((C - torch.eye(n_s, device=device))**2).sum() |
| loss.backward() |
| |
| with torch.no_grad(): |
| W_new = W_C - lr * W_C.grad |
| s_new = s_C - lr * s_C.grad |
| b_new = b_C - lr * b_C.grad |
| Z_new = s_new * (W_new @ x_hat) + b_new |
| delta_Z = Z_new - Z.detach() |
| delta_Z_perp = P_perp @ delta_Z |
| proj_norms.append(delta_Z_perp.norm().item()) |
| total_norms.append(delta_Z.norm().item()) |
| |
| return { |
| 'proj_perp_norm_mean': np.mean(proj_norms), |
| 'total_norm_mean': np.mean(total_norms), |
| 'perp_ratio_mean': np.mean([p/(t+1e-10) for p,t in zip(proj_norms, total_norms)]), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def diagnostic_6_gate_stats(pre, post): |
| pre_flat = pre.flatten(0, -2) |
| post_flat = post.flatten(0, -2) |
| return { |
| 'pre_mean': pre_flat.mean().item(), |
| 'pre_l1_mean': pre_flat.sum(-1).mean().item(), |
| 'post_mean': post_flat.mean().item(), |
| 'post_linf_mean': post_flat.abs().max(-1)[0].mean().item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def run_full_diagnostics(widths=[64,128,256,512,1024], n_s=4, K=20, device='cpu'): |
| results = {} |
| d3 = None |
| |
| for d in widths: |
| print(f"\nWidth d = {d}") |
| cfg = MHCConfig(dim=d, n_s=n_s, K=K, batch_size=4, seq_len=16) |
| hc_fn, hc_scale, hc_base = init_mhc_params(cfg, device) |
| x = make_residual_state(cfg, device) |
| y, post, comb = hc_pre_py(x, hc_fn, hc_scale, hc_base, cfg.n_s, cfg.K, cfg.eps) |
| |
| d1 = diagnostic_1_sinkhorn_error(comb) |
| d2 = diagnostic_2_spectral_norm(comb) |
| if d3 is None: |
| d3 = diagnostic_3_quotient_jacobian(n_s, [1,2,5,10,20,50], 100, device) |
| d4 = diagnostic_4_dynamic_sensitivity(cfg, 30, device) |
| d5 = diagnostic_5_logit_update(cfg, 30, device=device) |
| |
| x_flat = x.flatten(2).float() |
| rsqrt_val = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + cfg.eps) |
| mixes_val = F.linear(x_flat, hc_fn) * rsqrt_val |
| pre_val, post_val, _ = hc_split_sinkhorn_py(mixes_val, hc_scale, hc_base, cfg.n_s, cfg.K, cfg.eps) |
| d6 = diagnostic_6_gate_stats(pre_val, post_val) |
| |
| results[d] = { |
| 'sinkhorn_error': d1, 'spectral_norm': d2, |
| 'quotient_jacobian': d3, 'dynamic_sensitivity': d4, |
| 'logit_update': d5, 'gate_stats': d6, |
| } |
| |
| print(f" Ξ΅_K={d1['eps_K_mean']:.1e}, ||C||_2={d2['spectral_norm_mean']:.6f}") |
| print(f" ||DC||Β·||x||={d4['dc_x_mean']:.4f}, ||Dp||Β·||x||={d4['dp_x_mean']:.4f}") |
| |
| return results |
|
|
|
|
| if __name__ == '__main__': |
| print("Dynamic mHC Β΅P Diagnostics") |
| print("=" * 60) |
| results = run_full_diagnostics() |
| |
| |
| widths = sorted(results.keys()) |
| dc_means = [results[d]['dynamic_sensitivity']['dc_x_mean'] for d in widths] |
| log_d = np.log(widths) |
| A = np.vstack([log_d, np.ones_like(log_d)]).T |
| alpha = np.linalg.lstsq(A, np.log(dc_means), rcond=None)[0][0] |
| |
| print(f"\n||DC(x)||Β·||x|| scales as d^{alpha:.3f}") |
| print(f"\nSummary:") |
| print(f"{'Width':>8} {'||DC||Β·||x||':>14} {'||Dp||Β·||x||':>14} {'||C||_2':>10} {'Ξ΅_K':>10}") |
| for d in widths: |
| r = results[d] |
| print(f"{d:>8} {r['dynamic_sensitivity']['dc_x_mean']:>14.4f} " |
| f"{r['dynamic_sensitivity']['dp_x_mean']:>14.4f} " |
| f"{r['spectral_norm']['spectral_norm_mean']:>10.6f} " |
| f"{r['sinkhorn_error']['eps_K_mean']:>10.1e}") |
|
|