anshdadhich commited on
Commit
1c52248
·
verified ·
1 Parent(s): f9b75ef

v12: signal-proportional phase — TIES SinGLU 3-3

Browse files
Files changed (1) hide show
  1. benchmark_v12.py +254 -0
benchmark_v12.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ v12: SinGLU + SIGNAL-PROPORTIONAL Phase
4
+
5
+ v10: φ independent of signal → too hot, destroys structure
6
+ v11: φ tiny fixed scale → too cold, does nothing
7
+ v12: φ proportional to signal → adapts automatically
8
+
9
+ g = Wg·x (features)
10
+ φ = tanh(Wφ·x) (modulation signal, bounded [-1,1])
11
+ core = sin(ω · g · (1 + 0.2·φ)) (signal-proportional phase warping)
12
+ y = LN( Wo( core ⊙ Wv·x ) )
13
+
14
+ When g is large: phase effect is large (where the signal is strong)
15
+ When g is small: phase effect is small (where signal is weak)
16
+ On OOD: g·φ tracks signal magnitude, doesn't add uncorrelated noise
17
+ """
18
+
19
+ import torch, torch.nn as nn, torch.nn.functional as F
20
+ import numpy as np, math, json
21
+
22
+ SEEDS = [0, 1, 2]
23
+ def set_seed(s): torch.manual_seed(s); np.random.seed(s)
24
+
25
+ class VanillaMLP(nn.Module):
26
+ def __init__(self,di,do,h,n):
27
+ super().__init__(); L=[]; p=di
28
+ for _ in range(n): L+=[nn.Linear(p,h),nn.ReLU()]; p=h
29
+ L.append(nn.Linear(p,do)); self.net=nn.Sequential(*L)
30
+ def forward(self,x): return self.net(x)
31
+
32
+ class SinGLULayer(nn.Module):
33
+ def __init__(self,di,do,mid,w0=30.):
34
+ super().__init__()
35
+ self.Wg=nn.Linear(di,mid,bias=False); self.Wv=nn.Linear(di,mid,bias=False)
36
+ self.Wo=nn.Linear(mid,do,bias=True); self.w0=w0; self.ln=nn.LayerNorm(do)
37
+ with torch.no_grad():
38
+ self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0)
39
+ nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight)
40
+ def forward(self,x): return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x))*self.Wv(x)))
41
+
42
+ class SinGLUNet(nn.Module):
43
+ def __init__(self,di,do,h,n,w0=30.):
44
+ super().__init__(); mid=max(2,int(h*2/3)); L=[]; p=di
45
+ for _ in range(n): L.append(SinGLULayer(p,h,mid,w0)); p=h
46
+ L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L)
47
+ def forward(self,x):
48
+ for l in self.layers: x=l(x)
49
+ return x
50
+
51
+ class v12Layer(nn.Module):
52
+ """
53
+ g = Wg·x
54
+ φ = tanh(Wφ·x)
55
+ core = sin(ω · g · (1 + 0.2·φ)) ← phase proportional to signal
56
+ y = LN(Wo(core ⊙ Wv·x))
57
+ """
58
+ def __init__(self, di, do, mid, w0=30., warp_scale=0.2):
59
+ super().__init__()
60
+ self.Wg=nn.Linear(di,mid,bias=False)
61
+ self.Wv=nn.Linear(di,mid,bias=False)
62
+ self.Wo=nn.Linear(mid,do,bias=True)
63
+ self.Wphi=nn.Linear(di,mid,bias=True)
64
+ self.w0=w0; self.ws=warp_scale; self.ln=nn.LayerNorm(do)
65
+ with torch.no_grad():
66
+ self.Wg.weight.uniform_(-math.sqrt(6/di)/w0,math.sqrt(6/di)/w0)
67
+ nn.init.xavier_uniform_(self.Wv.weight); nn.init.xavier_uniform_(self.Wo.weight)
68
+ nn.init.zeros_(self.Wphi.weight); nn.init.zeros_(self.Wphi.bias)
69
+
70
+ def forward(self,x):
71
+ g = self.Wg(x)
72
+ phi = torch.tanh(self.Wphi(x))
73
+ warped = g * (1. + self.ws * phi) # signal-proportional warping
74
+ core = torch.sin(self.w0 * warped)
75
+ return self.ln(self.Wo(core * self.Wv(x)))
76
+
77
+ def get_phi(self,x):
78
+ with torch.no_grad(): return torch.tanh(self.Wphi(x))
79
+
80
+ class v12Net(nn.Module):
81
+ def __init__(self,di,do,h,n,w0=30.):
82
+ super().__init__(); mid=max(2,int(h*2/3)); L=[]; p=di
83
+ for _ in range(n): L.append(v12Layer(p,h,mid,w0)); p=h
84
+ L.append(nn.Linear(p,do)); self.layers=nn.ModuleList(L)
85
+ def forward(self,x):
86
+ for l in self.layers: x=l(x)
87
+ return x
88
+ def get_all_phi(self,x):
89
+ P=[]; h=x
90
+ for l in self.layers:
91
+ if isinstance(l,v12Layer): P.append(l.get_phi(h)); h=l(h)
92
+ else: h=l(h)
93
+ return P
94
+
95
+ # Utils
96
+ def np_(m): return sum(p.numel() for p in m.parameters() if p.requires_grad)
97
+ def fh(di,do,n,t,cls,**kw):
98
+ lo,hi,b=2,512,2
99
+ while lo<=hi:
100
+ mid=(lo+hi)//2; p=np_(cls(di,do,mid,n,**kw))
101
+ if abs(p-t)<abs(np_(cls(di,do,b,n,**kw))-t): b=mid
102
+ if p<t: lo=mid+1
103
+ else: hi=mid-1
104
+ return b
105
+
106
+ def tr_r(m,xt,yt,xe,ye,ep,lr,bs=256):
107
+ o=torch.optim.Adam(m.parameters(),lr=lr); s=torch.optim.lr_scheduler.CosineAnnealingLR(o,T_max=ep)
108
+ best=float('inf'); n=len(xt)
109
+ for e in range(ep):
110
+ m.train(); p=torch.randperm(n)
111
+ for i in range(0,n,bs):
112
+ idx=p[i:i+bs]; loss=F.mse_loss(m(xt[idx]),yt[idx])
113
+ o.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(m.parameters(),1.); o.step()
114
+ s.step()
115
+ if(e+1)%max(1,ep//10)==0:
116
+ m.eval()
117
+ with torch.no_grad(): best=min(best,F.mse_loss(m(xe),ye).item())
118
+ m.eval()
119
+ with torch.no_grad(): best=min(best,F.mse_loss(m(xe),ye).item())
120
+ return best
121
+
122
+ def tr_c(m,xt,yt,xe,ye,ep,lr,bs=256):
123
+ o=torch.optim.Adam(m.parameters(),lr=lr); s=torch.optim.lr_scheduler.CosineAnnealingLR(o,T_max=ep)
124
+ best=0; n=len(xt)
125
+ for e in range(ep):
126
+ m.train(); p=torch.randperm(n)
127
+ for i in range(0,n,bs):
128
+ idx=p[i:i+bs]; loss=F.cross_entropy(m(xt[idx]),yt[idx])
129
+ o.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(m.parameters(),1.); o.step()
130
+ s.step()
131
+ if(e+1)%max(1,ep//10)==0:
132
+ m.eval()
133
+ with torch.no_grad(): best=max(best,(m(xe).argmax(1)==ye).float().mean().item())
134
+ m.eval()
135
+ with torch.no_grad(): best=max(best,(m(xe).argmax(1)==ye).float().mean().item())
136
+ return best
137
+
138
+ # Data
139
+ def d_cx(n=1000): x=torch.rand(n,4)*2-1; return x,torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2)).unsqueeze(1)
140
+ def d_ne(n=1000): x=torch.rand(n,2)*2-1; return x,(torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1])).unsqueeze(1)
141
+ def d_sp(n=1000):
142
+ t=torch.linspace(0,4*np.pi,n//2); r=torch.linspace(.3,2,n//2)
143
+ x=torch.cat([torch.stack([r*torch.cos(t),r*torch.sin(t)],1),torch.stack([r*torch.cos(t+np.pi),r*torch.sin(t+np.pi)],1)])+torch.randn(n,2)*.05
144
+ y=torch.cat([torch.zeros(n//2),torch.ones(n//2)]).long(); p=torch.randperm(n); return x[p],y[p]
145
+ def d_ch(n=1000): x=torch.rand(n,2)*2-1; return x,((torch.sin(3*math.pi*x[:,0])*torch.sin(3*math.pi*x[:,1]))>0).long()
146
+ def d_hf(n=1000): x=torch.linspace(-1,1,n).unsqueeze(1); return x,torch.sin(20*x)+torch.sin(50*x)+.5*torch.sin(100*x)
147
+ def d_mm(n=200): return torch.randn(n,8),torch.randn(n,4)
148
+ def d_ot(n=800): x=torch.rand(n,2)*2-1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1)
149
+ def d_oe(n=300): x=torch.rand(n,2)+1; return x,(torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]).unsqueeze(1)
150
+
151
+ def main():
152
+ print("="*80)
153
+ print(" v12: SinGLU + SIGNAL-PROPORTIONAL Phase")
154
+ print(" core = sin(ω · g · (1 + 0.2·tanh(Wφ·x))) — phase scales with signal")
155
+ print("="*80)
156
+
157
+ N=3; Ms={'Vanilla':(VanillaMLP,{}),'SinGLU':(SinGLUNet,{'w0':None}),'v12':(v12Net,{'w0':None})}
158
+ tasks=[
159
+ ("Complex Fn","r",d_cx,4,1,5000,300,1e-3,30.,750),
160
+ ("Nested Fn","r",d_ne,2,1,3000,300,1e-3,20.,750),
161
+ ("Spiral","c",d_sp,2,2,3000,250,1e-3,15.,700),
162
+ ("Checker","c",d_ch,2,2,3000,250,1e-3,20.,700),
163
+ ("High-Freq","r",d_hf,1,1,8000,300,1e-3,60.,700),
164
+ ("Memorize","r",d_mm,8,4,5000,400,1e-3,10.,200),
165
+ ]
166
+ R={}; PH={}
167
+ for tn,tt,df,di,do,bud,ep,lr,w0,sp in tasks:
168
+ print(f"\n{'━'*80}\n {tn}\n{'━'*80}")
169
+ hs={mn:fh(di,do,N,bud,mc,**{k:(w0 if v is None else v) for k,v in mk.items()}) for mn,(mc,mk) in Ms.items()}
170
+ tr={}
171
+ for mn,(mc,mk) in Ms.items():
172
+ kw={k:(w0 if v is None else v) for k,v in mk.items()}; h=hs[mn]; sc=[]
173
+ for seed in SEEDS:
174
+ set_seed(seed); x,y=df()
175
+ if sp>=len(x): xt,yt,xe,ye=x,y,x,y
176
+ else: xt,yt,xe,ye=x[:sp],y[:sp],x[sp:],y[sp:]
177
+ set_seed(seed+100); mdl=mc(di,do,h,N,**kw)
178
+ s=tr_r(mdl,xt,yt,xe,ye,ep,lr) if tt=='r' else tr_c(mdl,xt,yt,xe,ye,ep,lr)
179
+ sc.append(s)
180
+ if mn=='v12' and seed==SEEDS[-1]:
181
+ mdl.eval()
182
+ with torch.no_grad():
183
+ pp=mdl.get_all_phi(xe[:100]); ap=torch.cat([p.flatten() for p in pp])
184
+ PH[tn]={'m':ap.mean().item(),'s':ap.std().item(),'mn':ap.min().item(),'mx':ap.max().item()}
185
+ p=np_(mc(di,do,h,N,**kw))
186
+ tr[mn]={'mean':np.mean(sc),'std':np.std(sc),'scores':sc,'params':p,'hidden':h}
187
+ ir=tt=='r'
188
+ best=min(tr,key=lambda k:tr[k]['mean']) if ir else max(tr,key=lambda k:tr[k]['mean'])
189
+ met="MSE ↓" if ir else "Acc ↑"
190
+ print(f"\n {'M':<8} {'H':>3} {'P':>6} {met+' (mean±std)':>26}")
191
+ print(f" {'─'*46}")
192
+ for mn,r in tr.items():
193
+ m,s=r['mean'],r['std']
194
+ ms=f"{m:.2e}±{s:.1e}" if(ir and m<.001) else(f"{m:.4f}±{s:.4f}" if ir else f"{m:.1%}±{s:.3f}")
195
+ print(f" {mn:<8} {r['hidden']:>3} {r['params']:>6,} {ms:>26}{' ★' if mn==best else ''}")
196
+ print(f" → {best}")
197
+ if tn in PH: d=PH[tn]; print(f" φ: mean={d['m']:.4f} std={d['s']:.4f} [{d['mn']:.3f},{d['mx']:.3f}]")
198
+ R[tn]=tr
199
+
200
+ # OOD
201
+ print(f"\n{'━'*80}\n OOD: [-1,1] → [1,2]\n{'━'*80}")
202
+ OD={}
203
+ for mn,(mc,mk) in Ms.items():
204
+ kw={k:(20. if v is None else v) for k,v in mk.items()}; h=fh(2,1,N,5000,mc,**kw); ids,ods=[],[]
205
+ for seed in SEEDS:
206
+ set_seed(seed); xtr,ytr=d_ot()
207
+ set_seed(seed+50); xi=torch.rand(200,2)*2-1; yi=(torch.sin(3*math.pi*xi[:,0])*torch.cos(3*math.pi*xi[:,1])+xi[:,0]*xi[:,1]).unsqueeze(1)
208
+ set_seed(seed+50); xo,yo=d_oe()
209
+ set_seed(seed+100); mdl=mc(2,1,h,N,**kw)
210
+ si=tr_r(mdl,xtr,ytr,xi,yi,300,1e-3); mdl.eval()
211
+ with torch.no_grad(): so=F.mse_loss(mdl(xo),yo).item()
212
+ ids.append(si); ods.append(so)
213
+ OD[mn]={'id':np.mean(ids),'ood':np.mean(ods),'deg':np.mean(ods)/max(np.mean(ids),1e-10),'is':np.std(ids),'os':np.std(ods)}
214
+ bo=min(OD,key=lambda k:OD[k]['ood'])
215
+ print(f"\n {'M':<8} {'ID':>12} {'OOD':>12} {'Deg':>7}")
216
+ print(f" {'─'*42}")
217
+ for mn,r in OD.items(): print(f" {mn:<8} {r['id']:>8.4f}±{r['is']:.3f} {r['ood']:>8.4f}±{r['os']:.3f} {r['deg']:>6.1f}x{' ★' if mn==bo else ''}")
218
+ print(f" → {bo}")
219
+ R['OOD']={mn:{'mean':r['ood'],'std':r['os']} for mn,r in OD.items()}
220
+
221
+ # Summary
222
+ print(f"\n{'='*80}\n FINAL SUMMARY\n{'='*80}")
223
+ wc={k:0 for k in Ms}
224
+ print(f"\n {'Task':<14}",end="")
225
+ for mn in Ms: print(f" {mn:>12}",end="")
226
+ print(f" {'W':>8}")
227
+ print(f" {'─'*50}")
228
+ for tn,t in R.items():
229
+ sc={k:v['mean'] for k,v in t.items()}; mx=max(sc.values())
230
+ ic=mx>.5 and mx<=1 and min(sc.values())>=0
231
+ if min(sc.values())<.001: ic=False
232
+ w=min(sc,key=sc.get) if(tn=='OOD' or not ic) else max(sc,key=sc.get)
233
+ wc[w]+=1
234
+ row=f" {tn:<14}"
235
+ for mn in Ms:
236
+ s=sc[mn]
237
+ if ic: row+=f" {s:>11.1%}"
238
+ elif s<.001: row+=f" {s:>11.2e}"
239
+ else: row+=f" {s:>11.4f}"
240
+ row+=f" ->{w}"; print(row)
241
+ print(f"\n {'─'*50}")
242
+ for mn,c in sorted(wc.items(),key=lambda x:-x[1]):
243
+ print(f" {mn:<8} {c} {'█'*c*4}")
244
+
245
+ sv={'tasks':{},'ood':{},'phi':PH}
246
+ for tn,t in R.items():
247
+ sv['tasks'][tn]={mn:{'mean':float(r['mean']),'std':float(r.get('std',0)),
248
+ 'scores':[float(s) for s in r.get('scores',[r['mean']])],
249
+ 'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in t.items()}
250
+ sv['ood']={mn:{k:float(v) if isinstance(v,(float,np.floating)) else v for k,v in r.items()} for mn,r in OD.items()}
251
+ with open('/app/results_v12.json','w') as f: json.dump(sv,f,indent=2,default=str)
252
+ print(f"\n Saved.")
253
+
254
+ if __name__=="__main__": main()