rogermt commited on
Commit
2427dfd
·
verified ·
1 Parent(s): 7c05244

Update neurogolf_solver.py

Browse files
Files changed (1) hide show
  1. neurogolf_solver.py +488 -713
neurogolf_solver.py CHANGED
@@ -1,34 +1,31 @@
1
  #!/usr/bin/env python3
2
  """
3
- ARC-AGI NeuroGolf Championship - Complete Solver v5
4
- Format: [1,10,30,30] one-hot input/output, opset 17, IR version 10.
5
-
6
- v5 CHANGES:
7
- - Switched to opset 17 (Kaggle-compatible) for cheaper analytical solvers
8
- - Slice-based analytical solvers: rotation, flip, transpose (near-zero cost)
9
- - LOOCV Ridge tuning in _lstsq_conv with condition number check + SVD-based λ auto-tune
10
- - stride_tricks speedup for patch extraction
11
- - Composition detectors: rotation+color, flip+color, transpose+color
12
- - Channel reduction wrapper for tasks with <8 colors
13
- - ARC-GEN validation, EXCLUDED tasks skipped, submission.csv generation
14
 
15
  Solvers:
16
  - Analytical: identity, constant, color_map, transpose, flip, rotate, tile, upscale,
17
  concat, concat_enhanced, spatial_gather, varshape_spatial_gather,
18
- diagonal_tile, kronecker, shift, mirror_h, mirror_v, quad_mirror,
19
- fixed_crop, nonuniform_scale
20
- - Composition: rotate+color_map, flip+color_map, transpose+color_map
21
  - Conv (fixed shape): Slice -> Conv -> ArgMax -> Equal+Cast -> Pad
22
  - Conv (variable shape): Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)
23
  - Conv (diff shape): Slice -> Conv -> Slice(crop) -> ArgMax -> Equal+Cast -> Pad
24
- - Channel reduction: Conv1x1(10->N) -> transform -> Conv1x1(N->10)
25
 
26
  Usage:
27
  python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission
28
  python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission --conv_budget 60 --arcgen_dir ARC-GEN-100K/
29
  """
30
 
31
- import json, os, sys, math, time, argparse, csv, io, zipfile, warnings
32
  import numpy as np
33
  import onnx
34
  from onnx import helper, TensorProto, numpy_helper
@@ -50,17 +47,17 @@ BATCH, CH, GH, GW = 1, 10, 30, 30
50
  GRID_SHAPE = [BATCH, CH, GH, GW]
51
  DT = TensorProto.FLOAT
52
  IR = 10
53
- # v5: opset 17 for cheaper Slice-based transforms
54
- OPSET = [helper.make_opsetid("", 17)]
55
 
56
  # Officially excluded tasks (score 0 regardless)
57
  EXCLUDED_TASKS = {21, 55, 80, 184, 202, 366}
58
 
59
  # Max ARC-GEN examples to use for validation (to keep runtime reasonable)
60
  MAX_ARCGEN_VALIDATE = 30
61
-
62
- # Max ARC-GEN examples for conv fitting
63
- MAX_ARCGEN_FIT = 0
 
64
 
65
  def get_providers():
66
  return ['CPUExecutionProvider']
@@ -79,6 +76,7 @@ def load_tasks_dir(data_dir, arcgen_dir=None):
79
  with open(os.path.join(data_dir, f)) as fh:
80
  data = json.load(fh)
81
  hex_id = f.replace('.json','')
 
82
  if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)):
83
  with open(os.path.join(arcgen_dir, f)) as fh:
84
  arcgen_examples = json.load(fh)
@@ -111,7 +109,8 @@ def to_onehot(grid):
111
  return arr
112
 
113
  def validate(path, td):
114
- """Validate model against ALL examples: train + test + arc-gen."""
 
115
  try:
116
  opts = ort.SessionOptions()
117
  opts.log_severity_level = 3
@@ -119,6 +118,7 @@ def validate(path, td):
119
  except:
120
  return False
121
  examples = td['train'] + td['test']
 
122
  if 'arc-gen' in td:
123
  examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
124
  for ex in examples:
@@ -164,7 +164,8 @@ BANNED_OPS = {'Loop', 'Scan', 'NonZero', 'Unique', 'If', 'Function'}
164
  MAX_FILESIZE = int(1.44 * 1024 * 1024)
165
 
166
  def score_network(path):
167
- """Static profiler matching Kaggle scoring: cost = macs + memory + params."""
 
168
  if HAS_ONNX_TOOL:
169
  try:
170
  return _score_network_official(path)
@@ -226,19 +227,24 @@ def get_exs(td):
226
  for ex in td['train'] + td['test']]
227
 
228
  def get_exs_for_fitting(td):
229
- """Get examples for conv fitting. Uses train+test + arc-gen WHERE SIZES MATCH."""
 
 
 
230
  base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
231
  for ex in td['train'] + td['test']]
232
 
233
  if not base_exs:
234
  return base_exs
235
 
 
236
  base_shapes = {inp.shape for inp, _ in base_exs}
237
  if len(base_shapes) != 1:
238
- return base_exs
239
 
240
  base_shape = list(base_shapes)[0]
241
 
 
242
  ag_exs = []
243
  for ex in td.get('arc-gen', []):
244
  inp = np.array(ex['input'], dtype=np.int64)
@@ -246,13 +252,17 @@ def get_exs_for_fitting(td):
246
  if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
247
  ag_exs.append((inp, out))
248
 
 
249
  return base_exs + ag_exs[:10]
250
 
251
  def get_exs_for_fitting_variable(td):
252
- """Get examples for variable-shape conv fitting."""
 
 
253
  base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
254
  for ex in td['train'] + td['test']]
255
 
 
256
  ag_exs = []
257
  for ex in td.get('arc-gen', []):
258
  inp = np.array(ex['input'], dtype=np.int64)
@@ -269,11 +279,12 @@ def fixed_shapes(td):
269
  return list(shapes)[0] if len(shapes) == 1 else None
270
 
271
  # ============================================================
272
- # GATHER HELPERS (opset 17 compatible)
273
  # ============================================================
274
 
275
  def _build_gather_model(OH, OW, idx):
276
- """Build Gather-based spatial remapping model."""
 
277
  flat_idx = np.zeros((GH*GW,), dtype=np.int64)
278
  mask = np.zeros((1,1,GH,GW), dtype=np.float32)
279
  for oi in range(OH):
@@ -295,7 +306,7 @@ def _build_gather_model(OH, OW, idx):
295
  return mk(nodes, inits)
296
 
297
  def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
298
- """Build Gather model with constant fill for unmapped positions."""
299
  flat_idx = np.zeros((GH*GW,), dtype=np.int64)
300
  gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
301
  const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
@@ -326,168 +337,6 @@ def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
326
  nodes[-1] = helper.make_node('Mul', ['raw','gmask'], ['output'])
327
  return mk(nodes, inits)
328
 
329
- # ============================================================
330
- # SLICE-BASED ANALYTICAL SOLVERS (opset 17, ~0 cost)
331
- # ============================================================
332
-
333
- def _build_pad_nodes(input_name, IH, IW, output_name='output', pad_name='pads'):
334
- """Build Pad nodes to pad spatial dims to 30x30 (opset 17 with tensor pads).
335
- Returns (pad_inits, pad_node)."""
336
- pad_h, pad_w = GH - IH, GW - IW
337
- if pad_h > 0 or pad_w > 0:
338
- pads_arr = np.array([0, 0, 0, 0, 0, 0, pad_h, pad_w], dtype=np.int64)
339
- pad_inits = [numpy_helper.from_array(pads_arr, pad_name)]
340
- pad_node = helper.make_node('Pad', [input_name, pad_name], [output_name], mode='constant')
341
- return pad_inits, pad_node
342
- else:
343
- return [], helper.make_node('Identity', [input_name], [output_name])
344
-
345
- def _build_slice_flip_model(axis, IH, IW):
346
- """Build a Slice-based flip model using negative steps (opset 17).
347
- Extracts content, applies flip, pads back to 30x30.
348
- axis=0: vertical flip (reverse rows), axis=1: horizontal flip (reverse cols).
349
- """
350
- # Step 1: Extract content region [1,10,30,30] -> [1,10,IH,IW]
351
- ex_st = np.array([0,0,0,0], dtype=np.int64)
352
- ex_en = np.array([1,10,IH,IW], dtype=np.int64)
353
-
354
- # Step 2: Flip with negative step Slice
355
- if axis == 0:
356
- starts = np.array([IH-1], dtype=np.int64)
357
- ends = np.array([-IH-1], dtype=np.int64)
358
- axes = np.array([2], dtype=np.int64)
359
- steps = np.array([-1], dtype=np.int64)
360
- else:
361
- starts = np.array([IW-1], dtype=np.int64)
362
- ends = np.array([-IW-1], dtype=np.int64)
363
- axes = np.array([3], dtype=np.int64)
364
- steps = np.array([-1], dtype=np.int64)
365
-
366
- inits = [
367
- numpy_helper.from_array(ex_st, 'ex_st'),
368
- numpy_helper.from_array(ex_en, 'ex_en'),
369
- numpy_helper.from_array(starts, 'sl_st'),
370
- numpy_helper.from_array(ends, 'sl_en'),
371
- numpy_helper.from_array(axes, 'sl_ax'),
372
- numpy_helper.from_array(steps, 'sl_sp'),
373
- ]
374
- nodes = [
375
- helper.make_node('Slice', ['input','ex_st','ex_en'], ['content']),
376
- helper.make_node('Slice', ['content','sl_st','sl_en','sl_ax','sl_sp'], ['flipped']),
377
- ]
378
-
379
- # Step 3: Pad back to 30x30 if needed
380
- pad_inits, pad_node = _build_pad_nodes('flipped', IH, IW)
381
- inits.extend(pad_inits)
382
- nodes.append(pad_node)
383
-
384
- return mk(nodes, inits)
385
-
386
- def _build_slice_transpose_model(IH, IW):
387
- """Build a Transpose-based transpose model (perm=[0,1,3,2]).
388
- Extracts content, transposes, pads back to 30x30."""
389
- # Step 1: Extract content [1,10,30,30] -> [1,10,IH,IW]
390
- ex_st = np.array([0,0,0,0], dtype=np.int64)
391
- ex_en = np.array([1,10,IH,IW], dtype=np.int64)
392
-
393
- inits = [
394
- numpy_helper.from_array(ex_st, 'ex_st'),
395
- numpy_helper.from_array(ex_en, 'ex_en'),
396
- ]
397
- nodes = [
398
- helper.make_node('Slice', ['input','ex_st','ex_en'], ['content']),
399
- helper.make_node('Transpose', ['content'], ['transposed'], perm=[0,1,3,2]),
400
- ]
401
- # After transpose, shape is [1,10,IW,IH]. Need to pad to [1,10,30,30].
402
- pad_inits, pad_node = _build_pad_nodes('transposed', IW, IH)
403
- nodes.append(pad_node)
404
- return mk(nodes, inits + pad_inits)
405
-
406
- def _build_slice_rotate_model(k, IH, IW):
407
- """Build a rotation model using Transpose + Slice (opset 17).
408
- Extracts content, applies rotation, pads back to 30x30.
409
- Matches existing s_rotate behavior (np.rot90):
410
- k=1: 90° CCW = Transpose then vflip (reverse rows)
411
- k=2: 180° = hflip then vflip
412
- k=3: 270° CCW = Transpose then hflip (reverse cols)
413
- """
414
- # Step 1: Extract content [1,10,30,30] -> [1,10,IH,IW]
415
- ex_st = np.array([0,0,0,0], dtype=np.int64)
416
- ex_en = np.array([1,10,IH,IW], dtype=np.int64)
417
-
418
- inits = [
419
- numpy_helper.from_array(ex_st, 'ex_st'),
420
- numpy_helper.from_array(ex_en, 'ex_en'),
421
- ]
422
- nodes = [helper.make_node('Slice', ['input','ex_st','ex_en'], ['content'])]
423
- current = 'content'
424
-
425
- if k in (1, 3):
426
- # Transpose: [1,10,IH,IW] -> [1,10,IW,IH]
427
- nodes.append(helper.make_node('Transpose', [current], ['t'], perm=[0,1,3,2]))
428
- current = 't'
429
- new_IH, new_IW = IW, IH
430
- else:
431
- new_IH, new_IW = IH, IW
432
-
433
- # Apply flips with negative step Slice
434
- if k == 1:
435
- # vflip (reverse rows, axis=2) after transpose
436
- starts = np.array([new_IH-1], dtype=np.int64)
437
- ends = np.array([-new_IH-1], dtype=np.int64)
438
- axes = np.array([2], dtype=np.int64)
439
- steps = np.array([-1], dtype=np.int64)
440
- elif k == 2:
441
- # 180° = hflip then vflip
442
- starts_h = np.array([new_IW-1], dtype=np.int64)
443
- ends_h = np.array([-new_IW-1], dtype=np.int64)
444
- axes_h = np.array([3], dtype=np.int64)
445
- steps_h = np.array([-1], dtype=np.int64)
446
- inits.extend([
447
- numpy_helper.from_array(starts_h, 'st_h'),
448
- numpy_helper.from_array(ends_h, 'en_h'),
449
- numpy_helper.from_array(axes_h, 'ax_h'),
450
- numpy_helper.from_array(steps_h, 'sp_h'),
451
- ])
452
- nodes.append(helper.make_node('Slice', [current,'st_h','en_h','ax_h','sp_h'], ['fh']))
453
- current = 'fh'
454
- starts_v = np.array([new_IH-1], dtype=np.int64)
455
- ends_v = np.array([-new_IH-1], dtype=np.int64)
456
- axes_v = np.array([2], dtype=np.int64)
457
- steps_v = np.array([-1], dtype=np.int64)
458
- inits.extend([
459
- numpy_helper.from_array(starts_v, 'st_v'),
460
- numpy_helper.from_array(ends_v, 'en_v'),
461
- numpy_helper.from_array(axes_v, 'ax_v'),
462
- numpy_helper.from_array(steps_v, 'sp_v'),
463
- ])
464
- nodes.append(helper.make_node('Slice', [current,'st_v','en_v','ax_v','sp_v'], ['rot']))
465
- current = 'rot'
466
- pad_inits, pad_node = _build_pad_nodes(current, new_IH, new_IW)
467
- nodes.append(pad_node)
468
- return mk(nodes, inits + pad_inits)
469
- elif k == 3:
470
- # hflip (reverse cols, axis=3) after transpose
471
- starts = np.array([new_IW-1], dtype=np.int64)
472
- ends = np.array([-new_IW-1], dtype=np.int64)
473
- axes = np.array([3], dtype=np.int64)
474
- steps = np.array([-1], dtype=np.int64)
475
-
476
- inits.extend([
477
- numpy_helper.from_array(starts, 'sl_st'),
478
- numpy_helper.from_array(ends, 'sl_en'),
479
- numpy_helper.from_array(axes, 'sl_ax'),
480
- numpy_helper.from_array(steps, 'sl_sp'),
481
- ])
482
- nodes.append(helper.make_node('Slice', [current,'sl_st','sl_en','sl_ax','sl_sp'], ['rot']))
483
- current = 'rot'
484
-
485
- # Pad back to 30x30
486
- pad_inits, pad_node = _build_pad_nodes(current, new_IH, new_IW)
487
- nodes.append(pad_node)
488
-
489
- return mk(nodes, inits + pad_inits)
490
-
491
  # ============================================================
492
  # ANALYTICAL SOLVERS
493
  # ============================================================
@@ -497,8 +346,7 @@ def s_identity(td):
497
  if ex['input'] != ex['output']: return None
498
  return mk([helper.make_node('Identity', ['input'], ['output'])])
499
 
500
- def _get_color_map(td):
501
- """Extract color map if consistent across all examples, or None."""
502
  cm = {}
503
  for ex in td['train']+td['test']:
504
  inp, out = np.array(ex['input']), np.array(ex['output'])
@@ -507,14 +355,12 @@ def _get_color_map(td):
507
  iv, ov = int(iv), int(ov)
508
  if iv in cm and cm[iv] != ov: return None
509
  cm[iv] = ov
510
- return cm
511
-
512
- def _build_color_map_model(cm, is_permutation=None):
513
- """Build ONNX model for a color map."""
514
- if is_permutation is None:
515
- is_permutation = (set(cm.keys()) == set(cm.values()))
516
 
517
  if is_permutation:
 
518
  gather_ch = np.arange(10, dtype=np.int32)
519
  for src, dst in cm.items():
520
  if 0 <= src < 10 and 0 <= dst < 10:
@@ -523,25 +369,17 @@ def _build_color_map_model(cm, is_permutation=None):
523
  nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
524
  return mk(nodes, inits)
525
  else:
 
526
  W = np.zeros((10,10,1,1), dtype=np.float32)
527
  for ic in range(10):
528
  W[cm.get(ic,ic), ic, 0, 0] = 1.0
529
  return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
530
  [numpy_helper.from_array(W, 'W')])
531
 
532
- def s_color_map(td):
533
- cm = _get_color_map(td)
534
- if cm is None: return None
535
- is_permutation = (set(cm.keys()) == set(cm.values()))
536
- return _build_color_map_model(cm, is_permutation)
537
-
538
  def s_transpose(td):
539
- exs = get_exs(td)
540
- sp = fixed_shapes(td)
541
- if sp is None: return None
542
- (IH,IW),(OH,OW) = sp
543
- if not all(np.array_equal(out, inp.T) for inp, out in exs): return None
544
- return _build_slice_transpose_model(IH, IW)
545
 
546
  def s_flip(td):
547
  exs = get_exs(td)
@@ -551,7 +389,15 @@ def s_flip(td):
551
  if (IH,IW) != (OH,OW): return None
552
  for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
553
  if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
554
- return _build_slice_flip_model(axis, IH, IW)
 
 
 
 
 
 
 
 
555
  return None
556
 
557
  def s_rotate(td):
@@ -560,8 +406,15 @@ def s_rotate(td):
560
  if sp is None: return None
561
  (IH,IW),(OH,OW) = sp
562
  for k in [1, 2, 3]:
563
- if all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs):
564
- return _build_slice_rotate_model(k, IH, IW)
 
 
 
 
 
 
 
565
  return None
566
 
567
  def s_spatial_gather(td):
@@ -587,9 +440,10 @@ def s_spatial_gather(td):
587
  def s_varshape_spatial_gather(td):
588
  """Spatial gather that works for variable-shape tasks by embedding in 30x30."""
589
  sp = fixed_shapes(td)
590
- if sp is not None: return None
591
  exs = get_exs(td)
592
 
 
593
  exs_30 = []
594
  for inp, out in exs:
595
  ih, iw = inp.shape
@@ -645,14 +499,11 @@ def s_tile(td):
645
  numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'en'),
646
  numpy_helper.from_array(np.array([1,1,rH,rW], dtype=np.int64), 'rp'),
647
  ]
648
- pads_arr = np.array([0, 0, 0, 0, 0, 0, pad_h, pad_w], dtype=np.int64)
649
- tile_pads = numpy_helper.from_array(pads_arr, 'tile_pads')
650
  nodes = [
651
  helper.make_node('Slice', ['input','st','en'], ['cr']),
652
  helper.make_node('Tile', ['cr','rp'], ['tl']),
653
- helper.make_node('Pad', ['tl', 'tile_pads'], ['output'], mode='constant'),
654
  ]
655
- inits.append(tile_pads)
656
  return mk(nodes, inits)
657
 
658
  def s_upscale(td):
@@ -732,18 +583,26 @@ def s_concat_enhanced(td):
732
  if sp is None: return None
733
  (IH,IW),(OH,OW) = sp
734
  if IH == OH and IW == OW: return None
 
 
735
  if OH % IH != 0 or OW % IW != 0: return None
736
  rH, rW = OH // IH, OW // IW
737
  if rH * rW > 16 or rH * rW < 2: return None
738
  if OH > 30 or OW > 30: return None
739
 
 
740
  transforms = [
741
- ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
742
- ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
743
- ('rot90', lambda x: np.rot90(x, 1)), ('rot270', lambda x: np.rot90(x, 3)),
744
- ('T', lambda x: x.T), ('T_fliplr', lambda x: np.fliplr(x.T)),
 
 
 
 
745
  ]
746
 
 
747
  block_transforms = {}
748
  for bi in range(rH):
749
  for bj in range(rW):
@@ -754,13 +613,16 @@ def s_concat_enhanced(td):
754
  block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
755
  expected = tfn(inp)
756
  if expected.shape != (IH, IW) or not np.array_equal(block, expected):
757
- ok = False; break
 
758
  if ok:
759
  found = (tidx, tname)
760
  break
761
- if found is None: return None
 
762
  block_transforms[(bi, bj)] = found
763
 
 
764
  idx = np.zeros((OH, OW, 2), dtype=np.int64)
765
  for bi in range(rH):
766
  for bj in range(rW):
@@ -778,16 +640,51 @@ def s_concat_enhanced(td):
778
  elif tname == 'T_fliplr': sr, sc = IW-1-lc, lr
779
  idx[oi, oj] = [sr, sc]
780
 
 
781
  for inp, out in exs:
782
  reconstructed = np.zeros_like(out)
783
  for oi in range(OH):
784
  for oj in range(OW):
785
  reconstructed[oi,oj] = inp[idx[oi,oj,0], idx[oi,oj,1]]
786
- if not np.array_equal(reconstructed, out): return None
 
787
 
788
  return _build_gather_model(OH, OW, idx)
789
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790
  def s_kronecker(td):
 
791
  exs = get_exs(td)
792
  sp = fixed_shapes(td)
793
  if sp is None: return None
@@ -796,9 +693,13 @@ def s_kronecker(td):
796
  sH, sW = OH // IH, OW // IW
797
  if sH < 2 or sW < 2: return None
798
  if OH > 30 or OW > 30: return None
 
799
  for inp, out in exs:
800
  expected = np.kron(inp, np.ones((sH, sW), dtype=np.int64))
801
- if not np.array_equal(out, expected): return None
 
 
 
802
  idx = np.zeros((OH,OW,2), dtype=np.int64)
803
  for r in range(OH):
804
  for c in range(OW):
@@ -806,6 +707,7 @@ def s_kronecker(td):
806
  return _build_gather_model(OH, OW, idx)
807
 
808
  def s_diagonal_tile(td):
 
809
  exs = get_exs(td)
810
  sp = fixed_shapes(td)
811
  if sp is None: return None
@@ -814,14 +716,19 @@ def s_diagonal_tile(td):
814
  rH, rW = OH // IH, OW // IW
815
  if rH != rW or rH < 2: return None
816
  if OH > 30 or OW > 30: return None
 
817
  for inp, out in exs:
818
  for bi in range(rH):
819
  for bj in range(rW):
820
  block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
821
  if bi == bj:
822
- if not np.array_equal(block, inp): return None
 
823
  else:
824
- if not np.all(block == 0): return None
 
 
 
825
  idx = np.zeros((OH,OW,2), dtype=np.int64)
826
  cst = np.full((OH,OW), -1, dtype=np.int64)
827
  for bi in range(rH):
@@ -829,11 +736,16 @@ def s_diagonal_tile(td):
829
  for lr in range(IH):
830
  for lc in range(IW):
831
  oi, oj = bi*IH + lr, bj*IW + lc
832
- if bi == bj: idx[oi, oj] = [lr, lc]
833
- else: idx[oi, oj] = [-1, -1]; cst[oi, oj] = 0
 
 
 
 
834
  return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
835
 
836
  def s_shift(td):
 
837
  exs = get_exs(td)
838
  sp = fixed_shapes(td)
839
  if sp is None: return None
@@ -853,17 +765,52 @@ def s_shift(td):
853
  if not np.array_equal(shifted, out):
854
  ok = False; break
855
  if not ok: continue
 
856
  idx = np.zeros((OH, OW, 2), dtype=np.int64)
857
- cst = np.full((OH, OW), 0, dtype=np.int64)
858
  for r in range(OH):
859
  for c in range(OW):
860
  sr, sc = r - dr, c - dc
861
- if 0 <= sr < IH and 0 <= sc < IW: idx[r, c] = [sr, sc]
862
- else: idx[r, c] = [-1, -1]
 
 
863
  return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
864
  return None
865
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
866
  def s_mirror_h(td):
 
867
  exs = get_exs(td)
868
  sp = fixed_shapes(td)
869
  if sp is None: return None
@@ -873,6 +820,7 @@ def s_mirror_h(td):
873
  for inp, out in exs:
874
  expected = np.concatenate([inp, np.flip(inp, 1)], 1)
875
  if not np.array_equal(expected, out): return None
 
876
  idx = np.zeros((OH, OW, 2), dtype=np.int64)
877
  for r in range(OH):
878
  for c in range(OW):
@@ -881,6 +829,7 @@ def s_mirror_h(td):
881
  return _build_gather_model(OH, OW, idx)
882
 
883
  def s_mirror_v(td):
 
884
  exs = get_exs(td)
885
  sp = fixed_shapes(td)
886
  if sp is None: return None
@@ -898,6 +847,7 @@ def s_mirror_v(td):
898
  return _build_gather_model(OH, OW, idx)
899
 
900
  def s_quad_mirror(td):
 
901
  exs = get_exs(td)
902
  sp = fixed_shapes(td)
903
  if sp is None: return None
@@ -919,6 +869,7 @@ def s_quad_mirror(td):
919
  return _build_gather_model(OH, OW, idx)
920
 
921
  def s_fixed_crop(td):
 
922
  exs = get_exs(td)
923
  sp = fixed_shapes(td)
924
  if sp is None: return None
@@ -935,6 +886,7 @@ def s_fixed_crop(td):
935
  return None
936
 
937
  def s_nonuniform_scale(td):
 
938
  exs = get_exs(td)
939
  sp = fixed_shapes(td)
940
  if sp is None: return None
@@ -967,289 +919,8 @@ def s_constant(td):
967
  helper.make_node('Add', ['s','c'], ['output'])]
968
  return mk(nodes, inits)
969
 
970
- def _attr_to_dict(attr_proto):
971
- """Convert ONNX AttributeProto to Python native type."""
972
- from onnx import AttributeProto
973
- if attr_proto.type == AttributeProto.INT:
974
- return attr_proto.i
975
- elif attr_proto.type == AttributeProto.INTS:
976
- return list(attr_proto.ints)
977
- elif attr_proto.type == AttributeProto.FLOAT:
978
- return attr_proto.f
979
- elif attr_proto.type == AttributeProto.FLOATS:
980
- return list(attr_proto.floats)
981
- elif attr_proto.type == AttributeProto.STRING:
982
- return attr_proto.s.decode('utf-8')
983
- elif attr_proto.type == AttributeProto.STRINGS:
984
- return [s.decode('utf-8') for s in attr_proto.strings]
985
- elif attr_proto.type == AttributeProto.TENSOR:
986
- return numpy_helper.to_array(attr_proto.t)
987
- else:
988
- return None
989
-
990
  # ============================================================
991
- # COMPOSITION DETECTORS (transform + color_map)
992
- # ============================================================
993
-
994
- def _apply_transform(inp, transform_name):
995
- """Apply a named transform to a numpy array."""
996
- if transform_name == 'id': return inp
997
- elif transform_name == 'fliplr': return np.fliplr(inp)
998
- elif transform_name == 'flipud': return np.flipud(inp)
999
- elif transform_name == 'rot90': return np.rot90(inp, 1)
1000
- elif transform_name == 'rot180': return np.rot90(inp, 2)
1001
- elif transform_name == 'rot270': return np.rot90(inp, 3)
1002
- elif transform_name == 'T': return inp.T
1003
- else: return inp
1004
-
1005
- def s_composition_rotate_color(td):
1006
- """Detect rotation + color_map composition."""
1007
- exs = get_exs(td)
1008
- sp = fixed_shapes(td)
1009
- if sp is None: return None
1010
- (IH,IW),(OH,OW) = sp
1011
- if (IH,IW) != (OH,OW): return None
1012
-
1013
- for k in [1, 2, 3]:
1014
- # Try each rotation, then check if consistent color_map remains
1015
- cm = {}
1016
- valid = True
1017
- for inp, out in exs:
1018
- rotated = np.rot90(inp, k)
1019
- if rotated.shape != out.shape: valid = False; break
1020
- for iv, ov in zip(rotated.flat, out.flat):
1021
- iv, ov = int(iv), int(ov)
1022
- if iv in cm and cm[iv] != ov: valid = False; break
1023
- cm[iv] = ov
1024
- if not valid: break
1025
- if not valid: continue
1026
-
1027
- # Build: rotate first (Slice-based), then color_map
1028
- rot_model = _build_slice_rotate_model(k, IH, IW)
1029
- # Extract nodes from rot_model, prepend to color_map
1030
- cm_model = _build_color_map_model(cm)
1031
-
1032
- # Combine: input -> rot_nodes -> color_map -> output
1033
- # We need to chain the graphs
1034
- combined_nodes = []
1035
- combined_inits = []
1036
-
1037
- # Add rotation nodes with renamed intermediates
1038
- for node in rot_model.graph.node:
1039
- if node.output[0] == 'output':
1040
- # Last node of rotation feeds into color map
1041
- new_node = helper.make_node(node.op_type, list(node.input), ['rot_out'],
1042
- **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1043
- else:
1044
- new_node = node
1045
- combined_nodes.append(new_node)
1046
-
1047
- for init in rot_model.graph.initializer:
1048
- combined_inits.append(init)
1049
-
1050
- # Add color map nodes with input = rot_out
1051
- for node in cm_model.graph.node:
1052
- if node.input[0] == 'input':
1053
- new_node = helper.make_node(node.op_type, ['rot_out'] + list(node.input[1:]), list(node.output),
1054
- **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1055
- else:
1056
- new_node = node
1057
- combined_nodes.append(new_node)
1058
-
1059
- for init in cm_model.graph.initializer:
1060
- combined_inits.append(init)
1061
-
1062
- return mk(combined_nodes, combined_inits)
1063
- return None
1064
-
1065
- def s_composition_flip_color(td):
1066
- """Detect flip + color_map composition."""
1067
- exs = get_exs(td)
1068
- sp = fixed_shapes(td)
1069
- if sp is None: return None
1070
- (IH,IW),(OH,OW) = sp
1071
- if (IH,IW) != (OH,OW): return None
1072
-
1073
- for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
1074
- cm = {}
1075
- valid = True
1076
- for inp, out in exs:
1077
- flipped = flip_fn(inp)
1078
- if flipped.shape != out.shape: valid = False; break
1079
- for iv, ov in zip(flipped.flat, out.flat):
1080
- iv, ov = int(iv), int(ov)
1081
- if iv in cm and cm[iv] != ov: valid = False; break
1082
- cm[iv] = ov
1083
- if not valid: break
1084
- if not valid: continue
1085
-
1086
- flip_model = _build_slice_flip_model(axis, IH, IW)
1087
- cm_model = _build_color_map_model(cm)
1088
-
1089
- combined_nodes = []
1090
- combined_inits = []
1091
-
1092
- for node in flip_model.graph.node:
1093
- if node.output[0] == 'output':
1094
- new_node = helper.make_node(node.op_type, list(node.input), ['flip_out'],
1095
- **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1096
- else:
1097
- new_node = node
1098
- combined_nodes.append(new_node)
1099
-
1100
- for init in flip_model.graph.initializer:
1101
- combined_inits.append(init)
1102
-
1103
- for node in cm_model.graph.node:
1104
- if node.input[0] == 'input':
1105
- new_node = helper.make_node(node.op_type, ['flip_out'] + list(node.input[1:]), list(node.output),
1106
- **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1107
- else:
1108
- new_node = node
1109
- combined_nodes.append(new_node)
1110
-
1111
- for init in cm_model.graph.initializer:
1112
- combined_inits.append(init)
1113
-
1114
- return mk(combined_nodes, combined_inits)
1115
- return None
1116
-
1117
- def s_composition_transpose_color(td):
1118
- """Detect transpose + color_map composition."""
1119
- exs = get_exs(td)
1120
- sp = fixed_shapes(td)
1121
- if sp is None: return None
1122
- (IH,IW),(OH,OW) = sp
1123
-
1124
- cm = {}
1125
- valid = True
1126
- for inp, out in exs:
1127
- transposed = inp.T
1128
- if transposed.shape != out.shape: valid = False; break
1129
- for iv, ov in zip(transposed.flat, out.flat):
1130
- iv, ov = int(iv), int(ov)
1131
- if iv in cm and cm[iv] != ov: valid = False; break
1132
- cm[iv] = ov
1133
- if not valid: break
1134
- if not valid: return None
1135
-
1136
- trans_model = _build_slice_transpose_model(IH, IW)
1137
- cm_model = _build_color_map_model(cm)
1138
-
1139
- combined_nodes = []
1140
- combined_inits = []
1141
-
1142
- for node in trans_model.graph.node:
1143
- if node.output[0] == 'output':
1144
- new_node = helper.make_node(node.op_type, list(node.input), ['trans_out'],
1145
- **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1146
- else:
1147
- new_node = node
1148
- combined_nodes.append(new_node)
1149
-
1150
- for init in trans_model.graph.initializer:
1151
- combined_inits.append(init)
1152
-
1153
- for node in cm_model.graph.node:
1154
- if node.input[0] == 'input':
1155
- new_node = helper.make_node(node.op_type, ['trans_out'] + list(node.input[1:]), list(node.output),
1156
- **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1157
- else:
1158
- new_node = node
1159
- combined_nodes.append(new_node)
1160
-
1161
- for init in cm_model.graph.initializer:
1162
- combined_inits.append(init)
1163
-
1164
- return mk(combined_nodes, combined_inits)
1165
-
1166
- # ============================================================
1167
- # CHANNEL REDUCTION WRAPPER
1168
- # ============================================================
1169
-
1170
- def _get_active_colors(td):
1171
- """Returns set of all colors appearing in inputs and outputs."""
1172
- colors = set()
1173
- for ex in td['train'] + td['test']:
1174
- for row in ex['input']:
1175
- colors.update(row)
1176
- for row in ex['output']:
1177
- colors.update(row)
1178
- return colors
1179
-
1180
- def _build_channel_reduced_model(inner_model, input_colors, output_colors):
1181
- """Wrap a model with channel reduction: Conv1x1(10->N) -> inner -> Conv1x1(N->10).
1182
- This saves MACs when N < 10."""
1183
- n_in = len(input_colors)
1184
- n_out = len(output_colors)
1185
-
1186
- # Maps from full 10 channels to reduced set
1187
- in_map = sorted(input_colors)
1188
- out_map = sorted(output_colors)
1189
-
1190
- # W_reduce: [n_in, 10, 1, 1] - maps 10 channels to n_in
1191
- W_reduce = np.zeros((n_in, 10, 1, 1), dtype=np.float32)
1192
- for i, c in enumerate(in_map):
1193
- W_reduce[i, c, 0, 0] = 1.0
1194
-
1195
- # W_expand: [10, n_out, 1, 1] - maps n_out channels back to 10
1196
- W_expand = np.zeros((10, n_out, 1, 1), dtype=np.float32)
1197
- for i, c in enumerate(out_map):
1198
- W_expand[c, i, 0, 0] = 1.0
1199
-
1200
- # Build the wrapped model
1201
- nodes = [
1202
- helper.make_node('Conv', ['input', 'W_reduce'], ['reduced'], kernel_shape=[1,1]),
1203
- ]
1204
- inits = [numpy_helper.from_array(W_reduce, 'W_reduce')]
1205
-
1206
- # Add inner model nodes with input='reduced' and output renamed
1207
- for node in inner_model.graph.node:
1208
- if node.input[0] == 'input':
1209
- new_inputs = ['reduced'] + list(node.input[1:])
1210
- else:
1211
- new_inputs = list(node.input)
1212
-
1213
- if node.output[0] == 'output':
1214
- new_outputs = ['inner_out']
1215
- else:
1216
- new_outputs = list(node.output)
1217
-
1218
- new_node = helper.make_node(node.op_type, new_inputs, new_outputs,
1219
- **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1220
- nodes.append(new_node)
1221
-
1222
- for init in inner_model.graph.initializer:
1223
- if init.name != 'W_reduce': # avoid conflict
1224
- inits.append(init)
1225
-
1226
- nodes.append(helper.make_node('Conv', ['inner_out', 'W_expand'], ['output'], kernel_shape=[1,1]))
1227
- inits.append(numpy_helper.from_array(W_expand, 'W_expand'))
1228
-
1229
- return mk(nodes, inits)
1230
-
1231
- def _try_channel_reduction(solver_fn, td):
1232
- """Try a solver with channel reduction wrapper if it reduces cost.
1233
- NOTE: Currently disabled for Gather-based models (spatial_gather, etc.)
1234
- as they hardcode channel=10 in Reshape operations."""
1235
- model = solver_fn(td)
1236
- if model is None: return None
1237
-
1238
- # DISABLED: Channel reduction breaks Gather-based models
1239
- # that reshape to [1,10,900]. Only applies to Conv-based models.
1240
- # colors = _get_active_colors(td)
1241
- # if len(colors) >= 8:
1242
- # return model
1243
- # try:
1244
- # wrapped = _build_channel_reduced_model(model, colors, colors)
1245
- # return wrapped
1246
- # except Exception:
1247
- # return model
1248
-
1249
- return model
1250
-
1251
- # ============================================================
1252
- # CONV SOLVERS WITH LOOCV RIDGE + STRIDE TRICKS
1253
  # ============================================================
1254
 
1255
  def add_onehot_block(nodes, inits, am_name, oh_name):
@@ -1259,79 +930,13 @@ def add_onehot_block(nodes, inits, am_name, oh_name):
1259
  nodes.append(helper.make_node('Equal', [am_name, 'classes'], ['eq']))
1260
  nodes.append(helper.make_node('Cast', ['eq'], [oh_name], to=TensorProto.FLOAT))
1261
 
1262
- def _extract_patches_strided(oh_pad, ks, out_shape):
1263
- """Extract patches using stride_tricks for speedup.
1264
- oh_pad: [C, H+2p, W+2p] padded one-hot array
1265
- ks: kernel size
1266
- out_shape: (OH, OW) output shape
1267
- Returns: patches array [OH*OW, C*ks*ks]
1268
- """
1269
- C, Hp, Wp = oh_pad.shape
1270
- OH, OW = out_shape
1271
-
1272
- # Use as_strided to create sliding window view over padded array
1273
- stride_c = oh_pad.strides[0]
1274
- stride_h = oh_pad.strides[1]
1275
- stride_w = oh_pad.strides[2]
1276
-
1277
- # Ensure base covers all needed elements: up to (OH-1+ks, OW-1+ks)
1278
- needed_h = min(OH - 1 + ks, Hp)
1279
- needed_w = min(OW - 1 + ks, Wp)
1280
- base = oh_pad[:, :needed_h, :needed_w]
1281
-
1282
- # Shape: [OH, OW, C, ks, ks]
1283
- shape = (OH, OW, C, ks, ks)
1284
- strides = (stride_h, stride_w, stride_c, stride_h, stride_w)
1285
-
1286
- patches_view = np.lib.stride_tricks.as_strided(base, shape=shape, strides=strides)
1287
- # Reshape to [OH*OW, C*ks*ks]
1288
- return patches_view.reshape(OH * OW, C * ks * ks)
1289
-
1290
- def _effective_rank(P):
1291
- """Compute effective rank r(Σ) = Tr(Σ) / ‖Σ‖."""
1292
- Sigma = np.cov(P, rowvar=False)
1293
- evals = np.linalg.eigvalsh(Sigma)
1294
- evals = evals[evals > 1e-12]
1295
- if len(evals) == 0: return 0
1296
- return np.sum(evals) / np.max(evals)
1297
-
1298
- def _tune_ridge_loocv(P, T_oh, lambdas):
1299
- """Find best λ using efficient LOOCV via Hat Matrix diagonal (SVD shortcut).
1300
- Cawley & Talbot (2010), JMLR.
1301
- """
1302
- n, p = P.shape
1303
- try:
1304
- U, s, Vt = np.linalg.svd(P, full_matrices=False)
1305
- except Exception:
1306
- return None
1307
-
1308
- best_lambda, min_err = None, float('inf')
1309
-
1310
- for lam in lambdas:
1311
- d = (s**2) / (s**2 + lam)
1312
- y_hat = (U * d) @ (U.T @ T_oh)
1313
- # Ridge hat matrix diagonal: h_ii = Σ_j U_ij^2 * s_j^2 / (s_j^2 + λ)
1314
- h_ii = np.sum((U**2) * d[np.newaxis, :], axis=1)
1315
-
1316
- # LOOCV shortcut: error_i = (y_i - ŷ_i) / (1 - h_ii)
1317
- denom = 1 - h_ii
1318
- denom = np.where(np.abs(denom) < 1e-10, 1e-10, denom)
1319
- errors = (T_oh - y_hat) / denom[:, np.newaxis]
1320
- mse = np.mean(errors**2)
1321
-
1322
- if mse < min_err:
1323
- min_err, best_lambda = mse, lam
1324
-
1325
- return best_lambda
1326
-
1327
- def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False, use_ridge=True):
1328
- """Shared lstsq conv fitting with optional LOOCV Ridge tuning.
1329
- Returns (Wconv, B) or None."""
1330
  pad = ks // 2
1331
  feat = 10 * ks * ks + (1 if use_bias else 0)
1332
  if feat > 20000: return None
1333
 
1334
- patches_list, targets = [], []
1335
  for inp_g, out_g in exs_raw:
1336
  ih, iw = inp_g.shape
1337
  if use_full_30:
@@ -1344,65 +949,22 @@ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False, use_ridge=True):
1344
  oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
1345
 
1346
  oh, ow = out_g.shape
1347
-
1348
- # Try stride_tricks for speedup
1349
- try:
1350
- patches = _extract_patches_strided(oh_pad, ks, (oh, ow))
1351
- if use_bias:
1352
- bias_col = np.ones((patches.shape[0], 1), dtype=np.float64)
1353
- patches = np.concatenate([patches, bias_col], axis=1)
1354
- patches_list.append(patches)
1355
- targets.append(out_g.flatten())
1356
- except Exception:
1357
- # Fallback to loop-based extraction
1358
- for r in range(oh):
1359
- for c in range(ow):
1360
- p = oh_pad[:, r:r+ks, c:c+ks].flatten()
1361
- if use_bias: p = np.append(p, 1.0)
1362
- patches_list.append(p)
1363
- targets.append(int(out_g[r, c]))
1364
-
1365
- if len(patches_list) > 0 and isinstance(patches_list[0], np.ndarray) and patches_list[0].ndim == 2:
1366
- P = np.concatenate(patches_list, axis=0)
1367
- T = np.concatenate(targets)
1368
- else:
1369
- P = np.array(patches_list, dtype=np.float64)
1370
- T = np.array(targets, dtype=np.int64)
1371
-
1372
- n_patches = P.shape[0]
1373
  if feat > 5000 and n_patches > 2000: return None
1374
 
 
 
1375
  T_oh = np.zeros((len(T), 10), dtype=np.float64)
1376
  for i, t in enumerate(T): T_oh[i, t] = 1.0
1377
 
1378
- # Quick condition number estimate using norm ratio (cheaper than full SVD)
1379
- # Only skip if clearly pathological; otherwise try lstsq
1380
- cond_estimate = None
1381
- try:
1382
- # Use 2-norm estimate: cond ≈ ||P||_2 * ||P^+||_2 ≈ max_singular / min_singular
1383
- # We approximate with norm ratios for speed
1384
- p_norm = np.linalg.norm(P, 2)
1385
- if p_norm > 0:
1386
- # Estimate using power method approximation or just try lstsq
1387
- pass # Don't waste time on condition number - lstsq will handle it
1388
- except Exception:
1389
- pass
1390
-
1391
- if use_ridge and n_patches <= feat * 1.5:
1392
- # Use LOOCV Ridge tuning when system is underdetermined or near interpolation threshold
1393
- lambdas = np.logspace(-4, 2, 10)
1394
- best_lam = _tune_ridge_loocv(P, T_oh, lambdas)
1395
- if best_lam is not None:
1396
- # Ridge solve: (P^T P + λI)^-1 P^T T
1397
- try:
1398
- WT = np.linalg.solve(P.T @ P + best_lam * np.eye(P.shape[1]), P.T @ T_oh)
1399
- except Exception:
1400
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1401
- else:
1402
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1403
- else:
1404
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1405
-
1406
  if not np.array_equal(np.argmax(P @ WT, axis=1), T): return None
1407
 
1408
  if use_bias:
@@ -1413,21 +975,6 @@ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False, use_ridge=True):
1413
  B = None
1414
  return Wconv, B
1415
 
1416
- # ============================================================
1417
- # CONV SOLVER WRAPPERS
1418
- # ============================================================
1419
-
1420
- def _get_ks_for_budget(time_budget):
1421
- """Return kernel sizes to try based on time budget."""
1422
- if time_budget < 5:
1423
- return [1, 3, 5]
1424
- elif time_budget < 10:
1425
- return [1, 3, 5, 7, 9]
1426
- elif time_budget < 20:
1427
- return [1, 3, 5, 7, 9, 11, 13, 15, 17]
1428
- else:
1429
- return [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]
1430
-
1431
  def solve_conv_fixed(td, path, time_budget=30.0):
1432
  """Fixed-shape conv: Slice -> Conv -> ArgMax -> Equal+Cast -> Pad."""
1433
  exs = get_exs(td)
@@ -1437,12 +984,14 @@ def solve_conv_fixed(td, path, time_budget=30.0):
1437
  if len(shapes) != 1: return None
1438
  IH, IW = shapes.pop()
1439
 
 
1440
  fit_exs = get_exs_for_fitting(td)
 
1441
  fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
1442
 
1443
  t_start = time.time()
1444
  for use_bias in [False, True]:
1445
- for ks in _get_ks_for_budget(time_budget):
1446
  if time.time() - t_start > time_budget: return None
1447
  result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=False)
1448
  if result is None: continue
@@ -1466,10 +1015,9 @@ def solve_conv_fixed(td, path, time_budget=30.0):
1466
  helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1467
  ]
1468
  add_onehot_block(nodes, inits, 'am', 'oh_out')
1469
- cf_pads = numpy_helper.from_array(np.array([0,0,0,0,0,0,pad_h,pad_w], dtype=np.int64), 'cf_pads')
1470
- inits.append(cf_pads)
1471
  nodes.append(
1472
- helper.make_node('Pad', ['oh_out', 'cf_pads'], ['output'], mode='constant')
 
1473
  )
1474
 
1475
  model = mk(nodes, inits)
@@ -1483,6 +1031,7 @@ def solve_conv_variable(td, path, time_budget=30.0):
1483
  for inp, out in exs:
1484
  if inp.shape != out.shape: return None
1485
 
 
1486
  fit_exs = get_exs_for_fitting_variable(td)
1487
  fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape]
1488
 
@@ -1591,10 +1140,9 @@ def solve_conv_diffshape(td, path, time_budget=30.0):
1591
  helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
1592
  ]
