anshdadhich's picture
Add v8: adaptive phase + amplitude gate
064558f verified
#!/usr/bin/env python3
"""
=============================================================================
BENCHMARK v8: ADAPTIVE PHASE + AMPLITUDE MODULATION
=============================================================================
v7 FAILED because: ω collapsed to a constant. Neural nets refuse to learn
frequency when adjusting weights is easier.
v8 FIX (from GPT's critique):
Don't learn frequency. Learn PHASE and AMPLITUDE instead.
val = W_val · x
per = sin(ω_fixed · W_per · x + φ(x)) # learned phase, fixed freq
α = sigmoid(W_gate · x) # learned amplitude gate
y = LN( val ⊙ (α ⊙ per + (1-α)) + res ) # smooth interpolation
Why this works:
- Phase gradient: d/dφ sin(ωx + φ) = cos(ωx + φ) — stable, bounded
- Frequency gradient: d/dω sin(ωx) = x·cos(ωx) — oscillatory, unstable
- Gate gradient: d/dα = (per - 1) — clean signal
+ Entropy regularization: loss += λ·α(1-α) pushes gate away from 0.5
=============================================================================
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import json
SEEDS = [0, 1, 2]
def set_seed(s):
torch.manual_seed(s)
np.random.seed(s)
# ============================================================================
# BASELINES (same as before)
# ============================================================================
class VanillaMLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, n_hidden):
super().__init__()
layers = []
prev = in_dim
for _ in range(n_hidden):
layers.extend([nn.Linear(prev, hidden_dim), nn.ReLU()])
prev = hidden_dim
layers.append(nn.Linear(prev, out_dim))
self.net = nn.Sequential(*layers)
def forward(self, x): return self.net(x)
class SinGLULayer(nn.Module):
def __init__(self, in_dim, out_dim, mid_dim, omega_0=30.0):
super().__init__()
self.Wg = nn.Linear(in_dim, mid_dim, bias=False)
self.Wv = nn.Linear(in_dim, mid_dim, bias=False)
self.Wo = nn.Linear(mid_dim, out_dim, bias=True)
self.omega_0 = omega_0
self.ln = nn.LayerNorm(out_dim)
with torch.no_grad():
self.Wg.weight.uniform_(-math.sqrt(6/in_dim)/omega_0, math.sqrt(6/in_dim)/omega_0)
nn.init.xavier_uniform_(self.Wv.weight)
nn.init.xavier_uniform_(self.Wo.weight)
def forward(self, x):
return self.ln(self.Wo(torch.sin(self.omega_0 * self.Wg(x)) * self.Wv(x)))
class SinGLUNet(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0):
super().__init__()
mid = max(2, int(hidden_dim * 2/3))
layers = []
prev = in_dim
for _ in range(n_hidden):
layers.append(SinGLULayer(prev, hidden_dim, mid, omega_0)); prev = hidden_dim
layers.append(nn.Linear(prev, out_dim))
self.layers = nn.ModuleList(layers)
def forward(self, x):
for l in self.layers: x = l(x)
return x
class HybridLayer(nn.Module):
def __init__(self, in_dim, out_dim, mid_dim, omega_0=30.0):
super().__init__()
self.W1 = nn.Linear(in_dim, mid_dim, bias=False)
self.W2 = nn.Linear(in_dim, mid_dim, bias=False)
self.phase = nn.Parameter(torch.empty(mid_dim))
self.W3 = nn.Linear(mid_dim, out_dim, bias=True)
self.omega_0 = omega_0
self.ln = nn.LayerNorm(out_dim)
self.res = nn.Linear(in_dim, out_dim, bias=False) if in_dim != out_dim else nn.Identity()
with torch.no_grad():
nn.init.xavier_uniform_(self.W1.weight)
self.W2.weight.uniform_(-math.sqrt(6/in_dim)/omega_0, math.sqrt(6/in_dim)/omega_0)
self.phase.uniform_(-math.pi, math.pi)
nn.init.xavier_uniform_(self.W3.weight)
def forward(self, x):
return self.ln(self.W3(self.W1(x) * torch.sin(self.omega_0 * self.W2(x) + self.phase)) + self.res(x))
class HybridNet(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0):
super().__init__()
mid = max(2, int(hidden_dim * 0.55))
layers = []
prev = in_dim
for _ in range(n_hidden):
layers.append(HybridLayer(prev, hidden_dim, mid, omega_0)); prev = hidden_dim
layers.append(nn.Linear(prev, out_dim))
self.layers = nn.ModuleList(layers)
def forward(self, x):
for l in self.layers: x = l(x)
return x
# ============================================================================
# v8: ADAPTIVE PHASE + AMPLITUDE GATE
# ============================================================================
class AdaptivePhaseLayer(nn.Module):
"""
val = W_val · x
per = sin(ω · W_per · x + φ(x)) ← learned phase (NOT frequency)
α = sigmoid(W_gate · x) ← amplitude gate
y = LN( val ⊙ (α ⊙ per + (1-α)) + residual )
Phase is easy to optimize (gradient = cos, bounded).
Gate polarizes with entropy regularization.
Explicit linear fallback when α → 0.
"""
def __init__(self, in_dim, out_dim, omega_0=30.0, rank=None):
super().__init__()
r = rank or max(2, min(in_dim // 4, 8))
self.W_val = nn.Linear(in_dim, out_dim, bias=True)
self.W_per = nn.Linear(in_dim, out_dim, bias=False)
# Phase predictor: low-rank, bounded by tanh
self.phi_down = nn.Linear(in_dim, r, bias=False)
self.phi_up = nn.Linear(r, out_dim, bias=True)
# Amplitude gate: low-rank
self.gate_down = nn.Linear(in_dim, r, bias=False)
self.gate_up = nn.Linear(r, out_dim, bias=True)
self.omega_0 = omega_0
self.ln = nn.LayerNorm(out_dim)
self.res = nn.Linear(in_dim, out_dim, bias=False) if in_dim != out_dim else nn.Identity()
with torch.no_grad():
nn.init.xavier_uniform_(self.W_val.weight)
bound = math.sqrt(6.0 / in_dim) / omega_0
self.W_per.weight.uniform_(-bound, bound)
# Phase: start at 0 (no shift initially)
nn.init.xavier_uniform_(self.phi_down.weight)
nn.init.zeros_(self.phi_up.weight)
nn.init.zeros_(self.phi_up.bias)
# Gate: start at 0 → sigmoid(0) = 0.5 (balanced)
nn.init.xavier_uniform_(self.gate_down.weight)
nn.init.zeros_(self.gate_up.weight)
nn.init.zeros_(self.gate_up.bias)
def forward(self, x):
val = self.W_val(x)
per_in = self.W_per(x)
# Input-dependent phase shift (bounded by tanh to [-π, π])
phi = math.pi * torch.tanh(self.phi_up(self.phi_down(x)))
per = torch.sin(self.omega_0 * per_in + phi)
# Amplitude gate (how much periodic vs linear)
alpha = torch.sigmoid(self.gate_up(self.gate_down(x)))
# Smooth interpolation: α=1 → full periodic, α=0 → just val
mixed = val * (alpha * per + (1 - alpha))
return self.ln(mixed + self.res(x))
def get_diagnostics(self, x):
with torch.no_grad():
phi = math.pi * torch.tanh(self.phi_up(self.phi_down(x)))
alpha = torch.sigmoid(self.gate_up(self.gate_down(x)))
return alpha, phi
class AdaptivePhaseNet(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0):
super().__init__()
layers = []
prev = in_dim
for _ in range(n_hidden):
layers.append(AdaptivePhaseLayer(prev, hidden_dim, omega_0))
prev = hidden_dim
layers.append(nn.Linear(prev, out_dim))
self.layers = nn.ModuleList(layers)
def forward(self, x):
for l in self.layers: x = l(x)
return x
def get_all_diagnostics(self, x):
alphas, phis = [], []
h = x
for l in self.layers:
if isinstance(l, AdaptivePhaseLayer):
a, p = l.get_diagnostics(h)
alphas.append(a); phis.append(p)
h = l(h)
else: h = l(h)
return alphas, phis
def entropy_reg(self, x):
"""Push α away from 0.5 — encourage polarization"""
total = 0
h = x
for l in self.layers:
if isinstance(l, AdaptivePhaseLayer):
alpha = torch.sigmoid(l.gate_up(l.gate_down(h)))
total = total + (alpha * (1 - alpha)).mean()
h = l(h)
else: h = l(h)
return total
# ============================================================================
# UTILS
# ============================================================================
def count_params(m):
return sum(p.numel() for p in m.parameters() if p.requires_grad)
def find_hidden(in_d, out_d, n_h, target_p, model_cls, **kw):
lo, hi, best_h = 2, 512, 2
while lo <= hi:
mid = (lo + hi) // 2
p = count_params(model_cls(in_d, out_d, mid, n_h, **kw))
if abs(p - target_p) < abs(count_params(model_cls(in_d, out_d, best_h, n_h, **kw)) - target_p):
best_h = mid
if p < target_p: lo = mid + 1
else: hi = mid - 1
return best_h
def train_reg(model, xtr, ytr, xte, yte, epochs, lr, entropy_lambda=1e-4, bs=256):
opt = torch.optim.Adam(model.parameters(), lr=lr)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best = float('inf')
use_entropy = isinstance(model, AdaptivePhaseNet) and entropy_lambda > 0
n = len(xtr)
for ep in range(epochs):
model.train()
perm = torch.randperm(n)
for i in range(0, n, bs):
idx = perm[i:i+bs]
bx, by = xtr[idx], ytr[idx]
loss = F.mse_loss(model(bx), by)
if use_entropy:
loss = loss + entropy_lambda * model.entropy_reg(bx)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sch.step()
if (ep+1) % max(1, epochs//10) == 0:
model.eval()
with torch.no_grad():
best = min(best, F.mse_loss(model(xte), yte).item())
model.eval()
with torch.no_grad():
best = min(best, F.mse_loss(model(xte), yte).item())
return best
def train_clf(model, xtr, ytr, xte, yte, epochs, lr, entropy_lambda=1e-4, bs=256):
opt = torch.optim.Adam(model.parameters(), lr=lr)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best = 0
use_entropy = isinstance(model, AdaptivePhaseNet) and entropy_lambda > 0
n = len(xtr)
for ep in range(epochs):
model.train()
perm = torch.randperm(n)
for i in range(0, n, bs):
idx = perm[i:i+bs]
bx, by = xtr[idx], ytr[idx]
loss = F.cross_entropy(model(bx), by)
if use_entropy:
loss = loss + entropy_lambda * model.entropy_reg(bx)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sch.step()
if (ep+1) % max(1, epochs//10) == 0:
model.eval()
with torch.no_grad():
best = max(best, (model(xte).argmax(1) == yte).float().mean().item())
model.eval()
with torch.no_grad():
best = max(best, (model(xte).argmax(1) == yte).float().mean().item())
return best
# ============================================================================
# DATA
# ============================================================================
def data_complex(n=1000):
x = torch.rand(n,4)*2-1
y = torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2))
return x, y.unsqueeze(1)
def data_nested(n=1000):
x = torch.rand(n,2)*2-1
y = torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1])
return x, y.unsqueeze(1)
def data_spiral(n=1000):
t = torch.linspace(0,4*np.pi,n//2); r = torch.linspace(0.3,2,n//2)
x1 = torch.stack([r*torch.cos(t),r*torch.sin(t)],1)
x2 = torch.stack([r*torch.cos(t+np.pi),r*torch.sin(t+np.pi)],1)
x = torch.cat([x1,x2])+torch.randn(n,2)*0.05
y = torch.cat([torch.zeros(n//2),torch.ones(n//2)]).long()
p = torch.randperm(n); return x[p],y[p]
def data_checker(n=1000):
x = torch.rand(n,2)*2-1
y = ((torch.sin(3*math.pi*x[:,0])*torch.sin(3*math.pi*x[:,1]))>0).long()
return x, y
def data_highfreq(n=1000):
x = torch.linspace(-1,1,n).unsqueeze(1)
return x, torch.sin(20*x)+torch.sin(50*x)+0.5*torch.sin(100*x)
def data_memorize(n=200):
return torch.randn(n,8), torch.randn(n,4)
def data_ood_train(n=800):
x = torch.rand(n,2)*2-1
y = torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]
return x, y.unsqueeze(1)
def data_ood_test(n=300):
x = torch.rand(n,2)+1
y = torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]
return x, y.unsqueeze(1)
# ============================================================================
# MAIN
# ============================================================================
def main():
print("="*80)
print(" BENCHMARK v8: ADAPTIVE PHASE + AMPLITUDE GATE")
print(" Learn PHASE φ(x) and GATE α(x), NOT frequency ω")
print(" + entropy regularization to prevent α collapse at 0.5")
print("="*80)
N_H = 3
models = {
'Vanilla': (VanillaMLP, {}),
'SinGLU': (SinGLUNet, {'omega_0': None}),
'Hybrid': (HybridNet, {'omega_0': None}),
'v8:Phase': (AdaptivePhaseNet, {'omega_0': None}),
}
tasks = [
("Complex Fn (4D)", "reg", data_complex, 4,1, 5000, 300, 1e-3, 30.0, 750),
("Nested Fn (2D)", "reg", data_nested, 2,1, 3000, 300, 1e-3, 20.0, 750),
("Spiral", "clf", data_spiral, 2,2, 3000, 250, 1e-3, 15.0, 700),
("Checkerboard", "clf", data_checker, 2,2, 3000, 250, 1e-3, 20.0, 700),
("High-Freq", "reg", data_highfreq, 1,1, 8000, 300, 1e-3, 60.0, 700),
("Memorization", "reg", data_memorize, 8,4, 5000, 400, 1e-3, 10.0, 200),
]
all_results = {}
diag_data = {}
for tname, ttype, dfn, ind, outd, budget, epochs, lr, omega, split in tasks:
print(f"\n{'━'*80}")
print(f" {tname} | budget ~{budget:,}")
print(f"{'━'*80}")
hdims = {}
for mn, (mc, mk) in models.items():
kw = {k: (omega if v is None else v) for k,v in mk.items()}
hdims[mn] = find_hidden(ind, outd, N_H, budget, mc, **kw)
task_res = {}
for mn, (mc, mk) in models.items():
kw = {k: (omega if v is None else v) for k,v in mk.items()}
h = hdims[mn]
scores = []
for seed in SEEDS:
set_seed(seed); x,y = dfn()
if split >= len(x): xtr,ytr,xte,yte = x,y,x,y
else: xtr,ytr,xte,yte = x[:split],y[:split],x[split:],y[split:]
set_seed(seed+100); model = mc(ind, outd, h, N_H, **kw)
if ttype == 'reg': s = train_reg(model, xtr, ytr, xte, yte, epochs, lr)
else: s = train_clf(model, xtr, ytr, xte, yte, epochs, lr)
scores.append(s)
# Diagnostics for v8 (last seed)
if mn == 'v8:Phase' and seed == SEEDS[-1]:
model.eval()
with torch.no_grad():
alphas, phis = model.get_all_diagnostics(xte[:100])
all_a = torch.cat([a.flatten() for a in alphas])
all_p = torch.cat([p.flatten() for p in phis])
diag_data[tname] = {
'alpha_mean': all_a.mean().item(),
'alpha_std': all_a.std().item(),
'alpha_pct_low': (all_a < 0.3).float().mean().item(),
'alpha_pct_high': (all_a > 0.7).float().mean().item(),
'phi_mean': all_p.mean().item(),
'phi_std': all_p.std().item(),
}
p = count_params(mc(ind, outd, h, N_H, **kw))
task_res[mn] = {'mean': np.mean(scores), 'std': np.std(scores),
'scores': scores, 'params': p, 'hidden': h}
is_reg = ttype == 'reg'
if is_reg: best_mn = min(task_res, key=lambda k: task_res[k]['mean'])
else: best_mn = max(task_res, key=lambda k: task_res[k]['mean'])
metric = "MSE ↓" if is_reg else "Acc ↑"
print(f"\n {'Model':<12} {'H':>4} {'Params':>7} {metric+' (mean±std)':>28}")
print(f" {'─'*56}")
for mn, r in task_res.items():
m,s = r['mean'], r['std']
ms = f"{m:.2e}±{s:.1e}" if (is_reg and m<0.001) else (f"{m:.4f}±{s:.4f}" if is_reg else f"{m:.1%}±{s:.3f}")
print(f" {mn:<12} {r['hidden']:>4} {r['params']:>7,} {ms:>28}{' ★' if mn==best_mn else ''}")
print(f" → Winner: {best_mn}")
if tname in diag_data:
d = diag_data[tname]
print(f" → v8 α: mean={d['alpha_mean']:.3f} std={d['alpha_std']:.3f}"
f" | {d['alpha_pct_low']:.0%} linear {d['alpha_pct_high']:.0%} periodic")
print(f" → v8 φ: mean={d['phi_mean']:.3f} std={d['phi_std']:.3f}")
all_results[tname] = task_res
# OOD
print(f"\n{'━'*80}")
print(f" OOD: Train [-1,1] → Test [1,2]")
print(f" Does α shift toward linear on OOD?")
print(f"{'━'*80}")
ood_res = {}; ood_diag = {}
for mn, (mc, mk) in models.items():
kw = {k: (20.0 if v is None else v) for k,v in mk.items()}
h = find_hidden(2, 1, N_H, 5000, mc, **kw)
id_sc, ood_sc = [], []
for seed in SEEDS:
set_seed(seed); xtr,ytr = data_ood_train()
set_seed(seed+50)
xid = torch.rand(200,2)*2-1
yid = (torch.sin(3*math.pi*xid[:,0])*torch.cos(3*math.pi*xid[:,1])+xid[:,0]*xid[:,1]).unsqueeze(1)
set_seed(seed+50); xood,yood = data_ood_test()
set_seed(seed+100); model = mc(2,1,h,N_H,**kw)
s_id = train_reg(model, xtr, ytr, xid, yid, 300, 1e-3)
model.eval()
with torch.no_grad(): s_ood = F.mse_loss(model(xood), yood).item()
id_sc.append(s_id); ood_sc.append(s_ood)
if mn == 'v8:Phase' and seed == SEEDS[-1]:
model.eval()
with torch.no_grad():
a_id, _ = model.get_all_diagnostics(xid[:100])
a_ood, _ = model.get_all_diagnostics(xood[:100])
ood_diag = {
'id_alpha': torch.cat([a.flatten() for a in a_id]).mean().item(),
'ood_alpha': torch.cat([a.flatten() for a in a_ood]).mean().item(),
}
p = count_params(mc(2,1,h,N_H,**kw))
ood_res[mn] = {'id': np.mean(id_sc), 'ood': np.mean(ood_sc), 'params': p,
'deg': np.mean(ood_sc)/max(np.mean(id_sc),1e-10),
'id_std': np.std(id_sc), 'ood_std': np.std(ood_sc)}
best_ood = min(ood_res, key=lambda k: ood_res[k]['ood'])
print(f"\n {'Model':<12} {'ID MSE':>14} {'OOD MSE':>14} {'Degrad.':>9}")
print(f" {'─'*52}")
for mn,r in ood_res.items():
mark = " ★" if mn==best_ood else ""
print(f" {mn:<12} {r['id']:>9.4f}±{r['id_std']:.3f} {r['ood']:>9.4f}±{r['ood_std']:.3f} {r['deg']:>8.1f}x{mark}")
print(f" → Best OOD: {best_ood}")
if ood_diag:
shift = ood_diag['ood_alpha'] - ood_diag['id_alpha']
print(f"\n v8 α SHIFT on OOD:")
print(f" ID: α = {ood_diag['id_alpha']:.4f}")
print(f" OOD: α = {ood_diag['ood_alpha']:.4f}")
if shift < -0.03:
print(f" → α DROPPED by {abs(shift):.4f} → periodic reduced on OOD ✅")
elif shift > 0.03:
print(f" → α INCREASED by {shift:.4f} → MORE periodic on OOD ❌")
else:
print(f" → α shift = {shift:+.4f} (minimal)")
all_results['OOD'] = {mn: {'mean': r['ood'], 'std': r['ood_std']} for mn,r in ood_res.items()}
# GRAND SUMMARY
print(f"\n{'='*80}")
print(f" GRAND SUMMARY")
print(f"{'='*80}")
win_counts = {k: 0 for k in models}
print(f"\n {'Task':<20}", end="")
for mn in models: print(f" {mn:>12}", end="")
print(f" {'Winner':>10}")
print(f" {'─'*72}")
for tname, tr in all_results.items():
scores = {k: v['mean'] for k,v in tr.items()}
max_s = max(scores.values())
is_clf = max_s > 0.5 and max_s <= 1.0 and min(scores.values()) >= 0
if min(scores.values()) < 0.001: is_clf = False
if tname == 'OOD': winner = min(scores, key=scores.get)
elif is_clf: winner = max(scores, key=scores.get)
else: winner = min(scores, key=scores.get)
win_counts[winner] += 1
row = f" {tname:<20}"
for mn in models:
s = scores[mn]
if is_clf: row += f" {s:>11.1%}"
elif s < 0.001: row += f" {s:>11.2e}"
else: row += f" {s:>11.4f}"
row += f" {'->'+winner:>10}"
print(row)
print(f"\n {'─'*72}")
for mn, c in sorted(win_counts.items(), key=lambda x: -x[1]):
print(f" {mn:<14} {c} wins {'█'*c*3}")
# DIAGNOSTICS SUMMARY
print(f"\n{'━'*80}")
print(f" v8 DIAGNOSTICS: Did phase & gate actually learn?")
print(f"{'━'*80}")
print(f"\n {'Task':<22} {'α mean':>7} {'α std':>7} {'%Lin':>6} {'%Per':>6} {'φ std':>7}")
print(f" {'─'*58}")
for tname, d in diag_data.items():
print(f" {tname:<22} {d['alpha_mean']:>7.3f} {d['alpha_std']:>7.3f}"
f" {d['alpha_pct_low']:>5.0%} {d['alpha_pct_high']:>5.0%} {d['phi_std']:>7.3f}")
print(f"""
╔════════════════════════════════════════════════════════════════════════════╗
║ v8 VERDICT: ADAPTIVE PHASE + AMPLITUDE GATE ║
║ ║
║ Key questions: ║
║ 1. Did α polarize (not stuck at 0.5)? Check α_std and %Lin/%Per ║
║ 2. Did φ vary per input? Check φ_std > 0 ║
║ 3. Did α shift on OOD? Check α shift above ║
║ 4. Did it beat SinGLU? Check win counts ║
╚════════════════════════════════════════════════════════════════════════════╝
""")
save = {'tasks': {}, 'ood': {}, 'diagnostics': diag_data, 'ood_diag': ood_diag}
for tname, tr in all_results.items():
save['tasks'][tname] = {mn: {'mean':float(r['mean']),'std':float(r.get('std',0)),
'scores':[float(s) for s in r.get('scores',[r['mean']])],
'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in tr.items()}
save['ood'] = {mn:{k:float(v) if isinstance(v,(float,np.floating)) else v
for k,v in r.items()} for mn,r in ood_res.items()}
with open('/app/results_v8.json','w') as f:
json.dump(save, f, indent=2, default=str)
print(" Saved to /app/results_v8.json")
if __name__ == "__main__":
main()