geolip-deep-embedding-analysis / cv_sweep_mha_testing.py
AbstractPhil's picture
Create cv_sweep_mha_testing.py
fc29ddb verified
"""
MHA CV Relational Test β€” Prototype
Train a minimal embedding + MHA + classifier on 10 noise patterns.
Measure CV on embedding weights, Q/K/V projections, and attention output
across different head counts per embedding dimension.
Hypothesis: head_dim (D / n_heads) determines CV of internal representations,
and the band-valid head_dims produce qualitatively different geometric behavior.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ── CM primitives ──
def cayley_menger_vol2(points):
B, N, D = points.shape
gram = torch.bmm(points, points.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)
cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)
cm[:, 0, 1:] = 1.0
cm[:, 1:, 0] = 1.0
cm[:, 1:, 1:] = d2
k = N - 1
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))
def cv_metric(weight, n_samples=300):
"""CV of pentachoron volumes. weight: (N, D)"""
V, D = weight.shape
if V < 5:
return None
pool = min(V, 512)
indices = torch.stack([
torch.randperm(pool, device=weight.device)[:5]
for _ in range(n_samples)
])
pts = weight[:pool][indices]
vol2 = cayley_menger_vol2(pts)
valid = vol2 > 1e-20
if valid.sum() < 10:
return None
vols = vol2[valid].sqrt()
return (vols.std() / (vols.mean() + 1e-8)).item()
# ── Minimal model ──
class MHAClassifier(nn.Module):
def __init__(self, vocab, dim, n_heads, seq_len, n_classes):
super().__init__()
self.emb = nn.Embedding(vocab, dim)
self.pos = nn.Parameter(torch.randn(1, seq_len, dim) * 0.02)
self.mha = nn.MultiheadAttention(dim, n_heads, batch_first=True)
self.norm = nn.LayerNorm(dim)
self.head = nn.Linear(dim, n_classes)
def forward(self, x):
# x: (B, seq_len) token indices
h = self.emb(x) + self.pos
attn_out, _ = self.mha(h, h, h)
h = self.norm(h + attn_out)
# pool over sequence
h = h.mean(dim=1)
return self.head(h)
@torch.no_grad()
def forward_activations(self, x, n_heads):
"""Forward pass returning per-head Q/K/V activations and post-attn output.
Returns dict of (B*seq, head_dim) tensors for CV measurement.
"""
h = self.emb(x) + self.pos # (B, seq, D)
B, S, D = h.shape
head_dim = D // n_heads
# Manually compute Q, K, V from in_proj
w = self.mha.in_proj_weight
b = self.mha.in_proj_bias
qkv = F.linear(h, w, b) # (B, seq, 3*D)
q, k, v = qkv.chunk(3, dim=-1) # each (B, seq, D)
# Reshape to per-head: (B, seq, n_heads, head_dim)
q = q.view(B, S, n_heads, head_dim)
k = k.view(B, S, n_heads, head_dim)
v = v.view(B, S, n_heads, head_dim)
# Compute attention output
attn_out, _ = self.mha(h, h, h)
post_attn = self.norm(h + attn_out) # (B, seq, D)
# Post-attn per head view
post_heads = post_attn.view(B, S, n_heads, head_dim)
acts = {}
for i in range(n_heads):
acts[f"act_Q_h{i}"] = q[:, :, i, :].reshape(-1, head_dim)
acts[f"act_K_h{i}"] = k[:, :, i, :].reshape(-1, head_dim)
acts[f"act_V_h{i}"] = v[:, :, i, :].reshape(-1, head_dim)
acts[f"act_post_h{i}"] = post_heads[:, :, i, :].reshape(-1, head_dim)
# Also full-dim activations
acts["act_emb"] = h.reshape(-1, D)
acts["act_post_full"] = post_attn.reshape(-1, D)
return acts
def get_qkv_weights(self):
"""Extract Q, K, V projection weight matrices."""
# nn.MultiheadAttention packs Q, K, V into in_proj_weight: (3*dim, dim)
w = self.mha.in_proj_weight.detach()
d = w.shape[1]
q_w = w[:d] # (dim, dim)
k_w = w[d:2*d] # (dim, dim)
v_w = w[2*d:] # (dim, dim)
return q_w, k_w, v_w
def get_per_head_projections(self, n_heads):
"""Split Q/K/V weights into per-head chunks. Returns list of (head_dim, dim) per head."""
q_w, k_w, v_w = self.get_qkv_weights()
d = q_w.shape[0]
head_dim = d // n_heads
q_heads = [q_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)]
k_heads = [k_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)]
v_heads = [v_w[i*head_dim:(i+1)*head_dim] for i in range(n_heads)]
return q_heads, k_heads, v_heads
# ── Data: 10 noise patterns with perturbations ──
def make_data(n_classes=10, samples_per_class=50, seq_len=8, vocab=256):
"""Create simple classification data. Each class has a base token pattern with noise."""
torch.manual_seed(42)
# Base patterns: each class gets a fixed token sequence
base_patterns = torch.randint(0, vocab, (n_classes, seq_len))
all_x, all_y = [], []
for cls in range(n_classes):
for _ in range(samples_per_class):
pattern = base_patterns[cls].clone()
# Perturb ~25% of positions
mask = torch.rand(seq_len) < 0.25
pattern[mask] = torch.randint(0, vocab, (mask.sum(),))
all_x.append(pattern)
all_y.append(cls)
x = torch.stack(all_x)
y = torch.tensor(all_y)
perm = torch.randperm(len(x))
return x[perm], y[perm]
# ── CV measurement suite ──
def measure_all_cv(model, n_heads, x=None):
"""Measure CV on all relevant weight matrices and activations."""
results = {}
# Embedding weights
emb_w = model.emb.weight.detach()
results["emb"] = cv_metric(emb_w)
# Full Q, K, V projection matrices (dim Γ— dim)
q_w, k_w, v_w = model.get_qkv_weights()
results["Q_full"] = cv_metric(q_w)
results["K_full"] = cv_metric(k_w)
results["V_full"] = cv_metric(v_w)
# Per-head projections (head_dim Γ— dim) β€” CV measured on head_dim rows
q_heads, k_heads, v_heads = model.get_per_head_projections(n_heads)
for i in range(n_heads):
results[f"Q_h{i}"] = cv_metric(q_heads[i])
results[f"K_h{i}"] = cv_metric(k_heads[i])
results[f"V_h{i}"] = cv_metric(v_heads[i])
# Output projection
out_w = model.mha.out_proj.weight.detach()
results["out_proj"] = cv_metric(out_w)
# Classifier head
head_w = model.head.weight.detach()
results["cls_head"] = cv_metric(head_w)
# Activations β€” the space where attention actually operates
if x is not None:
model.eval()
acts = model.forward_activations(x, n_heads)
for name, tensor in acts.items():
results[name] = cv_metric(tensor)
return results
def fmt_cv(cv):
if cv is None:
return " N/A "
band = "*" if 0.13 < cv < 0.30 else " "
return f"{band}{cv:.4f}{band}"
# ── Training + measurement loop ──
def run_experiment(dim, n_heads, vocab=256, seq_len=8, n_classes=10, epochs=50, lr=1e-3):
head_dim = dim // n_heads
print(f"\n{'='*70}")
print(f"D={dim} heads={n_heads} head_dim={head_dim}")
print(f"{'='*70}")
x, y = make_data(n_classes=n_classes, seq_len=seq_len, vocab=vocab)
model = MHAClassifier(vocab, dim, n_heads, seq_len, n_classes)
opt = torch.optim.Adam(model.parameters(), lr=lr)
# Pre-training CV
print(f"\n [pre-train]")
pre_cv = measure_all_cv(model, n_heads, x)
for k, v in pre_cv.items():
print(f" {k:16s}: {fmt_cv(v)}")
# Training
mid_cv = None
for epoch in range(1, epochs + 1):
model.train()
opt.zero_grad()
logits = model(x)
loss = F.cross_entropy(logits, y)
loss.backward()
opt.step()
if epoch == epochs // 2:
model.eval()
with torch.no_grad():
acc = (model(x).argmax(-1) == y).float().mean().item()
mid_cv = measure_all_cv(model, n_heads, x)
print(f"\n [epoch {epoch}] loss={loss.item():.4f} acc={acc:.2%}")
for k, v in mid_cv.items():
print(f" {k:16s}: {fmt_cv(v)}")
# Post-training CV
model.eval()
with torch.no_grad():
acc = (model(x).argmax(-1) == y).float().mean().item()
print(f"\n [post-train] loss={loss.item():.4f} acc={acc:.2%}")
post_cv = measure_all_cv(model, n_heads, x)
for k, v in post_cv.items():
pre = pre_cv.get(k)
delta = ""
if v is not None and pre is not None:
d = v - pre
delta = f" Ξ”={d:+.4f}"
print(f" {k:16s}: {fmt_cv(v)}{delta}")
return {
"dim": dim, "n_heads": n_heads, "head_dim": head_dim,
"pre": pre_cv, "mid": mid_cv, "post": post_cv, "acc": acc,
}
# ── Main ──
if __name__ == "__main__":
print("MHA CV Relational Test β€” Prototype")
print("Band: 0.13 < CV < 0.30")
configs = [
# D=64: head_dims 64, 32, 16, 8
(64, 1),
(64, 2),
(64, 4),
(64, 8),
# D=128: head_dims 128, 64, 32, 16
(128, 1),
(128, 2),
(128, 4),
(128, 8),
# D=256: head_dims 256, 128, 64, 32
(256, 1),
(256, 2),
(256, 4),
(256, 8),
]
all_results = []
for dim, n_heads in configs:
r = run_experiment(dim, n_heads)
all_results.append(r)
# Summary β€” Weights
print(f"\n\n{'='*70}")
print("SUMMARY: Post-training WEIGHT CV by head_dim")
print(f"{'='*70}")
print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'emb':>8} {'Q_full':>8} {'K_full':>8} {'V_full':>8} {'out':>8} | acc")
print("-" * 80)
for r in all_results:
p = r["post"]
print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | "
f"{fmt_cv(p.get('emb')):>8} {fmt_cv(p.get('Q_full')):>8} "
f"{fmt_cv(p.get('K_full')):>8} {fmt_cv(p.get('V_full')):>8} "
f"{fmt_cv(p.get('out_proj')):>8} | {r['acc']:.2%}")
# Summary β€” Activations (the real test)
print(f"\n\n{'='*70}")
print("SUMMARY: Post-training ACTIVATION CV by head_dim")
print("(These measure the space where attention actually operates)")
print(f"{'='*70}")
print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'act_emb':>8} {'aQ_h0':>8} {'aK_h0':>8} {'aV_h0':>8} {'aPost0':>8} {'act_full':>8} | acc")
print("-" * 90)
for r in all_results:
p = r["post"]
print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | "
f"{fmt_cv(p.get('act_emb')):>8} "
f"{fmt_cv(p.get('act_Q_h0')):>8} {fmt_cv(p.get('act_K_h0')):>8} "
f"{fmt_cv(p.get('act_V_h0')):>8} {fmt_cv(p.get('act_post_h0')):>8} "
f"{fmt_cv(p.get('act_post_full')):>8} | {r['acc']:.2%}")
# Summary — Activation CV delta (pre→post)
print(f"\n\n{'='*70}")
print("SUMMARY: ACTIVATION CV movement (post - pre)")
print(f"{'='*70}")
print(f"{'D':>5} {'heads':>5} {'hdim':>5} | {'act_emb':>8} {'aQ_h0':>8} {'aK_h0':>8} {'aV_h0':>8} {'aPost0':>8} {'act_full':>8}")
print("-" * 80)
for r in all_results:
pre, post = r["pre"], r["post"]
def delta(k):
a, b = pre.get(k), post.get(k)
if a is not None and b is not None:
d = b - a
return f"{d:+.4f}"
return " N/A "
print(f"{r['dim']:5d} {r['n_heads']:5d} {r['head_dim']:5d} | "
f"{delta('act_emb'):>8} "
f"{delta('act_Q_h0'):>8} {delta('act_K_h0'):>8} "
f"{delta('act_V_h0'):>8} {delta('act_post_h0'):>8} "
f"{delta('act_post_full'):>8}")