v4.2: Add PyTorch learned conv solver (single+two-layer, multi-seed, ternary snap). Needs GPU to be practical - use on Kaggle with --conv_budget 60
Browse files- neurogolf_solver.py +220 -0
neurogolf_solver.py
CHANGED
|
@@ -1261,6 +1261,220 @@ def solve_conv_var_diff(td, path, time_budget=30.0):
|
|
| 1261 |
if validate(path, td): return 'conv_var_diff', model
|
| 1262 |
return None
|
| 1263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1264 |
# ============================================================
|
| 1265 |
# MAIN
|
| 1266 |
# ============================================================
|
|
@@ -1316,6 +1530,12 @@ def solve_task(tn, td, outdir, conv_budget=30.0):
|
|
| 1316 |
if result is not None:
|
| 1317 |
sname, model = result
|
| 1318 |
return True, sname, os.path.getsize(path), time.time() - t_start, path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1319 |
else:
|
| 1320 |
sp = fixed_shapes(td)
|
| 1321 |
if sp is not None:
|
|
|
|
| 1261 |
if validate(path, td): return 'conv_var_diff', model
|
| 1262 |
return None
|
| 1263 |
|
| 1264 |
+
# ============================================================
|
| 1265 |
+
# PYTORCH LEARNED CONV (gradient descent, multi-seed, ternary snap)
|
| 1266 |
+
# ============================================================
|
| 1267 |
+
|
| 1268 |
+
def _ternary_snap(w, eps=0.2):
|
| 1269 |
+
"""Snap weights to {-1, 0, 1} — smaller model, often still correct."""
|
| 1270 |
+
return np.where(w > eps, 1.0, np.where(w < -eps, -1.0, 0.0)).astype(np.float32)
|
| 1271 |
+
|
| 1272 |
+
def _build_conv_onnx_from_weights(W, ks, use_full_30=False, IH=None, IW=None):
|
| 1273 |
+
"""Build ONNX conv model from numpy weight array W [10,10,ks,ks].
|
| 1274 |
+
For fixed-shape: Slice→Conv→ArgMax→Equal+Cast→Pad
|
| 1275 |
+
For variable/full30: Conv→ArgMax→Equal+Cast→Mul(mask)"""
|
| 1276 |
+
pad = ks // 2
|
| 1277 |
+
if use_full_30:
|
| 1278 |
+
# Variable shape: full 30x30 conv with mask
|
| 1279 |
+
inits = [numpy_helper.from_array(W, 'W')]
|
| 1280 |
+
nodes = [
|
| 1281 |
+
helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
|
| 1282 |
+
helper.make_node('Conv', ['input', 'W'], ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
|
| 1283 |
+
helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
|
| 1284 |
+
]
|
| 1285 |
+
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1286 |
+
nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
|
| 1287 |
+
return mk(nodes, inits)
|
| 1288 |
+
else:
|
| 1289 |
+
# Fixed shape: slice, conv, pad
|
| 1290 |
+
pad_h, pad_w = GH - IH, GW - IW
|
| 1291 |
+
inits = [
|
| 1292 |
+
numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
|
| 1293 |
+
numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
|
| 1294 |
+
numpy_helper.from_array(W, 'W'),
|
| 1295 |
+
]
|
| 1296 |
+
nodes = [
|
| 1297 |
+
helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
|
| 1298 |
+
helper.make_node('Conv', ['grid', 'W'], ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
|
| 1299 |
+
helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
|
| 1300 |
+
]
|
| 1301 |
+
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1302 |
+
nodes.append(
|
| 1303 |
+
helper.make_node('Pad', ['oh_out'], ['output'],
|
| 1304 |
+
pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
|
| 1305 |
+
)
|
| 1306 |
+
return mk(nodes, inits)
|
| 1307 |
+
|
| 1308 |
+
def _build_two_layer_conv_onnx(W1, W2, ks1, ks2, use_full_30=False, IH=None, IW=None):
|
| 1309 |
+
"""Build ONNX two-layer conv: Conv→ReLU→Conv→ArgMax→Equal+Cast→Pad/Mul(mask)."""
|
| 1310 |
+
pad1, pad2 = ks1 // 2, ks2 // 2
|
| 1311 |
+
if use_full_30:
|
| 1312 |
+
inits = [
|
| 1313 |
+
numpy_helper.from_array(W1, 'W1'),
|
| 1314 |
+
numpy_helper.from_array(W2, 'W2'),
|
| 1315 |
+
]
|
| 1316 |
+
nodes = [
|
| 1317 |
+
helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
|
| 1318 |
+
helper.make_node('Conv', ['input', 'W1'], ['h1'], kernel_shape=[ks1,ks1], pads=[pad1]*4),
|
| 1319 |
+
helper.make_node('Relu', ['h1'], ['h1r']),
|
| 1320 |
+
helper.make_node('Conv', ['h1r', 'W2'], ['co'], kernel_shape=[ks2,ks2], pads=[pad2]*4),
|
| 1321 |
+
helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
|
| 1322 |
+
]
|
| 1323 |
+
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1324 |
+
nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
|
| 1325 |
+
return mk(nodes, inits)
|
| 1326 |
+
else:
|
| 1327 |
+
pad_h, pad_w = GH - IH, GW - IW
|
| 1328 |
+
inits = [
|
| 1329 |
+
numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
|
| 1330 |
+
numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
|
| 1331 |
+
numpy_helper.from_array(W1, 'W1'),
|
| 1332 |
+
numpy_helper.from_array(W2, 'W2'),
|
| 1333 |
+
]
|
| 1334 |
+
nodes = [
|
| 1335 |
+
helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
|
| 1336 |
+
helper.make_node('Conv', ['grid', 'W1'], ['h1'], kernel_shape=[ks1,ks1], pads=[pad1]*4),
|
| 1337 |
+
helper.make_node('Relu', ['h1'], ['h1r']),
|
| 1338 |
+
helper.make_node('Conv', ['h1r', 'W2'], ['co'], kernel_shape=[ks2,ks2], pads=[pad2]*4),
|
| 1339 |
+
helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
|
| 1340 |
+
]
|
| 1341 |
+
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1342 |
+
nodes.append(
|
| 1343 |
+
helper.make_node('Pad', ['oh_out'], ['output'],
|
| 1344 |
+
pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
|
| 1345 |
+
)
|
| 1346 |
+
return mk(nodes, inits)
|
| 1347 |
+
|
| 1348 |
+
def solve_pytorch_conv(td, path, time_budget=30.0):
|
| 1349 |
+
"""PyTorch gradient descent conv solver. Tries single-layer then two-layer.
|
| 1350 |
+
Multi-seed training with ternary weight snapping for smaller models.
|
| 1351 |
+
Validates against arc-gen before accepting."""
|
| 1352 |
+
try:
|
| 1353 |
+
import torch
|
| 1354 |
+
import torch.nn as nn
|
| 1355 |
+
import copy as _copy
|
| 1356 |
+
except ImportError:
|
| 1357 |
+
return None
|
| 1358 |
+
|
| 1359 |
+
exs = get_exs(td)
|
| 1360 |
+
same_shape = all(inp.shape == out.shape for inp, out in exs)
|
| 1361 |
+
if not same_shape:
|
| 1362 |
+
return None # Only handle same-shape for now
|
| 1363 |
+
|
| 1364 |
+
shapes = set(inp.shape for inp, _ in exs)
|
| 1365 |
+
fixed_in = len(shapes) == 1
|
| 1366 |
+
|
| 1367 |
+
# Prepare tensors
|
| 1368 |
+
all_pairs = td['train'] + td['test']
|
| 1369 |
+
inp_list = [to_onehot(p['input'])[0] for p in all_pairs]
|
| 1370 |
+
out_list = [to_onehot(p['output'])[0] for p in all_pairs]
|
| 1371 |
+
inp_t = torch.tensor(np.stack(inp_list), dtype=torch.float32)
|
| 1372 |
+
out_t = torch.tensor(np.stack(out_list), dtype=torch.float32)
|
| 1373 |
+
|
| 1374 |
+
if fixed_in:
|
| 1375 |
+
IH, IW = list(shapes)[0]
|
| 1376 |
+
# Train on cropped region
|
| 1377 |
+
inp_t = inp_t[:, :, :IH, :IW]
|
| 1378 |
+
out_t = out_t[:, :, :IH, :IW]
|
| 1379 |
+
|
| 1380 |
+
t_start = time.time()
|
| 1381 |
+
best_result = None
|
| 1382 |
+
|
| 1383 |
+
# Phase 1: Single-layer conv (multiple kernel sizes and seeds)
|
| 1384 |
+
for ks in [1, 3, 5, 7]:
|
| 1385 |
+
if time.time() - t_start > time_budget * 0.6:
|
| 1386 |
+
break
|
| 1387 |
+
pad = ks // 2
|
| 1388 |
+
for seed in [0, 7, 42]:
|
| 1389 |
+
if time.time() - t_start > time_budget * 0.6:
|
| 1390 |
+
break
|
| 1391 |
+
torch.manual_seed(seed)
|
| 1392 |
+
conv = nn.Conv2d(CH, CH, kernel_size=ks, padding=pad, bias=False)
|
| 1393 |
+
if seed == 0:
|
| 1394 |
+
nn.init.zeros_(conv.weight)
|
| 1395 |
+
opt = torch.optim.Adam(conv.parameters(), lr=0.03)
|
| 1396 |
+
best_loss, best_state = float('inf'), None
|
| 1397 |
+
for step in range(3000):
|
| 1398 |
+
opt.zero_grad()
|
| 1399 |
+
pred = conv(inp_t)
|
| 1400 |
+
loss = nn.functional.mse_loss(pred, out_t)
|
| 1401 |
+
loss.backward()
|
| 1402 |
+
opt.step()
|
| 1403 |
+
if loss.item() < best_loss:
|
| 1404 |
+
best_loss = loss.item()
|
| 1405 |
+
best_state = _copy.deepcopy(conv.state_dict())
|
| 1406 |
+
if best_loss < 1e-8:
|
| 1407 |
+
break
|
| 1408 |
+
if best_state is None:
|
| 1409 |
+
continue
|
| 1410 |
+
conv.load_state_dict(best_state)
|
| 1411 |
+
w = conv.weight.detach().numpy()
|
| 1412 |
+
|
| 1413 |
+
# Try continuous weights, then ternary-snapped
|
| 1414 |
+
for w_cand in [w, _ternary_snap(w)]:
|
| 1415 |
+
use_full = not fixed_in
|
| 1416 |
+
model = _build_conv_onnx_from_weights(
|
| 1417 |
+
w_cand, ks, use_full_30=use_full,
|
| 1418 |
+
IH=IH if fixed_in else None,
|
| 1419 |
+
IW=IW if fixed_in else None
|
| 1420 |
+
)
|
| 1421 |
+
onnx.save(model, path)
|
| 1422 |
+
if validate(path, td):
|
| 1423 |
+
sz = os.path.getsize(path)
|
| 1424 |
+
if best_result is None or sz < best_result[2]:
|
| 1425 |
+
best_result = ('pt_conv', model, sz)
|
| 1426 |
+
|
| 1427 |
+
# Phase 2: Two-layer conv (Conv→ReLU→Conv)
|
| 1428 |
+
for ks1, ks2, hidden in [(3, 1, CH), (5, 1, CH), (3, 3, CH)]:
|
| 1429 |
+
if time.time() - t_start > time_budget:
|
| 1430 |
+
break
|
| 1431 |
+
for seed in [0, 7]:
|
| 1432 |
+
if time.time() - t_start > time_budget:
|
| 1433 |
+
break
|
| 1434 |
+
torch.manual_seed(seed)
|
| 1435 |
+
net = nn.Sequential(
|
| 1436 |
+
nn.Conv2d(CH, hidden, kernel_size=ks1, padding=ks1//2, bias=False),
|
| 1437 |
+
nn.ReLU(),
|
| 1438 |
+
nn.Conv2d(hidden, CH, kernel_size=ks2, padding=ks2//2, bias=False),
|
| 1439 |
+
)
|
| 1440 |
+
opt = torch.optim.Adam(net.parameters(), lr=0.01)
|
| 1441 |
+
best_loss, best_state = float('inf'), None
|
| 1442 |
+
for step in range(2500):
|
| 1443 |
+
opt.zero_grad()
|
| 1444 |
+
pred = net(inp_t)
|
| 1445 |
+
loss = nn.functional.mse_loss(pred, out_t)
|
| 1446 |
+
loss.backward()
|
| 1447 |
+
opt.step()
|
| 1448 |
+
if loss.item() < best_loss:
|
| 1449 |
+
best_loss = loss.item()
|
| 1450 |
+
best_state = _copy.deepcopy(net.state_dict())
|
| 1451 |
+
if best_loss < 1e-8:
|
| 1452 |
+
break
|
| 1453 |
+
if best_state is None:
|
| 1454 |
+
continue
|
| 1455 |
+
net.load_state_dict(best_state)
|
| 1456 |
+
w1 = net[0].weight.detach().numpy()
|
| 1457 |
+
w2 = net[2].weight.detach().numpy()
|
| 1458 |
+
|
| 1459 |
+
for w1c, w2c in [(w1, w2), (_ternary_snap(w1), _ternary_snap(w2))]:
|
| 1460 |
+
use_full = not fixed_in
|
| 1461 |
+
model = _build_two_layer_conv_onnx(
|
| 1462 |
+
w1c, w2c, ks1, ks2, use_full_30=use_full,
|
| 1463 |
+
IH=IH if fixed_in else None,
|
| 1464 |
+
IW=IW if fixed_in else None
|
| 1465 |
+
)
|
| 1466 |
+
onnx.save(model, path)
|
| 1467 |
+
if validate(path, td):
|
| 1468 |
+
sz = os.path.getsize(path)
|
| 1469 |
+
if best_result is None or sz < best_result[2]:
|
| 1470 |
+
best_result = ('pt_conv2', model, sz)
|
| 1471 |
+
|
| 1472 |
+
if best_result is not None:
|
| 1473 |
+
sname, model, _ = best_result
|
| 1474 |
+
onnx.save(model, path)
|
| 1475 |
+
return sname, model
|
| 1476 |
+
return None
|
| 1477 |
+
|
| 1478 |
# ============================================================
|
| 1479 |
# MAIN
|
| 1480 |
# ============================================================
|
|
|
|
| 1530 |
if result is not None:
|
| 1531 |
sname, model = result
|
| 1532 |
return True, sname, os.path.getsize(path), time.time() - t_start, path
|
| 1533 |
+
# 3. PyTorch learned conv as fallback for same-shape tasks
|
| 1534 |
+
remaining = max(1, conv_time - (time.time() - t_start))
|
| 1535 |
+
result = solve_pytorch_conv(td, path, time_budget=remaining)
|
| 1536 |
+
if result is not None:
|
| 1537 |
+
sname, model = result
|
| 1538 |
+
return True, sname, os.path.getsize(path), time.time() - t_start, path
|
| 1539 |
else:
|
| 1540 |
sp = fixed_shapes(td)
|
| 1541 |
if sp is not None:
|