1593
  add_onehot_block(nodes, inits, 'am', 'oh_out')
1594
- diff_pads = numpy_helper.from_array(np.array([0,0,0,0,0,0,pad_h,pad_w], dtype=np.int64), 'diff_pads')
1595
- inits.append(diff_pads)
1596
  nodes.append(
1597
- helper.make_node('Pad', ['oh_out', 'diff_pads'], ['output'], mode='constant')
 
1598
  )
1599
 
1600
  model = mk(nodes, inits)
@@ -1603,7 +1151,8 @@ def solve_conv_diffshape(td, path, time_budget=30.0):
1603
  return None
1604
 
1605
  def solve_conv_var_diff(td, path, time_budget=30.0):
1606
- """Variable diff-shape conv."""
 
1607
  exs = get_exs(td)
1608
 
1609
  t_start = time.time()
@@ -1651,13 +1200,47 @@ def solve_conv_var_diff(td, path, time_budget=30.0):
1651
  Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1652
  B = None
1653
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1654
  all_output_within_input = all(
1655
  out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
1656
  for inp_g, out_g in exs
1657
  )
1658
 
1659
  if not all_output_within_input:
1660
- continue
1661
 
1662
  inits = [numpy_helper.from_array(Wconv, 'W')]
1663
  conv_inputs = ['input', 'W']
@@ -1679,25 +1262,230 @@ def solve_conv_var_diff(td, path, time_budget=30.0):
1679
  return None
1680
 
1681
  # ============================================================
1682
- # MAIN SOLVER PIPELINE
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1683
  # ============================================================
1684
 
