rogermt commited on
Commit
99c34bc
·
verified ·
1 Parent(s): 92d1187

v4.2: Add PyTorch learned conv solver (single+two-layer, multi-seed, ternary snap). Needs GPU to be practical - use on Kaggle with --conv_budget 60

Browse files
Files changed (1) hide show
  1. neurogolf_solver.py +220 -0
neurogolf_solver.py CHANGED
@@ -1261,6 +1261,220 @@ def solve_conv_var_diff(td, path, time_budget=30.0):
1261
  if validate(path, td): return 'conv_var_diff', model
1262
  return None
1263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1264
  # ============================================================
1265
  # MAIN
1266
  # ============================================================
@@ -1316,6 +1530,12 @@ def solve_task(tn, td, outdir, conv_budget=30.0):
1316
  if result is not None:
1317
  sname, model = result
1318
  return True, sname, os.path.getsize(path), time.time() - t_start, path
 
 
 
 
 
 
1319
  else:
1320
  sp = fixed_shapes(td)
1321
  if sp is not None:
 
1261
  if validate(path, td): return 'conv_var_diff', model
1262
  return None
1263
 
1264
+ # ============================================================
1265
+ # PYTORCH LEARNED CONV (gradient descent, multi-seed, ternary snap)
1266
+ # ============================================================
1267
+
1268
+ def _ternary_snap(w, eps=0.2):
1269
+ """Snap weights to {-1, 0, 1} — smaller model, often still correct."""
1270
+ return np.where(w > eps, 1.0, np.where(w < -eps, -1.0, 0.0)).astype(np.float32)
1271
+
1272
+ def _build_conv_onnx_from_weights(W, ks, use_full_30=False, IH=None, IW=None):
1273
+ """Build ONNX conv model from numpy weight array W [10,10,ks,ks].
1274
+ For fixed-shape: Slice→Conv→ArgMax→Equal+Cast→Pad
1275
+ For variable/full30: Conv→ArgMax→Equal+Cast→Mul(mask)"""
1276
+ pad = ks // 2
1277
+ if use_full_30:
1278
+ # Variable shape: full 30x30 conv with mask
1279
+ inits = [numpy_helper.from_array(W, 'W')]
1280
+ nodes = [
1281
+ helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1282
+ helper.make_node('Conv', ['input', 'W'], ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1283
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1284
+ ]
1285
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1286
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1287
+ return mk(nodes, inits)
1288
+ else:
1289
+ # Fixed shape: slice, conv, pad
1290
+ pad_h, pad_w = GH - IH, GW - IW
1291
+ inits = [
1292
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1293
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1294
+ numpy_helper.from_array(W, 'W'),
1295
+ ]
1296
+ nodes = [
1297
+ helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1298
+ helper.make_node('Conv', ['grid', 'W'], ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1299
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1300
+ ]
1301
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1302
+ nodes.append(
1303
+ helper.make_node('Pad', ['oh_out'], ['output'],
1304
+ pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1305
+ )
1306
+ return mk(nodes, inits)
1307
+
1308
+ def _build_two_layer_conv_onnx(W1, W2, ks1, ks2, use_full_30=False, IH=None, IW=None):
1309
+ """Build ONNX two-layer conv: Conv→ReLU→Conv→ArgMax→Equal+Cast→Pad/Mul(mask)."""
1310
+ pad1, pad2 = ks1 // 2, ks2 // 2
1311
+ if use_full_30:
1312
+ inits = [
1313
+ numpy_helper.from_array(W1, 'W1'),
1314
+ numpy_helper.from_array(W2, 'W2'),
1315
+ ]
1316
+ nodes = [
1317
+ helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1318
+ helper.make_node('Conv', ['input', 'W1'], ['h1'], kernel_shape=[ks1,ks1], pads=[pad1]*4),
1319
+ helper.make_node('Relu', ['h1'], ['h1r']),
1320
+ helper.make_node('Conv', ['h1r', 'W2'], ['co'], kernel_shape=[ks2,ks2], pads=[pad2]*4),
1321
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1322
+ ]
1323
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1324
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1325
+ return mk(nodes, inits)
1326
+ else:
1327
+ pad_h, pad_w = GH - IH, GW - IW
1328
+ inits = [
1329
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1330
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1331
+ numpy_helper.from_array(W1, 'W1'),
1332
+ numpy_helper.from_array(W2, 'W2'),
1333
+ ]
1334
+ nodes = [
1335
+ helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1336
+ helper.make_node('Conv', ['grid', 'W1'], ['h1'], kernel_shape=[ks1,ks1], pads=[pad1]*4),
1337
+ helper.make_node('Relu', ['h1'], ['h1r']),
1338
+ helper.make_node('Conv', ['h1r', 'W2'], ['co'], kernel_shape=[ks2,ks2], pads=[pad2]*4),
1339
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1340
+ ]
1341
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1342
+ nodes.append(
1343
+ helper.make_node('Pad', ['oh_out'], ['output'],
1344
+ pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1345
+ )
1346
+ return mk(nodes, inits)
1347
+
1348
+ def solve_pytorch_conv(td, path, time_budget=30.0):
1349
+ """PyTorch gradient descent conv solver. Tries single-layer then two-layer.
1350
+ Multi-seed training with ternary weight snapping for smaller models.
1351
+ Validates against arc-gen before accepting."""
1352
+ try:
1353
+ import torch
1354
+ import torch.nn as nn
1355
+ import copy as _copy
1356
+ except ImportError:
1357
+ return None
1358
+
1359
+ exs = get_exs(td)
1360
+ same_shape = all(inp.shape == out.shape for inp, out in exs)
1361
+ if not same_shape:
1362
+ return None # Only handle same-shape for now
1363
+
1364
+ shapes = set(inp.shape for inp, _ in exs)
1365
+ fixed_in = len(shapes) == 1
1366
+
1367
+ # Prepare tensors
1368
+ all_pairs = td['train'] + td['test']
1369
+ inp_list = [to_onehot(p['input'])[0] for p in all_pairs]
1370
+ out_list = [to_onehot(p['output'])[0] for p in all_pairs]
1371
+ inp_t = torch.tensor(np.stack(inp_list), dtype=torch.float32)
1372
+ out_t = torch.tensor(np.stack(out_list), dtype=torch.float32)
1373
+
1374
+ if fixed_in:
1375
+ IH, IW = list(shapes)[0]
1376
+ # Train on cropped region
1377
+ inp_t = inp_t[:, :, :IH, :IW]
1378
+ out_t = out_t[:, :, :IH, :IW]
1379
+
1380
+ t_start = time.time()
1381
+ best_result = None
1382
+
1383
+ # Phase 1: Single-layer conv (multiple kernel sizes and seeds)
1384
+ for ks in [1, 3, 5, 7]:
1385
+ if time.time() - t_start > time_budget * 0.6:
1386
+ break
1387
+ pad = ks // 2
1388
+ for seed in [0, 7, 42]:
1389
+ if time.time() - t_start > time_budget * 0.6:
1390
+ break
1391
+ torch.manual_seed(seed)
1392
+ conv = nn.Conv2d(CH, CH, kernel_size=ks, padding=pad, bias=False)
1393
+ if seed == 0:
1394
+ nn.init.zeros_(conv.weight)
1395
+ opt = torch.optim.Adam(conv.parameters(), lr=0.03)
1396
+ best_loss, best_state = float('inf'), None
1397
+ for step in range(3000):
1398
+ opt.zero_grad()
1399
+ pred = conv(inp_t)
1400
+ loss = nn.functional.mse_loss(pred, out_t)
1401
+ loss.backward()
1402
+ opt.step()
1403
+ if loss.item() < best_loss:
1404
+ best_loss = loss.item()
1405
+ best_state = _copy.deepcopy(conv.state_dict())
1406
+ if best_loss < 1e-8:
1407
+ break
1408
+ if best_state is None:
1409
+ continue
1410
+ conv.load_state_dict(best_state)
1411
+ w = conv.weight.detach().numpy()
1412
+
1413
+ # Try continuous weights, then ternary-snapped
1414
+ for w_cand in [w, _ternary_snap(w)]:
1415
+ use_full = not fixed_in
1416
+ model = _build_conv_onnx_from_weights(
1417
+ w_cand, ks, use_full_30=use_full,
1418
+ IH=IH if fixed_in else None,
1419
+ IW=IW if fixed_in else None
1420
+ )
1421
+ onnx.save(model, path)
1422
+ if validate(path, td):
1423
+ sz = os.path.getsize(path)
1424
+ if best_result is None or sz < best_result[2]:
1425
+ best_result = ('pt_conv', model, sz)
1426
+
1427
+ # Phase 2: Two-layer conv (Conv→ReLU→Conv)
1428
+ for ks1, ks2, hidden in [(3, 1, CH), (5, 1, CH), (3, 3, CH)]:
1429
+ if time.time() - t_start > time_budget:
1430
+ break
1431
+ for seed in [0, 7]:
1432
+ if time.time() - t_start > time_budget:
1433
+ break
1434
+ torch.manual_seed(seed)
1435
+ net = nn.Sequential(
1436
+ nn.Conv2d(CH, hidden, kernel_size=ks1, padding=ks1//2, bias=False),
1437
+ nn.ReLU(),
1438
+ nn.Conv2d(hidden, CH, kernel_size=ks2, padding=ks2//2, bias=False),
1439
+ )
1440
+ opt = torch.optim.Adam(net.parameters(), lr=0.01)
1441
+ best_loss, best_state = float('inf'), None
1442
+ for step in range(2500):
1443
+ opt.zero_grad()
1444
+ pred = net(inp_t)
1445
+ loss = nn.functional.mse_loss(pred, out_t)
1446
+ loss.backward()
1447
+ opt.step()
1448
+ if loss.item() < best_loss:
1449
+ best_loss = loss.item()
1450
+ best_state = _copy.deepcopy(net.state_dict())
1451
+ if best_loss < 1e-8:
1452
+ break
1453
+ if best_state is None:
1454
+ continue
1455
+ net.load_state_dict(best_state)
1456
+ w1 = net[0].weight.detach().numpy()
1457
+ w2 = net[2].weight.detach().numpy()
1458
+
1459
+ for w1c, w2c in [(w1, w2), (_ternary_snap(w1), _ternary_snap(w2))]:
1460
+ use_full = not fixed_in
1461
+ model = _build_two_layer_conv_onnx(
1462
+ w1c, w2c, ks1, ks2, use_full_30=use_full,
1463
+ IH=IH if fixed_in else None,
1464
+ IW=IW if fixed_in else None
1465
+ )
1466
+ onnx.save(model, path)
1467
+ if validate(path, td):
1468
+ sz = os.path.getsize(path)
1469
+ if best_result is None or sz < best_result[2]:
1470
+ best_result = ('pt_conv2', model, sz)
1471
+
1472
+ if best_result is not None:
1473
+ sname, model, _ = best_result
1474
+ onnx.save(model, path)
1475
+ return sname, model
1476
+ return None
1477
+
1478
  # ============================================================
1479
  # MAIN
1480
  # ============================================================
 
1530
  if result is not None:
1531
  sname, model = result
1532
  return True, sname, os.path.getsize(path), time.time() - t_start, path
1533
+ # 3. PyTorch learned conv as fallback for same-shape tasks
1534
+ remaining = max(1, conv_time - (time.time() - t_start))
1535
+ result = solve_pytorch_conv(td, path, time_budget=remaining)
1536
+ if result is not None:
1537
+ sname, model = result
1538
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1539
  else:
1540
  sp = fixed_shapes(td)
1541
  if sp is not None: