| """ |
| MHA CV Relational Test β Prototype |
| Train a minimal embedding + MHA + classifier on 10 noise patterns. |
| Measure CV on embedding weights, Q/K/V projections, and attention output |
| across different head counts per embedding dimension. |
| |
| Hypothesis: head_dim (D / n_heads) determines CV of internal representations, |
| and the band-valid head_dims produce qualitatively different geometric behavior. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| |
|
|
| def cayley_menger_vol2(points): |
| B, N, D = points.shape |
| gram = torch.bmm(points, points.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) |
| cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype) |
| cm[:, 0, 1:] = 1.0 |
| cm[:, 1:, 0] = 1.0 |
| cm[:, 1:, 1:] = d2 |
| k = N - 1 |
| sign = (-1.0) ** (k + 1) |
| fact = math.factorial(k) |
| return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2)) |
|
|
|
|
| def cv_metric(weight, n_samples=300): |
| """CV of pentachoron volumes. weight: (N, D)""" |
| V, D = weight.shape |
| if V < 5: |
| return None |
| pool = min(V, 512) |
| indices = torch.stack([ |
| torch.randperm(pool, device=weight.device)[:5] |
| for _ in range(n_samples) |
| ]) |
| pts = weight[:pool][indices] |
| vol2 = cayley_menger_vol2(pts) |
| valid = vol2 > 1e-20 |
| if valid.sum() < 10: |
| return None |
| vols = vol2[valid].sqrt() |
| return (vols.std() / (vols.mean() + 1e-8)).item() |
|
|
|
|
| |
|
|
| class MHAClassifier(nn.Module): |
| def __init__(self, vocab, dim, n_heads, seq_len, n_classes): |
| super().__init__() |
| self.emb = nn.Embedding(vocab, dim) |
| self.pos = nn.Parameter(torch.randn(1, seq_len, dim) * 0.02) |
| self.mha = nn.MultiheadAttention(dim, n_heads, batch_first=True) |
| self.norm = nn.LayerNorm(dim) |
| self.head = nn.Linear(dim, n_classes) |
|
|
| def forward(self, x): |
| |
| h = self.emb(x) + self.pos |
| attn_out, _ = self.mha(h, h, h) |
| h = self.norm(h + attn_out) |
| |
| h = h.mean(dim=1) |
| return self.head(h) |
|
|
| @torch.no_grad() |
| def forward_activations(self, x, n_heads): |
| """Forward pass returning per-head Q/K/V activations and post-attn output. |
| |
| Returns dict of (B*seq, head_dim) tensors for CV measurement. |
| """ |
| h = self.emb(x) + self.pos |
| B, S, D = h.shape |
| head_dim = D // n_heads |
|
|
| |
| w = self.mha.in_proj_weight |
| b = self.mha.in_proj_bias |
| qkv = F.linear(h, w, b) |
| q, k, v = qkv.chunk(3, dim=-1) |
|
|
| |
| q = q.view(B, S, n_heads, head_dim) |
| k = k.view(B, S, n_heads, head_dim) |
| v = v.view(B, S, n_heads, head_dim) |
|
|
| |
| attn_out, _ = self.mha(h, h, h) |
| post_attn = self.norm(h + attn_out) |
| |
| post_heads = post_attn.view(B, S, n_heads, head_dim) |
|
|
| acts = {} |
| for i in range(n_heads): |
| acts[f"act_Q_h{i}"] = q[:, :, i, :].reshape(-1, head_dim) |
| acts[f"act_K_h{i}"] = k[:, :, i, :].reshape(-1, head_dim) |
| acts[f"act_V_h{i}"] = v[:, :, i, :].reshape(-1, head_dim) |
| acts[f"act_post_h{i}"] = post_heads[:, :, i, :].reshape(-1, head_dim) |
|
|
| |
| acts["act_emb"] = h.reshape(-1, D) |
| acts["act_post_full"] = post_attn.reshape(-1, D) |
|
|
| return acts |
|
|
| def get_qkv_weights(self): |
| """Extract Q, K, V projection weight matrices.""" |
| |
| w = self.mha.in_proj_weight.detach() |
| d = w.shape[1] |
| q_w = w[:d] |
| k_w = w[d:2*d] |
| v_w = w[2*d:] |
| return q_w, k_w, v_w |
|
|
| def get_per_head_projections(self, n_heads): |
| """Split Q/K/V weights into per-head chunks. Returns list of (head_dim, dim) per head.""" |
| q_w, k_w, v_w = self.get_qkv_weights() |
| d = q_w.shape[0] |
| head_dim = d // n_heads |
| q_heads = [q_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)] |
| k_heads = [k_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)] |
| v_heads = [v_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)] |
| return q_heads, k_heads, v_heads |
|
|
|
|
| |
|
|
| def make_data(n_classes=10, samples_per_class=50, seq_len=8, vocab=256): |
| """Create simple classification data. Each class has a base token pattern with noise.""" |
| torch.manual_seed(42) |
| |
| base_patterns = torch.randint(0, vocab, (n_classes, seq_len)) |
|
|
| all_x, all_y = [], [] |
| for cls in range(n_classes): |
| for _ in range(samples_per_class): |
| pattern = base_patterns[cls].clone() |
| |
| mask = torch.rand(seq_len) < 0.25 |
| pattern[mask] = torch.randint(0, vocab, (mask.sum(),)) |
| all_x.append(pattern) |
| all_y.append(cls) |
|
|
| x = torch.stack(all_x) |
| y = torch.tensor(all_y) |
| perm = torch.randperm(len(x)) |
| return x[perm], y[perm] |
|
|
|
|
| |
|
|
| def measure_all_cv(model, n_heads, x=None): |
| """Measure CV on all relevant weight matrices and activations.""" |
| results = {} |
|
|
| |
| emb_w = model.emb.weight.detach() |
| results["emb"] = cv_metric(emb_w) |
|
|
| |
| q_w, k_w, v_w = model.get_qkv_weights() |
| results["Q_full"] = cv_metric(q_w) |
| results["K_full"] = cv_metric(k_w) |
| results["V_full"] = cv_metric(v_w) |
|
|
| |
| q_heads, k_heads, v_heads = model.get_per_head_projections(n_heads) |
| for i in range(n_heads): |
| results[f"Q_h{i}"] = cv_metric(q_heads[i]) |
| results[f"K_h{i}"] = cv_metric(k_heads[i]) |
| results[f"V_h{i}"] = cv_metric(v_heads[i]) |
|
|
| |
| out_w = model.mha.out_proj.weight.detach() |
| results["out_proj"] = cv_metric(out_w) |
|
|
| |
| head_w = model.head.weight.detach() |
| results["cls_head"] = cv_metric(head_w) |
|
|
| |
| if x is not None: |
| model.eval() |
| acts = model.forward_activations(x, n_heads) |
| for name, tensor in acts.items(): |
| results[name] = cv_metric(tensor) |
|
|
| return results |
|
|
|
|
| def fmt_cv(cv): |
| if cv is None: |
| return " N/A " |
| band = "*" if 0.13 < cv < 0.30 else " " |
| return f"{band}{cv:.4f}{band}" |
|
|
|
|
| |
|
|
| def run_experiment(dim, n_heads, vocab=256, seq_len=8, n_classes=10, epochs=50, lr=1e-3): |
| head_dim = dim // n_heads |
| print(f"\n{'='*70}") |
| print(f"D={dim} heads={n_heads} head_dim={head_dim}") |
| print(f"{'='*70}") |
|
|
| x, y = make_data(n_classes=n_classes, seq_len=seq_len, vocab=vocab) |
| model = MHAClassifier(vocab, dim, n_heads, seq_len, n_classes) |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
|
|
| |
| print(f"\n [pre-train]") |
| pre_cv = measure_all_cv(model, n_heads, x) |
| for k, v in pre_cv.items(): |
| print(f" {k:16s}: {fmt_cv(v)}") |
|
|
| |
| mid_cv = None |
| for epoch in range(1, epochs + 1): |
| model.train() |
| opt.zero_grad() |
| logits = model(x) |
| loss = F.cross_entropy(logits, y) |
| loss.backward() |
| opt.step() |
|
|
| if epoch == epochs // 2: |
| model.eval() |
| with torch.no_grad(): |
| acc = (model(x).argmax(-1) == y).float().mean().item() |
| mid_cv = measure_all_cv(model, n_heads, x) |
| print(f"\n [epoch {epoch}] loss={loss.item():.4f} acc={acc:.2%}") |
| for k, v in mid_cv.items(): |
| print(f" {k:16s}: {fmt_cv(v)}") |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| acc = (model(x).argmax(-1) == y).float().mean().item() |
| print(f"\n [post-train] loss={loss.item():.4f} acc={acc:.2%}") |
| post_cv = measure_all_cv(model, n_heads, x) |
| for k, v in post_cv.items(): |
| pre = pre_cv.get(k) |
| delta = "" |
| if v is not None and pre is not None: |
| d = v - pre |
| delta = f" Ξ={d:+.4f}" |
| print(f" {k:16s}: {fmt_cv(v)}{delta}") |
|
|
| return { |
| "dim": dim, "n_heads": n_heads, "head_dim": head_dim, |
| "pre": pre_cv, "mid": mid_cv, "post": post_cv, "acc": acc, |
| } |
|
|
|
|
| |
|
|
| if __name__ == "__main__": |
| print("MHA CV Relational Test β Prototype") |
| print("Band: 0.13 < CV < 0.30") |
|
|
| configs = [ |
| |
| (64, 1), |
| (64, 2), |
| (64, 4), |
| (64, 8), |
| |
| (128, 1), |
| (128, 2), |
| (128, 4), |
| (128, 8), |
| |
| (256, 1), |
| (256, 2), |
| (256, 4), |
| (256, 8), |
| ] |
|
|
| all_results = [] |
| for dim, n_heads in configs: |
| r = run_experiment(dim, n_heads) |
| all_results.append(r) |
|
|
| |
| print(f"\n\n{'='*70}") |
| print("SUMMARY: Post-training WEIGHT CV by head_dim") |
| print(f"{'='*70}") |
| print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'emb':>8} {'Q_full':>8} {'K_full':>8} {'V_full':>8} {'out':>8} | acc") |
| print("-" * 80) |
| for r in all_results: |
| p = r["post"] |
| print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | " |
| f"{fmt_cv(p.get('emb')):>8} {fmt_cv(p.get('Q_full')):>8} " |
| f"{fmt_cv(p.get('K_full')):>8} {fmt_cv(p.get('V_full')):>8} " |
| f"{fmt_cv(p.get('out_proj')):>8} | {r['acc']:.2%}") |
|
|
| |
| print(f"\n\n{'='*70}") |
| print("SUMMARY: Post-training ACTIVATION CV by head_dim") |
| print("(These measure the space where attention actually operates)") |
| print(f"{'='*70}") |
| print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'act_emb':>8} {'aQ_h0':>8} {'aK_h0':>8} {'aV_h0':>8} {'aPost0':>8} {'act_full':>8} | acc") |
| print("-" * 90) |
| for r in all_results: |
| p = r["post"] |
| print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | " |
| f"{fmt_cv(p.get('act_emb')):>8} " |
| f"{fmt_cv(p.get('act_Q_h0')):>8} {fmt_cv(p.get('act_K_h0')):>8} " |
| f"{fmt_cv(p.get('act_V_h0')):>8} {fmt_cv(p.get('act_post_h0')):>8} " |
| f"{fmt_cv(p.get('act_post_full')):>8} | {r['acc']:.2%}") |
|
|
| |
| print(f"\n\n{'='*70}") |
| print("SUMMARY: ACTIVATION CV movement (post - pre)") |
| print(f"{'='*70}") |
| print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'act_emb':>8} {'aQ_h0':>8} {'aK_h0':>8} {'aV_h0':>8} {'aPost0':>8} {'act_full':>8}") |
| print("-" * 80) |
| for r in all_results: |
| pre, post = r["pre"], r["post"] |
| def delta(k): |
| a, b = pre.get(k), post.get(k) |
| if a is not None and b is not None: |
| d = b - a |
| return f"{d:+.4f}" |
| return " N/A " |
| print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | " |
| f"{delta('act_emb'):>8} " |
| f"{delta('act_Q_h0'):>8} {delta('act_K_h0'):>8} " |
| f"{delta('act_V_h0'):>8} {delta('act_post_h0'):>8} " |
| f"{delta('act_post_full'):>8}") |