""" Dynamic mHC µP Diagnostic Suite ================================ Implements all 6 diagnostics from the updated plan for the conditional fixed-depth dynamic mHC µP theorem. Matches DeepSeek V4 inference code exactly: - hc_split_sinkhorn: softmax init + K-1 row/col normalization iterations - hc_pre: flatten → RMS-normalize → F.linear → split → sigmoid/Sinkhorn - hc_post: post * x + comb @ residual All diagnostics operate on synthetic data at multiple widths to test width-transfer conditions. Usage: python mhc_diagnostics.py Results saved to diagnostic_results.json """ import torch import torch.nn.functional as F import numpy as np import json from dataclasses import dataclass from typing import Dict, List import math # ─────────────────────────────────────────────────────── # §0 Pure-PyTorch mHC matching V4 inference/model.py # ─────────────────────────────────────────────────────── def sinkhorn_knopp(logits: torch.Tensor, K: int, eps: float = 1e-6) -> torch.Tensor: """ Sinkhorn-Knopp normalization matching V4 kernel.py exactly. V4 procedure: 1. comb = softmax(logits, dim=-1) + eps 2. comb = comb / (comb.sum(dim=-2) + eps) 3. repeat K-1 times: comb = comb / (comb.sum(dim=-1, keepdim=True) + eps) comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) """ comb = F.softmax(logits, dim=-1) + eps comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) for _ in range(K - 1): comb = comb / (comb.sum(dim=-1, keepdim=True) + eps) comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) return comb def sinkhorn_knopp_differentiable(logits: torch.Tensor, K: int, eps: float = 1e-6): """Same as sinkhorn_knopp but ensures autograd graph is built.""" comb = F.softmax(logits, dim=-1) + eps comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) for _ in range(K - 1): comb = comb / (comb.sum(dim=-1, keepdim=True) + eps) comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) return comb def hc_split_sinkhorn_py(mixes, hc_scale, hc_base, hc_mult=4, sinkhorn_iters=20, eps=1e-6): """Pure PyTorch version of hc_split_sinkhorn from V4 kernel.py.""" n = hc_mult pre_raw = mixes[..., :n] post_raw = mixes[..., n:2*n] comb_raw = mixes[..., 2*n:] pre_base = hc_base[:n] post_base = hc_base[n:2*n] comb_base = hc_base[2*n:] pre = torch.sigmoid(pre_raw * hc_scale[0] + pre_base) + eps post = 2.0 * torch.sigmoid(post_raw * hc_scale[1] + post_base) comb_logits = comb_raw * hc_scale[2] + comb_base comb_logits = comb_logits.unflatten(-1, (n, n)) comb = sinkhorn_knopp(comb_logits, sinkhorn_iters, eps) return pre, post, comb def hc_pre_py(x, hc_fn, hc_scale, hc_base, hc_mult=4, sinkhorn_iters=20, eps=1e-6): """Matches Block.hc_pre from V4 model.py.""" shape = x.shape x_flat = x.flatten(2).float() rsqrt = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + eps) mixes = F.linear(x_flat, hc_fn) * rsqrt pre, post, comb = hc_split_sinkhorn_py(mixes, hc_scale, hc_base, hc_mult, sinkhorn_iters, eps) y = torch.sum(pre.unsqueeze(-1) * x.float(), dim=2) return y, post, comb # ─────────────────────────────────────────────────────── # §1 Configuration # ─────────────────────────────────────────────────────── @dataclass class MHCConfig: dim: int = 256 n_s: int = 4 K: int = 20 eps: float = 1e-6 batch_size: int = 8 seq_len: int = 32 init_std: float = 0.02 @property def hc_dim(self): return self.n_s * self.dim @property def mix_hc(self): return (2 + self.n_s) * self.n_s def init_mhc_params(cfg, device='cpu'): n_s, d, mix_hc, hc_dim = cfg.n_s, cfg.dim, cfg.mix_hc, cfg.hc_dim hc_fn = torch.randn(mix_hc, hc_dim, device=device) * cfg.init_std hc_base = torch.zeros(mix_hc, device=device) hc_scale = torch.ones(3, device=device) * 0.1 return hc_fn, hc_scale, hc_base def make_residual_state(cfg, device='cpu'): return torch.randn(cfg.batch_size, cfg.seq_len, cfg.n_s, cfg.dim, device=device) # ─────────────────────────────────────────────────────── # §2 Diagnostic 1: Finite-K Sinkhorn error # ─────────────────────────────────────────────────────── def diagnostic_1_sinkhorn_error(comb): row_sums = comb.sum(dim=-1) row_violation = (row_sums - 1.0).abs().max(dim=-1)[0] col_sums = comb.sum(dim=-2) col_violation = (col_sums - 1.0).abs().max(dim=-1)[0] eps_K = torch.maximum(row_violation, col_violation) return { 'eps_K_mean': eps_K.mean().item(), 'eps_K_max': eps_K.max().item(), } # ─────────────────────────────────────────────────────── # §3 Diagnostic 2: Spectral norm # ─────────────────────────────────────────────────────── def diagnostic_2_spectral_norm(comb): flat = comb.reshape(-1, comb.shape[-2], comb.shape[-1]) sv = torch.linalg.svdvals(flat) sigma_max = sv[:, 0] sigma_min = sv[:, -1] return { 'spectral_norm_mean': sigma_max.mean().item(), 'spectral_norm_max': sigma_max.max().item(), 'condition_number_mean': (sigma_max / (sigma_min + 1e-10)).mean().item(), } # ─────────────────────────────────────────────────────── # §4 Diagnostic 3: Quotient Jacobian # ─────────────────────────────────────────────────────── def gauge_projector(n_s, device='cpu'): """Orthogonal projector onto G^⊥ where G = {r·1ᵀ + 1·cᵀ}.""" basis_vecs = [] for i in range(n_s): v = torch.zeros(n_s, n_s, device=device) v[i, :] = 1.0 basis_vecs.append(v.flatten()) for j in range(n_s): v = torch.zeros(n_s, n_s, device=device) v[:, j] = 1.0 basis_vecs.append(v.flatten()) G = torch.stack(basis_vecs, dim=0).T Q, R = torch.linalg.qr(G.double()) diag_R = R.diag().abs() rank = (diag_R > 1e-10).sum().item() Q = Q[:, :rank].float() P_G = Q @ Q.T P_perp = torch.eye(n_s * n_s, device=device) - P_G return P_perp def diagnostic_3_quotient_jacobian(n_s=4, K_values=[1,2,5,10,20,50], n_samples=100, device='cpu'): P_perp = gauge_projector(n_s, device) d_perp = (n_s - 1) ** 2 eigvals, eigvecs = torch.linalg.eigh(P_perp) basis = eigvecs[:, -d_perp:] results = {} for K in K_values: sigma_maxs, sigma_mins, kappas = [], [], [] for _ in range(n_samples): Z = torch.randn(n_s, n_s, device=device, requires_grad=True) C = sinkhorn_knopp_differentiable(Z, K) c_flat = C.flatten() J = torch.zeros(n_s*n_s, n_s*n_s, device=device) for idx in range(n_s*n_s): if Z.grad is not None: Z.grad.zero_() c_flat[idx].backward(retain_graph=True) J[idx] = Z.grad.flatten() J_quot = basis.T @ J @ basis sv = torch.linalg.svdvals(J_quot) sigma_maxs.append(sv[0].item()) sigma_mins.append(sv[-1].item()) kappas.append((sv[0] / (sv[-1] + 1e-10)).item()) results[f'K={K}'] = { 'sigma_max_mean': np.mean(sigma_maxs), 'sigma_min_mean': np.mean(sigma_mins), 'kappa_mean': np.mean(kappas), } return results # ─────────────────────────────────────────────────────── # §5 Diagnostic 4: Dynamic sensitivity # ─────────────────────────────────────────────────────── def diagnostic_4_dynamic_sensitivity(cfg, n_samples=50, device='cpu'): n_s, d, K, eps = cfg.n_s, cfg.dim, cfg.K, cfg.eps hc_dim = n_s * d mix_hc = cfg.mix_hc dc_x_norms, dp_x_norms, dq_x_norms = [], [], [] for _ in range(n_samples): hc_fn = torch.randn(mix_hc, hc_dim, device=device) * cfg.init_std hc_scale = torch.ones(3, device=device) * 0.1 hc_base = torch.zeros(mix_hc, device=device) x = torch.randn(1, 1, n_s, d, device=device, requires_grad=True) x_norm = x.norm().item() x_flat = x.flatten(2).float() rsqrt = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + eps) mixes = F.linear(x_flat, hc_fn) * rsqrt # Comb Jacobian comb_raw = mixes[..., 2*n_s:] comb_logits = (comb_raw * hc_scale[2] + hc_base[2*n_s:]).unflatten(-1, (n_s, n_s)) comb = sinkhorn_knopp_differentiable(comb_logits, K, eps) c_flat = comb.flatten() J_rows = [] for idx in range(n_s*n_s): if x.grad is not None: x.grad.zero_() c_flat[idx].backward(retain_graph=True) J_rows.append(x.grad.flatten().clone()) J_C = torch.stack(J_rows, dim=0) dc_x_norms.append(torch.linalg.norm(J_C, ord=2).item() * x_norm) # Pre Jacobian pre = torch.sigmoid(mixes[..., :n_s] * hc_scale[0] + hc_base[:n_s]) + eps p_flat = pre.flatten() jac_p = [] for idx in range(n_s): if x.grad is not None: x.grad.zero_() p_flat[idx].backward(retain_graph=True) jac_p.append(x.grad.flatten().clone()) J_p = torch.stack(jac_p, dim=0) dp_x_norms.append(torch.linalg.norm(J_p, ord=2).item() * x_norm) # Post Jacobian post = 2.0 * torch.sigmoid(mixes[..., n_s:2*n_s] * hc_scale[1] + hc_base[n_s:2*n_s]) q_flat = post.flatten() jac_q = [] for idx in range(n_s): if x.grad is not None: x.grad.zero_() q_flat[idx].backward(retain_graph=True) jac_q.append(x.grad.flatten().clone()) J_q = torch.stack(jac_q, dim=0) dq_x_norms.append(torch.linalg.norm(J_q, ord=2).item() * x_norm) return { 'dc_x_mean': np.mean(dc_x_norms), 'dc_x_std': np.std(dc_x_norms), 'dp_x_mean': np.mean(dp_x_norms), 'dp_x_std': np.std(dp_x_norms), 'dq_x_mean': np.mean(dq_x_norms), 'dq_x_std': np.std(dq_x_norms), } # ─────────────────────────────────────────────────────── # §6 Diagnostic 5: Logit update scale # ─────────────────────────────────────────────────────── def diagnostic_5_logit_update(cfg, n_samples=50, lr=1e-3, device='cpu'): n_s, d, K, eps = cfg.n_s, cfg.dim, cfg.K, cfg.eps hc_dim = n_s * d P_perp = gauge_projector(n_s, device) proj_norms, total_norms = [], [] for _ in range(n_samples): W_C = torch.nn.Parameter(torch.randn(n_s*n_s, hc_dim, device=device) * cfg.init_std) s_C = torch.nn.Parameter(torch.tensor([0.1], device=device)) b_C = torch.nn.Parameter(torch.zeros(n_s*n_s, device=device)) x = torch.randn(1, 1, n_s, d, device=device) x_flat = x.flatten(2).float() rsqrt = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + eps) x_hat = (x_flat * rsqrt).squeeze(0).squeeze(0) Z = s_C * (W_C @ x_hat) + b_C C = sinkhorn_knopp_differentiable(Z.unflatten(0, (n_s, n_s)), K, eps) loss = ((C - torch.eye(n_s, device=device))**2).sum() loss.backward() with torch.no_grad(): W_new = W_C - lr * W_C.grad s_new = s_C - lr * s_C.grad b_new = b_C - lr * b_C.grad Z_new = s_new * (W_new @ x_hat) + b_new delta_Z = Z_new - Z.detach() delta_Z_perp = P_perp @ delta_Z proj_norms.append(delta_Z_perp.norm().item()) total_norms.append(delta_Z.norm().item()) return { 'proj_perp_norm_mean': np.mean(proj_norms), 'total_norm_mean': np.mean(total_norms), 'perp_ratio_mean': np.mean([p/(t+1e-10) for p,t in zip(proj_norms, total_norms)]), } # ─────────────────────────────────────────────────────── # §7 Diagnostic 6: Gate statistics # ─────────────────────────────────────────────────────── def diagnostic_6_gate_stats(pre, post): pre_flat = pre.flatten(0, -2) post_flat = post.flatten(0, -2) return { 'pre_mean': pre_flat.mean().item(), 'pre_l1_mean': pre_flat.sum(-1).mean().item(), 'post_mean': post_flat.mean().item(), 'post_linf_mean': post_flat.abs().max(-1)[0].mean().item(), } # ─────────────────────────────────────────────────────── # §8 Main runner # ─────────────────────────────────────────────────────── def run_full_diagnostics(widths=[64,128,256,512,1024], n_s=4, K=20, device='cpu'): results = {} d3 = None for d in widths: print(f"\nWidth d = {d}") cfg = MHCConfig(dim=d, n_s=n_s, K=K, batch_size=4, seq_len=16) hc_fn, hc_scale, hc_base = init_mhc_params(cfg, device) x = make_residual_state(cfg, device) y, post, comb = hc_pre_py(x, hc_fn, hc_scale, hc_base, cfg.n_s, cfg.K, cfg.eps) d1 = diagnostic_1_sinkhorn_error(comb) d2 = diagnostic_2_spectral_norm(comb) if d3 is None: d3 = diagnostic_3_quotient_jacobian(n_s, [1,2,5,10,20,50], 100, device) d4 = diagnostic_4_dynamic_sensitivity(cfg, 30, device) d5 = diagnostic_5_logit_update(cfg, 30, device=device) x_flat = x.flatten(2).float() rsqrt_val = torch.rsqrt(x_flat.square().mean(-1, keepdim=True) + cfg.eps) mixes_val = F.linear(x_flat, hc_fn) * rsqrt_val pre_val, post_val, _ = hc_split_sinkhorn_py(mixes_val, hc_scale, hc_base, cfg.n_s, cfg.K, cfg.eps) d6 = diagnostic_6_gate_stats(pre_val, post_val) results[d] = { 'sinkhorn_error': d1, 'spectral_norm': d2, 'quotient_jacobian': d3, 'dynamic_sensitivity': d4, 'logit_update': d5, 'gate_stats': d6, } print(f" ε_K={d1['eps_K_mean']:.1e}, ||C||_2={d2['spectral_norm_mean']:.6f}") print(f" ||DC||·||x||={d4['dc_x_mean']:.4f}, ||Dp||·||x||={d4['dp_x_mean']:.4f}") return results if __name__ == '__main__': print("Dynamic mHC µP Diagnostics") print("=" * 60) results = run_full_diagnostics() # Scaling analysis widths = sorted(results.keys()) dc_means = [results[d]['dynamic_sensitivity']['dc_x_mean'] for d in widths] log_d = np.log(widths) A = np.vstack([log_d, np.ones_like(log_d)]).T alpha = np.linalg.lstsq(A, np.log(dc_means), rcond=None)[0][0] print(f"\n||DC(x)||·||x|| scales as d^{alpha:.3f}") print(f"\nSummary:") print(f"{'Width':>8} {'||DC||·||x||':>14} {'||Dp||·||x||':>14} {'||C||_2':>10} {'ε_K':>10}") for d in widths: r = results[d] print(f"{d:>8} {r['dynamic_sensitivity']['dc_x_mean']:>14.4f} " f"{r['dynamic_sensitivity']['dp_x_mean']:>14.4f} " f"{r['spectral_norm']['spectral_norm_mean']:>10.6f} " f"{r['sinkhorn_error']['eps_K_mean']:>10.1e}")