anshdadhich's picture
Add v9: controlled freq + phase + gate
36a49b5 verified
#!/usr/bin/env python3
"""
v9: Controlled Frequency + Phase + Gate (the convergent design)
per = sin( ω(x) ⊙ W_per·x + φ(x) ) ω(x) = ω0·(1 + 0.1·tanh(W_ω·x))
y = LN( α(x) ⊙ per + (1-α(x)) ⊙ val + residual )
Key vs v7: ω is bounded (±10%), not free
Key vs v8: ω exists (not removed)
Key vs v8: paths are SEPARATED (not entangled as val*(α*per+(1-α)))
"""
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np, math, json
SEEDS = [0, 1, 2]
def set_seed(s): torch.manual_seed(s); np.random.seed(s)
# ── Baselines ──
class VanillaMLP(nn.Module):
def __init__(self, d_in, d_out, h, n):
super().__init__()
layers = []
prev = d_in
for _ in range(n):
layers += [nn.Linear(prev, h), nn.ReLU()]; prev = h
layers.append(nn.Linear(prev, d_out))
self.net = nn.Sequential(*layers)
def forward(self, x): return self.net(x)
class SinGLULayer(nn.Module):
def __init__(self, d_in, d_out, mid, w0=30.):
super().__init__()
self.Wg = nn.Linear(d_in, mid, bias=False)
self.Wv = nn.Linear(d_in, mid, bias=False)
self.Wo = nn.Linear(mid, d_out, bias=True)
self.w0 = w0; self.ln = nn.LayerNorm(d_out)
with torch.no_grad():
self.Wg.weight.uniform_(-math.sqrt(6/d_in)/w0, math.sqrt(6/d_in)/w0)
nn.init.xavier_uniform_(self.Wv.weight)
nn.init.xavier_uniform_(self.Wo.weight)
def forward(self, x):
return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x))*self.Wv(x)))
class SinGLUNet(nn.Module):
def __init__(self, d_in, d_out, h, n, w0=30.):
super().__init__()
mid = max(2, int(h*2/3)); layers = []; prev = d_in
for _ in range(n): layers.append(SinGLULayer(prev, h, mid, w0)); prev = h
layers.append(nn.Linear(prev, d_out)); self.layers = nn.ModuleList(layers)
def forward(self, x):
for l in self.layers: x = l(x)
return x
# ── v9: THE CONVERGENT DESIGN ──
class v9Layer(nn.Module):
"""
Controlled freq + phase + gate + separated paths + residual.
ω(x) = ω0 · (1 + 0.1·tanh(W_ω·x)) bounded ±10%
φ(x) = π · tanh(W_φ·x) bounded [-π, π]
per = sin(ω(x) ⊙ W_per·x + φ(x)) full periodic
val = W_val·x full linear
α(x) = sigmoid(W_α·x) gate
y = LN( α⊙per + (1-α)⊙val + res ) SEPARATED paths
"""
def __init__(self, d_in, d_out, w0=30.):
super().__init__()
self.W_val = nn.Linear(d_in, d_out, bias=True) # linear path
self.W_per = nn.Linear(d_in, d_out, bias=False) # periodic input
self.W_omega = nn.Linear(d_in, d_out, bias=True) # frequency mod
self.W_phi = nn.Linear(d_in, d_out, bias=True) # phase
self.W_alpha = nn.Linear(d_in, d_out, bias=True) # gate
self.w0 = w0
self.ln = nn.LayerNorm(d_out)
self.res = nn.Linear(d_in, d_out, bias=False) if d_in != d_out else nn.Identity()
with torch.no_grad():
nn.init.xavier_uniform_(self.W_val.weight)
b = math.sqrt(6./d_in)/w0
self.W_per.weight.uniform_(-b, b)
# ω: start at ω0 (tanh(0)=0 → ω=ω0)
nn.init.zeros_(self.W_omega.weight); nn.init.zeros_(self.W_omega.bias)
# φ: start at 0
nn.init.zeros_(self.W_phi.weight); nn.init.zeros_(self.W_phi.bias)
# α: start at 0.5 (sigmoid(0))
nn.init.zeros_(self.W_alpha.weight); nn.init.zeros_(self.W_alpha.bias)
def forward(self, x):
val = self.W_val(x)
omega = self.w0 * (1. + 0.1 * torch.tanh(self.W_omega(x)))
phi = math.pi * torch.tanh(self.W_phi(x))
per = torch.sin(omega * self.W_per(x) + phi)
alpha = torch.sigmoid(self.W_alpha(x))
# SEPARATED: α picks between per and val, not val*(α*per+(1-α))
return self.ln(alpha * per + (1. - alpha) * val + self.res(x))
def get_diag(self, x):
with torch.no_grad():
omega = self.w0 * (1. + 0.1 * torch.tanh(self.W_omega(x)))
phi = math.pi * torch.tanh(self.W_phi(x))
alpha = torch.sigmoid(self.W_alpha(x))
return alpha, phi, omega
class v9Net(nn.Module):
def __init__(self, d_in, d_out, h, n, w0=30.):
super().__init__()
layers = []; prev = d_in
for _ in range(n): layers.append(v9Layer(prev, h, w0)); prev = h
layers.append(nn.Linear(prev, d_out)); self.layers = nn.ModuleList(layers)
def forward(self, x):
for l in self.layers: x = l(x)
return x
def get_all_diag(self, x):
alphas, phis, omegas = [], [], []
h = x
for l in self.layers:
if isinstance(l, v9Layer):
a,p,o = l.get_diag(h); alphas.append(a); phis.append(p); omegas.append(o)
h = l(h)
else: h = l(h)
return alphas, phis, omegas
def gate_reg(self, x):
"""Stronger polarization: (α - 0.5)² pushes away from center"""
total = 0; h = x
for l in self.layers:
if isinstance(l, v9Layer):
a = torch.sigmoid(l.W_alpha(h))
total = total + ((a - 0.5)**2).mean()
h = l(h)
else: h = l(h)
return total
# ── Utils ──
def nparams(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)
def find_h(di, do, n, target, cls, **kw):
lo,hi,best = 2,512,2
while lo<=hi:
mid=(lo+hi)//2; p=nparams(cls(di,do,mid,n,**kw))
if abs(p-target)<abs(nparams(cls(di,do,best,n,**kw))-target): best=mid
if p<target: lo=mid+1
else: hi=mid-1
return best
def train_reg(m, xtr,ytr,xte,yte, ep, lr, lam=5e-4, bs=256):
opt=torch.optim.Adam(m.parameters(),lr=lr)
sch=torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=ep)
best=float('inf'); n=len(xtr); use_reg=isinstance(m,v9Net)
for e in range(ep):
m.train(); perm=torch.randperm(n)
for i in range(0,n,bs):
idx=perm[i:i+bs]; bx,by=xtr[idx],ytr[idx]
loss=F.mse_loss(m(bx),by)
if use_reg: loss=loss+lam*m.gate_reg(bx)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(m.parameters(),1.0); opt.step()
sch.step()
if (e+1)%max(1,ep//10)==0:
m.eval()
with torch.no_grad(): best=min(best,F.mse_loss(m(xte),yte).item())
m.eval()
with torch.no_grad(): best=min(best,F.mse_loss(m(xte),yte).item())
return best
def train_clf(m, xtr,ytr,xte,yte, ep, lr, lam=5e-4, bs=256):
opt=torch.optim.Adam(m.parameters(),lr=lr)
sch=torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=ep)
best=0; n=len(xtr); use_reg=isinstance(m,v9Net)
for e in range(ep):
m.train(); perm=torch.randperm(n)
for i in range(0,n,bs):
idx=perm[i:i+bs]; bx,by=xtr[idx],ytr[idx]
loss=F.cross_entropy(m(bx),by)
if use_reg: loss=loss+lam*m.gate_reg(bx)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(m.parameters(),1.0); opt.step()
sch.step()
if (e+1)%max(1,ep//10)==0:
m.eval()
with torch.no_grad(): best=max(best,(m(xte).argmax(1)==yte).float().mean().item())
m.eval()
with torch.no_grad(): best=max(best,(m(xte).argmax(1)==yte).float().mean().item())
return best
# ── Data ──
def d_complex(n=1000):
x=torch.rand(n,4)*2-1; y=torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2))
return x,y.unsqueeze(1)
def d_nested(n=1000):
x=torch.rand(n,2)*2-1; y=torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1])
return x,y.unsqueeze(1)
def d_spiral(n=1000):
t=torch.linspace(0,4*np.pi,n//2); r=torch.linspace(.3,2,n//2)
x1=torch.stack([r*torch.cos(t),r*torch.sin(t)],1)
x2=torch.stack([r*torch.cos(t+np.pi),r*torch.sin(t+np.pi)],1)
x=torch.cat([x1,x2])+torch.randn(n,2)*.05
y=torch.cat([torch.zeros(n//2),torch.ones(n//2)]).long()
p=torch.randperm(n); return x[p],y[p]
def d_checker(n=1000):
x=torch.rand(n,2)*2-1; y=((torch.sin(3*math.pi*x[:,0])*torch.sin(3*math.pi*x[:,1]))>0).long()
return x,y
def d_highfreq(n=1000):
x=torch.linspace(-1,1,n).unsqueeze(1); return x,torch.sin(20*x)+torch.sin(50*x)+.5*torch.sin(100*x)
def d_mem(n=200): return torch.randn(n,8),torch.randn(n,4)
def d_ood_tr(n=800):
x=torch.rand(n,2)*2-1; y=torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]
return x,y.unsqueeze(1)
def d_ood_te(n=300):
x=torch.rand(n,2)+1; y=torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]
return x,y.unsqueeze(1)
# ── Main ──
def main():
print("="*80)
print(" v9: CONTROLLED FREQ + PHASE + GATE (separated paths)")
print(" ω(x) = ω0·(1+0.1·tanh(W_ω·x)) | φ(x) = π·tanh(W_φ·x)")
print(" y = LN( α⊙per + (1-α)⊙val + res ) | λ=5e-4 polarization")
print("="*80)
N=3
models = {
'Vanilla': (VanillaMLP, {}),
'SinGLU': (SinGLUNet, {'w0':None}),
'v9': (v9Net, {'w0':None}),
}
tasks = [
("Complex Fn", "reg", d_complex, 4,1, 5000, 300, 1e-3, 30., 750),
("Nested Fn", "reg", d_nested, 2,1, 3000, 300, 1e-3, 20., 750),
("Spiral", "clf", d_spiral, 2,2, 3000, 250, 1e-3, 15., 700),
("Checkerboard","clf", d_checker, 2,2, 3000, 250, 1e-3, 20., 700),
("High-Freq", "reg", d_highfreq, 1,1, 8000, 300, 1e-3, 60., 700),
("Memorize", "reg", d_mem, 8,4, 5000, 400, 1e-3, 10., 200),
]
all_res = {}; diag = {}
for tn,tt,df,di,do,bud,ep,lr,w0,sp in tasks:
print(f"\n{'━'*80}\n {tn} | ~{bud:,} params\n{'━'*80}")
hs={}
for mn,(mc,mk) in models.items():
kw={k:(w0 if v is None else v) for k,v in mk.items()}
hs[mn]=find_h(di,do,N,bud,mc,**kw)
tr={}
for mn,(mc,mk) in models.items():
kw={k:(w0 if v is None else v) for k,v in mk.items()}
h=hs[mn]; scores=[]
for seed in SEEDS:
set_seed(seed); x,y=df()
if sp>=len(x): xtr,ytr,xte,yte=x,y,x,y
else: xtr,ytr,xte,yte=x[:sp],y[:sp],x[sp:],y[sp:]
set_seed(seed+100); mdl=mc(di,do,h,N,**kw)
if tt=='reg': s=train_reg(mdl,xtr,ytr,xte,yte,ep,lr)
else: s=train_clf(mdl,xtr,ytr,xte,yte,ep,lr)
scores.append(s)
if mn=='v9' and seed==SEEDS[-1]:
mdl.eval()
with torch.no_grad():
als,phs,oms=mdl.get_all_diag(xte[:100])
aa=torch.cat([a.flatten() for a in als])
pp=torch.cat([p.flatten() for p in phs])
oo=torch.cat([o.flatten() for o in oms])
diag[tn]={
'a_m':aa.mean().item(),'a_s':aa.std().item(),
'a_lo':(aa<.3).float().mean().item(),'a_hi':(aa>.7).float().mean().item(),
'p_s':pp.std().item(),'p_m':pp.mean().item(),
'o_m':oo.mean().item(),'o_s':oo.std().item(),
'o_min':oo.min().item(),'o_max':oo.max().item(),
}
p=nparams(mc(di,do,h,N,**kw))
tr[mn]={'mean':np.mean(scores),'std':np.std(scores),'scores':scores,'params':p,'hidden':h}
is_reg=tt=='reg'
if is_reg: best=min(tr,key=lambda k:tr[k]['mean'])
else: best=max(tr,key=lambda k:tr[k]['mean'])
met="MSE ↓" if is_reg else "Acc ↑"
print(f"\n {'Model':<10} {'H':>4} {'P':>6} {met+' (mean±std)':>26}")
print(f" {'─'*50}")
for mn,r in tr.items():
m,s=r['mean'],r['std']
ms=f"{m:.2e}±{s:.1e}" if(is_reg and m<.001) else(f"{m:.4f}±{s:.4f}" if is_reg else f"{m:.1%}±{s:.3f}")
print(f" {mn:<10} {r['hidden']:>4} {r['params']:>6,} {ms:>26}{' ★' if mn==best else ''}")
print(f" → {best}")
if tn in diag:
d=diag[tn]
print(f" α: {d['a_m']:.3f}±{d['a_s']:.3f} ({d['a_lo']:.0%} lin, {d['a_hi']:.0%} per)")
print(f" φ: std={d['p_s']:.3f}")
print(f" ω: {d['o_m']:.1f}±{d['o_s']:.2f} [{d['o_min']:.1f},{d['o_max']:.1f}]")
all_res[tn]=tr
# OOD
print(f"\n{'━'*80}\n OOD: [-1,1] → [1,2]\n{'━'*80}")
ood_r={}; ood_d={}
for mn,(mc,mk) in models.items():
kw={k:(20. if v is None else v) for k,v in mk.items()}
h=find_h(2,1,N,5000,mc,**kw); ids,oods=[],[]
for seed in SEEDS:
set_seed(seed); xtr,ytr=d_ood_tr()
set_seed(seed+50); xid=torch.rand(200,2)*2-1
yid=(torch.sin(3*math.pi*xid[:,0])*torch.cos(3*math.pi*xid[:,1])+xid[:,0]*xid[:,1]).unsqueeze(1)
set_seed(seed+50); xoo,yoo=d_ood_te()
set_seed(seed+100); mdl=mc(2,1,h,N,**kw)
si=train_reg(mdl,xtr,ytr,xid,yid,300,1e-3)
mdl.eval()
with torch.no_grad(): so=F.mse_loss(mdl(xoo),yoo).item()
ids.append(si); oods.append(so)
if mn=='v9' and seed==SEEDS[-1]:
mdl.eval()
with torch.no_grad():
ai,_,oi=mdl.get_all_diag(xid[:100])
ao,_,oo2=mdl.get_all_diag(xoo[:100])
ood_d={
'id_a':torch.cat([a.flatten() for a in ai]).mean().item(),
'ood_a':torch.cat([a.flatten() for a in ao]).mean().item(),
'id_o':torch.cat([o.flatten() for o in oi]).mean().item(),
'ood_o':torch.cat([o.flatten() for o in oo2]).mean().item(),
}
p=nparams(mc(2,1,h,N,**kw))
ood_r[mn]={'id':np.mean(ids),'ood':np.mean(oods),'p':p,
'deg':np.mean(oods)/max(np.mean(ids),1e-10),
'is':np.std(ids),'os':np.std(oods)}
bo=min(ood_r,key=lambda k:ood_r[k]['ood'])
print(f"\n {'Model':<10} {'ID MSE':>14} {'OOD MSE':>14} {'Deg':>8}")
print(f" {'─'*50}")
for mn,r in ood_r.items():
print(f" {mn:<10} {r['id']:>9.4f}±{r['is']:.3f} {r['ood']:>9.4f}±{r['os']:.3f} {r['deg']:>7.1f}x{' ★' if mn==bo else ''}")
print(f" → {bo}")
if ood_d:
print(f"\n v9 on OOD:")
print(f" α: ID={ood_d['id_a']:.4f} → OOD={ood_d['ood_a']:.4f} (shift={ood_d['ood_a']-ood_d['id_a']:+.4f})")
print(f" ω: ID={ood_d['id_o']:.2f} → OOD={ood_d['ood_o']:.2f} (shift={ood_d['ood_o']-ood_d['id_o']:+.2f})")
all_res['OOD']={mn:{'mean':r['ood'],'std':r['os']} for mn,r in ood_r.items()}
# Summary
print(f"\n{'='*80}\n SUMMARY\n{'='*80}")
wc={k:0 for k in models}
print(f"\n {'Task':<18}",end="")
for mn in models: print(f" {mn:>12}",end="")
print(f" {'W':>8}")
print(f" {'─'*56}")
for tn,t in all_res.items():
sc={k:v['mean'] for k,v in t.items()}
mx=max(sc.values()); is_c=mx>.5 and mx<=1 and min(sc.values())>=0
if min(sc.values())<.001: is_c=False
w=min(sc,key=sc.get) if (tn=='OOD' or not is_c) else max(sc,key=sc.get)
wc[w]+=1
row=f" {tn:<18}"
for mn in models:
s=sc[mn]
if is_c: row+=f" {s:>11.1%}"
elif s<.001: row+=f" {s:>11.2e}"
else: row+=f" {s:>11.4f}"
row+=f" {'->'+w:>8}"; print(row)
print(f"\n {'─'*56}")
for mn,c in sorted(wc.items(),key=lambda x:-x[1]):
print(f" {mn:<10} {c} wins {'█'*c*4}")
# Diag summary
print(f"\n v9 DIAGNOSTICS:")
print(f" {'Task':<18} {'α':>7} {'α_std':>7} {'%L':>5} {'%P':>5} {'φ_std':>7} {'ω':>7} {'ω_std':>7} {'ω range':>14}")
print(f" {'─'*80}")
for tn,d in diag.items():
print(f" {tn:<18} {d['a_m']:>7.3f} {d['a_s']:>7.3f} {d['a_lo']:>4.0%} {d['a_hi']:>4.0%}"
f" {d['p_s']:>7.3f} {d['o_m']:>7.1f} {d['o_s']:>7.3f} [{d['o_min']:.1f},{d['o_max']:.1f}]")
sv={'tasks':{},'ood':{},'diag':diag,'ood_diag':ood_d}
for tn,t in all_res.items():
sv['tasks'][tn]={mn:{'mean':float(r['mean']),'std':float(r.get('std',0)),
'scores':[float(s) for s in r.get('scores',[r['mean']])],
'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in t.items()}
sv['ood']={mn:{k:float(v) if isinstance(v,(float,np.floating)) else v for k,v in r.items()} for mn,r in ood_r.items()}
with open('/app/results_v9.json','w') as f: json.dump(sv,f,indent=2,default=str)
print("\n Saved to /app/results_v9.json")
if __name__=="__main__": main()