1685
  ANALYTICAL_SOLVERS = [
1686
- ('identity', s_identity),
1687
- ('constant', s_constant),
1688
- ('color_map', s_color_map),
1689
- ('transpose', s_transpose),
1690
- ('flip', s_flip),
1691
- ('rotate', s_rotate),
1692
- ('tile', s_tile),
1693
- ('upscale', s_upscale),
1694
- ('kronecker', s_kronecker),
1695
  ('nonuniform_scale', s_nonuniform_scale),
1696
- ('mirror_h', s_mirror_h),
1697
- ('mirror_v', s_mirror_v),
1698
- ('quad_mirror', s_quad_mirror),
1699
- ('concat', s_concat),
1700
- ('concat_enhanced', s_concat_enhanced),
1701
  ('diagonal_tile', s_diagonal_tile),
1702
  ('fixed_crop', s_fixed_crop),
1703
  ('spatial_gather', s_spatial_gather),
@@ -1705,46 +1493,26 @@ ANALYTICAL_SOLVERS = [
1705
  ('varshape_spatial_gather', s_varshape_spatial_gather),
1706
  ]
1707
 
1708
- COMPOSITION_SOLVERS = [
1709
- ('rotate_color', s_composition_rotate_color),
1710
- ('flip_color', s_composition_flip_color),
1711
- ('transpose_color', s_composition_transpose_color),
1712
- ]
1713
-
1714
- def solve_task(tn, td, outdir, conv_budget=30.0, use_channel_reduction=True):
1715
  t_start = time.time()
1716
  os.makedirs(outdir, exist_ok=True)
1717
  path = os.path.join(outdir, f"task{tn:03d}.onnx")
1718
 
 
1719
  if tn in EXCLUDED_TASKS:
1720
  return False, 'excluded', None, time.time() - t_start, path
1721
 
1722
  # 1. Try analytical solvers (fast, tiny models)
1723
  for sname, sfn in ANALYTICAL_SOLVERS:
1724
- try:
1725
- if use_channel_reduction and sname in ('transpose', 'flip', 'rotate', 'mirror_h', 'mirror_v', 'quad_mirror', 'shift', 'spatial_gather', 'varshape_spatial_gather'):
1726
- model = _try_channel_reduction(sfn, td)
1727
- else:
1728
- model = sfn(td)
1729
- if model is None: continue
1730
- onnx.save(model, path)
1731
- if validate(path, td):
1732
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1733
- except Exception as e:
1734
- pass
1735
-
1736
- # 2. Try composition solvers
1737
- for sname, sfn in COMPOSITION_SOLVERS:
1738
  try:
1739
  model = sfn(td)
1740
  if model is None: continue
1741
  onnx.save(model, path)
1742
- if validate(path, td):
1743
  return True, sname, os.path.getsize(path), time.time() - t_start, path
1744
- except Exception:
1745
- pass
1746
 
1747
- # 3. Determine task shape category and try conv solvers
1748
  exs = get_exs(td)
1749
  same_shape = all(inp.shape == out.shape for inp, out in exs)
1750
  shapes = set(inp.shape for inp, _ in exs)
@@ -1762,6 +1530,12 @@ def solve_task(tn, td, outdir, conv_budget=30.0, use_channel_reduction=True):
1762
  if result is not None:
1763
  sname, model = result
1764
  return True, sname, os.path.getsize(path), time.time() - t_start, path
 
 
 
 
 
 
1765
  else:
1766
  sp = fixed_shapes(td)
1767
  if sp is not None:
@@ -1772,6 +1546,7 @@ def solve_task(tn, td, outdir, conv_budget=30.0, use_channel_reduction=True):
1772
  sname, model = result
1773
  return True, sname, os.path.getsize(path), time.time() - t_start, path
1774
 
 
1775
  result = solve_conv_var_diff(td, path, time_budget=conv_time)
1776
  if result is not None:
1777
  sname, model = result
@@ -1779,7 +1554,7 @@ def solve_task(tn, td, outdir, conv_budget=30.0, use_channel_reduction=True):
1779
 
1780
  return False, None, None, time.time() - t_start, path
1781
 
1782
- def run_tasks(task_nums, tasks, output_dir, conv_budget, use_wandb, use_channel_reduction=True):
1783
  results = {}
1784
  costs_dict = {}
1785
  total_score = 0
@@ -1791,7 +1566,7 @@ def run_tasks(task_nums, tasks, output_dir, conv_budget, use_wandb, use_channel_
1791
  continue
1792
 
1793
  td = tasks[tn]['data']
1794
- ok, sname, sz, t_task, model_path = solve_task(tn, td, output_dir, conv_budget, use_channel_reduction)
1795
 
1796
  if ok:
1797
  macs, memory, params = score_network(model_path)
@@ -1831,7 +1606,6 @@ def main():
1831
  parser.add_argument('--tasks', type=str, default='')
1832
  parser.add_argument('--device', type=str, default='auto', choices=['auto','cpu','cuda'])
1833
  parser.add_argument('--use_wandb', action='store_true')
1834
- parser.add_argument('--no_channel_reduction', action='store_true', help='Disable channel reduction wrapper')
1835
  args = parser.parse_args()
1836
  global ORT_PROVIDERS
1837
  config = {
@@ -1849,7 +1623,6 @@ def main():
1849
 
1850
  ort.set_default_logger_severity(3)
1851
  print(f"Using providers: {ORT_PROVIDERS}")
1852
- print(f"OPSET: 17 (v5)")
1853
 
1854
  if args.kaggle:
1855
  tasks = load_tasks_kaggle(args.data_dir)
@@ -1857,6 +1630,7 @@ def main():
1857
  arcgen = args.arcgen_dir if args.arcgen_dir else None
1858
  tasks = load_tasks_dir(args.data_dir, arcgen_dir=arcgen)
1859
 
 
1860
  total_arcgen = sum(len(t['data'].get('arc-gen', [])) for t in tasks.values())
1861
  print(f"Loaded {len(tasks)} tasks ({total_arcgen} ARC-GEN examples)")
1862
  print(f"Excluded tasks: {sorted(EXCLUDED_TASKS)}")
@@ -1865,17 +1639,14 @@ def main():
1865
  active_tasks = [t for t in task_nums if t not in EXCLUDED_TASKS]
1866
  print(f"Solving {len(active_tasks)} active tasks (skipping {len(task_nums) - len(active_tasks)} excluded)")
1867
  print(f"Conv budget: {args.conv_budget}s per task")
1868
- print(f"Channel reduction: {'enabled' if not args.no_channel_reduction else 'disabled'}")
1869
  print("=" * 70)
1870
  t0 = time.time()
1871
 
1872
- use_ch_red = not args.no_channel_reduction
1873
-
1874
  if args.use_wandb and wandb is not None:
1875
  with wandb.init(project="neurogolf", name="solver_run", config=config):
1876
- results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=True, use_channel_reduction=use_ch_red)
1877
  else:
1878
- results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=False, use_channel_reduction=use_ch_red)
1879
 
1880
  elapsed = time.time() - t0
1881
  print(f"\n{'='*70}")
@@ -1884,11 +1655,13 @@ def main():
1884
  sc = Counter(solver_names)
1885
  for s, c in sc.most_common(): print(f" {s}: {c}")
1886
 
 
1887
  outdir = args.output_dir
1888
  n_files = len([f for f in os.listdir(outdir) if f.endswith('.onnx')])
1889
  total_size = sum(os.path.getsize(os.path.join(outdir, f))
1890
  for f in os.listdir(outdir) if f.endswith('.onnx'))
1891
 
 
1892
  zip_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.zip')
1893
  buf = io.BytesIO()
1894
  with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
@@ -1900,6 +1673,7 @@ def main():
1900
  f.write(zip_bytes)
1901
  zip_size = len(zip_bytes)
1902
 
 
1903
  csv_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.csv')
1904
  with open(csv_path, 'w', newline='') as f:
1905
  w = csv.writer(f)
@@ -1907,6 +1681,7 @@ def main():
1907
  for tn in sorted(costs_dict.keys()):
1908
  w.writerow([f'task{tn:03d}', costs_dict[tn]])
1909
 
 
1910
  unsolved_count = len(active_tasks) - len(results)
1911
  est_lb = total_score + unsolved_count * 1.0
1912
 
@@ -1916,4 +1691,4 @@ def main():
1916
  print(f"Written: {zip_path} | {csv_path}")
1917
 
1918
  if __name__ == '__main__':
1919
- main()
 
1
  #!/usr/bin/env python3
2
  """
3
+ ARC-AGI NeuroGolf Championship - Complete Solver v4
4
+ Format: [1,10,30,30] one-hot input/output, opset 10, IR version 10.
5
+
6
+ v4 CRITICAL FIXES:
7
+ - ARC-GEN data loaded and used for conv fitting (more data = better lstsq)
8
+ - ARC-GEN validation: models validated against train+test+arc-gen
9
+ - EXCLUDED tasks: {21, 55, 80, 184, 202, 366} skipped
10
+ - submission.csv generation for Kaggle
11
+ - s_flip fixed: GatherElements -> Gather (opset 10 compat)
12
+ - Static profiler: no onnx_tool dependency for cost estimation
13
+ - get_exs_for_fitting(): uses train+test+arc-gen for conv fitting
14
 
15
  Solvers:
16
  - Analytical: identity, constant, color_map, transpose, flip, rotate, tile, upscale,
17
  concat, concat_enhanced, spatial_gather, varshape_spatial_gather,
18
+ diagonal_tile, kronecker
 
 
19
  - Conv (fixed shape): Slice -> Conv -> ArgMax -> Equal+Cast -> Pad
20
  - Conv (variable shape): Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)
21
  - Conv (diff shape): Slice -> Conv -> Slice(crop) -> ArgMax -> Equal+Cast -> Pad
 
22
 
23
  Usage:
24
  python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission
25
  python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission --conv_budget 60 --arcgen_dir ARC-GEN-100K/
26
  """
27
 
28
+ import json, os, sys, math, time, argparse, csv, io, zipfile
29
  import numpy as np
30
  import onnx
31
  from onnx import helper, TensorProto, numpy_helper
 
47
  GRID_SHAPE = [BATCH, CH, GH, GW]
48
  DT = TensorProto.FLOAT
49
  IR = 10
50
+ OPSET = [helper.make_opsetid("", 10)]
 
51
 
52
  # Officially excluded tasks (score 0 regardless)
53
  EXCLUDED_TASKS = {21, 55, 80, 184, 202, 366}
54
 
55
  # Max ARC-GEN examples to use for validation (to keep runtime reasonable)
56
  MAX_ARCGEN_VALIDATE = 30
57
+ # Max ARC-GEN examples for conv fitting (keep separate from validation!)
58
+ # NOTE: Conv fitting uses train+test only. ARC-GEN is for VALIDATION only.
59
+ # lstsq underdetermines with too many variable-size arc-gen examples.
60
+ MAX_ARCGEN_FIT = 0 # Don't use arc-gen for fitting — use for validation only
61
 
62
  def get_providers():
63
  return ['CPUExecutionProvider']
 
76
  with open(os.path.join(data_dir, f)) as fh:
77
  data = json.load(fh)
78
  hex_id = f.replace('.json','')
79
+ # Load ARC-GEN data if available
80
  if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)):
81
  with open(os.path.join(arcgen_dir, f)) as fh:
82
  arcgen_examples = json.load(fh)
 
109
  return arr
110
 
111
  def validate(path, td):
112
+ """Validate model against ALL examples: train + test + arc-gen.
113
+ This matches what Kaggle does for scoring."""
114
  try:
115
  opts = ort.SessionOptions()
116
  opts.log_severity_level = 3
 
118
  except:
119
  return False
120
  examples = td['train'] + td['test']
121
+ # Include arc-gen examples (capped for speed)
122
  if 'arc-gen' in td:
123
  examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
124
  for ex in examples:
 
164
  MAX_FILESIZE = int(1.44 * 1024 * 1024)
165
 
166
  def score_network(path):
167
+ """Static profiler matching Kaggle scoring: cost = macs + memory + params.
168
+ Falls back to official neurogolf_utils if available."""
169
  if HAS_ONNX_TOOL:
170
  try:
171
  return _score_network_official(path)
 
227
  for ex in td['train'] + td['test']]
228
 
229
  def get_exs_for_fitting(td):
230
+ """Get examples for conv fitting. Uses train+test + arc-gen WHERE SIZES MATCH.
231
+ For fixed-size tasks, arc-gen examples have the same grid size,
232
+ so they provide more data points for lstsq without changing the feature dimension.
233
+ For variable-size tasks, only use train+test (arc-gen varies too much)."""
234
  base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
235
  for ex in td['train'] + td['test']]
236
 
237
  if not base_exs:
238
  return base_exs
239
 
240
+ # Check if all base examples have same input shape
241
  base_shapes = {inp.shape for inp, _ in base_exs}
242
  if len(base_shapes) != 1:
243
+ return base_exs # Variable sizes — don't add arc-gen
244
 
245
  base_shape = list(base_shapes)[0]
246
 
247
+ # Add arc-gen examples that match the base shape
248
  ag_exs = []
249
  for ex in td.get('arc-gen', []):
250
  inp = np.array(ex['input'], dtype=np.int64)
 
252
  if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
253
  ag_exs.append((inp, out))
254
 
255
+ # Cap to avoid massive lstsq (diminishing returns after ~10)
256
  return base_exs + ag_exs[:10]
257
 
258
  def get_exs_for_fitting_variable(td):
