#!/usr/bin/env python3 """ Constellation Core V11 Hybrid Constellation Relay v2 ================================ Fixes from v1: - Split gates: fixed_gate (cold, -3.0) + dynamic_gate (warm, -1.0) - Balanced: 8 fixed + 8 dynamic per patch - Separate dynamic MLP before merge - Proper causal intervention test for cross-token routing - V-projection: dynamic anchors carry value information, not just position """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import time DEVICE = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(42) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True HAS_FP8 = hasattr(torch, 'float8_e4m3fn') def compute_cv(points, n_samples=1500, n_points=5): N = points.shape[0] if N < n_points: return float('nan') points = F.normalize(points.to(DEVICE).float(), dim=-1) vols = [] for _ in range(n_samples): idx = torch.randperm(min(N, 10000), device=DEVICE)[:n_points] pts = points[idx].unsqueeze(0) gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram d2 = F.relu(d2) cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 v2 = -torch.linalg.det(cm) / 9216 if v2[0].item() > 1e-20: vols.append(v2[0].sqrt().cpu()) if len(vols) < 50: return float('nan') vt = torch.stack(vols) return (vt.std() / (vt.mean() + 1e-8)).item() def eff_dim(x): x_c = x - x.mean(0, keepdim=True) _, S, _ = torch.linalg.svd(x_c[:512].float(), full_matrices=False) p = S / S.sum() return p.pow(2).sum().reciprocal().item() def uniform_sphere(n, d): return F.normalize(torch.randn(n, d), dim=-1) # ══════════════════════════════════════════════════════════════════ # HYBRID CONSTELLATION RELAY v2 # ══════════════════════════════════════════════════════════════════ class HybridRelay(nn.Module): """ Fixed geometric anchors + dynamic cross-token anchors. Split processing paths with separate gates. Per patch (d=16): Fixed path: A_f anchors × n_phases → fixed_mlp → fixed_out (d) Dynamic path: top-k Q·K selection → gather V → dynamic_mlp → dyn_out (d) Output: fixed_gate * fixed_out + dyn_gate * dyn_out + (1-both) * identity """ def __init__( self, input_dim, patch_dim=16, n_fixed=8, n_dynamic=8, n_phases=3, pw_hidden=32, fixed_gate_init=-3.0, # sigmoid ≈ 0.047 dyn_gate_init=-1.0, # sigmoid ≈ 0.269 ): super().__init__() assert input_dim % patch_dim == 0 self.input_dim = input_dim self.patch_dim = patch_dim self.n_patches = input_dim // patch_dim self.n_fixed = n_fixed self.n_dynamic = n_dynamic self.n_phases = n_phases P, Af, k, d = self.n_patches, n_fixed, n_dynamic, patch_dim # ── Fixed constellation ── home = torch.empty(P, Af, d) nn.init.xavier_normal_(home.view(P * Af, d)) home = F.normalize(home.view(P, Af, d), dim=-1) self.register_buffer('home', home) self.anchors = nn.Parameter(home.clone()) # Fixed path MLP: (phases * Af) → d fixed_tri_dim = n_phases * Af self.fixed_w1 = nn.Parameter(torch.empty(P, fixed_tri_dim, pw_hidden)) self.fixed_b1 = nn.Parameter(torch.zeros(1, 1, P, pw_hidden)) self.fixed_w2 = nn.Parameter(torch.empty(P, pw_hidden, d)) self.fixed_b2 = nn.Parameter(torch.zeros(1, 1, P, d)) for p in range(P): nn.init.xavier_normal_(self.fixed_w1.data[p]) nn.init.xavier_normal_(self.fixed_w2.data[p]) self.fixed_norm = nn.LayerNorm(d) # ── Dynamic cross-token path ── # Q, K for selection; V for information transfer self.q_proj = nn.Parameter(torch.empty(P, d, d)) self.k_proj = nn.Parameter(torch.empty(P, d, d)) self.v_proj = nn.Parameter(torch.empty(P, d, d)) for p in range(P): nn.init.xavier_normal_(self.q_proj.data[p]) nn.init.xavier_normal_(self.k_proj.data[p]) nn.init.xavier_normal_(self.v_proj.data[p]) # Dynamic path MLP: (k * d) → d (reads gathered V values) dyn_input_dim = k * d self.dyn_w1 = nn.Parameter(torch.empty(P, dyn_input_dim, pw_hidden)) self.dyn_b1 = nn.Parameter(torch.zeros(1, 1, P, pw_hidden)) self.dyn_w2 = nn.Parameter(torch.empty(P, pw_hidden, d)) self.dyn_b2 = nn.Parameter(torch.zeros(1, 1, P, d)) for p in range(P): nn.init.xavier_normal_(self.dyn_w1.data[p]) nn.init.xavier_normal_(self.dyn_w2.data[p]) self.dyn_norm = nn.LayerNorm(d) # ── Split gates ── self.fixed_gate = nn.Parameter(torch.full((P,), fixed_gate_init)) self.dyn_gate = nn.Parameter(torch.full((P,), dyn_gate_init)) self.norm = nn.LayerNorm(input_dim) def drift(self): h = F.normalize(self.home, dim=-1) c = F.normalize(self.anchors, dim=-1) cos = (h * c).sum(dim=-1).clamp(-1 + 1e-7, 1 - 1e-7) return torch.acos(cos) def at_phase(self, t): h = F.normalize(self.home, dim=-1) c = F.normalize(self.anchors, dim=-1) omega = self.drift().unsqueeze(-1) sin_omega = omega.sin().clamp(min=1e-7) return (torch.sin((1 - t) * omega) / sin_omega * h + torch.sin(t * omega) / sin_omega * c) def forward(self, x, return_diagnostics=False): """x: (B, S, D)""" B, S, D = x.shape P, Af, k, d = self.n_patches, self.n_fixed, self.n_dynamic, self.patch_dim x_n = self.norm(x) patches = x_n.reshape(B, S, P, d) patches_n = F.normalize(patches, dim=-1) # ══════ FIXED PATH ══════ phases = torch.linspace(0, 1, self.n_phases).tolist() fixed_tris = [] for t in phases: anchors_t = F.normalize(self.at_phase(t), dim=-1) # (P, Af, d) cos_f = torch.einsum('bspd,pad->bspa', patches_n, anchors_t) fixed_tris.append(1.0 - cos_f) fixed_tri = torch.cat(fixed_tris, dim=-1) # (B, S, P, Af*phases) h_f = torch.einsum('bspt,pth->bsph', fixed_tri, self.fixed_w1) + self.fixed_b1 h_f = F.gelu(h_f) fixed_out = torch.einsum('bsph,phd->bspd', h_f, self.fixed_w2) + self.fixed_b2 fixed_out = self.fixed_norm(fixed_out) # (B, S, P, d) # ══════ DYNAMIC PATH ══════ # Q, K, V projections Q = F.normalize(torch.einsum('bspd,pde->bspe', patches_n, self.q_proj), dim=-1) K = F.normalize(torch.einsum('bspd,pde->bspe', patches_n, self.k_proj), dim=-1) V = torch.einsum('bspd,pde->bspe', patches, self.v_proj) # V not normalized — carries magnitude # Relevance: Q_i · K_j → (B, P, S, S) relevance = torch.einsum('bspd,btpd->bpst', Q, K) # Mask self self_mask = torch.eye(S, device=x.device, dtype=torch.bool) relevance = relevance.masked_fill(self_mask.unsqueeze(0).unsqueeze(0), -1e9) # Soft top-k: take softmax over keys, then gather top-k # This makes gradients flow through the selection rel_weights = relevance.softmax(dim=-1) # (B, P, S, S) # Top-k indices for sparse gather _, topk_idx = relevance.topk(k, dim=-1) # (B, P, S, k) # Gather top-k weights and re-normalize topk_weights = torch.gather(rel_weights, -1, topk_idx) # (B, P, S, k) topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-8) # Gather top-k V vectors: V is (B, S, P, d) → need (B, P, S, d) V_perm = V.permute(0, 2, 1, 3) # (B, P, S, d) # For each (b, p, s), gather V[b, p, topk_idx[b,p,s,:], :] topk_idx_v = topk_idx.unsqueeze(-1).expand(-1, -1, -1, -1, d) # (B, P, S, k, d) V_expanded = V_perm.unsqueeze(2).expand(-1, -1, S, -1, -1) # (B, P, S, S, d) topk_V = torch.gather(V_expanded, 3, topk_idx_v) # (B, P, S, k, d) # Weighted sum of top-k values weighted_V = (topk_weights.unsqueeze(-1) * topk_V).reshape(B, P, S, k * d) # → (B, S, P, k*d) weighted_V = weighted_V.permute(0, 2, 1, 3) # Dynamic MLP h_d = torch.einsum('bspt,pth->bsph', weighted_V, self.dyn_w1) + self.dyn_b1 h_d = F.gelu(h_d) dyn_out = torch.einsum('bsph,phd->bspd', h_d, self.dyn_w2) + self.dyn_b2 dyn_out = self.dyn_norm(dyn_out) # (B, S, P, d) # ══════ GATED MERGE ══════ fg = self.fixed_gate.sigmoid().view(1, 1, P, 1) dg = self.dyn_gate.sigmoid().view(1, 1, P, 1) # Identity weight = 1 - fg - dg (can go negative if both gates high, but sigmoid caps each at 1) identity_weight = (1.0 - fg - dg).clamp(min=0.0) blended = fg * fixed_out + dg * dyn_out + identity_weight * patches out = blended.reshape(B, S, D) result = x + out if return_diagnostics: drift = self.drift() diag = { 'drift_mean': drift.mean().item(), 'fixed_gate': self.fixed_gate.sigmoid().mean().item(), 'dyn_gate': self.dyn_gate.sigmoid().mean().item(), 'identity_weight': identity_weight.mean().item(), 'topk_cos_mean': torch.gather(relevance, -1, topk_idx).mean().item(), 'topk_cos_max': torch.gather(relevance, -1, topk_idx).max().item(), } return result, diag return result # ══════════════════════════════════════════════════════════════════ # COMPARISON MODULES # ══════════════════════════════════════════════════════════════════ class VanillaAttn(nn.Module): def __init__(self, dim, n_heads=4): super().__init__() self.n_heads = n_heads self.head_dim = dim // n_heads self.qkv = nn.Linear(dim, 3 * dim, bias=False) self.out_proj = nn.Linear(dim, dim, bias=False) self.norm = nn.LayerNorm(dim) def forward(self, x): B, S, D = x.shape x_n = self.norm(x) qkv = self.qkv(x_n).reshape(B, S, 3, self.n_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) attn = attn.softmax(dim=-1) out = (attn @ v).transpose(1, 2).reshape(B, S, D) return x + self.out_proj(out) class PureRelay(nn.Module): def __init__(self, input_dim, patch_dim=16, n_anchors=16, n_phases=3, pw_hidden=32, gate_init=-3.0): super().__init__() assert input_dim % patch_dim == 0 P = input_dim // patch_dim A, d = n_anchors, patch_dim self.input_dim, self.patch_dim, self.n_patches = input_dim, patch_dim, P self.n_anchors, self.n_phases = n_anchors, n_phases home = torch.empty(P, A, d) nn.init.xavier_normal_(home.view(P * A, d)) home = F.normalize(home.view(P, A, d), dim=-1) self.register_buffer('home', home) self.anchors = nn.Parameter(home.clone()) tri_dim = n_phases * A self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden)) self.pw_b1 = nn.Parameter(torch.zeros(1, 1, P, pw_hidden)) self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d)) self.pw_b2 = nn.Parameter(torch.zeros(1, 1, P, d)) for p in range(P): nn.init.xavier_normal_(self.pw_w1.data[p]) nn.init.xavier_normal_(self.pw_w2.data[p]) self.pw_norm = nn.LayerNorm(d) self.gates = nn.Parameter(torch.full((P,), gate_init)) self.norm = nn.LayerNorm(input_dim) def drift(self): h = F.normalize(self.home, dim=-1) c = F.normalize(self.anchors, dim=-1) return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)) def at_phase(self, t): h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) omega = self.drift().unsqueeze(-1) so = omega.sin().clamp(min=1e-7) return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c def forward(self, x): if x.dim() == 2: x = x.unsqueeze(1) B, S, D = x.shape P, A, d = self.n_patches, self.n_anchors, self.patch_dim patches = self.norm(x).reshape(B*S, P, d) patches_n = F.normalize(patches, dim=-1) tris = [] for t in torch.linspace(0, 1, self.n_phases).tolist(): at = F.normalize(self.at_phase(t), dim=-1) tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at)) tri = torch.cat(tris, dim=-1) h = F.gelu(torch.einsum('bpt,pth->bph', tri, self.pw_w1) + self.pw_b1.squeeze(1)) pw = self.pw_norm(torch.einsum('bph,phd->bpd', h, self.pw_w2) + self.pw_b2.squeeze(1)) g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1) out = (g * pw + (1-g) * patches).reshape(B, S, D) return x + out # ══════════════════════════════════════════════════════════════════ # TESTS # ══════════════════════════════════════════════════════════════════ B = 4 S = 256 D = 128 N_CV = 1500 print("=" * 90) print("HYBRID CONSTELLATION RELAY v2 — SPLIT GATES + CAUSAL TEST") print(f" B={B}, S={S}, D={D} = {D//16}p × 16d") print(f" Fixed: 8 anchors × 3 phases | Dynamic: 8 top-k with V-projection") print(f" Device: {DEVICE}") print("=" * 90) configs = { 'vanilla_attn': lambda: VanillaAttn(D, 8).to(DEVICE), 'pure_relay': lambda: PureRelay(D, 16, 16, 3, 32).to(DEVICE), 'hybrid_v2': lambda: HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE), } # ── TEST 1: Single pass ── print(f"\n{'━'*90}") print("TEST 1: Single pass") print(f"{'━'*90}") x = torch.randn(B, S, D, device=DEVICE) x_flat_n = F.normalize(x.reshape(B*S, D), dim=-1) cv_base = compute_cv(x_flat_n, N_CV) print(f" Baseline CV: {cv_base:.4f}") print(f" {'arch':>15} {'params':>8} {'CV_n':>8} {'cos_orig':>10}") for name, builder in configs.items(): m = builder() np_ = sum(p.numel() for p in m.parameters()) with torch.no_grad(): out = m(x) out_n = F.normalize(out.reshape(B*S, D), dim=-1) cv = compute_cv(out_n, N_CV) cos = (x_flat_n * out_n).sum(-1).mean().item() print(f" {name:>15} {np_:>8,} {cv:>8.4f} {cos:>10.6f}") # Hybrid diagnostics hybrid_diag = HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE) with torch.no_grad(): _, diag = hybrid_diag(x, return_diagnostics=True) print(f"\n Hybrid gates: fixed={diag['fixed_gate']:.4f} dyn={diag['dyn_gate']:.4f} " f"identity={diag['identity_weight']:.4f}") # ── TEST 2: Depth sweep ── print(f"\n{'━'*90}") print("TEST 2: Depth 16") print(f"{'━'*90}") x = torch.randn(B, S, D, device=DEVICE) x_flat_n = F.normalize(x.reshape(B*S, D), dim=-1) checks = [1, 2, 4, 8, 12, 16] for name, builder in configs.items(): print(f"\n {name}:") print(f" {'d':>4} {'CV_n':>8} {'cos':>10} {'eff_d':>8}") stack = nn.ModuleList([builder() for _ in range(16)]) z = x.clone() for i, layer in enumerate(stack): with torch.no_grad(): z = layer(z) if (i+1) in checks: zn = F.normalize(z.reshape(B*S, D), dim=-1) print(f" {i+1:>4} {compute_cv(zn, N_CV):>8.4f} " f"{(x_flat_n * zn).sum(-1).mean().item():>10.6f} " f"{eff_dim(z.reshape(B*S, D)):>8.1f}") # ── TEST 3: Interleaved ── print(f"\n{'━'*90}") print("TEST 3: Interleaved attn → hybrid → attn → hybrid") print(f"{'━'*90}") x = torch.randn(B, S, D, device=DEVICE) x_flat_n = F.normalize(x.reshape(B*S, D), dim=-1) attn_l = nn.ModuleList([VanillaAttn(D, 8).to(DEVICE) for _ in range(8)]) hyb_l = nn.ModuleList([HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE) for _ in range(8)]) print(f" {'step':>4} {'type':>8} {'CV_n':>8} {'cos':>10} {'eff_d':>8}") z = x.clone() step = 0 for i in range(8): with torch.no_grad(): z = attn_l[i](z) step += 1 if step in checks: zn = F.normalize(z.reshape(B*S, D), dim=-1) print(f" {step:>4} {'attn':>8} {compute_cv(zn, N_CV):>8.4f} " f"{(x_flat_n * zn).sum(-1).mean().item():>10.6f} " f"{eff_dim(z.reshape(B*S, D)):>8.1f}") with torch.no_grad(): z = hyb_l[i](z) step += 1 if step in checks: zn = F.normalize(z.reshape(B*S, D), dim=-1) print(f" {step:>4} {'hybrid':>8} {compute_cv(zn, N_CV):>8.4f} " f"{(x_flat_n * zn).sum(-1).mean().item():>10.6f} " f"{eff_dim(z.reshape(B*S, D)):>8.1f}") # ── TEST 4: CAUSAL INTERVENTION — the real cross-token routing test ── print(f"\n{'━'*90}") print("TEST 4: Causal intervention — does changing token 0 affect other tokens?") print(f" Run same sequence twice, swap only token 0. Measure Δ on tokens 1-31.") print(f"{'━'*90}") S_test = 32 x_a = torch.randn(1, S_test, D, device=DEVICE) x_b = x_a.clone() x_b[:, 0] = torch.randn(1, D, device=DEVICE) # only token 0 differs print(f" Token 0 cosine between A and B: " f"{F.cosine_similarity(x_a[:, 0], x_b[:, 0]).item():.4f}") print(f" Tokens 1-31 identical: " f"{(x_a[:, 1:] == x_b[:, 1:]).all().item()}") print(f"\n {'arch':>15} {'other_Δ_norm':>12} {'other_Δ_cos':>12} {'t0_Δ_norm':>10}") for name, builder in configs.items(): m = builder() with torch.no_grad(): out_a = m(x_a) out_b = m(x_b) # How much did tokens 1-31 change? delta_others = (out_a[:, 1:] - out_b[:, 1:]) other_norm = delta_others.norm(dim=-1).mean().item() # Cosine change for other tokens cos_others = F.cosine_similarity( out_a[:, 1:].reshape(-1, D), out_b[:, 1:].reshape(-1, D)).mean().item() # Token 0 change (sanity — should be large for all) t0_norm = (out_a[:, 0] - out_b[:, 0]).norm().item() print(f" {name:>15} {other_norm:>12.6f} {1-cos_others:>12.8f} {t0_norm:>10.4f}") # Run multiple layers to amplify routing signal print(f"\n After 4 stacked layers:") print(f" {'arch':>15} {'other_Δ_norm':>12} {'other_Δ_cos':>12}") for name, builder in configs.items(): layers = nn.ModuleList([builder() for _ in range(4)]) with torch.no_grad(): za, zb = x_a.clone(), x_b.clone() for layer in layers: za = layer(za) zb = layer(zb) delta = (za[:, 1:] - zb[:, 1:]) other_norm = delta.norm(dim=-1).mean().item() cos_others = F.cosine_similarity( za[:, 1:].reshape(-1, D), zb[:, 1:].reshape(-1, D)).mean().item() print(f" {name:>15} {other_norm:>12.6f} {1-cos_others:>12.8f}") # ── TEST 5: Throughput ── print(f"\n{'━'*90}") print("TEST 5: Throughput") print(f"{'━'*90}") x_bench = torch.randn(B, S, D, device=DEVICE) print(f" {'arch':>15} {'ms':>8} {'params':>10}") for name, builder in configs.items(): m = builder() np_ = sum(p.numel() for p in m.parameters()) for _ in range(5): with torch.no_grad(): _ = m(x_bench) torch.cuda.synchronize() t0 = time.time() for _ in range(100): with torch.no_grad(): _ = m(x_bench) torch.cuda.synchronize() ms = (time.time() - t0) / 100 * 1000 print(f" {name:>15} {ms:>8.2f} {np_:>10,}") # ── TEST 6: Sequence scaling ── print(f"\n{'━'*90}") print("TEST 6: Sequence length scaling") print(f"{'━'*90}") print(f" {'S':>6} {'hybrid_ms':>10} {'attn_ms':>10} {'ratio':>8}") for sl in [64, 128, 256, 512, 1024]: xs = torch.randn(2, sl, D, device=DEVICE) h_m = HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE) a_m = VanillaAttn(D, 8).to(DEVICE) with torch.no_grad(): _ = h_m(xs); _ = a_m(xs) torch.cuda.synchronize() t0 = time.time() for _ in range(50): with torch.no_grad(): _ = h_m(xs) torch.cuda.synchronize() h_ms = (time.time() - t0) / 50 * 1000 t0 = time.time() for _ in range(50): with torch.no_grad(): _ = a_m(xs) torch.cuda.synchronize() a_ms = (time.time() - t0) / 50 * 1000 print(f" {sl:>6} {h_ms:>10.2f} {a_ms:>10.2f} {h_ms/a_ms:>8.2f}×") # ══════════════════════════════════════════════════════════════════ # SUMMARY # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*90}") print("SUMMARY") print(f"{'='*90}") print(f""" Hybrid v2 architecture per patch: Fixed: 8 anchors × 3 phases → MLP → fixed_out (gate ≈ 0.047) Dynamic: top-8 Q·K → gather V → weighted sum → MLP → dyn_out (gate ≈ 0.269) Output: fg*fixed + dg*dynamic + (1-fg-dg)*identity + skip GPT's challenge: ✓ Selective interaction — Q·K top-k selection ✓ Conditional transformation — separate MLPs for fixed/dynamic ✓ Information routing — V-projection carries information through geometric channel Key test: Causal intervention (Test 4) If other_Δ_norm > 0 for hybrid but ≈ 0 for pure_relay, cross-token routing is proven. """) print(f"{'='*90}") print("DONE") print(f"{'='*90}")