anshdadhich commited on
Commit
16b1a41
·
verified ·
1 Parent(s): c673cd5

v13: aligned phase + correlation analysis

Browse files
Files changed (1) hide show
  1. benchmark_v13.py +301 -0
benchmark_v13.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ v13: ALIGNED PHASE — phase along signal axes, not independent
4
+
5
+ v10: sin(ω·g + π·tanh(Wφ·x)) ← additive but independent → chaotic
6
+ v12: sin(ω · g·(1+0.2·φ)) ← frequency modulation → drift
7
+ v13: sin(ω·g + 0.1·g·tanh(Wφ·x)) ← additive phase, aligned to g → stable
8
+
9
+ The key: φ ∝ g. Phase only shifts where signal exists.
10
+ No frequency drift (ω stays fixed). No independent noise.
11
+ """
12
+
13
+ import torch, torch.nn as nn, torch.nn.functional as F
14
+ import numpy as np, math, json
15
+
16
+ SEEDS = [0, 1, 2]
17
+ def set_seed(s): torch.manual_seed(s); np.random.seed(s)
18
+
19
+ class VanillaMLP(nn.Module):
20
+ def __init__(self,di,do,h,n):
21
+ super().__init__(); L=[]; p=di
22
+ for _ in range(n): L+=[nn.Linear(p,h),nn.ReLU()]; p=h
23
+ L.append(nn.Linear(p,do)); self.net=nn.Sequential(*L)
24
+ def forward(self,x): return self.net(x)
25
+
26
+ class SinGLULayer(nn.Module):
27
+ def __init__(self,di,do,mid,w0=30.):
28
+ super().__init__()
29
+ self.Wg=nn.Linear(di,mid,bias=False); self.Wv=nn.Linear(di,mid,bias=False)
30
+ self.Wo=nn.Linear(mid,do,bias=True); self.w0=w0; self.ln=nn.LayerNorm(do)
31
+ with torch.no_grad():
32
+ self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0)
33
+ nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight)
34
+ def forward(self,x): return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x))*self.Wv(x)))
35
+
36
+ class SinGLUNet(nn.Module):
37
+ def __init__(self,di,do,h,n,w0=30.):
38
+ super().__init__(); mid=max(2,int(h*2/3)); L=[]; p=di
39
+ for _ in range(n): L.append(SinGLULayer(p,h,mid,w0)); p=h
40
+ L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L)
41
+ def forward(self,x):
42
+ for l in self.layers: x=l(x)
43
+ return x
44
+
45
+ # v10 (free phase) for comparison
46
+ class v10Layer(nn.Module):
47
+ def __init__(self,di,do,mid,w0=30.):
48
+ super().__init__()
49
+ self.Wg=nn.Linear(di,mid,bias=False); self.Wv=nn.Linear(di,mid,bias=False)
50
+ self.Wo=nn.Linear(mid,do,bias=True); self.Wphi=nn.Linear(di,mid,bias=True)
51
+ self.w0=w0; self.ln=nn.LayerNorm(do)
52
+ with torch.no_grad():
53
+ self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0)
54
+ nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight)
55
+ nn.init.zeros_(self.Wphi.weight); nn.init.zeros_(self.Wphi.bias)
56
+ def forward(self,x):
57
+ phi=math.pi*torch.tanh(self.Wphi(x))
58
+ return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x)+phi)*self.Wv(x)))
59
+
60
+ class v10Net(nn.Module):
61
+ def __init__(self,di,do,h,n,w0=30.):
62
+ super().__init__(); mid=max(2,int(h*2/3)); L=[]; p=di
63
+ for _ in range(n): L.append(v10Layer(p,h,mid,w0)); p=h
64
+ L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L)
65
+ def forward(self,x):
66
+ for l in self.layers: x=l(x)
67
+ return x
68
+
69
+ # v12 (FM) for comparison
70
+ class v12Layer(nn.Module):
71
+ def __init__(self,di,do,mid,w0=30.):
72
+ super().__init__()
73
+ self.Wg=nn.Linear(di,mid,bias=False); self.Wv=nn.Linear(di,mid,bias=False)
74
+ self.Wo=nn.Linear(mid,do,bias=True); self.Wphi=nn.Linear(di,mid,bias=True)
75
+ self.w0=w0; self.ln=nn.LayerNorm(do)
76
+ with torch.no_grad():
77
+ self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0)
78
+ nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight)
79
+ nn.init.zeros_(self.Wphi.weight); nn.init.zeros_(self.Wphi.bias)
80
+ def forward(self,x):
81
+ g=self.Wg(x); phi=torch.tanh(self.Wphi(x))
82
+ return self.ln(self.Wo(torch.sin(self.w0*g*(1.+0.2*phi))*self.Wv(x)))
83
+
84
+ class v12Net(nn.Module):
85
+ def __init__(self,di,do,h,n,w0=30.):
86
+ super().__init__(); mid=max(2,int(h*2/3)); L=[]; p=di
87
+ for _ in range(n): L.append(v12Layer(p,h,mid,w0)); p=h
88
+ L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L)
89
+ def forward(self,x):
90
+ for l in self.layers: x=l(x)
91
+ return x
92
+
93
+ # v13: ALIGNED PHASE
94
+ class v13Layer(nn.Module):
95
+ """
96
+ g = Wg·x
97
+ φ = 0.1 · g · tanh(Wφ·x) ← phase ALIGNED to signal
98
+ core = sin(ω·g + φ) ← additive, no freq drift
99
+ y = LN(Wo(core ⊙ Wv·x))
100
+ """
101
+ def __init__(self,di,do,mid,w0=30.,alpha=0.1):
102
+ super().__init__()
103
+ self.Wg=nn.Linear(di,mid,bias=False); self.Wv=nn.Linear(di,mid,bias=False)
104
+ self.Wo=nn.Linear(mid,do,bias=True); self.Wphi=nn.Linear(di,mid,bias=True)
105
+ self.w0=w0; self.a=alpha; self.ln=nn.LayerNorm(do)
106
+ with torch.no_grad():
107
+ self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0)
108
+ nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight)
109
+ nn.init.zeros_(self.Wphi.weight); nn.init.zeros_(self.Wphi.bias)
110
+
111
+ def forward(self,x):
112
+ g = self.Wg(x)
113
+ phi = self.a * g * torch.tanh(self.Wphi(x)) # aligned to signal
114
+ core = torch.sin(self.w0 * g + phi) # additive phase, no freq drift
115
+ return self.ln(self.Wo(core * self.Wv(x)))
116
+
117
+ def get_corr(self,x):
118
+ """Measure correlation between g and phi"""
119
+ with torch.no_grad():
120
+ g=self.Wg(x); phi=self.a*g*torch.tanh(self.Wphi(x))
121
+ # pearson correlation per neuron, averaged
122
+ gf=g.flatten(); pf=phi.flatten()
123
+ if gf.std()==0 or pf.std()==0: return 0.
124
+ return ((gf-gf.mean())*(pf-pf.mean())).mean()/(gf.std()*pf.std()+1e-8)
125
+
126
+ class v13Net(nn.Module):
127
+ def __init__(self,di,do,h,n,w0=30.):
128
+ super().__init__(); mid=max(2,int(h*2/3)); L=[]; p=di
129
+ for _ in range(n): L.append(v13Layer(p,h,mid,w0)); p=h
130
+ L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L)
131
+ def forward(self,x):
132
+ for l in self.layers: x=l(x)
133
+ return x
134
+ def get_corrs(self,x):
135
+ cs=[]; h=x
136
+ for l in self.layers:
137
+ if isinstance(l,v13Layer): cs.append(l.get_corr(h).item()); h=l(h)
138
+ else: h=l(h)
139
+ return cs
140
+
141
+ # Utils
142
+ def np_(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)
143
+ def fh(di,do,n,t,cls,**kw):
144
+ lo,hi,b=2,512,2
145
+ while lo<=hi:
146
+ mid=(lo+hi)//2; p=np_(cls(di,do,mid,n,**kw))
147
+ if abs(p-t)<abs(np_(cls(di,do,b,n,**kw))-t): b=mid
148
+ if p<t: lo=mid+1
149
+ else: hi=mid-1
150
+ return b
151
+ def tr_r(m,xt,yt,xe,ye,ep,lr,bs=256):
152
+ o=torch.optim.Adam(m.parameters(),lr=lr); s=torch.optim.lr_scheduler.CosineAnnealingLR(o,T_max=ep)
153
+ best=float('inf'); n=len(xt)
154
+ for e in range(ep):
155
+ m.train(); p=torch.randperm(n)
156
+ for i in range(0,n,bs):
157
+ idx=p[i:i+bs]; loss=F.mse_loss(m(xt[idx]),yt[idx])
158
+ o.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(m.parameters(),1.); o.step()
159
+ s.step()
160
+ if(e+1)%max(1,ep//10)==0:
161
+ m.eval()
162
+ with torch.no_grad(): best=min(best,F.mse_loss(m(xe),ye).item())
163
+ m.eval()
164
+ with torch.no_grad(): best=min(best,F.mse_loss(m(xe),ye).item())
165
+ return best
166
+ def tr_c(m,xt,yt,xe,ye,ep,lr,bs=256):
167
+ o=torch.optim.Adam(m.parameters(),lr=lr); s=torch.optim.lr_scheduler.CosineAnnealingLR(o,T_max=ep)
168
+ best=0; n=len(xt)
169
+ for e in range(ep):
170
+ m.train(); p=torch.randperm(n)
171
+ for i in range(0,n,bs):
172
+ idx=p[i:i+bs]; loss=F.cross_entropy(m(xt[idx]),yt[idx])
173
+ o.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(m.parameters(),1.); o.step()
174
+ s.step()
175
+ if(e+1)%max(1,ep//10)==0:
176
+ m.eval()
177
+ with torch.no_grad(): best=max(best,(m(xe).argmax(1)==ye).float().mean().item())
178
+ m.eval()
179
+ with torch.no_grad(): best=max(best,(m(xe).argmax(1)==ye).float().mean().item())
180
+ return best
181
+
182
+ # Data
183
+ def d_cx(n=1000): x=torch.rand(n,4)*2-1; return x,torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2)).unsqueeze(1)
184
+ def d_ne(n=1000): x=torch.rand(n,2)*2-1; return x,(torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1])).unsqueeze(1)
185
+ def d_sp(n=1000):
186
+ t=torch.linspace(0,4*np.pi,n//2); r=torch.linspace(.3,2,n//2)
187
+ x=torch.cat([torch.stack([r*torch.cos(t),r*torch.sin(t)],1),torch.stack([r*torch.cos(t+np.pi),r*torch.sin(t+np.pi)],1)])+torch.randn(n,2)*.05
188
+ y=torch.cat([torch.zeros(n//2),torch.ones(n//2)]).long(); p=torch.randperm(n); return x[p],y[p]
189
+ def d_ch(n=1000): x=torch.rand(n,2)*2-1; return x,((torch.sin(3*math.pi*x[:,0])*torch.sin(3*math.pi*x[:,1]))>0).long()
190
+ def d_hf(n=1000): x=torch.linspace(-1,1,n).unsqueeze(1); return x,torch.sin(20*x)+torch.sin(50*x)+.5*torch.sin(100*x)
191
+ def d_mm(n=200): return torch.randn(n,8),torch.randn(n,4)
192
+ def d_ot(n=800): x=torch.rand(n,2)*2-1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1)
193
+ def d_oe(n=300): x=torch.rand(n,2)+1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1)
194
+
195
+ def main():
196
+ print("="*80)
197
+ print(" v13: ALIGNED PHASE | sin(ω·g + 0.1·g·tanh(Wφ·x))")
198
+ print(" + corr(g,φ) analysis | vs SinGLU, v10(free), v12(FM)")
199
+ print("="*80)
200
+
201
+ N=3
202
+ Ms={'Vanilla':(VanillaMLP,{}),'SinGLU':(SinGLUNet,{'w0':None}),
203
+ 'v10:free':(v10Net,{'w0':None}),'v12:FM':(v12Net,{'w0':None}),'v13':(v13Net,{'w0':None})}
204
+
205
+ tasks=[
206
+ ("Complex","r",d_cx,4,1,5000,300,1e-3,30.,750),
207
+ ("Nested","r",d_ne,2,1,3000,300,1e-3,20.,750),
208
+ ("Spiral","c",d_sp,2,2,3000,250,1e-3,15.,700),
209
+ ("Checker","c",d_ch,2,2,3000,250,1e-3,20.,700),
210
+ ("HiFreq","r",d_hf,1,1,8000,300,1e-3,60.,700),
211
+ ("Memorize","r",d_mm,8,4,5000,400,1e-3,10.,200),
212
+ ]
213
+ R={}; CORR={}
214
+ for tn,tt,df,di,do,bud,ep,lr,w0,sp in tasks:
215
+ print(f"\n{'━'*80}\n {tn}\n{'━'*80}")
216
+ hs={mn:fh(di,do,N,bud,mc,**{k:(w0 if v is None else v) for k,v in mk.items()}) for mn,(mc,mk) in Ms.items()}
217
+ tr={}
218
+ for mn,(mc,mk) in Ms.items():
219
+ kw={k:(w0 if v is None else v) for k,v in mk.items()}; h=hs[mn]; sc=[]
220
+ for seed in SEEDS:
221
+ set_seed(seed); x,y=df()
222
+ if sp>=len(x): xt,yt,xe,ye=x,y,x,y
223
+ else: xt,yt,xe,ye=x[:sp],y[:sp],x[sp:],y[sp:]
224
+ set_seed(seed+100); mdl=mc(di,do,h,N,**kw)
225
+ s=tr_r(mdl,xt,yt,xe,ye,ep,lr) if tt=='r' else tr_c(mdl,xt,yt,xe,ye,ep,lr)
226
+ sc.append(s)
227
+ if mn=='v13' and seed==SEEDS[-1]:
228
+ mdl.eval(); CORR[tn]=mdl.get_corrs(xe[:100])
229
+ p=np_(mc(di,do,h,N,**kw))
230
+ tr[mn]={'mean':np.mean(sc),'std':np.std(sc),'scores':sc,'params':p,'hidden':h}
231
+ ir=tt=='r'
232
+ best=min(tr,key=lambda k:tr[k]['mean']) if ir else max(tr,key=lambda k:tr[k]['mean'])
233
+ met="MSE ↓" if ir else "Acc ↑"
234
+ print(f"\n {'M':<10} {'H':>3} {'P':>6} {met:>24}")
235
+ print(f" {'─'*46}")
236
+ for mn,r in tr.items():
237
+ m=r['mean']; s=r['std']
238
+ ms=f"{m:.2e}±{s:.1e}" if(ir and m<.001) else(f"{m:.4f}±{s:.4f}" if ir else f"{m:.1%}±{s:.3f}")
239
+ print(f" {mn:<10} {r['hidden']:>3} {r['params']:>6,} {ms:>24}{' ★' if mn==best else ''}")
240
+ print(f" → {best}")
241
+ if tn in CORR: print(f" v13 corr(g,φ) per layer: {['%.3f'%c for c in CORR[tn]]}")
242
+ R[tn]=tr
243
+
244
+ # OOD
245
+ print(f"\n{'━'*80}\n OOD\n{'━'*80}")
246
+ OD={}
247
+ for mn,(mc,mk) in Ms.items():
248
+ kw={k:(20. if v is None else v) for k,v in mk.items()}; h=fh(2,1,N,5000,mc,**kw); ids,ods=[],[]
249
+ for seed in SEEDS:
250
+ set_seed(seed); xtr,ytr=d_ot()
251
+ set_seed(seed+50); xi=torch.rand(200,2)*2-1; yi=(torch.sin(3*math.pi*xi[:,0])*torch.cos(3*math.pi*xi[:,1])+xi[:,0]*xi[:,1]).unsqueeze(1)
252
+ set_seed(seed+50); xo,yo=d_oe()
253
+ set_seed(seed+100); mdl=mc(2,1,h,N,**kw)
254
+ si=tr_r(mdl,xtr,ytr,xi,yi,300,1e-3); mdl.eval()
255
+ with torch.no_grad(): so=F.mse_loss(mdl(xo),yo).item()
256
+ ids.append(si); ods.append(so)
257
+ OD[mn]={'id':np.mean(ids),'ood':np.mean(ods),'deg':np.mean(ods)/max(np.mean(ids),1e-10),'is':np.std(ids),'os':np.std(ods)}
258
+ bo=min(OD,key=lambda k:OD[k]['ood'])
259
+ print(f"\n {'M':<10} {'ID':>12} {'OOD':>12} {'Deg':>7}")
260
+ print(f" {'─'*44}")
261
+ for mn,r in OD.items(): print(f" {mn:<10} {r['id']:>8.4f}±{r['is']:.3f} {r['ood']:>8.4f}±{r['os']:.3f} {r['deg']:>6.1f}x{' ★' if mn==bo else ''}")
262
+ R['OOD']={mn:{'mean':r['ood'],'std':r['os']} for mn,r in OD.items()}
263
+
264
+ # Summary
265
+ print(f"\n{'='*80}\n SUMMARY\n{'='*80}")
266
+ wc={k:0 for k in Ms}
267
+ print(f"\n {'Task':<10}",end="")
268
+ for mn in Ms: print(f" {mn:>10}",end="")
269
+ print(f" {'W':>8}")
270
+ print(f" {'─'*68}")
271
+ for tn,t in R.items():
272
+ sc={k:v['mean'] for k,v in t.items()}; mx=max(sc.values())
273
+ ic=mx>.5 and mx<=1 and min(sc.values())>=0
274
+ if min(sc.values())<.001: ic=False
275
+ w=min(sc,key=sc.get) if(tn=='OOD' or not ic) else max(sc,key=sc.get)
276
+ wc[w]+=1
277
+ row=f" {tn:<10}"
278
+ for mn in Ms:
279
+ s=sc[mn]
280
+ if ic: row+=f" {s:>9.1%}"
281
+ elif s<.001: row+=f" {s:>9.2e}"
282
+ else: row+=f" {s:>9.4f}"
283
+ row+=f" ->{w}"; print(row)
284
+ print(f"\n {'─'*68}")
285
+ for mn,c in sorted(wc.items(),key=lambda x:-x[1]):
286
+ print(f" {mn:<10} {c} {'█'*c*3}")
287
+
288
+ print(f"\n corr(g,φ) — does v13 phase align with signal?")
289
+ for tn,cs in CORR.items():
290
+ print(f" {tn:<10} layers: {['%.3f'%c for c in cs]} avg={np.mean(cs):.3f}")
291
+
292
+ sv={'tasks':{},'ood':{},'corr':CORR}
293
+ for tn,t in R.items():
294
+ sv['tasks'][tn]={mn:{'mean':float(r['mean']),'std':float(r.get('std',0)),
295
+ 'scores':[float(s) for s in r.get('scores',[r['mean']])],
296
+ 'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in t.items()}
297
+ sv['ood']={mn:{k:float(v) if isinstance(v,(float,np.floating)) else v for k,v in r.items()} for mn,r in OD.items()}
298
+ with open('/app/results_v13.json','w') as f: json.dump(sv,f,indent=2,default=str)
299
+ print(f"\n Saved.")
300
+
301
+ if __name__=="__main__": main()