| |
| """ |
| v9: Controlled Frequency + Phase + Gate (the convergent design) |
| |
| per = sin( ω(x) ⊙ W_per·x + φ(x) ) ω(x) = ω0·(1 + 0.1·tanh(W_ω·x)) |
| y = LN( α(x) ⊙ per + (1-α(x)) ⊙ val + residual ) |
| |
| Key vs v7: ω is bounded (±10%), not free |
| Key vs v8: ω exists (not removed) |
| Key vs v8: paths are SEPARATED (not entangled as val*(α*per+(1-α))) |
| """ |
|
|
| import torch, torch.nn as nn, torch.nn.functional as F |
| import numpy as np, math, json |
|
|
| SEEDS = [0, 1, 2] |
| def set_seed(s): torch.manual_seed(s); np.random.seed(s) |
|
|
| |
|
|
| class VanillaMLP(nn.Module): |
| def __init__(self, d_in, d_out, h, n): |
| super().__init__() |
| layers = [] |
| prev = d_in |
| for _ in range(n): |
| layers += [nn.Linear(prev, h), nn.ReLU()]; prev = h |
| layers.append(nn.Linear(prev, d_out)) |
| self.net = nn.Sequential(*layers) |
| def forward(self, x): return self.net(x) |
|
|
| class SinGLULayer(nn.Module): |
| def __init__(self, d_in, d_out, mid, w0=30.): |
| super().__init__() |
| self.Wg = nn.Linear(d_in, mid, bias=False) |
| self.Wv = nn.Linear(d_in, mid, bias=False) |
| self.Wo = nn.Linear(mid, d_out, bias=True) |
| self.w0 = w0; self.ln = nn.LayerNorm(d_out) |
| with torch.no_grad(): |
| self.Wg.weight.uniform_(-math.sqrt(6/d_in)/w0, math.sqrt(6/d_in)/w0) |
| nn.init.xavier_uniform_(self.Wv.weight) |
| nn.init.xavier_uniform_(self.Wo.weight) |
| def forward(self, x): |
| return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x))*self.Wv(x))) |
|
|
| class SinGLUNet(nn.Module): |
| def __init__(self, d_in, d_out, h, n, w0=30.): |
| super().__init__() |
| mid = max(2, int(h*2/3)); layers = []; prev = d_in |
| for _ in range(n): layers.append(SinGLULayer(prev, h, mid, w0)); prev = h |
| layers.append(nn.Linear(prev, d_out)); self.layers = nn.ModuleList(layers) |
| def forward(self, x): |
| for l in self.layers: x = l(x) |
| return x |
|
|
| |
|
|
| class v9Layer(nn.Module): |
| """ |
| Controlled freq + phase + gate + separated paths + residual. |
| |
| ω(x) = ω0 · (1 + 0.1·tanh(W_ω·x)) bounded ±10% |
| φ(x) = π · tanh(W_φ·x) bounded [-π, π] |
| per = sin(ω(x) ⊙ W_per·x + φ(x)) full periodic |
| val = W_val·x full linear |
| α(x) = sigmoid(W_α·x) gate |
| y = LN( α⊙per + (1-α)⊙val + res ) SEPARATED paths |
| """ |
| def __init__(self, d_in, d_out, w0=30.): |
| super().__init__() |
| self.W_val = nn.Linear(d_in, d_out, bias=True) |
| self.W_per = nn.Linear(d_in, d_out, bias=False) |
| self.W_omega = nn.Linear(d_in, d_out, bias=True) |
| self.W_phi = nn.Linear(d_in, d_out, bias=True) |
| self.W_alpha = nn.Linear(d_in, d_out, bias=True) |
| self.w0 = w0 |
| self.ln = nn.LayerNorm(d_out) |
| self.res = nn.Linear(d_in, d_out, bias=False) if d_in != d_out else nn.Identity() |
| |
| with torch.no_grad(): |
| nn.init.xavier_uniform_(self.W_val.weight) |
| b = math.sqrt(6./d_in)/w0 |
| self.W_per.weight.uniform_(-b, b) |
| |
| nn.init.zeros_(self.W_omega.weight); nn.init.zeros_(self.W_omega.bias) |
| |
| nn.init.zeros_(self.W_phi.weight); nn.init.zeros_(self.W_phi.bias) |
| |
| nn.init.zeros_(self.W_alpha.weight); nn.init.zeros_(self.W_alpha.bias) |
| |
| def forward(self, x): |
| val = self.W_val(x) |
| omega = self.w0 * (1. + 0.1 * torch.tanh(self.W_omega(x))) |
| phi = math.pi * torch.tanh(self.W_phi(x)) |
| per = torch.sin(omega * self.W_per(x) + phi) |
| alpha = torch.sigmoid(self.W_alpha(x)) |
| |
| return self.ln(alpha * per + (1. - alpha) * val + self.res(x)) |
| |
| def get_diag(self, x): |
| with torch.no_grad(): |
| omega = self.w0 * (1. + 0.1 * torch.tanh(self.W_omega(x))) |
| phi = math.pi * torch.tanh(self.W_phi(x)) |
| alpha = torch.sigmoid(self.W_alpha(x)) |
| return alpha, phi, omega |
|
|
|
|
| class v9Net(nn.Module): |
| def __init__(self, d_in, d_out, h, n, w0=30.): |
| super().__init__() |
| layers = []; prev = d_in |
| for _ in range(n): layers.append(v9Layer(prev, h, w0)); prev = h |
| layers.append(nn.Linear(prev, d_out)); self.layers = nn.ModuleList(layers) |
| |
| def forward(self, x): |
| for l in self.layers: x = l(x) |
| return x |
| |
| def get_all_diag(self, x): |
| alphas, phis, omegas = [], [], [] |
| h = x |
| for l in self.layers: |
| if isinstance(l, v9Layer): |
| a,p,o = l.get_diag(h); alphas.append(a); phis.append(p); omegas.append(o) |
| h = l(h) |
| else: h = l(h) |
| return alphas, phis, omegas |
| |
| def gate_reg(self, x): |
| """Stronger polarization: (α - 0.5)² pushes away from center""" |
| total = 0; h = x |
| for l in self.layers: |
| if isinstance(l, v9Layer): |
| a = torch.sigmoid(l.W_alpha(h)) |
| total = total + ((a - 0.5)**2).mean() |
| h = l(h) |
| else: h = l(h) |
| return total |
|
|
| |
|
|
| def nparams(m): return sum(p.numel() for p in m.parameters() if p.requires_grad) |
|
|
| def find_h(di, do, n, target, cls, **kw): |
| lo,hi,best = 2,512,2 |
| while lo<=hi: |
| mid=(lo+hi)//2; p=nparams(cls(di,do,mid,n,**kw)) |
| if abs(p-target)<abs(nparams(cls(di,do,best,n,**kw))-target): best=mid |
| if p<target: lo=mid+1 |
| else: hi=mid-1 |
| return best |
|
|
| def train_reg(m, xtr,ytr,xte,yte, ep, lr, lam=5e-4, bs=256): |
| opt=torch.optim.Adam(m.parameters(),lr=lr) |
| sch=torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=ep) |
| best=float('inf'); n=len(xtr); use_reg=isinstance(m,v9Net) |
| for e in range(ep): |
| m.train(); perm=torch.randperm(n) |
| for i in range(0,n,bs): |
| idx=perm[i:i+bs]; bx,by=xtr[idx],ytr[idx] |
| loss=F.mse_loss(m(bx),by) |
| if use_reg: loss=loss+lam*m.gate_reg(bx) |
| opt.zero_grad(); loss.backward() |
| torch.nn.utils.clip_grad_norm_(m.parameters(),1.0); opt.step() |
| sch.step() |
| if (e+1)%max(1,ep//10)==0: |
| m.eval() |
| with torch.no_grad(): best=min(best,F.mse_loss(m(xte),yte).item()) |
| m.eval() |
| with torch.no_grad(): best=min(best,F.mse_loss(m(xte),yte).item()) |
| return best |
|
|
| def train_clf(m, xtr,ytr,xte,yte, ep, lr, lam=5e-4, bs=256): |
| opt=torch.optim.Adam(m.parameters(),lr=lr) |
| sch=torch.optim.lr_scheduler.CosineAnnealingLR(opt,T_max=ep) |
| best=0; n=len(xtr); use_reg=isinstance(m,v9Net) |
| for e in range(ep): |
| m.train(); perm=torch.randperm(n) |
| for i in range(0,n,bs): |
| idx=perm[i:i+bs]; bx,by=xtr[idx],ytr[idx] |
| loss=F.cross_entropy(m(bx),by) |
| if use_reg: loss=loss+lam*m.gate_reg(bx) |
| opt.zero_grad(); loss.backward() |
| torch.nn.utils.clip_grad_norm_(m.parameters(),1.0); opt.step() |
| sch.step() |
| if (e+1)%max(1,ep//10)==0: |
| m.eval() |
| with torch.no_grad(): best=max(best,(m(xte).argmax(1)==yte).float().mean().item()) |
| m.eval() |
| with torch.no_grad(): best=max(best,(m(xte).argmax(1)==yte).float().mean().item()) |
| return best |
|
|
| |
|
|
| def d_complex(n=1000): |
| x=torch.rand(n,4)*2-1; y=torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2)) |
| return x,y.unsqueeze(1) |
| def d_nested(n=1000): |
| x=torch.rand(n,2)*2-1; y=torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1]) |
| return x,y.unsqueeze(1) |
| def d_spiral(n=1000): |
| t=torch.linspace(0,4*np.pi,n//2); r=torch.linspace(.3,2,n//2) |
| x1=torch.stack([r*torch.cos(t),r*torch.sin(t)],1) |
| x2=torch.stack([r*torch.cos(t+np.pi),r*torch.sin(t+np.pi)],1) |
| x=torch.cat([x1,x2])+torch.randn(n,2)*.05 |
| y=torch.cat([torch.zeros(n//2),torch.ones(n//2)]).long() |
| p=torch.randperm(n); return x[p],y[p] |
| def d_checker(n=1000): |
| x=torch.rand(n,2)*2-1; y=((torch.sin(3*math.pi*x[:,0])*torch.sin(3*math.pi*x[:,1]))>0).long() |
| return x,y |
| def d_highfreq(n=1000): |
| x=torch.linspace(-1,1,n).unsqueeze(1); return x,torch.sin(20*x)+torch.sin(50*x)+.5*torch.sin(100*x) |
| def d_mem(n=200): return torch.randn(n,8),torch.randn(n,4) |
| def d_ood_tr(n=800): |
| x=torch.rand(n,2)*2-1; y=torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1] |
| return x,y.unsqueeze(1) |
| def d_ood_te(n=300): |
| x=torch.rand(n,2)+1; y=torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1] |
| return x,y.unsqueeze(1) |
|
|
| |
|
|
| def main(): |
| print("="*80) |
| print(" v9: CONTROLLED FREQ + PHASE + GATE (separated paths)") |
| print(" ω(x) = ω0·(1+0.1·tanh(W_ω·x)) | φ(x) = π·tanh(W_φ·x)") |
| print(" y = LN( α⊙per + (1-α)⊙val + res ) | λ=5e-4 polarization") |
| print("="*80) |
| |
| N=3 |
| models = { |
| 'Vanilla': (VanillaMLP, {}), |
| 'SinGLU': (SinGLUNet, {'w0':None}), |
| 'v9': (v9Net, {'w0':None}), |
| } |
| |
| tasks = [ |
| ("Complex Fn", "reg", d_complex, 4,1, 5000, 300, 1e-3, 30., 750), |
| ("Nested Fn", "reg", d_nested, 2,1, 3000, 300, 1e-3, 20., 750), |
| ("Spiral", "clf", d_spiral, 2,2, 3000, 250, 1e-3, 15., 700), |
| ("Checkerboard","clf", d_checker, 2,2, 3000, 250, 1e-3, 20., 700), |
| ("High-Freq", "reg", d_highfreq, 1,1, 8000, 300, 1e-3, 60., 700), |
| ("Memorize", "reg", d_mem, 8,4, 5000, 400, 1e-3, 10., 200), |
| ] |
| |
| all_res = {}; diag = {} |
| |
| for tn,tt,df,di,do,bud,ep,lr,w0,sp in tasks: |
| print(f"\n{'━'*80}\n {tn} | ~{bud:,} params\n{'━'*80}") |
| |
| hs={} |
| for mn,(mc,mk) in models.items(): |
| kw={k:(w0 if v is None else v) for k,v in mk.items()} |
| hs[mn]=find_h(di,do,N,bud,mc,**kw) |
| |
| tr={} |
| for mn,(mc,mk) in models.items(): |
| kw={k:(w0 if v is None else v) for k,v in mk.items()} |
| h=hs[mn]; scores=[] |
| for seed in SEEDS: |
| set_seed(seed); x,y=df() |
| if sp>=len(x): xtr,ytr,xte,yte=x,y,x,y |
| else: xtr,ytr,xte,yte=x[:sp],y[:sp],x[sp:],y[sp:] |
| set_seed(seed+100); mdl=mc(di,do,h,N,**kw) |
| if tt=='reg': s=train_reg(mdl,xtr,ytr,xte,yte,ep,lr) |
| else: s=train_clf(mdl,xtr,ytr,xte,yte,ep,lr) |
| scores.append(s) |
| if mn=='v9' and seed==SEEDS[-1]: |
| mdl.eval() |
| with torch.no_grad(): |
| als,phs,oms=mdl.get_all_diag(xte[:100]) |
| aa=torch.cat([a.flatten() for a in als]) |
| pp=torch.cat([p.flatten() for p in phs]) |
| oo=torch.cat([o.flatten() for o in oms]) |
| diag[tn]={ |
| 'a_m':aa.mean().item(),'a_s':aa.std().item(), |
| 'a_lo':(aa<.3).float().mean().item(),'a_hi':(aa>.7).float().mean().item(), |
| 'p_s':pp.std().item(),'p_m':pp.mean().item(), |
| 'o_m':oo.mean().item(),'o_s':oo.std().item(), |
| 'o_min':oo.min().item(),'o_max':oo.max().item(), |
| } |
| p=nparams(mc(di,do,h,N,**kw)) |
| tr[mn]={'mean':np.mean(scores),'std':np.std(scores),'scores':scores,'params':p,'hidden':h} |
| |
| is_reg=tt=='reg' |
| if is_reg: best=min(tr,key=lambda k:tr[k]['mean']) |
| else: best=max(tr,key=lambda k:tr[k]['mean']) |
| met="MSE ↓" if is_reg else "Acc ↑" |
| |
| print(f"\n {'Model':<10} {'H':>4} {'P':>6} {met+' (mean±std)':>26}") |
| print(f" {'─'*50}") |
| for mn,r in tr.items(): |
| m,s=r['mean'],r['std'] |
| ms=f"{m:.2e}±{s:.1e}" if(is_reg and m<.001) else(f"{m:.4f}±{s:.4f}" if is_reg else f"{m:.1%}±{s:.3f}") |
| print(f" {mn:<10} {r['hidden']:>4} {r['params']:>6,} {ms:>26}{' ★' if mn==best else ''}") |
| print(f" → {best}") |
| |
| if tn in diag: |
| d=diag[tn] |
| print(f" α: {d['a_m']:.3f}±{d['a_s']:.3f} ({d['a_lo']:.0%} lin, {d['a_hi']:.0%} per)") |
| print(f" φ: std={d['p_s']:.3f}") |
| print(f" ω: {d['o_m']:.1f}±{d['o_s']:.2f} [{d['o_min']:.1f},{d['o_max']:.1f}]") |
| |
| all_res[tn]=tr |
| |
| |
| print(f"\n{'━'*80}\n OOD: [-1,1] → [1,2]\n{'━'*80}") |
| ood_r={}; ood_d={} |
| for mn,(mc,mk) in models.items(): |
| kw={k:(20. if v is None else v) for k,v in mk.items()} |
| h=find_h(2,1,N,5000,mc,**kw); ids,oods=[],[] |
| for seed in SEEDS: |
| set_seed(seed); xtr,ytr=d_ood_tr() |
| set_seed(seed+50); xid=torch.rand(200,2)*2-1 |
| yid=(torch.sin(3*math.pi*xid[:,0])*torch.cos(3*math.pi*xid[:,1])+xid[:,0]*xid[:,1]).unsqueeze(1) |
| set_seed(seed+50); xoo,yoo=d_ood_te() |
| set_seed(seed+100); mdl=mc(2,1,h,N,**kw) |
| si=train_reg(mdl,xtr,ytr,xid,yid,300,1e-3) |
| mdl.eval() |
| with torch.no_grad(): so=F.mse_loss(mdl(xoo),yoo).item() |
| ids.append(si); oods.append(so) |
| if mn=='v9' and seed==SEEDS[-1]: |
| mdl.eval() |
| with torch.no_grad(): |
| ai,_,oi=mdl.get_all_diag(xid[:100]) |
| ao,_,oo2=mdl.get_all_diag(xoo[:100]) |
| ood_d={ |
| 'id_a':torch.cat([a.flatten() for a in ai]).mean().item(), |
| 'ood_a':torch.cat([a.flatten() for a in ao]).mean().item(), |
| 'id_o':torch.cat([o.flatten() for o in oi]).mean().item(), |
| 'ood_o':torch.cat([o.flatten() for o in oo2]).mean().item(), |
| } |
| p=nparams(mc(2,1,h,N,**kw)) |
| ood_r[mn]={'id':np.mean(ids),'ood':np.mean(oods),'p':p, |
| 'deg':np.mean(oods)/max(np.mean(ids),1e-10), |
| 'is':np.std(ids),'os':np.std(oods)} |
| |
| bo=min(ood_r,key=lambda k:ood_r[k]['ood']) |
| print(f"\n {'Model':<10} {'ID MSE':>14} {'OOD MSE':>14} {'Deg':>8}") |
| print(f" {'─'*50}") |
| for mn,r in ood_r.items(): |
| print(f" {mn:<10} {r['id']:>9.4f}±{r['is']:.3f} {r['ood']:>9.4f}±{r['os']:.3f} {r['deg']:>7.1f}x{' ★' if mn==bo else ''}") |
| print(f" → {bo}") |
| |
| if ood_d: |
| print(f"\n v9 on OOD:") |
| print(f" α: ID={ood_d['id_a']:.4f} → OOD={ood_d['ood_a']:.4f} (shift={ood_d['ood_a']-ood_d['id_a']:+.4f})") |
| print(f" ω: ID={ood_d['id_o']:.2f} → OOD={ood_d['ood_o']:.2f} (shift={ood_d['ood_o']-ood_d['id_o']:+.2f})") |
| |
| all_res['OOD']={mn:{'mean':r['ood'],'std':r['os']} for mn,r in ood_r.items()} |
| |
| |
| print(f"\n{'='*80}\n SUMMARY\n{'='*80}") |
| wc={k:0 for k in models} |
| print(f"\n {'Task':<18}",end="") |
| for mn in models: print(f" {mn:>12}",end="") |
| print(f" {'W':>8}") |
| print(f" {'─'*56}") |
| for tn,t in all_res.items(): |
| sc={k:v['mean'] for k,v in t.items()} |
| mx=max(sc.values()); is_c=mx>.5 and mx<=1 and min(sc.values())>=0 |
| if min(sc.values())<.001: is_c=False |
| w=min(sc,key=sc.get) if (tn=='OOD' or not is_c) else max(sc,key=sc.get) |
| wc[w]+=1 |
| row=f" {tn:<18}" |
| for mn in models: |
| s=sc[mn] |
| if is_c: row+=f" {s:>11.1%}" |
| elif s<.001: row+=f" {s:>11.2e}" |
| else: row+=f" {s:>11.4f}" |
| row+=f" {'->'+w:>8}"; print(row) |
| |
| print(f"\n {'─'*56}") |
| for mn,c in sorted(wc.items(),key=lambda x:-x[1]): |
| print(f" {mn:<10} {c} wins {'█'*c*4}") |
| |
| |
| print(f"\n v9 DIAGNOSTICS:") |
| print(f" {'Task':<18} {'α':>7} {'α_std':>7} {'%L':>5} {'%P':>5} {'φ_std':>7} {'ω':>7} {'ω_std':>7} {'ω range':>14}") |
| print(f" {'─'*80}") |
| for tn,d in diag.items(): |
| print(f" {tn:<18} {d['a_m']:>7.3f} {d['a_s']:>7.3f} {d['a_lo']:>4.0%} {d['a_hi']:>4.0%}" |
| f" {d['p_s']:>7.3f} {d['o_m']:>7.1f} {d['o_s']:>7.3f} [{d['o_min']:.1f},{d['o_max']:.1f}]") |
| |
| sv={'tasks':{},'ood':{},'diag':diag,'ood_diag':ood_d} |
| for tn,t in all_res.items(): |
| sv['tasks'][tn]={mn:{'mean':float(r['mean']),'std':float(r.get('std',0)), |
| 'scores':[float(s) for s in r.get('scores',[r['mean']])], |
| 'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in t.items()} |
| sv['ood']={mn:{k:float(v) if isinstance(v,(float,np.floating)) else v for k,v in r.items()} for mn,r in ood_r.items()} |
| with open('/app/results_v9.json','w') as f: json.dump(sv,f,indent=2,default=str) |
| print("\n Saved to /app/results_v9.json") |
|
|
| if __name__=="__main__": main() |
|
|