259
+ """Get examples for variable-shape conv fitting.
260
+ For variable-shape tasks, arc-gen examples may have different sizes per example
261
+ but since we embed in 30x30 anyway, we can safely include them."""
262
  base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
263
  for ex in td['train'] + td['test']]
264
 
265
+ # For variable shape, include arc-gen examples (they get embedded in 30x30)
266
  ag_exs = []
267
  for ex in td.get('arc-gen', []):
268
  inp = np.array(ex['input'], dtype=np.int64)
 
279
  return list(shapes)[0] if len(shapes) == 1 else None
280
 
281
  # ============================================================
282
+ # GATHER HELPERS
283
  # ============================================================
284
 
285
  def _build_gather_model(OH, OW, idx):
286
+ # Use Gather (opset 1) instead of GatherElements (opset 11)
287
+ # Flatten spatial: [1,10,900] -> Gather(axis=2, indices=[900]) -> [1,10,900]
288
  flat_idx = np.zeros((GH*GW,), dtype=np.int64)
289
  mask = np.zeros((1,1,GH,GW), dtype=np.float32)
290
  for oi in range(OH):
 
306
  return mk(nodes, inits)
307
 
308
  def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
309
+ # Use Gather (opset 1) instead of GatherElements (opset 11)
310
  flat_idx = np.zeros((GH*GW,), dtype=np.int64)
311
  gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
312
  const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
 
337
  nodes[-1] = helper.make_node('Mul', ['raw','gmask'], ['output'])
338
  return mk(nodes, inits)
339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  # ============================================================
341
  # ANALYTICAL SOLVERS
342
  # ============================================================
 
346
  if ex['input'] != ex['output']: return None
347
  return mk([helper.make_node('Identity', ['input'], ['output'])])
348
 
349
+ def s_color_map(td):
 
350
  cm = {}
351
  for ex in td['train']+td['test']:
352
  inp, out = np.array(ex['input']), np.array(ex['output'])
 
355
  iv, ov = int(iv), int(ov)
356
  if iv in cm and cm[iv] != ov: return None
357
  cm[iv] = ov
358
+
359
+ # Check if it's a permutation (bijective + all mapped colors form a closed set)
360
+ is_permutation = (set(cm.keys()) == set(cm.values()))
 
 
 
361
 
362
  if is_permutation:
363
+ # Use channel Gather — zero MACs, much cheaper
364
  gather_ch = np.arange(10, dtype=np.int32)
365
  for src, dst in cm.items():
366
  if 0 <= src < 10 and 0 <= dst < 10:
 
369
  nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
370
  return mk(nodes, inits)
371
  else:
372
+ # Non-permutation: use Conv 1x1 (has MACs but handles any mapping)
373
  W = np.zeros((10,10,1,1), dtype=np.float32)
374
  for ic in range(10):
375
  W[cm.get(ic,ic), ic, 0, 0] = 1.0
376
  return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
377
  [numpy_helper.from_array(W, 'W')])
378
 
 
 
 
 
 
 
379
  def s_transpose(td):
380
+ for ex in td['train']+td['test']:
381
+ if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T): return None
382
+ return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0,1,3,2])])
 
 
 
383
 
384
  def s_flip(td):
385
  exs = get_exs(td)
 
389
  if (IH,IW) != (OH,OW): return None
390
  for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
391
  if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
392
+ # Build gather index map (using Gather, opset 1 compatible)
393
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
394
+ for r in range(OH):
395
+ for c in range(OW):
396
+ if axis == 0:
397
+ idx[r,c] = [IH-1-r, c]
398
+ else:
399
+ idx[r,c] = [r, IW-1-c]
400
+ return _build_gather_model(OH, OW, idx)
401
  return None
402
 
403
  def s_rotate(td):
 
406
  if sp is None: return None
407
  (IH,IW),(OH,OW) = sp
408
  for k in [1, 2, 3]:
409
+ if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs): continue
410
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
411
+ for r in range(OH):
412
+ for c in range(OW):
413
+ if k == 1: sr, sc = c, IH-1-r
414
+ elif k == 2: sr, sc = IH-1-r, IW-1-c
415
+ elif k == 3: sr, sc = IW-1-c, r
416
+ idx[r,c] = [sr, sc]
417
+ return _build_gather_model(OH, OW, idx)
418
  return None
419
 
420
  def s_spatial_gather(td):
 
440
  def s_varshape_spatial_gather(td):
441
  """Spatial gather that works for variable-shape tasks by embedding in 30x30."""
442
  sp = fixed_shapes(td)
443
+ if sp is not None: return None # fixed shapes handled by s_spatial_gather
444
  exs = get_exs(td)
445
 
446
+ # Embed all examples in 30x30
447
  exs_30 = []
448
  for inp, out in exs:
449
  ih, iw = inp.shape
 
499
  numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'en'),
500
  numpy_helper.from_array(np.array([1,1,rH,rW], dtype=np.int64), 'rp'),
501
  ]
 
 
502
  nodes = [
503
  helper.make_node('Slice', ['input','st','en'], ['cr']),
504
  helper.make_node('Tile', ['cr','rp'], ['tl']),
505
+ helper.make_node('Pad', ['tl'], ['output'], pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0),
506
  ]
 
507
  return mk(nodes, inits)
508
 
509
  def s_upscale(td):
 
583
  if sp is None: return None
584
  (IH,IW),(OH,OW) = sp
585
  if IH == OH and IW == OW: return None
586
+
587
+ # Need block decomposition
588
  if OH % IH != 0 or OW % IW != 0: return None
589
  rH, rW = OH // IH, OW // IW
590
  if rH * rW > 16 or rH * rW < 2: return None
591
  if OH > 30 or OW > 30: return None
592
 
593
+ # All 8 symmetry transforms of the dihedral group
594
  transforms = [
595
+ ('id', lambda x: x),
596
+ ('fliplr', lambda x: np.fliplr(x)),
597
+ ('flipud', lambda x: np.flipud(x)),
598
+ ('rot180', lambda x: np.rot90(x, 2)),
599
+ ('rot90', lambda x: np.rot90(x, 1)),
600
+ ('rot270', lambda x: np.rot90(x, 3)),
601
+ ('T', lambda x: x.T),
602
+ ('T_fliplr', lambda x: np.fliplr(x.T)),
603
  ]
604
 
605
+ # For each block, find which transform matches
606
  block_transforms = {}
607
  for bi in range(rH):
608
  for bj in range(rW):
 
613
  block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
614
  expected = tfn(inp)
615
  if expected.shape != (IH, IW) or not np.array_equal(block, expected):
616
+ ok = False
617
+ break
618
  if ok:
619
  found = (tidx, tname)
620
  break
621
+ if found is None:
622
+ return None
623
  block_transforms[(bi, bj)] = found
624
 
625
+ # Build index map
626
  idx = np.zeros((OH, OW, 2), dtype=np.int64)
627
  for bi in range(rH):
628
  for bj in range(rW):
 
640
  elif tname == 'T_fliplr': sr, sc = IW-1-lc, lr
641
  idx[oi, oj] = [sr, sc]
642
 
643
+ # Verify
644
  for inp, out in exs:
645
  reconstructed = np.zeros_like(out)
646
  for oi in range(OH):
647
  for oj in range(OW):
648
  reconstructed[oi,oj] = inp[idx[oi,oj,0], idx[oi,oj,1]]
649
+ if not np.array_equal(reconstructed, out):
650
+ return None
651
 
652
  return _build_gather_model(OH, OW, idx)
653
 
654
+ def s_input_driven_tile(td):
655
+ """Each non-zero input pixel controls a block that's a copy of the input."""
656
+ exs = get_exs(td)
657
+ sp = fixed_shapes(td)
658
+ if sp is None: return None
659
+ (IH,IW),(OH,OW) = sp
660
+ if OH % IH != 0 or OW % IW != 0: return None
661
+ sH, sW = OH // IH, OW // IW
662
+ if sH != IH or sW != IW: return None
663
+ if OH > 30 or OW > 30: return None
664
+
665
+ for inp, out in exs:
666
+ for bi in range(IH):
667
+ for bj in range(IW):
668
+ block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
669
+ if inp[bi, bj] != 0:
670
+ if not np.array_equal(block, inp):
671
+ return None
672
+ else:
673
+ if not np.all(block == 0):
674
+ return None
675
+
676
+ # Build gather model: each output pixel at (bi*IH+lr, bj*IW+lc) maps to
677
+ # input[lr, lc] if input[bi, bj] != 0, else constant 0
678
+ # Problem: whether block is active depends on input value, which varies.
679
+ # This needs a different ONNX approach: can't use static gather.
680
+ # But we CAN use: Tile input -> Mul by mask derived from input
681
+ # Actually we need: for each (bi,bj) block position, multiply by inp[bi,bj] != 0
682
+ # This is NOT static - it depends on input content.
683
+ # Skip for now - spatial_gather can handle if block positions are fixed.
684
+ return None
685
+
686
  def s_kronecker(td):
687
+ """output = kron(input, ones(sH,sW)) — nearest-neighbor upscaling."""
688
  exs = get_exs(td)
689
  sp = fixed_shapes(td)
690
  if sp is None: return None
 
693
  sH, sW = OH // IH, OW // IW
694
  if sH < 2 or sW < 2: return None
695
  if OH > 30 or OW > 30: return None
696
+
697
  for inp, out in exs:
698
  expected = np.kron(inp, np.ones((sH, sW), dtype=np.int64))
699
+ if not np.array_equal(out, expected):
700
+ return None
701
+
702
+ # This is identical to upscale - build gather index
703
  idx = np.zeros((OH,OW,2), dtype=np.int64)
704
  for r in range(OH):
705
  for c in range(OW):
 
707
  return _build_gather_model(OH, OW, idx)
708
 
709
  def s_diagonal_tile(td):
710
+ """Input placed along diagonal: block[i,i] = input, rest = 0."""
711
  exs = get_exs(td)
712
  sp = fixed_shapes(td)
713
  if sp is None: return None
 
716
  rH, rW = OH // IH, OW // IW
717
  if rH != rW or rH < 2: return None
718
  if OH > 30 or OW > 30: return None
719
+
720
  for inp, out in exs:
721
  for bi in range(rH):
722
  for bj in range(rW):
723
  block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
724
  if bi == bj:
725
+ if not np.array_equal(block, inp):
726
+ return None
727
  else:
728
+ if not np.all(block == 0):
729
+ return None
730
+
731
+ # Build: diagonal blocks map to input, off-diagonal are constant 0
732
  idx = np.zeros((OH,OW,2), dtype=np.int64)
733
  cst = np.full((OH,OW), -1, dtype=np.int64)
734
  for bi in range(rH):
 
736
  for lr in range(IH):
737
  for lc in range(IW):
738
  oi, oj = bi*IH + lr, bj*IW + lc
739
+ if bi == bj:
740
+ idx[oi, oj] = [lr, lc]
741
+ else:
742
+ idx[oi, oj] = [-1, -1]
743
+ cst[oi, oj] = 0
744
+
745
  return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
746
 
747
  def s_shift(td):
748
+ """Detect constant spatial shift of the grid."""
749
  exs = get_exs(td)
750
  sp = fixed_shapes(td)
751
  if sp is None: return None
 
765
  if not np.array_equal(shifted, out):
766
  ok = False; break
767
  if not ok: continue
768
+ # Build gather index
769
  idx = np.zeros((OH, OW, 2), dtype=np.int64)
770
+ cst = np.full((OH, OW), 0, dtype=np.int64) # zeros for out-of-bounds
771
  for r in range(OH):
772
  for c in range(OW):
773
  sr, sc = r - dr, c - dc
774
+ if 0 <= sr < IH and 0 <= sc < IW:
775
+ idx[r, c] = [sr, sc]
776
+ else:
777
+ idx[r, c] = [-1, -1]
778
  return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
779
  return None
780
 
781
+ def s_gravity(td):
782
+ """Detect gravity-like compaction in one direction."""
783
+ exs = get_exs(td)
784
+ sp = fixed_shapes(td)
785
+ if sp is None: return None
786
+ (IH, IW), (OH, OW) = sp
787
+ if (IH, IW) != (OH, OW): return None
788
+
789
+ def _gravity(grid, direction):
790
+ r = np.zeros_like(grid); h, w = grid.shape
791
+ if direction in ('down', 'up'):
792
+ for c in range(w):
793
+ nz = grid[:, c][grid[:, c] != 0]
794
+ if direction == 'down': r[h-len(nz):h, c] = nz
795
+ else: r[:len(nz), c] = nz
796
+ else:
797
+ for rr in range(h):
798
+ nz = grid[rr, :][grid[rr, :] != 0]
799
+ if direction == 'right': r[rr, w-len(nz):w] = nz
800
+ else: r[rr, :len(nz)] = nz
801
+ return r
802
+
803
+ for d in ('down', 'up', 'left', 'right'):
804
+ if all(np.array_equal(_gravity(inp, d), out) for inp, out in exs):
805
+ # Gravity is input-dependent (positions depend on content)
806
+ # Can't use static Gather — need Conv to learn it
807
+ # But conv also can't learn arbitrary sorting...
808
+ # Skip for now — this needs a specialized ONNX graph
809
+ return None
810
+ return None
811
+
812
  def s_mirror_h(td):
