rogermt commited on
Commit
4a06c47
·
verified ·
1 Parent(s): 2cde25d

v2: add --device cpu/cuda/auto flag, fix providers

Browse files
Files changed (1) hide show
  1. neurogolf_solver.py +67 -145
neurogolf_solver.py CHANGED
@@ -6,9 +6,7 @@ Pipeline: Slice -> Conv -> ArgMax -> OneHot -> Pad
6
 
7
  Usage:
8
  python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission
9
-
10
- On Kaggle:
11
- python neurogolf_solver.py --data_dir /kaggle/input/competitions/neurogolf-2026/ --output_dir submission --kaggle
12
  """
13
 
14
  import json, os, sys, math, time, argparse
@@ -18,16 +16,22 @@ from onnx import helper, TensorProto, numpy_helper
18
  import onnxruntime as ort
19
  from collections import Counter
20
 
21
- # Constants
22
  BATCH, CH, GH, GW = 1, 10, 30, 30
23
  GRID_SHAPE = [BATCH, CH, GH, GW]
24
  DT = TensorProto.FLOAT
25
  IR = 10
26
  OPSET = [helper.make_opsetid("", 10)]
27
 
 
 
 
 
 
 
 
 
28
 
29
  def load_tasks_dir(data_dir):
30
- """Load tasks from directory of JSON files."""
31
  files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
32
  tasks = {}
33
  for i, f in enumerate(files):
@@ -35,9 +39,7 @@ def load_tasks_dir(data_dir):
35
  tasks[i+1] = {'hex': f.replace('.json',''), 'data': json.load(fh)}
36
  return tasks
37
 
38
-
39
  def load_tasks_kaggle(data_dir):
40
- """Load tasks from Kaggle competition format (task001.json etc.)."""
41
  tasks = {}
42
  for tn in range(1, 401):
43
  path = os.path.join(data_dir, f"task{tn:03d}.json")
@@ -46,7 +48,6 @@ def load_tasks_kaggle(data_dir):
46
  tasks[tn] = {'hex': f'task{tn:03d}', 'data': json.load(f)}
47
  return tasks
48
 
49
-
50
  def to_onehot(grid):
51
  arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
52
  for r, row in enumerate(grid):
@@ -54,11 +55,9 @@ def to_onehot(grid):
54
  arr[0, v, r, c] = 1.0
55
  return arr
56
 
57
-
58
  def validate(path, td):
59
- """Validate ONNX model against all train+test examples."""
60
  try:
61
- sess = ort.InferenceSession(path, providers=['CPUExecutionProvider'])
62
  except:
63
  return False
64
  examples = td['train'] + td['test']
@@ -76,39 +75,32 @@ def validate(path, td):
76
  return False
77
  return True
78
 
79
-
80
  def mk(nodes, inits=None):
81
  x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
82
  y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
83
  g = helper.make_graph(nodes, "g", [x], [y], initializer=inits or [])
84
  return helper.make_model(g, ir_version=IR, opset_imports=OPSET)
85
 
86
-
87
  def get_exs(td):
88
  return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
89
  for ex in td['train'] + td['test']]
90
 
91
-
92
  def fixed_shapes(td):
93
  shapes = set()
94
  for inp, out in get_exs(td):
95
  shapes.add((inp.shape, out.shape))
96
  return list(shapes)[0] if len(shapes) == 1 else None
97
 
98
-
99
  # ============================================================
100
  # SOLVERS
101
  # ============================================================
102
 
103
  def s_identity(td):
104
  for ex in td['train']+td['test']:
105
- if ex['input'] != ex['output']:
106
- return None
107
  return mk([helper.make_node('Identity', ['input'], ['output'])])
108
 
109
-
110
  def s_color_map(td):
111
- """1x1 conv implementing color permutation."""
112
  cm = {}
113
  for ex in td['train']+td['test']:
114
  inp, out = np.array(ex['input']), np.array(ex['output'])
@@ -123,17 +115,12 @@ def s_color_map(td):
123
  return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
124
  [numpy_helper.from_array(W, 'W')])
125
 
126
-
127
  def s_transpose(td):
128
- """Swap rows and columns."""
129
  for ex in td['train']+td['test']:
130
- if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T):
131
- return None
132
  return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0,1,3,2])])
133
 
134
-
135
  def s_flip(td):
136
- """Flip vertically or horizontally using GatherElements."""
137
  exs = get_exs(td)
138
  sp = fixed_shapes(td)
139
  if sp is None: return None
@@ -143,29 +130,22 @@ def s_flip(td):
143
  if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
144
  if axis == 0:
145
  idx = np.arange(GH).reshape(1,1,GH,1).repeat(CH,1).repeat(GW,3)
146
- for r in range(IH):
147
- idx[0,:,r,:] = IH - 1 - r
148
  else:
149
  idx = np.arange(GW).reshape(1,1,1,GW).repeat(CH,1).repeat(GH,2)
150
- for c in range(IW):
151
- idx[0,:,:,c] = IW - 1 - c
152
  ax = 2 if axis == 0 else 3
153
- return mk(
154
- [helper.make_node('GatherElements', ['input','idx'], ['output'], axis=ax)],
155
- [numpy_helper.from_array(idx.astype(np.int64), 'idx')]
156
- )
157
  return None
158
 
159
-
160
  def s_rotate(td):
161
- """Rotate 90/180/270 degrees."""
162
  exs = get_exs(td)
163
  sp = fixed_shapes(td)
164
  if sp is None: return None
165
  (IH,IW),(OH,OW) = sp
166
  for k in [1, 2, 3]:
167
- if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs):
168
- continue
169
  idx = np.zeros((OH,OW,2), dtype=np.int64)
170
  for r in range(OH):
171
  for c in range(OW):
@@ -176,9 +156,7 @@ def s_rotate(td):
176
  return _build_gather_model(OH, OW, idx)
177
  return None
178
 
179
-
180
  def s_spatial_gather(td):
181
- """Each output pixel copied from a fixed input pixel."""
182
  sp = fixed_shapes(td)
183
  if sp is None: return None
184
  (IH,IW),(OH,OW) = sp
@@ -188,23 +166,17 @@ def s_spatial_gather(td):
188
  for oi in range(OH):
189
  for oj in range(OW):
190
  vals = set(int(out[oi,oj]) for _,out in exs)
191
- if len(vals) == 1:
192
- cst[oi,oj] = vals.pop()
193
  found = False
194
  for ri in range(IH):
195
  for rj in range(IW):
196
  if all(int(inp[ri,rj]) == int(out[oi,oj]) for inp,out in exs):
197
- idx[oi,oj] = [ri, rj]
198
- found = True
199
- break
200
  if found: break
201
- if not found and cst[oi,oj] < 0:
202
- return None
203
  return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
204
 
205
-
206
  def s_tile(td):
207
- """Tile input NxM times."""
208
  exs = get_exs(td)
209
  in_shapes = set(inp.shape for inp,_ in exs)
210
  if len(in_shapes) != 1: return None
@@ -231,14 +203,11 @@ def s_tile(td):
231
  nodes = [
232
  helper.make_node('Slice', ['input','st','en'], ['cr']),
233
  helper.make_node('Tile', ['cr','rp'], ['tl']),
234
- helper.make_node('Pad', ['tl'], ['output'],
235
- pads=[0,0,0,0, 0,0,pad_h,pad_w], value=0.0),
236
  ]
237
  return mk(nodes, inits)
238
 
239
-
240
  def s_upscale(td):
241
- """Nearest-neighbor upscale by integer factor."""
242
  exs = get_exs(td)
243
  in_shapes = set(inp.shape for inp,_ in exs)
244
  if len(in_shapes) != 1: return None
@@ -255,27 +224,22 @@ def s_upscale(td):
255
  OH, OW = IH*sH, IW*sW
256
  if OH > 30 or OW > 30: return None
257
  for inp, out in exs:
258
- if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)):
259
- return None
260
  idx = np.zeros((OH,OW,2), dtype=np.int64)
261
  for r in range(OH):
262
  for c in range(OW):
263
  idx[r,c] = [r//sH, c//sW]
264
  return _build_gather_model(OH, OW, idx)
265
 
266
-
267
  def s_concat(td):
268
- """Output = concat of transformed copies of input."""
269
  from itertools import product as iproduct
270
  exs = get_exs(td)
271
  sp = fixed_shapes(td)
272
  if sp is None: return None
273
  (IH,IW),(OH,OW) = sp
274
  transforms = [
275
- ('id', lambda x: x),
276
- ('fliplr', lambda x: np.fliplr(x)),
277
- ('flipud', lambda x: np.flipud(x)),
278
- ('rot180', lambda x: np.rot90(x, 2)),
279
  ]
280
  if OH == IH and OW % IW == 0 and OW > IW:
281
  n = OW // IW
@@ -286,8 +250,7 @@ def s_concat(td):
286
  idx = np.zeros((OH,OW,2), dtype=np.int64)
287
  for oi in range(OH):
288
  for oj in range(OW):
289
- bj = oj // IW
290
- lr, lc = oi, oj % IW
291
  t = transforms[combo[bj]][0]
292
  if t == 'id': sr, sc = lr, lc
293
  elif t == 'fliplr': sr, sc = lr, IW-1-lc
@@ -304,8 +267,7 @@ def s_concat(td):
304
  idx = np.zeros((OH,OW,2), dtype=np.int64)
305
  for oi in range(OH):
306
  for oj in range(OW):
307
- bi = oi // IH
308
- lr, lc = oi % IH, oj
309
  t = transforms[combo[bi]][0]
310
  if t == 'id': sr, sc = lr, lc
311
  elif t == 'fliplr': sr, sc = lr, IW-1-lc
@@ -315,9 +277,7 @@ def s_concat(td):
315
  return _build_gather_model(OH, OW, idx)
316
  return None
317
 
318
-
319
  def s_constant(td):
320
- """Output is always the same."""
321
  sp = fixed_shapes(td)
322
  if sp is None: return None
323
  exs = get_exs(td)
@@ -327,25 +287,18 @@ def s_constant(td):
327
  for r, row in enumerate(outs[0]):
328
  for c, v in enumerate(row):
329
  const[0, int(v), r, c] = 1.0
330
- inits = [
331
- numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
332
- numpy_helper.from_array(const, 'c'),
333
- ]
334
- nodes = [
335
- helper.make_node('Mul', ['input','z'], ['zd']),
336
- helper.make_node('ReduceSum', ['zd'], ['s'], axes=[1,2,3], keepdims=1),
337
- helper.make_node('Add', ['s','c'], ['output']),
338
- ]
339
  return mk(nodes, inits)
340
 
341
-
342
  # ============================================================
343
- # CONV SOLVER (the main workhorse)
344
  # ============================================================
345
 
346
  def solve_conv(td, path, time_budget=30.0, try_bias=True):
347
- """Solve same-shape task with one-hot conv + ArgMax + OneHot.
348
- Returns model or None."""
349
  exs = get_exs(td)
350
  for inp, out in exs:
351
  if inp.shape != out.shape: return None
@@ -355,43 +308,36 @@ def solve_conv(td, path, time_budget=30.0, try_bias=True):
355
  t_start = time.time()
356
  for use_bias in ([False, True] if try_bias else [False]):
357
  for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
358
- if time.time() - t_start > time_budget:
359
- return None
360
  pad = ks // 2
361
  feat = 10 * ks * ks + (1 if use_bias else 0)
362
  n_grid = sum(inp.size for inp, _ in exs)
363
- if feat > 20000 or (feat > 5000 and n_grid > 2000):
364
- continue
365
  patches, targets = [], []
366
  for inp_g, out_g in exs:
367
  ih, iw = inp_g.shape
368
  oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
369
- for c in range(10):
370
- oh_enc[c] = (inp_g == c)
371
  oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
372
  for r in range(ih):
373
  for c in range(iw):
374
  p = oh_pad[:, r:r+ks, c:c+ks].flatten()
375
- if use_bias:
376
- p = np.append(p, 1.0)
377
  patches.append(p)
378
  targets.append(int(out_g[r, c]))
379
  P = np.array(patches, dtype=np.float64)
380
  T = np.array(targets, dtype=np.int64)
381
  T_oh = np.zeros((len(T), 10), dtype=np.float64)
382
- for i, t in enumerate(T):
383
- T_oh[i, t] = 1.0
384
  WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
385
- if not np.array_equal(np.argmax(P @ WT, axis=1), T):
386
- continue
387
  if use_bias:
388
  Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
389
  B = WT[-1].astype(np.float32)
390
  else:
391
  Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
392
  B = None
393
- pad_h = GH - IH
394
- pad_w = GW - IW
395
  inits = [
396
  numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
397
  numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
@@ -405,32 +351,26 @@ def solve_conv(td, path, time_budget=30.0, try_bias=True):
405
  conv_inputs.append('B')
406
  nodes = [
407
  helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
408
- helper.make_node('Conv', conv_inputs, ['co'],
409
- kernel_shape=[ks,ks], pads=[pad]*4),
410
  helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=0),
411
  helper.make_node('OneHot', ['am','depth','ohvals'], ['oh_out'], axis=1),
412
- helper.make_node('Pad', ['oh_out'], ['output'],
413
- pads=[0,0,0,0, 0,0,pad_h,pad_w], value=0.0),
414
  ]
415
  model = mk(nodes, inits)
416
  onnx.save(model, path)
417
- if validate(path, td):
418
- return model
419
  return None
420
 
421
-
422
  # ============================================================
423
  # GATHER HELPERS
424
  # ============================================================
425
 
426
  def _build_gather_model(OH, OW, idx):
427
- """Build model from index array idx[OH,OW,2] -> (src_r, src_c)."""
428
  flat_idx = np.zeros((1,10,GH*GW), dtype=np.int64)
429
  mask = np.zeros((1,1,GH,GW), dtype=np.float32)
430
  for oi in range(OH):
431
  for oj in range(OW):
432
- flat = idx[oi,oj,0]*GW + idx[oi,oj,1]
433
- flat_idx[0,:,oi*GW+oj] = flat
434
  mask[0,0,oi,oj] = 1.0
435
  inits = [
436
  numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
@@ -446,17 +386,14 @@ def _build_gather_model(OH, OW, idx):
446
  ]
447
  return mk(nodes, inits)
448
 
449
-
450
  def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
451
- """Build gather model with constant values for some positions."""
452
  flat_idx = np.zeros((1,10,GH*GW), dtype=np.int64)
453
  gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
454
  const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
455
  for oi in range(OH):
456
  for oj in range(OW):
457
  if idx[oi,oj,0] >= 0:
458
- flat = idx[oi,oj,0]*GW + idx[oi,oj,1]
459
- flat_idx[0,:,oi*GW+oj] = flat
460
  gather_mask[0,0,oi,oj] = 1.0
461
  elif cst[oi,oj] >= 0:
462
  const_oh[0, cst[oi,oj], oi, oj] = 1.0
@@ -480,69 +417,56 @@ def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
480
  nodes[-1] = helper.make_node('Mul', ['raw','gmask'], ['output'])
481
  return mk(nodes, inits)
482
 
483
-
484
  # ============================================================
485
- # MAIN SOLVER
486
  # ============================================================
487
 
488
  ANALYTICAL_SOLVERS = [
489
- ('identity', s_identity),
490
- ('constant', s_constant),
491
- ('color_map', s_color_map),
492
- ('transpose', s_transpose),
493
- ('flip', s_flip),
494
- ('rotate', s_rotate),
495
- ('tile', s_tile),
496
- ('upscale', s_upscale),
497
- ('concat', s_concat),
498
  ('spatial_gather', s_spatial_gather),
499
  ]
500
 
501
-
502
  def solve_task(tn, td, outdir, conv_budget=30.0):
503
- """Solve one task. Returns (solved, solver_name, file_size)."""
504
  os.makedirs(outdir, exist_ok=True)
505
  path = os.path.join(outdir, f"task{tn:03d}.onnx")
506
  for sname, sfn in ANALYTICAL_SOLVERS:
507
  try:
508
  model = sfn(td)
509
- if model is None:
510
- continue
511
  onnx.save(model, path)
512
- if validate(path, td):
513
- return True, sname, os.path.getsize(path)
514
- except:
515
- pass
516
  model = solve_conv(td, path, time_budget=conv_budget)
517
- if model is not None:
518
- return True, 'conv', os.path.getsize(path)
519
  return False, None, None
520
 
521
-
522
  def main():
523
  parser = argparse.ArgumentParser()
524
  parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
525
  parser.add_argument('--output_dir', default='submission')
526
- parser.add_argument('--kaggle', action='store_true', help='Use Kaggle format')
527
- parser.add_argument('--conv_budget', type=float, default=30.0, help='Seconds per task for conv')
528
- parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers to solve')
 
529
  args = parser.parse_args()
530
- if args.kaggle:
531
- tasks = load_tasks_kaggle(args.data_dir)
532
- else:
533
- tasks = load_tasks_dir(args.data_dir)
534
- if args.tasks:
535
- task_nums = [int(t) for t in args.tasks.split(',')]
536
- else:
537
- task_nums = sorted(tasks.keys())
 
538
  print(f"Loaded {len(tasks)} tasks, solving {len(task_nums)}")
539
  print(f"Conv budget: {args.conv_budget}s per task")
540
  print("=" * 70)
541
  t0 = time.time()
542
  results = {}
543
  for tn in task_nums:
544
- if tn not in tasks:
545
- continue
546
  td = tasks[tn]['data']
547
  ok, sname, sz = solve_task(tn, td, args.output_dir, args.conv_budget)
548
  if ok:
@@ -554,13 +478,11 @@ def main():
554
  print(f"\n{'='*70}")
555
  print(f"Solved: {len(results)}/{len(task_nums)} in {elapsed:.0f}s")
556
  sc = Counter(results.values())
557
- for s, c in sc.most_common():
558
- print(f" {s}: {c}")
559
  n_files = len([f for f in os.listdir(args.output_dir) if f.endswith('.onnx')])
560
  total_size = sum(os.path.getsize(os.path.join(args.output_dir, f))
561
  for f in os.listdir(args.output_dir) if f.endswith('.onnx'))
562
  print(f"\n{n_files} ONNX files, total {total_size/1024:.1f} KB")
563
 
564
-
565
  if __name__ == '__main__':
566
  main()
 
6
 
7
  Usage:
8
  python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission
9
+ python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission --device cuda --conv_budget 60
 
 
10
  """
