#!/usr/bin/env python3 """ ============================================================================= BENCHMARK v3: RichNeuron (Mult × Periodic + Residual) vs Vanilla MLP ============================================================================= Strictly matched param budgets. Single run per task (for speed on CPU). 7 diverse tasks covering regression, classification, memorization, frequency. RichNeuron layer: y = LayerNorm( (W1·x) ⊙ sin(ω·W2·x+b) + W1·x ) - W1 creates linear features (like standard) - W2 + sin() creates periodic features - ⊙ (element-wise multiply) creates CROSS-TERMS between them - +W1·x residual prevents scalar collapse - LayerNorm stabilizes across depth Run: pip install torch numpy && python benchmark.py ============================================================================= """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import math import time import json DEVICE = 'cpu' def set_seed(s=42): torch.manual_seed(s) np.random.seed(s) # ============================================================================ # ARCHITECTURES # ============================================================================ class RichNeuronLayer(nn.Module): """ y = LayerNorm( (W1·x) ⊙ sin(ω · W2·x + b) + W1·x ) Multiplicative interaction between linear and periodic branches. The residual (+W1·x) prevents scalar collapse. LayerNorm stabilizes magnitude across depth. """ def __init__(self, in_dim, out_dim, omega_0=30.0): super().__init__() self.W1 = nn.Linear(in_dim, out_dim, bias=False) self.W2 = nn.Linear(in_dim, out_dim, bias=True) self.omega_0 = omega_0 self.ln = nn.LayerNorm(out_dim) with torch.no_grad(): nn.init.xavier_uniform_(self.W1.weight) bound = math.sqrt(6.0 / in_dim) / omega_0 self.W2.weight.uniform_(-bound, bound) self.W2.bias.uniform_(-math.pi, math.pi) def forward(self, x): linear = self.W1(x) periodic = torch.sin(self.omega_0 * self.W2(x)) return self.ln(linear * periodic + linear) class VanillaMLP(nn.Module): def __init__(self, in_dim, out_dim, hidden_dim, n_hidden): super().__init__() layers = [] prev = in_dim for _ in range(n_hidden): layers.extend([nn.Linear(prev, hidden_dim), nn.ReLU()]) prev = hidden_dim layers.append(nn.Linear(prev, out_dim)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class RichNet(nn.Module): def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0): super().__init__() layers = [] prev = in_dim for _ in range(n_hidden): layers.append(RichNeuronLayer(prev, hidden_dim, omega_0)) prev = hidden_dim layers.append(nn.Linear(prev, out_dim)) self.layers = nn.ModuleList(layers) def forward(self, x): for l in self.layers: x = l(x) return x def count_params(m): return sum(p.numel() for p in m.parameters() if p.requires_grad) def find_hidden(in_d, out_d, n_h, target_p, model_cls, **kw): """Binary search for hidden dim matching target param count.""" lo, hi, best_h = 2, 1024, 2 while lo <= hi: mid = (lo + hi) // 2 m = model_cls(in_d, out_d, mid, n_h, **kw) p = count_params(m) if abs(p - target_p) < abs(count_params(model_cls(in_d, out_d, best_h, n_h, **kw)) - target_p): best_h = mid if p < target_p: lo = mid + 1 else: hi = mid - 1 return best_h # ============================================================================ # TRAINING (mini-batch for speed) # ============================================================================ def train_regression(model, x_tr, y_tr, x_te, y_te, epochs, lr, bs=256): opt = torch.optim.Adam(model.parameters(), lr=lr) sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) best = float('inf') n = len(x_tr) for ep in range(epochs): model.train() perm = torch.randperm(n) for i in range(0, n, bs): idx = perm[i:i+bs] loss = F.mse_loss(model(x_tr[idx]), y_tr[idx]) opt.zero_grad(); loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() sch.step() if (ep+1) % max(1, epochs//10) == 0: model.eval() with torch.no_grad(): tl = F.mse_loss(model(x_te), y_te).item() best = min(best, tl) model.eval() with torch.no_grad(): best = min(best, F.mse_loss(model(x_te), y_te).item()) return best def train_classification(model, x_tr, y_tr, x_te, y_te, epochs, lr, bs=256): opt = torch.optim.Adam(model.parameters(), lr=lr) sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) best = 0 n = len(x_tr) for ep in range(epochs): model.train() perm = torch.randperm(n) for i in range(0, n, bs): idx = perm[i:i+bs] loss = F.cross_entropy(model(x_tr[idx]), y_tr[idx]) opt.zero_grad(); loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() sch.step() if (ep+1) % max(1, epochs//10) == 0: model.eval() with torch.no_grad(): acc = (model(x_te).argmax(1) == y_te).float().mean().item() best = max(best, acc) model.eval() with torch.no_grad(): best = max(best, (model(x_te).argmax(1) == y_te).float().mean().item()) return best # ============================================================================ # DATA # ============================================================================ def data_complex(n=2000): x = torch.rand(n, 4)*2-1 y = torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2)) return x, y.unsqueeze(1) def data_nested(n=2000): x = torch.rand(n, 2)*2-1 y = torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1]) return x, y.unsqueeze(1) def data_spiral(n=1500): t = torch.linspace(0, 4*np.pi, n//2) r = torch.linspace(0.3, 2, n//2) x1 = torch.stack([r*torch.cos(t), r*torch.sin(t)], 1) x2 = torch.stack([r*torch.cos(t+np.pi), r*torch.sin(t+np.pi)], 1) x = torch.cat([x1,x2]) + torch.randn(n,2)*0.05 y = torch.cat([torch.zeros(n//2), torch.ones(n//2)]).long() p = torch.randperm(n); return x[p], y[p] def data_checker(n=2000, freq=3): x = torch.rand(n,2)*2-1 y = ((torch.sin(freq*math.pi*x[:,0])*torch.sin(freq*math.pi*x[:,1])) > 0).long() return x, y def data_highfreq(n=1500): x = torch.linspace(-1,1,n).unsqueeze(1) y = torch.sin(20*x)+torch.sin(50*x)+0.5*torch.sin(100*x) return x, y def data_memorize(n=200, kd=8, vd=4): return torch.randn(n, kd), torch.randn(n, vd) def data_mnist_or_synth(): try: import torchvision, torchvision.transforms as T tr = torchvision.datasets.MNIST('./data',True,T.ToTensor(),download=True) te = torchvision.datasets.MNIST('./data',False,T.ToTensor(),download=True) return (tr.data[:3000].float().view(-1,784)/255., tr.targets[:3000], te.data[:500].float().view(-1,784)/255., te.targets[:500], "MNIST", 784) except: d = 64; centers = torch.randn(10, d) def make(n): y = torch.randint(0,10,(n,)) x = torch.randn(n, d)*0.5 for i in range(n): x[i] += centers[y[i]] return x, y tx, ty = make(2000); ex, ey = make(400) return tx, ty, ex, ey, "Synth-10class", d # ============================================================================ # MAIN # ============================================================================ def main(): print("="*78) print(" BENCHMARK: RichNeuron vs Vanilla MLP") print(" RichNeuron = (W1·x) ⊙ sin(ω·W2·x+b) + W1·x [Mult×Periodic+Skip]") print(" Fair comparison: SAME parameter budget for both") print("="*78) N_HIDDEN = 3 results = {} tasks = [ ("Complex Compositional Fn", "regression", data_complex, 4, 1, 8000, 1500, 1e-3, 30.0, 1500), ("Nested Nonlinear Fn", "regression", data_nested, 2, 1, 4000, 1500, 1e-3, 20.0, 1500), ("Two-Spiral Classification", "classification", data_spiral, 2, 2, 4000, 1000, 1e-3, 15.0, 1000), ("Checkerboard Pattern", "classification", data_checker, 2, 2, 4000, 1000, 1e-3, 20.0, 1500), ("High-Frequency Signal", "regression", data_highfreq, 1, 1, 10000, 2000, 1e-3, 60.0, 1000), ("Knowledge Memorization", "regression", data_memorize, 8, 4, 6000, 3000, 1e-3, 10.0, 200), ] for name, ttype, datafn, ind, outd, budget, epochs, lr, omega, split in tasks: print(f"\n{'─'*78}") print(f" {name}") print(f" Type: {ttype} | Params: ~{budget:,} | Epochs: {epochs}") print(f"{'─'*78}") h_v = find_hidden(ind, outd, N_HIDDEN, budget, VanillaMLP) h_r = find_hidden(ind, outd, N_HIDDEN, budget, RichNet, omega_0=omega) set_seed() mv = VanillaMLP(ind, outd, h_v, N_HIDDEN) mr = RichNet(ind, outd, h_r, N_HIDDEN, omega) vp, rp = count_params(mv), count_params(mr) print(f" Vanilla: hidden={h_v:>4}, params={vp:>6,}") print(f" Rich: hidden={h_r:>4}, params={rp:>6,}") set_seed() x, y = datafn() if split >= len(x): xtr, ytr, xte, yte = x, y, x, y else: xtr, ytr = x[:split], y[:split] xte, yte = x[split:], y[split:] set_seed(123) mv = VanillaMLP(ind, outd, h_v, N_HIDDEN) t0 = time.time() if ttype == 'regression': vs = train_regression(mv, xtr, ytr, xte, yte, epochs, lr) else: vs = train_classification(mv, xtr, ytr, xte, yte, epochs, lr) vt = time.time() - t0 set_seed(123) mr = RichNet(ind, outd, h_r, N_HIDDEN, omega) t0 = time.time() if ttype == 'regression': rs = train_regression(mr, xtr, ytr, xte, yte, epochs, lr) else: rs = train_classification(mr, xtr, ytr, xte, yte, epochs, lr) rt = time.time() - t0 if ttype == 'regression': winner = 'rich' if rs < vs else 'vanilla' vs_str, rs_str = f"{vs:.6f}", f"{rs:.6f}" metric = "MSE ↓" else: winner = 'rich' if rs > vs else 'vanilla' vs_str, rs_str = f"{vs:.1%}", f"{rs:.1%}" metric = "Acc ↑" w = "🟢 RichNeuron" if winner == 'rich' else "⚪ Vanilla" print(f"\n {metric:<20} Vanilla: {vs_str:>12} Rich: {rs_str:>12} → {w}") print(f" Time (s) Vanilla: {vt:>11.1f}s Rich: {rt:>11.1f}s") results[name] = {'v': vs, 'r': rs, 'vp': vp, 'rp': rp, 'vt': vt, 'rt': rt, 'winner': winner, 'type': ttype} # ----- MNIST ----- print(f"\n{'─'*78}") print(f" MNIST / Structured Classification") print(f"{'─'*78}") set_seed() txr, tyr, txe, tye, dsn, ind = data_mnist_or_synth() budget = 30000 h_v = find_hidden(ind, 10, N_HIDDEN, budget, VanillaMLP) h_r = find_hidden(ind, 10, N_HIDDEN, budget, RichNet, omega_0=10.0) set_seed(123) mv = VanillaMLP(ind, 10, h_v, N_HIDDEN) vp = count_params(mv) vs = train_classification(mv, txr, tyr, txe, tye, 500, 1e-3) set_seed(123) mr = RichNet(ind, 10, h_r, N_HIDDEN, 10.0) rp = count_params(mr) rs = train_classification(mr, txr, tyr, txe, tye, 500, 1e-3) winner = 'rich' if rs > vs else 'vanilla' w = "🟢 RichNeuron" if winner == 'rich' else "⚪ Vanilla" print(f" {dsn}: Vanilla({vp:,}p)={vs:.1%} Rich({rp:,}p)={rs:.1%} → {w}") results[dsn] = {'v': vs, 'r': rs, 'vp': vp, 'rp': rp, 'winner': winner, 'type': 'classification'} # ============================================================ # GRAND SUMMARY # ============================================================ print("\n" + "="*78) print(" GRAND SUMMARY") print("="*78) rich_w = sum(1 for r in results.values() if r['winner'] == 'rich') van_w = sum(1 for r in results.values() if r['winner'] == 'vanilla') print(f"\n {'Task':<35} {'Params':>12} {'Vanilla':>12} {'Rich':>12} {'Winner':>14}") print(f" {'─'*85}") for name, r in results.items(): ps = f"{r['vp']}/{r['rp']}" if r['type'] == 'regression': vs = f"{r['v']:.6f}" rs = f"{r['r']:.6f}" else: vs = f"{r['v']:.1%}" rs = f"{r['r']:.1%}" w = "🟢 Rich" if r['winner'] == 'rich' else "⚪ Vanilla" print(f" {name:<35} {ps:>12} {vs:>12} {rs:>12} {w:>14}") print(f"\n {'─'*85}") print(f" 🏆 FINAL SCORE: RichNeuron {rich_w} vs Vanilla MLP {van_w}") print(f" {'─'*85}") with open('results.json', 'w') as f: json.dump({k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv for kk, vv in v.items()} for k, v in results.items()}, f, indent=2) print("\n Results saved to results.json") if __name__ == "__main__": main()