813
+ """Output = input | flip(input, horizontal), doubling width."""
814
  exs = get_exs(td)
815
  sp = fixed_shapes(td)
816
  if sp is None: return None
 
820
  for inp, out in exs:
821
  expected = np.concatenate([inp, np.flip(inp, 1)], 1)
822
  if not np.array_equal(expected, out): return None
823
+ # Build gather index
824
  idx = np.zeros((OH, OW, 2), dtype=np.int64)
825
  for r in range(OH):
826
  for c in range(OW):
 
829
  return _build_gather_model(OH, OW, idx)
830
 
831
  def s_mirror_v(td):
832
+ """Output = input over flip(input, vertical), doubling height."""
833
  exs = get_exs(td)
834
  sp = fixed_shapes(td)
835
  if sp is None: return None
 
847
  return _build_gather_model(OH, OW, idx)
848
 
849
  def s_quad_mirror(td):
850
+ """Output = 2x2 block of input with h/v flips."""
851
  exs = get_exs(td)
852
  sp = fixed_shapes(td)
853
  if sp is None: return None
 
869
  return _build_gather_model(OH, OW, idx)
870
 
871
  def s_fixed_crop(td):
872
+ """Output = fixed subregion of input."""
873
  exs = get_exs(td)
874
  sp = fixed_shapes(td)
875
  if sp is None: return None
 
886
  return None
887
 
888
  def s_nonuniform_scale(td):
889
+ """Output = input scaled by different factors in h and w."""
890
  exs = get_exs(td)
891
  sp = fixed_shapes(td)
892
  if sp is None: return None
 
919
  helper.make_node('Add', ['s','c'], ['output'])]
920
  return mk(nodes, inits)
921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
922
  # ============================================================
923
+ # CONV SOLVERS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
924
  # ============================================================
925
 
926
  def add_onehot_block(nodes, inits, am_name, oh_name):
 
930
  nodes.append(helper.make_node('Equal', [am_name, 'classes'], ['eq']))
931
  nodes.append(helper.make_node('Cast', ['eq'], [oh_name], to=TensorProto.FLOAT))
932
 