11
 
12
  import json, os, sys, math, time, argparse
 
16
  import onnxruntime as ort
17
  from collections import Counter
18
 
 
19
  BATCH, CH, GH, GW = 1, 10, 30, 30
20
  GRID_SHAPE = [BATCH, CH, GH, GW]
21
  DT = TensorProto.FLOAT
22
  IR = 10
23
  OPSET = [helper.make_opsetid("", 10)]
24
 
25
+ def get_providers():
26
+ available = ort.get_available_providers()
27
+ if 'CUDAExecutionProvider' in available:
28
+ return ['CUDAExecutionProvider', 'CPUExecutionProvider']
29
+ return ['CPUExecutionProvider']
30
+
31
+ ORT_PROVIDERS = get_providers()
32
+ print(f"ONNX Runtime providers: {ORT_PROVIDERS}")
33
 
34
  def load_tasks_dir(data_dir):
 
35
  files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
36
  tasks = {}
37
  for i, f in enumerate(files):
 
39
  tasks[i+1] = {'hex': f.replace('.json',''), 'data': json.load(fh)}
40
  return tasks
41
 
 
42
  def load_tasks_kaggle(data_dir):
 
43
  tasks = {}
44
  for tn in range(1, 401):
45
  path = os.path.join(data_dir, f"task{tn:03d}.json")
 
