rogermt commited on
Commit
3bbfe2f
·
verified ·
1 Parent(s): 4fd4c3e

Upload complete neurogolf solver v1

Browse files
Files changed (1) hide show
  1. neurogolf_solver.py +565 -2
neurogolf_solver.py CHANGED
@@ -1,3 +1,566 @@
1
  #!/usr/bin/env python3
2
- # See full file at: https://huggingface.co/rogermt/neurogolf-solver
3
- # Upload the actual file content from /app/neurogolf_solver.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
+ """
3
+ ARC-AGI NeuroGolf Championship - Complete Solver
4
+ Format: [1,10,30,30] one-hot input/output, opset 10, IR version 10.
5
+ Pipeline: Slice -> Conv -> ArgMax -> OneHot -> Pad
6
+
7
+ Usage:
8
+ python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission
9
+
10
+ On Kaggle:
11
+ python neurogolf_solver.py --data_dir /kaggle/input/competitions/neurogolf-2026/ --output_dir submission --kaggle
12
+ """
13
+
14
+ import json, os, sys, math, time, argparse
15
+ import numpy as np
16
+ import onnx
17
+ from onnx import helper, TensorProto, numpy_helper
18
+ import onnxruntime as ort
19
+ from collections import Counter
20
+
21
+ # Constants
22
+ BATCH, CH, GH, GW = 1, 10, 30, 30
23
+ GRID_SHAPE = [BATCH, CH, GH, GW]
24
+ DT = TensorProto.FLOAT
25
+ IR = 10
26
+ OPSET = [helper.make_opsetid("", 10)]
27
+
28
+
29
+ def load_tasks_dir(data_dir):
30
+ """Load tasks from directory of JSON files."""
31
+ files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
32
+ tasks = {}
33
+ for i, f in enumerate(files):
34
+ with open(os.path.join(data_dir, f)) as fh:
35
+ tasks[i+1] = {'hex': f.replace('.json',''), 'data': json.load(fh)}
36
+ return tasks
37
+
38
+
39
+ def load_tasks_kaggle(data_dir):
40
+ """Load tasks from Kaggle competition format (task001.json etc.)."""
41
+ tasks = {}
42
+ for tn in range(1, 401):
43
+ path = os.path.join(data_dir, f"task{tn:03d}.json")
44
+ if os.path.exists(path):
45
+ with open(path) as f:
46
+ tasks[tn] = {'hex': f'task{tn:03d}', 'data': json.load(f)}
47
+ return tasks
48
+
49
+
50
+ def to_onehot(grid):
51
+ arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
52
+ for r, row in enumerate(grid):
53
+ for c, v in enumerate(row):
54
+ arr[0, v, r, c] = 1.0
55
+ return arr
56
+
57
+
58
+ def validate(path, td):
59
+ """Validate ONNX model against all train+test examples."""
60
+ try:
61
+ sess = ort.InferenceSession(path, providers=['CPUExecutionProvider'])
62
+ except:
63
+ return False
64
+ examples = td['train'] + td['test']
65
+ if 'arc-gen' in td:
66
+ examples = examples + td['arc-gen']
67
+ for ex in examples:
68
+ inp = to_onehot(ex['input'])
69
+ exp = to_onehot(ex['output'])
70
+ try:
71
+ out = sess.run(['output'], {'input': inp})[0]
72
+ out = (out > 0.0).astype(np.float32)
73
+ except:
74
+ return False
75
+ if not np.array_equal(out, exp):
76
+ return False
77
+ return True
78
+
79
+
80
+ def mk(nodes, inits=None):
81
+ x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
82
+ y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
83
+ g = helper.make_graph(nodes, "g", [x], [y], initializer=inits or [])
84
+ return helper.make_model(g, ir_version=IR, opset_imports=OPSET)
85
+
86
+
87
+ def get_exs(td):
88
+ return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
89
+ for ex in td['train'] + td['test']]
90
+
91
+
92
+ def fixed_shapes(td):
93
+ shapes = set()
94
+ for inp, out in get_exs(td):
95
+ shapes.add((inp.shape, out.shape))
96
+ return list(shapes)[0] if len(shapes) == 1 else None
97
+
98
+
99
+ # ============================================================
100
+ # SOLVERS
101
+ # ============================================================
102
+
103
+ def s_identity(td):
104
+ for ex in td['train']+td['test']:
105
+ if ex['input'] != ex['output']:
106
+ return None
107
+ return mk([helper.make_node('Identity', ['input'], ['output'])])
108
+
109
+
110
+ def s_color_map(td):
111
+ """1x1 conv implementing color permutation."""
112
+ cm = {}
113
+ for ex in td['train']+td['test']:
114
+ inp, out = np.array(ex['input']), np.array(ex['output'])
115
+ if inp.shape != out.shape: return None
116
+ for iv, ov in zip(inp.flat, out.flat):
117
+ iv, ov = int(iv), int(ov)
118
+ if iv in cm and cm[iv] != ov: return None
119
+ cm[iv] = ov
120
+ W = np.zeros((10,10,1,1), dtype=np.float32)
121
+ for ic in range(10):
122
+ W[cm.get(ic,ic), ic, 0, 0] = 1.0
123
+ return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
124
+ [numpy_helper.from_array(W, 'W')])
125
+
126
+
127
+ def s_transpose(td):
128
+ """Swap rows and columns."""
129
+ for ex in td['train']+td['test']:
130
+ if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T):
131
+ return None
132
+ return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0,1,3,2])])
133
+
134
+
135
+ def s_flip(td):
136
+ """Flip vertically or horizontally using GatherElements."""
137
+ exs = get_exs(td)
138
+ sp = fixed_shapes(td)
139
+ if sp is None: return None
140
+ (IH,IW),(OH,OW) = sp
141
+ if (IH,IW) != (OH,OW): return None
142
+ for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
143
+ if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
144
+ if axis == 0:
145
+ idx = np.arange(GH).reshape(1,1,GH,1).repeat(CH,1).repeat(GW,3)
146
+ for r in range(IH):
147
+ idx[0,:,r,:] = IH - 1 - r
148
+ else:
149
+ idx = np.arange(GW).reshape(1,1,1,GW).repeat(CH,1).repeat(GH,2)
150
+ for c in range(IW):
151
+ idx[0,:,:,c] = IW - 1 - c
152
+ ax = 2 if axis == 0 else 3
153
+ return mk(
154
+ [helper.make_node('GatherElements', ['input','idx'], ['output'], axis=ax)],
155
+ [numpy_helper.from_array(idx.astype(np.int64), 'idx')]
156
+ )
157
+ return None
158
+
159
+
160
+ def s_rotate(td):
161
+ """Rotate 90/180/270 degrees."""
162
+ exs = get_exs(td)
163
+ sp = fixed_shapes(td)
164
+ if sp is None: return None
165
+ (IH,IW),(OH,OW) = sp
166
+ for k in [1, 2, 3]:
167
+ if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs):
168
+ continue
169
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
170
+ for r in range(OH):
171
+ for c in range(OW):
172
+ if k == 1: sr, sc = c, IH-1-r
173
+ elif k == 2: sr, sc = IH-1-r, IW-1-c
174
+ elif k == 3: sr, sc = IW-1-c, r
175
+ idx[r,c] = [sr, sc]
176
+ return _build_gather_model(OH, OW, idx)
177
+ return None
178
+
179
+
180
+ def s_spatial_gather(td):
181
+ """Each output pixel copied from a fixed input pixel."""
182
+ sp = fixed_shapes(td)
183
+ if sp is None: return None
184
+ (IH,IW),(OH,OW) = sp
185
+ exs = get_exs(td)
186
+ idx = np.full((OH,OW,2), -1, dtype=np.int64)
187
+ cst = np.full((OH,OW), -1, dtype=np.int64)
188
+ for oi in range(OH):
189
+ for oj in range(OW):
190
+ vals = set(int(out[oi,oj]) for _,out in exs)
191
+ if len(vals) == 1:
192
+ cst[oi,oj] = vals.pop()
193
+ found = False
194
+ for ri in range(IH):
195
+ for rj in range(IW):
196
+ if all(int(inp[ri,rj]) == int(out[oi,oj]) for inp,out in exs):
197
+ idx[oi,oj] = [ri, rj]
198
+ found = True
199
+ break
200
+ if found: break
201
+ if not found and cst[oi,oj] < 0:
202
+ return None
203
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
204
+
205
+
206
+ def s_tile(td):
207
+ """Tile input NxM times."""
208
+ exs = get_exs(td)
209
+ in_shapes = set(inp.shape for inp,_ in exs)
210
+ if len(in_shapes) != 1: return None
211
+ IH, IW = in_shapes.pop()
212
+ tiles = set()
213
+ for inp, out in exs:
214
+ OH, OW = out.shape
215
+ if OH % IH or OW % IW: return None
216
+ rH, rW = OH//IH, OW//IW
217
+ if rH < 1 or rW < 1 or (rH==1 and rW==1): return None
218
+ tiles.add((rH, rW))
219
+ if len(tiles) != 1: return None
220
+ rH, rW = tiles.pop()
221
+ OH, OW = IH*rH, IW*rW
222
+ if OH > 30 or OW > 30: return None
223
+ for inp, out in exs:
224
+ if not np.array_equal(out, np.tile(inp, (rH, rW))): return None
225
+ pad_h, pad_w = 30-OH, 30-OW
226
+ inits = [
227
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'st'),
228
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'en'),
229
+ numpy_helper.from_array(np.array([1,1,rH,rW], dtype=np.int64), 'rp'),
230
+ ]
231
+ nodes = [
232
+ helper.make_node('Slice', ['input','st','en'], ['cr']),
233
+ helper.make_node('Tile', ['cr','rp'], ['tl']),
234
+ helper.make_node('Pad', ['tl'], ['output'],
235
+ pads=[0,0,0,0, 0,0,pad_h,pad_w], value=0.0),
236
+ ]
237
+ return mk(nodes, inits)
238
+
239
+
240
+ def s_upscale(td):
241
+ """Nearest-neighbor upscale by integer factor."""
242
+ exs = get_exs(td)
243
+ in_shapes = set(inp.shape for inp,_ in exs)
244
+ if len(in_shapes) != 1: return None
245
+ IH, IW = in_shapes.pop()
246
+ scales = set()
247
+ for inp, out in exs:
248
+ OH, OW = out.shape
249
+ if OH % IH or OW % IW: return None
250
+ sH, sW = OH//IH, OW//IW
251
+ if sH < 2 or sW < 2: return None
252
+ scales.add((sH, sW))
253
+ if len(scales) != 1: return None
254
+ sH, sW = scales.pop()
255
+ OH, OW = IH*sH, IW*sW
256
+ if OH > 30 or OW > 30: return None
257
+ for inp, out in exs:
258
+ if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)):
259
+ return None
260
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
261
+ for r in range(OH):
262
+ for c in range(OW):
263
+ idx[r,c] = [r//sH, c//sW]
264
+ return _build_gather_model(OH, OW, idx)
265
+
266
+
267
+ def s_concat(td):
268
+ """Output = concat of transformed copies of input."""
269
+ from itertools import product as iproduct
270
+ exs = get_exs(td)
271
+ sp = fixed_shapes(td)
272
+ if sp is None: return None
273
+ (IH,IW),(OH,OW) = sp
274
+ transforms = [
275
+ ('id', lambda x: x),
276
+ ('fliplr', lambda x: np.fliplr(x)),
277
+ ('flipud', lambda x: np.flipud(x)),
278
+ ('rot180', lambda x: np.rot90(x, 2)),
279
+ ]
280
+ if OH == IH and OW % IW == 0 and OW > IW:
281
+ n = OW // IW
282
+ if 2 <= n <= 4:
283
+ for combo in iproduct(range(4), repeat=n):
284
+ if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=1))
285
+ for inp, out in exs):
286
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
287
+ for oi in range(OH):
288
+ for oj in range(OW):
289
+ bj = oj // IW
290
+ lr, lc = oi, oj % IW
291
+ t = transforms[combo[bj]][0]
292
+ if t == 'id': sr, sc = lr, lc
293
+ elif t == 'fliplr': sr, sc = lr, IW-1-lc
294
+ elif t == 'flipud': sr, sc = IH-1-lr, lc
295
+ elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
296
+ idx[oi,oj] = [sr, sc]
297
+ return _build_gather_model(OH, OW, idx)
298
+ if OW == IW and OH % IH == 0 and OH > IH:
299
+ n = OH // IH
300
+ if 2 <= n <= 4:
301
+ for combo in iproduct(range(4), repeat=n):
302
+ if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=0))
303
+ for inp, out in exs):
304
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
305
+ for oi in range(OH):
306
+ for oj in range(OW):
307
+ bi = oi // IH
308
+ lr, lc = oi % IH, oj
309
+ t = transforms[combo[bi]][0]
310
+ if t == 'id': sr, sc = lr, lc
311
+ elif t == 'fliplr': sr, sc = lr, IW-1-lc
312
+ elif t == 'flipud': sr, sc = IH-1-lr, lc
313
+ elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
314
+ idx[oi,oj] = [sr, sc]
315
+ return _build_gather_model(OH, OW, idx)
316
+ return None
317
+
318
+
319
+ def s_constant(td):
320
+ """Output is always the same."""
321
+ sp = fixed_shapes(td)
322
+ if sp is None: return None
323
+ exs = get_exs(td)
324
+ outs = [out for _,out in exs]
325
+ if not all(np.array_equal(outs[0], o) for o in outs[1:]): return None
326
+ const = np.zeros((1,10,30,30), dtype=np.float32)
327
+ for r, row in enumerate(outs[0]):
328
+ for c, v in enumerate(row):
329
+ const[0, int(v), r, c] = 1.0
330
+ inits = [
331
+ numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
332
+ numpy_helper.from_array(const, 'c'),
333
+ ]
334
+ nodes = [
335
+ helper.make_node('Mul', ['input','z'], ['zd']),
336
+ helper.make_node('ReduceSum', ['zd'], ['s'], axes=[1,2,3], keepdims=1),
337
+ helper.make_node('Add', ['s','c'], ['output']),
338
+ ]
339
+ return mk(nodes, inits)
340
+
341
+
342
+ # ============================================================
343
+ # CONV SOLVER (the main workhorse)
344
+ # ============================================================
345
+
346
+ def solve_conv(td, path, time_budget=30.0, try_bias=True):
347
+ """Solve same-shape task with one-hot conv + ArgMax + OneHot.
348
+ Returns model or None."""
349
+ exs = get_exs(td)
350
+ for inp, out in exs:
351
+ if inp.shape != out.shape: return None
352
+ shapes = set(inp.shape for inp, _ in exs)
353
+ if len(shapes) != 1: return None
354
+ IH, IW = shapes.pop()
355
+ t_start = time.time()
356
+ for use_bias in ([False, True] if try_bias else [False]):
357
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
358
+ if time.time() - t_start > time_budget:
359
+ return None
360
+ pad = ks // 2
361
+ feat = 10 * ks * ks + (1 if use_bias else 0)
362
+ n_grid = sum(inp.size for inp, _ in exs)
363
+ if feat > 20000 or (feat > 5000 and n_grid > 2000):
364
+ continue
365
+ patches, targets = [], []
366
+ for inp_g, out_g in exs:
367
+ ih, iw = inp_g.shape
368
+ oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
369
+ for c in range(10):
370
+ oh_enc[c] = (inp_g == c)
371
+ oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
372
+ for r in range(ih):
373
+ for c in range(iw):
374
+ p = oh_pad[:, r:r+ks, c:c+ks].flatten()
375
+ if use_bias:
376
+ p = np.append(p, 1.0)
377
+ patches.append(p)
378
+ targets.append(int(out_g[r, c]))
379
+ P = np.array(patches, dtype=np.float64)
380
+ T = np.array(targets, dtype=np.int64)
381
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
382
+ for i, t in enumerate(T):
383
+ T_oh[i, t] = 1.0
384
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
385
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T):
386
+ continue
387
+ if use_bias:
388
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
389
+ B = WT[-1].astype(np.float32)
390
+ else:
391
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
392
+ B = None
393
+ pad_h = GH - IH
394
+ pad_w = GW - IW
395
+ inits = [
396
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
397
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
398
+ numpy_helper.from_array(Wconv, 'W'),
399
+ numpy_helper.from_array(np.array(10, dtype=np.int64), 'depth'),
400
+ numpy_helper.from_array(np.array([0.0, 1.0], dtype=np.float32), 'ohvals'),
401
+ ]
402
+ conv_inputs = ['grid', 'W']
403
+ if B is not None:
404
+ inits.append(numpy_helper.from_array(B, 'B'))
405
+ conv_inputs.append('B')
406
+ nodes = [
407
+ helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
408
+ helper.make_node('Conv', conv_inputs, ['co'],
409
+ kernel_shape=[ks,ks], pads=[pad]*4),
410
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=0),
411
+ helper.make_node('OneHot', ['am','depth','ohvals'], ['oh_out'], axis=1),
412
+ helper.make_node('Pad', ['oh_out'], ['output'],
413
+ pads=[0,0,0,0, 0,0,pad_h,pad_w], value=0.0),
414
+ ]
415
+ model = mk(nodes, inits)
416
+ onnx.save(model, path)
417
+ if validate(path, td):
418
+ return model
419
+ return None
420
+
421
+
422
+ # ============================================================
423
+ # GATHER HELPERS
424
+ # ============================================================
425
+
426
+ def _build_gather_model(OH, OW, idx):
427
+ """Build model from index array idx[OH,OW,2] -> (src_r, src_c)."""
428
+ flat_idx = np.zeros((1,10,GH*GW), dtype=np.int64)
429
+ mask = np.zeros((1,1,GH,GW), dtype=np.float32)
430
+ for oi in range(OH):
431
+ for oj in range(OW):
432
+ flat = idx[oi,oj,0]*GW + idx[oi,oj,1]
433
+ flat_idx[0,:,oi*GW+oj] = flat
434
+ mask[0,0,oi,oj] = 1.0
435
+ inits = [
436
+ numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
437
+ numpy_helper.from_array(flat_idx, 'idx'),
438
+ numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
439
+ numpy_helper.from_array(mask, 'mask'),
440
+ ]
441
+ nodes = [
442
+ helper.make_node('Reshape', ['input','fs'], ['flat']),
443
+ helper.make_node('GatherElements', ['flat','idx'], ['g'], axis=2),
444
+ helper.make_node('Reshape', ['g','os'], ['raw']),
445
+ helper.make_node('Mul', ['raw','mask'], ['output']),
446
+ ]
447
+ return mk(nodes, inits)
448
+
449
+
450
+ def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
451
+ """Build gather model with constant values for some positions."""
452
+ flat_idx = np.zeros((1,10,GH*GW), dtype=np.int64)
453
+ gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
454
+ const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
455
+ for oi in range(OH):
456
+ for oj in range(OW):
457
+ if idx[oi,oj,0] >= 0:
458
+ flat = idx[oi,oj,0]*GW + idx[oi,oj,1]
459
+ flat_idx[0,:,oi*GW+oj] = flat
460
+ gather_mask[0,0,oi,oj] = 1.0
461
+ elif cst[oi,oj] >= 0:
462
+ const_oh[0, cst[oi,oj], oi, oj] = 1.0
463
+ has_const = np.any(const_oh > 0)
464
+ inits = [
465
+ numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
466
+ numpy_helper.from_array(flat_idx, 'idx'),
467
+ numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
468
+ numpy_helper.from_array(gather_mask, 'gmask'),
469
+ ]
470
+ nodes = [
471
+ helper.make_node('Reshape', ['input','fs'], ['flat']),
472
+ helper.make_node('GatherElements', ['flat','idx'], ['g'], axis=2),
473
+ helper.make_node('Reshape', ['g','os'], ['raw']),
474
+ helper.make_node('Mul', ['raw','gmask'], ['masked']),
475
+ ]
476
+ if has_const:
477
+ inits.append(numpy_helper.from_array(const_oh, 'cst'))
478
+ nodes.append(helper.make_node('Add', ['masked','cst'], ['output']))
479
+ else:
480
+ nodes[-1] = helper.make_node('Mul', ['raw','gmask'], ['output'])
481
+ return mk(nodes, inits)
482
+
483
+
484
+ # ============================================================
485
+ # MAIN SOLVER
486
+ # ============================================================
487
+
488
+ ANALYTICAL_SOLVERS = [
489
+ ('identity', s_identity),
490
+ ('constant', s_constant),
491
+ ('color_map', s_color_map),
492
+ ('transpose', s_transpose),
493
+ ('flip', s_flip),
494
+ ('rotate', s_rotate),
495
+ ('tile', s_tile),
496
+ ('upscale', s_upscale),
497
+ ('concat', s_concat),
498
+ ('spatial_gather', s_spatial_gather),
499
+ ]
500
+
501
+
502
+ def solve_task(tn, td, outdir, conv_budget=30.0):
503
+ """Solve one task. Returns (solved, solver_name, file_size)."""
504
+ os.makedirs(outdir, exist_ok=True)
505
+ path = os.path.join(outdir, f"task{tn:03d}.onnx")
506
+ for sname, sfn in ANALYTICAL_SOLVERS:
507
+ try:
508
+ model = sfn(td)
509
+ if model is None:
510
+ continue
511
+ onnx.save(model, path)
512
+ if validate(path, td):
513
+ return True, sname, os.path.getsize(path)
514
+ except:
515
+ pass
516
+ model = solve_conv(td, path, time_budget=conv_budget)
517
+ if model is not None:
518
+ return True, 'conv', os.path.getsize(path)
519
+ return False, None, None
520
+
521
+
522
+ def main():
523
+ parser = argparse.ArgumentParser()
524
+ parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
525
+ parser.add_argument('--output_dir', default='submission')
526
+ parser.add_argument('--kaggle', action='store_true', help='Use Kaggle format')
527
+ parser.add_argument('--conv_budget', type=float, default=30.0, help='Seconds per task for conv')
528
+ parser.add_argument('--tasks', type=str, default='', help='Comma-separated task numbers to solve')
529
+ args = parser.parse_args()
530
+ if args.kaggle:
531
+ tasks = load_tasks_kaggle(args.data_dir)
532
+ else:
533
+ tasks = load_tasks_dir(args.data_dir)
534
+ if args.tasks:
535
+ task_nums = [int(t) for t in args.tasks.split(',')]
536
+ else:
537
+ task_nums = sorted(tasks.keys())
538
+ print(f"Loaded {len(tasks)} tasks, solving {len(task_nums)}")
539
+ print(f"Conv budget: {args.conv_budget}s per task")
540
+ print("=" * 70)
541
+ t0 = time.time()
542
+ results = {}
543
+ for tn in task_nums:
544
+ if tn not in tasks:
545
+ continue
546
+ td = tasks[tn]['data']
547
+ ok, sname, sz = solve_task(tn, td, args.output_dir, args.conv_budget)
548
+ if ok:
549
+ results[tn] = sname
550
+ print(f"Task {tn:3d}: {sname:20s} ({sz:>8,} bytes)")
551
+ else:
552
+ print(f"Task {tn:3d}: UNSOLVED")
553
+ elapsed = time.time() - t0
554
+ print(f"\n{'='*70}")
555
+ print(f"Solved: {len(results)}/{len(task_nums)} in {elapsed:.0f}s")
556
+ sc = Counter(results.values())
557
+ for s, c in sc.most_common():
558
+ print(f" {s}: {c}")
559
+ n_files = len([f for f in os.listdir(args.output_dir) if f.endswith('.onnx')])
560
+ total_size = sum(os.path.getsize(os.path.join(args.output_dir, f))
561
+ for f in os.listdir(args.output_dir) if f.endswith('.onnx'))
562
+ print(f"\n{n_files} ONNX files, total {total_size/1024:.1f} KB")
563
+
564
+
565
+ if __name__ == '__main__':
566
+ main()