933
+ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
934
+ """Shared lstsq conv fitting. Returns (Wconv, B) or None."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935
  pad = ks // 2
936
  feat = 10 * ks * ks + (1 if use_bias else 0)
937
  if feat > 20000: return None
938
 
939
+ patches, targets = [], []
940
  for inp_g, out_g in exs_raw:
941
  ih, iw = inp_g.shape
942
  if use_full_30:
 
949
  oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
950
 
951
  oh, ow = out_g.shape
952
+ for r in range(oh):
953
+ for c in range(ow):
954
+ p = oh_pad[:, r:r+ks, c:c+ks].flatten()
955
+ if use_bias: p = np.append(p, 1.0)
956
+ patches.append(p)
957
+ targets.append(int(out_g[r, c]))
958
+
959
+ n_patches = len(patches)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
960
  if feat > 5000 and n_patches > 2000: return None
961
 
962
+ P = np.array(patches, dtype=np.float64)
963
+ T = np.array(targets, dtype=np.int64)
964
  T_oh = np.zeros((len(T), 10), dtype=np.float64)
965
  for i, t in enumerate(T): T_oh[i, t] = 1.0
966
 
967
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
968
  if not np.array_equal(np.argmax(P @ WT, axis=1), T): return None
969
 
970
  if use_bias:
 
975
  B = None
976
  return Wconv, B
977
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
978
  def solve_conv_fixed(td, path, time_budget=30.0):
979
  """Fixed-shape conv: Slice -> Conv -> ArgMax -> Equal+Cast -> Pad."""
980
  exs = get_exs(td)
 
984
  if len(shapes) != 1: return None
985
  IH, IW = shapes.pop()
986
 
987
+ # Use ARC-GEN data for better fitting
988
  fit_exs = get_exs_for_fitting(td)
989
+ # Filter to same-shape, same IH/IW
990
  fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
991
 
992
  t_start = time.time()
993
  for use_bias in [False, True]:
994
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
995
  if time.time() - t_start > time_budget: return None
996
  result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=False)
997
  if result is None: continue
 
1015
  helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1016
  ]
1017
  add_onehot_block(nodes, inits, 'am', 'oh_out')
 
 
1018
  nodes.append(
1019
+ helper.make_node('Pad', ['oh_out'], ['output'],
1020
+ pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1021
  )
1022
 
1023
  model = mk(nodes, inits)
 
1031
  for inp, out in exs:
1032
  if inp.shape != out.shape: return None
1033
 
1034
+ # Use ARC-GEN data for better fitting (variable shape, embedded in 30x30)
1035
  fit_exs = get_exs_for_fitting_variable(td)
1036
  fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape]
1037
 
 
1140
  helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
1141
  ]
1142
  add_onehot_block(nodes, inits, 'am', 'oh_out')
 
 
1143
  nodes.append(
1144
+ helper.make_node('Pad', ['oh_out'], ['output'],
1145
+ pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1146
  )
1147
 
1148
  model = mk(nodes, inits)
 
1151
  return None
1152
 
1153
  def solve_conv_var_diff(td, path, time_budget=30.0):
1154
+ """Variable diff-shape conv: Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(output_mask).
1155
+ Works when output shape differs from input but mapping is convolutional on 30x30 grid."""
1156
  exs = get_exs(td)
1157
 
1158
  t_start = time.time()
 
1200
  Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1201
  B = None
1202
 
1203
+ # Use ReduceSum of output channels as mask (sum across channels == 1 for valid pixels)
1204
+ # But we don't know the output mask at inference time from input alone...
1205
+ # We need a way to derive the output mask from the input.
1206
+ # For same-shape: mask = ReduceSum(input, axis=1) works
1207
+ # For diff-shape: we need to compute the output mask differently
1208
+ #
1209
+ # Approach: Conv output at valid positions should have max > threshold,
1210
+ # and at padding positions max ≈ 0. Use the ArgMax+OneHot and then
1211
+ # mask with ReduceSum(input) which is 1 at input positions but 0 at padding.
1212
+ # BUT output may be LARGER than input...
1213
+ #
1214
+ # Alternative: just use Conv -> ArgMax -> Equal+Cast -> Mul(input_mask_expanded)
1215
+ # where input_mask covers the output region too.
1216
+ # This won't work if output extends beyond input region.
1217
+ #
1218
+ # Simplest correct approach: let the conv produce valid one-hot everywhere,
1219
+ # then the padding region should naturally produce channel-0 output.
1220
+ # Since padding is all-zero input, conv output there = bias only.
1221
+ # If no bias, conv output = 0 for all channels -> argmax gives channel 0 -> onehot gives [1,0,...,0]
1222
+ # which equals the padding encoding (channel 0 = 1 in padding).
1223
+ # Wait - that's WRONG for the NeuroGolf format. In the padding region, ALL channels should be 0.
1224
+ # The one-hot encoding has channel[color]=1, but padding = ALL zeros.
1225
+ #
1226
+ # So we NEED a mask. But for diff-shape, what mask?
1227
+ # If output is always top-left aligned and we know max output size...
1228
+ # We can't statically determine the output mask from the input.
1229
+ #
1230
+ # However: we can try the ReduceSum approach anyway — if conv naturally
1231
+ # produces channel-0 dominant output in padding, then:
1232
+ # mask = ReduceSum(input, axis=1) gives 1 for input pixels, 0 for padding
1233
+ # If output region ⊆ input region, this works.
1234
+ # If output region > input region... we need the output's ReduceSum instead.
1235
+
1236
+ # For tasks where output fits within input bounds, use input mask
1237
  all_output_within_input = all(
1238
  out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
1239
  for inp_g, out_g in exs
1240
  )
1241
 
1242
  if not all_output_within_input:
1243
+ continue # Skip tasks where output extends beyond input
1244
 
1245
  inits = [numpy_helper.from_array(Wconv, 'W')]
1246
  conv_inputs = ['input', 'W']
 
1262
  return None
1263
 
1264
  # ============================================================
1265
+ # PYTORCH LEARNED CONV (gradient descent, multi-seed, ternary snap)
1266
+ # ============================================================
1267
+
1268
+ def _ternary_snap(w, eps=0.2):
1269
+ """Snap weights to {-1, 0, 1} — smaller model, often still correct."""
1270
+ return np.where(w > eps, 1.0, np.where(w < -eps, -1.0, 0.0)).astype(np.float32)
1271
+
1272
+ def _build_conv_onnx_from_weights(W, ks, use_full_30=False, IH=None, IW=None):
1273
+ """Build ONNX conv model from numpy weight array W [10,10,ks,ks].
1274
+ For fixed-shape: Slice→Conv→ArgMax→Equal+Cast→Pad
1275
+ For variable/full30: Conv→ArgMax→Equal+Cast→Mul(mask)"""
1276
+ pad = ks // 2
1277
+ if use_full_30:
1278
+ # Variable shape: full 30x30 conv with mask
1279
+ inits = [numpy_helper.from_array(W, 'W')]
1280
+ nodes = [
1281
+ helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1282
+ helper.make_node('Conv', ['input', 'W'], ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1283
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1284
+ ]
1285
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1286
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1287
+ return mk(nodes, inits)
1288
+ else:
1289
+ # Fixed shape: slice, conv, pad
1290
+ pad_h, pad_w = GH - IH, GW - IW
1291
+ inits = [
1292
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1293
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1294
+ numpy_helper.from_array(W, 'W'),
1295
+ ]
1296
+ nodes = [
1297
+ helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1298
+ helper.make_node('Conv', ['grid', 'W'], ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1299
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1300
+ ]
1301
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1302
+ nodes.append(
1303
+ helper.make_node('Pad', ['oh_out'], ['output'],
1304
+ pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1305
+ )
1306
+ return mk(nodes, inits)
1307
+
1308
+ def _build_two_layer_conv_onnx(W1, W2, ks1, ks2, use_full_30=False, IH=None, IW=None):
1309
+ """Build ONNX two-layer conv: Conv→ReLU→Conv→ArgMax→Equal+Cast→Pad/Mul(mask)."""
1310
+ pad1, pad2 = ks1 // 2, ks2 // 2
1311
+ if use_full_30:
1312
+ inits = [
1313
+ numpy_helper.from_array(W1, 'W1'),
1314
+ numpy_helper.from_array(W2, 'W2'),
1315
+ ]
1316
+ nodes = [
1317
+ helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1318
+ helper.make_node('Conv', ['input', 'W1'], ['h1'], kernel_shape=[ks1,ks1], pads=[pad1]*4),
1319
+ helper.make_node('Relu', ['h1'], ['h1r']),
1320
+ helper.make_node('Conv', ['h1r', 'W2'], ['co'], kernel_shape=[ks2,ks2], pads=[pad2]*4),
1321
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1322
+ ]
1323
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1324
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1325
+ return mk(nodes, inits)
1326
+ else:
1327
+ pad_h, pad_w = GH - IH, GW - IW
1328
+ inits = [
1329
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1330
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1331
+ numpy_helper.from_array(W1, 'W1'),
1332
+ numpy_helper.from_array(W2, 'W2'),
1333
+ ]
1334
+ nodes = [
1335
+ helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1336
+ helper.make_node('Conv', ['grid', 'W1'], ['h1'], kernel_shape=[ks1,ks1], pads=[pad1]*4),
1337
+ helper.make_node('Relu', ['h1'], ['h1r']),
1338
+ helper.make_node('Conv', ['h1r', 'W2'], ['co'], kernel_shape=[ks2,ks2], pads=[pad2]*4),
1339
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1340
+ ]
1341
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1342
+ nodes.append(
1343
+ helper.make_node('Pad', ['oh_out'], ['output'],
1344
+ pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1345
+ )
1346
+ return mk(nodes, inits)
1347
+
1348
+ def solve_pytorch_conv(td, path, time_budget=30.0):
1349
+ """PyTorch gradient descent conv solver. Tries single-layer then two-layer.
1350
+ Multi-seed training with ternary weight snapping for smaller models.
1351
+ Validates against arc-gen before accepting."""
1352
+ try:
1353
+ import torch
1354
+ import torch.nn as nn
1355
+ import copy as _copy
1356
+ except ImportError:
1357
+ return None
1358
+
1359
+ exs = get_exs(td)
1360
+ same_shape = all(inp.shape == out.shape for inp, out in exs)
1361
+ if not same_shape:
1362
+ return None # Only handle same-shape for now
1363
+
1364
+ shapes = set(inp.shape for inp, _ in exs)
1365
+ fixed_in = len(shapes) == 1
1366
+
1367
+ # Prepare tensors
1368
+ all_pairs = td['train'] + td['test']
1369
+ inp_list = [to_onehot(p['input'])[0] for p in all_pairs]
1370
+ out_list = [to_onehot(p['output'])[0] for p in all_pairs]
1371
+ inp_t = torch.tensor(np.stack(inp_list), dtype=torch.float32)
1372
+ out_t = torch.tensor(np.stack(out_list), dtype=torch.float32)
1373
+
1374
+ if fixed_in:
1375
+ IH, IW = list(shapes)[0]
1376
+ # Train on cropped region
1377
+ inp_t = inp_t[:, :, :IH, :IW]
1378
+ out_t = out_t[:, :, :IH, :IW]
1379
+
1380
+ t_start = time.time()
1381
+ best_result = None
1382
+
1383
+ # Phase 1: Single-layer conv (multiple kernel sizes and seeds)
1384
+ for ks in [1, 3, 5, 7]:
1385
+ if time.time() - t_start > time_budget * 0.6:
1386
+ break
1387
+ pad = ks // 2
1388
+ for seed in [0, 7, 42]:
1389
+ if time.time() - t_start > time_budget * 0.6:
1390
+ break
1391
+ torch.manual_seed(seed)
1392
+ conv = nn.Conv2d(CH, CH, kernel_size=ks, padding=pad, bias=False)
1393
+ if seed == 0:
1394
+ nn.init.zeros_(conv.weight)
1395
+ opt = torch.optim.Adam(conv.parameters(), lr=0.03)
1396
+ best_loss, best_state = float('inf'), None
1397
+ for step in range(3000):
1398
+ opt.zero_grad()
1399
+ pred = conv(inp_t)
1400
+ loss = nn.functional.mse_loss(pred, out_t)
1401
+ loss.backward()
1402
+ opt.step()
1403
+ if loss.item() < best_loss:
1404
+ best_loss = loss.item()
1405
+ best_state = _copy.deepcopy(conv.state_dict())
1406
+ if best_loss < 1e-8:
1407
+ break
1408
+ if best_state is None:
1409
+ continue
1410
+ conv.load_state_dict(best_state)
1411
+ w = conv.weight.detach().numpy()
1412
+
1413
+ # Try continuous weights, then ternary-snapped
1414
+ for w_cand in [w, _ternary_snap(w)]:
1415
+ use_full = not fixed_in
1416
+ model = _build_conv_onnx_from_weights(
1417
+ w_cand, ks, use_full_30=use_full,
1418
+ IH=IH if fixed_in else None,
1419
+ IW=IW if fixed_in else None
1420
+ )
1421
+ onnx.save(model, path)
1422
+ if validate(path, td):
1423
+ sz = os.path.getsize(path)
1424
+ if best_result is None or sz < best_result[2]:
1425
+ best_result = ('pt_conv', model, sz)
1426
+
1427
+ # Phase 2: Two-layer conv (Conv→ReLU→Conv)
1428
+ for ks1, ks2, hidden in [(3, 1, CH), (5, 1, CH), (3, 3, CH)]:
1429
+ if time.time() - t_start > time_budget:
1430
+ break
1431
+ for seed in [0, 7]:
1432
+ if time.time() - t_start > time_budget:
1433
+ break
1434
+ torch.manual_seed(seed)
1435
+ net = nn.Sequential(
1436
+ nn.Conv2d(CH, hidden, kernel_size=ks1, padding=ks1//2, bias=False),
1437
+ nn.ReLU(),
1438
+ nn.Conv2d(hidden, CH, kernel_size=ks2, padding=ks2//2, bias=False),
1439
+ )
1440
+ opt = torch.optim.Adam(net.parameters(), lr=0.01)
1441
+ best_loss, best_state = float('inf'), None
1442
+ for step in range(2500):
1443
+ opt.zero_grad()
1444
+ pred = net(inp_t)
1445
+ loss = nn.functional.mse_loss(pred, out_t)
1446
+ loss.backward()
1447
+ opt.step()
1448
+ if loss.item() < best_loss:
1449
+ best_loss = loss.item()
1450
+ best_state = _copy.deepcopy(net.state_dict())
1451
+ if best_loss < 1e-8:
1452
+ break
1453
+ if best_state is None:
1454
+ continue
1455
+ net.load_state_dict(best_state)
1456
+ w1 = net[0].weight.detach().numpy()
1457
+ w2 = net[2].weight.detach().numpy()
1458
+
1459
+ for w1c, w2c in [(w1, w2), (_ternary_snap(w1), _ternary_snap(w2))]:
1460
+ use_full = not fixed_in
1461
+ model = _build_two_layer_conv_onnx(
1462
+ w1c, w2c, ks1, ks2, use_full_30=use_full,
1463
+ IH=IH if fixed_in else None,
1464
+ IW=IW if fixed_in else None
1465
+ )
1466
+ onnx.save(model, path)
1467
+ if validate(path, td):
1468
+ sz = os.path.getsize(path)
1469
+ if best_result is None or sz < best_result[2]:
1470
+ best_result = ('pt_conv2', model, sz)
1471
+
1472
+ if best_result is not None:
1473
+ sname, model, _ = best_result
1474
+ onnx.save(model, path)
1475
+ return sname, model
1476
+ return None
1477
+
1478
+ # ============================================================
1479
+ # MAIN
1480
  # ============================================================
1481
 
1482
  ANALYTICAL_SOLVERS = [
1483
+ ('identity', s_identity), ('constant', s_constant), ('color_map', s_color_map),
1484
+ ('transpose', s_transpose), ('flip', s_flip), ('rotate', s_rotate),
1485
+ ('tile', s_tile), ('upscale', s_upscale), ('kronecker', s_kronecker),
 
 
 
 
 
 
1486
  ('nonuniform_scale', s_nonuniform_scale),
1487
+ ('mirror_h', s_mirror_h), ('mirror_v', s_mirror_v), ('quad_mirror', s_quad_mirror),
1488
+ ('concat', s_concat), ('concat_enhanced', s_concat_enhanced),
 
 
 
1489
  ('diagonal_tile', s_diagonal_tile),
1490
  ('fixed_crop', s_fixed_crop),
1491
  ('spatial_gather', s_spatial_gather),
 
1493
  ('varshape_spatial_gather', s_varshape_spatial_gather),
1494
  ]
1495
 
1496
+ def solve_task(tn, td, outdir, conv_budget=30.0):
 
 
 
 
 
 
1497
  t_start = time.time()
1498
  os.makedirs(outdir, exist_ok=True)
1499
  path = os.path.join(outdir, f"task{tn:03d}.onnx")
1500
 
1501
+ # Skip excluded tasks
1502
  if tn in EXCLUDED_TASKS:
1503
  return False, 'excluded', None, time.time() - t_start, path
1504
 
1505
  # 1. Try analytical solvers (fast, tiny models)
1506
  for sname, sfn in ANALYTICAL_SOLVERS:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1507
  try:
1508
  model = sfn(td)
1509
  if model is None: continue
1510
  onnx.save(model, path)
1511
+ if validate(path, td):
1512
  return True, sname, os.path.getsize(path), time.time() - t_start, path
1513
+ except: pass
 
1514
 
1515
+ # 2. Determine task shape category and try conv solvers
1516
  exs = get_exs(td)
1517
  same_shape = all(inp.shape == out.shape for inp, out in exs)
1518
  shapes = set(inp.shape for inp, _ in exs)
 
1530
  if result is not None:
1531
  sname, model = result
1532
  return True, sname, os.path.getsize(path), time.time() - t_start, path
1533
+ # 3. PyTorch learned conv as fallback for same-shape tasks
1534
+ remaining = max(1, conv_time - (time.time() - t_start))
1535
+ result = solve_pytorch_conv(td, path, time_budget=remaining)
1536
+ if result is not None:
1537
+ sname, model = result
1538
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1539
  else:
1540
  sp = fixed_shapes(td)
1541
  if sp is not None:
 
1546
  sname, model = result
1547
  return True, sname, os.path.getsize(path), time.time() - t_start, path
1548
 
1549
+ # Try variable diff-shape conv (output within input bounds)
1550
  result = solve_conv_var_diff(td, path, time_budget=conv_time)
1551
  if result is not None:
1552
  sname, model = result
 
1554
 
1555
  return False, None, None, time.time() - t_start, path
1556
 
1557
+ def run_tasks(task_nums, tasks, output_dir, conv_budget, use_wandb):
1558
  results = {}
1559
  costs_dict = {}
1560
  total_score = 0
 
1566
  continue
1567
 
1568
  td = tasks[tn]['data']
1569
+ ok, sname, sz, t_task, model_path = solve_task(tn, td, output_dir, conv_budget)
1570
 
1571
  if ok:
1572
  macs, memory, params = score_network(model_path)
 
1606
  parser.add_argument('--tasks', type=str, default='')
1607
  parser.add_argument('--device', type=str, default='auto', choices=['auto','cpu','cuda'])
1608
  parser.add_argument('--use_wandb', action='store_true')
 
1609
  args = parser.parse_args()
1610
  global ORT_PROVIDERS
1611
  config = {
 
1623
 
1624
  ort.set_default_logger_severity(3)
1625
  print(f"Using providers: {ORT_PROVIDERS}")
 
1626
 
1627
  if args.kaggle:
1628
  tasks = load_tasks_kaggle(args.data_dir)
 
1630
  arcgen = args.arcgen_dir if args.arcgen_dir else None
1631
  tasks = load_tasks_dir(args.data_dir, arcgen_dir=arcgen)
1632
 
1633
+ # Count arc-gen examples
1634
  total_arcgen = sum(len(t['data'].get('arc-gen', [])) for t in tasks.values())
1635
  print(f"Loaded {len(tasks)} tasks ({total_arcgen} ARC-GEN examples)")
1636
  print(f"Excluded tasks: {sorted(EXCLUDED_TASKS)}")
 
1639
  active_tasks = [t for t in task_nums if t not in EXCLUDED_TASKS]
1640
  print(f"Solving {len(active_tasks)} active tasks (skipping {len(task_nums) - len(active_tasks)} excluded)")
1641
  print(f"Conv budget: {args.conv_budget}s per task")
 
1642
  print("=" * 70)
1643
  t0 = time.time()
1644
 
 
 
1645
  if args.use_wandb and wandb is not None:
1646
  with wandb.init(project="neurogolf", name="solver_run", config=config):
1647
+ results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=True)
1648
  else:
1649
+ results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=False)
1650
 
1651
  elapsed = time.time() - t0
1652
  print(f"\n{'='*70}")
 
1655
  sc = Counter(solver_names)
1656
  for s, c in sc.most_common(): print(f" {s}: {c}")
1657
 
1658
+ # Generate submission
1659
  outdir = args.output_dir
1660
  n_files = len([f for f in os.listdir(outdir) if f.endswith('.onnx')])
1661
  total_size = sum(os.path.getsize(os.path.join(outdir, f))
1662
  for f in os.listdir(outdir) if f.endswith('.onnx'))
1663
 
1664
+ # Create submission.zip
1665
  zip_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.zip')
1666
  buf = io.BytesIO()
1667
  with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
 
1673
  f.write(zip_bytes)
1674
  zip_size = len(zip_bytes)
1675
 
1676
+ # Create submission.csv
1677
  csv_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.csv')
1678
  with open(csv_path, 'w', newline='') as f:
1679
  w = csv.writer(f)
 
1681
  for tn in sorted(costs_dict.keys()):
1682
  w.writerow([f'task{tn:03d}', costs_dict[tn]])
1683
 
1684
+ # Estimate LB score: solved tasks get their score, unsolved get 1.0
1685
  unsolved_count = len(active_tasks) - len(results)
1686
  est_lb = total_score + unsolved_count * 1.0
1687
 
 
1691
  print(f"Written: {zip_path} | {csv_path}")
1692
 
1693
  if __name__ == '__main__':
1694
+ main()