| |
| """ |
| v15: Dual-Phase Decomposition (low-freq structure + high-freq detail) |
| + KILLER EXPERIMENT: Train on low freq, test on high freq |
| |
| low = sin(ω·g + β·φ) ← structure |
| high = sin(2ω·g + γ·φ) ← detail |
| core = low ⊙ (1 + α·high) ← AM modulation |
| |
| + Freq Generalization: train sin(2πx), test sin(10πx) |
| + Mixed Freq: train sin(2πx)+sin(4πx), test sin(2πx)+sin(20πx) |
| """ |
|
|
| import torch, torch.nn as nn, torch.nn.functional as F |
| import numpy as np, math, json |
|
|
| SEEDS = [0, 1, 2] |
| def set_seed(s): torch.manual_seed(s); np.random.seed(s) |
|
|
| |
|
|
| class VanillaMLP(nn.Module): |
| def __init__(self,di,do,h,n): |
| super().__init__(); L=[]; p=di |
| for _ in range(n): L+=[nn.Linear(p,h),nn.ReLU()]; p=h |
| L.append(nn.Linear(p,do)); self.net=nn.Sequential(*L) |
| def forward(self,x): return self.net(x) |
|
|
| class SinGLULayer(nn.Module): |
| def __init__(self,di,do,mid,w0=30.): |
| super().__init__() |
| self.Wg=nn.Linear(di,mid,bias=False); self.Wv=nn.Linear(di,mid,bias=False) |
| self.Wo=nn.Linear(mid,do,bias=True); self.w0=w0; self.ln=nn.LayerNorm(do) |
| with torch.no_grad(): |
| self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0) |
| nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight) |
| def forward(self,x): return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x))*self.Wv(x))) |
|
|
| class SinGLUNet(nn.Module): |
| def __init__(self,di,do,h,n,w0=30.): |
| super().__init__(); mid=max(2,int(h*2/3)); L=[]; p=di |
| for _ in range(n): L.append(SinGLULayer(p,h,mid,w0)); p=h |
| L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L) |
| def forward(self,x): |
| for l in self.layers: x=l(x) |
| return x |
|
|
| |
| class v10Layer(nn.Module): |
| def __init__(self,di,do,mid,w0=30.): |
| super().__init__() |
| self.Wg=nn.Linear(di,mid,bias=False); self.Wv=nn.Linear(di,mid,bias=False) |
| self.Wo=nn.Linear(mid,do,bias=True); self.Wphi=nn.Linear(di,mid,bias=True) |
| self.w0=w0; self.ln=nn.LayerNorm(do) |
| with torch.no_grad(): |
| self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0) |
| nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight) |
| nn.init.zeros_(self.Wphi.weight); nn.init.zeros_(self.Wphi.bias) |
| def forward(self,x): |
| phi=math.pi*torch.tanh(self.Wphi(x)) |
| return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x)+phi)*self.Wv(x))) |
|
|
| class v10Net(nn.Module): |
| def __init__(self,di,do,h,n,w0=30.): |
| super().__init__(); mid=max(2,int(h*2/3)); L=[]; p=di |
| for _ in range(n): L.append(v10Layer(p,h,mid,w0)); p=h |
| L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L) |
| def forward(self,x): |
| for l in self.layers: x=l(x) |
| return x |
|
|
| |
|
|
| class v15Layer(nn.Module): |
| """ |
| low = sin(ω·g + β·φ) structure channel |
| high = sin(2ω·g + γ·φ) detail channel |
| core = low ⊙ (1 + α·high) AM modulation |
| y = LN(Wo(core ⊙ Wv·x)) |
| """ |
| def __init__(self, di, do, mid, w0=30., beta=0.05, gamma=0.05, alpha=0.3): |
| super().__init__() |
| self.Wg=nn.Linear(di,mid,bias=False); self.Wv=nn.Linear(di,mid,bias=False) |
| self.Wo=nn.Linear(mid,do,bias=True); self.Wphi=nn.Linear(di,mid,bias=True) |
| self.w0=w0; self.beta=beta; self.gamma=gamma; self.alpha=alpha |
| self.ln=nn.LayerNorm(do) |
| with torch.no_grad(): |
| self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0) |
| nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight) |
| nn.init.zeros_(self.Wphi.weight); nn.init.zeros_(self.Wphi.bias) |
| |
| def forward(self,x): |
| g=self.Wg(x); phi=torch.tanh(self.Wphi(x)) |
| low=torch.sin(self.w0*g+self.beta*phi) |
| high=torch.sin(2*self.w0*g+self.gamma*phi) |
| core=low*(1.+self.alpha*high) |
| return self.ln(self.Wo(core*self.Wv(x))) |
| |
| def get_stats(self,x): |
| with torch.no_grad(): |
| phi=torch.tanh(self.Wphi(x)) |
| return {'phi_m':phi.mean().item(),'phi_s':phi.std().item()} |
|
|
| class v15Net(nn.Module): |
| def __init__(self,di,do,h,n,w0=30.): |
| super().__init__(); mid=max(2,int(h*2/3)); L=[]; p=di |
| for _ in range(n): L.append(v15Layer(p,h,mid,w0)); p=h |
| L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L) |
| def forward(self,x): |
| for l in self.layers: x=l(x) |
| return x |
|
|
| |
|
|
| def np_(m): return sum(p.numel() for p in m.parameters() if p.requires_grad) |
| def fh(di,do,n,t,cls,**kw): |
| lo,hi,b=2,512,2 |
| while lo<=hi: |
| mid=(lo+hi)//2; p=np_(cls(di,do,mid,n,**kw)) |
| if abs(p-t)<abs(np_(cls(di,do,b,n,**kw))-t): b=mid |
| if p<t: lo=mid+1 |
| else: hi=mid-1 |
| return b |
|
|
| def tr_r(m,xt,yt,xe,ye,ep,lr,bs=256): |
| o=torch.optim.Adam(m.parameters(),lr=lr); s=torch.optim.lr_scheduler.CosineAnnealingLR(o,T_max=ep) |
| best=float('inf'); n=len(xt) |
| for e in range(ep): |
| m.train(); p=torch.randperm(n) |
| for i in range(0,n,bs): |
| idx=p[i:i+bs]; loss=F.mse_loss(m(xt[idx]),yt[idx]) |
| o.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(m.parameters(),1.); o.step() |
| s.step() |
| if(e+1)%max(1,ep//10)==0: |
| m.eval() |
| with torch.no_grad(): best=min(best,F.mse_loss(m(xe),ye).item()) |
| m.eval() |
| with torch.no_grad(): best=min(best,F.mse_loss(m(xe),ye).item()) |
| return best |
|
|
| def tr_c(m,xt,yt,xe,ye,ep,lr,bs=256): |
| o=torch.optim.Adam(m.parameters(),lr=lr); s=torch.optim.lr_scheduler.CosineAnnealingLR(o,T_max=ep) |
| best=0; n=len(xt) |
| for e in range(ep): |
| m.train(); p=torch.randperm(n) |
| for i in range(0,n,bs): |
| idx=p[i:i+bs]; loss=F.cross_entropy(m(xt[idx]),yt[idx]) |
| o.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(m.parameters(),1.); o.step() |
| s.step() |
| if(e+1)%max(1,ep//10)==0: |
| m.eval() |
| with torch.no_grad(): best=max(best,(m(xe).argmax(1)==ye).float().mean().item()) |
| m.eval() |
| with torch.no_grad(): best=max(best,(m(xe).argmax(1)==ye).float().mean().item()) |
| return best |
|
|
| |
|
|
| def d_cx(n=1000): x=torch.rand(n,4)*2-1; return x,torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2)).unsqueeze(1) |
| def d_ne(n=1000): x=torch.rand(n,2)*2-1; return x,(torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1])).unsqueeze(1) |
| def d_sp(n=1000): |
| t=torch.linspace(0,4*np.pi,n//2); r=torch.linspace(.3,2,n//2) |
| x=torch.cat([torch.stack([r*torch.cos(t),r*torch.sin(t)],1),torch.stack([r*torch.cos(t+np.pi),r*torch.sin(t+np.pi)],1)])+torch.randn(n,2)*.05 |
| y=torch.cat([torch.zeros(n//2),torch.ones(n//2)]).long(); p=torch.randperm(n); return x[p],y[p] |
| def d_ch(n=1000): x=torch.rand(n,2)*2-1; return x,((torch.sin(3*math.pi*x[:,0])*torch.sin(3*math.pi*x[:,1]))>0).long() |
| def d_hf(n=1000): x=torch.linspace(-1,1,n).unsqueeze(1); return x,torch.sin(20*x)+torch.sin(50*x)+.5*torch.sin(100*x) |
| def d_mm(n=200): return torch.randn(n,8),torch.randn(n,4) |
|
|
| |
| def d_freq_train(n=1000): |
| x=torch.linspace(-1,1,n).unsqueeze(1); return x, torch.sin(2*math.pi*x) |
| def d_freq_test(n=1000): |
| x=torch.linspace(-1,1,n).unsqueeze(1); return x, torch.sin(10*math.pi*x) |
| def d_mixed_train(n=1000): |
| x=torch.linspace(-1,1,n).unsqueeze(1); return x, torch.sin(2*math.pi*x)+torch.sin(4*math.pi*x) |
| def d_mixed_test(n=1000): |
| x=torch.linspace(-1,1,n).unsqueeze(1); return x, torch.sin(2*math.pi*x)+torch.sin(20*math.pi*x) |
|
|
| |
| def d_ot(n=800): x=torch.rand(n,2)*2-1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1) |
| def d_oe(n=300): x=torch.rand(n,2)+1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1) |
|
|
| |
|
|
| def main(): |
| print("="*80) |
| print(" v15: DUAL-PHASE DECOMPOSITION + KILLER EXPERIMENT") |
| print(" low=sin(ωg+βφ), high=sin(2ωg+γφ), core=low⊙(1+α·high)") |
| print(" + Freq Gen: train sin(2πx) → test sin(10πx)") |
| print("="*80) |
| |
| N=3 |
| Ms={'Vanilla':(VanillaMLP,{}),'SinGLU':(SinGLUNet,{'w0':None}), |
| 'v10':(v10Net,{'w0':None}),'v15':(v15Net,{'w0':None})} |
| |
| |
| tasks=[ |
| ("Complex","r",d_cx,4,1,5000,300,1e-3,30.,750), |
| ("Nested","r",d_ne,2,1,3000,300,1e-3,20.,750), |
| ("Spiral","c",d_sp,2,2,3000,250,1e-3,15.,700), |
| ("Checker","c",d_ch,2,2,3000,250,1e-3,20.,700), |
| ("HiFreq","r",d_hf,1,1,8000,300,1e-3,60.,700), |
| ("Memorize","r",d_mm,8,4,5000,400,1e-3,10.,200), |
| ] |
| |
| R={} |
| for tn,tt,df,di,do,bud,ep,lr,w0,sp in tasks: |
| print(f"\n{'━'*80}\n {tn}\n{'━'*80}") |
| hs={mn:fh(di,do,N,bud,mc,**{k:(w0 if v is None else v) for k,v in mk.items()}) for mn,(mc,mk) in Ms.items()} |
| tr={} |
| for mn,(mc,mk) in Ms.items(): |
| kw={k:(w0 if v is None else v) for k,v in mk.items()}; h=hs[mn]; sc=[] |
| for seed in SEEDS: |
| set_seed(seed); x,y=df() |
| if sp>=len(x): xt,yt,xe,ye=x,y,x,y |
| else: xt,yt,xe,ye=x[:sp],y[:sp],x[sp:],y[sp:] |
| set_seed(seed+100); mdl=mc(di,do,h,N,**kw) |
| s=tr_r(mdl,xt,yt,xe,ye,ep,lr) if tt=='r' else tr_c(mdl,xt,yt,xe,ye,ep,lr) |
| sc.append(s) |
| p=np_(mc(di,do,h,N,**kw)) |
| tr[mn]={'mean':np.mean(sc),'std':np.std(sc),'scores':sc,'params':p,'hidden':h} |
| ir=tt=='r' |
| best=min(tr,key=lambda k:tr[k]['mean']) if ir else max(tr,key=lambda k:tr[k]['mean']) |
| met="MSE ↓" if ir else "Acc ↑" |
| print(f"\n {'M':<8} {'H':>3} {'P':>6} {met:>24}") |
| print(f" {'─'*44}") |
| for mn,r in tr.items(): |
| m,s=r['mean'],r['std'] |
| ms=f"{m:.2e}±{s:.1e}" if(ir and m<.001) else(f"{m:.4f}±{s:.4f}" if ir else f"{m:.1%}±{s:.3f}") |
| print(f" {mn:<8} {r['hidden']:>3} {r['params']:>6,} {ms:>24}{' ★' if mn==best else ''}") |
| print(f" → {best}") |
| R[tn]=tr |
| |
| |
| |
| |
| |
| print(f"\n{'━'*80}") |
| print(f" 🔥 KILLER EXPERIMENT 1: Frequency Generalization") |
| print(f" Train: sin(2πx) → Test: sin(10πx)") |
| print(f" Can the model generalize to unseen frequencies?") |
| print(f"{'━'*80}") |
| |
| bud_k=4000 |
| xte_k,yte_k=d_freq_test() |
| fg_res={} |
| for mn,(mc,mk) in Ms.items(): |
| kw={k:(30. if v is None else v) for k,v in mk.items()} |
| h=fh(1,1,N,bud_k,mc,**kw); train_sc=[]; test_sc=[] |
| for seed in SEEDS: |
| set_seed(seed); xtr,ytr=d_freq_train() |
| |
| xt,yt=xtr[:800],ytr[:800]; xv,yv=xtr[800:],ytr[800:] |
| set_seed(seed+100); mdl=mc(1,1,h,N,**kw) |
| s_train=tr_r(mdl,xt,yt,xv,yv,300,1e-3) |
| mdl.eval() |
| with torch.no_grad(): s_test=F.mse_loss(mdl(xte_k),yte_k).item() |
| train_sc.append(s_train); test_sc.append(s_test) |
| fg_res[mn]={'train':np.mean(train_sc),'test':np.mean(test_sc),'test_std':np.std(test_sc), |
| 'params':np_(mc(1,1,h,N,**kw)),'hidden':h} |
| |
| print(f"\n {'M':<8} {'H':>3} {'Train MSE':>12} {'Test MSE (10πx)':>18} {'Gap':>8}") |
| print(f" {'─'*52}") |
| best_fg=min(fg_res,key=lambda k:fg_res[k]['test']) |
| for mn,r in fg_res.items(): |
| gap=r['test']/max(r['train'],1e-10) |
| print(f" {mn:<8} {r['hidden']:>3} {r['train']:>12.6f} {r['test']:>12.4f}±{r['test_std']:.3f} {gap:>7.1f}x{' ★' if mn==best_fg else ''}") |
| print(f" → Best freq generalization: {best_fg}") |
| |
| |
| |
| |
| |
| print(f"\n{'━'*80}") |
| print(f" 🔥 KILLER EXPERIMENT 2: Mixed Frequency Decomposition") |
| print(f" Train: sin(2πx)+sin(4πx) → Test: sin(2πx)+sin(20πx)") |
| print(f" Can the model decompose and generalize frequency components?") |
| print(f"{'━'*80}") |
| |
| xte_m,yte_m=d_mixed_test() |
| mf_res={} |
| for mn,(mc,mk) in Ms.items(): |
| kw={k:(30. if v is None else v) for k,v in mk.items()} |
| h=fh(1,1,N,bud_k,mc,**kw); train_sc=[]; test_sc=[] |
| for seed in SEEDS: |
| set_seed(seed); xtr,ytr=d_mixed_train() |
| xt,yt=xtr[:800],ytr[:800]; xv,yv=xtr[800:],ytr[800:] |
| set_seed(seed+100); mdl=mc(1,1,h,N,**kw) |
| s_train=tr_r(mdl,xt,yt,xv,yv,300,1e-3) |
| mdl.eval() |
| with torch.no_grad(): s_test=F.mse_loss(mdl(xte_m),yte_m).item() |
| train_sc.append(s_train); test_sc.append(s_test) |
| mf_res[mn]={'train':np.mean(train_sc),'test':np.mean(test_sc),'test_std':np.std(test_sc)} |
| |
| print(f"\n {'M':<8} {'Train MSE':>12} {'Test MSE (20πx)':>18} {'Gap':>8}") |
| print(f" {'─'*44}") |
| best_mf=min(mf_res,key=lambda k:mf_res[k]['test']) |
| for mn,r in mf_res.items(): |
| gap=r['test']/max(r['train'],1e-10) |
| print(f" {mn:<8} {r['train']:>12.6f} {r['test']:>12.4f}±{r['test_std']:.3f} {gap:>7.1f}x{' ★' if mn==best_mf else ''}") |
| print(f" → Best mixed freq: {best_mf}") |
| |
| |
| |
| |
| print(f"\n{'━'*80}\n OOD: [-1,1] → [1,2]\n{'━'*80}") |
| OD={} |
| for mn,(mc,mk) in Ms.items(): |
| kw={k:(20. if v is None else v) for k,v in mk.items()}; h=fh(2,1,N,5000,mc,**kw); ids,ods=[],[] |
| for seed in SEEDS: |
| set_seed(seed); xtr,ytr=d_ot() |
| set_seed(seed+50); xi=torch.rand(200,2)*2-1; yi=(torch.sin(3*math.pi*xi[:,0])*torch.cos(3*math.pi*xi[:,1])+xi[:,0]*xi[:,1]).unsqueeze(1) |
| set_seed(seed+50); xo,yo=d_oe() |
| set_seed(seed+100); mdl=mc(2,1,h,N,**kw) |
| si=tr_r(mdl,xtr,ytr,xi,yi,300,1e-3); mdl.eval() |
| with torch.no_grad(): so=F.mse_loss(mdl(xo),yo).item() |
| ids.append(si); ods.append(so) |
| OD[mn]={'id':np.mean(ids),'ood':np.mean(ods),'deg':np.mean(ods)/max(np.mean(ids),1e-10)} |
| bo=min(OD,key=lambda k:OD[k]['ood']) |
| print(f"\n {'M':<8} {'ID':>10} {'OOD':>10} {'Deg':>7}") |
| print(f" {'─'*38}") |
| for mn,r in OD.items(): print(f" {mn:<8} {r['id']:>10.4f} {r['ood']:>10.4f} {r['deg']:>6.1f}x{' ★' if mn==bo else ''}") |
| |
| |
| |
| |
| print(f"\n{'='*80}") |
| print(f" GRAND SUMMARY: v15 + KILLER EXPERIMENTS") |
| print(f"{'='*80}") |
| |
| |
| R['OOD']={mn:{'mean':r['ood']} for mn,r in OD.items()} |
| R['FreqGen']={mn:{'mean':r['test']} for mn,r in fg_res.items()} |
| R['MixedFreq']={mn:{'mean':r['test']} for mn,r in mf_res.items()} |
| |
| wc={k:0 for k in Ms} |
| print(f"\n {'Task':<14}",end="") |
| for mn in Ms: print(f" {mn:>10}",end="") |
| print(f" {'W':>8}") |
| print(f" {'─'*60}") |
| |
| for tn,t in R.items(): |
| sc={k:v['mean'] for k,v in t.items()}; mx=max(sc.values()) |
| ic=mx>.5 and mx<=1 and min(sc.values())>=0 |
| if min(sc.values())<.001: ic=False |
| |
| if ic: w=max(sc,key=sc.get) |
| else: w=min(sc,key=sc.get) |
| wc[w]+=1 |
| row=f" {tn:<14}" |
| for mn in Ms: |
| s=sc[mn] |
| if ic: row+=f" {s:>9.1%}" |
| elif s<.001: row+=f" {s:>9.2e}" |
| else: row+=f" {s:>9.4f}" |
| row+=f" ->{w}"; print(row) |
| |
| print(f"\n {'─'*60}") |
| for mn,c in sorted(wc.items(),key=lambda x:-x[1]): |
| print(f" {mn:<8} {c} wins {'█'*c*3}") |
| |
| sv={'tasks':{},'freq_gen':fg_res,'mixed_freq':mf_res,'ood':{mn:{k:float(v) if isinstance(v,(float,np.floating)) else v for k,v in r.items()} for mn,r in OD.items()}} |
| for tn,t in R.items(): |
| sv['tasks'][tn]={mn:{'mean':float(r['mean']),'std':float(r.get('std',0)), |
| 'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in t.items()} |
| with open('/app/results_v15.json','w') as f: json.dump(sv,f,indent=2,default=str) |
| print(f"\n Saved.") |
|
|
| if __name__=="__main__": main() |
|
|