48
  tasks[tn] = {'hex': f'task{tn:03d}', 'data': json.load(f)}
49
  return tasks
50
 
 
51
  def to_onehot(grid):
52
  arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
53
  for r, row in enumerate(grid):
 
55
  arr[0, v, r, c] = 1.0
56
  return arr
57
 
 
58
  def validate(path, td):
 
59
  try:
60
+ sess = ort.InferenceSession(path, providers=ORT_PROVIDERS)
61
  except:
62
  return False
63
  examples = td['train'] + td['test']
 
75
  return False
76
  return True
77
 
 
78
  def mk(nodes, inits=None):
79
  x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
80
  y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
81
  g = helper.make_graph(nodes, "g", [x], [y], initializer=inits or [])
82
  return helper.make_model(g, ir_version=IR, opset_imports=OPSET)
83
 
 
84
  def get_exs(td):
85
  return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
86
  for ex in td['train'] + td['test']]
87
 
 
88
  def fixed_shapes(td):
89
  shapes = set()
90
  for inp, out in get_exs(td):
91
  shapes.add((inp.shape, out.shape))
92
  return list(shapes)[0] if len(shapes) == 1 else None
93
 
 
94
  # ============================================================
95
  # SOLVERS
96
  # ============================================================
97
 
98
  def s_identity(td):
