#!/usr/bin/env python3 """ Constellation-Cantor Relay — O(S) Cross-Token Routing This is likely one of the most powerful routing mechanisms that can exist in current spectrum until more formulas are resolved. ======================================================= Replaces attention entirely with triangulation-mediated hierarchical routing. Architecture: per-token: constellation relay (triangulate → patchwork → gated residual) cross-token: Cantor router (hierarchical scatter/gather through anchor tree) The triangulation profile IS the routing key. Tokens near the same anchor on S^(d-1) share information at level 0. Anchor pairs share at level 1. Quads at level 2. Full global at level log2(A). Total cross-token cost: O(S × n_levels) = O(S × 4) for 16 anchors. Total per-token cost: O(S × tri_dim × pw_hidden). No attention anywhere. Fully O(S). Benchmarks: 1. Throughput: cantor-relay vs hybrid vs pure relay vs attention 2. Cross-token causal intervention at scale 3. Geometric preservation 4. Trained task requiring cross-token routing """ import os os.environ["CUDA_LAUNCH_BLOCKING"] = "1" import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import time import gc from collections import OrderedDict DEVICE = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # ══════════════════════════════════════════════════════════════════ # ACTIVATIONS # ══════════════════════════════════════════════════════════════════ class SquaredReLU(nn.Module): def forward(self, x): return F.relu(x) ** 2 # ══════════════════════════════════════════════════════════════════ # CONSTELLATION RELAY — per-token geometric layer # ══════════════════════════════════════════════════════════════════ class ConstellationRelay(nn.Module): """Per-token constellation triangulation + patchwork. O(S).""" def __init__(self, dim=256, patch_dim=16, n_anchors=16, n_phases=3): super().__init__() self.dim = dim self.patch_dim = patch_dim self.n_patches = dim // patch_dim self.n_anchors = n_anchors self.n_phases = n_phases P, A, d = self.n_patches, n_anchors, patch_dim self.ln = nn.LayerNorm(dim) home = torch.empty(P, A, d) nn.init.xavier_normal_(home.view(P * A, d)) home = F.normalize(home.view(P, A, d), dim=-1) self.register_buffer('home', home) self.anchors = nn.Parameter(home.clone()) tri_dim = P * A * n_phases self.tri_dim = tri_dim pw_hidden = tri_dim * 2 self.patchwork = nn.Sequential( nn.Linear(tri_dim, pw_hidden), SquaredReLU(), nn.LayerNorm(pw_hidden), nn.Linear(pw_hidden, dim), ) self.gate = nn.Parameter(torch.tensor(-3.0)) def drift(self): h = F.normalize(self.home.float(), dim=-1) c = F.normalize(self.anchors.float(), dim=-1) return torch.acos((h * c).sum(-1).clamp(-1 + 1e-6, 1 - 1e-6)) def at_phase(self, t): h = F.normalize(self.home.float(), dim=-1) c = F.normalize(self.anchors.float(), dim=-1) omega = self.drift().unsqueeze(-1) so = omega.sin().clamp(min=1e-6) return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c def triangulate(self, patches_n): phases = torch.linspace(0, 1, self.n_phases, device=patches_n.device).tolist() tris = [] for t in phases: at = F.normalize(self.at_phase(t), dim=-1).to(patches_n.dtype) tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at)) return torch.cat(tris, dim=-1).reshape(patches_n.shape[0], -1) def forward(self, x): """x: (B*S, D) or (B, S, D)""" is_seq = x.dim() == 3 if is_seq: B, S, D = x.shape x_flat = x.reshape(B * S, D) else: x_flat = x residual = x_flat h = self.ln(x_flat) patches = h.reshape(-1, self.n_patches, self.patch_dim) patches_n = F.normalize(patches, dim=-1) tri = self.triangulate(patches_n) pw_out = self.patchwork(tri) g = self.gate.sigmoid() out = residual + g * pw_out if is_seq: return out.reshape(B, S, D), tri.reshape(B, S, -1) return out, tri def forward_no_tri(self, x): """Original forward without returning tri — for compatibility.""" out, _ = self.forward(x) return out # ══════════════════════════════════════════════════════════════════ # CANTOR CONSTELLATION ROUTER — hierarchical cross-token, O(S) # ══════════════════════════════════════════════════════════════════ class CantorConstellationRouter(nn.Module): """ Hierarchical cross-token routing through the constellation anchor tree. The triangulation profile assigns each token to a region on S^(d-1). A binary tree over anchors defines the routing hierarchy: Level 0: A groups (per-anchor, local neighbors) Level 1: A/2 groups (anchor pairs, nearby interaction) Level 2: A/4 groups (quads, medium range) ... Level L: 1 group (global summary) At each level: 1. Soft-assign tokens to groups via triangulation weights 2. Weighted scatter: aggregate token representations per group 3. Transform: per-level MLP on group summaries 4. Weighted gather: distribute transformed summaries back to tokens 5. Gated residual addition Cost: O(S × L × D) where L = log2(A) + 1 = 5 for A=16. Memory: O(S × D + A × D) — no S² term anywhere. """ def __init__(self, dim=256, n_anchors=16, n_patches=16): super().__init__() self.dim = dim self.n_anchors = n_anchors self.n_patches = n_patches self.n_levels = int(math.log2(n_anchors)) + 1 # 5 for A=16 # Build anchor hierarchy — which anchors merge at each level # Level l: anchors are grouped into bins of size 2^l # The ordering is determined at init from anchor geometry # Per-level transforms: group_dim → dim self.level_mlps = nn.ModuleList() self.level_gates = nn.ParameterList() self.level_lns = nn.ModuleList() for l in range(self.n_levels): n_groups = max(1, n_anchors // (2 ** l)) self.level_mlps.append(nn.Sequential( nn.Linear(dim, dim * 2), SquaredReLU(), nn.Linear(dim * 2, dim), )) self.level_lns.append(nn.LayerNorm(dim)) self.level_gates.append(nn.Parameter(torch.tensor(-3.0))) # Projection from triangulation distances to routing weights # Input: per-token distances to each anchor (n_patches × n_anchors) self.weight_proj = nn.Linear(n_patches * n_anchors, n_anchors) def compute_routing_weights(self, tri, n_anchors): """ Extract soft anchor assignment weights from triangulation profile. tri: (BS, tri_dim) — full triangulation (n_patches × n_anchors × n_phases) Returns: (BS, n_anchors) — soft assignment weights (sum to 1) """ BS = tri.shape[0] # Extract phase-0 distances: first n_patches * n_anchors values # These are 1 - cos(token, anchor) for each patch × anchor phase0 = tri[:, :self.n_patches * n_anchors] # Average over patches to get per-anchor proximity # phase0: (BS, n_patches * n_anchors) → reshape → mean over patches dists = phase0.reshape(BS, self.n_patches, n_anchors).mean(dim=1) # (BS, A) # Convert distances to weights: closer = higher weight # dists are in [0, 2] (1 - cos), so proximity = 2 - dists proximity = (2.0 - dists).clamp(min=0) weights = F.softmax(proximity * 5.0, dim=-1) # temperature-scaled return weights def forward(self, x, tri): """ x: (B, S, D) token representations tri: (B, S, tri_dim) triangulation profiles from constellation Returns: (B, S, D) with cross-token information routed through anchor hierarchy """ B, S, D = x.shape x_flat = x.reshape(B * S, D) tri_flat = tri.reshape(B * S, -1) # Compute soft routing weights: (BS, A) weights = self.compute_routing_weights(tri_flat, self.n_anchors) h = x_flat # working copy for level in range(self.n_levels): group_size = 2 ** level n_groups = max(1, self.n_anchors // group_size) # Merge anchor weights into group weights # Reshape weights (BS, A) → (BS, n_groups, group_size) → sum over group if n_groups > 1: group_weights = weights.reshape(B * S, n_groups, group_size).sum(dim=-1) else: group_weights = weights.sum(dim=-1, keepdim=True) # (BS, 1) # Normalize group weights group_weights = group_weights / (group_weights.sum(dim=-1, keepdim=True) + 1e-8) # Weighted scatter: aggregate tokens into groups # group_sums[g] = sum_s(group_weights[s, g] * h[s]) # Shape: (BS, n_groups, 1) × (BS, 1, D) summed over BS # But we need per-batch grouping. Reshape to (B, S, ...) for batched ops. gw = group_weights.reshape(B, S, n_groups) # (B, S, G) hh = h.reshape(B, S, D) # (B, S, D) # Weighted sum: (B, G, S) @ (B, S, D) → (B, G, D) group_summary = torch.bmm(gw.transpose(1, 2), hh) # (B, G, D) # Normalize by total weight per group weight_sums = gw.sum(dim=1).unsqueeze(-1).clamp(min=1e-8) # (B, G, 1) group_summary = group_summary / weight_sums # Transform gs_flat = group_summary.reshape(B * n_groups, D) gs_flat = self.level_lns[level](gs_flat) gs_transformed = self.level_mlps[level](gs_flat) gs_transformed = gs_transformed.reshape(B, n_groups, D) # Weighted gather: distribute back to tokens # update[s] = sum_g(group_weights[s, g] * gs_transformed[g]) # (B, S, G) @ (B, G, D) → (B, S, D) token_update = torch.bmm(gw, gs_transformed).reshape(B * S, D) # Gated residual g = self.level_gates[level].sigmoid() h = h + g * token_update return h.reshape(B, S, D) # ══════════════════════════════════════════════════════════════════ # CONSTELLATION-CANTOR RELAY — FULL O(S) TRANSFORMER LAYER # ══════════════════════════════════════════════════════════════════ class ConstellationCantorRelay(nn.Module): """ Complete O(S) transformer layer. No attention. per-token: ConstellationRelay (triangulate → patchwork → gated residual) cross-token: CantorConstellationRouter (hierarchical scatter/gather through anchors) The triangulation from the per-token relay is reused as routing keys for the cross-token path — no redundant computation. """ def __init__(self, dim=256, patch_dim=16, n_anchors=16, n_phases=3): super().__init__() self.relay = ConstellationRelay( dim=dim, patch_dim=patch_dim, n_anchors=n_anchors, n_phases=n_phases) self.router = CantorConstellationRouter( dim=dim, n_anchors=n_anchors, n_patches=dim // patch_dim) self.gate_relay = nn.Parameter(torch.tensor(-2.0)) self.gate_router = nn.Parameter(torch.tensor(-2.0)) def forward(self, x): """x: (B, S, D)""" B, S, D = x.shape # Per-token relay — returns delta + triangulation relay_out, tri = self.relay(x) # (B, S, D), (B, S, tri_dim) relay_delta = relay_out - x # Cross-token routing using triangulation as routing key routed = self.router(x, tri) # (B, S, D) router_delta = routed - x # Gated combination gr = self.gate_relay.sigmoid() gc = self.gate_router.sigmoid() return x + gr * relay_delta + gc * router_delta # ══════════════════════════════════════════════════════════════════ # COMPARISON ARCHITECTURES # ══════════════════════════════════════════════════════════════════ class VanillaAttention(nn.Module): """Standard attention layer for comparison. O(S²).""" def __init__(self, dim=256, n_heads=4): super().__init__() self.n_heads = n_heads self.head_dim = dim // n_heads self.ln = nn.LayerNorm(dim) self.qkv = nn.Linear(dim, 3 * dim) self.proj = nn.Linear(dim, dim) def forward(self, x): B, S, D = x.shape h = self.ln(x) qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim) q, k, v = qkv.unbind(2) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) attn = F.scaled_dot_product_attention(q, k, v) return x + self.proj(attn.transpose(1, 2).reshape(B, S, D)) class HybridRelay(nn.Module): """Constellation relay + vanilla attention. For comparison.""" def __init__(self, dim=256, n_heads=4): super().__init__() self.relay = ConstellationRelay(dim=dim) self.n_heads = n_heads self.head_dim = dim // n_heads self.qkv = nn.Linear(dim, 3 * dim) self.attn_proj = nn.Linear(dim, dim) self.attn_ln = nn.LayerNorm(dim) self.gate_relay = nn.Parameter(torch.tensor(-2.0)) self.gate_attn = nn.Parameter(torch.tensor(-2.0)) def forward(self, x): B, S, D = x.shape relay_out = self.relay.forward_no_tri(x) relay_delta = relay_out - x h = self.attn_ln(x) qkv = self.qkv(h).reshape(B, S, 3, self.n_heads, self.head_dim) q, k, v = qkv.unbind(2) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) attn = F.scaled_dot_product_attention(q, k, v) attn_out = self.attn_proj(attn.transpose(1, 2).reshape(B, S, D)) gr = self.gate_relay.sigmoid() ga = self.gate_attn.sigmoid() return x + gr * relay_delta + ga * attn_out class PureRelayLayer(nn.Module): """Relay-only, no cross-token. For comparison.""" def __init__(self, dim=256): super().__init__() self.relay = ConstellationRelay(dim=dim) def forward(self, x): return self.relay.forward_no_tri(x) # ══════════════════════════════════════════════════════════════════ # UTILITIES # ══════════════════════════════════════════════════════════════════ def reset_vram(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() def peak_mb(): return torch.cuda.max_memory_allocated() / 1e6 D = 256 print("=" * 80) print("CONSTELLATION-CANTOR RELAY — O(S) CROSS-TOKEN ROUTING BENCHMARK") print(f" Device: {torch.cuda.get_device_name()}") print(f" VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") print(f" Dimension: {D}") print("=" * 80) # ══════════════════════════════════════════════════════════════════ # TEST 1: THROUGHPUT — ALL FOUR ARCHITECTURES # ══════════════════════════════════════════════════════════════════ print(f"\n{'━'*80}") print("TEST 1: Throughput Scaling — 4 architectures, S=64 to 131K") print(" Single layer, B=1, fp16") print(f"{'━'*80}") SEQ_LENGTHS = [64, 256, 1024, 4096, 16384, 32768, 65536, 131072] print(f"\n {'S':>8} {'relay':>9} {'cantor':>9} {'hybrid':>9} {'attn':>9} " f"{'c/r':>6} {'c/a':>6} {'c_MB':>7}") for S in SEQ_LENGTHS: results = {} for name, make_layer in [ ("relay", lambda: PureRelayLayer(D)), ("cantor", lambda: ConstellationCantorRelay(D)), ("hybrid", lambda: HybridRelay(D)), ("attn", lambda: VanillaAttention(D)), ]: try: reset_vram() layer = make_layer().to(DEVICE).half() x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1) # Warmup with torch.no_grad(): for _ in range(3): _ = layer(x) torch.cuda.synchronize() t0 = time.perf_counter() with torch.no_grad(): for _ in range(10): _ = layer(x) torch.cuda.synchronize() ms = (time.perf_counter() - t0) / 10 * 1000 mb = peak_mb() results[name] = (ms, mb) del layer, x reset_vram() except (torch.cuda.OutOfMemoryError, RuntimeError): results[name] = (float('inf'), float('inf')) reset_vram() r = results.get("relay", (0, 0))[0] c = results.get("cantor", (0, 0))[0] h = results.get("hybrid", (0, 0))[0] a = results.get("attn", (0, 0))[0] c_mb = results.get("cantor", (0, 0))[1] def fmt(v): return f"{v:>8.2f}ms" if v < float('inf') else " OOM" cr_ratio = f"{c/r:>5.1f}×" if r > 0 and c < float('inf') else " -" ca_ratio = f"{c/a:>5.1f}×" if a > 0 and a < float('inf') and c < float('inf') else " -" print(f" {S:>8} {fmt(r)} {fmt(c)} {fmt(h)} {fmt(a)} " f"{cr_ratio} {ca_ratio} {c_mb:>7.0f}") # ══════════════════════════════════════════════════════════════════ # TEST 2: CROSS-TOKEN CAUSAL INTERVENTION — CANTOR vs HYBRID # ══════════════════════════════════════════════════════════════════ print(f"\n{'━'*80}") print("TEST 2: Cross-Token Causal Intervention") print(" Modify token 0, measure effect on token S//2") print(" 4 layers deep. Compare: cantor relay vs hybrid vs pure relay") print(f"{'━'*80}") N_LAYERS = 4 print(f"\n {'S':>8} {'arch':>10} {'Δ_mid':>10} {'Δ_last':>10} " f"{'cos_orig':>10} {'time_ms':>10}") for S in [64, 256, 1024, 4096, 16384]: for arch_name, make_stack in [ ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(N_LAYERS)])), ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(N_LAYERS)])), ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(N_LAYERS)])), ]: try: reset_vram() torch.manual_seed(42) stack = make_stack().to(DEVICE).half() x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1) x_mod = x.clone() x_mod[:, 0] = F.normalize(torch.randn(1, D, device=DEVICE, dtype=torch.float16), dim=-1) torch.cuda.synchronize() t0 = time.perf_counter() with torch.no_grad(): h = x.clone() h_mod = x_mod.clone() for layer in stack: h = layer(h) h_mod = layer(h_mod) torch.cuda.synchronize() elapsed = (time.perf_counter() - t0) * 1000 mid = S // 2 delta_mid = (h[0, mid].float() - h_mod[0, mid].float()).norm().item() delta_last = (h[0, -1].float() - h_mod[0, -1].float()).norm().item() cos_orig = F.cosine_similarity( x[0, mid:mid+1].float(), h[0, mid:mid+1].float()).item() print(f" {S:>8} {arch_name:>10} {delta_mid:>10.4f} {delta_last:>10.4f} " f"{cos_orig:>10.4f} {elapsed:>10.1f}") del stack, x, x_mod, h, h_mod reset_vram() except (torch.cuda.OutOfMemoryError, RuntimeError): print(f" {S:>8} {arch_name:>10} OOM") reset_vram() print() # ══════════════════════════════════════════════════════════════════ # TEST 3: GEOMETRIC PRESERVATION WITH CROSS-TOKEN ROUTING # ══════════════════════════════════════════════════════════════════ print(f"\n{'━'*80}") print("TEST 3: Geometric Preservation — does Cantor routing hurt geometry?") print(" 8 layers, S=4096. Compare cos_to_orig, CV, eff_dim.") print(f"{'━'*80}") def compute_cv(points, n_samples=500): N = points.shape[0] if N < 5: return float('nan') points = F.normalize(points.float(), dim=-1) vols = [] for _ in range(n_samples): idx = torch.randperm(min(N, 2000), device=points.device)[:5] pts = points[idx].unsqueeze(0) gram = torch.bmm(pts, pts.transpose(1, 2)) norms = torch.diagonal(gram, dim1=1, dim2=2) d2 = norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram d2 = F.relu(d2) cm = torch.zeros(1, 6, 6, device=points.device, dtype=torch.float32) cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2 v2 = -torch.linalg.det(cm) / 9216 if v2[0].item() > 1e-20: vols.append(v2[0].sqrt().cpu()) if len(vols) < 50: return float('nan') vt = torch.stack(vols) return (vt.std() / (vt.mean() + 1e-8)).item() GEO_DEPTH = 8 GEO_S = 4096 print(f"\n {'arch':>10} {'cos_orig':>10} {'norm':>8} {'CV':>8} " f"{'eff_dim':>8} {'self_sim':>10}") for arch_name, make_stack in [ ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(GEO_DEPTH)])), ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(GEO_DEPTH)])), ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(GEO_DEPTH)])), ("attn", lambda: nn.ModuleList([VanillaAttention(D) for _ in range(GEO_DEPTH)])), ]: try: reset_vram() torch.manual_seed(42) stack = make_stack().to(DEVICE).half() x = F.normalize(torch.randn(1, GEO_S, D, device=DEVICE, dtype=torch.float16), dim=-1) with torch.no_grad(): h = x.clone() for layer in stack: h = layer(h) x_s = x[0, :512].float() h_s = h[0, :512].float() cos = F.cosine_similarity(x_s, h_s).mean().item() norm = h_s.norm(dim=-1).mean().item() h_n = F.normalize(h_s, dim=-1) sim = h_n @ h_n.T mask = ~torch.eye(512, device=DEVICE, dtype=torch.bool) self_sim = sim[mask].mean().item() cv = compute_cv(h_n, 500) _, S_vals, _ = torch.linalg.svd(h_n[:256], full_matrices=False) p = S_vals / S_vals.sum() ed = p.pow(2).sum().reciprocal().item() print(f" {arch_name:>10} {cos:>10.4f} {norm:>8.4f} {cv:>8.4f} " f"{ed:>8.1f} {self_sim:>10.6f}") del stack, x, h reset_vram() except (torch.cuda.OutOfMemoryError, RuntimeError): print(f" {arch_name:>10} OOM") reset_vram() # ══════════════════════════════════════════════════════════════════ # TEST 4: TRAINED CROSS-TOKEN TASK — ALL ARCHITECTURES # ══════════════════════════════════════════════════════════════════ print(f"\n{'━'*80}") print("TEST 4: Trained Cross-Token Task") print(" Label = (token_0_class + token_1_class) % 10") print(" Pure relay CANNOT solve this (zero cross-token info).") print(" 4 layers, 500 steps, S=8.") print(f"{'━'*80}") S_TASK = 8 N_CLS = 10 N_SAMPLES = 4096 STEPS = 500 torch.manual_seed(777) keys_a = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1) keys_b = F.normalize(torch.randn(N_CLS, D, device=DEVICE), dim=-1) task_x = F.normalize(torch.randn(N_SAMPLES, S_TASK, D, device=DEVICE), dim=-1).clone() label_a = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE) label_b = torch.randint(0, N_CLS, (N_SAMPLES,), dtype=torch.long, device=DEVICE) task_x[:, 0] = keys_a[label_a] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2 task_x[:, 1] = keys_b[label_b] + torch.randn(N_SAMPLES, D, device=DEVICE) * 0.2 task_x = F.normalize(task_x, dim=-1) task_y = ((label_a + label_b) % N_CLS).long() print(f"\n {'arch':>10} {'acc':>8} {'loss':>8} {'cross_Δ':>10} {'params':>10}") for arch_name, make_stack in [ ("relay", lambda: nn.ModuleList([PureRelayLayer(D) for _ in range(4)])), ("cantor", lambda: nn.ModuleList([ConstellationCantorRelay(D) for _ in range(4)])), ("hybrid", lambda: nn.ModuleList([HybridRelay(D) for _ in range(4)])), ("attn", lambda: nn.ModuleList([VanillaAttention(D) for _ in range(4)])), ]: torch.manual_seed(42) class TaskModel(nn.Module): def __init__(self, stack): super().__init__() self.layers = stack self.pool = nn.Linear(D * S_TASK, D) self.head = nn.Linear(D, N_CLS) def forward(self, x): for layer in self.layers: x = layer(x) return self.head(F.gelu(self.pool(x.reshape(x.shape[0], -1)))) model = TaskModel(make_stack()).to(DEVICE) n_params = sum(p.numel() for p in model.parameters()) opt = torch.optim.Adam(model.parameters(), lr=3e-4) for step in range(STEPS): idx = torch.randint(0, N_SAMPLES, (128,)) logits = model(task_x[idx]) loss = F.cross_entropy(logits, task_y[idx]) if torch.isnan(loss) or torch.isinf(loss): break opt.zero_grad() loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() model.eval() with torch.no_grad(): logits = model(task_x[:1024]) acc = (logits.argmax(-1) == task_y[:1024]).float().mean().item() final_loss = F.cross_entropy(logits, task_y[:1024]).item() # Cross-token intervention h1 = task_x[:64].clone() for layer in model.layers: h1 = layer(h1) h2 = task_x[:64].clone() h2[:, 0] = F.normalize(torch.randn(64, D, device=DEVICE), dim=-1) for layer in model.layers: h2 = layer(h2) cross_delta = (h1[:, 1] - h2[:, 1]).norm(dim=-1).mean().item() print(f" {arch_name:>10} {acc:>8.1%} {final_loss:>8.4f} {cross_delta:>10.4f} {n_params:>10,}") del model reset_vram() # ══════════════════════════════════════════════════════════════════ # TEST 5: THE O(S²) WALL — CANTOR vs ATTENTION at depth 8 # ══════════════════════════════════════════════════════════════════ print(f"\n{'━'*80}") print("TEST 5: The O(S²) Wall — Cantor vs Attention, 8 layers deep") print(f"{'━'*80}") WALL_DEPTH = 8 print(f"\n {'S':>8} {'cantor_ms':>10} {'attn_ms':>10} {'speedup':>8} " f"{'c_cos':>8} {'a_cos':>8} {'c_MB':>8} {'a_MB':>8}") for S in [1024, 4096, 8192, 16384, 32768, 65536, 131072]: c_result = None a_result = None # Cantor try: reset_vram() torch.manual_seed(42) c_stack = nn.ModuleList([ ConstellationCantorRelay(D) for _ in range(WALL_DEPTH) ]).to(DEVICE).half() x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1) with torch.no_grad(): h = x.clone() for layer in c_stack: h = layer(h) torch.cuda.synchronize() t0 = time.perf_counter() with torch.no_grad(): h = x.clone() for layer in c_stack: h = layer(h) torch.cuda.synchronize() c_ms = (time.perf_counter() - t0) * 1000 c_mb = peak_mb() c_cos = F.cosine_similarity(x[0, :256].float(), h[0, :256].float()).mean().item() c_result = (c_ms, c_cos, c_mb) del x, h, c_stack reset_vram() except (torch.cuda.OutOfMemoryError, RuntimeError): reset_vram() # Attention try: reset_vram() torch.manual_seed(42) a_stack = nn.ModuleList([ VanillaAttention(D) for _ in range(WALL_DEPTH) ]).to(DEVICE).half() x = F.normalize(torch.randn(1, S, D, device=DEVICE, dtype=torch.float16), dim=-1) with torch.no_grad(): h = x.clone() for layer in a_stack: h = layer(h) torch.cuda.synchronize() t0 = time.perf_counter() with torch.no_grad(): h = x.clone() for layer in a_stack: h = layer(h) torch.cuda.synchronize() a_ms = (time.perf_counter() - t0) * 1000 a_mb = peak_mb() a_cos = F.cosine_similarity(x[0, :256].float(), h[0, :256].float()).mean().item() a_result = (a_ms, a_cos, a_mb) del x, h, a_stack reset_vram() except (torch.cuda.OutOfMemoryError, RuntimeError): reset_vram() c_str = f"{c_result[0]:>9.1f}ms" if c_result else " OOM" a_str = f"{a_result[0]:>9.1f}ms" if a_result else " OOM" sp = f"{a_result[0]/c_result[0]:>7.1f}×" if c_result and a_result else " -" cc = f"{c_result[1]:>8.4f}" if c_result else " ---" ac = f"{a_result[1]:>8.4f}" if a_result else " ---" cm = f"{c_result[2]:>8.0f}" if c_result else " OOM" am = f"{a_result[2]:>8.0f}" if a_result else " OOM" print(f" {S:>8} {c_str} {a_str} {sp} {cc} {ac} {cm} {am}") if c_result is None: print(f" → Cantor OOM at S={S}, stopping") break # ══════════════════════════════════════════════════════════════════ # SUMMARY # ══════════════════════════════════════════════════════════════════ print(f"\n{'='*80}") print("CONSTELLATION-CANTOR RELAY — BENCHMARK COMPLETE") print(f"{'='*80}") print(f""" Architecture: per-token: constellation relay (triangulate → patchwork → gated residual) cross-token: cantor router (hierarchical scatter/gather through anchor tree) total: O(S) time, O(S) memory, no attention 5 tests: T1: Throughput — relay vs cantor vs hybrid vs attention, S to 131K T2: Cross-token causal intervention — who routes strongest? T3: Geometric preservation — does cross-token routing hurt geometry? T4: Trained cross-token task — accuracy on interaction-dependent labels T5: O(S²) wall — cantor vs attention at 8 layers to OOM Key questions answered: • Is the cantor router faster than attention at all sequence lengths? • Does it provide meaningful cross-token interaction? • Does the routing hurt per-token geometric preservation? • Can it solve tasks that require cross-token information? """)