geolip-constellation-core / constelation_v11_cross_token_preservation.py
AbstractPhil's picture
Update constelation_v11_cross_token_preservation.py
786b9ac verified
#!/usr/bin/env python3
"""
Constellation Core V11
Hybrid Constellation Relay v2
================================
Fixes from v1:
- Split gates: fixed_gate (cold, -3.0) + dynamic_gate (warm, -1.0)
- Balanced: 8 fixed + 8 dynamic per patch
- Separate dynamic MLP before merge
- Proper causal intervention test for cross-token routing
- V-projection: dynamic anchors carry value information, not just position
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(42)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
HAS_FP8 = hasattr(torch, 'float8_e4m3fn')
def compute_cv(points, n_samples=1500, n_points=5):
N = points.shape[0]
if N < n_points: return float('nan')
points = F.normalize(points.to(DEVICE).float(), dim=-1)
vols = []
for _ in range(n_samples):
idx = torch.randperm(min(N, 10000), device=DEVICE)[:n_points]
pts = points[idx].unsqueeze(0)
gram = torch.bmm(pts, pts.transpose(1, 2))
norms = torch.diagonal(gram, dim1=1, dim2=2)
d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram
d2 = F.relu(d2)
cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32)
cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2
v2 = -torch.linalg.det(cm) / 9216
if v2[0].item() > 1e-20:
vols.append(v2[0].sqrt().cpu())
if len(vols) < 50: return float('nan')
vt = torch.stack(vols)
return (vt.std() / (vt.mean() + 1e-8)).item()
def eff_dim(x):
x_c = x - x.mean(0, keepdim=True)
_, S, _ = torch.linalg.svd(x_c[:512].float(), full_matrices=False)
p = S / S.sum()
return p.pow(2).sum().reciprocal().item()
def uniform_sphere(n, d):
return F.normalize(torch.randn(n, d), dim=-1)
# ══════════════════════════════════════════════════════════════════
# HYBRID CONSTELLATION RELAY v2
# ══════════════════════════════════════════════════════════════════
class HybridRelay(nn.Module):
"""
Fixed geometric anchors + dynamic cross-token anchors.
Split processing paths with separate gates.
Per patch (d=16):
Fixed path: A_f anchors Γ— n_phases β†’ fixed_mlp β†’ fixed_out (d)
Dynamic path: top-k QΒ·K selection β†’ gather V β†’ dynamic_mlp β†’ dyn_out (d)
Output: fixed_gate * fixed_out + dyn_gate * dyn_out + (1-both) * identity
"""
def __init__(
self,
input_dim,
patch_dim=16,
n_fixed=8,
n_dynamic=8,
n_phases=3,
pw_hidden=32,
fixed_gate_init=-3.0, # sigmoid β‰ˆ 0.047
dyn_gate_init=-1.0, # sigmoid β‰ˆ 0.269
):
super().__init__()
assert input_dim % patch_dim == 0
self.input_dim = input_dim
self.patch_dim = patch_dim
self.n_patches = input_dim // patch_dim
self.n_fixed = n_fixed
self.n_dynamic = n_dynamic
self.n_phases = n_phases
P, Af, k, d = self.n_patches, n_fixed, n_dynamic, patch_dim
# ── Fixed constellation ──
home = torch.empty(P, Af, d)
nn.init.xavier_normal_(home.view(P * Af, d))
home = F.normalize(home.view(P, Af, d), dim=-1)
self.register_buffer('home', home)
self.anchors = nn.Parameter(home.clone())
# Fixed path MLP: (phases * Af) β†’ d
fixed_tri_dim = n_phases * Af
self.fixed_w1 = nn.Parameter(torch.empty(P, fixed_tri_dim, pw_hidden))
self.fixed_b1 = nn.Parameter(torch.zeros(1, 1, P, pw_hidden))
self.fixed_w2 = nn.Parameter(torch.empty(P, pw_hidden, d))
self.fixed_b2 = nn.Parameter(torch.zeros(1, 1, P, d))
for p in range(P):
nn.init.xavier_normal_(self.fixed_w1.data[p])
nn.init.xavier_normal_(self.fixed_w2.data[p])
self.fixed_norm = nn.LayerNorm(d)
# ── Dynamic cross-token path ──
# Q, K for selection; V for information transfer
self.q_proj = nn.Parameter(torch.empty(P, d, d))
self.k_proj = nn.Parameter(torch.empty(P, d, d))
self.v_proj = nn.Parameter(torch.empty(P, d, d))
for p in range(P):
nn.init.xavier_normal_(self.q_proj.data[p])
nn.init.xavier_normal_(self.k_proj.data[p])
nn.init.xavier_normal_(self.v_proj.data[p])
# Dynamic path MLP: (k * d) β†’ d (reads gathered V values)
dyn_input_dim = k * d
self.dyn_w1 = nn.Parameter(torch.empty(P, dyn_input_dim, pw_hidden))
self.dyn_b1 = nn.Parameter(torch.zeros(1, 1, P, pw_hidden))
self.dyn_w2 = nn.Parameter(torch.empty(P, pw_hidden, d))
self.dyn_b2 = nn.Parameter(torch.zeros(1, 1, P, d))
for p in range(P):
nn.init.xavier_normal_(self.dyn_w1.data[p])
nn.init.xavier_normal_(self.dyn_w2.data[p])
self.dyn_norm = nn.LayerNorm(d)
# ── Split gates ──
self.fixed_gate = nn.Parameter(torch.full((P,), fixed_gate_init))
self.dyn_gate = nn.Parameter(torch.full((P,), dyn_gate_init))
self.norm = nn.LayerNorm(input_dim)
def drift(self):
h = F.normalize(self.home, dim=-1)
c = F.normalize(self.anchors, dim=-1)
cos = (h * c).sum(dim=-1).clamp(-1 + 1e-7, 1 - 1e-7)
return torch.acos(cos)
def at_phase(self, t):
h = F.normalize(self.home, dim=-1)
c = F.normalize(self.anchors, dim=-1)
omega = self.drift().unsqueeze(-1)
sin_omega = omega.sin().clamp(min=1e-7)
return (torch.sin((1 - t) * omega) / sin_omega * h +
torch.sin(t * omega) / sin_omega * c)
def forward(self, x, return_diagnostics=False):
"""x: (B, S, D)"""
B, S, D = x.shape
P, Af, k, d = self.n_patches, self.n_fixed, self.n_dynamic, self.patch_dim
x_n = self.norm(x)
patches = x_n.reshape(B, S, P, d)
patches_n = F.normalize(patches, dim=-1)
# ══════ FIXED PATH ══════
phases = torch.linspace(0, 1, self.n_phases).tolist()
fixed_tris = []
for t in phases:
anchors_t = F.normalize(self.at_phase(t), dim=-1) # (P, Af, d)
cos_f = torch.einsum('bspd,pad->bspa', patches_n, anchors_t)
fixed_tris.append(1.0 - cos_f)
fixed_tri = torch.cat(fixed_tris, dim=-1) # (B, S, P, Af*phases)
h_f = torch.einsum('bspt,pth->bsph', fixed_tri, self.fixed_w1) + self.fixed_b1
h_f = F.gelu(h_f)
fixed_out = torch.einsum('bsph,phd->bspd', h_f, self.fixed_w2) + self.fixed_b2
fixed_out = self.fixed_norm(fixed_out) # (B, S, P, d)
# ══════ DYNAMIC PATH ══════
# Q, K, V projections
Q = F.normalize(torch.einsum('bspd,pde->bspe', patches_n, self.q_proj), dim=-1)
K = F.normalize(torch.einsum('bspd,pde->bspe', patches_n, self.k_proj), dim=-1)
V = torch.einsum('bspd,pde->bspe', patches, self.v_proj) # V not normalized β€” carries magnitude
# Relevance: Q_i Β· K_j β†’ (B, P, S, S)
relevance = torch.einsum('bspd,btpd->bpst', Q, K)
# Mask self
self_mask = torch.eye(S, device=x.device, dtype=torch.bool)
relevance = relevance.masked_fill(self_mask.unsqueeze(0).unsqueeze(0), -1e9)
# Soft top-k: take softmax over keys, then gather top-k
# This makes gradients flow through the selection
rel_weights = relevance.softmax(dim=-1) # (B, P, S, S)
# Top-k indices for sparse gather
_, topk_idx = relevance.topk(k, dim=-1) # (B, P, S, k)
# Gather top-k weights and re-normalize
topk_weights = torch.gather(rel_weights, -1, topk_idx) # (B, P, S, k)
topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-8)
# Gather top-k V vectors: V is (B, S, P, d) β†’ need (B, P, S, d)
V_perm = V.permute(0, 2, 1, 3) # (B, P, S, d)
# For each (b, p, s), gather V[b, p, topk_idx[b,p,s,:], :]
topk_idx_v = topk_idx.unsqueeze(-1).expand(-1, -1, -1, -1, d) # (B, P, S, k, d)
V_expanded = V_perm.unsqueeze(2).expand(-1, -1, S, -1, -1) # (B, P, S, S, d)
topk_V = torch.gather(V_expanded, 3, topk_idx_v) # (B, P, S, k, d)
# Weighted sum of top-k values
weighted_V = (topk_weights.unsqueeze(-1) * topk_V).reshape(B, P, S, k * d)
# β†’ (B, S, P, k*d)
weighted_V = weighted_V.permute(0, 2, 1, 3)
# Dynamic MLP
h_d = torch.einsum('bspt,pth->bsph', weighted_V, self.dyn_w1) + self.dyn_b1
h_d = F.gelu(h_d)
dyn_out = torch.einsum('bsph,phd->bspd', h_d, self.dyn_w2) + self.dyn_b2
dyn_out = self.dyn_norm(dyn_out) # (B, S, P, d)
# ══════ GATED MERGE ══════
fg = self.fixed_gate.sigmoid().view(1, 1, P, 1)
dg = self.dyn_gate.sigmoid().view(1, 1, P, 1)
# Identity weight = 1 - fg - dg (can go negative if both gates high, but sigmoid caps each at 1)
identity_weight = (1.0 - fg - dg).clamp(min=0.0)
blended = fg * fixed_out + dg * dyn_out + identity_weight * patches
out = blended.reshape(B, S, D)
result = x + out
if return_diagnostics:
drift = self.drift()
diag = {
'drift_mean': drift.mean().item(),
'fixed_gate': self.fixed_gate.sigmoid().mean().item(),
'dyn_gate': self.dyn_gate.sigmoid().mean().item(),
'identity_weight': identity_weight.mean().item(),
'topk_cos_mean': torch.gather(relevance, -1, topk_idx).mean().item(),
'topk_cos_max': torch.gather(relevance, -1, topk_idx).max().item(),
}
return result, diag
return result
# ══════════════════════════════════════════════════════════════════
# COMPARISON MODULES
# ══════════════════════════════════════════════════════════════════
class VanillaAttn(nn.Module):
def __init__(self, dim, n_heads=4):
super().__init__()
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.qkv = nn.Linear(dim, 3 * dim, bias=False)
self.out_proj = nn.Linear(dim, dim, bias=False)
self.norm = nn.LayerNorm(dim)
def forward(self, x):
B, S, D = x.shape
x_n = self.norm(x)
qkv = self.qkv(x_n).reshape(B, S, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
attn = attn.softmax(dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, S, D)
return x + self.out_proj(out)
class PureRelay(nn.Module):
def __init__(self, input_dim, patch_dim=16, n_anchors=16, n_phases=3,
pw_hidden=32, gate_init=-3.0):
super().__init__()
assert input_dim % patch_dim == 0
P = input_dim // patch_dim
A, d = n_anchors, patch_dim
self.input_dim, self.patch_dim, self.n_patches = input_dim, patch_dim, P
self.n_anchors, self.n_phases = n_anchors, n_phases
home = torch.empty(P, A, d)
nn.init.xavier_normal_(home.view(P * A, d))
home = F.normalize(home.view(P, A, d), dim=-1)
self.register_buffer('home', home)
self.anchors = nn.Parameter(home.clone())
tri_dim = n_phases * A
self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden))
self.pw_b1 = nn.Parameter(torch.zeros(1, 1, P, pw_hidden))
self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d))
self.pw_b2 = nn.Parameter(torch.zeros(1, 1, P, d))
for p in range(P):
nn.init.xavier_normal_(self.pw_w1.data[p])
nn.init.xavier_normal_(self.pw_w2.data[p])
self.pw_norm = nn.LayerNorm(d)
self.gates = nn.Parameter(torch.full((P,), gate_init))
self.norm = nn.LayerNorm(input_dim)
def drift(self):
h = F.normalize(self.home, dim=-1)
c = F.normalize(self.anchors, dim=-1)
return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7))
def at_phase(self, t):
h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)
omega = self.drift().unsqueeze(-1)
so = omega.sin().clamp(min=1e-7)
return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c
def forward(self, x):
if x.dim() == 2: x = x.unsqueeze(1)
B, S, D = x.shape
P, A, d = self.n_patches, self.n_anchors, self.patch_dim
patches = self.norm(x).reshape(B*S, P, d)
patches_n = F.normalize(patches, dim=-1)
tris = []
for t in torch.linspace(0, 1, self.n_phases).tolist():
at = F.normalize(self.at_phase(t), dim=-1)
tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at))
tri = torch.cat(tris, dim=-1)
h = F.gelu(torch.einsum('bpt,pth->bph', tri, self.pw_w1) + self.pw_b1.squeeze(1))
pw = self.pw_norm(torch.einsum('bph,phd->bpd', h, self.pw_w2) + self.pw_b2.squeeze(1))
g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1)
out = (g * pw + (1-g) * patches).reshape(B, S, D)
return x + out
# ══════════════════════════════════════════════════════════════════
# TESTS
# ══════════════════════════════════════════════════════════════════
B = 4
S = 256
D = 128
N_CV = 1500
print("=" * 90)
print("HYBRID CONSTELLATION RELAY v2 β€” SPLIT GATES + CAUSAL TEST")
print(f" B={B}, S={S}, D={D} = {D//16}p Γ— 16d")
print(f" Fixed: 8 anchors Γ— 3 phases | Dynamic: 8 top-k with V-projection")
print(f" Device: {DEVICE}")
print("=" * 90)
configs = {
'vanilla_attn': lambda: VanillaAttn(D, 8).to(DEVICE),
'pure_relay': lambda: PureRelay(D, 16, 16, 3, 32).to(DEVICE),
'hybrid_v2': lambda: HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE),
}
# ── TEST 1: Single pass ──
print(f"\n{'━'*90}")
print("TEST 1: Single pass")
print(f"{'━'*90}")
x = torch.randn(B, S, D, device=DEVICE)
x_flat_n = F.normalize(x.reshape(B*S, D), dim=-1)
cv_base = compute_cv(x_flat_n, N_CV)
print(f" Baseline CV: {cv_base:.4f}")
print(f" {'arch':>15} {'params':>8} {'CV_n':>8} {'cos_orig':>10}")
for name, builder in configs.items():
m = builder()
np_ = sum(p.numel() for p in m.parameters())
with torch.no_grad():
out = m(x)
out_n = F.normalize(out.reshape(B*S, D), dim=-1)
cv = compute_cv(out_n, N_CV)
cos = (x_flat_n * out_n).sum(-1).mean().item()
print(f" {name:>15} {np_:>8,} {cv:>8.4f} {cos:>10.6f}")
# Hybrid diagnostics
hybrid_diag = HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE)
with torch.no_grad():
_, diag = hybrid_diag(x, return_diagnostics=True)
print(f"\n Hybrid gates: fixed={diag['fixed_gate']:.4f} dyn={diag['dyn_gate']:.4f} "
f"identity={diag['identity_weight']:.4f}")
# ── TEST 2: Depth sweep ──
print(f"\n{'━'*90}")
print("TEST 2: Depth 16")
print(f"{'━'*90}")
x = torch.randn(B, S, D, device=DEVICE)
x_flat_n = F.normalize(x.reshape(B*S, D), dim=-1)
checks = [1, 2, 4, 8, 12, 16]
for name, builder in configs.items():
print(f"\n {name}:")
print(f" {'d':>4} {'CV_n':>8} {'cos':>10} {'eff_d':>8}")
stack = nn.ModuleList([builder() for _ in range(16)])
z = x.clone()
for i, layer in enumerate(stack):
with torch.no_grad(): z = layer(z)
if (i+1) in checks:
zn = F.normalize(z.reshape(B*S, D), dim=-1)
print(f" {i+1:>4} {compute_cv(zn, N_CV):>8.4f} "
f"{(x_flat_n * zn).sum(-1).mean().item():>10.6f} "
f"{eff_dim(z.reshape(B*S, D)):>8.1f}")
# ── TEST 3: Interleaved ──
print(f"\n{'━'*90}")
print("TEST 3: Interleaved attn β†’ hybrid β†’ attn β†’ hybrid")
print(f"{'━'*90}")
x = torch.randn(B, S, D, device=DEVICE)
x_flat_n = F.normalize(x.reshape(B*S, D), dim=-1)
attn_l = nn.ModuleList([VanillaAttn(D, 8).to(DEVICE) for _ in range(8)])
hyb_l = nn.ModuleList([HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE) for _ in range(8)])
print(f" {'step':>4} {'type':>8} {'CV_n':>8} {'cos':>10} {'eff_d':>8}")
z = x.clone()
step = 0
for i in range(8):
with torch.no_grad(): z = attn_l[i](z)
step += 1
if step in checks:
zn = F.normalize(z.reshape(B*S, D), dim=-1)
print(f" {step:>4} {'attn':>8} {compute_cv(zn, N_CV):>8.4f} "
f"{(x_flat_n * zn).sum(-1).mean().item():>10.6f} "
f"{eff_dim(z.reshape(B*S, D)):>8.1f}")
with torch.no_grad(): z = hyb_l[i](z)
step += 1
if step in checks:
zn = F.normalize(z.reshape(B*S, D), dim=-1)
print(f" {step:>4} {'hybrid':>8} {compute_cv(zn, N_CV):>8.4f} "
f"{(x_flat_n * zn).sum(-1).mean().item():>10.6f} "
f"{eff_dim(z.reshape(B*S, D)):>8.1f}")
# ── TEST 4: CAUSAL INTERVENTION β€” the real cross-token routing test ──
print(f"\n{'━'*90}")
print("TEST 4: Causal intervention β€” does changing token 0 affect other tokens?")
print(f" Run same sequence twice, swap only token 0. Measure Ξ” on tokens 1-31.")
print(f"{'━'*90}")
S_test = 32
x_a = torch.randn(1, S_test, D, device=DEVICE)
x_b = x_a.clone()
x_b[:, 0] = torch.randn(1, D, device=DEVICE) # only token 0 differs
print(f" Token 0 cosine between A and B: "
f"{F.cosine_similarity(x_a[:, 0], x_b[:, 0]).item():.4f}")
print(f" Tokens 1-31 identical: "
f"{(x_a[:, 1:] == x_b[:, 1:]).all().item()}")
print(f"\n {'arch':>15} {'other_Ξ”_norm':>12} {'other_Ξ”_cos':>12} {'t0_Ξ”_norm':>10}")
for name, builder in configs.items():
m = builder()
with torch.no_grad():
out_a = m(x_a)
out_b = m(x_b)
# How much did tokens 1-31 change?
delta_others = (out_a[:, 1:] - out_b[:, 1:])
other_norm = delta_others.norm(dim=-1).mean().item()
# Cosine change for other tokens
cos_others = F.cosine_similarity(
out_a[:, 1:].reshape(-1, D),
out_b[:, 1:].reshape(-1, D)).mean().item()
# Token 0 change (sanity β€” should be large for all)
t0_norm = (out_a[:, 0] - out_b[:, 0]).norm().item()
print(f" {name:>15} {other_norm:>12.6f} {1-cos_others:>12.8f} {t0_norm:>10.4f}")
# Run multiple layers to amplify routing signal
print(f"\n After 4 stacked layers:")
print(f" {'arch':>15} {'other_Ξ”_norm':>12} {'other_Ξ”_cos':>12}")
for name, builder in configs.items():
layers = nn.ModuleList([builder() for _ in range(4)])
with torch.no_grad():
za, zb = x_a.clone(), x_b.clone()
for layer in layers:
za = layer(za)
zb = layer(zb)
delta = (za[:, 1:] - zb[:, 1:])
other_norm = delta.norm(dim=-1).mean().item()
cos_others = F.cosine_similarity(
za[:, 1:].reshape(-1, D),
zb[:, 1:].reshape(-1, D)).mean().item()
print(f" {name:>15} {other_norm:>12.6f} {1-cos_others:>12.8f}")
# ── TEST 5: Throughput ──
print(f"\n{'━'*90}")
print("TEST 5: Throughput")
print(f"{'━'*90}")
x_bench = torch.randn(B, S, D, device=DEVICE)
print(f" {'arch':>15} {'ms':>8} {'params':>10}")
for name, builder in configs.items():
m = builder()
np_ = sum(p.numel() for p in m.parameters())
for _ in range(5):
with torch.no_grad(): _ = m(x_bench)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(100):
with torch.no_grad(): _ = m(x_bench)
torch.cuda.synchronize()
ms = (time.time() - t0) / 100 * 1000
print(f" {name:>15} {ms:>8.2f} {np_:>10,}")
# ── TEST 6: Sequence scaling ──
print(f"\n{'━'*90}")
print("TEST 6: Sequence length scaling")
print(f"{'━'*90}")
print(f" {'S':>6} {'hybrid_ms':>10} {'attn_ms':>10} {'ratio':>8}")
for sl in [64, 128, 256, 512, 1024]:
xs = torch.randn(2, sl, D, device=DEVICE)
h_m = HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE)
a_m = VanillaAttn(D, 8).to(DEVICE)
with torch.no_grad(): _ = h_m(xs); _ = a_m(xs)
torch.cuda.synchronize()
t0 = time.time()
for _ in range(50):
with torch.no_grad(): _ = h_m(xs)
torch.cuda.synchronize()
h_ms = (time.time() - t0) / 50 * 1000
t0 = time.time()
for _ in range(50):
with torch.no_grad(): _ = a_m(xs)
torch.cuda.synchronize()
a_ms = (time.time() - t0) / 50 * 1000
print(f" {sl:>6} {h_ms:>10.2f} {a_ms:>10.2f} {h_ms/a_ms:>8.2f}Γ—")
# ══════════════════════════════════════════════════════════════════
# SUMMARY
# ══════════════════════════════════════════════════════════════════
print(f"\n{'='*90}")
print("SUMMARY")
print(f"{'='*90}")
print(f"""
Hybrid v2 architecture per patch:
Fixed: 8 anchors Γ— 3 phases β†’ MLP β†’ fixed_out (gate β‰ˆ 0.047)
Dynamic: top-8 QΒ·K β†’ gather V β†’ weighted sum β†’ MLP β†’ dyn_out (gate β‰ˆ 0.269)
Output: fg*fixed + dg*dynamic + (1-fg-dg)*identity + skip
GPT's challenge:
βœ“ Selective interaction β€” QΒ·K top-k selection
βœ“ Conditional transformation β€” separate MLPs for fixed/dynamic
βœ“ Information routing β€” V-projection carries information through geometric channel
Key test: Causal intervention (Test 4)
If other_Ξ”_norm > 0 for hybrid but β‰ˆ 0 for pure_relay,
cross-token routing is proven.
""")
print(f"{'='*90}")
print("DONE")
print(f"{'='*90}")