| |
| """ |
| ============================================================================= |
| BENCHMARK v3: RichNeuron (Mult × Periodic + Residual) vs Vanilla MLP |
| ============================================================================= |
| Strictly matched param budgets. Single run per task (for speed on CPU). |
| 7 diverse tasks covering regression, classification, memorization, frequency. |
| |
| RichNeuron layer: y = LayerNorm( (W1·x) ⊙ sin(ω·W2·x+b) + W1·x ) |
| - W1 creates linear features (like standard) |
| - W2 + sin() creates periodic features |
| - ⊙ (element-wise multiply) creates CROSS-TERMS between them |
| - +W1·x residual prevents scalar collapse |
| - LayerNorm stabilizes across depth |
| |
| Run: pip install torch numpy && python benchmark.py |
| ============================================================================= |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import math |
| import time |
| import json |
|
|
| DEVICE = 'cpu' |
|
|
| def set_seed(s=42): |
| torch.manual_seed(s) |
| np.random.seed(s) |
|
|
| |
| |
| |
|
|
| class RichNeuronLayer(nn.Module): |
| """ |
| y = LayerNorm( (W1·x) ⊙ sin(ω · W2·x + b) + W1·x ) |
| |
| Multiplicative interaction between linear and periodic branches. |
| The residual (+W1·x) prevents scalar collapse. |
| LayerNorm stabilizes magnitude across depth. |
| """ |
| def __init__(self, in_dim, out_dim, omega_0=30.0): |
| super().__init__() |
| self.W1 = nn.Linear(in_dim, out_dim, bias=False) |
| self.W2 = nn.Linear(in_dim, out_dim, bias=True) |
| self.omega_0 = omega_0 |
| self.ln = nn.LayerNorm(out_dim) |
| |
| with torch.no_grad(): |
| nn.init.xavier_uniform_(self.W1.weight) |
| bound = math.sqrt(6.0 / in_dim) / omega_0 |
| self.W2.weight.uniform_(-bound, bound) |
| self.W2.bias.uniform_(-math.pi, math.pi) |
| |
| def forward(self, x): |
| linear = self.W1(x) |
| periodic = torch.sin(self.omega_0 * self.W2(x)) |
| return self.ln(linear * periodic + linear) |
|
|
|
|
| class VanillaMLP(nn.Module): |
| def __init__(self, in_dim, out_dim, hidden_dim, n_hidden): |
| super().__init__() |
| layers = [] |
| prev = in_dim |
| for _ in range(n_hidden): |
| layers.extend([nn.Linear(prev, hidden_dim), nn.ReLU()]) |
| prev = hidden_dim |
| layers.append(nn.Linear(prev, out_dim)) |
| self.net = nn.Sequential(*layers) |
| |
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class RichNet(nn.Module): |
| def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0): |
| super().__init__() |
| layers = [] |
| prev = in_dim |
| for _ in range(n_hidden): |
| layers.append(RichNeuronLayer(prev, hidden_dim, omega_0)) |
| prev = hidden_dim |
| layers.append(nn.Linear(prev, out_dim)) |
| self.layers = nn.ModuleList(layers) |
| |
| def forward(self, x): |
| for l in self.layers: |
| x = l(x) |
| return x |
|
|
|
|
| def count_params(m): |
| return sum(p.numel() for p in m.parameters() if p.requires_grad) |
|
|
|
|
| def find_hidden(in_d, out_d, n_h, target_p, model_cls, **kw): |
| """Binary search for hidden dim matching target param count.""" |
| lo, hi, best_h = 2, 1024, 2 |
| while lo <= hi: |
| mid = (lo + hi) // 2 |
| m = model_cls(in_d, out_d, mid, n_h, **kw) |
| p = count_params(m) |
| if abs(p - target_p) < abs(count_params(model_cls(in_d, out_d, best_h, n_h, **kw)) - target_p): |
| best_h = mid |
| if p < target_p: |
| lo = mid + 1 |
| else: |
| hi = mid - 1 |
| return best_h |
|
|
|
|
| |
| |
| |
|
|
| def train_regression(model, x_tr, y_tr, x_te, y_te, epochs, lr, bs=256): |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
| sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
| best = float('inf') |
| n = len(x_tr) |
| |
| for ep in range(epochs): |
| model.train() |
| perm = torch.randperm(n) |
| for i in range(0, n, bs): |
| idx = perm[i:i+bs] |
| loss = F.mse_loss(model(x_tr[idx]), y_tr[idx]) |
| opt.zero_grad(); loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| sch.step() |
| |
| if (ep+1) % max(1, epochs//10) == 0: |
| model.eval() |
| with torch.no_grad(): |
| tl = F.mse_loss(model(x_te), y_te).item() |
| best = min(best, tl) |
| |
| model.eval() |
| with torch.no_grad(): |
| best = min(best, F.mse_loss(model(x_te), y_te).item()) |
| return best |
|
|
|
|
| def train_classification(model, x_tr, y_tr, x_te, y_te, epochs, lr, bs=256): |
| opt = torch.optim.Adam(model.parameters(), lr=lr) |
| sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) |
| best = 0 |
| n = len(x_tr) |
| |
| for ep in range(epochs): |
| model.train() |
| perm = torch.randperm(n) |
| for i in range(0, n, bs): |
| idx = perm[i:i+bs] |
| loss = F.cross_entropy(model(x_tr[idx]), y_tr[idx]) |
| opt.zero_grad(); loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
| sch.step() |
| |
| if (ep+1) % max(1, epochs//10) == 0: |
| model.eval() |
| with torch.no_grad(): |
| acc = (model(x_te).argmax(1) == y_te).float().mean().item() |
| best = max(best, acc) |
| |
| model.eval() |
| with torch.no_grad(): |
| best = max(best, (model(x_te).argmax(1) == y_te).float().mean().item()) |
| return best |
|
|
|
|
| |
| |
| |
|
|
| def data_complex(n=2000): |
| x = torch.rand(n, 4)*2-1 |
| y = torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2)) |
| return x, y.unsqueeze(1) |
|
|
| def data_nested(n=2000): |
| x = torch.rand(n, 2)*2-1 |
| y = torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1]) |
| return x, y.unsqueeze(1) |
|
|
| def data_spiral(n=1500): |
| t = torch.linspace(0, 4*np.pi, n//2) |
| r = torch.linspace(0.3, 2, n//2) |
| x1 = torch.stack([r*torch.cos(t), r*torch.sin(t)], 1) |
| x2 = torch.stack([r*torch.cos(t+np.pi), r*torch.sin(t+np.pi)], 1) |
| x = torch.cat([x1,x2]) + torch.randn(n,2)*0.05 |
| y = torch.cat([torch.zeros(n//2), torch.ones(n//2)]).long() |
| p = torch.randperm(n); return x[p], y[p] |
|
|
| def data_checker(n=2000, freq=3): |
| x = torch.rand(n,2)*2-1 |
| y = ((torch.sin(freq*math.pi*x[:,0])*torch.sin(freq*math.pi*x[:,1])) > 0).long() |
| return x, y |
|
|
| def data_highfreq(n=1500): |
| x = torch.linspace(-1,1,n).unsqueeze(1) |
| y = torch.sin(20*x)+torch.sin(50*x)+0.5*torch.sin(100*x) |
| return x, y |
|
|
| def data_memorize(n=200, kd=8, vd=4): |
| return torch.randn(n, kd), torch.randn(n, vd) |
|
|
| def data_mnist_or_synth(): |
| try: |
| import torchvision, torchvision.transforms as T |
| tr = torchvision.datasets.MNIST('./data',True,T.ToTensor(),download=True) |
| te = torchvision.datasets.MNIST('./data',False,T.ToTensor(),download=True) |
| return (tr.data[:3000].float().view(-1,784)/255., tr.targets[:3000], |
| te.data[:500].float().view(-1,784)/255., te.targets[:500], "MNIST", 784) |
| except: |
| d = 64; centers = torch.randn(10, d) |
| def make(n): |
| y = torch.randint(0,10,(n,)) |
| x = torch.randn(n, d)*0.5 |
| for i in range(n): x[i] += centers[y[i]] |
| return x, y |
| tx, ty = make(2000); ex, ey = make(400) |
| return tx, ty, ex, ey, "Synth-10class", d |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| print("="*78) |
| print(" BENCHMARK: RichNeuron vs Vanilla MLP") |
| print(" RichNeuron = (W1·x) ⊙ sin(ω·W2·x+b) + W1·x [Mult×Periodic+Skip]") |
| print(" Fair comparison: SAME parameter budget for both") |
| print("="*78) |
| |
| N_HIDDEN = 3 |
| results = {} |
| |
| tasks = [ |
| ("Complex Compositional Fn", "regression", data_complex, 4, 1, 8000, 1500, 1e-3, 30.0, 1500), |
| ("Nested Nonlinear Fn", "regression", data_nested, 2, 1, 4000, 1500, 1e-3, 20.0, 1500), |
| ("Two-Spiral Classification", "classification", data_spiral, 2, 2, 4000, 1000, 1e-3, 15.0, 1000), |
| ("Checkerboard Pattern", "classification", data_checker, 2, 2, 4000, 1000, 1e-3, 20.0, 1500), |
| ("High-Frequency Signal", "regression", data_highfreq, 1, 1, 10000, 2000, 1e-3, 60.0, 1000), |
| ("Knowledge Memorization", "regression", data_memorize, 8, 4, 6000, 3000, 1e-3, 10.0, 200), |
| ] |
| |
| for name, ttype, datafn, ind, outd, budget, epochs, lr, omega, split in tasks: |
| print(f"\n{'─'*78}") |
| print(f" {name}") |
| print(f" Type: {ttype} | Params: ~{budget:,} | Epochs: {epochs}") |
| print(f"{'─'*78}") |
| |
| h_v = find_hidden(ind, outd, N_HIDDEN, budget, VanillaMLP) |
| h_r = find_hidden(ind, outd, N_HIDDEN, budget, RichNet, omega_0=omega) |
| |
| set_seed() |
| mv = VanillaMLP(ind, outd, h_v, N_HIDDEN) |
| mr = RichNet(ind, outd, h_r, N_HIDDEN, omega) |
| vp, rp = count_params(mv), count_params(mr) |
| |
| print(f" Vanilla: hidden={h_v:>4}, params={vp:>6,}") |
| print(f" Rich: hidden={h_r:>4}, params={rp:>6,}") |
| |
| set_seed() |
| x, y = datafn() |
| if split >= len(x): |
| xtr, ytr, xte, yte = x, y, x, y |
| else: |
| xtr, ytr = x[:split], y[:split] |
| xte, yte = x[split:], y[split:] |
| |
| set_seed(123) |
| mv = VanillaMLP(ind, outd, h_v, N_HIDDEN) |
| t0 = time.time() |
| if ttype == 'regression': |
| vs = train_regression(mv, xtr, ytr, xte, yte, epochs, lr) |
| else: |
| vs = train_classification(mv, xtr, ytr, xte, yte, epochs, lr) |
| vt = time.time() - t0 |
| |
| set_seed(123) |
| mr = RichNet(ind, outd, h_r, N_HIDDEN, omega) |
| t0 = time.time() |
| if ttype == 'regression': |
| rs = train_regression(mr, xtr, ytr, xte, yte, epochs, lr) |
| else: |
| rs = train_classification(mr, xtr, ytr, xte, yte, epochs, lr) |
| rt = time.time() - t0 |
| |
| if ttype == 'regression': |
| winner = 'rich' if rs < vs else 'vanilla' |
| vs_str, rs_str = f"{vs:.6f}", f"{rs:.6f}" |
| metric = "MSE ↓" |
| else: |
| winner = 'rich' if rs > vs else 'vanilla' |
| vs_str, rs_str = f"{vs:.1%}", f"{rs:.1%}" |
| metric = "Acc ↑" |
| |
| w = "🟢 RichNeuron" if winner == 'rich' else "⚪ Vanilla" |
| |
| print(f"\n {metric:<20} Vanilla: {vs_str:>12} Rich: {rs_str:>12} → {w}") |
| print(f" Time (s) Vanilla: {vt:>11.1f}s Rich: {rt:>11.1f}s") |
| |
| results[name] = {'v': vs, 'r': rs, 'vp': vp, 'rp': rp, |
| 'vt': vt, 'rt': rt, 'winner': winner, 'type': ttype} |
| |
| |
| print(f"\n{'─'*78}") |
| print(f" MNIST / Structured Classification") |
| print(f"{'─'*78}") |
| |
| set_seed() |
| txr, tyr, txe, tye, dsn, ind = data_mnist_or_synth() |
| budget = 30000 |
| h_v = find_hidden(ind, 10, N_HIDDEN, budget, VanillaMLP) |
| h_r = find_hidden(ind, 10, N_HIDDEN, budget, RichNet, omega_0=10.0) |
| |
| set_seed(123) |
| mv = VanillaMLP(ind, 10, h_v, N_HIDDEN) |
| vp = count_params(mv) |
| vs = train_classification(mv, txr, tyr, txe, tye, 500, 1e-3) |
| |
| set_seed(123) |
| mr = RichNet(ind, 10, h_r, N_HIDDEN, 10.0) |
| rp = count_params(mr) |
| rs = train_classification(mr, txr, tyr, txe, tye, 500, 1e-3) |
| |
| winner = 'rich' if rs > vs else 'vanilla' |
| w = "🟢 RichNeuron" if winner == 'rich' else "⚪ Vanilla" |
| print(f" {dsn}: Vanilla({vp:,}p)={vs:.1%} Rich({rp:,}p)={rs:.1%} → {w}") |
| results[dsn] = {'v': vs, 'r': rs, 'vp': vp, 'rp': rp, 'winner': winner, 'type': 'classification'} |
| |
| |
| |
| |
| print("\n" + "="*78) |
| print(" GRAND SUMMARY") |
| print("="*78) |
| |
| rich_w = sum(1 for r in results.values() if r['winner'] == 'rich') |
| van_w = sum(1 for r in results.values() if r['winner'] == 'vanilla') |
| |
| print(f"\n {'Task':<35} {'Params':>12} {'Vanilla':>12} {'Rich':>12} {'Winner':>14}") |
| print(f" {'─'*85}") |
| for name, r in results.items(): |
| ps = f"{r['vp']}/{r['rp']}" |
| if r['type'] == 'regression': |
| vs = f"{r['v']:.6f}" |
| rs = f"{r['r']:.6f}" |
| else: |
| vs = f"{r['v']:.1%}" |
| rs = f"{r['r']:.1%}" |
| w = "🟢 Rich" if r['winner'] == 'rich' else "⚪ Vanilla" |
| print(f" {name:<35} {ps:>12} {vs:>12} {rs:>12} {w:>14}") |
| |
| print(f"\n {'─'*85}") |
| print(f" 🏆 FINAL SCORE: RichNeuron {rich_w} vs Vanilla MLP {van_w}") |
| print(f" {'─'*85}") |
| |
| with open('results.json', 'w') as f: |
| json.dump({k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv |
| for kk, vv in v.items()} for k, v in results.items()}, f, indent=2) |
| print("\n Results saved to results.json") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|