""" MHA CV Relational Test — Prototype Train a minimal embedding + MHA + classifier on 10 noise patterns. Measure CV on embedding weights, Q/K/V projections, and attention output across different head counts per embedding dimension. Hypothesis: head_dim (D / n_heads) determines CV of internal representations, and the band-valid head_dims produce qualitatively different geometric behavior. """ import torch import torch.nn as nn import torch.nn.functional as F import math # ── CM primitives ── def cayley_menger_vol2(points): B, N, D = points.shape gram = torch.bmm(points, points.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram) cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype) cm[:, 0, 1:] = 1.0 cm[:, 1:, 0] = 1.0 cm[:, 1:, 1:] = d2 k = N - 1 sign = (-1.0) ** (k + 1) fact = math.factorial(k) return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2)) def cv_metric(weight, n_samples=300): """CV of pentachoron volumes. weight: (N, D)""" V, D = weight.shape if V < 5: return None pool = min(V, 512) indices = torch.stack([ torch.randperm(pool, device=weight.device)[:5] for _ in range(n_samples) ]) pts = weight[:pool][indices] vol2 = cayley_menger_vol2(pts) valid = vol2 > 1e-20 if valid.sum() < 10: return None vols = vol2[valid].sqrt() return (vols.std() / (vols.mean() + 1e-8)).item() # ── Minimal model ── class MHAClassifier(nn.Module): def __init__(self, vocab, dim, n_heads, seq_len, n_classes): super().__init__() self.emb = nn.Embedding(vocab, dim) self.pos = nn.Parameter(torch.randn(1, seq_len, dim) * 0.02) self.mha = nn.MultiheadAttention(dim, n_heads, batch_first=True) self.norm = nn.LayerNorm(dim) self.head = nn.Linear(dim, n_classes) def forward(self, x): # x: (B, seq_len) token indices h = self.emb(x) + self.pos attn_out, _ = self.mha(h, h, h) h = self.norm(h + attn_out) # pool over sequence h = h.mean(dim=1) return self.head(h) @torch.no_grad() def forward_activations(self, x, n_heads): """Forward pass returning per-head Q/K/V activations and post-attn output. Returns dict of (B*seq, head_dim) tensors for CV measurement. """ h = self.emb(x) + self.pos # (B, seq, D) B, S, D = h.shape head_dim = D // n_heads # Manually compute Q, K, V from in_proj w = self.mha.in_proj_weight b = self.mha.in_proj_bias qkv = F.linear(h, w, b) # (B, seq, 3*D) q, k, v = qkv.chunk(3, dim=-1) # each (B, seq, D) # Reshape to per-head: (B, seq, n_heads, head_dim) q = q.view(B, S, n_heads, head_dim) k = k.view(B, S, n_heads, head_dim) v = v.view(B, S, n_heads, head_dim) # Compute attention output attn_out, _ = self.mha(h, h, h) post_attn = self.norm(h + attn_out) # (B, seq, D) # Post-attn per head view post_heads = post_attn.view(B, S, n_heads, head_dim) acts = {} for i in range(n_heads): acts[f"act_Q_h{i}"] = q[:, :, i, :].reshape(-1, head_dim) acts[f"act_K_h{i}"] = k[:, :, i, :].reshape(-1, head_dim) acts[f"act_V_h{i}"] = v[:, :, i, :].reshape(-1, head_dim) acts[f"act_post_h{i}"] = post_heads[:, :, i, :].reshape(-1, head_dim) # Also full-dim activations acts["act_emb"] = h.reshape(-1, D) acts["act_post_full"] = post_attn.reshape(-1, D) return acts def get_qkv_weights(self): """Extract Q, K, V projection weight matrices.""" # nn.MultiheadAttention packs Q, K, V into in_proj_weight: (3*dim, dim) w = self.mha.in_proj_weight.detach() d = w.shape[1] q_w = w[:d] # (dim, dim) k_w = w[d:2*d] # (dim, dim) v_w = w[2*d:] # (dim, dim) return q_w, k_w, v_w def get_per_head_projections(self, n_heads): """Split Q/K/V weights into per-head chunks. Returns list of (head_dim, dim) per head.""" q_w, k_w, v_w = self.get_qkv_weights() d = q_w.shape[0] head_dim = d // n_heads q_heads = [q_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)] k_heads = [k_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)] v_heads = [v_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)] return q_heads, k_heads, v_heads # ── Data: 10 noise patterns with perturbations ── def make_data(n_classes=10, samples_per_class=50, seq_len=8, vocab=256): """Create simple classification data. Each class has a base token pattern with noise.""" torch.manual_seed(42) # Base patterns: each class gets a fixed token sequence base_patterns = torch.randint(0, vocab, (n_classes, seq_len)) all_x, all_y = [], [] for cls in range(n_classes): for _ in range(samples_per_class): pattern = base_patterns[cls].clone() # Perturb ~25% of positions mask = torch.rand(seq_len) < 0.25 pattern[mask] = torch.randint(0, vocab, (mask.sum(),)) all_x.append(pattern) all_y.append(cls) x = torch.stack(all_x) y = torch.tensor(all_y) perm = torch.randperm(len(x)) return x[perm], y[perm] # ── CV measurement suite ── def measure_all_cv(model, n_heads, x=None): """Measure CV on all relevant weight matrices and activations.""" results = {} # Embedding weights emb_w = model.emb.weight.detach() results["emb"] = cv_metric(emb_w) # Full Q, K, V projection matrices (dim × dim) q_w, k_w, v_w = model.get_qkv_weights() results["Q_full"] = cv_metric(q_w) results["K_full"] = cv_metric(k_w) results["V_full"] = cv_metric(v_w) # Per-head projections (head_dim × dim) — CV measured on head_dim rows q_heads, k_heads, v_heads = model.get_per_head_projections(n_heads) for i in range(n_heads): results[f"Q_h{i}"] = cv_metric(q_heads[i]) results[f"K_h{i}"] = cv_metric(k_heads[i]) results[f"V_h{i}"] = cv_metric(v_heads[i]) # Output projection out_w = model.mha.out_proj.weight.detach() results["out_proj"] = cv_metric(out_w) # Classifier head head_w = model.head.weight.detach() results["cls_head"] = cv_metric(head_w) # Activations — the space where attention actually operates if x is not None: model.eval() acts = model.forward_activations(x, n_heads) for name, tensor in acts.items(): results[name] = cv_metric(tensor) return results def fmt_cv(cv): if cv is None: return " N/A " band = "*" if 0.13 < cv < 0.30 else " " return f"{band}{cv:.4f}{band}" # ── Training + measurement loop ── def run_experiment(dim, n_heads, vocab=256, seq_len=8, n_classes=10, epochs=50, lr=1e-3): head_dim = dim // n_heads print(f"\n{'='*70}") print(f"D={dim} heads={n_heads} head_dim={head_dim}") print(f"{'='*70}") x, y = make_data(n_classes=n_classes, seq_len=seq_len, vocab=vocab) model = MHAClassifier(vocab, dim, n_heads, seq_len, n_classes) opt = torch.optim.Adam(model.parameters(), lr=lr) # Pre-training CV print(f"\n [pre-train]") pre_cv = measure_all_cv(model, n_heads, x) for k, v in pre_cv.items(): print(f" {k:16s}: {fmt_cv(v)}") # Training mid_cv = None for epoch in range(1, epochs + 1): model.train() opt.zero_grad() logits = model(x) loss = F.cross_entropy(logits, y) loss.backward() opt.step() if epoch == epochs // 2: model.eval() with torch.no_grad(): acc = (model(x).argmax(-1) == y).float().mean().item() mid_cv = measure_all_cv(model, n_heads, x) print(f"\n [epoch {epoch}] loss={loss.item():.4f} acc={acc:.2%}") for k, v in mid_cv.items(): print(f" {k:16s}: {fmt_cv(v)}") # Post-training CV model.eval() with torch.no_grad(): acc = (model(x).argmax(-1) == y).float().mean().item() print(f"\n [post-train] loss={loss.item():.4f} acc={acc:.2%}") post_cv = measure_all_cv(model, n_heads, x) for k, v in post_cv.items(): pre = pre_cv.get(k) delta = "" if v is not None and pre is not None: d = v - pre delta = f" Δ={d:+.4f}" print(f" {k:16s}: {fmt_cv(v)}{delta}") return { "dim": dim, "n_heads": n_heads, "head_dim": head_dim, "pre": pre_cv, "mid": mid_cv, "post": post_cv, "acc": acc, } # ── Main ── if __name__ == "__main__": print("MHA CV Relational Test — Prototype") print("Band: 0.13 < CV < 0.30") configs = [ # D=64: head_dims 64, 32, 16, 8 (64, 1), (64, 2), (64, 4), (64, 8), # D=128: head_dims 128, 64, 32, 16 (128, 1), (128, 2), (128, 4), (128, 8), # D=256: head_dims 256, 128, 64, 32 (256, 1), (256, 2), (256, 4), (256, 8), ] all_results = [] for dim, n_heads in configs: r = run_experiment(dim, n_heads) all_results.append(r) # Summary — Weights print(f"\n\n{'='*70}") print("SUMMARY: Post-training WEIGHT CV by head_dim") print(f"{'='*70}") print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'emb':>8} {'Q_full':>8} {'K_full':>8} {'V_full':>8} {'out':>8} | acc") print("-" * 80) for r in all_results: p = r["post"] print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | " f"{fmt_cv(p.get('emb')):>8} {fmt_cv(p.get('Q_full')):>8} " f"{fmt_cv(p.get('K_full')):>8} {fmt_cv(p.get('V_full')):>8} " f"{fmt_cv(p.get('out_proj')):>8} | {r['acc']:.2%}") # Summary — Activations (the real test) print(f"\n\n{'='*70}") print("SUMMARY: Post-training ACTIVATION CV by head_dim") print("(These measure the space where attention actually operates)") print(f"{'='*70}") print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'act_emb':>8} {'aQ_h0':>8} {'aK_h0':>8} {'aV_h0':>8} {'aPost0':>8} {'act_full':>8} | acc") print("-" * 90) for r in all_results: p = r["post"] print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | " f"{fmt_cv(p.get('act_emb')):>8} " f"{fmt_cv(p.get('act_Q_h0')):>8} {fmt_cv(p.get('act_K_h0')):>8} " f"{fmt_cv(p.get('act_V_h0')):>8} {fmt_cv(p.get('act_post_h0')):>8} " f"{fmt_cv(p.get('act_post_full')):>8} | {r['acc']:.2%}") # Summary — Activation CV delta (pre→post) print(f"\n\n{'='*70}") print("SUMMARY: ACTIVATION CV movement (post - pre)") print(f"{'='*70}") print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'act_emb':>8} {'aQ_h0':>8} {'aK_h0':>8} {'aV_h0':>8} {'aPost0':>8} {'act_full':>8}") print("-" * 80) for r in all_results: pre, post = r["pre"], r["post"] def delta(k): a, b = pre.get(k), post.get(k) if a is not None and b is not None: d = b - a return f"{d:+.4f}" return " N/A " print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | " f"{delta('act_emb'):>8} " f"{delta('act_Q_h0'):>8} {delta('act_K_h0'):>8} " f"{delta('act_V_h0'):>8} {delta('act_post_h0'):>8} " f"{delta('act_post_full'):>8}")