anshdadhich's picture
Add v6: adaptive neuron benchmark
3511526 verified
#!/usr/bin/env python3
"""
=============================================================================
BENCHMARK v6: THE ADAPTIVE NEURON
=============================================================================
The idea: Don't pick one neuron type. Let each neuron LEARN whether to be
linear or periodic based on the input.
α(x) = sigmoid(W_route · x) # per-input routing
linear_path = W_lin · x # standard linear
periodic_path = (W_gate · x) ⊙ sin(ω·W_per · x) # multiplicative periodic
y = α ⊙ periodic_path + (1-α) ⊙ linear_path # blend
On periodic tasks → α learns to go high → gets the 162× gains
On OOD inputs → α learns to go low → falls back to safe linear behavior
On mixed tasks → different neurons specialize differently
Extra cost: one routing vector W_route per layer — negligible.
We test against: Vanilla, SinGLU, Hybrid, and report:
- Performance (3 seeds, mean±std)
- α distribution (what did the routing learn?)
- OOD generalization (does it auto-fallback?)
- Gradient norms (is it stable?)
=============================================================================
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
import json
DEVICE = 'cpu'
SEEDS = [0, 1, 2]
def set_seed(s):
torch.manual_seed(s)
np.random.seed(s)
# ============================================================================
# ALL ARCHITECTURES
# ============================================================================
class VanillaMLP(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, n_hidden):
super().__init__()
layers = []
prev = in_dim
for _ in range(n_hidden):
layers.extend([nn.Linear(prev, hidden_dim), nn.ReLU()])
prev = hidden_dim
layers.append(nn.Linear(prev, out_dim))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class SinGLULayer(nn.Module):
def __init__(self, in_dim, out_dim, mid_dim, omega_0=30.0):
super().__init__()
self.W_gate = nn.Linear(in_dim, mid_dim, bias=False)
self.W_val = nn.Linear(in_dim, mid_dim, bias=False)
self.W_out = nn.Linear(mid_dim, out_dim, bias=True)
self.omega_0 = omega_0
self.ln = nn.LayerNorm(out_dim)
with torch.no_grad():
bound = math.sqrt(6.0 / in_dim) / omega_0
self.W_gate.weight.uniform_(-bound, bound)
nn.init.xavier_uniform_(self.W_val.weight)
nn.init.xavier_uniform_(self.W_out.weight)
def forward(self, x):
return self.ln(self.W_out(torch.sin(self.omega_0 * self.W_gate(x)) * self.W_val(x)))
class SinGLUNet(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0):
super().__init__()
mid = max(2, int(hidden_dim * 2 / 3))
layers = []
prev = in_dim
for _ in range(n_hidden):
layers.append(SinGLULayer(prev, hidden_dim, mid, omega_0))
prev = hidden_dim
layers.append(nn.Linear(prev, out_dim))
self.layers = nn.ModuleList(layers)
def forward(self, x):
for l in self.layers: x = l(x)
return x
class HybridLayer(nn.Module):
def __init__(self, in_dim, out_dim, mid_dim, omega_0=30.0):
super().__init__()
self.W1 = nn.Linear(in_dim, mid_dim, bias=False)
self.W2 = nn.Linear(in_dim, mid_dim, bias=False)
self.phase = nn.Parameter(torch.empty(mid_dim))
self.W3 = nn.Linear(mid_dim, out_dim, bias=True)
self.omega_0 = omega_0
self.ln = nn.LayerNorm(out_dim)
self.residual = nn.Linear(in_dim, out_dim, bias=False) if in_dim != out_dim else nn.Identity()
with torch.no_grad():
nn.init.xavier_uniform_(self.W1.weight)
bound = math.sqrt(6.0 / in_dim) / omega_0
self.W2.weight.uniform_(-bound, bound)
self.phase.uniform_(-math.pi, math.pi)
nn.init.xavier_uniform_(self.W3.weight)
def forward(self, x):
lin = self.W1(x)
per = torch.sin(self.omega_0 * self.W2(x) + self.phase)
return self.ln(self.W3(lin * per) + self.residual(x))
class HybridNet(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0):
super().__init__()
mid = max(2, int(hidden_dim * 0.55))
layers = []
prev = in_dim
for _ in range(n_hidden):
layers.append(HybridLayer(prev, hidden_dim, mid, omega_0))
prev = hidden_dim
layers.append(nn.Linear(prev, out_dim))
self.layers = nn.ModuleList(layers)
def forward(self, x):
for l in self.layers: x = l(x)
return x
# ============================================================================
# THE ADAPTIVE NEURON
# ============================================================================
class AdaptiveNeuronLayer(nn.Module):
"""
The adaptive neuron: learns PER-INPUT whether to be linear or periodic.
α(x) = sigmoid(W_route · x + b_route) # routing signal [0,1]
linear_path = ReLU(W_lin · x + b_lin) # standard linear
periodic_path = (W_val · x) ⊙ sin(ω·W_per · x) # multiplicative periodic
y = α ⊙ periodic_path + (1-α) ⊙ linear_path # input-dependent blend
Param budget:
W_route(h×d) + W_lin(h×d) + W_per(h×d) + W_val(h×d) + biases
= 4 matrices. To match budget, use reduced hidden dim.
Alternatively (param-efficient version):
W_route is a SMALL projection: W_r(h×r) @ W_r2(r×d), r=d//4
And W_per shares structure with W_lin via low-rank delta.
We implement the efficient version:
- W_lin: main weight (h×d) — shared backbone
- W_delta: low-rank perturbation for periodic branch: U(h×r)@V(r×d)
- W_route: small routing head: W_r(h×r2)@W_r2(r2×d)
- W_val: reuse W_lin (same as the linear path)
So periodic branch = (W_lin·x) ⊙ sin(ω·(W_lin + U@V)·x)
This way W_lin captures shared features, U@V captures what's different
about the periodic branch. Much more param-efficient.
"""
def __init__(self, in_dim, out_dim, omega_0=30.0, route_rank=None):
super().__init__()
r = route_rank or max(2, in_dim // 4) # rank for routing + delta
# Main linear path
self.W_lin = nn.Linear(in_dim, out_dim, bias=True)
# Low-rank delta for periodic branch: periodic_W = W_lin.weight + U @ V
self.U = nn.Linear(r, out_dim, bias=False)
self.V = nn.Linear(in_dim, r, bias=False)
self.phase = nn.Parameter(torch.empty(out_dim))
# Routing: small learned gate
self.route_V = nn.Linear(in_dim, r, bias=False)
self.route_U = nn.Linear(r, out_dim, bias=True)
self.omega_0 = omega_0
self.ln = nn.LayerNorm(out_dim)
with torch.no_grad():
nn.init.xavier_uniform_(self.W_lin.weight)
# Periodic delta: SIREN init
bound = math.sqrt(6.0 / in_dim) / omega_0
self.U.weight.uniform_(-bound, bound)
self.V.weight.uniform_(-1.0/in_dim, 1.0/in_dim)
self.phase.uniform_(-math.pi, math.pi)
# Routing: init near 0.5 (balanced)
nn.init.zeros_(self.route_U.weight)
nn.init.zeros_(self.route_U.bias) # sigmoid(0) = 0.5
self.route_V.weight.uniform_(-0.01, 0.01)
def forward(self, x):
# Linear path
lin = F.relu(self.W_lin(x)) # (batch, out)
# Periodic path: use W_lin + low-rank delta
base = self.W_lin(x) # shared computation
delta = self.U(self.V(x)) # low-rank perturbation
periodic_input = base + delta
per = base * torch.sin(self.omega_0 * periodic_input + self.phase) # multiplicative
# Routing: input-dependent α ∈ [0,1] per neuron
alpha = torch.sigmoid(self.route_U(self.route_V(x))) # (batch, out)
# Blend
out = alpha * per + (1 - alpha) * lin
return self.ln(out)
def get_alpha(self, x):
"""For analysis: return the routing values"""
with torch.no_grad():
return torch.sigmoid(self.route_U(self.route_V(x)))
class AdaptiveNet(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim, n_hidden, omega_0=30.0):
super().__init__()
layers = []
prev = in_dim
for _ in range(n_hidden):
layers.append(AdaptiveNeuronLayer(prev, hidden_dim, omega_0))
prev = hidden_dim
layers.append(nn.Linear(prev, out_dim))
self.layers = nn.ModuleList(layers)
def forward(self, x):
for l in self.layers: x = l(x)
return x
def get_all_alphas(self, x):
"""Get routing values from all adaptive layers"""
alphas = []
h = x
for l in self.layers:
if isinstance(l, AdaptiveNeuronLayer):
alphas.append(l.get_alpha(h))
h = l(h)
else:
h = l(h)
return alphas
# ============================================================================
# UTILS
# ============================================================================
def count_params(m):
return sum(p.numel() for p in m.parameters() if p.requires_grad)
def find_hidden(in_d, out_d, n_h, target_p, model_cls, **kw):
lo, hi, best_h = 2, 512, 2
while lo <= hi:
mid = (lo + hi) // 2
p = count_params(model_cls(in_d, out_d, mid, n_h, **kw))
if abs(p - target_p) < abs(count_params(model_cls(in_d, out_d, best_h, n_h, **kw)) - target_p):
best_h = mid
if p < target_p: lo = mid + 1
else: hi = mid - 1
return best_h
def train_reg(model, xtr, ytr, xte, yte, epochs, lr, bs=256):
opt = torch.optim.Adam(model.parameters(), lr=lr)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best = float('inf')
n = len(xtr)
for ep in range(epochs):
model.train()
perm = torch.randperm(n)
for i in range(0, n, bs):
idx = perm[i:i+bs]
loss = F.mse_loss(model(xtr[idx]), ytr[idx])
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sch.step()
if (ep+1) % max(1, epochs//10) == 0:
model.eval()
with torch.no_grad():
best = min(best, F.mse_loss(model(xte), yte).item())
model.eval()
with torch.no_grad():
best = min(best, F.mse_loss(model(xte), yte).item())
return best
def train_clf(model, xtr, ytr, xte, yte, epochs, lr, bs=256):
opt = torch.optim.Adam(model.parameters(), lr=lr)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
best = 0
n = len(xtr)
for ep in range(epochs):
model.train()
perm = torch.randperm(n)
for i in range(0, n, bs):
idx = perm[i:i+bs]
loss = F.cross_entropy(model(xtr[idx]), ytr[idx])
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
sch.step()
if (ep+1) % max(1, epochs//10) == 0:
model.eval()
with torch.no_grad():
best = max(best, (model(xte).argmax(1) == yte).float().mean().item())
model.eval()
with torch.no_grad():
best = max(best, (model(xte).argmax(1) == yte).float().mean().item())
return best
# ============================================================================
# DATA
# ============================================================================
def data_complex(n=1000):
x = torch.rand(n,4)*2-1
y = torch.exp(torch.sin(x[:,0]**2+x[:,1]**2)+torch.sin(x[:,2]**2+x[:,3]**2))
return x, y.unsqueeze(1)
def data_nested(n=1000):
x = torch.rand(n,2)*2-1
y = torch.sin(math.pi*(x[:,0]**2+x[:,1]**2))*torch.cos(3*math.pi*x[:,0]*x[:,1])
return x, y.unsqueeze(1)
def data_spiral(n=1000):
t = torch.linspace(0, 4*np.pi, n//2)
r = torch.linspace(0.3, 2, n//2)
x1 = torch.stack([r*torch.cos(t), r*torch.sin(t)], 1)
x2 = torch.stack([r*torch.cos(t+np.pi), r*torch.sin(t+np.pi)], 1)
x = torch.cat([x1,x2]) + torch.randn(n,2)*0.05
y = torch.cat([torch.zeros(n//2), torch.ones(n//2)]).long()
p = torch.randperm(n); return x[p], y[p]
def data_checker(n=1000, freq=3):
x = torch.rand(n,2)*2-1
y = ((torch.sin(freq*math.pi*x[:,0])*torch.sin(freq*math.pi*x[:,1])) > 0).long()
return x, y
def data_highfreq(n=1000):
x = torch.linspace(-1,1,n).unsqueeze(1)
y = torch.sin(20*x)+torch.sin(50*x)+0.5*torch.sin(100*x)
return x, y
def data_memorize(n=200):
return torch.randn(n, 8), torch.randn(n, 4)
def data_ood_train(n=800):
x = torch.rand(n,2)*2-1
y = torch.sin(3*math.pi*x[:,0]) * torch.cos(3*math.pi*x[:,1]) + x[:,0]*x[:,1]
return x, y.unsqueeze(1)
def data_ood_test(n=300):
x = torch.rand(n,2) + 1
y = torch.sin(3*math.pi*x[:,0]) * torch.cos(3*math.pi*x[:,1]) + x[:,0]*x[:,1]
return x, y.unsqueeze(1)
# ============================================================================
# MAIN
# ============================================================================
def main():
print("="*80)
print(" BENCHMARK v6: THE ADAPTIVE NEURON")
print(" α(x)·periodic + (1-α(x))·linear — learns WHEN to use each")
print(" 3 seeds | vs Vanilla, SinGLU, Hybrid | + OOD + α analysis")
print("="*80)
N_H = 3
models = {
'Vanilla': (VanillaMLP, {}),
'SinGLU': (SinGLUNet, {'omega_0': None}),
'Hybrid': (HybridNet, {'omega_0': None}),
'Adaptive': (AdaptiveNet, {'omega_0': None}),
}
tasks = [
("Complex Fn (4D)", "reg", data_complex, 4,1, 5000, 300, 1e-3, 30.0, 750),
("Nested Fn (2D)", "reg", data_nested, 2,1, 3000, 300, 1e-3, 20.0, 750),
("Spiral", "clf", data_spiral, 2,2, 3000, 250, 1e-3, 15.0, 700),
("Checkerboard", "clf", data_checker, 2,2, 3000, 250, 1e-3, 20.0, 700),
("High-Freq", "reg", data_highfreq, 1,1, 8000, 300, 1e-3, 60.0, 700),
("Memorization", "reg", data_memorize, 8,4, 5000, 400, 1e-3, 10.0, 200),
]
all_results = {}
alpha_analysis = {}
for tname, ttype, dfn, ind, outd, budget, epochs, lr, omega, split in tasks:
print(f"\n{'━'*80}")
print(f" {tname} | budget ~{budget:,} | {len(SEEDS)} seeds")
print(f"{'━'*80}")
# Find hidden dims
hdims = {}
for mn, (mc, mk) in models.items():
kw = {k: (omega if v is None else v) for k,v in mk.items()}
hdims[mn] = find_hidden(ind, outd, N_H, budget, mc, **kw)
task_res = {}
task_alphas = []
for mn, (mc, mk) in models.items():
kw = {k: (omega if v is None else v) for k,v in mk.items()}
h = hdims[mn]
scores = []
for seed in SEEDS:
set_seed(seed)
x, y = dfn()
if split >= len(x): xtr,ytr,xte,yte = x,y,x,y
else: xtr,ytr,xte,yte = x[:split],y[:split],x[split:],y[split:]
set_seed(seed + 100)
model = mc(ind, outd, h, N_H, **kw)
if ttype == 'reg':
s = train_reg(model, xtr, ytr, xte, yte, epochs, lr)
else:
s = train_clf(model, xtr, ytr, xte, yte, epochs, lr)
scores.append(s)
# Collect α values from Adaptive model (last seed only)
if mn == 'Adaptive' and seed == SEEDS[-1] and isinstance(model, AdaptiveNet):
model.eval()
with torch.no_grad():
alphas = model.get_all_alphas(xte[:100])
mean_alpha = torch.cat([a.mean(dim=0) for a in alphas]).mean().item()
std_alpha = torch.cat([a.std(dim=0) for a in alphas]).mean().item()
task_alphas = [mean_alpha, std_alpha]
p = count_params(mc(ind, outd, h, N_H, **kw))
task_res[mn] = {
'mean': np.mean(scores), 'std': np.std(scores),
'scores': scores, 'params': p, 'hidden': h
}
alpha_analysis[tname] = task_alphas
# Print
is_reg = ttype == 'reg'
metric = "MSE ↓" if is_reg else "Acc ↑"
if is_reg:
best_mn = min(task_res, key=lambda k: task_res[k]['mean'])
else:
best_mn = max(task_res, key=lambda k: task_res[k]['mean'])
print(f"\n {'Model':<12} {'H':>4} {'Params':>7} {metric+' (mean±std)':>28}")
print(f" {'─'*56}")
for mn, r in task_res.items():
m, s = r['mean'], r['std']
if is_reg:
if m < 0.001: ms = f"{m:.2e}±{s:.1e}"
else: ms = f"{m:.4f}±{s:.4f}"
else:
ms = f"{m:.1%}±{s:.3f}"
mark = " ★" if mn == best_mn else ""
print(f" {mn:<12} {r['hidden']:>4} {r['params']:>7,} {ms:>28}{mark}")
print(f" → Winner: {best_mn}")
if task_alphas:
print(f" → Adaptive α: mean={task_alphas[0]:.3f}, std={task_alphas[1]:.3f}"
f" ({'mostly periodic' if task_alphas[0] > 0.6 else 'mostly linear' if task_alphas[0] < 0.4 else 'balanced'})")
all_results[tname] = task_res
# ================================================================
# OOD TEST
# ================================================================
print(f"\n{'━'*80}")
print(f" OOD: Train [-1,1] → Test [1,2]")
print(f" f(x1,x2) = sin(3π·x1)·cos(3π·x2) + x1·x2")
print(f" Does Adaptive auto-fallback to linear on OOD?")
print(f"{'━'*80}")
budget_ood = 5000
ood_res = {}
ood_alphas = {}
for mn, (mc, mk) in models.items():
kw = {k: (20.0 if v is None else v) for k,v in mk.items()}
h = find_hidden(2, 1, N_H, budget_ood, mc, **kw)
id_scores, ood_scores = [], []
for seed in SEEDS:
set_seed(seed)
xtr, ytr = data_ood_train()
set_seed(seed + 50)
xid = torch.rand(200,2)*2-1
yid = (torch.sin(3*math.pi*xid[:,0]) * torch.cos(3*math.pi*xid[:,1]) + xid[:,0]*xid[:,1]).unsqueeze(1)
set_seed(seed + 50)
xood, yood = data_ood_test()
set_seed(seed + 100)
model = mc(2, 1, h, N_H, **kw)
s_id = train_reg(model, xtr, ytr, xid, yid, 300, 1e-3)
model.eval()
with torch.no_grad():
s_ood = F.mse_loss(model(xood), yood).item()
id_scores.append(s_id)
ood_scores.append(s_ood)
# Get α on ID vs OOD for adaptive
if mn == 'Adaptive' and seed == SEEDS[-1] and isinstance(model, AdaptiveNet):
model.eval()
with torch.no_grad():
a_id = model.get_all_alphas(xid[:100])
a_ood = model.get_all_alphas(xood[:100])
ood_alphas['ID'] = torch.cat([a.mean(dim=0) for a in a_id]).mean().item()
ood_alphas['OOD'] = torch.cat([a.mean(dim=0) for a in a_ood]).mean().item()
p = count_params(mc(2, 1, h, N_H, **kw))
ood_res[mn] = {
'id_mean': np.mean(id_scores), 'id_std': np.std(id_scores),
'ood_mean': np.mean(ood_scores), 'ood_std': np.std(ood_scores),
'params': p,
'degradation': np.mean(ood_scores) / max(np.mean(id_scores), 1e-10),
}
best_ood_mn = min(ood_res, key=lambda k: ood_res[k]['ood_mean'])
print(f"\n {'Model':<12} {'Params':>7} {'ID MSE':>14} {'OOD MSE':>14} {'Degrad.':>9}")
print(f" {'─'*58}")
for mn, r in ood_res.items():
mark = " ★" if mn == best_ood_mn else ""
print(f" {mn:<12} {r['params']:>7,} {r['id_mean']:>9.4f}±{r['id_std']:.3f}"
f" {r['ood_mean']:>9.4f}±{r['ood_std']:.3f} {r['degradation']:>8.1f}x{mark}")
print(f" → Best OOD: {best_ood_mn}")
if ood_alphas:
print(f"\n ADAPTIVE α SHIFT (the key test):")
print(f" α on in-distribution data: {ood_alphas.get('ID', 'N/A'):.3f}")
print(f" α on out-of-distribution: {ood_alphas.get('OOD', 'N/A'):.3f}")
shift = ood_alphas.get('OOD', 0.5) - ood_alphas.get('ID', 0.5)
if shift < -0.05:
print(f" → α DROPPED by {abs(shift):.3f} on OOD → model learned to reduce periodic! ✅")
elif shift > 0.05:
print(f" → α INCREASED by {shift:.3f} on OOD → model did NOT learn fallback ❌")
else:
print(f" → α roughly stable (shift={shift:+.3f}) → routing not input-sensitive here")
all_results['OOD'] = {mn: {'mean': r['ood_mean'], 'std': r['ood_std']}
for mn, r in ood_res.items()}
# ================================================================
# α ANALYSIS SUMMARY
# ================================================================
print(f"\n{'━'*80}")
print(f" WHAT DID α LEARN? (Routing analysis per task)")
print(f"{'━'*80}")
print(f"\n {'Task':<22} {'Mean α':>8} {'Std α':>8} {'Interpretation':<30}")
print(f" {'─'*70}")
for tname, aa in alpha_analysis.items():
if aa:
interp = ("→ mostly PERIODIC" if aa[0] > 0.6
else "→ mostly LINEAR" if aa[0] < 0.4
else "→ BALANCED mix")
print(f" {tname:<22} {aa[0]:>8.3f} {aa[1]:>8.3f} {interp:<30}")
if ood_alphas:
id_val = ood_alphas.get('ID', 0)
ood_val = ood_alphas.get('OOD', 0)
print(f" {'OOD (out-of-dist)':<22} {ood_val:>8.3f} {'--':>8} "
f"shifted from {id_val:.3f}")
# ================================================================
# GRAND SUMMARY
# ================================================================
print(f"\n{'='*80}")
print(f" GRAND SUMMARY")
print(f"{'='*80}")
win_counts = {k: 0 for k in models}
print(f"\n {'Task':<20}", end="")
for mn in models: print(f" {mn:>12}", end="")
print(f" {'Winner':>10}")
print(f" {'─'*72}")
for tname, tr in all_results.items():
if tname == 'OOD':
scores = {k: v['mean'] for k,v in tr.items()}
winner = min(scores, key=scores.get)
win_counts[winner] += 1
row = f" {tname:<20}"
for mn in models: row += f" {scores[mn]:>12.4f}"
row += f" {'→'+winner:>10}"
print(row)
continue
scores = {k: v['mean'] for k,v in tr.items()}
max_s = max(scores.values())
is_clf = max_s > 0.5 and max_s <= 1.0 and min(scores.values()) >= 0
if min(scores.values()) < 0.001: is_clf = False
if is_clf: winner = max(scores, key=scores.get)
else: winner = min(scores, key=scores.get)
win_counts[winner] += 1
row = f" {tname:<20}"
for mn in models:
s = scores[mn]
if is_clf: row += f" {s:>11.1%}"
elif s < 0.001: row += f" {s:>11.2e}"
else: row += f" {s:>11.4f}"
row += f" {'→'+winner:>10}"
print(row)
print(f"\n {'─'*72}")
print(f" WIN COUNTS:")
for mn, c in sorted(win_counts.items(), key=lambda x: -x[1]):
print(f" {mn:<14} {c} wins {'█'*c*3}")
print(f"""
╔════════════════════════════════════════════════════════════════════════════╗
║ THE ADAPTIVE NEURON: VERDICT ║
║ ║
║ The question: can a neuron that LEARNS when to be periodic vs linear ║
║ get the best of both worlds? ║
║ ║
║ Check the α values above — they tell the whole story: ║
║ • If α is high on periodic tasks and low on linear tasks → it works ║
║ • If α is ~0.5 everywhere → routing didn't learn, it's just averaging ║
║ • If α shifts down on OOD → it learned to auto-fallback → that's huge ║
╚════════════════════════════════════════════════════════════════════════════╝
""")
# Save
save = {'main_tasks': {}, 'ood': {}, 'alpha_analysis': {}, 'ood_alphas': ood_alphas}
for tname, tr in all_results.items():
save['main_tasks'][tname] = {
mn: {'mean': float(r['mean']), 'std': float(r.get('std', 0)),
'scores': [float(s) for s in r.get('scores', [r['mean']])],
'params': r.get('params', 0), 'hidden': r.get('hidden', 0)}
for mn, r in tr.items()
}
save['alpha_analysis'] = {k: v for k,v in alpha_analysis.items()}
save['ood'] = {mn: {k: float(v) if isinstance(v, (float, np.floating)) else v
for k,v in r.items()} for mn, r in ood_res.items()}
with open('/app/results_v6.json', 'w') as f:
json.dump(save, f, indent=2, default=str)
print(" Results saved to /app/results_v6.json")
if __name__ == "__main__":
main()