anshdadhich commited on
Commit
064558f
Β·
verified Β·
1 Parent(s): 3a13752

Add v8: adaptive phase + amplitude gate

Browse files
Files changed (1) hide show
  1. benchmark_v8.py +564 -0
benchmark_v8.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ =============================================================================
4
+ BENCHMARK v8: ADAPTIVE PHASE + AMPLITUDE MODULATION
5
+ =============================================================================
6
+
7
+ v7 FAILED because: Ο‰ collapsed to a constant. Neural nets refuse to learn
8
+ frequency when adjusting weights is easier.
9
+
10
+ v8 FIX (from GPT's critique):
11
+ Don't learn frequency. Learn PHASE and AMPLITUDE instead.
12
+
13
+ val = W_val Β· x
14
+ per = sin(Ο‰_fixed Β· W_per Β· x + Ο†(x)) # learned phase, fixed freq
15
+ Ξ± = sigmoid(W_gate Β· x) # learned amplitude gate
16
+ y = LN( val βŠ™ (Ξ± βŠ™ per + (1-Ξ±)) + res ) # smooth interpolation
17
+
18
+ Why this works:
19
+ - Phase gradient: d/dφ sin(ωx + φ) = cos(ωx + φ) — stable, bounded
20
+ - Frequency gradient: d/dω sin(ωx) = x·cos(ωx) — oscillatory, unstable
21
+ - Gate gradient: d/dΞ± = (per - 1) β€” clean signal
22
+
23
+ + Entropy regularization: loss += λ·α(1-α) pushes gate away from 0.5
24
+
25
+ =============================================================================
26
+ """
27
+
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+ import numpy as np
32
+ import math
33
+ import json
34
+
35
+ SEEDS = [0, 1, 2]
36
+
37
+ def set_seed(s):
38
+ torch.manual_seed(s)
39
+ np.random.seed(s)
40
+
41
+ # ============================================================================
42
+ # BASELINES (same as before)
43
+ # ============================================================================
44
+
45
+ class VanillaMLP(nn.Module):
46
+ def __init__(self, in_dim, out_dim, hidden_dim, n_hidden):
47
+ super().__init__()
48
+ layers = []
49
+ prev = in_dim
50
+ for _ in range(n_hidden):
51
+ layers.extend([nn.Linear(prev, hidden_dim), nn.ReLU()])
52
+ prev = hidden_dim
53
+ layers.append(nn.Linear(prev, out_dim))
54
+ self.net = nn.Sequential(*layers)
55
+ def forward(self, x): return self.net(x)
56
+
57
+ class SinGLULayer(nn.Module):
58
+ def __init__(self, in_dim, out_dim, mid_dim, omega_0=30.0):
59
+ super().__init__()
60
+ self.Wg = nn.Linear(in_dim, mid_dim, bias=False)
61
+ self.Wv = nn.Linear(in_dim, mid_dim, bias=False)
62
+ self.Wo = nn.Linear(mid_dim, out_dim, bias=True)
63
+ self.omega_0 = omega_0
64
+ self.ln = nn.LayerNorm(out_dim)
65
+ with torch.no_grad():
66
+ self.Wg.weight.uniform_(-math.sqrt(6/in_dim)/omega_0, math.sqrt(6/in_dim)/omega_0)
67
+ nn.init.xavier_uniform_(self.Wv.weight)
68
+ nn.init.xavier_uniform_(self.Wo.weight)
69
+ def forward(self, x):
70
+ return self.ln(self.Wo(torch.sin(self.omega_0 * self.Wg(x)) * self.Wv(x)))
71
+
72
+ class SinGLUNet(nn.Module):
73
+ def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0):
74
+ super().__init__()
75
+ mid = max(2, int(hidden_dim * 2/3))
76
+ layers = []
77
+ prev = in_dim
78
+ for _ in range(n_hidden):
79
+ layers.append(SinGLULayer(prev, hidden_dim, mid, omega_0)); prev = hidden_dim
80
+ layers.append(nn.Linear(prev, out_dim))
81
+ self.layers = nn.ModuleList(layers)
82
+ def forward(self, x):
83
+ for l in self.layers: x = l(x)
84
+ return x
85
+
86
+ class HybridLayer(nn.Module):
87
+ def __init__(self, in_dim, out_dim, mid_dim, omega_0=30.0):
88
+ super().__init__()
89
+ self.W1 = nn.Linear(in_dim, mid_dim, bias=False)
90
+ self.W2 = nn.Linear(in_dim, mid_dim, bias=False)
91
+ self.phase = nn.Parameter(torch.empty(mid_dim))
92
+ self.W3 = nn.Linear(mid_dim, out_dim, bias=True)
93
+ self.omega_0 = omega_0
94
+ self.ln = nn.LayerNorm(out_dim)
95
+ self.res = nn.Linear(in_dim, out_dim, bias=False) if in_dim != out_dim else nn.Identity()
96
+ with torch.no_grad():
97
+ nn.init.xavier_uniform_(self.W1.weight)
98
+ self.W2.weight.uniform_(-math.sqrt(6/in_dim)/omega_0, math.sqrt(6/in_dim)/omega_0)
99
+ self.phase.uniform_(-math.pi, math.pi)
100
+ nn.init.xavier_uniform_(self.W3.weight)
101
+ def forward(self, x):
102
+ return self.ln(self.W3(self.W1(x) * torch.sin(self.omega_0 * self.W2(x) + self.phase)) + self.res(x))
103
+
104
+ class HybridNet(nn.Module):
105
+ def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0):
106
+ super().__init__()
107
+ mid = max(2, int(hidden_dim * 0.55))
108
+ layers = []
109
+ prev = in_dim
110
+ for _ in range(n_hidden):
111
+ layers.append(HybridLayer(prev, hidden_dim, mid, omega_0)); prev = hidden_dim
112
+ layers.append(nn.Linear(prev, out_dim))
113
+ self.layers = nn.ModuleList(layers)
114
+ def forward(self, x):
115
+ for l in self.layers: x = l(x)
116
+ return x
117
+
118
+ # ============================================================================
119
+ # v8: ADAPTIVE PHASE + AMPLITUDE GATE
120
+ # ============================================================================
121
+
122
+ class AdaptivePhaseLayer(nn.Module):
123
+ """
124
+ val = W_val Β· x
125
+ per = sin(Ο‰ Β· W_per Β· x + Ο†(x)) ← learned phase (NOT frequency)
126
+ Ξ± = sigmoid(W_gate Β· x) ← amplitude gate
127
+ y = LN( val βŠ™ (Ξ± βŠ™ per + (1-Ξ±)) + residual )
128
+
129
+ Phase is easy to optimize (gradient = cos, bounded).
130
+ Gate polarizes with entropy regularization.
131
+ Explicit linear fallback when Ξ± β†’ 0.
132
+ """
133
+ def __init__(self, in_dim, out_dim, omega_0=30.0, rank=None):
134
+ super().__init__()
135
+ r = rank or max(2, min(in_dim // 4, 8))
136
+
137
+ self.W_val = nn.Linear(in_dim, out_dim, bias=True)
138
+ self.W_per = nn.Linear(in_dim, out_dim, bias=False)
139
+
140
+ # Phase predictor: low-rank, bounded by tanh
141
+ self.phi_down = nn.Linear(in_dim, r, bias=False)
142
+ self.phi_up = nn.Linear(r, out_dim, bias=True)
143
+
144
+ # Amplitude gate: low-rank
145
+ self.gate_down = nn.Linear(in_dim, r, bias=False)
146
+ self.gate_up = nn.Linear(r, out_dim, bias=True)
147
+
148
+ self.omega_0 = omega_0
149
+ self.ln = nn.LayerNorm(out_dim)
150
+ self.res = nn.Linear(in_dim, out_dim, bias=False) if in_dim != out_dim else nn.Identity()
151
+
152
+ with torch.no_grad():
153
+ nn.init.xavier_uniform_(self.W_val.weight)
154
+ bound = math.sqrt(6.0 / in_dim) / omega_0
155
+ self.W_per.weight.uniform_(-bound, bound)
156
+ # Phase: start at 0 (no shift initially)
157
+ nn.init.xavier_uniform_(self.phi_down.weight)
158
+ nn.init.zeros_(self.phi_up.weight)
159
+ nn.init.zeros_(self.phi_up.bias)
160
+ # Gate: start at 0 β†’ sigmoid(0) = 0.5 (balanced)
161
+ nn.init.xavier_uniform_(self.gate_down.weight)
162
+ nn.init.zeros_(self.gate_up.weight)
163
+ nn.init.zeros_(self.gate_up.bias)
164
+
165
+ def forward(self, x):
166
+ val = self.W_val(x)
167
+ per_in = self.W_per(x)
168
+
169
+ # Input-dependent phase shift (bounded by tanh to [-Ο€, Ο€])
170
+ phi = math.pi * torch.tanh(self.phi_up(self.phi_down(x)))
171
+ per = torch.sin(self.omega_0 * per_in + phi)
172
+
173
+ # Amplitude gate (how much periodic vs linear)
174
+ alpha = torch.sigmoid(self.gate_up(self.gate_down(x)))
175
+
176
+ # Smooth interpolation: Ξ±=1 β†’ full periodic, Ξ±=0 β†’ just val
177
+ mixed = val * (alpha * per + (1 - alpha))
178
+ return self.ln(mixed + self.res(x))
179
+
180
+ def get_diagnostics(self, x):
181
+ with torch.no_grad():
182
+ phi = math.pi * torch.tanh(self.phi_up(self.phi_down(x)))
183
+ alpha = torch.sigmoid(self.gate_up(self.gate_down(x)))
184
+ return alpha, phi
185
+
186
+
187
+ class AdaptivePhaseNet(nn.Module):
188
+ def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0):
189
+ super().__init__()
190
+ layers = []
191
+ prev = in_dim
192
+ for _ in range(n_hidden):
193
+ layers.append(AdaptivePhaseLayer(prev, hidden_dim, omega_0))
194
+ prev = hidden_dim
195
+ layers.append(nn.Linear(prev, out_dim))
196
+ self.layers = nn.ModuleList(layers)
197
+
198
+ def forward(self, x):
199
+ for l in self.layers: x = l(x)
200
+ return x
201
+
202
+ def get_all_diagnostics(self, x):
203
+ alphas, phis = [], []
204
+ h = x
205
+ for l in self.layers:
206
+ if isinstance(l, AdaptivePhaseLayer):
207
+ a, p = l.get_diagnostics(h)
208
+ alphas.append(a); phis.append(p)
209
+ h = l(h)
210
+ else: h = l(h)
211
+ return alphas, phis
212
+
213
+ def entropy_reg(self, x):
214
+ """Push Ξ± away from 0.5 β€” encourage polarization"""
215
+ total = 0
216
+ h = x
217
+ for l in self.layers:
218
+ if isinstance(l, AdaptivePhaseLayer):
219
+ alpha = torch.sigmoid(l.gate_up(l.gate_down(h)))
220
+ total = total + (alpha * (1 - alpha)).mean()
221
+ h = l(h)
222
+ else: h = l(h)
223
+ return total
224
+
225
+ # ============================================================================
226
+ # UTILS
227
+ # ============================================================================
228
+
229
+ def count_params(m):
230
+ return sum(p.numel() for p in m.parameters() if p.requires_grad)
231
+
232
+ def find_hidden(in_d, out_d, n_h, target_p, model_cls, **kw):
233
+ lo, hi, best_h = 2, 512, 2
234
+ while lo <= hi:
235
+ mid = (lo + hi) // 2
236
+ p = count_params(model_cls(in_d, out_d, mid, n_h, **kw))
237
+ if abs(p - target_p) < abs(count_params(model_cls(in_d, out_d, best_h, n_h, **kw)) - target_p):
238
+ best_h = mid
239
+ if p < target_p: lo = mid + 1
240
+ else: hi = mid - 1
241
+ return best_h
242
+
243
+ def train_reg(model, xtr, ytr, xte, yte, epochs, lr, entropy_lambda=1e-4, bs=256):
244
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
245
+ sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
246
+ best = float('inf')
247
+ use_entropy = isinstance(model, AdaptivePhaseNet) and entropy_lambda > 0
248
+ n = len(xtr)
249
+ for ep in range(epochs):
250
+ model.train()
251
+ perm = torch.randperm(n)
252
+ for i in range(0, n, bs):
253
+ idx = perm[i:i+bs]
254
+ bx, by = xtr[idx], ytr[idx]
255
+ loss = F.mse_loss(model(bx), by)
256
+ if use_entropy:
257
+ loss = loss + entropy_lambda * model.entropy_reg(bx)
258
+ opt.zero_grad(); loss.backward()
259
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
260
+ opt.step()
261
+ sch.step()
262
+ if (ep+1) % max(1, epochs//10) == 0:
263
+ model.eval()
264
+ with torch.no_grad():
265
+ best = min(best, F.mse_loss(model(xte), yte).item())
266
+ model.eval()
267
+ with torch.no_grad():
268
+ best = min(best, F.mse_loss(model(xte), yte).item())
269
+ return best
270
+
271
+ def train_clf(model, xtr, ytr, xte, yte, epochs, lr, entropy_lambda=1e-4, bs=256):
272
+ opt = torch.optim.Adam(model.parameters(), lr=lr)
273
+ sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
274
+ best = 0
275
+ use_entropy = isinstance(model, AdaptivePhaseNet) and entropy_lambda > 0
276
+ n = len(xtr)
277
+ for ep in range(epochs):
278
+ model.train()
279
+ perm = torch.randperm(n)
280
+ for i in range(0, n, bs):
281
+ idx = perm[i:i+bs]
282
+ bx, by = xtr[idx], ytr[idx]
283
+ loss = F.cross_entropy(model(bx), by)
284
+ if use_entropy:
285
+ loss = loss + entropy_lambda * model.entropy_reg(bx)
286
+ opt.zero_grad(); loss.backward()
287
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
288
+ opt.step()
289
+ sch.step()
290
+ if (ep+1) % max(1, epochs//10) == 0:
291
+ model.eval()
292
+ with torch.no_grad():
293
+ best = max(best, (model(xte).argmax(1) == yte).float().mean().item())
294
+ model.eval()
295
+ with torch.no_grad():
296
+ best = max(best, (model(xte).argmax(1) == yte).float().mean().item())
297
+ return best
298
+
299
+ # ============================================================================
300
+ # DATA
301
+ # ============================================================================
302
+
303
+ def data_complex(n=1000):
304
+ x = torch.rand(n,4)*2-1
305
+ y = torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2))
306
+ return x, y.unsqueeze(1)
307
+
308
+ def data_nested(n=1000):
309
+ x = torch.rand(n,2)*2-1
310
+ y = torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1])
311
+ return x, y.unsqueeze(1)
312
+
313
+ def data_spiral(n=1000):
314
+ t = torch.linspace(0,4*np.pi,n//2); r = torch.linspace(0.3,2,n//2)
315
+ x1 = torch.stack([r*torch.cos(t),r*torch.sin(t)],1)
316
+ x2 = torch.stack([r*torch.cos(t+np.pi),r*torch.sin(t+np.pi)],1)
317
+ x = torch.cat([x1,x2])+torch.randn(n,2)*0.05
318
+ y = torch.cat([torch.zeros(n//2),torch.ones(n//2)]).long()
319
+ p = torch.randperm(n); return x[p],y[p]
320
+
321
+ def data_checker(n=1000):
322
+ x = torch.rand(n,2)*2-1
323
+ y = ((torch.sin(3*math.pi*x[:,0])*torch.sin(3*math.pi*x[:,1]))>0).long()
324
+ return x, y
325
+
326
+ def data_highfreq(n=1000):
327
+ x = torch.linspace(-1,1,n).unsqueeze(1)
328
+ return x, torch.sin(20*x)+torch.sin(50*x)+0.5*torch.sin(100*x)
329
+
330
+ def data_memorize(n=200):
331
+ return torch.randn(n,8), torch.randn(n,4)
332
+
333
+ def data_ood_train(n=800):
334
+ x = torch.rand(n,2)*2-1
335
+ y = torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]
336
+ return x, y.unsqueeze(1)
337
+
338
+ def data_ood_test(n=300):
339
+ x = torch.rand(n,2)+1
340
+ y = torch.sin(3*math.pi*x[:,0])*torch.cos(3*math.pi*x[:,1])+x[:,0]*x[:,1]
341
+ return x, y.unsqueeze(1)
342
+
343
+ # ============================================================================
344
+ # MAIN
345
+ # ============================================================================
346
+
347
+ def main():
348
+ print("="*80)
349
+ print(" BENCHMARK v8: ADAPTIVE PHASE + AMPLITUDE GATE")
350
+ print(" Learn PHASE Ο†(x) and GATE Ξ±(x), NOT frequency Ο‰")
351
+ print(" + entropy regularization to prevent Ξ± collapse at 0.5")
352
+ print("="*80)
353
+
354
+ N_H = 3
355
+ models = {
356
+ 'Vanilla': (VanillaMLP, {}),
357
+ 'SinGLU': (SinGLUNet, {'omega_0': None}),
358
+ 'Hybrid': (HybridNet, {'omega_0': None}),
359
+ 'v8:Phase': (AdaptivePhaseNet, {'omega_0': None}),
360
+ }
361
+
362
+ tasks = [
363
+ ("Complex Fn (4D)", "reg", data_complex, 4,1, 5000, 300, 1e-3, 30.0, 750),
364
+ ("Nested Fn (2D)", "reg", data_nested, 2,1, 3000, 300, 1e-3, 20.0, 750),
365
+ ("Spiral", "clf", data_spiral, 2,2, 3000, 250, 1e-3, 15.0, 700),
366
+ ("Checkerboard", "clf", data_checker, 2,2, 3000, 250, 1e-3, 20.0, 700),
367
+ ("High-Freq", "reg", data_highfreq, 1,1, 8000, 300, 1e-3, 60.0, 700),
368
+ ("Memorization", "reg", data_memorize, 8,4, 5000, 400, 1e-3, 10.0, 200),
369
+ ]
370
+
371
+ all_results = {}
372
+ diag_data = {}
373
+
374
+ for tname, ttype, dfn, ind, outd, budget, epochs, lr, omega, split in tasks:
375
+ print(f"\n{'━'*80}")
376
+ print(f" {tname} | budget ~{budget:,}")
377
+ print(f"{'━'*80}")
378
+
379
+ hdims = {}
380
+ for mn, (mc, mk) in models.items():
381
+ kw = {k: (omega if v is None else v) for k,v in mk.items()}
382
+ hdims[mn] = find_hidden(ind, outd, N_H, budget, mc, **kw)
383
+
384
+ task_res = {}
385
+ for mn, (mc, mk) in models.items():
386
+ kw = {k: (omega if v is None else v) for k,v in mk.items()}
387
+ h = hdims[mn]
388
+ scores = []
389
+ for seed in SEEDS:
390
+ set_seed(seed); x,y = dfn()
391
+ if split >= len(x): xtr,ytr,xte,yte = x,y,x,y
392
+ else: xtr,ytr,xte,yte = x[:split],y[:split],x[split:],y[split:]
393
+ set_seed(seed+100); model = mc(ind, outd, h, N_H, **kw)
394
+ if ttype == 'reg': s = train_reg(model, xtr, ytr, xte, yte, epochs, lr)
395
+ else: s = train_clf(model, xtr, ytr, xte, yte, epochs, lr)
396
+ scores.append(s)
397
+
398
+ # Diagnostics for v8 (last seed)
399
+ if mn == 'v8:Phase' and seed == SEEDS[-1]:
400
+ model.eval()
401
+ with torch.no_grad():
402
+ alphas, phis = model.get_all_diagnostics(xte[:100])
403
+ all_a = torch.cat([a.flatten() for a in alphas])
404
+ all_p = torch.cat([p.flatten() for p in phis])
405
+ diag_data[tname] = {
406
+ 'alpha_mean': all_a.mean().item(),
407
+ 'alpha_std': all_a.std().item(),
408
+ 'alpha_pct_low': (all_a < 0.3).float().mean().item(),
409
+ 'alpha_pct_high': (all_a > 0.7).float().mean().item(),
410
+ 'phi_mean': all_p.mean().item(),
411
+ 'phi_std': all_p.std().item(),
412
+ }
413
+
414
+ p = count_params(mc(ind, outd, h, N_H, **kw))
415
+ task_res[mn] = {'mean': np.mean(scores), 'std': np.std(scores),
416
+ 'scores': scores, 'params': p, 'hidden': h}
417
+
418
+ is_reg = ttype == 'reg'
419
+ if is_reg: best_mn = min(task_res, key=lambda k: task_res[k]['mean'])
420
+ else: best_mn = max(task_res, key=lambda k: task_res[k]['mean'])
421
+ metric = "MSE ↓" if is_reg else "Acc ↑"
422
+
423
+ print(f"\n {'Model':<12} {'H':>4} {'Params':>7} {metric+' (meanΒ±std)':>28}")
424
+ print(f" {'─'*56}")
425
+ for mn, r in task_res.items():
426
+ m,s = r['mean'], r['std']
427
+ ms = f"{m:.2e}Β±{s:.1e}" if (is_reg and m<0.001) else (f"{m:.4f}Β±{s:.4f}" if is_reg else f"{m:.1%}Β±{s:.3f}")
428
+ print(f" {mn:<12} {r['hidden']:>4} {r['params']:>7,} {ms:>28}{' β˜…' if mn==best_mn else ''}")
429
+ print(f" β†’ Winner: {best_mn}")
430
+
431
+ if tname in diag_data:
432
+ d = diag_data[tname]
433
+ print(f" β†’ v8 Ξ±: mean={d['alpha_mean']:.3f} std={d['alpha_std']:.3f}"
434
+ f" | {d['alpha_pct_low']:.0%} linear {d['alpha_pct_high']:.0%} periodic")
435
+ print(f" β†’ v8 Ο†: mean={d['phi_mean']:.3f} std={d['phi_std']:.3f}")
436
+
437
+ all_results[tname] = task_res
438
+
439
+ # OOD
440
+ print(f"\n{'━'*80}")
441
+ print(f" OOD: Train [-1,1] β†’ Test [1,2]")
442
+ print(f" Does Ξ± shift toward linear on OOD?")
443
+ print(f"{'━'*80}")
444
+
445
+ ood_res = {}; ood_diag = {}
446
+ for mn, (mc, mk) in models.items():
447
+ kw = {k: (20.0 if v is None else v) for k,v in mk.items()}
448
+ h = find_hidden(2, 1, N_H, 5000, mc, **kw)
449
+ id_sc, ood_sc = [], []
450
+ for seed in SEEDS:
451
+ set_seed(seed); xtr,ytr = data_ood_train()
452
+ set_seed(seed+50)
453
+ xid = torch.rand(200,2)*2-1
454
+ yid = (torch.sin(3*math.pi*xid[:,0])*torch.cos(3*math.pi*xid[:,1])+xid[:,0]*xid[:,1]).unsqueeze(1)
455
+ set_seed(seed+50); xood,yood = data_ood_test()
456
+ set_seed(seed+100); model = mc(2,1,h,N_H,**kw)
457
+ s_id = train_reg(model, xtr, ytr, xid, yid, 300, 1e-3)
458
+ model.eval()
459
+ with torch.no_grad(): s_ood = F.mse_loss(model(xood), yood).item()
460
+ id_sc.append(s_id); ood_sc.append(s_ood)
461
+ if mn == 'v8:Phase' and seed == SEEDS[-1]:
462
+ model.eval()
463
+ with torch.no_grad():
464
+ a_id, _ = model.get_all_diagnostics(xid[:100])
465
+ a_ood, _ = model.get_all_diagnostics(xood[:100])
466
+ ood_diag = {
467
+ 'id_alpha': torch.cat([a.flatten() for a in a_id]).mean().item(),
468
+ 'ood_alpha': torch.cat([a.flatten() for a in a_ood]).mean().item(),
469
+ }
470
+ p = count_params(mc(2,1,h,N_H,**kw))
471
+ ood_res[mn] = {'id': np.mean(id_sc), 'ood': np.mean(ood_sc), 'params': p,
472
+ 'deg': np.mean(ood_sc)/max(np.mean(id_sc),1e-10),
473
+ 'id_std': np.std(id_sc), 'ood_std': np.std(ood_sc)}
474
+
475
+ best_ood = min(ood_res, key=lambda k: ood_res[k]['ood'])
476
+ print(f"\n {'Model':<12} {'ID MSE':>14} {'OOD MSE':>14} {'Degrad.':>9}")
477
+ print(f" {'─'*52}")
478
+ for mn,r in ood_res.items():
479
+ mark = " β˜…" if mn==best_ood else ""
480
+ print(f" {mn:<12} {r['id']:>9.4f}Β±{r['id_std']:.3f} {r['ood']:>9.4f}Β±{r['ood_std']:.3f} {r['deg']:>8.1f}x{mark}")
481
+ print(f" β†’ Best OOD: {best_ood}")
482
+
483
+ if ood_diag:
484
+ shift = ood_diag['ood_alpha'] - ood_diag['id_alpha']
485
+ print(f"\n v8 Ξ± SHIFT on OOD:")
486
+ print(f" ID: Ξ± = {ood_diag['id_alpha']:.4f}")
487
+ print(f" OOD: Ξ± = {ood_diag['ood_alpha']:.4f}")
488
+ if shift < -0.03:
489
+ print(f" β†’ Ξ± DROPPED by {abs(shift):.4f} β†’ periodic reduced on OOD βœ…")
490
+ elif shift > 0.03:
491
+ print(f" β†’ Ξ± INCREASED by {shift:.4f} β†’ MORE periodic on OOD ❌")
492
+ else:
493
+ print(f" β†’ Ξ± shift = {shift:+.4f} (minimal)")
494
+
495
+ all_results['OOD'] = {mn: {'mean': r['ood'], 'std': r['ood_std']} for mn,r in ood_res.items()}
496
+
497
+ # GRAND SUMMARY
498
+ print(f"\n{'='*80}")
499
+ print(f" GRAND SUMMARY")
500
+ print(f"{'='*80}")
501
+
502
+ win_counts = {k: 0 for k in models}
503
+ print(f"\n {'Task':<20}", end="")
504
+ for mn in models: print(f" {mn:>12}", end="")
505
+ print(f" {'Winner':>10}")
506
+ print(f" {'─'*72}")
507
+
508
+ for tname, tr in all_results.items():
509
+ scores = {k: v['mean'] for k,v in tr.items()}
510
+ max_s = max(scores.values())
511
+ is_clf = max_s > 0.5 and max_s <= 1.0 and min(scores.values()) >= 0
512
+ if min(scores.values()) < 0.001: is_clf = False
513
+ if tname == 'OOD': winner = min(scores, key=scores.get)
514
+ elif is_clf: winner = max(scores, key=scores.get)
515
+ else: winner = min(scores, key=scores.get)
516
+ win_counts[winner] += 1
517
+ row = f" {tname:<20}"
518
+ for mn in models:
519
+ s = scores[mn]
520
+ if is_clf: row += f" {s:>11.1%}"
521
+ elif s < 0.001: row += f" {s:>11.2e}"
522
+ else: row += f" {s:>11.4f}"
523
+ row += f" {'->'+winner:>10}"
524
+ print(row)
525
+
526
+ print(f"\n {'─'*72}")
527
+ for mn, c in sorted(win_counts.items(), key=lambda x: -x[1]):
528
+ print(f" {mn:<14} {c} wins {'β–ˆ'*c*3}")
529
+
530
+ # DIAGNOSTICS SUMMARY
531
+ print(f"\n{'━'*80}")
532
+ print(f" v8 DIAGNOSTICS: Did phase & gate actually learn?")
533
+ print(f"{'━'*80}")
534
+ print(f"\n {'Task':<22} {'Ξ± mean':>7} {'Ξ± std':>7} {'%Lin':>6} {'%Per':>6} {'Ο† std':>7}")
535
+ print(f" {'─'*58}")
536
+ for tname, d in diag_data.items():
537
+ print(f" {tname:<22} {d['alpha_mean']:>7.3f} {d['alpha_std']:>7.3f}"
538
+ f" {d['alpha_pct_low']:>5.0%} {d['alpha_pct_high']:>5.0%} {d['phi_std']:>7.3f}")
539
+
540
+ print(f"""
541
+ ╔════════════════════════════════════════════════════════════════════════════╗
542
+ β•‘ v8 VERDICT: ADAPTIVE PHASE + AMPLITUDE GATE β•‘
543
+ β•‘ β•‘
544
+ β•‘ Key questions: β•‘
545
+ β•‘ 1. Did Ξ± polarize (not stuck at 0.5)? Check Ξ±_std and %Lin/%Per β•‘
546
+ β•‘ 2. Did Ο† vary per input? Check Ο†_std > 0 β•‘
547
+ β•‘ 3. Did Ξ± shift on OOD? Check Ξ± shift above β•‘
548
+ β•‘ 4. Did it beat SinGLU? Check win counts β•‘
549
+ β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
550
+ """)
551
+
552
+ save = {'tasks': {}, 'ood': {}, 'diagnostics': diag_data, 'ood_diag': ood_diag}
553
+ for tname, tr in all_results.items():
554
+ save['tasks'][tname] = {mn: {'mean':float(r['mean']),'std':float(r.get('std',0)),
555
+ 'scores':[float(s) for s in r.get('scores',[r['mean']])],
556
+ 'params':r.get('params',0),'hidden':r.get('hidden',0)} for mn,r in tr.items()}
557
+ save['ood'] = {mn:{k:float(v) if isinstance(v,(float,np.floating)) else v
558
+ for k,v in r.items()} for mn,r in ood_res.items()}
559
+ with open('/app/results_v8.json','w') as f:
560
+ json.dump(save, f, indent=2, default=str)
561
+ print(" Saved to /app/results_v8.json")
562
+
563
+ if __name__ == "__main__":
564
+ main()