| |
| """ |
| Constellation Core V11 |
| Hybrid Constellation Relay v2 |
| ================================ |
| Fixes from v1: |
| - Split gates: fixed_gate (cold, -3.0) + dynamic_gate (warm, -1.0) |
| - Balanced: 8 fixed + 8 dynamic per patch |
| - Separate dynamic MLP before merge |
| - Proper causal intervention test for cross-token routing |
| - V-projection: dynamic anchors carry value information, not just position |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import time |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.manual_seed(42) |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| HAS_FP8 = hasattr(torch, 'float8_e4m3fn') |
|
|
|
|
| def compute_cv(points, n_samples=1500, n_points=5): |
| N = points.shape[0] |
| if N < n_points: return float('nan') |
| points = F.normalize(points.to(DEVICE).float(), dim=-1) |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(min(N, 10000), device=DEVICE)[:n_points] |
| pts = points[idx].unsqueeze(0) |
| gram = torch.bmm(pts, pts.transpose(1, 2)) |
| norms = torch.diagonal(gram, dim1=1, dim2=2) |
| d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram |
| d2 = F.relu(d2) |
| cm = torch.zeros(1, 6, 6, device=DEVICE, dtype=torch.float32) |
| cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 |
| v2 = -torch.linalg.det(cm) / 9216 |
| if v2[0].item() > 1e-20: |
| vols.append(v2[0].sqrt().cpu()) |
| if len(vols) < 50: return float('nan') |
| vt = torch.stack(vols) |
| return (vt.std() / (vt.mean() + 1e-8)).item() |
|
|
|
|
| def eff_dim(x): |
| x_c = x - x.mean(0, keepdim=True) |
| _, S, _ = torch.linalg.svd(x_c[:512].float(), full_matrices=False) |
| p = S / S.sum() |
| return p.pow(2).sum().reciprocal().item() |
|
|
|
|
| def uniform_sphere(n, d): |
| return F.normalize(torch.randn(n, d), dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class HybridRelay(nn.Module): |
| """ |
| Fixed geometric anchors + dynamic cross-token anchors. |
| Split processing paths with separate gates. |
| |
| Per patch (d=16): |
| Fixed path: A_f anchors Γ n_phases β fixed_mlp β fixed_out (d) |
| Dynamic path: top-k QΒ·K selection β gather V β dynamic_mlp β dyn_out (d) |
| Output: fixed_gate * fixed_out + dyn_gate * dyn_out + (1-both) * identity |
| """ |
| def __init__( |
| self, |
| input_dim, |
| patch_dim=16, |
| n_fixed=8, |
| n_dynamic=8, |
| n_phases=3, |
| pw_hidden=32, |
| fixed_gate_init=-3.0, |
| dyn_gate_init=-1.0, |
| ): |
| super().__init__() |
| assert input_dim % patch_dim == 0 |
| self.input_dim = input_dim |
| self.patch_dim = patch_dim |
| self.n_patches = input_dim // patch_dim |
| self.n_fixed = n_fixed |
| self.n_dynamic = n_dynamic |
| self.n_phases = n_phases |
|
|
| P, Af, k, d = self.n_patches, n_fixed, n_dynamic, patch_dim |
|
|
| |
| home = torch.empty(P, Af, d) |
| nn.init.xavier_normal_(home.view(P * Af, d)) |
| home = F.normalize(home.view(P, Af, d), dim=-1) |
| self.register_buffer('home', home) |
| self.anchors = nn.Parameter(home.clone()) |
|
|
| |
| fixed_tri_dim = n_phases * Af |
| self.fixed_w1 = nn.Parameter(torch.empty(P, fixed_tri_dim, pw_hidden)) |
| self.fixed_b1 = nn.Parameter(torch.zeros(1, 1, P, pw_hidden)) |
| self.fixed_w2 = nn.Parameter(torch.empty(P, pw_hidden, d)) |
| self.fixed_b2 = nn.Parameter(torch.zeros(1, 1, P, d)) |
| for p in range(P): |
| nn.init.xavier_normal_(self.fixed_w1.data[p]) |
| nn.init.xavier_normal_(self.fixed_w2.data[p]) |
| self.fixed_norm = nn.LayerNorm(d) |
|
|
| |
| |
| self.q_proj = nn.Parameter(torch.empty(P, d, d)) |
| self.k_proj = nn.Parameter(torch.empty(P, d, d)) |
| self.v_proj = nn.Parameter(torch.empty(P, d, d)) |
| for p in range(P): |
| nn.init.xavier_normal_(self.q_proj.data[p]) |
| nn.init.xavier_normal_(self.k_proj.data[p]) |
| nn.init.xavier_normal_(self.v_proj.data[p]) |
|
|
| |
| dyn_input_dim = k * d |
| self.dyn_w1 = nn.Parameter(torch.empty(P, dyn_input_dim, pw_hidden)) |
| self.dyn_b1 = nn.Parameter(torch.zeros(1, 1, P, pw_hidden)) |
| self.dyn_w2 = nn.Parameter(torch.empty(P, pw_hidden, d)) |
| self.dyn_b2 = nn.Parameter(torch.zeros(1, 1, P, d)) |
| for p in range(P): |
| nn.init.xavier_normal_(self.dyn_w1.data[p]) |
| nn.init.xavier_normal_(self.dyn_w2.data[p]) |
| self.dyn_norm = nn.LayerNorm(d) |
|
|
| |
| self.fixed_gate = nn.Parameter(torch.full((P,), fixed_gate_init)) |
| self.dyn_gate = nn.Parameter(torch.full((P,), dyn_gate_init)) |
|
|
| self.norm = nn.LayerNorm(input_dim) |
|
|
| def drift(self): |
| h = F.normalize(self.home, dim=-1) |
| c = F.normalize(self.anchors, dim=-1) |
| cos = (h * c).sum(dim=-1).clamp(-1 + 1e-7, 1 - 1e-7) |
| return torch.acos(cos) |
|
|
| def at_phase(self, t): |
| h = F.normalize(self.home, dim=-1) |
| c = F.normalize(self.anchors, dim=-1) |
| omega = self.drift().unsqueeze(-1) |
| sin_omega = omega.sin().clamp(min=1e-7) |
| return (torch.sin((1 - t) * omega) / sin_omega * h + |
| torch.sin(t * omega) / sin_omega * c) |
|
|
| def forward(self, x, return_diagnostics=False): |
| """x: (B, S, D)""" |
| B, S, D = x.shape |
| P, Af, k, d = self.n_patches, self.n_fixed, self.n_dynamic, self.patch_dim |
|
|
| x_n = self.norm(x) |
| patches = x_n.reshape(B, S, P, d) |
| patches_n = F.normalize(patches, dim=-1) |
|
|
| |
| phases = torch.linspace(0, 1, self.n_phases).tolist() |
| fixed_tris = [] |
| for t in phases: |
| anchors_t = F.normalize(self.at_phase(t), dim=-1) |
| cos_f = torch.einsum('bspd,pad->bspa', patches_n, anchors_t) |
| fixed_tris.append(1.0 - cos_f) |
| fixed_tri = torch.cat(fixed_tris, dim=-1) |
|
|
| h_f = torch.einsum('bspt,pth->bsph', fixed_tri, self.fixed_w1) + self.fixed_b1 |
| h_f = F.gelu(h_f) |
| fixed_out = torch.einsum('bsph,phd->bspd', h_f, self.fixed_w2) + self.fixed_b2 |
| fixed_out = self.fixed_norm(fixed_out) |
|
|
| |
| |
| Q = F.normalize(torch.einsum('bspd,pde->bspe', patches_n, self.q_proj), dim=-1) |
| K = F.normalize(torch.einsum('bspd,pde->bspe', patches_n, self.k_proj), dim=-1) |
| V = torch.einsum('bspd,pde->bspe', patches, self.v_proj) |
|
|
| |
| relevance = torch.einsum('bspd,btpd->bpst', Q, K) |
|
|
| |
| self_mask = torch.eye(S, device=x.device, dtype=torch.bool) |
| relevance = relevance.masked_fill(self_mask.unsqueeze(0).unsqueeze(0), -1e9) |
|
|
| |
| |
| rel_weights = relevance.softmax(dim=-1) |
|
|
| |
| _, topk_idx = relevance.topk(k, dim=-1) |
|
|
| |
| topk_weights = torch.gather(rel_weights, -1, topk_idx) |
| topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + 1e-8) |
|
|
| |
| V_perm = V.permute(0, 2, 1, 3) |
| |
| topk_idx_v = topk_idx.unsqueeze(-1).expand(-1, -1, -1, -1, d) |
| V_expanded = V_perm.unsqueeze(2).expand(-1, -1, S, -1, -1) |
| topk_V = torch.gather(V_expanded, 3, topk_idx_v) |
|
|
| |
| weighted_V = (topk_weights.unsqueeze(-1) * topk_V).reshape(B, P, S, k * d) |
| |
| weighted_V = weighted_V.permute(0, 2, 1, 3) |
|
|
| |
| h_d = torch.einsum('bspt,pth->bsph', weighted_V, self.dyn_w1) + self.dyn_b1 |
| h_d = F.gelu(h_d) |
| dyn_out = torch.einsum('bsph,phd->bspd', h_d, self.dyn_w2) + self.dyn_b2 |
| dyn_out = self.dyn_norm(dyn_out) |
|
|
| |
| fg = self.fixed_gate.sigmoid().view(1, 1, P, 1) |
| dg = self.dyn_gate.sigmoid().view(1, 1, P, 1) |
| |
| identity_weight = (1.0 - fg - dg).clamp(min=0.0) |
|
|
| blended = fg * fixed_out + dg * dyn_out + identity_weight * patches |
| out = blended.reshape(B, S, D) |
| result = x + out |
|
|
| if return_diagnostics: |
| drift = self.drift() |
| diag = { |
| 'drift_mean': drift.mean().item(), |
| 'fixed_gate': self.fixed_gate.sigmoid().mean().item(), |
| 'dyn_gate': self.dyn_gate.sigmoid().mean().item(), |
| 'identity_weight': identity_weight.mean().item(), |
| 'topk_cos_mean': torch.gather(relevance, -1, topk_idx).mean().item(), |
| 'topk_cos_max': torch.gather(relevance, -1, topk_idx).max().item(), |
| } |
| return result, diag |
| return result |
|
|
|
|
| |
| |
| |
|
|
| class VanillaAttn(nn.Module): |
| def __init__(self, dim, n_heads=4): |
| super().__init__() |
| self.n_heads = n_heads |
| self.head_dim = dim // n_heads |
| self.qkv = nn.Linear(dim, 3 * dim, bias=False) |
| self.out_proj = nn.Linear(dim, dim, bias=False) |
| self.norm = nn.LayerNorm(dim) |
|
|
| def forward(self, x): |
| B, S, D = x.shape |
| x_n = self.norm(x) |
| qkv = self.qkv(x_n).reshape(B, S, 3, self.n_heads, self.head_dim) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) |
| attn = attn.softmax(dim=-1) |
| out = (attn @ v).transpose(1, 2).reshape(B, S, D) |
| return x + self.out_proj(out) |
|
|
|
|
| class PureRelay(nn.Module): |
| def __init__(self, input_dim, patch_dim=16, n_anchors=16, n_phases=3, |
| pw_hidden=32, gate_init=-3.0): |
| super().__init__() |
| assert input_dim % patch_dim == 0 |
| P = input_dim // patch_dim |
| A, d = n_anchors, patch_dim |
| self.input_dim, self.patch_dim, self.n_patches = input_dim, patch_dim, P |
| self.n_anchors, self.n_phases = n_anchors, n_phases |
|
|
| home = torch.empty(P, A, d) |
| nn.init.xavier_normal_(home.view(P * A, d)) |
| home = F.normalize(home.view(P, A, d), dim=-1) |
| self.register_buffer('home', home) |
| self.anchors = nn.Parameter(home.clone()) |
| tri_dim = n_phases * A |
| self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden)) |
| self.pw_b1 = nn.Parameter(torch.zeros(1, 1, P, pw_hidden)) |
| self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d)) |
| self.pw_b2 = nn.Parameter(torch.zeros(1, 1, P, d)) |
| for p in range(P): |
| nn.init.xavier_normal_(self.pw_w1.data[p]) |
| nn.init.xavier_normal_(self.pw_w2.data[p]) |
| self.pw_norm = nn.LayerNorm(d) |
| self.gates = nn.Parameter(torch.full((P,), gate_init)) |
| self.norm = nn.LayerNorm(input_dim) |
|
|
| def drift(self): |
| h = F.normalize(self.home, dim=-1) |
| c = F.normalize(self.anchors, dim=-1) |
| return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)) |
|
|
| def at_phase(self, t): |
| h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) |
| omega = self.drift().unsqueeze(-1) |
| so = omega.sin().clamp(min=1e-7) |
| return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c |
|
|
| def forward(self, x): |
| if x.dim() == 2: x = x.unsqueeze(1) |
| B, S, D = x.shape |
| P, A, d = self.n_patches, self.n_anchors, self.patch_dim |
| patches = self.norm(x).reshape(B*S, P, d) |
| patches_n = F.normalize(patches, dim=-1) |
| tris = [] |
| for t in torch.linspace(0, 1, self.n_phases).tolist(): |
| at = F.normalize(self.at_phase(t), dim=-1) |
| tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at)) |
| tri = torch.cat(tris, dim=-1) |
| h = F.gelu(torch.einsum('bpt,pth->bph', tri, self.pw_w1) + self.pw_b1.squeeze(1)) |
| pw = self.pw_norm(torch.einsum('bph,phd->bpd', h, self.pw_w2) + self.pw_b2.squeeze(1)) |
| g = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1) |
| out = (g * pw + (1-g) * patches).reshape(B, S, D) |
| return x + out |
|
|
|
|
| |
| |
| |
|
|
| B = 4 |
| S = 256 |
| D = 128 |
| N_CV = 1500 |
|
|
| print("=" * 90) |
| print("HYBRID CONSTELLATION RELAY v2 β SPLIT GATES + CAUSAL TEST") |
| print(f" B={B}, S={S}, D={D} = {D//16}p Γ 16d") |
| print(f" Fixed: 8 anchors Γ 3 phases | Dynamic: 8 top-k with V-projection") |
| print(f" Device: {DEVICE}") |
| print("=" * 90) |
|
|
| configs = { |
| 'vanilla_attn': lambda: VanillaAttn(D, 8).to(DEVICE), |
| 'pure_relay': lambda: PureRelay(D, 16, 16, 3, 32).to(DEVICE), |
| 'hybrid_v2': lambda: HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE), |
| } |
|
|
|
|
| |
| print(f"\n{'β'*90}") |
| print("TEST 1: Single pass") |
| print(f"{'β'*90}") |
|
|
| x = torch.randn(B, S, D, device=DEVICE) |
| x_flat_n = F.normalize(x.reshape(B*S, D), dim=-1) |
| cv_base = compute_cv(x_flat_n, N_CV) |
| print(f" Baseline CV: {cv_base:.4f}") |
| print(f" {'arch':>15} {'params':>8} {'CV_n':>8} {'cos_orig':>10}") |
|
|
| for name, builder in configs.items(): |
| m = builder() |
| np_ = sum(p.numel() for p in m.parameters()) |
| with torch.no_grad(): |
| out = m(x) |
| out_n = F.normalize(out.reshape(B*S, D), dim=-1) |
| cv = compute_cv(out_n, N_CV) |
| cos = (x_flat_n * out_n).sum(-1).mean().item() |
| print(f" {name:>15} {np_:>8,} {cv:>8.4f} {cos:>10.6f}") |
|
|
| |
| hybrid_diag = HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE) |
| with torch.no_grad(): |
| _, diag = hybrid_diag(x, return_diagnostics=True) |
| print(f"\n Hybrid gates: fixed={diag['fixed_gate']:.4f} dyn={diag['dyn_gate']:.4f} " |
| f"identity={diag['identity_weight']:.4f}") |
|
|
|
|
| |
| print(f"\n{'β'*90}") |
| print("TEST 2: Depth 16") |
| print(f"{'β'*90}") |
|
|
| x = torch.randn(B, S, D, device=DEVICE) |
| x_flat_n = F.normalize(x.reshape(B*S, D), dim=-1) |
| checks = [1, 2, 4, 8, 12, 16] |
|
|
| for name, builder in configs.items(): |
| print(f"\n {name}:") |
| print(f" {'d':>4} {'CV_n':>8} {'cos':>10} {'eff_d':>8}") |
| stack = nn.ModuleList([builder() for _ in range(16)]) |
| z = x.clone() |
| for i, layer in enumerate(stack): |
| with torch.no_grad(): z = layer(z) |
| if (i+1) in checks: |
| zn = F.normalize(z.reshape(B*S, D), dim=-1) |
| print(f" {i+1:>4} {compute_cv(zn, N_CV):>8.4f} " |
| f"{(x_flat_n * zn).sum(-1).mean().item():>10.6f} " |
| f"{eff_dim(z.reshape(B*S, D)):>8.1f}") |
|
|
|
|
| |
| print(f"\n{'β'*90}") |
| print("TEST 3: Interleaved attn β hybrid β attn β hybrid") |
| print(f"{'β'*90}") |
|
|
| x = torch.randn(B, S, D, device=DEVICE) |
| x_flat_n = F.normalize(x.reshape(B*S, D), dim=-1) |
|
|
| attn_l = nn.ModuleList([VanillaAttn(D, 8).to(DEVICE) for _ in range(8)]) |
| hyb_l = nn.ModuleList([HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE) for _ in range(8)]) |
|
|
| print(f" {'step':>4} {'type':>8} {'CV_n':>8} {'cos':>10} {'eff_d':>8}") |
| z = x.clone() |
| step = 0 |
| for i in range(8): |
| with torch.no_grad(): z = attn_l[i](z) |
| step += 1 |
| if step in checks: |
| zn = F.normalize(z.reshape(B*S, D), dim=-1) |
| print(f" {step:>4} {'attn':>8} {compute_cv(zn, N_CV):>8.4f} " |
| f"{(x_flat_n * zn).sum(-1).mean().item():>10.6f} " |
| f"{eff_dim(z.reshape(B*S, D)):>8.1f}") |
| with torch.no_grad(): z = hyb_l[i](z) |
| step += 1 |
| if step in checks: |
| zn = F.normalize(z.reshape(B*S, D), dim=-1) |
| print(f" {step:>4} {'hybrid':>8} {compute_cv(zn, N_CV):>8.4f} " |
| f"{(x_flat_n * zn).sum(-1).mean().item():>10.6f} " |
| f"{eff_dim(z.reshape(B*S, D)):>8.1f}") |
|
|
|
|
| |
| print(f"\n{'β'*90}") |
| print("TEST 4: Causal intervention β does changing token 0 affect other tokens?") |
| print(f" Run same sequence twice, swap only token 0. Measure Ξ on tokens 1-31.") |
| print(f"{'β'*90}") |
|
|
| S_test = 32 |
| x_a = torch.randn(1, S_test, D, device=DEVICE) |
| x_b = x_a.clone() |
| x_b[:, 0] = torch.randn(1, D, device=DEVICE) |
|
|
| print(f" Token 0 cosine between A and B: " |
| f"{F.cosine_similarity(x_a[:, 0], x_b[:, 0]).item():.4f}") |
| print(f" Tokens 1-31 identical: " |
| f"{(x_a[:, 1:] == x_b[:, 1:]).all().item()}") |
|
|
| print(f"\n {'arch':>15} {'other_Ξ_norm':>12} {'other_Ξ_cos':>12} {'t0_Ξ_norm':>10}") |
|
|
| for name, builder in configs.items(): |
| m = builder() |
| with torch.no_grad(): |
| out_a = m(x_a) |
| out_b = m(x_b) |
|
|
| |
| delta_others = (out_a[:, 1:] - out_b[:, 1:]) |
| other_norm = delta_others.norm(dim=-1).mean().item() |
| |
| cos_others = F.cosine_similarity( |
| out_a[:, 1:].reshape(-1, D), |
| out_b[:, 1:].reshape(-1, D)).mean().item() |
| |
| t0_norm = (out_a[:, 0] - out_b[:, 0]).norm().item() |
|
|
| print(f" {name:>15} {other_norm:>12.6f} {1-cos_others:>12.8f} {t0_norm:>10.4f}") |
|
|
|
|
| |
| print(f"\n After 4 stacked layers:") |
| print(f" {'arch':>15} {'other_Ξ_norm':>12} {'other_Ξ_cos':>12}") |
|
|
| for name, builder in configs.items(): |
| layers = nn.ModuleList([builder() for _ in range(4)]) |
| with torch.no_grad(): |
| za, zb = x_a.clone(), x_b.clone() |
| for layer in layers: |
| za = layer(za) |
| zb = layer(zb) |
|
|
| delta = (za[:, 1:] - zb[:, 1:]) |
| other_norm = delta.norm(dim=-1).mean().item() |
| cos_others = F.cosine_similarity( |
| za[:, 1:].reshape(-1, D), |
| zb[:, 1:].reshape(-1, D)).mean().item() |
| print(f" {name:>15} {other_norm:>12.6f} {1-cos_others:>12.8f}") |
|
|
|
|
| |
| print(f"\n{'β'*90}") |
| print("TEST 5: Throughput") |
| print(f"{'β'*90}") |
|
|
| x_bench = torch.randn(B, S, D, device=DEVICE) |
| print(f" {'arch':>15} {'ms':>8} {'params':>10}") |
|
|
| for name, builder in configs.items(): |
| m = builder() |
| np_ = sum(p.numel() for p in m.parameters()) |
| for _ in range(5): |
| with torch.no_grad(): _ = m(x_bench) |
| torch.cuda.synchronize() |
| t0 = time.time() |
| for _ in range(100): |
| with torch.no_grad(): _ = m(x_bench) |
| torch.cuda.synchronize() |
| ms = (time.time() - t0) / 100 * 1000 |
| print(f" {name:>15} {ms:>8.2f} {np_:>10,}") |
|
|
|
|
| |
| print(f"\n{'β'*90}") |
| print("TEST 6: Sequence length scaling") |
| print(f"{'β'*90}") |
|
|
| print(f" {'S':>6} {'hybrid_ms':>10} {'attn_ms':>10} {'ratio':>8}") |
| for sl in [64, 128, 256, 512, 1024]: |
| xs = torch.randn(2, sl, D, device=DEVICE) |
| h_m = HybridRelay(D, 16, 8, 8, 3, 32).to(DEVICE) |
| a_m = VanillaAttn(D, 8).to(DEVICE) |
| with torch.no_grad(): _ = h_m(xs); _ = a_m(xs) |
| torch.cuda.synchronize() |
|
|
| t0 = time.time() |
| for _ in range(50): |
| with torch.no_grad(): _ = h_m(xs) |
| torch.cuda.synchronize() |
| h_ms = (time.time() - t0) / 50 * 1000 |
|
|
| t0 = time.time() |
| for _ in range(50): |
| with torch.no_grad(): _ = a_m(xs) |
| torch.cuda.synchronize() |
| a_ms = (time.time() - t0) / 50 * 1000 |
| print(f" {sl:>6} {h_ms:>10.2f} {a_ms:>10.2f} {h_ms/a_ms:>8.2f}Γ") |
|
|
|
|
| |
| |
| |
|
|
| print(f"\n{'='*90}") |
| print("SUMMARY") |
| print(f"{'='*90}") |
| print(f""" |
| Hybrid v2 architecture per patch: |
| Fixed: 8 anchors Γ 3 phases β MLP β fixed_out (gate β 0.047) |
| Dynamic: top-8 QΒ·K β gather V β weighted sum β MLP β dyn_out (gate β 0.269) |
| Output: fg*fixed + dg*dynamic + (1-fg-dg)*identity + skip |
| |
| GPT's challenge: |
| β Selective interaction β QΒ·K top-k selection |
| β Conditional transformation β separate MLPs for fixed/dynamic |
| β Information routing β V-projection carries information through geometric channel |
| |
| Key test: Causal intervention (Test 4) |
| If other_Ξ_norm > 0 for hybrid but β 0 for pure_relay, |
| cross-token routing is proven. |
| """) |
| print(f"{'='*90}") |
| print("DONE") |
| print(f"{'='*90}") |