#!/usr/bin/env python3 """ v11: SinGLU + DISCIPLINED Phase v10 problem: φ(x) = π·tanh(Wφ·x) is too powerful. Phase std ~0.3 rad destroys frequency stability on high-freq, memorization, OOD. Three surgical fixes (from critique): 1. Scale phase DOWN: φ = 0.1·π·tanh(Wφ·x) not full π 2. Tie phase to features: sin(ω·(Wg·x + φ)) not sin(ω·Wg·x + φ) 3. That's it. No gate, no freq mod, no extra paths. """ import torch, torch.nn as nn, torch.nn.functional as F import numpy as np, math, json SEEDS = [0, 1, 2] def set_seed(s): torch.manual_seed(s); np.random.seed(s) # ── Baselines ── class VanillaMLP(nn.Module): def __init__(self,di,do,h,n): super().__init__() L=[]; p=di for _ in range(n): L+=[nn.Linear(p,h),nn.ReLU()]; p=h L.append(nn.Linear(p,do)); self.net=nn.Sequential(*L) def forward(self,x): return self.net(x) class SinGLULayer(nn.Module): def __init__(self,di,do,mid,w0=30.): super().__init__() self.Wg=nn.Linear(di,mid,bias=False); self.Wv=nn.Linear(di,mid,bias=False) self.Wo=nn.Linear(mid,do,bias=True); self.w0=w0; self.ln=nn.LayerNorm(do) with torch.no_grad(): self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0) nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight) def forward(self,x): return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x))*self.Wv(x))) class SinGLUNet(nn.Module): def __init__(self,di,do,h,n,w0=30.): super().__init__() mid=max(2,int(h*2/3)); L=[]; p=di for _ in range(n): L.append(SinGLULayer(p,h,mid,w0)); p=h L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L) def forward(self,x): for l in self.layers: x=l(x) return x # ── v11: Disciplined Phase ── class v11Layer(nn.Module): """ FIX 1: Scale phase to 0.1·π (not full π) FIX 2: Phase tied to feature space: sin(ω·(Wg·x + α·φ(x))) core = sin( ω · (Wg·x + 0.1·tanh(Wφ·x)) ) y = LN( Wo( core ⊙ Wv·x ) ) """ def __init__(self, di, do, mid, w0=30., phase_scale=0.1): super().__init__() self.Wg=nn.Linear(di,mid,bias=False) self.Wv=nn.Linear(di,mid,bias=False) self.Wo=nn.Linear(mid,do,bias=True) self.Wphi=nn.Linear(di,mid,bias=True) # phase (tied to feature space) self.w0=w0; self.ps=phase_scale; self.ln=nn.LayerNorm(do) with torch.no_grad(): self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0) nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight) nn.init.zeros_(self.Wphi.weight); nn.init.zeros_(self.Wphi.bias) def forward(self,x): g = self.Wg(x) phi = self.ps * torch.tanh(self.Wphi(x)) # small, bounded core = torch.sin(self.w0 * (g + phi)) # phase IN feature space return self.ln(self.Wo(core * self.Wv(x))) def get_phi(self,x): with torch.no_grad(): return self.ps * torch.tanh(self.Wphi(x)) class v11Net(nn.Module): def __init__(self,di,do,h,n,w0=30.): super().__init__() mid=max(2,int(h*2/3)); L=[]; p=di for _ in range(n): L.append(v11Layer(p,h,mid,w0)); p=h L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L) def forward(self,x): for l in self.layers: x=l(x) return x def get_all_phi(self,x): P=[]; h=x for l in self.layers: if isinstance(l,v11Layer): P.append(l.get_phi(h)); h=l(h) else: h=l(h) return P # ── Utils ── def np_(m): return sum(p.numel() for p in m.parameters() if p.requires_grad) def fh(di,do,n,t,cls,**kw): lo,hi,b=2,512,2 while lo<=hi: mid=(lo+hi)//2; p=np_(cls(di,do,mid,n,**kw)) if abs(p-t)0).long() def d_hf(n=1000): x=torch.linspace(-1,1,n).unsqueeze(1); return x,torch.sin(20*x)+torch.sin(50*x)+.5*torch.sin(100*x) def d_mm(n=200): return torch.randn(n,8),torch.randn(n,4) def d_ot(n=800): x=torch.rand(n,2)*2-1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1) def d_oe(n=300): x=torch.rand(n,2)+1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1) # ── Main ── def main(): print("="*80) print(" v11: SinGLU + DISCIPLINED Phase") print(" φ scaled to 0.1, tied to feature space: sin(ω·(Wg·x + 0.1·tanh(Wφ·x)))") print("="*80) N=3; Ms={'Vanilla':(VanillaMLP,{}),'SinGLU':(SinGLUNet,{'w0':None}),'v11':(v11Net,{'w0':None})} tasks=[ ("Complex Fn","r",d_cx,4,1,5000,300,1e-3,30.,750), ("Nested Fn","r",d_ne,2,1,3000,300,1e-3,20.,750), ("Spiral","c",d_sp,2,2,3000,250,1e-3,15.,700), ("Checker","c",d_ch,2,2,3000,250,1e-3,20.,700), ("High-Freq","r",d_hf,1,1,8000,300,1e-3,60.,700), ("Memorize","r",d_mm,8,4,5000,400,1e-3,10.,200), ] R={}; PH={} for tn,tt,df,di,do,bud,ep,lr,w0,sp in tasks: print(f"\n{'━'*80}\n {tn} | ~{bud:,}p\n{'━'*80}") hs={mn:fh(di,do,N,bud,mc,**{k:(w0 if v is None else v) for k,v in mk.items()}) for mn,(mc,mk) in Ms.items()} tr={} for mn,(mc,mk) in Ms.items(): kw={k:(w0 if v is None else v) for k,v in mk.items()}; h=hs[mn]; sc=[] for seed in SEEDS: set_seed(seed); x,y=df() if sp>=len(x): xt,yt,xe,ye=x,y,x,y else: xt,yt,xe,ye=x[:sp],y[:sp],x[sp:],y[sp:] set_seed(seed+100); mdl=mc(di,do,h,N,**kw) s=tr_r(mdl,xt,yt,xe,ye,ep,lr) if tt=='r' else tr_c(mdl,xt,yt,xe,ye,ep,lr) sc.append(s) if mn=='v11' and seed==SEEDS[-1]: mdl.eval() with torch.no_grad(): pp=mdl.get_all_phi(xe[:100]); ap=torch.cat([p.flatten() for p in pp]) PH[tn]={'m':ap.mean().item(),'s':ap.std().item(),'mn':ap.min().item(),'mx':ap.max().item()} p=np_(mc(di,do,h,N,**kw)) tr[mn]={'mean':np.mean(sc),'std':np.std(sc),'scores':sc,'params':p,'hidden':h} ir=tt=='r' best=min(tr,key=lambda k:tr[k]['mean']) if ir else max(tr,key=lambda k:tr[k]['mean']) met="MSE ↓" if ir else "Acc ↑" print(f"\n {'M':<8} {'H':>3} {'P':>6} {met+' (mean±std)':>26}") print(f" {'─'*46}") for mn,r in tr.items(): m,s=r['mean'],r['std'] ms=f"{m:.2e}±{s:.1e}" if(ir and m<.001) else(f"{m:.4f}±{s:.4f}" if ir else f"{m:.1%}±{s:.3f}") print(f" {mn:<8} {r['hidden']:>3} {r['params']:>6,} {ms:>26}{' ★' if mn==best else ''}") print(f" → {best}") if tn in PH: d=PH[tn]; print(f" φ: std={d['s']:.4f} range=[{d['mn']:.3f},{d['mx']:.3f}]") R[tn]=tr # OOD print(f"\n{'━'*80}\n OOD: [-1,1] → [1,2]\n{'━'*80}") OD={} for mn,(mc,mk) in Ms.items(): kw={k:(20. if v is None else v) for k,v in mk.items()}; h=fh(2,1,N,5000,mc,**kw); ids,ods=[],[] for seed in SEEDS: set_seed(seed); xtr,ytr=d_ot() set_seed(seed+50); xi=torch.rand(200,2)*2-1; yi=(torch.sin(3*math.pi*xi[:,0])*torch.cos(3*math.pi*xi[:,1])+xi[:,0]*xi[:,1]).unsqueeze(1) set_seed(seed+50); xo,yo=d_oe() set_seed(seed+100); mdl=mc(2,1,h,N,**kw) si=tr_r(mdl,xtr,ytr,xi,yi,300,1e-3); mdl.eval() with torch.no_grad(): so=F.mse_loss(mdl(xo),yo).item() ids.append(si); ods.append(so) OD[mn]={'id':np.mean(ids),'ood':np.mean(ods),'deg':np.mean(ods)/max(np.mean(ids),1e-10),'p':np_(mc(2,1,h,N,**kw)),'is':np.std(ids),'os':np.std(ods)} bo=min(OD,key=lambda k:OD[k]['ood']) print(f"\n {'M':<8} {'ID':>12} {'OOD':>12} {'Deg':>7}") print(f" {'─'*42}") for mn,r in OD.items(): print(f" {mn:<8} {r['id']:>8.4f}±{r['is']:.3f} {r['ood']:>8.4f}±{r['os']:.3f} {r['deg']:>6.1f}x{' ★' if mn==bo else ''}") print(f" → {bo}") R['OOD']={mn:{'mean':r['ood'],'std':r['os']} for mn,r in OD.items()} # Summary print(f"\n{'='*80}\n SUMMARY: v11 vs SinGLU vs Vanilla\n{'='*80}") wc={k:0 for k in Ms} print(f"\n {'Task':<14}",end="") for mn in Ms: print(f" {mn:>12}",end="") print(f" {'W':>8}") print(f" {'─'*50}") for tn,t in R.items(): sc={k:v['mean'] for k,v in t.items()}; mx=max(sc.values()) ic=mx>.5 and mx<=1 and min(sc.values())>=0 if min(sc.values())<.001: ic=False w=min(sc,key=sc.get) if(tn=='OOD' or not ic) else max(sc,key=sc.get) wc[w]+=1 row=f" {tn:<14}" for mn in Ms: s=sc[mn] if ic: row+=f" {s:>11.1%}" elif s<.001: row+=f" {s:>11.2e}" else: row+=f" {s:>11.4f}" row+=f" ->{'':>1}{w}"; print(row) print(f"\n {'─'*50}") for mn,c in sorted(wc.items(),key=lambda x:-x[1]): print(f" {mn:<8} {c} {'█'*c*4}") # Compare v10 vs v11 φ print(f"\n φ DISCIPLINE CHECK:") print(f" {'Task':<14} {'v11 φ std':>10} {'v10 was':>10} {'Change':>10}") print(f" {'─'*46}") v10_stds={'Complex Fn':.192,'Nested Fn':.142,'Spiral':.242,'Checker':.207,'High-Freq':.321,'Memorize':.206} for tn,d in PH.items(): v10s=v10_stds.get(tn,0) change=f"{d['s']/v10s:.1%}" if v10s>0 else "N/A" print(f" {tn:<14} {d['s']:>10.4f} {v10s:>10.3f} {change:>10}") sv={'tasks':{},'ood':{},'phi':PH} for tn,t in R.items(): sv['tasks'][tn]={mn:{'mean':float(r['mean']),'std':float(r.get('std',0)), 'scores':[float(s) for s in r.get('scores',[r['mean']])], 'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in t.items()} sv['ood']={mn:{k:float(v) if isinstance(v,(float,np.floating)) else v for k,v in r.items()} for mn,r in OD.items()} with open('/app/results_v11.json','w') as f: json.dump(sv,f,indent=2,default=str) print(f"\n Saved to /app/results_v11.json") if __name__=="__main__": main()