#!/usr/bin/env python3 """ v9: Controlled Frequency + Phase + Gate (the convergent design) per = sin( ω(x) ⊙ W_per·x + φ(x) ) ω(x) = ω0·(1 + 0.1·tanh(W_ω·x)) y = LN( α(x) ⊙ per + (1-α(x)) ⊙ val + residual ) Key vs v7: ω is bounded (±10%), not free Key vs v8: ω exists (not removed) Key vs v8: paths are SEPARATED (not entangled as val*(α*per+(1-α))) """ import torch, torch.nn as nn, torch.nn.functional as F import numpy as np, math, json SEEDS = [0, 1, 2] def set_seed(s): torch.manual_seed(s); np.random.seed(s) # ── Baselines ── class VanillaMLP(nn.Module): def __init__(self, d_in, d_out, h, n): super().__init__() layers = [] prev = d_in for _ in range(n): layers += [nn.Linear(prev, h), nn.ReLU()]; prev = h layers.append(nn.Linear(prev, d_out)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class SinGLULayer(nn.Module): def __init__(self, d_in, d_out, mid, w0=30.): super().__init__() self.Wg = nn.Linear(d_in, mid, bias=False) self.Wv = nn.Linear(d_in, mid, bias=False) self.Wo = nn.Linear(mid, d_out, bias=True) self.w0 = w0; self.ln = nn.LayerNorm(d_out) with torch.no_grad(): self.Wg.weight.uniform_(-math.sqrt(6/d_in)/w0, math.sqrt(6/d_in)/w0) nn.init.xavier_uniform_(self.Wv.weight) nn.init.xavier_uniform_(self.Wo.weight) def forward(self, x): return self.ln(self.Wo(torch.sin(self.w0*self.Wg(x))*self.Wv(x))) class SinGLUNet(nn.Module): def __init__(self, d_in, d_out, h, n, w0=30.): super().__init__() mid = max(2, int(h*2/3)); layers = []; prev = d_in for _ in range(n): layers.append(SinGLULayer(prev, h, mid, w0)); prev = h layers.append(nn.Linear(prev, d_out)); self.layers = nn.ModuleList(layers) def forward(self, x): for l in self.layers: x = l(x) return x # ── v9: THE CONVERGENT DESIGN ── class v9Layer(nn.Module): """ Controlled freq + phase + gate + separated paths + residual. ω(x) = ω0 · (1 + 0.1·tanh(W_ω·x)) bounded ±10% φ(x) = π · tanh(W_φ·x) bounded [-π, π] per = sin(ω(x) ⊙ W_per·x + φ(x)) full periodic val = W_val·x full linear α(x) = sigmoid(W_α·x) gate y = LN( α⊙per + (1-α)⊙val + res ) SEPARATED paths """ def __init__(self, d_in, d_out, w0=30.): super().__init__() self.W_val = nn.Linear(d_in, d_out, bias=True) # linear path self.W_per = nn.Linear(d_in, d_out, bias=False) # periodic input self.W_omega = nn.Linear(d_in, d_out, bias=True) # frequency mod self.W_phi = nn.Linear(d_in, d_out, bias=True) # phase self.W_alpha = nn.Linear(d_in, d_out, bias=True) # gate self.w0 = w0 self.ln = nn.LayerNorm(d_out) self.res = nn.Linear(d_in, d_out, bias=False) if d_in != d_out else nn.Identity() with torch.no_grad(): nn.init.xavier_uniform_(self.W_val.weight) b = math.sqrt(6./d_in)/w0 self.W_per.weight.uniform_(-b, b) # ω: start at ω0 (tanh(0)=0 → ω=ω0) nn.init.zeros_(self.W_omega.weight); nn.init.zeros_(self.W_omega.bias) # φ: start at 0 nn.init.zeros_(self.W_phi.weight); nn.init.zeros_(self.W_phi.bias) # α: start at 0.5 (sigmoid(0)) nn.init.zeros_(self.W_alpha.weight); nn.init.zeros_(self.W_alpha.bias) def forward(self, x): val = self.W_val(x) omega = self.w0 * (1. + 0.1 * torch.tanh(self.W_omega(x))) phi = math.pi * torch.tanh(self.W_phi(x)) per = torch.sin(omega * self.W_per(x) + phi) alpha = torch.sigmoid(self.W_alpha(x)) # SEPARATED: α picks between per and val, not val*(α*per+(1-α)) return self.ln(alpha * per + (1. - alpha) * val + self.res(x)) def get_diag(self, x): with torch.no_grad(): omega = self.w0 * (1. + 0.1 * torch.tanh(self.W_omega(x))) phi = math.pi * torch.tanh(self.W_phi(x)) alpha = torch.sigmoid(self.W_alpha(x)) return alpha, phi, omega class v9Net(nn.Module): def __init__(self, d_in, d_out, h, n, w0=30.): super().__init__() layers = []; prev = d_in for _ in range(n): layers.append(v9Layer(prev, h, w0)); prev = h layers.append(nn.Linear(prev, d_out)); self.layers = nn.ModuleList(layers) def forward(self, x): for l in self.layers: x = l(x) return x def get_all_diag(self, x): alphas, phis, omegas = [], [], [] h = x for l in self.layers: if isinstance(l, v9Layer): a,p,o = l.get_diag(h); alphas.append(a); phis.append(p); omegas.append(o) h = l(h) else: h = l(h) return alphas, phis, omegas def gate_reg(self, x): """Stronger polarization: (α - 0.5)² pushes away from center""" total = 0; h = x for l in self.layers: if isinstance(l, v9Layer): a = torch.sigmoid(l.W_alpha(h)) total = total + ((a - 0.5)**2).mean() h = l(h) else: h = l(h) return total # ── Utils ── def nparams(m): return sum(p.numel() for p in m.parameters() if p.requires_grad) def find_h(di, do, n, target, cls, **kw): lo,hi,best = 2,512,2 while lo<=hi: mid=(lo+hi)//2; p=nparams(cls(di,do,mid,n,**kw)) if abs(p-target)0).long() return x,y def d_highfreq(n=1000): x=torch.linspace(-1,1,n).unsqueeze(1); return x,torch.sin(20*x)+torch.sin(50*x)+.5*torch.sin(100*x) def d_mem(n=200): return torch.randn(n,8),torch.randn(n,4) def d_ood_tr(n=800): x=torch.rand(n,2)*2-1; y=torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1] return x,y.unsqueeze(1) def d_ood_te(n=300): x=torch.rand(n,2)+1; y=torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1] return x,y.unsqueeze(1) # ── Main ── def main(): print("="*80) print(" v9: CONTROLLED FREQ + PHASE + GATE (separated paths)") print(" ω(x) = ω0·(1+0.1·tanh(W_ω·x)) | φ(x) = π·tanh(W_φ·x)") print(" y = LN( α⊙per + (1-α)⊙val + res ) | λ=5e-4 polarization") print("="*80) N=3 models = { 'Vanilla': (VanillaMLP, {}), 'SinGLU': (SinGLUNet, {'w0':None}), 'v9': (v9Net, {'w0':None}), } tasks = [ ("Complex Fn", "reg", d_complex, 4,1, 5000, 300, 1e-3, 30., 750), ("Nested Fn", "reg", d_nested, 2,1, 3000, 300, 1e-3, 20., 750), ("Spiral", "clf", d_spiral, 2,2, 3000, 250, 1e-3, 15., 700), ("Checkerboard","clf", d_checker, 2,2, 3000, 250, 1e-3, 20., 700), ("High-Freq", "reg", d_highfreq, 1,1, 8000, 300, 1e-3, 60., 700), ("Memorize", "reg", d_mem, 8,4, 5000, 400, 1e-3, 10., 200), ] all_res = {}; diag = {} for tn,tt,df,di,do,bud,ep,lr,w0,sp in tasks: print(f"\n{'━'*80}\n {tn} | ~{bud:,} params\n{'━'*80}") hs={} for mn,(mc,mk) in models.items(): kw={k:(w0 if v is None else v) for k,v in mk.items()} hs[mn]=find_h(di,do,N,bud,mc,**kw) tr={} for mn,(mc,mk) in models.items(): kw={k:(w0 if v is None else v) for k,v in mk.items()} h=hs[mn]; scores=[] for seed in SEEDS: set_seed(seed); x,y=df() if sp>=len(x): xtr,ytr,xte,yte=x,y,x,y else: xtr,ytr,xte,yte=x[:sp],y[:sp],x[sp:],y[sp:] set_seed(seed+100); mdl=mc(di,do,h,N,**kw) if tt=='reg': s=train_reg(mdl,xtr,ytr,xte,yte,ep,lr) else: s=train_clf(mdl,xtr,ytr,xte,yte,ep,lr) scores.append(s) if mn=='v9' and seed==SEEDS[-1]: mdl.eval() with torch.no_grad(): als,phs,oms=mdl.get_all_diag(xte[:100]) aa=torch.cat([a.flatten() for a in als]) pp=torch.cat([p.flatten() for p in phs]) oo=torch.cat([o.flatten() for o in oms]) diag[tn]={ 'a_m':aa.mean().item(),'a_s':aa.std().item(), 'a_lo':(aa<.3).float().mean().item(),'a_hi':(aa>.7).float().mean().item(), 'p_s':pp.std().item(),'p_m':pp.mean().item(), 'o_m':oo.mean().item(),'o_s':oo.std().item(), 'o_min':oo.min().item(),'o_max':oo.max().item(), } p=nparams(mc(di,do,h,N,**kw)) tr[mn]={'mean':np.mean(scores),'std':np.std(scores),'scores':scores,'params':p,'hidden':h} is_reg=tt=='reg' if is_reg: best=min(tr,key=lambda k:tr[k]['mean']) else: best=max(tr,key=lambda k:tr[k]['mean']) met="MSE ↓" if is_reg else "Acc ↑" print(f"\n {'Model':<10} {'H':>4} {'P':>6} {met+' (mean±std)':>26}") print(f" {'─'*50}") for mn,r in tr.items(): m,s=r['mean'],r['std'] ms=f"{m:.2e}±{s:.1e}" if(is_reg and m<.001) else(f"{m:.4f}±{s:.4f}" if is_reg else f"{m:.1%}±{s:.3f}") print(f" {mn:<10} {r['hidden']:>4} {r['params']:>6,} {ms:>26}{' ★' if mn==best else ''}") print(f" → {best}") if tn in diag: d=diag[tn] print(f" α: {d['a_m']:.3f}±{d['a_s']:.3f} ({d['a_lo']:.0%} lin, {d['a_hi']:.0%} per)") print(f" φ: std={d['p_s']:.3f}") print(f" ω: {d['o_m']:.1f}±{d['o_s']:.2f} [{d['o_min']:.1f},{d['o_max']:.1f}]") all_res[tn]=tr # OOD print(f"\n{'━'*80}\n OOD: [-1,1] → [1,2]\n{'━'*80}") ood_r={}; ood_d={} for mn,(mc,mk) in models.items(): kw={k:(20. if v is None else v) for k,v in mk.items()} h=find_h(2,1,N,5000,mc,**kw); ids,oods=[],[] for seed in SEEDS: set_seed(seed); xtr,ytr=d_ood_tr() set_seed(seed+50); xid=torch.rand(200,2)*2-1 yid=(torch.sin(3*math.pi*xid[:,0])*torch.cos(3*math.pi*xid[:,1])+xid[:,0]*xid[:,1]).unsqueeze(1) set_seed(seed+50); xoo,yoo=d_ood_te() set_seed(seed+100); mdl=mc(2,1,h,N,**kw) si=train_reg(mdl,xtr,ytr,xid,yid,300,1e-3) mdl.eval() with torch.no_grad(): so=F.mse_loss(mdl(xoo),yoo).item() ids.append(si); oods.append(so) if mn=='v9' and seed==SEEDS[-1]: mdl.eval() with torch.no_grad(): ai,_,oi=mdl.get_all_diag(xid[:100]) ao,_,oo2=mdl.get_all_diag(xoo[:100]) ood_d={ 'id_a':torch.cat([a.flatten() for a in ai]).mean().item(), 'ood_a':torch.cat([a.flatten() for a in ao]).mean().item(), 'id_o':torch.cat([o.flatten() for o in oi]).mean().item(), 'ood_o':torch.cat([o.flatten() for o in oo2]).mean().item(), } p=nparams(mc(2,1,h,N,**kw)) ood_r[mn]={'id':np.mean(ids),'ood':np.mean(oods),'p':p, 'deg':np.mean(oods)/max(np.mean(ids),1e-10), 'is':np.std(ids),'os':np.std(oods)} bo=min(ood_r,key=lambda k:ood_r[k]['ood']) print(f"\n {'Model':<10} {'ID MSE':>14} {'OOD MSE':>14} {'Deg':>8}") print(f" {'─'*50}") for mn,r in ood_r.items(): print(f" {mn:<10} {r['id']:>9.4f}±{r['is']:.3f} {r['ood']:>9.4f}±{r['os']:.3f} {r['deg']:>7.1f}x{' ★' if mn==bo else ''}") print(f" → {bo}") if ood_d: print(f"\n v9 on OOD:") print(f" α: ID={ood_d['id_a']:.4f} → OOD={ood_d['ood_a']:.4f} (shift={ood_d['ood_a']-ood_d['id_a']:+.4f})") print(f" ω: ID={ood_d['id_o']:.2f} → OOD={ood_d['ood_o']:.2f} (shift={ood_d['ood_o']-ood_d['id_o']:+.2f})") all_res['OOD']={mn:{'mean':r['ood'],'std':r['os']} for mn,r in ood_r.items()} # Summary print(f"\n{'='*80}\n SUMMARY\n{'='*80}") wc={k:0 for k in models} print(f"\n {'Task':<18}",end="") for mn in models: print(f" {mn:>12}",end="") print(f" {'W':>8}") print(f" {'─'*56}") for tn,t in all_res.items(): sc={k:v['mean'] for k,v in t.items()} mx=max(sc.values()); is_c=mx>.5 and mx<=1 and min(sc.values())>=0 if min(sc.values())<.001: is_c=False w=min(sc,key=sc.get) if (tn=='OOD' or not is_c) else max(sc,key=sc.get) wc[w]+=1 row=f" {tn:<18}" for mn in models: s=sc[mn] if is_c: row+=f" {s:>11.1%}" elif s<.001: row+=f" {s:>11.2e}" else: row+=f" {s:>11.4f}" row+=f" {'->'+w:>8}"; print(row) print(f"\n {'─'*56}") for mn,c in sorted(wc.items(),key=lambda x:-x[1]): print(f" {mn:<10} {c} wins {'█'*c*4}") # Diag summary print(f"\n v9 DIAGNOSTICS:") print(f" {'Task':<18} {'α':>7} {'α_std':>7} {'%L':>5} {'%P':>5} {'φ_std':>7} {'ω':>7} {'ω_std':>7} {'ω range':>14}") print(f" {'─'*80}") for tn,d in diag.items(): print(f" {tn:<18} {d['a_m']:>7.3f} {d['a_s']:>7.3f} {d['a_lo']:>4.0%} {d['a_hi']:>4.0%}" f" {d['p_s']:>7.3f} {d['o_m']:>7.1f} {d['o_s']:>7.3f} [{d['o_min']:.1f},{d['o_max']:.1f}]") sv={'tasks':{},'ood':{},'diag':diag,'ood_diag':ood_d} for tn,t in all_res.items(): sv['tasks'][tn]={mn:{'mean':float(r['mean']),'std':float(r.get('std',0)), 'scores':[float(s) for s in r.get('scores',[r['mean']])], 'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in t.items()} sv['ood']={mn:{k:float(v) if isinstance(v,(float,np.floating)) else v for k,v in r.items()} for mn,r in ood_r.items()} with open('/app/results_v9.json','w') as f: json.dump(sv,f,indent=2,default=str) print("\n Saved to /app/results_v9.json") if __name__=="__main__": main()