anshdadhich's picture
Add benchmark code: RichNeuron vs Vanilla MLP
d840966 verified
#!/usr/bin/env python3
"""
=============================================================================
BENCHMARK v3: RichNeuron (Mult × Periodic + Residual) vs Vanilla MLP
=============================================================================
Strictly matched param budgets. Single run per task (for speed on CPU).
7 diverse tasks covering regression, classification, memorization, frequency.
RichNeuron layer: y = LayerNorm( (W1·x) ⊙ sin(ω·W2·x+b) + W1·x )
- W1 creates linear features (like standard)
- W2 + sin() creates periodic features
- ⊙ (element-wise multiply) creates CROSS-TERMS between them
- +W1·x residual prevents scalar collapse
- LayerNorm stabilizes across depth
Run: pip install torch numpy && python benchmark.py
=============================================================================
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
import json
DEVICE = 'cpu'
def set_seed(s=42):
torch.manual_seed(s)
np.random.seed(s)
# ============================================================================
# ARCHITECTURES
# ============================================================================
class RichNeuronLayer(nn.Module):
"""
y = LayerNorm( (W1·x) ⊙ sin(ω · W2·x + b) + W1·x )
Multiplicative interaction between linear and periodic branches.
The residual (+W1·x) prevents scalar collapse.
LayerNorm stabilizes magnitude across depth.
"""
def __init__(self, in_dim, out_dim, omega_0=30.0):
super().__init__()
self.W1 = nn.Linear(in_dim, out_dim, bias=False)
self.W2 = nn.Linear(in_dim, out_dim, bias=True)
self.omega_0 = omega_0
self.ln = nn.LayerNorm(out_dim)
with torch.no_grad():
nn.init.xavier_uniform_(self.W1.weight)
bound = math.sqrt(6.0 / in_dim) / omega_0
self.W2.weight.uniform_(-bound, bound)
self.W2.bias.uniform_(-math.pi, math.pi)
def forward(self, x):
linear = self.W1(x)
periodic = torch.sin(self.omega_0 * self.W2(x))
return self.ln(linear * periodic + linear)
class VanillaMLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, n_hidden):
super().__init__()
layers = []
prev = in_dim
for _ in range(n_hidden):
layers.extend([nn.Linear(prev, hidden_dim), nn.ReLU()])
prev = hidden_dim
layers.append(nn.Linear(prev, out_dim))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class RichNet(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0):
super().__init__()
layers = []
prev = in_dim
for _ in range(n_hidden):
layers.append(RichNeuronLayer(prev, hidden_dim, omega_0))
prev = hidden_dim
layers.append(nn.Linear(prev, out_dim))
self.layers = nn.ModuleList(layers)
def forward(self, x):
for l in self.layers:
x = l(x)
return x
def count_params(m):
return sum(p.numel() for p in m.parameters() if p.requires_grad)
def find_hidden(in_d, out_d, n_h, target_p, model_cls, **kw):
"""Binary search for hidden dim matching target param count."""
lo, hi, best_h = 2, 1024, 2
while lo <= hi:
mid = (lo + hi) // 2
m = model_cls(in_d, out_d, mid, n_h, **kw)
p = count_params(m)
if abs(p - target_p) < abs(count_params(model_cls(in_d, out_d, best_h, n_h, **kw)) - target_p):
best_h = mid
if p < target_p:
lo = mid + 1
else:
hi = mid - 1
return best_h
# ============================================================================
# TRAINING (mini-batch for speed)
# ============================================================================
def train_regression(model, x_tr, y_tr, x_te, y_te, epochs, lr, bs=256):
opt = torch.optim.Adam(model.parameters(), lr=lr)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best = float('inf')
n = len(x_tr)
for ep in range(epochs):
model.train()
perm = torch.randperm(n)
for i in range(0, n, bs):
idx = perm[i:i+bs]
loss = F.mse_loss(model(x_tr[idx]), y_tr[idx])
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sch.step()
if (ep+1) % max(1, epochs//10) == 0:
model.eval()
with torch.no_grad():
tl = F.mse_loss(model(x_te), y_te).item()
best = min(best, tl)
model.eval()
with torch.no_grad():
best = min(best, F.mse_loss(model(x_te), y_te).item())
return best
def train_classification(model, x_tr, y_tr, x_te, y_te, epochs, lr, bs=256):
opt = torch.optim.Adam(model.parameters(), lr=lr)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best = 0
n = len(x_tr)
for ep in range(epochs):
model.train()
perm = torch.randperm(n)
for i in range(0, n, bs):
idx = perm[i:i+bs]
loss = F.cross_entropy(model(x_tr[idx]), y_tr[idx])
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sch.step()
if (ep+1) % max(1, epochs//10) == 0:
model.eval()
with torch.no_grad():
acc = (model(x_te).argmax(1) == y_te).float().mean().item()
best = max(best, acc)
model.eval()
with torch.no_grad():
best = max(best, (model(x_te).argmax(1) == y_te).float().mean().item())
return best
# ============================================================================
# DATA
# ============================================================================
def data_complex(n=2000):
x = torch.rand(n, 4)*2-1
y = torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2))
return x, y.unsqueeze(1)
def data_nested(n=2000):
x = torch.rand(n, 2)*2-1
y = torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1])
return x, y.unsqueeze(1)
def data_spiral(n=1500):
t = torch.linspace(0, 4*np.pi, n//2)
r = torch.linspace(0.3, 2, n//2)
x1 = torch.stack([r*torch.cos(t), r*torch.sin(t)], 1)
x2 = torch.stack([r*torch.cos(t+np.pi), r*torch.sin(t+np.pi)], 1)
x = torch.cat([x1,x2]) + torch.randn(n,2)*0.05
y = torch.cat([torch.zeros(n//2), torch.ones(n//2)]).long()
p = torch.randperm(n); return x[p], y[p]
def data_checker(n=2000, freq=3):
x = torch.rand(n,2)*2-1
y = ((torch.sin(freq*math.pi*x[:,0])*torch.sin(freq*math.pi*x[:,1])) > 0).long()
return x, y
def data_highfreq(n=1500):
x = torch.linspace(-1,1,n).unsqueeze(1)
y = torch.sin(20*x)+torch.sin(50*x)+0.5*torch.sin(100*x)
return x, y
def data_memorize(n=200, kd=8, vd=4):
return torch.randn(n, kd), torch.randn(n, vd)
def data_mnist_or_synth():
try:
import torchvision, torchvision.transforms as T
tr = torchvision.datasets.MNIST('./data',True,T.ToTensor(),download=True)
te = torchvision.datasets.MNIST('./data',False,T.ToTensor(),download=True)
return (tr.data[:3000].float().view(-1,784)/255., tr.targets[:3000],
te.data[:500].float().view(-1,784)/255., te.targets[:500], "MNIST", 784)
except:
d = 64; centers = torch.randn(10, d)
def make(n):
y = torch.randint(0,10,(n,))
x = torch.randn(n, d)*0.5
for i in range(n): x[i] += centers[y[i]]
return x, y
tx, ty = make(2000); ex, ey = make(400)
return tx, ty, ex, ey, "Synth-10class", d
# ============================================================================
# MAIN
# ============================================================================
def main():
print("="*78)
print(" BENCHMARK: RichNeuron vs Vanilla MLP")
print(" RichNeuron = (W1·x) ⊙ sin(ω·W2·x+b) + W1·x [Mult×Periodic+Skip]")
print(" Fair comparison: SAME parameter budget for both")
print("="*78)
N_HIDDEN = 3
results = {}
tasks = [
("Complex Compositional Fn", "regression", data_complex, 4, 1, 8000, 1500, 1e-3, 30.0, 1500),
("Nested Nonlinear Fn", "regression", data_nested, 2, 1, 4000, 1500, 1e-3, 20.0, 1500),
("Two-Spiral Classification", "classification", data_spiral, 2, 2, 4000, 1000, 1e-3, 15.0, 1000),
("Checkerboard Pattern", "classification", data_checker, 2, 2, 4000, 1000, 1e-3, 20.0, 1500),
("High-Frequency Signal", "regression", data_highfreq, 1, 1, 10000, 2000, 1e-3, 60.0, 1000),
("Knowledge Memorization", "regression", data_memorize, 8, 4, 6000, 3000, 1e-3, 10.0, 200),
]
for name, ttype, datafn, ind, outd, budget, epochs, lr, omega, split in tasks:
print(f"\n{'─'*78}")
print(f" {name}")
print(f" Type: {ttype} | Params: ~{budget:,} | Epochs: {epochs}")
print(f"{'─'*78}")
h_v = find_hidden(ind, outd, N_HIDDEN, budget, VanillaMLP)
h_r = find_hidden(ind, outd, N_HIDDEN, budget, RichNet, omega_0=omega)
set_seed()
mv = VanillaMLP(ind, outd, h_v, N_HIDDEN)
mr = RichNet(ind, outd, h_r, N_HIDDEN, omega)
vp, rp = count_params(mv), count_params(mr)
print(f" Vanilla: hidden={h_v:>4}, params={vp:>6,}")
print(f" Rich: hidden={h_r:>4}, params={rp:>6,}")
set_seed()
x, y = datafn()
if split >= len(x):
xtr, ytr, xte, yte = x, y, x, y
else:
xtr, ytr = x[:split], y[:split]
xte, yte = x[split:], y[split:]
set_seed(123)
mv = VanillaMLP(ind, outd, h_v, N_HIDDEN)
t0 = time.time()
if ttype == 'regression':
vs = train_regression(mv, xtr, ytr, xte, yte, epochs, lr)
else:
vs = train_classification(mv, xtr, ytr, xte, yte, epochs, lr)
vt = time.time() - t0
set_seed(123)
mr = RichNet(ind, outd, h_r, N_HIDDEN, omega)
t0 = time.time()
if ttype == 'regression':
rs = train_regression(mr, xtr, ytr, xte, yte, epochs, lr)
else:
rs = train_classification(mr, xtr, ytr, xte, yte, epochs, lr)
rt = time.time() - t0
if ttype == 'regression':
winner = 'rich' if rs < vs else 'vanilla'
vs_str, rs_str = f"{vs:.6f}", f"{rs:.6f}"
metric = "MSE ↓"
else:
winner = 'rich' if rs > vs else 'vanilla'
vs_str, rs_str = f"{vs:.1%}", f"{rs:.1%}"
metric = "Acc ↑"
w = "🟢 RichNeuron" if winner == 'rich' else "⚪ Vanilla"
print(f"\n {metric:<20} Vanilla: {vs_str:>12} Rich: {rs_str:>12}{w}")
print(f" Time (s) Vanilla: {vt:>11.1f}s Rich: {rt:>11.1f}s")
results[name] = {'v': vs, 'r': rs, 'vp': vp, 'rp': rp,
'vt': vt, 'rt': rt, 'winner': winner, 'type': ttype}
# ----- MNIST -----
print(f"\n{'─'*78}")
print(f" MNIST / Structured Classification")
print(f"{'─'*78}")
set_seed()
txr, tyr, txe, tye, dsn, ind = data_mnist_or_synth()
budget = 30000
h_v = find_hidden(ind, 10, N_HIDDEN, budget, VanillaMLP)
h_r = find_hidden(ind, 10, N_HIDDEN, budget, RichNet, omega_0=10.0)
set_seed(123)
mv = VanillaMLP(ind, 10, h_v, N_HIDDEN)
vp = count_params(mv)
vs = train_classification(mv, txr, tyr, txe, tye, 500, 1e-3)
set_seed(123)
mr = RichNet(ind, 10, h_r, N_HIDDEN, 10.0)
rp = count_params(mr)
rs = train_classification(mr, txr, tyr, txe, tye, 500, 1e-3)
winner = 'rich' if rs > vs else 'vanilla'
w = "🟢 RichNeuron" if winner == 'rich' else "⚪ Vanilla"
print(f" {dsn}: Vanilla({vp:,}p)={vs:.1%} Rich({rp:,}p)={rs:.1%}{w}")
results[dsn] = {'v': vs, 'r': rs, 'vp': vp, 'rp': rp, 'winner': winner, 'type': 'classification'}
# ============================================================
# GRAND SUMMARY
# ============================================================
print("\n" + "="*78)
print(" GRAND SUMMARY")
print("="*78)
rich_w = sum(1 for r in results.values() if r['winner'] == 'rich')
van_w = sum(1 for r in results.values() if r['winner'] == 'vanilla')
print(f"\n {'Task':<35} {'Params':>12} {'Vanilla':>12} {'Rich':>12} {'Winner':>14}")
print(f" {'─'*85}")
for name, r in results.items():
ps = f"{r['vp']}/{r['rp']}"
if r['type'] == 'regression':
vs = f"{r['v']:.6f}"
rs = f"{r['r']:.6f}"
else:
vs = f"{r['v']:.1%}"
rs = f"{r['r']:.1%}"
w = "🟢 Rich" if r['winner'] == 'rich' else "⚪ Vanilla"
print(f" {name:<35} {ps:>12} {vs:>12} {rs:>12} {w:>14}")
print(f"\n {'─'*85}")
print(f" 🏆 FINAL SCORE: RichNeuron {rich_w} vs Vanilla MLP {van_w}")
print(f" {'─'*85}")
with open('results.json', 'w') as f:
json.dump({k: {kk: float(vv) if isinstance(vv, (float, np.floating)) else vv
for kk, vv in v.items()} for k, v in results.items()}, f, indent=2)
print("\n Results saved to results.json")
if __name__ == "__main__":
main()