99
  for ex in td['train']+td['test']:
100
+ if ex['input'] != ex['output']: return None
 
101
  return mk([helper.make_node('Identity', ['input'], ['output'])])
102
 
 
103
  def s_color_map(td):
 
104
  cm = {}
105
  for ex in td['train']+td['test']:
106
  inp, out = np.array(ex['input']), np.array(ex['output'])
 
115
  return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
116
  [numpy_helper.from_array(W, 'W')])
117
 
 
118
  def s_transpose(td):
 
119
  for ex in td['train']+td['test']:
120
+ if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T): return None
 
121
  return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0,1,3,2])])
122
 
 
123
  def s_flip(td):
 
124
  exs = get_exs(td)
125
  sp = fixed_shapes(td)
126
  if sp is None: return None
 
130
  if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
131
  if axis == 0:
132
  idx = np.arange(GH).reshape(1,1,GH,1).repeat(CH,1).repeat(GW,3)
133
+ for r in range(IH): idx[0,:,r,:] = IH - 1 - r
 
134
  else:
135
  idx = np.arange(GW).reshape(1,1,1,GW).repeat(CH,1).repeat(GH,2)
136
+ for c in range(IW): idx[0,:,:,c] = IW - 1 - c
 
137
  ax = 2 if axis == 0 else 3
