anshdadhich commited on
Commit
76d4be8
·
verified ·
1 Parent(s): d90f9e7

Add v11: disciplined phase (0.1 scale, tied to features)

Browse files
Files changed (1) hide show
  1. benchmark_v11.py +281 -0
benchmark_v11.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ v11: SinGLU + DISCIPLINED Phase
4
+
5
+ v10 problem: φ(x) = π·tanh(Wφ·x) is too powerful. Phase std ~0.3 rad
6
+ destroys frequency stability on high-freq, memorization, OOD.
7
+
8
+ Three surgical fixes (from critique):
9
+ 1. Scale phase DOWN: φ = 0.1·π·tanh(Wφ·x) not full π
10
+ 2. Tie phase to features: sin(ω·(Wg·x + φ)) not sin(ω·Wg·x + φ)
11
+ 3. That's it. No gate, no freq mod, no extra paths.
12
+ """
13
+
14
+ import torch, torch.nn as nn, torch.nn.functional as F
15
+ import numpy as np, math, json
16
+
17
+ SEEDS = [0, 1, 2]
18
+ def set_seed(s): torch.manual_seed(s); np.random.seed(s)
19
+
20
+ # ── Baselines ──
21
+
22
+ class VanillaMLP(nn.Module):
23
+ def __init__(self,di,do,h,n):
24
+ super().__init__()
25
+ L=[]; p=di
26
+ for _ in range(n): L+=[nn.Linear(p,h),nn.ReLU()]; p=h
27
+ L.append(nn.Linear(p,do)); self.net=nn.Sequential(*L)
28
+ def forward(self,x): return self.net(x)
29
+
30
+ class SinGLULayer(nn.Module):
31
+ def __init__(self,di,do,mid,w0=30.):
32
+ super().__init__()
33
+ self.Wg=nn.Linear(di,mid,bias=False); self.Wv=nn.Linear(di,mid,bias=False)
34
+ self.Wo=nn.Linear(mid,do,bias=True); self.w0=w0; self.ln=nn.LayerNorm(do)
35
+ with torch.no_grad():
36
+ self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0)
37
+ nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight)
38
+ def forward(self,x): return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x))*self.Wv(x)))
39
+
40
+ class SinGLUNet(nn.Module):
41
+ def __init__(self,di,do,h,n,w0=30.):
42
+ super().__init__()
43
+ mid=max(2,int(h*2/3)); L=[]; p=di
44
+ for _ in range(n): L.append(SinGLULayer(p,h,mid,w0)); p=h
45
+ L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L)
46
+ def forward(self,x):
47
+ for l in self.layers: x=l(x)
48
+ return x
49
+
50
+ # ── v11: Disciplined Phase ──
51
+
52
+ class v11Layer(nn.Module):
53
+ """
54
+ FIX 1: Scale phase to 0.1·π (not full π)
55
+ FIX 2: Phase tied to feature space: sin(ω·(Wg·x + α·φ(x)))
56
+
57
+ core = sin( ω · (Wg·x + 0.1·tanh(Wφ·x)) )
58
+ y = LN( Wo( core ⊙ Wv·x ) )
59
+ """
60
+ def __init__(self, di, do, mid, w0=30., phase_scale=0.1):
61
+ super().__init__()
62
+ self.Wg=nn.Linear(di,mid,bias=False)
63
+ self.Wv=nn.Linear(di,mid,bias=False)
64
+ self.Wo=nn.Linear(mid,do,bias=True)
65
+ self.Wphi=nn.Linear(di,mid,bias=True) # phase (tied to feature space)
66
+ self.w0=w0; self.ps=phase_scale; self.ln=nn.LayerNorm(do)
67
+ with torch.no_grad():
68
+ self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0)
69
+ nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight)
70
+ nn.init.zeros_(self.Wphi.weight); nn.init.zeros_(self.Wphi.bias)
71
+
72
+ def forward(self,x):
73
+ g = self.Wg(x)
74
+ phi = self.ps * torch.tanh(self.Wphi(x)) # small, bounded
75
+ core = torch.sin(self.w0 * (g + phi)) # phase IN feature space
76
+ return self.ln(self.Wo(core * self.Wv(x)))
77
+
78
+ def get_phi(self,x):
79
+ with torch.no_grad():
80
+ return self.ps * torch.tanh(self.Wphi(x))
81
+
82
+ class v11Net(nn.Module):
83
+ def __init__(self,di,do,h,n,w0=30.):
84
+ super().__init__()
85
+ mid=max(2,int(h*2/3)); L=[]; p=di
86
+ for _ in range(n): L.append(v11Layer(p,h,mid,w0)); p=h
87
+ L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L)
88
+ def forward(self,x):
89
+ for l in self.layers: x=l(x)
90
+ return x
91
+ def get_all_phi(self,x):
92
+ P=[]; h=x
93
+ for l in self.layers:
94
+ if isinstance(l,v11Layer): P.append(l.get_phi(h)); h=l(h)
95
+ else: h=l(h)
96
+ return P
97
+
98
+ # ── Utils ──
99
+
100
+ def np_(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)
101
+ def fh(di,do,n,t,cls,**kw):
102
+ lo,hi,b=2,512,2
103
+ while lo<=hi:
104
+ mid=(lo+hi)//2; p=np_(cls(di,do,mid,n,**kw))
105
+ if abs(p-t)<abs(np_(cls(di,do,b,n,**kw))-t): b=mid
106
+ if p<t: lo=mid+1
107
+ else: hi=mid-1
108
+ return b
109
+
110
+ def tr_r(m,xtr,ytr,xte,yte,ep,lr,bs=256):
111
+ o=torch.optim.Adam(m.parameters(),lr=lr); s=torch.optim.lr_scheduler.CosineAnnealingLR(o,T_max=ep)
112
+ best=float('inf'); n=len(xtr)
113
+ for e in range(ep):
114
+ m.train(); p=torch.randperm(n)
115
+ for i in range(0,n,bs):
116
+ idx=p[i:i+bs]; loss=F.mse_loss(m(xtr[idx]),ytr[idx])
117
+ o.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(m.parameters(),1.); o.step()
118
+ s.step()
119
+ if(e+1)%max(1,ep//10)==0:
120
+ m.eval()
121
+ with torch.no_grad(): best=min(best,F.mse_loss(m(xte),yte).item())
122
+ m.eval()
123
+ with torch.no_grad(): best=min(best,F.mse_loss(m(xte),yte).item())
124
+ return best
125
+
126
+ def tr_c(m,xtr,ytr,xte,yte,ep,lr,bs=256):
127
+ o=torch.optim.Adam(m.parameters(),lr=lr); s=torch.optim.lr_scheduler.CosineAnnealingLR(o,T_max=ep)
128
+ best=0; n=len(xtr)
129
+ for e in range(ep):
130
+ m.train(); p=torch.randperm(n)
131
+ for i in range(0,n,bs):
132
+ idx=p[i:i+bs]; loss=F.cross_entropy(m(xtr[idx]),ytr[idx])
133
+ o.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(m.parameters(),1.); o.step()
134
+ s.step()
135
+ if(e+1)%max(1,ep//10)==0:
136
+ m.eval()
137
+ with torch.no_grad(): best=max(best,(m(xte).argmax(1)==yte).float().mean().item())
138
+ m.eval()
139
+ with torch.no_grad(): best=max(best,(m(xte).argmax(1)==yte).float().mean().item())
140
+ return best
141
+
142
+ # ── Data ──
143
+
144
+ def d_cx(n=1000):
145
+ x=torch.rand(n,4)*2-1; return x,torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2)).unsqueeze(1)
146
+ def d_ne(n=1000):
147
+ x=torch.rand(n,2)*2-1; return x,(torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1])).unsqueeze(1)
148
+ def d_sp(n=1000):
149
+ t=torch.linspace(0,4*np.pi,n//2); r=torch.linspace(.3,2,n//2)
150
+ x=torch.cat([torch.stack([r*torch.cos(t),r*torch.sin(t)],1),torch.stack([r*torch.cos(t+np.pi),r*torch.sin(t+np.pi)],1)])+torch.randn(n,2)*.05
151
+ y=torch.cat([torch.zeros(n//2),torch.ones(n//2)]).long(); p=torch.randperm(n); return x[p],y[p]
152
+ def d_ch(n=1000):
153
+ x=torch.rand(n,2)*2-1; return x,((torch.sin(3*math.pi*x[:,0])*torch.sin(3*math.pi*x[:,1]))>0).long()
154
+ def d_hf(n=1000):
155
+ x=torch.linspace(-1,1,n).unsqueeze(1); return x,torch.sin(20*x)+torch.sin(50*x)+.5*torch.sin(100*x)
156
+ def d_mm(n=200): return torch.randn(n,8),torch.randn(n,4)
157
+ def d_ot(n=800):
158
+ x=torch.rand(n,2)*2-1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1)
159
+ def d_oe(n=300):
160
+ x=torch.rand(n,2)+1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1)
161
+
162
+ # ── Main ──
163
+
164
+ def main():
165
+ print("="*80)
166
+ print(" v11: SinGLU + DISCIPLINED Phase")
167
+ print(" φ scaled to 0.1, tied to feature space: sin(ω·(Wg·x + 0.1·tanh(Wφ·x)))")
168
+ print("="*80)
169
+
170
+ N=3; Ms={'Vanilla':(VanillaMLP,{}),'SinGLU':(SinGLUNet,{'w0':None}),'v11':(v11Net,{'w0':None})}
171
+ tasks=[
172
+ ("Complex Fn","r",d_cx,4,1,5000,300,1e-3,30.,750),
173
+ ("Nested Fn","r",d_ne,2,1,3000,300,1e-3,20.,750),
174
+ ("Spiral","c",d_sp,2,2,3000,250,1e-3,15.,700),
175
+ ("Checker","c",d_ch,2,2,3000,250,1e-3,20.,700),
176
+ ("High-Freq","r",d_hf,1,1,8000,300,1e-3,60.,700),
177
+ ("Memorize","r",d_mm,8,4,5000,400,1e-3,10.,200),
178
+ ]
179
+
180
+ R={}; PH={}
181
+ for tn,tt,df,di,do,bud,ep,lr,w0,sp in tasks:
182
+ print(f"\n{'━'*80}\n {tn} | ~{bud:,}p\n{'━'*80}")
183
+ hs={mn:fh(di,do,N,bud,mc,**{k:(w0 if v is None else v) for k,v in mk.items()}) for mn,(mc,mk) in Ms.items()}
184
+ tr={}
185
+ for mn,(mc,mk) in Ms.items():
186
+ kw={k:(w0 if v is None else v) for k,v in mk.items()}; h=hs[mn]; sc=[]
187
+ for seed in SEEDS:
188
+ set_seed(seed); x,y=df()
189
+ if sp>=len(x): xt,yt,xe,ye=x,y,x,y
190
+ else: xt,yt,xe,ye=x[:sp],y[:sp],x[sp:],y[sp:]
191
+ set_seed(seed+100); mdl=mc(di,do,h,N,**kw)
192
+ s=tr_r(mdl,xt,yt,xe,ye,ep,lr) if tt=='r' else tr_c(mdl,xt,yt,xe,ye,ep,lr)
193
+ sc.append(s)
194
+ if mn=='v11' and seed==SEEDS[-1]:
195
+ mdl.eval()
196
+ with torch.no_grad():
197
+ pp=mdl.get_all_phi(xe[:100]); ap=torch.cat([p.flatten() for p in pp])
198
+ PH[tn]={'m':ap.mean().item(),'s':ap.std().item(),'mn':ap.min().item(),'mx':ap.max().item()}
199
+ p=np_(mc(di,do,h,N,**kw))
200
+ tr[mn]={'mean':np.mean(sc),'std':np.std(sc),'scores':sc,'params':p,'hidden':h}
201
+
202
+ ir=tt=='r'
203
+ best=min(tr,key=lambda k:tr[k]['mean']) if ir else max(tr,key=lambda k:tr[k]['mean'])
204
+ met="MSE ↓" if ir else "Acc ↑"
205
+ print(f"\n {'M':<8} {'H':>3} {'P':>6} {met+' (mean±std)':>26}")
206
+ print(f" {'─'*46}")
207
+ for mn,r in tr.items():
208
+ m,s=r['mean'],r['std']
209
+ ms=f"{m:.2e}±{s:.1e}" if(ir and m<.001) else(f"{m:.4f}±{s:.4f}" if ir else f"{m:.1%}±{s:.3f}")
210
+ print(f" {mn:<8} {r['hidden']:>3} {r['params']:>6,} {ms:>26}{' ★' if mn==best else ''}")
211
+ print(f" → {best}")
212
+ if tn in PH: d=PH[tn]; print(f" φ: std={d['s']:.4f} range=[{d['mn']:.3f},{d['mx']:.3f}]")
213
+ R[tn]=tr
214
+
215
+ # OOD
216
+ print(f"\n{'━'*80}\n OOD: [-1,1] → [1,2]\n{'━'*80}")
217
+ OD={}
218
+ for mn,(mc,mk) in Ms.items():
219
+ kw={k:(20. if v is None else v) for k,v in mk.items()}; h=fh(2,1,N,5000,mc,**kw); ids,ods=[],[]
220
+ for seed in SEEDS:
221
+ set_seed(seed); xtr,ytr=d_ot()
222
+ set_seed(seed+50); xi=torch.rand(200,2)*2-1; yi=(torch.sin(3*math.pi*xi[:,0])*torch.cos(3*math.pi*xi[:,1])+xi[:,0]*xi[:,1]).unsqueeze(1)
223
+ set_seed(seed+50); xo,yo=d_oe()
224
+ set_seed(seed+100); mdl=mc(2,1,h,N,**kw)
225
+ si=tr_r(mdl,xtr,ytr,xi,yi,300,1e-3); mdl.eval()
226
+ with torch.no_grad(): so=F.mse_loss(mdl(xo),yo).item()
227
+ ids.append(si); ods.append(so)
228
+ OD[mn]={'id':np.mean(ids),'ood':np.mean(ods),'deg':np.mean(ods)/max(np.mean(ids),1e-10),'p':np_(mc(2,1,h,N,**kw)),'is':np.std(ids),'os':np.std(ods)}
229
+
230
+ bo=min(OD,key=lambda k:OD[k]['ood'])
231
+ print(f"\n {'M':<8} {'ID':>12} {'OOD':>12} {'Deg':>7}")
232
+ print(f" {'─'*42}")
233
+ for mn,r in OD.items():
234
+ print(f" {mn:<8} {r['id']:>8.4f}±{r['is']:.3f} {r['ood']:>8.4f}±{r['os']:.3f} {r['deg']:>6.1f}x{' ★' if mn==bo else ''}")
235
+ print(f" → {bo}")
236
+ R['OOD']={mn:{'mean':r['ood'],'std':r['os']} for mn,r in OD.items()}
237
+
238
+ # Summary
239
+ print(f"\n{'='*80}\n SUMMARY: v11 vs SinGLU vs Vanilla\n{'='*80}")
240
+ wc={k:0 for k in Ms}
241
+ print(f"\n {'Task':<14}",end="")
242
+ for mn in Ms: print(f" {mn:>12}",end="")
243
+ print(f" {'W':>8}")
244
+ print(f" {'─'*50}")
245
+ for tn,t in R.items():
246
+ sc={k:v['mean'] for k,v in t.items()}; mx=max(sc.values())
247
+ ic=mx>.5 and mx<=1 and min(sc.values())>=0
248
+ if min(sc.values())<.001: ic=False
249
+ w=min(sc,key=sc.get) if(tn=='OOD' or not ic) else max(sc,key=sc.get)
250
+ wc[w]+=1
251
+ row=f" {tn:<14}"
252
+ for mn in Ms:
253
+ s=sc[mn]
254
+ if ic: row+=f" {s:>11.1%}"
255
+ elif s<.001: row+=f" {s:>11.2e}"
256
+ else: row+=f" {s:>11.4f}"
257
+ row+=f" ->{'':>1}{w}"; print(row)
258
+ print(f"\n {'─'*50}")
259
+ for mn,c in sorted(wc.items(),key=lambda x:-x[1]):
260
+ print(f" {mn:<8} {c} {'█'*c*4}")
261
+
262
+ # Compare v10 vs v11 φ
263
+ print(f"\n φ DISCIPLINE CHECK:")
264
+ print(f" {'Task':<14} {'v11 φ std':>10} {'v10 was':>10} {'Change':>10}")
265
+ print(f" {'─'*46}")
266
+ v10_stds={'Complex Fn':.192,'Nested Fn':.142,'Spiral':.242,'Checker':.207,'High-Freq':.321,'Memorize':.206}
267
+ for tn,d in PH.items():
268
+ v10s=v10_stds.get(tn,0)
269
+ change=f"{d['s']/v10s:.1%}" if v10s>0 else "N/A"
270
+ print(f" {tn:<14} {d['s']:>10.4f} {v10s:>10.3f} {change:>10}")
271
+
272
+ sv={'tasks':{},'ood':{},'phi':PH}
273
+ for tn,t in R.items():
274
+ sv['tasks'][tn]={mn:{'mean':float(r['mean']),'std':float(r.get('std',0)),
275
+ 'scores':[float(s) for s in r.get('scores',[r['mean']])],
276
+ 'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in t.items()}
277
+ sv['ood']={mn:{k:float(v) if isinstance(v,(float,np.floating)) else v for k,v in r.items()} for mn,r in OD.items()}
278
+ with open('/app/results_v11.json','w') as f: json.dump(sv,f,indent=2,default=str)
279
+ print(f"\n Saved to /app/results_v11.json")
280
+
281
+ if __name__=="__main__": main()