anshdadhich's picture
Add v10: SinGLU + phase only — first model to beat SinGLU
890f744 verified
#!/usr/bin/env python3
"""
v10: SinGLU + Phase(x). Nothing else.
One extra matrix Wφ for input-dependent phase shift. That's it.
"""
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np, math, json
SEEDS = [0, 1, 2]
def set_seed(s): torch.manual_seed(s); np.random.seed(s)
# ── SinGLU baseline ──
class SinGLULayer(nn.Module):
def __init__(self, d_in, d_out, mid, w0=30.):
super().__init__()
self.Wg=nn.Linear(d_in,mid,bias=False)
self.Wv=nn.Linear(d_in,mid,bias=False)
self.Wo=nn.Linear(mid,d_out,bias=True)
self.w0=w0; self.ln=nn.LayerNorm(d_out)
with torch.no_grad():
self.Wg.weight.uniform_(-math.sqrt(6/d_in)/w0,math.sqrt(6/d_in)/w0)
nn.init.xavier_uniform_(self.Wv.weight)
nn.init.xavier_uniform_(self.Wo.weight)
def forward(self,x):
return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x))*self.Wv(x)))
class SinGLUNet(nn.Module):
def __init__(self,di,do,h,n,w0=30.):
super().__init__()
mid=max(2,int(h*2/3)); layers=[]; prev=di
for _ in range(n): layers.append(SinGLULayer(prev,h,mid,w0)); prev=h
layers.append(nn.Linear(prev,do)); self.layers=nn.ModuleList(layers)
def forward(self,x):
for l in self.layers: x=l(x)
return x
# ── Vanilla ──
class VanillaMLP(nn.Module):
def __init__(self,di,do,h,n):
super().__init__()
layers=[]; prev=di
for _ in range(n): layers+=[nn.Linear(prev,h),nn.ReLU()]; prev=h
layers.append(nn.Linear(prev,do)); self.net=nn.Sequential(*layers)
def forward(self,x): return self.net(x)
# ── v10: SinGLU + Phase. ONE extra matrix. ──
class v10Layer(nn.Module):
"""
core = sin(w0 · Wg·x + φ(x)) where φ(x) = π·tanh(Wφ·x)
y = LN( Wo( core ⊙ Wv·x ) )
vs SinGLU: identical except +Wφ for phase.
Wφ starts at 0 → v10 = SinGLU at initialization.
"""
def __init__(self, di, do, mid, w0=30.):
super().__init__()
self.Wg=nn.Linear(di,mid,bias=False)
self.Wv=nn.Linear(di,mid,bias=False)
self.Wo=nn.Linear(mid,do,bias=True)
self.Wphi=nn.Linear(di,mid,bias=True) # THE ONLY ADDITION
self.w0=w0; self.ln=nn.LayerNorm(do)
with torch.no_grad():
self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0)
nn.init.xavier_uniform_(self.Wv.weight)
nn.init.xavier_uniform_(self.Wo.weight)
nn.init.zeros_(self.Wphi.weight) # start = SinGLU exactly
nn.init.zeros_(self.Wphi.bias)
def forward(self,x):
phi=math.pi*torch.tanh(self.Wphi(x))
core=torch.sin(self.w0*self.Wg(x)+phi)
return self.ln(self.Wo(core*self.Wv(x)))
def get_phi(self,x):
with torch.no_grad():
return math.pi*torch.tanh(self.Wphi(x))
class v10Net(nn.Module):
def __init__(self,di,do,h,n,w0=30.):
super().__init__()
mid=max(2,int(h*2/3)); layers=[]; prev=di
for _ in range(n): layers.append(v10Layer(prev,h,mid,w0)); prev=h
layers.append(nn.Linear(prev,do)); self.layers=nn.ModuleList(layers)
def forward(self,x):
for l in self.layers: x=l(x)
return x
def get_all_phi(self,x):
phis=[]; h=x
for l in self.layers:
if isinstance(l,v10Layer): phis.append(l.get_phi(h)); h=l(h)
else: h=l(h)
return phis
# ── Utils ──
def np_(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)
def fh(di,do,n,t,cls,**kw):
lo,hi,b=2,512,2
while lo<=hi:
mid=(lo+hi)//2; p=np_(cls(di,do,mid,n,**kw))
if abs(p-t)<abs(np_(cls(di,do,b,n,**kw))-t): b=mid
if p<t: lo=mid+1
else: hi=mid-1
return b
def tr_reg(m,xtr,ytr,xte,yte,ep,lr,bs=256):
o=torch.optim.Adam(m.parameters(),lr=lr)
s=torch.optim.lr_scheduler.CosineAnnealingLR(o,T_max=ep)
best=float('inf'); n=len(xtr)
for e in range(ep):
m.train(); p=torch.randperm(n)
for i in range(0,n,bs):
idx=p[i:i+bs]; loss=F.mse_loss(m(xtr[idx]),ytr[idx])
o.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(m.parameters(),1.); o.step()
s.step()
if(e+1)%max(1,ep//10)==0:
m.eval()
with torch.no_grad(): best=min(best,F.mse_loss(m(xte),yte).item())
m.eval()
with torch.no_grad(): best=min(best,F.mse_loss(m(xte),yte).item())
return best
def tr_clf(m,xtr,ytr,xte,yte,ep,lr,bs=256):
o=torch.optim.Adam(m.parameters(),lr=lr)
s=torch.optim.lr_scheduler.CosineAnnealingLR(o,T_max=ep)
best=0; n=len(xtr)
for e in range(ep):
m.train(); p=torch.randperm(n)
for i in range(0,n,bs):
idx=p[i:i+bs]; loss=F.cross_entropy(m(xtr[idx]),ytr[idx])
o.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(m.parameters(),1.); o.step()
s.step()
if(e+1)%max(1,ep//10)==0:
m.eval()
with torch.no_grad(): best=max(best,(m(xte).argmax(1)==yte).float().mean().item())
m.eval()
with torch.no_grad(): best=max(best,(m(xte).argmax(1)==yte).float().mean().item())
return best
# ── Data ──
def d_complex(n=1000):
x=torch.rand(n,4)*2-1; return x,torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2)).unsqueeze(1)
def d_nested(n=1000):
x=torch.rand(n,2)*2-1; return x,(torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1])).unsqueeze(1)
def d_spiral(n=1000):
t=torch.linspace(0,4*np.pi,n//2); r=torch.linspace(.3,2,n//2)
x=torch.cat([torch.stack([r*torch.cos(t),r*torch.sin(t)],1),torch.stack([r*torch.cos(t+np.pi),r*torch.sin(t+np.pi)],1)])+torch.randn(n,2)*.05
y=torch.cat([torch.zeros(n//2),torch.ones(n//2)]).long(); p=torch.randperm(n); return x[p],y[p]
def d_check(n=1000):
x=torch.rand(n,2)*2-1; return x,((torch.sin(3*math.pi*x[:,0])*torch.sin(3*math.pi*x[:,1]))>0).long()
def d_hf(n=1000):
x=torch.linspace(-1,1,n).unsqueeze(1); return x,torch.sin(20*x)+torch.sin(50*x)+.5*torch.sin(100*x)
def d_mem(n=200): return torch.randn(n,8),torch.randn(n,4)
def d_ood_tr(n=800):
x=torch.rand(n,2)*2-1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1)
def d_ood_te(n=300):
x=torch.rand(n,2)+1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1)
# ── Main ──
def main():
print("="*80)
print(" v10: SinGLU + Phase(x). Nothing else.")
print(" One extra Wφ matrix. Starts as pure SinGLU (Wφ=0).")
print("="*80)
N=3
Ms={'Vanilla':(VanillaMLP,{}),'SinGLU':(SinGLUNet,{'w0':None}),'v10':(v10Net,{'w0':None})}
tasks=[
("Complex Fn","reg",d_complex,4,1,5000,300,1e-3,30.,750),
("Nested Fn","reg",d_nested,2,1,3000,300,1e-3,20.,750),
("Spiral","clf",d_spiral,2,2,3000,250,1e-3,15.,700),
("Checkerboard","clf",d_check,2,2,3000,250,1e-3,20.,700),
("High-Freq","reg",d_hf,1,1,8000,300,1e-3,60.,700),
("Memorize","reg",d_mem,8,4,5000,400,1e-3,10.,200),
]
R={}; PHI={}
for tn,tt,df,di,do,bud,ep,lr,w0,sp in tasks:
print(f"\n{'━'*80}\n {tn} | ~{bud:,}p\n{'━'*80}")
hs={mn:fh(di,do,N,bud,mc,**{k:(w0 if v is None else v) for k,v in mk.items()}) for mn,(mc,mk) in Ms.items()}
tr={}
for mn,(mc,mk) in Ms.items():
kw={k:(w0 if v is None else v) for k,v in mk.items()}; h=hs[mn]; sc=[]
for seed in SEEDS:
set_seed(seed); x,y=df()
if sp>=len(x): xtr,ytr,xte,yte=x,y,x,y
else: xtr,ytr,xte,yte=x[:sp],y[:sp],x[sp:],y[sp:]
set_seed(seed+100); mdl=mc(di,do,h,N,**kw)
s=tr_reg(mdl,xtr,ytr,xte,yte,ep,lr) if tt=='reg' else tr_clf(mdl,xtr,ytr,xte,yte,ep,lr)
sc.append(s)
if mn=='v10' and seed==SEEDS[-1]:
mdl.eval()
with torch.no_grad():
pp=mdl.get_all_phi(xte[:100])
ap=torch.cat([p.flatten() for p in pp])
PHI[tn]={'m':ap.mean().item(),'s':ap.std().item(),'mn':ap.min().item(),'mx':ap.max().item()}
p=np_(mc(di,do,h,N,**kw))
tr[mn]={'mean':np.mean(sc),'std':np.std(sc),'scores':sc,'params':p,'hidden':h}
ir=tt=='reg'
best=min(tr,key=lambda k:tr[k]['mean']) if ir else max(tr,key=lambda k:tr[k]['mean'])
met="MSE ↓" if ir else "Acc ↑"
print(f"\n {'Model':<10} {'H':>3} {'P':>6} {met+' (mean±std)':>26}")
print(f" {'─'*48}")
for mn,r in tr.items():
m,s=r['mean'],r['std']
ms=f"{m:.2e}±{s:.1e}" if(ir and m<.001) else(f"{m:.4f}±{s:.4f}" if ir else f"{m:.1%}±{s:.3f}")
print(f" {mn:<10} {r['hidden']:>3} {r['params']:>6,} {ms:>26}{' ★' if mn==best else ''}")
print(f" → {best}")
if tn in PHI:
d=PHI[tn]; print(f" v10 φ: mean={d['m']:.3f} std={d['s']:.3f} range=[{d['mn']:.2f},{d['mx']:.2f}]")
R[tn]=tr
# OOD
print(f"\n{'━'*80}\n OOD: [-1,1] → [1,2]\n{'━'*80}")
OOD={}; od_phi={}
for mn,(mc,mk) in Ms.items():
kw={k:(20. if v is None else v) for k,v in mk.items()}
h=fh(2,1,N,5000,mc,**kw); ids,ods=[],[]
for seed in SEEDS:
set_seed(seed); xtr,ytr=d_ood_tr()
set_seed(seed+50); xi=torch.rand(200,2)*2-1; yi=(torch.sin(3*math.pi*xi[:,0])*torch.cos(3*math.pi*xi[:,1])+xi[:,0]*xi[:,1]).unsqueeze(1)
set_seed(seed+50); xo,yo=d_ood_te()
set_seed(seed+100); mdl=mc(2,1,h,N,**kw)
si=tr_reg(mdl,xtr,ytr,xi,yi,300,1e-3)
mdl.eval()
with torch.no_grad(): so=F.mse_loss(mdl(xo),yo).item()
ids.append(si); ods.append(so)
if mn=='v10' and seed==SEEDS[-1]:
mdl.eval()
with torch.no_grad():
pi=mdl.get_all_phi(xi[:100]); po=mdl.get_all_phi(xo[:100])
od_phi={'id':torch.cat([p.flatten() for p in pi]).mean().item(),
'ood':torch.cat([p.flatten() for p in po]).mean().item(),
'id_s':torch.cat([p.flatten() for p in pi]).std().item(),
'ood_s':torch.cat([p.flatten() for p in po]).std().item()}
OOD[mn]={'id':np.mean(ids),'ood':np.mean(ods),'deg':np.mean(ods)/max(np.mean(ids),1e-10),
'is':np.std(ids),'os':np.std(ods),'p':np_(mc(2,1,h,N,**kw))}
bo=min(OOD,key=lambda k:OOD[k]['ood'])
print(f"\n {'M':<10} {'ID':>12} {'OOD':>12} {'Deg':>7}")
print(f" {'─'*44}")
for mn,r in OOD.items():
print(f" {mn:<10} {r['id']:>8.4f}±{r['is']:.3f} {r['ood']:>8.4f}±{r['os']:.3f} {r['deg']:>6.1f}x{' ★' if mn==bo else ''}")
print(f" → {bo}")
if od_phi:
print(f" v10 φ shift: ID={od_phi['id']:.4f}{od_phi['id_s']:.3f}) → OOD={od_phi['ood']:.4f}{od_phi['ood_s']:.3f})")
R['OOD']={mn:{'mean':r['ood'],'std':r['os']} for mn,r in OOD.items()}
# Summary
print(f"\n{'='*80}\n SUMMARY\n{'='*80}")
wc={k:0 for k in Ms}
print(f"\n {'Task':<16}",end="")
for mn in Ms: print(f" {mn:>12}",end="")
print(f" {'W':>8}")
print(f" {'─'*52}")
for tn,t in R.items():
sc={k:v['mean'] for k,v in t.items()}
mx=max(sc.values()); ic=mx>.5 and mx<=1 and min(sc.values())>=0
if min(sc.values())<.001: ic=False
w=min(sc,key=sc.get) if(tn=='OOD' or not ic) else max(sc,key=sc.get)
wc[w]+=1
row=f" {tn:<16}"
for mn in Ms:
s=sc[mn]
if ic: row+=f" {s:>11.1%}"
elif s<.001: row+=f" {s:>11.2e}"
else: row+=f" {s:>11.4f}"
row+=f" ->{'':>1}{w}"; print(row)
print(f"\n {'─'*52}")
for mn,c in sorted(wc.items(),key=lambda x:-x[1]):
print(f" {mn:<10} {c} {'█'*c*4}")
print(f"\n φ ANALYSIS (did phase learn something useful?):")
for tn,d in PHI.items():
status="ACTIVE" if d['s']>.1 else "weak" if d['s']>.01 else "DEAD"
print(f" {tn:<16} std={d['s']:.3f} range=[{d['mn']:.2f},{d['mx']:.2f}] {status}")
sv={'tasks':{},'ood':{},'phi':PHI,'ood_phi':od_phi}
for tn,t in R.items():
sv['tasks'][tn]={mn:{'mean':float(r['mean']),'std':float(r.get('std',0)),
'scores':[float(s) for s in r.get('scores',[r['mean']])],
'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in t.items()}
sv['ood']={mn:{k:float(v) if isinstance(v,(float,np.floating)) else v for k,v in r.items()} for mn,r in OOD.items()}
with open('/app/results_v10.json','w') as f: json.dump(sv,f,indent=2,default=str)
print(f"\n Saved to /app/results_v10.json")
if __name__=="__main__": main()