138
+ return mk([helper.make_node('GatherElements', ['input','idx'], ['output'], axis=ax)],
139
+ [numpy_helper.from_array(idx.astype(np.int64), 'idx')])
 
 
140
  return None
141
 
 
142
  def s_rotate(td):
 
143
  exs = get_exs(td)
144
  sp = fixed_shapes(td)
145
  if sp is None: return None
146
  (IH,IW),(OH,OW) = sp
147
  for k in [1, 2, 3]:
148
+ if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs): continue
 
149
  idx = np.zeros((OH,OW,2), dtype=np.int64)
150
  for r in range(OH):
151
  for c in range(OW):
 
156
  return _build_gather_model(OH, OW, idx)
157
  return None
158
 
 
159
  def s_spatial_gather(td):
 
160
  sp = fixed_shapes(td)
161
  if sp is None: return None
162
  (IH,IW),(OH,OW) = sp
 
166
  for oi in range(OH):
167
  for oj in range(OW):
168
  vals = set(int(out[oi,oj]) for _,out in exs)
169
+ if len(vals) == 1: cst[oi,oj] = vals.pop()
 
170
  found = False
171
  for ri in range(IH):
172
  for rj in range(IW):
173
  if all(int(inp[ri,rj]) == int(out[oi,oj]) for inp,out in exs):
174
+ idx[oi,oj] = [ri, rj]; found = True; break
 
 
175
  if found: break
176
+ if not found and cst[oi,oj] < 0: return None
 
177
  return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
178
 
 
179
  def s_tile(td):
 
180
  exs = get_exs(td)
181
  in_shapes = set(inp.shape for inp,_ in exs)
182
  if len(in_shapes) != 1: return None
 
203
  nodes = [
204
  helper.make_node('Slice', ['input','st','en'], ['cr']),
205
  helper.make_node('Tile', ['cr','rp'], ['tl']),
206
+ helper.make_node('Pad', ['tl'], ['output'], pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0),
 
207
  ]
208
  return mk(nodes, inits)
209
 
 
210
  def s_upscale(td):
 
211
  exs = get_exs(td)
212
  in_shapes = set(inp.shape for inp,_ in exs)
213
  if len(in_shapes) != 1: return None
 
224
  OH, OW = IH*sH, IW*sW
225
  if OH > 30 or OW > 30: return None
226
  for inp, out in exs:
227
+ if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)): return None
 
228
  idx = np.zeros((OH,OW,2), dtype=np.int64)
229
  for r in range(OH):
230
  for c in range(OW):
231
  idx[r,c] = [r//sH, c//sW]
232
  return _build_gather_model(OH, OW, idx)
233
 
 
234
  def s_concat(td):
 
235
  from itertools import product as iproduct
236
  exs = get_exs(td)
237
  sp = fixed_shapes(td)
238
  if sp is None: return None
239
  (IH,IW),(OH,OW) = sp
240
  transforms = [
241
+ ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
242
+ ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
 
 
243
  ]
244
  if OH == IH and OW % IW == 0 and OW > IW:
245
  n = OW // IW
 
250
  idx = np.zeros((OH,OW,2), dtype=np.int64)
251
  for oi in range(OH):
252
  for oj in range(OW):
253
+ bj = oj // IW; lr, lc = oi, oj % IW
 
254
  t = transforms[combo[bj]][0]
255
  if t == 'id': sr, sc = lr, lc
256
  elif t == 'fliplr': sr, sc = lr, IW-1-lc
 
267
  idx = np.zeros((OH,OW,2), dtype=np.int64)
268
  for oi in range(OH):
269
  for oj in range(OW):
270
+ bi = oi // IH; lr, lc = oi % IH, oj
 
271
  t = transforms[combo[bi]][0]
272
  if t == 'id': sr, sc = lr, lc
273
  elif t == 'fliplr': sr, sc = lr, IW-1-lc
 
277
  return _build_gather_model(OH, OW, idx)
278
  return None
279
 
 
280
  def s_constant(td):
 
281
  sp = fixed_shapes(td)
282
  if sp is None: return None
283
  exs = get_exs(td)
 
287
  for r, row in enumerate(outs[0]):
288
  for c, v in enumerate(row):
289
  const[0, int(v), r, c] = 1.0
290
+ inits = [numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
291
+ numpy_helper.from_array(const, 'c')]
292
+ nodes = [helper.make_node('Mul', ['input','z'], ['zd']),
293
+ helper.make_node('ReduceSum', ['zd'], ['s'], axes=[1,2,3], keepdims=1),
294
+ helper.make_node('Add', ['s','c'], ['output'])]
 
 
 
 
295
  return mk(nodes, inits)
296
 
 
297
  # ============================================================
298
+ # CONV SOLVER
299
  # ============================================================
300
 
301
  def solve_conv(td, path, time_budget=30.0, try_bias=True):
 
 
302
  exs = get_exs(td)
303
  for inp, out in exs:
304
  if inp.shape != out.shape: return None
 
308
  t_start = time.time()
309
  for use_bias in ([False, True] if try_bias else [False]):
310
  for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
311
+ if time.time() - t_start > time_budget: return None
 
312
  pad = ks // 2
313
  feat = 10 * ks * ks + (1 if use_bias else 0)
314
  n_grid = sum(inp.size for inp, _ in exs)
315
+ if feat > 20000 or (feat > 5000 and n_grid > 2000): continue
 
316
  patches, targets = [], []
317
  for inp_g, out_g in exs:
318
  ih, iw = inp_g.shape
319
  oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
320
+ for c in range(10): oh_enc[c] = (inp_g == c)
 
321
  oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
322
  for r in range(ih):
323
  for c in range(iw):
324
  p = oh_pad[:, r:r+ks, c:c+ks].flatten()
325
+ if use_bias: p = np.append(p, 1.0)
 
326
  patches.append(p)
327
  targets.append(int(out_g[r, c]))
328
  P = np.array(patches, dtype=np.float64)
329
  T = np.array(targets, dtype=np.int64)
330
  T_oh = np.zeros((len(T), 10), dtype=np.float64)
331
+ for i, t in enumerate(T): T_oh[i, t] = 1.0
 
332
  WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
333
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
 
334
  if use_bias:
335
  Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
336
  B = WT[-1].astype(np.float32)
337
  else:
338
  Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
339
  B = None
340
+ pad_h, pad_w = GH - IH, GW - IW
 
341
  inits = [
342
  numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
343
  numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
 
351
  conv_inputs.append('B')
352
  nodes = [
353
  helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
354
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
 
355
  helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=0),
356
  helper.make_node('OneHot', ['am','depth','ohvals'], ['oh_out'], axis=1),
357
+ helper.make_node('Pad', ['oh_out'], ['output'], pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0),
 
358
  ]
359
  model = mk(nodes, inits)
360
  onnx.save(model, path)
361
+ if validate(path, td): return model
 
362
  return None
363
 
 
364
  # ============================================================
365
  # GATHER HELPERS
366
  # ============================================================
367
 
368
  def _build_gather_model(OH, OW, idx):
 
369
  flat_idx = np.zeros((1,10,GH*GW), dtype=np.int64)
370
  mask = np.zeros((1,1,GH,GW), dtype=np.float32)
371
  for oi in range(OH):
372
  for oj in range(OW):
373
+ flat_idx[0,:,oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
 
374
  mask[0,0,oi,oj] = 1.0
375
  inits = [
376
  numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
 
386
  ]
387
  return mk(nodes, inits)
388
 
 
389
  def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
 
390
  flat_idx = np.zeros((1,10,GH*GW), dtype=np.int64)
391
  gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
392
  const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
393
  for oi in range(OH):
394
  for oj in range(OW):
395
  if idx[oi,oj,0] >= 0:
396
+ flat_idx[0,:,oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
 
397
  gather_mask[0,0,oi,oj] = 1.0
398
  elif cst[oi,oj] >= 0:
399
  const_oh[0, cst[oi,oj], oi, oj] = 1.0
 
417
  nodes[-1] = helper.make_node('Mul', ['raw','gmask'], ['output'])
418
  return mk(nodes, inits)
419
 
 
420
  # ============================================================
421
+ # MAIN
422
  # ============================================================
423
 
424
  ANALYTICAL_SOLVERS = [
425
+ ('identity', s_identity), ('constant', s_constant), ('color_map', s_color_map),
426
+ ('transpose', s_transpose), ('flip', s_flip), ('rotate', s_rotate),
427
+ ('tile', s_tile), ('upscale', s_upscale), ('concat', s_concat),
 
 
 
 
 
 
428
  ('spatial_gather', s_spatial_gather),
429
  ]
430
 
 
431
  def solve_task(tn, td, outdir, conv_budget=30.0):
 
432
  os.makedirs(outdir, exist_ok=True)
433
  path = os.path.join(outdir, f"task{tn:03d}.onnx")
434
  for sname, sfn in ANALYTICAL_SOLVERS:
435
  try:
436
  model = sfn(td)
437
+ if model is None: continue
 
438
  onnx.save(model, path)
439
+ if validate(path, td): return True, sname, os.path.getsize(path)
440
+ except: pass
 
 
441
  model = solve_conv(td, path, time_budget=conv_budget)
442
+ if model is not None: return True, 'conv', os.path.getsize(path)
 
443
  return False, None, None
444
 
 
445
  def main():
446
  parser = argparse.ArgumentParser()
447
  parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
448
  parser.add_argument('--output_dir', default='submission')
449
+ parser.add_argument('--kaggle', action='store_true')
450
+ parser.add_argument('--conv_budget', type=float, default=30.0)
451
+ parser.add_argument('--tasks', type=str, default='')
452
+ parser.add_argument('--device', type=str, default='auto', choices=['auto','cpu','cuda'])
453
  args = parser.parse_args()
454
+ global ORT_PROVIDERS
455
+ if args.device == 'cuda':
456
+ ORT_PROVIDERS = ['CUDAExecutionProvider', 'CPUExecutionProvider']
457
+ elif args.device == 'cpu':
458
+ ORT_PROVIDERS = ['CPUExecutionProvider']
459
+ print(f"Using providers: {ORT_PROVIDERS}")
460
+ if args.kaggle: tasks = load_tasks_kaggle(args.data_dir)
461
+ else: tasks = load_tasks_dir(args.data_dir)
462
+ task_nums = [int(t) for t in args.tasks.split(',')] if args.tasks else sorted(tasks.keys())
463
  print(f"Loaded {len(tasks)} tasks, solving {len(task_nums)}")
464
  print(f"Conv budget: {args.conv_budget}s per task")
465
  print("=" * 70)
466
  t0 = time.time()
467
  results = {}
468
  for tn in task_nums:
469
+ if tn not in tasks: continue
 
470
  td = tasks[tn]['data']
471
  ok, sname, sz = solve_task(tn, td, args.output_dir, args.conv_budget)
472
  if ok:
 
478
  print(f"\n{'='*70}")
479
  print(f"Solved: {len(results)}/{len(task_nums)} in {elapsed:.0f}s")
480
  sc = Counter(results.values())
481
+ for s, c in sc.most_common(): print(f" {s}: {c}")
 
482
  n_files = len([f for f in os.listdir(args.output_dir) if f.endswith('.onnx')])
483
  total_size = sum(os.path.getsize(os.path.join(args.output_dir, f))
484
  for f in os.listdir(args.output_dir) if f.endswith('.onnx'))
485
  print(f"\n{n_files} ONNX files, total {total_size/1024:.1f} KB")
486
 
 
487
  if __name__ == '__main__':
488
  main()