rogermt commited on
Commit
260c341
·
verified ·
1 Parent(s): 760f72a

Upload neurogolf_solver.py

Browse files
Files changed (1) hide show
  1. neurogolf_solver.py +1 -1694
neurogolf_solver.py CHANGED
@@ -1,1694 +1 @@
1
- #!/usr/bin/env python3
2
- """
3
- ARC-AGI NeuroGolf Championship - Complete Solver v4
4
- Format: [1,10,30,30] one-hot input/output, opset 10, IR version 10.
5
-
6
- v4 CRITICAL FIXES:
7
- - ARC-GEN data loaded and used for conv fitting (more data = better lstsq)
8
- - ARC-GEN validation: models validated against train+test+arc-gen
9
- - EXCLUDED tasks: {21, 55, 80, 184, 202, 366} skipped
10
- - submission.csv generation for Kaggle
11
- - s_flip fixed: GatherElements -> Gather (opset 10 compat)
12
- - Static profiler: no onnx_tool dependency for cost estimation
13
- - get_exs_for_fitting(): uses train+test+arc-gen for conv fitting
14
-
15
- Solvers:
16
- - Analytical: identity, constant, color_map, transpose, flip, rotate, tile, upscale,
17
- concat, concat_enhanced, spatial_gather, varshape_spatial_gather,
18
- diagonal_tile, kronecker
19
- - Conv (fixed shape): Slice -> Conv -> ArgMax -> Equal+Cast -> Pad
20
- - Conv (variable shape): Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)
21
- - Conv (diff shape): Slice -> Conv -> Slice(crop) -> ArgMax -> Equal+Cast -> Pad
22
-
23
- Usage:
24
- python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission
25
- python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission --conv_budget 60 --arcgen_dir ARC-GEN-100K/
26
- """
27
-
28
- import json, os, sys, math, time, argparse, csv, io, zipfile
29
- import numpy as np
30
- import onnx
31
- from onnx import helper, TensorProto, numpy_helper
32
- import onnxruntime as ort
33
- from collections import Counter
34
-
35
- try:
36
- from neurogolf_utils import score_network as _score_network_official
37
- HAS_ONNX_TOOL = True
38
- except ImportError:
39
- HAS_ONNX_TOOL = False
40
-
41
- try:
42
- import wandb
43
- except ImportError:
44
- wandb = None
45
-
46
- BATCH, CH, GH, GW = 1, 10, 30, 30
47
- GRID_SHAPE = [BATCH, CH, GH, GW]
48
- DT = TensorProto.FLOAT
49
- IR = 10
50
- OPSET = [helper.make_opsetid("", 10)]
51
-
52
- # Officially excluded tasks (score 0 regardless)
53
- EXCLUDED_TASKS = {21, 55, 80, 184, 202, 366}
54
-
55
- # Max ARC-GEN examples to use for validation (to keep runtime reasonable)
56
- MAX_ARCGEN_VALIDATE = 30
57
- # Max ARC-GEN examples for conv fitting (keep separate from validation!)
58
- # NOTE: Conv fitting uses train+test only. ARC-GEN is for VALIDATION only.
59
- # lstsq underdetermines with too many variable-size arc-gen examples.
60
- MAX_ARCGEN_FIT = 0 # Don't use arc-gen for fitting — use for validation only
61
-
62
- def get_providers():
63
- return ['CPUExecutionProvider']
64
-
65
- ORT_PROVIDERS = get_providers()
66
-
67
- # ============================================================
68
- # LOAD / VALIDATE
69
- # ============================================================
70
-
71
- def load_tasks_dir(data_dir, arcgen_dir=None):
72
- """Load ARC-AGI tasks and optionally merge ARC-GEN data."""
73
- files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
74
- tasks = {}
75
- for i, f in enumerate(files):
76
- with open(os.path.join(data_dir, f)) as fh:
77
- data = json.load(fh)
78
- hex_id = f.replace('.json','')
79
- # Load ARC-GEN data if available
80
- if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)):
81
- with open(os.path.join(arcgen_dir, f)) as fh:
82
- arcgen_examples = json.load(fh)
83
- if isinstance(arcgen_examples, list):
84
- data['arc-gen'] = arcgen_examples
85
- if 'arc-gen' not in data:
86
- data['arc-gen'] = []
87
- tasks[i+1] = {'hex': hex_id, 'data': data}
88
- return tasks
89
-
90
- def load_tasks_kaggle(data_dir):
91
- """Load Kaggle format tasks (already have arc-gen embedded)."""
92
- tasks = {}
93
- for tn in range(1, 401):
94
- path = os.path.join(data_dir, f"task{tn:03d}.json")
95
- if os.path.exists(path):
96
- with open(path) as f:
97
- data = json.load(f)
98
- if 'arc-gen' not in data:
99
- data['arc-gen'] = []
100
- tasks[tn] = {'hex': f'task{tn:03d}', 'data': data}
101
- return tasks
102
-
103
- def to_onehot(grid):
104
- arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
105
- for r, row in enumerate(grid):
106
- for c, v in enumerate(row):
107
- if r < GH and c < GW and 0 <= v < CH:
108
- arr[0, v, r, c] = 1.0
109
- return arr
110
-
111
- def validate(path, td):
112
- """Validate model against ALL examples: train + test + arc-gen.
113
- This matches what Kaggle does for scoring."""
114
- try:
115
- opts = ort.SessionOptions()
116
- opts.log_severity_level = 3
117
- sess = ort.InferenceSession(path, sess_options=opts, providers=ORT_PROVIDERS)
118
- except:
119
- return False
120
- examples = td['train'] + td['test']
121
- # Include arc-gen examples (capped for speed)
122
- if 'arc-gen' in td:
123
- examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
124
- for ex in examples:
125
- inp = to_onehot(ex['input'])
126
- exp = to_onehot(ex['output'])
127
- try:
128
- out = sess.run(['output'], {'input': inp})[0]
129
- out = (out > 0.0).astype(np.float32)
130
- except:
131
- return False
132
- if not np.array_equal(out, exp):
133
- return False
134
- return True
135
-
136
- def validate_raw(raw_bytes, td):
137
- """Validate model from raw bytes against ALL examples."""
138
- try:
139
- opts = ort.SessionOptions()
140
- opts.log_severity_level = 3
141
- sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=ORT_PROVIDERS)
142
- except:
143
- return False
144
- examples = td['train'] + td['test']
145
- if 'arc-gen' in td:
146
- examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
147
- for ex in examples:
148
- inp = to_onehot(ex['input'])
149
- exp = to_onehot(ex['output'])
150
- try:
151
- out = sess.run(['output'], {'input': inp})[0]
152
- out = (out > 0.0).astype(np.float32)
153
- except:
154
- return False
155
- if not np.array_equal(out, exp):
156
- return False
157
- return True
158
-
159
- # ============================================================
160
- # STATIC PROFILER (no onnx_tool dependency)
161
- # ============================================================
162
-
163
- BANNED_OPS = {'Loop', 'Scan', 'NonZero', 'Unique', 'If', 'Function'}
164
- MAX_FILESIZE = int(1.44 * 1024 * 1024)
165
-
166
- def score_network(path):
167
- """Static profiler matching Kaggle scoring: cost = macs + memory + params.
168
- Falls back to official neurogolf_utils if available."""
169
- if HAS_ONNX_TOOL:
170
- try:
171
- return _score_network_official(path)
172
- except:
173
- pass
174
- return _static_profile(path)
175
-
176
- def _static_profile(path):
177
- """Compute cost without onnx_tool: params + nbytes + macs."""
178
- try:
179
- model = onnx.load(path)
180
- except:
181
- return None, None, None
182
-
183
- tensors = {}
184
- params = 0
185
- nbytes = 0
186
- macs = 0
187
-
188
- for init in model.graph.initializer:
189
- a = numpy_helper.to_array(init)
190
- tensors[init.name] = a
191
- params += a.size
192
- nbytes += a.nbytes
193
-
194
- for nd in model.graph.node:
195
- if nd.op_type == 'Constant':
196
- for attr in nd.attribute:
197
- if attr.t and attr.t.ByteSize() > 0:
198
- try:
199
- a = numpy_helper.to_array(attr.t)
200
- if nd.output:
201
- tensors[nd.output[0]] = a
202
- params += a.size
203
- nbytes += a.nbytes
204
- except:
205
- pass
206
-
207
- if nd.op_type in BANNED_OPS:
208
- return None, None, None
209
-
210
- if nd.op_type == 'Conv' and len(nd.input) >= 2 and nd.input[1] in tensors:
211
- w = tensors[nd.input[1]]
212
- if w.ndim == 4:
213
- co, ci, kh, kw = w.shape
214
- macs += co * ci * kh * kw * GH * GW
215
-
216
- return int(macs), int(nbytes), int(params)
217
-
218
- def mk(nodes, inits=None):
219
- x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
220
- y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
221
- g = helper.make_graph(nodes, "g", [x], [y], initializer=inits or [])
222
- return helper.make_model(g, ir_version=IR, opset_imports=OPSET)
223
-
224
- def get_exs(td):
225
- """Get examples for analytical solvers (train+test only)."""
226
- return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
227
- for ex in td['train'] + td['test']]
228
-
229
- def get_exs_for_fitting(td):
230
- """Get examples for conv fitting. Uses train+test + arc-gen WHERE SIZES MATCH.
231
- For fixed-size tasks, arc-gen examples have the same grid size,
232
- so they provide more data points for lstsq without changing the feature dimension.
233
- For variable-size tasks, only use train+test (arc-gen varies too much)."""
234
- base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
235
- for ex in td['train'] + td['test']]
236
-
237
- if not base_exs:
238
- return base_exs
239
-
240
- # Check if all base examples have same input shape
241
- base_shapes = {inp.shape for inp, _ in base_exs}
242
- if len(base_shapes) != 1:
243
- return base_exs # Variable sizes — don't add arc-gen
244
-
245
- base_shape = list(base_shapes)[0]
246
-
247
- # Add arc-gen examples that match the base shape
248
- ag_exs = []
249
- for ex in td.get('arc-gen', []):
250
- inp = np.array(ex['input'], dtype=np.int64)
251
- out = np.array(ex['output'], dtype=np.int64)
252
- if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
253
- ag_exs.append((inp, out))
254
-
255
- # Cap to avoid massive lstsq (diminishing returns after ~10)
256
- return base_exs + ag_exs[:10]
257
-
258
- def get_exs_for_fitting_variable(td):
259
- """Get examples for variable-shape conv fitting.
260
- For variable-shape tasks, arc-gen examples may have different sizes per example
261
- but since we embed in 30x30 anyway, we can safely include them."""
262
- base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
263
- for ex in td['train'] + td['test']]
264
-
265
- # For variable shape, include arc-gen examples (they get embedded in 30x30)
266
- ag_exs = []
267
- for ex in td.get('arc-gen', []):
268
- inp = np.array(ex['input'], dtype=np.int64)
269
- out = np.array(ex['output'], dtype=np.int64)
270
- if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
271
- ag_exs.append((inp, out))
272
-
273
- return base_exs + ag_exs[:20]
274
-
275
- def fixed_shapes(td):
276
- shapes = set()
277
- for inp, out in get_exs(td):
278
- shapes.add((inp.shape, out.shape))
279
- return list(shapes)[0] if len(shapes) == 1 else None
280
-
281
- # ============================================================
282
- # GATHER HELPERS
283
- # ============================================================
284
-
285
- def _build_gather_model(OH, OW, idx):
286
- # Use Gather (opset 1) instead of GatherElements (opset 11)
287
- # Flatten spatial: [1,10,900] -> Gather(axis=2, indices=[900]) -> [1,10,900]
288
- flat_idx = np.zeros((GH*GW,), dtype=np.int64)
289
- mask = np.zeros((1,1,GH,GW), dtype=np.float32)
290
- for oi in range(OH):
291
- for oj in range(OW):
292
- flat_idx[oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
293
- mask[0,0,oi,oj] = 1.0
294
- inits = [
295
- numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
296
- numpy_helper.from_array(flat_idx, 'idx'),
297
- numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
298
- numpy_helper.from_array(mask, 'mask'),
299
- ]
300
- nodes = [
301
- helper.make_node('Reshape', ['input','fs'], ['flat']),
302
- helper.make_node('Gather', ['flat','idx'], ['g'], axis=2),
303
- helper.make_node('Reshape', ['g','os'], ['raw']),
304
- helper.make_node('Mul', ['raw','mask'], ['output']),
305
- ]
306
- return mk(nodes, inits)
307
-
308
- def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
309
- # Use Gather (opset 1) instead of GatherElements (opset 11)
310
- flat_idx = np.zeros((GH*GW,), dtype=np.int64)
311
- gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
312
- const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
313
- for oi in range(OH):
314
- for oj in range(OW):
315
- if idx[oi,oj,0] >= 0:
316
- flat_idx[oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
317
- gather_mask[0,0,oi,oj] = 1.0
318
- elif cst[oi,oj] >= 0:
319
- const_oh[0, cst[oi,oj], oi, oj] = 1.0
320
- has_const = np.any(const_oh > 0)
321
- inits = [
322
- numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
323
- numpy_helper.from_array(flat_idx, 'idx'),
324
- numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
325
- numpy_helper.from_array(gather_mask, 'gmask'),
326
- ]
327
- nodes = [
328
- helper.make_node('Reshape', ['input','fs'], ['flat']),
329
- helper.make_node('Gather', ['flat','idx'], ['g'], axis=2),
330
- helper.make_node('Reshape', ['g','os'], ['raw']),
331
- helper.make_node('Mul', ['raw','gmask'], ['masked']),
332
- ]
333
- if has_const:
334
- inits.append(numpy_helper.from_array(const_oh, 'cst'))
335
- nodes.append(helper.make_node('Add', ['masked','cst'], ['output']))
336
- else:
337
- nodes[-1] = helper.make_node('Mul', ['raw','gmask'], ['output'])
338
- return mk(nodes, inits)
339
-
340
- # ============================================================
341
- # ANALYTICAL SOLVERS
342
- # ============================================================
343
-
344
- def s_identity(td):
345
- for ex in td['train']+td['test']:
346
- if ex['input'] != ex['output']: return None
347
- return mk([helper.make_node('Identity', ['input'], ['output'])])
348
-
349
- def s_color_map(td):
350
- cm = {}
351
- for ex in td['train']+td['test']:
352
- inp, out = np.array(ex['input']), np.array(ex['output'])
353
- if inp.shape != out.shape: return None
354
- for iv, ov in zip(inp.flat, out.flat):
355
- iv, ov = int(iv), int(ov)
356
- if iv in cm and cm[iv] != ov: return None
357
- cm[iv] = ov
358
-
359
- # Check if it's a permutation (bijective + all mapped colors form a closed set)
360
- is_permutation = (set(cm.keys()) == set(cm.values()))
361
-
362
- if is_permutation:
363
- # Use channel Gather — zero MACs, much cheaper
364
- gather_ch = np.arange(10, dtype=np.int32)
365
- for src, dst in cm.items():
366
- if 0 <= src < 10 and 0 <= dst < 10:
367
- gather_ch[dst] = src
368
- inits = [numpy_helper.from_array(gather_ch, 'gi')]
369
- nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
370
- return mk(nodes, inits)
371
- else:
372
- # Non-permutation: use Conv 1x1 (has MACs but handles any mapping)
373
- W = np.zeros((10,10,1,1), dtype=np.float32)
374
- for ic in range(10):
375
- W[cm.get(ic,ic), ic, 0, 0] = 1.0
376
- return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
377
- [numpy_helper.from_array(W, 'W')])
378
-
379
- def s_transpose(td):
380
- for ex in td['train']+td['test']:
381
- if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T): return None
382
- return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0,1,3,2])])
383
-
384
- def s_flip(td):
385
- exs = get_exs(td)
386
- sp = fixed_shapes(td)
387
- if sp is None: return None
388
- (IH,IW),(OH,OW) = sp
389
- if (IH,IW) != (OH,OW): return None
390
- for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
391
- if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
392
- # Build gather index map (using Gather, opset 1 compatible)
393
- idx = np.zeros((OH,OW,2), dtype=np.int64)
394
- for r in range(OH):
395
- for c in range(OW):
396
- if axis == 0:
397
- idx[r,c] = [IH-1-r, c]
398
- else:
399
- idx[r,c] = [r, IW-1-c]
400
- return _build_gather_model(OH, OW, idx)
401
- return None
402
-
403
- def s_rotate(td):
404
- exs = get_exs(td)
405
- sp = fixed_shapes(td)
406
- if sp is None: return None
407
- (IH,IW),(OH,OW) = sp
408
- for k in [1, 2, 3]:
409
- if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs): continue
410
- idx = np.zeros((OH,OW,2), dtype=np.int64)
411
- for r in range(OH):
412
- for c in range(OW):
413
- if k == 1: sr, sc = c, IH-1-r
414
- elif k == 2: sr, sc = IH-1-r, IW-1-c
415
- elif k == 3: sr, sc = IW-1-c, r
416
- idx[r,c] = [sr, sc]
417
- return _build_gather_model(OH, OW, idx)
418
- return None
419
-
420
- def s_spatial_gather(td):
421
- sp = fixed_shapes(td)
422
- if sp is None: return None
423
- (IH,IW),(OH,OW) = sp
424
- exs = get_exs(td)
425
- idx = np.full((OH,OW,2), -1, dtype=np.int64)
426
- cst = np.full((OH,OW), -1, dtype=np.int64)
427
- for oi in range(OH):
428
- for oj in range(OW):
429
- vals = set(int(out[oi,oj]) for _,out in exs)
430
- if len(vals) == 1: cst[oi,oj] = vals.pop()
431
- found = False
432
- for ri in range(IH):
433
- for rj in range(IW):
434
- if all(int(inp[ri,rj]) == int(out[oi,oj]) for inp,out in exs):
435
- idx[oi,oj] = [ri, rj]; found = True; break
436
- if found: break
437
- if not found and cst[oi,oj] < 0: return None
438
- return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
439
-
440
- def s_varshape_spatial_gather(td):
441
- """Spatial gather that works for variable-shape tasks by embedding in 30x30."""
442
- sp = fixed_shapes(td)
443
- if sp is not None: return None # fixed shapes handled by s_spatial_gather
444
- exs = get_exs(td)
445
-
446
- # Embed all examples in 30x30
447
- exs_30 = []
448
- for inp, out in exs:
449
- ih, iw = inp.shape
450
- oh, ow = out.shape
451
- inp30 = np.zeros((30, 30), dtype=np.int64)
452
- out30 = np.zeros((30, 30), dtype=np.int64)
453
- inp30[:ih, :iw] = inp
454
- out30[:oh, :ow] = out
455
- exs_30.append((inp30, out30))
456
-
457
- idx = np.full((30, 30, 2), -1, dtype=np.int64)
458
- cst = np.full((30, 30), -1, dtype=np.int64)
459
-
460
- for oi in range(30):
461
- for oj in range(30):
462
- vals = set(int(out30[oi, oj]) for _, out30 in exs_30)
463
- if len(vals) == 1:
464
- cst[oi, oj] = vals.pop()
465
- found = False
466
- for ri in range(30):
467
- for rj in range(30):
468
- if all(int(inp30[ri, rj]) == int(out30[oi, oj]) for inp30, out30 in exs_30):
469
- idx[oi, oj] = [ri, rj]
470
- found = True
471
- break
472
- if found: break
473
- if not found and cst[oi, oj] < 0:
474
- return None
475
-
476
- return _build_gather_model_with_const(30, 30, 30, 30, idx, cst)
477
-
478
- def s_tile(td):
479
- exs = get_exs(td)
480
- in_shapes = set(inp.shape for inp,_ in exs)
481
- if len(in_shapes) != 1: return None
482
- IH, IW = in_shapes.pop()
483
- tiles = set()
484
- for inp, out in exs:
485
- OH, OW = out.shape
486
- if OH % IH or OW % IW: return None
487
- rH, rW = OH//IH, OW//IW
488
- if rH < 1 or rW < 1 or (rH==1 and rW==1): return None
489
- tiles.add((rH, rW))
490
- if len(tiles) != 1: return None
491
- rH, rW = tiles.pop()
492
- OH, OW = IH*rH, IW*rW
493
- if OH > 30 or OW > 30: return None
494
- for inp, out in exs:
495
- if not np.array_equal(out, np.tile(inp, (rH, rW))): return None
496
- pad_h, pad_w = 30-OH, 30-OW
497
- inits = [
498
- numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'st'),
499
- numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'en'),
500
- numpy_helper.from_array(np.array([1,1,rH,rW], dtype=np.int64), 'rp'),
501
- ]
502
- nodes = [
503
- helper.make_node('Slice', ['input','st','en'], ['cr']),
504
- helper.make_node('Tile', ['cr','rp'], ['tl']),
505
- helper.make_node('Pad', ['tl'], ['output'], pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0),
506
- ]
507
- return mk(nodes, inits)
508
-
509
- def s_upscale(td):
510
- exs = get_exs(td)
511
- in_shapes = set(inp.shape for inp,_ in exs)
512
- if len(in_shapes) != 1: return None
513
- IH, IW = in_shapes.pop()
514
- scales = set()
515
- for inp, out in exs:
516
- OH, OW = out.shape
517
- if OH % IH or OW % IW: return None
518
- sH, sW = OH//IH, OW//IW
519
- if sH < 2 or sW < 2: return None
520
- scales.add((sH, sW))
521
- if len(scales) != 1: return None
522
- sH, sW = scales.pop()
523
- OH, OW = IH*sH, IW*sW
524
- if OH > 30 or OW > 30: return None
525
- for inp, out in exs:
526
- if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)): return None
527
- idx = np.zeros((OH,OW,2), dtype=np.int64)
528
- for r in range(OH):
529
- for c in range(OW):
530
- idx[r,c] = [r//sH, c//sW]
531
- return _build_gather_model(OH, OW, idx)
532
-
533
- def s_concat(td):
534
- from itertools import product as iproduct
535
- exs = get_exs(td)
536
- sp = fixed_shapes(td)
537
- if sp is None: return None
538
- (IH,IW),(OH,OW) = sp
539
- transforms = [
540
- ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
541
- ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
542
- ]
543
- if OH == IH and OW % IW == 0 and OW > IW:
544
- n = OW // IW
545
- if 2 <= n <= 4:
546
- for combo in iproduct(range(4), repeat=n):
547
- if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=1))
548
- for inp, out in exs):
549
- idx = np.zeros((OH,OW,2), dtype=np.int64)
550
- for oi in range(OH):
551
- for oj in range(OW):
552
- bj = oj // IW; lr, lc = oi, oj % IW
553
- t = transforms[combo[bj]][0]
554
- if t == 'id': sr, sc = lr, lc
555
- elif t == 'fliplr': sr, sc = lr, IW-1-lc
556
- elif t == 'flipud': sr, sc = IH-1-lr, lc
557
- elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
558
- idx[oi,oj] = [sr, sc]
559
- return _build_gather_model(OH, OW, idx)
560
- if OW == IW and OH % IH == 0 and OH > IH:
561
- n = OH // IH
562
- if 2 <= n <= 4:
563
- for combo in iproduct(range(4), repeat=n):
564
- if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=0))
565
- for inp, out in exs):
566
- idx = np.zeros((OH,OW,2), dtype=np.int64)
567
- for oi in range(OH):
568
- for oj in range(OW):
569
- bi = oi // IH; lr, lc = oi % IH, oj
570
- t = transforms[combo[bi]][0]
571
- if t == 'id': sr, sc = lr, lc
572
- elif t == 'fliplr': sr, sc = lr, IW-1-lc
573
- elif t == 'flipud': sr, sc = IH-1-lr, lc
574
- elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
575
- idx[oi,oj] = [sr, sc]
576
- return _build_gather_model(OH, OW, idx)
577
- return None
578
-
579
- def s_concat_enhanced(td):
580
- """Enhanced concat with all 8 dihedral group transforms."""
581
- exs = get_exs(td)
582
- sp = fixed_shapes(td)
583
- if sp is None: return None
584
- (IH,IW),(OH,OW) = sp
585
- if IH == OH and IW == OW: return None
586
-
587
- # Need block decomposition
588
- if OH % IH != 0 or OW % IW != 0: return None
589
- rH, rW = OH // IH, OW // IW
590
- if rH * rW > 16 or rH * rW < 2: return None
591
- if OH > 30 or OW > 30: return None
592
-
593
- # All 8 symmetry transforms of the dihedral group
594
- transforms = [
595
- ('id', lambda x: x),
596
- ('fliplr', lambda x: np.fliplr(x)),
597
- ('flipud', lambda x: np.flipud(x)),
598
- ('rot180', lambda x: np.rot90(x, 2)),
599
- ('rot90', lambda x: np.rot90(x, 1)),
600
- ('rot270', lambda x: np.rot90(x, 3)),
601
- ('T', lambda x: x.T),
602
- ('T_fliplr', lambda x: np.fliplr(x.T)),
603
- ]
604
-
605
- # For each block, find which transform matches
606
- block_transforms = {}
607
- for bi in range(rH):
608
- for bj in range(rW):
609
- found = None
610
- for tidx, (tname, tfn) in enumerate(transforms):
611
- ok = True
612
- for inp, out in exs:
613
- block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
614
- expected = tfn(inp)
615
- if expected.shape != (IH, IW) or not np.array_equal(block, expected):
616
- ok = False
617
- break
618
- if ok:
619
- found = (tidx, tname)
620
- break
621
- if found is None:
622
- return None
623
- block_transforms[(bi, bj)] = found
624
-
625
- # Build index map
626
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
627
- for bi in range(rH):
628
- for bj in range(rW):
629
- _, tname = block_transforms[(bi, bj)]
630
- for lr in range(IH):
631
- for lc in range(IW):
632
- oi, oj = bi*IH + lr, bj*IW + lc
633
- if tname == 'id': sr, sc = lr, lc
634
- elif tname == 'fliplr': sr, sc = lr, IW-1-lc
635
- elif tname == 'flipud': sr, sc = IH-1-lr, lc
636
- elif tname == 'rot180': sr, sc = IH-1-lr, IW-1-lc
637
- elif tname == 'rot90': sr, sc = IW-1-lc, lr
638
- elif tname == 'rot270': sr, sc = lc, IH-1-lr
639
- elif tname == 'T': sr, sc = lc, lr
640
- elif tname == 'T_fliplr': sr, sc = IW-1-lc, lr
641
- idx[oi, oj] = [sr, sc]
642
-
643
- # Verify
644
- for inp, out in exs:
645
- reconstructed = np.zeros_like(out)
646
- for oi in range(OH):
647
- for oj in range(OW):
648
- reconstructed[oi,oj] = inp[idx[oi,oj,0], idx[oi,oj,1]]
649
- if not np.array_equal(reconstructed, out):
650
- return None
651
-
652
- return _build_gather_model(OH, OW, idx)
653
-
654
- def s_input_driven_tile(td):
655
- """Each non-zero input pixel controls a block that's a copy of the input."""
656
- exs = get_exs(td)
657
- sp = fixed_shapes(td)
658
- if sp is None: return None
659
- (IH,IW),(OH,OW) = sp
660
- if OH % IH != 0 or OW % IW != 0: return None
661
- sH, sW = OH // IH, OW // IW
662
- if sH != IH or sW != IW: return None
663
- if OH > 30 or OW > 30: return None
664
-
665
- for inp, out in exs:
666
- for bi in range(IH):
667
- for bj in range(IW):
668
- block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
669
- if inp[bi, bj] != 0:
670
- if not np.array_equal(block, inp):
671
- return None
672
- else:
673
- if not np.all(block == 0):
674
- return None
675
-
676
- # Build gather model: each output pixel at (bi*IH+lr, bj*IW+lc) maps to
677
- # input[lr, lc] if input[bi, bj] != 0, else constant 0
678
- # Problem: whether block is active depends on input value, which varies.
679
- # This needs a different ONNX approach: can't use static gather.
680
- # But we CAN use: Tile input -> Mul by mask derived from input
681
- # Actually we need: for each (bi,bj) block position, multiply by inp[bi,bj] != 0
682
- # This is NOT static - it depends on input content.
683
- # Skip for now - spatial_gather can handle if block positions are fixed.
684
- return None
685
-
686
- def s_kronecker(td):
687
- """output = kron(input, ones(sH,sW)) — nearest-neighbor upscaling."""
688
- exs = get_exs(td)
689
- sp = fixed_shapes(td)
690
- if sp is None: return None
691
- (IH,IW),(OH,OW) = sp
692
- if OH % IH != 0 or OW % IW != 0: return None
693
- sH, sW = OH // IH, OW // IW
694
- if sH < 2 or sW < 2: return None
695
- if OH > 30 or OW > 30: return None
696
-
697
- for inp, out in exs:
698
- expected = np.kron(inp, np.ones((sH, sW), dtype=np.int64))
699
- if not np.array_equal(out, expected):
700
- return None
701
-
702
- # This is identical to upscale - build gather index
703
- idx = np.zeros((OH,OW,2), dtype=np.int64)
704
- for r in range(OH):
705
- for c in range(OW):
706
- idx[r,c] = [r//sH, c//sW]
707
- return _build_gather_model(OH, OW, idx)
708
-
709
- def s_diagonal_tile(td):
710
- """Input placed along diagonal: block[i,i] = input, rest = 0."""
711
- exs = get_exs(td)
712
- sp = fixed_shapes(td)
713
- if sp is None: return None
714
- (IH,IW),(OH,OW) = sp
715
- if OH % IH != 0 or OW % IW != 0: return None
716
- rH, rW = OH // IH, OW // IW
717
- if rH != rW or rH < 2: return None
718
- if OH > 30 or OW > 30: return None
719
-
720
- for inp, out in exs:
721
- for bi in range(rH):
722
- for bj in range(rW):
723
- block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
724
- if bi == bj:
725
- if not np.array_equal(block, inp):
726
- return None
727
- else:
728
- if not np.all(block == 0):
729
- return None
730
-
731
- # Build: diagonal blocks map to input, off-diagonal are constant 0
732
- idx = np.zeros((OH,OW,2), dtype=np.int64)
733
- cst = np.full((OH,OW), -1, dtype=np.int64)
734
- for bi in range(rH):
735
- for bj in range(rW):
736
- for lr in range(IH):
737
- for lc in range(IW):
738
- oi, oj = bi*IH + lr, bj*IW + lc
739
- if bi == bj:
740
- idx[oi, oj] = [lr, lc]
741
- else:
742
- idx[oi, oj] = [-1, -1]
743
- cst[oi, oj] = 0
744
-
745
- return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
746
-
747
- def s_shift(td):
748
- """Detect constant spatial shift of the grid."""
749
- exs = get_exs(td)
750
- sp = fixed_shapes(td)
751
- if sp is None: return None
752
- (IH, IW), (OH, OW) = sp
753
- if (IH, IW) != (OH, OW): return None
754
- for dr in range(-5, 6):
755
- for dc in range(-5, 6):
756
- if dr == 0 and dc == 0: continue
757
- ok = True
758
- for inp, out in exs:
759
- shifted = np.zeros_like(inp)
760
- r0, r1 = max(0, dr), min(IH, IH + dr)
761
- c0, c1 = max(0, dc), min(IW, IW + dc)
762
- if r1 > r0 and c1 > c0:
763
- sr0, sc0 = max(0, -dr), max(0, -dc)
764
- shifted[r0:r1, c0:c1] = inp[sr0:sr0+(r1-r0), sc0:sc0+(c1-c0)]
765
- if not np.array_equal(shifted, out):
766
- ok = False; break
767
- if not ok: continue
768
- # Build gather index
769
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
770
- cst = np.full((OH, OW), 0, dtype=np.int64) # zeros for out-of-bounds
771
- for r in range(OH):
772
- for c in range(OW):
773
- sr, sc = r - dr, c - dc
774
- if 0 <= sr < IH and 0 <= sc < IW:
775
- idx[r, c] = [sr, sc]
776
- else:
777
- idx[r, c] = [-1, -1]
778
- return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
779
- return None
780
-
781
- def s_gravity(td):
782
- """Detect gravity-like compaction in one direction."""
783
- exs = get_exs(td)
784
- sp = fixed_shapes(td)
785
- if sp is None: return None
786
- (IH, IW), (OH, OW) = sp
787
- if (IH, IW) != (OH, OW): return None
788
-
789
- def _gravity(grid, direction):
790
- r = np.zeros_like(grid); h, w = grid.shape
791
- if direction in ('down', 'up'):
792
- for c in range(w):
793
- nz = grid[:, c][grid[:, c] != 0]
794
- if direction == 'down': r[h-len(nz):h, c] = nz
795
- else: r[:len(nz), c] = nz
796
- else:
797
- for rr in range(h):
798
- nz = grid[rr, :][grid[rr, :] != 0]
799
- if direction == 'right': r[rr, w-len(nz):w] = nz
800
- else: r[rr, :len(nz)] = nz
801
- return r
802
-
803
- for d in ('down', 'up', 'left', 'right'):
804
- if all(np.array_equal(_gravity(inp, d), out) for inp, out in exs):
805
- # Gravity is input-dependent (positions depend on content)
806
- # Can't use static Gather — need Conv to learn it
807
- # But conv also can't learn arbitrary sorting...
808
- # Skip for now — this needs a specialized ONNX graph
809
- return None
810
- return None
811
-
812
- def s_mirror_h(td):
813
- """Output = input | flip(input, horizontal), doubling width."""
814
- exs = get_exs(td)
815
- sp = fixed_shapes(td)
816
- if sp is None: return None
817
- (IH, IW), (OH, OW) = sp
818
- if OH != IH or OW != 2 * IW: return None
819
- if OW > 30: return None
820
- for inp, out in exs:
821
- expected = np.concatenate([inp, np.flip(inp, 1)], 1)
822
- if not np.array_equal(expected, out): return None
823
- # Build gather index
824
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
825
- for r in range(OH):
826
- for c in range(OW):
827
- sc = c if c < IW else 2*IW - 1 - c
828
- idx[r, c] = [r, sc]
829
- return _build_gather_model(OH, OW, idx)
830
-
831
- def s_mirror_v(td):
832
- """Output = input over flip(input, vertical), doubling height."""
833
- exs = get_exs(td)
834
- sp = fixed_shapes(td)
835
- if sp is None: return None
836
- (IH, IW), (OH, OW) = sp
837
- if OW != IW or OH != 2 * IH: return None
838
- if OH > 30: return None
839
- for inp, out in exs:
840
- expected = np.concatenate([inp, np.flip(inp, 0)], 0)
841
- if not np.array_equal(expected, out): return None
842
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
843
- for r in range(OH):
844
- for c in range(OW):
845
- sr = r if r < IH else 2*IH - 1 - r
846
- idx[r, c] = [sr, c]
847
- return _build_gather_model(OH, OW, idx)
848
-
849
- def s_quad_mirror(td):
850
- """Output = 2x2 block of input with h/v flips."""
851
- exs = get_exs(td)
852
- sp = fixed_shapes(td)
853
- if sp is None: return None
854
- (IH, IW), (OH, OW) = sp
855
- if OH != 2 * IH or OW != 2 * IW: return None
856
- if OH > 30 or OW > 30: return None
857
- for inp, out in exs:
858
- expected = np.block([
859
- [inp, np.flip(inp, 1)],
860
- [np.flip(inp, 0), np.flip(np.flip(inp, 0), 1)]
861
- ])
862
- if not np.array_equal(expected, out): return None
863
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
864
- for r in range(OH):
865
- for c in range(OW):
866
- sr = r if r < IH else 2*IH - 1 - r
867
- sc = c if c < IW else 2*IW - 1 - c
868
- idx[r, c] = [sr, sc]
869
- return _build_gather_model(OH, OW, idx)
870
-
871
- def s_fixed_crop(td):
872
- """Output = fixed subregion of input."""
873
- exs = get_exs(td)
874
- sp = fixed_shapes(td)
875
- if sp is None: return None
876
- (IH, IW), (OH, OW) = sp
877
- if OH > IH or OW > IW or (OH == IH and OW == IW): return None
878
- for r0 in range(IH - OH + 1):
879
- for c0 in range(IW - OW + 1):
880
- if all(np.array_equal(inp[r0:r0+OH, c0:c0+OW], out) for inp, out in exs):
881
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
882
- for r in range(OH):
883
- for c in range(OW):
884
- idx[r, c] = [r0 + r, c0 + c]
885
- return _build_gather_model(OH, OW, idx)
886
- return None
887
-
888
- def s_nonuniform_scale(td):
889
- """Output = input scaled by different factors in h and w."""
890
- exs = get_exs(td)
891
- sp = fixed_shapes(td)
892
- if sp is None: return None
893
- (IH, IW), (OH, OW) = sp
894
- for fh, fw in [(1,2),(2,1),(1,3),(3,1),(2,3),(3,2),(1,4),(4,1),(2,4),(4,2)]:
895
- if OH != IH*fh or OW != IW*fw: continue
896
- if OH > 30 or OW > 30: continue
897
- if all(np.array_equal(np.repeat(np.repeat(inp, fh, 0), fw, 1), out) for inp, out in exs):
898
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
899
- for r in range(OH):
900
- for c in range(OW):
901
- idx[r, c] = [r//fh, c//fw]
902
- return _build_gather_model(OH, OW, idx)
903
- return None
904
-
905
- def s_constant(td):
906
- sp = fixed_shapes(td)
907
- if sp is None: return None
908
- exs = get_exs(td)
909
- outs = [out for _,out in exs]
910
- if not all(np.array_equal(outs[0], o) for o in outs[1:]): return None
911
- const = np.zeros((1,10,30,30), dtype=np.float32)
912
- for r, row in enumerate(outs[0]):
913
- for c, v in enumerate(row):
914
- const[0, int(v), r, c] = 1.0
915
- inits = [numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
916
- numpy_helper.from_array(const, 'c')]
917
- nodes = [helper.make_node('Mul', ['input','z'], ['zd']),
918
- helper.make_node('ReduceSum', ['zd'], ['s'], axes=[1,2,3], keepdims=1),
919
- helper.make_node('Add', ['s','c'], ['output'])]
920
- return mk(nodes, inits)
921
-
922
- # ============================================================
923
- # CONV SOLVERS
924
- # ============================================================
925
-
926
- def add_onehot_block(nodes, inits, am_name, oh_name):
927
- """Equal + Cast one-hot encoding (replaces OneHot which lacks CUDA kernel)."""
928
- classes = np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1)
929
- inits.append(numpy_helper.from_array(classes, 'classes'))
930
- nodes.append(helper.make_node('Equal', [am_name, 'classes'], ['eq']))
931
- nodes.append(helper.make_node('Cast', ['eq'], [oh_name], to=TensorProto.FLOAT))
932
-
933
- def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
934
- """Shared lstsq conv fitting. Returns (Wconv, B) or None."""
935
- pad = ks // 2
936
- feat = 10 * ks * ks + (1 if use_bias else 0)
937
- if feat > 20000: return None
938
-
939
- patches, targets = [], []
940
- for inp_g, out_g in exs_raw:
941
- ih, iw = inp_g.shape
942
- if use_full_30:
943
- oh_full = np.zeros((10, GH, GW), dtype=np.float64)
944
- for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
945
- oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
946
- else:
947
- oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
948
- for c in range(10): oh_enc[c] = (inp_g == c)
949
- oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
950
-
951
- oh, ow = out_g.shape
952
- for r in range(oh):
953
- for c in range(ow):
954
- p = oh_pad[:, r:r+ks, c:c+ks].flatten()
955
- if use_bias: p = np.append(p, 1.0)
956
- patches.append(p)
957
- targets.append(int(out_g[r, c]))
958
-
959
- n_patches = len(patches)
960
- if feat > 5000 and n_patches > 2000: return None
961
-
962
- P = np.array(patches, dtype=np.float64)
963
- T = np.array(targets, dtype=np.int64)
964
- T_oh = np.zeros((len(T), 10), dtype=np.float64)
965
- for i, t in enumerate(T): T_oh[i, t] = 1.0
966
-
967
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
968
- if not np.array_equal(np.argmax(P @ WT, axis=1), T): return None
969
-
970
- if use_bias:
971
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
972
- B = WT[-1].astype(np.float32)
973
- else:
974
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
975
- B = None
976
- return Wconv, B
977
-
978
- def solve_conv_fixed(td, path, time_budget=30.0):
979
- """Fixed-shape conv: Slice -> Conv -> ArgMax -> Equal+Cast -> Pad."""
980
- exs = get_exs(td)
981
- for inp, out in exs:
982
- if inp.shape != out.shape: return None
983
- shapes = set(inp.shape for inp, _ in exs)
984
- if len(shapes) != 1: return None
985
- IH, IW = shapes.pop()
986
-
987
- # Use ARC-GEN data for better fitting
988
- fit_exs = get_exs_for_fitting(td)
989
- # Filter to same-shape, same IH/IW
990
- fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
991
-
992
- t_start = time.time()
993
- for use_bias in [False, True]:
994
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
995
- if time.time() - t_start > time_budget: return None
996
- result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=False)
997
- if result is None: continue
998
- Wconv, B = result
999
- pad = ks // 2
1000
- pad_h, pad_w = GH - IH, GW - IW
1001
-
1002
- inits = [
1003
- numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1004
- numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1005
- numpy_helper.from_array(Wconv, 'W'),
1006
- ]
1007
- conv_inputs = ['grid', 'W']
1008
- if B is not None:
1009
- inits.append(numpy_helper.from_array(B, 'B'))
1010
- conv_inputs.append('B')
1011
-
1012
- nodes = [
1013
- helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1014
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1015
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1016
- ]
1017
- add_onehot_block(nodes, inits, 'am', 'oh_out')
1018
- nodes.append(
1019
- helper.make_node('Pad', ['oh_out'], ['output'],
1020
- pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1021
- )
1022
-
1023
- model = mk(nodes, inits)
1024
- onnx.save(model, path)
1025
- if validate(path, td): return 'conv_fixed', model
1026
- return None
1027
-
1028
- def solve_conv_variable(td, path, time_budget=30.0):
1029
- """Variable-shape conv: Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)."""
1030
- exs = get_exs(td)
1031
- for inp, out in exs:
1032
- if inp.shape != out.shape: return None
1033
-
1034
- # Use ARC-GEN data for better fitting (variable shape, embedded in 30x30)
1035
- fit_exs = get_exs_for_fitting_variable(td)
1036
- fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape]
1037
-
1038
- t_start = time.time()
1039
- for use_bias in [False, True]:
1040
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
1041
- if time.time() - t_start > time_budget: return None
1042
- result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=True)
1043
- if result is None: continue
1044
- Wconv, B = result
1045
- pad = ks // 2
1046
-
1047
- inits = [numpy_helper.from_array(Wconv, 'W')]
1048
- conv_inputs = ['input', 'W']
1049
- if B is not None:
1050
- inits.append(numpy_helper.from_array(B, 'B'))
1051
- conv_inputs.append('B')
1052
-
1053
- nodes = [
1054
- helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1055
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1056
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1057
- ]
1058
- add_onehot_block(nodes, inits, 'am', 'oh_out')
1059
- nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1060
-
1061
- model = mk(nodes, inits)
1062
- onnx.save(model, path)
1063
- if validate(path, td): return 'conv_var', model
1064
- return None
1065
-
1066
- def solve_conv_diffshape(td, path, time_budget=30.0):
1067
- """Diff-shape conv for fixed io shapes where output is smaller."""
1068
- sp = fixed_shapes(td)
1069
- if sp is None: return None
1070
- (IH, IW), (OH, OW) = sp
1071
- if IH == OH and IW == OW: return None
1072
- if OH > IH or OW > IW: return None
1073
- if OH > 30 or OW > 30: return None
1074
-
1075
- exs = get_exs(td)
1076
- t_start = time.time()
1077
-
1078
- for dr_off, dc_off in [(0, 0), ((IH-OH)//2, (IW-OW)//2)]:
1079
- for use_bias in [False, True]:
1080
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
1081
- if time.time() - t_start > time_budget: return None
1082
- pad = ks // 2
1083
- feat = 10 * ks * ks + (1 if use_bias else 0)
1084
- if feat > 10000: continue
1085
-
1086
- patches, targets = [], []
1087
- valid = True
1088
- for inp_g, out_g in exs:
1089
- oh_enc = np.zeros((10, IH, IW), dtype=np.float64)
1090
- for c in range(10): oh_enc[c] = (inp_g == c)
1091
- oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
1092
- for r in range(OH):
1093
- for c in range(OW):
1094
- sr, sc = r + dr_off, c + dc_off
1095
- if sr < 0 or sr >= IH or sc < 0 or sc >= IW:
1096
- valid = False; break
1097
- p = oh_pad[:, sr:sr+ks, sc:sc+ks].flatten()
1098
- if use_bias: p = np.append(p, 1.0)
1099
- patches.append(p)
1100
- targets.append(int(out_g[r, c]))
1101
- if not valid: break
1102
- if not valid: break
1103
- if not valid: continue
1104
-
1105
- n_patches = len(patches)
1106
- if feat > 5000 and n_patches > 2000: continue
1107
-
1108
- P = np.array(patches, dtype=np.float64)
1109
- T = np.array(targets, dtype=np.int64)
1110
- T_oh = np.zeros((len(T), 10), dtype=np.float64)
1111
- for i, t in enumerate(T): T_oh[i, t] = 1.0
1112
-
1113
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1114
- if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
1115
-
1116
- if use_bias:
1117
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
1118
- B = WT[-1].astype(np.float32)
1119
- else:
1120
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1121
- B = None
1122
-
1123
- pad_h, pad_w = GH - OH, GW - OW
1124
- inits = [
1125
- numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1126
- numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1127
- numpy_helper.from_array(Wconv, 'W'),
1128
- numpy_helper.from_array(np.array([0,0,dr_off,dc_off], dtype=np.int64), 'cr_st'),
1129
- numpy_helper.from_array(np.array([1,10,dr_off+OH,dc_off+OW], dtype=np.int64), 'cr_en'),
1130
- ]
1131
- conv_inputs = ['grid', 'W']
1132
- if B is not None:
1133
- inits.append(numpy_helper.from_array(B, 'B'))
1134
- conv_inputs.append('B')
1135
-
1136
- nodes = [
1137
- helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1138
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1139
- helper.make_node('Slice', ['co','cr_st','cr_en'], ['co_crop']),
1140
- helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
1141
- ]
1142
- add_onehot_block(nodes, inits, 'am', 'oh_out')
1143
- nodes.append(
1144
- helper.make_node('Pad', ['oh_out'], ['output'],
1145
- pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1146
- )
1147
-
1148
- model = mk(nodes, inits)
1149
- onnx.save(model, path)
1150
- if validate(path, td): return 'conv_diff', model
1151
- return None
1152
-
1153
- def solve_conv_var_diff(td, path, time_budget=30.0):
1154
- """Variable diff-shape conv: Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(output_mask).
1155
- Works when output shape differs from input but mapping is convolutional on 30x30 grid."""
1156
- exs = get_exs(td)
1157
-
1158
- t_start = time.time()
1159
- for use_bias in [False, True]:
1160
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
1161
- if time.time() - t_start > time_budget: return None
1162
-
1163
- pad = ks // 2
1164
- feat = 10 * ks * ks + (1 if use_bias else 0)
1165
- if feat > 20000: continue
1166
-
1167
- patches, targets = [], []
1168
- for inp_g, out_g in exs:
1169
- ih, iw = inp_g.shape
1170
- oh, ow = out_g.shape
1171
- oh_full = np.zeros((10, GH, GW), dtype=np.float64)
1172
- for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
1173
- oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
1174
-
1175
- for r in range(oh):
1176
- for c in range(ow):
1177
- p = oh_pad[:, r:r+ks, c:c+ks].flatten()
1178
- if use_bias: p = np.append(p, 1.0)
1179
- patches.append(p)
1180
- targets.append(int(out_g[r, c]))
1181
-
1182
- n_patches = len(patches)
1183
- if feat > 5000 and n_patches > 2000: continue
1184
-
1185
- P = np.array(patches, dtype=np.float64)
1186
- T = np.array(targets, dtype=np.int64)
1187
- T_oh = np.zeros((len(T), 10), dtype=np.float64)
1188
- for i, t in enumerate(T): T_oh[i, t] = 1.0
1189
-
1190
- try:
1191
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1192
- except:
1193
- continue
1194
- if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
1195
-
1196
- if use_bias:
1197
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
1198
- B = WT[-1].astype(np.float32)
1199
- else:
1200
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1201
- B = None
1202
-
1203
- # Use ReduceSum of output channels as mask (sum across channels == 1 for valid pixels)
1204
- # But we don't know the output mask at inference time from input alone...
1205
- # We need a way to derive the output mask from the input.
1206
- # For same-shape: mask = ReduceSum(input, axis=1) works
1207
- # For diff-shape: we need to compute the output mask differently
1208
- #
1209
- # Approach: Conv output at valid positions should have max > threshold,
1210
- # and at padding positions max ≈ 0. Use the ArgMax+OneHot and then
1211
- # mask with ReduceSum(input) which is 1 at input positions but 0 at padding.
1212
- # BUT output may be LARGER than input...
1213
- #
1214
- # Alternative: just use Conv -> ArgMax -> Equal+Cast -> Mul(input_mask_expanded)
1215
- # where input_mask covers the output region too.
1216
- # This won't work if output extends beyond input region.
1217
- #
1218
- # Simplest correct approach: let the conv produce valid one-hot everywhere,
1219
- # then the padding region should naturally produce channel-0 output.
1220
- # Since padding is all-zero input, conv output there = bias only.
1221
- # If no bias, conv output = 0 for all channels -> argmax gives channel 0 -> onehot gives [1,0,...,0]
1222
- # which equals the padding encoding (channel 0 = 1 in padding).
1223
- # Wait - that's WRONG for the NeuroGolf format. In the padding region, ALL channels should be 0.
1224
- # The one-hot encoding has channel[color]=1, but padding = ALL zeros.
1225
- #
1226
- # So we NEED a mask. But for diff-shape, what mask?
1227
- # If output is always top-left aligned and we know max output size...
1228
- # We can't statically determine the output mask from the input.
1229
- #
1230
- # However: we can try the ReduceSum approach anyway — if conv naturally
1231
- # produces channel-0 dominant output in padding, then:
1232
- # mask = ReduceSum(input, axis=1) gives 1 for input pixels, 0 for padding
1233
- # If output region ⊆ input region, this works.
1234
- # If output region > input region... we need the output's ReduceSum instead.
1235
-
1236
- # For tasks where output fits within input bounds, use input mask
1237
- all_output_within_input = all(
1238
- out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
1239
- for inp_g, out_g in exs
1240
- )
1241
-
1242
- if not all_output_within_input:
1243
- continue # Skip tasks where output extends beyond input
1244
-
1245
- inits = [numpy_helper.from_array(Wconv, 'W')]
1246
- conv_inputs = ['input', 'W']
1247
- if B is not None:
1248
- inits.append(numpy_helper.from_array(B, 'B'))
1249
- conv_inputs.append('B')
1250
-
1251
- nodes = [
1252
- helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1253
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1254
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1255
- ]
1256
- add_onehot_block(nodes, inits, 'am', 'oh_out')
1257
- nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1258
-
1259
- model = mk(nodes, inits)
1260
- onnx.save(model, path)
1261
- if validate(path, td): return 'conv_var_diff', model
1262
- return None
1263
-
1264
- # ============================================================
1265
- # PYTORCH LEARNED CONV (gradient descent, multi-seed, ternary snap)
1266
- # ============================================================
1267
-
1268
- def _ternary_snap(w, eps=0.2):
1269
- """Snap weights to {-1, 0, 1} — smaller model, often still correct."""
1270
- return np.where(w > eps, 1.0, np.where(w < -eps, -1.0, 0.0)).astype(np.float32)
1271
-
1272
- def _build_conv_onnx_from_weights(W, ks, use_full_30=False, IH=None, IW=None):
1273
- """Build ONNX conv model from numpy weight array W [10,10,ks,ks].
1274
- For fixed-shape: Slice→Conv→ArgMax→Equal+Cast→Pad
1275
- For variable/full30: Conv→ArgMax→Equal+Cast→Mul(mask)"""
1276
- pad = ks // 2
1277
- if use_full_30:
1278
- # Variable shape: full 30x30 conv with mask
1279
- inits = [numpy_helper.from_array(W, 'W')]
1280
- nodes = [
1281
- helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1282
- helper.make_node('Conv', ['input', 'W'], ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1283
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1284
- ]
1285
- add_onehot_block(nodes, inits, 'am', 'oh_out')
1286
- nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1287
- return mk(nodes, inits)
1288
- else:
1289
- # Fixed shape: slice, conv, pad
1290
- pad_h, pad_w = GH - IH, GW - IW
1291
- inits = [
1292
- numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1293
- numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1294
- numpy_helper.from_array(W, 'W'),
1295
- ]
1296
- nodes = [
1297
- helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1298
- helper.make_node('Conv', ['grid', 'W'], ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1299
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1300
- ]
1301
- add_onehot_block(nodes, inits, 'am', 'oh_out')
1302
- nodes.append(
1303
- helper.make_node('Pad', ['oh_out'], ['output'],
1304
- pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1305
- )
1306
- return mk(nodes, inits)
1307
-
1308
- def _build_two_layer_conv_onnx(W1, W2, ks1, ks2, use_full_30=False, IH=None, IW=None):
1309
- """Build ONNX two-layer conv: Conv→ReLU→Conv→ArgMax→Equal+Cast→Pad/Mul(mask)."""
1310
- pad1, pad2 = ks1 // 2, ks2 // 2
1311
- if use_full_30:
1312
- inits = [
1313
- numpy_helper.from_array(W1, 'W1'),
1314
- numpy_helper.from_array(W2, 'W2'),
1315
- ]
1316
- nodes = [
1317
- helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1318
- helper.make_node('Conv', ['input', 'W1'], ['h1'], kernel_shape=[ks1,ks1], pads=[pad1]*4),
1319
- helper.make_node('Relu', ['h1'], ['h1r']),
1320
- helper.make_node('Conv', ['h1r', 'W2'], ['co'], kernel_shape=[ks2,ks2], pads=[pad2]*4),
1321
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1322
- ]
1323
- add_onehot_block(nodes, inits, 'am', 'oh_out')
1324
- nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1325
- return mk(nodes, inits)
1326
- else:
1327
- pad_h, pad_w = GH - IH, GW - IW
1328
- inits = [
1329
- numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1330
- numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1331
- numpy_helper.from_array(W1, 'W1'),
1332
- numpy_helper.from_array(W2, 'W2'),
1333
- ]
1334
- nodes = [
1335
- helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1336
- helper.make_node('Conv', ['grid', 'W1'], ['h1'], kernel_shape=[ks1,ks1], pads=[pad1]*4),
1337
- helper.make_node('Relu', ['h1'], ['h1r']),
1338
- helper.make_node('Conv', ['h1r', 'W2'], ['co'], kernel_shape=[ks2,ks2], pads=[pad2]*4),
1339
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1340
- ]
1341
- add_onehot_block(nodes, inits, 'am', 'oh_out')
1342
- nodes.append(
1343
- helper.make_node('Pad', ['oh_out'], ['output'],
1344
- pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1345
- )
1346
- return mk(nodes, inits)
1347
-
1348
- def solve_pytorch_conv(td, path, time_budget=30.0):
1349
- """PyTorch gradient descent conv solver. Tries single-layer then two-layer.
1350
- Multi-seed training with ternary weight snapping for smaller models.
1351
- Validates against arc-gen before accepting."""
1352
- try:
1353
- import torch
1354
- import torch.nn as nn
1355
- import copy as _copy
1356
- except ImportError:
1357
- return None
1358
-
1359
- exs = get_exs(td)
1360
- same_shape = all(inp.shape == out.shape for inp, out in exs)
1361
- if not same_shape:
1362
- return None # Only handle same-shape for now
1363
-
1364
- shapes = set(inp.shape for inp, _ in exs)
1365
- fixed_in = len(shapes) == 1
1366
-
1367
- # Prepare tensors
1368
- all_pairs = td['train'] + td['test']
1369
- inp_list = [to_onehot(p['input'])[0] for p in all_pairs]
1370
- out_list = [to_onehot(p['output'])[0] for p in all_pairs]
1371
- inp_t = torch.tensor(np.stack(inp_list), dtype=torch.float32)
1372
- out_t = torch.tensor(np.stack(out_list), dtype=torch.float32)
1373
-
1374
- if fixed_in:
1375
- IH, IW = list(shapes)[0]
1376
- # Train on cropped region
1377
- inp_t = inp_t[:, :, :IH, :IW]
1378
- out_t = out_t[:, :, :IH, :IW]
1379
-
1380
- t_start = time.time()
1381
- best_result = None
1382
-
1383
- # Phase 1: Single-layer conv (multiple kernel sizes and seeds)
1384
- for ks in [1, 3, 5, 7]:
1385
- if time.time() - t_start > time_budget * 0.6:
1386
- break
1387
- pad = ks // 2
1388
- for seed in [0, 7, 42]:
1389
- if time.time() - t_start > time_budget * 0.6:
1390
- break
1391
- torch.manual_seed(seed)
1392
- conv = nn.Conv2d(CH, CH, kernel_size=ks, padding=pad, bias=False)
1393
- if seed == 0:
1394
- nn.init.zeros_(conv.weight)
1395
- opt = torch.optim.Adam(conv.parameters(), lr=0.03)
1396
- best_loss, best_state = float('inf'), None
1397
- for step in range(3000):
1398
- opt.zero_grad()
1399
- pred = conv(inp_t)
1400
- loss = nn.functional.mse_loss(pred, out_t)
1401
- loss.backward()
1402
- opt.step()
1403
- if loss.item() < best_loss:
1404
- best_loss = loss.item()
1405
- best_state = _copy.deepcopy(conv.state_dict())
1406
- if best_loss < 1e-8:
1407
- break
1408
- if best_state is None:
1409
- continue
1410
- conv.load_state_dict(best_state)
1411
- w = conv.weight.detach().numpy()
1412
-
1413
- # Try continuous weights, then ternary-snapped
1414
- for w_cand in [w, _ternary_snap(w)]:
1415
- use_full = not fixed_in
1416
- model = _build_conv_onnx_from_weights(
1417
- w_cand, ks, use_full_30=use_full,
1418
- IH=IH if fixed_in else None,
1419
- IW=IW if fixed_in else None
1420
- )
1421
- onnx.save(model, path)
1422
- if validate(path, td):
1423
- sz = os.path.getsize(path)
1424
- if best_result is None or sz < best_result[2]:
1425
- best_result = ('pt_conv', model, sz)
1426
-
1427
- # Phase 2: Two-layer conv (Conv→ReLU→Conv)
1428
- for ks1, ks2, hidden in [(3, 1, CH), (5, 1, CH), (3, 3, CH)]:
1429
- if time.time() - t_start > time_budget:
1430
- break
1431
- for seed in [0, 7]:
1432
- if time.time() - t_start > time_budget:
1433
- break
1434
- torch.manual_seed(seed)
1435
- net = nn.Sequential(
1436
- nn.Conv2d(CH, hidden, kernel_size=ks1, padding=ks1//2, bias=False),
1437
- nn.ReLU(),
1438
- nn.Conv2d(hidden, CH, kernel_size=ks2, padding=ks2//2, bias=False),
1439
- )
1440
- opt = torch.optim.Adam(net.parameters(), lr=0.01)
1441
- best_loss, best_state = float('inf'), None
1442
- for step in range(2500):
1443
- opt.zero_grad()
1444
- pred = net(inp_t)
1445
- loss = nn.functional.mse_loss(pred, out_t)
1446
- loss.backward()
1447
- opt.step()
1448
- if loss.item() < best_loss:
1449
- best_loss = loss.item()
1450
- best_state = _copy.deepcopy(net.state_dict())
1451
- if best_loss < 1e-8:
1452
- break
1453
- if best_state is None:
1454
- continue
1455
- net.load_state_dict(best_state)
1456
- w1 = net[0].weight.detach().numpy()
1457
- w2 = net[2].weight.detach().numpy()
1458
-
1459
- for w1c, w2c in [(w1, w2), (_ternary_snap(w1), _ternary_snap(w2))]:
1460
- use_full = not fixed_in
1461
- model = _build_two_layer_conv_onnx(
1462
- w1c, w2c, ks1, ks2, use_full_30=use_full,
1463
- IH=IH if fixed_in else None,
1464
- IW=IW if fixed_in else None
1465
- )
1466
- onnx.save(model, path)
1467
- if validate(path, td):
1468
- sz = os.path.getsize(path)
1469
- if best_result is None or sz < best_result[2]:
1470
- best_result = ('pt_conv2', model, sz)
1471
-
1472
- if best_result is not None:
1473
- sname, model, _ = best_result
1474
- onnx.save(model, path)
1475
- return sname, model
1476
- return None
1477
-
1478
- # ============================================================
1479
- # MAIN
1480
- # ============================================================
1481
-
1482
- ANALYTICAL_SOLVERS = [
1483
- ('identity', s_identity), ('constant', s_constant), ('color_map', s_color_map),
1484
- ('transpose', s_transpose), ('flip', s_flip), ('rotate', s_rotate),
1485
- ('tile', s_tile), ('upscale', s_upscale), ('kronecker', s_kronecker),
1486
- ('nonuniform_scale', s_nonuniform_scale),
1487
- ('mirror_h', s_mirror_h), ('mirror_v', s_mirror_v), ('quad_mirror', s_quad_mirror),
1488
- ('concat', s_concat), ('concat_enhanced', s_concat_enhanced),
1489
- ('diagonal_tile', s_diagonal_tile),
1490
- ('fixed_crop', s_fixed_crop),
1491
- ('spatial_gather', s_spatial_gather),
1492
- ('shift', s_shift),
1493
- ('varshape_spatial_gather', s_varshape_spatial_gather),
1494
- ]
1495
-
1496
- def solve_task(tn, td, outdir, conv_budget=30.0):
1497
- t_start = time.time()
1498
- os.makedirs(outdir, exist_ok=True)
1499
- path = os.path.join(outdir, f"task{tn:03d}.onnx")
1500
-
1501
- # Skip excluded tasks
1502
- if tn in EXCLUDED_TASKS:
1503
- return False, 'excluded', None, time.time() - t_start, path
1504
-
1505
- # 1. Try analytical solvers (fast, tiny models)
1506
- for sname, sfn in ANALYTICAL_SOLVERS:
1507
- try:
1508
- model = sfn(td)
1509
- if model is None: continue
1510
- onnx.save(model, path)
1511
- if validate(path, td):
1512
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1513
- except: pass
1514
-
1515
- # 2. Determine task shape category and try conv solvers
1516
- exs = get_exs(td)
1517
- same_shape = all(inp.shape == out.shape for inp, out in exs)
1518
- shapes = set(inp.shape for inp, _ in exs)
1519
- fixed_in = len(shapes) == 1
1520
-
1521
- conv_time = conv_budget
1522
-
1523
- if same_shape:
1524
- if fixed_in:
1525
- result = solve_conv_fixed(td, path, time_budget=conv_time/2)
1526
- if result is not None:
1527
- sname, model = result
1528
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1529
- result = solve_conv_variable(td, path, time_budget=conv_time)
1530
- if result is not None:
1531
- sname, model = result
1532
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1533
- # 3. PyTorch learned conv as fallback for same-shape tasks
1534
- remaining = max(1, conv_time - (time.time() - t_start))
1535
- result = solve_pytorch_conv(td, path, time_budget=remaining)
1536
- if result is not None:
1537
- sname, model = result
1538
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1539
- else:
1540
- sp = fixed_shapes(td)
1541
- if sp is not None:
1542
- (IH,IW),(OH,OW) = sp
1543
- if OH <= IH and OW <= IW:
1544
- result = solve_conv_diffshape(td, path, time_budget=conv_time)
1545
- if result is not None:
1546
- sname, model = result
1547
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1548
-
1549
- # Try variable diff-shape conv (output within input bounds)
1550
- result = solve_conv_var_diff(td, path, time_budget=conv_time)
1551
- if result is not None:
1552
- sname, model = result
1553
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1554
-
1555
- return False, None, None, time.time() - t_start, path
1556
-
1557
- def run_tasks(task_nums, tasks, output_dir, conv_budget, use_wandb):
1558
- results = {}
1559
- costs_dict = {}
1560
- total_score = 0
1561
- for tn in task_nums:
1562
- if tn not in tasks:
1563
- continue
1564
- if tn in EXCLUDED_TASKS:
1565
- print(f"Task {tn:3d}: EXCLUDED (officially)")
1566
- continue
1567
-
1568
- td = tasks[tn]['data']
1569
- ok, sname, sz, t_task, model_path = solve_task(tn, td, output_dir, conv_budget)
1570
-
1571
- if ok:
1572
- macs, memory, params = score_network(model_path)
1573
- if macs is None:
1574
- macs, memory, params = 0, 0, 0
1575
- cost = macs + memory + params
1576
- score = max(1.0, 25.0 - math.log(max(1, cost)))
1577
- total_score += score
1578
-
1579
- results[tn] = (sname, t_task, sz)
1580
- costs_dict[tn] = cost
1581
- print(f"Task {tn:3d}: {sname:25s} {score:7.3f} {cost:>12} {t_task:7.3f}s ({sz:>8,} bytes)")
1582
- else:
1583
- print(f"Task {tn:3d}: UNSOLVED {t_task:7.3f}s")
1584
- cost = 0
1585
-
1586
- if use_wandb and wandb is not None:
1587
- wandb.log({
1588
- "task_id": tn,
1589
- "solver": sname if ok else "unsolved",
1590
- "onnx_bytes": sz if ok else 0,
1591
- "task_time_sec": t_task,
1592
- "cost": cost,
1593
- "score": score if ok else 0,
1594
- })
1595
-
1596
- return results, costs_dict, total_score
1597
-
1598
-
1599
- def main():
1600
- parser = argparse.ArgumentParser()
1601
- parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
1602
- parser.add_argument('--arcgen_dir', default='', help='Path to ARC-GEN-100K/ directory')
1603
- parser.add_argument('--output_dir', default='submission')
1604
- parser.add_argument('--kaggle', action='store_true')
1605
- parser.add_argument('--conv_budget', type=float, default=30.0)
1606
- parser.add_argument('--tasks', type=str, default='')
1607
- parser.add_argument('--device', type=str, default='auto', choices=['auto','cpu','cuda'])
1608
- parser.add_argument('--use_wandb', action='store_true')
1609
- args = parser.parse_args()
1610
- global ORT_PROVIDERS
1611
- config = {
1612
- "device": args.device,
1613
- "conv_budget": args.conv_budget,
1614
- "data_dir": args.data_dir,
1615
- "arcgen_dir": args.arcgen_dir,
1616
- "tasks": args.tasks,
1617
- }
1618
-
1619
- if args.device == 'cuda':
1620
- ORT_PROVIDERS = ['CUDAExecutionProvider', 'CPUExecutionProvider']
1621
- elif args.device == 'cpu':
1622
- ORT_PROVIDERS = ['CPUExecutionProvider']
1623
-
1624
- ort.set_default_logger_severity(3)
1625
- print(f"Using providers: {ORT_PROVIDERS}")
1626
-
1627
- if args.kaggle:
1628
- tasks = load_tasks_kaggle(args.data_dir)
1629
- else:
1630
- arcgen = args.arcgen_dir if args.arcgen_dir else None
1631
- tasks = load_tasks_dir(args.data_dir, arcgen_dir=arcgen)
1632
-
1633
- # Count arc-gen examples
1634
- total_arcgen = sum(len(t['data'].get('arc-gen', [])) for t in tasks.values())
1635
- print(f"Loaded {len(tasks)} tasks ({total_arcgen} ARC-GEN examples)")
1636
- print(f"Excluded tasks: {sorted(EXCLUDED_TASKS)}")
1637
-
1638
- task_nums = [int(t) for t in args.tasks.split(',')] if args.tasks else sorted(tasks.keys())
1639
- active_tasks = [t for t in task_nums if t not in EXCLUDED_TASKS]
1640
- print(f"Solving {len(active_tasks)} active tasks (skipping {len(task_nums) - len(active_tasks)} excluded)")
1641
- print(f"Conv budget: {args.conv_budget}s per task")
1642
- print("=" * 70)
1643
- t0 = time.time()
1644
-
1645
- if args.use_wandb and wandb is not None:
1646
- with wandb.init(project="neurogolf", name="solver_run", config=config):
1647
- results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=True)
1648
- else:
1649
- results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=False)
1650
-
1651
- elapsed = time.time() - t0
1652
- print(f"\n{'='*70}")
1653
- print(f"Solved: {len(results)}/{len(active_tasks)} active tasks in {elapsed:.0f}s")
1654
- solver_names = [v[0] for v in results.values()]
1655
- sc = Counter(solver_names)
1656
- for s, c in sc.most_common(): print(f" {s}: {c}")
1657
-
1658
- # Generate submission
1659
- outdir = args.output_dir
1660
- n_files = len([f for f in os.listdir(outdir) if f.endswith('.onnx')])
1661
- total_size = sum(os.path.getsize(os.path.join(outdir, f))
1662
- for f in os.listdir(outdir) if f.endswith('.onnx'))
1663
-
1664
- # Create submission.zip
1665
- zip_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.zip')
1666
- buf = io.BytesIO()
1667
- with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
1668
- for f in sorted(os.listdir(outdir)):
1669
- if f.endswith('.onnx'):
1670
- zf.write(os.path.join(outdir, f), f)
1671
- zip_bytes = buf.getvalue()
1672
- with open(zip_path, 'wb') as f:
1673
- f.write(zip_bytes)
1674
- zip_size = len(zip_bytes)
1675
-
1676
- # Create submission.csv
1677
- csv_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.csv')
1678
- with open(csv_path, 'w', newline='') as f:
1679
- w = csv.writer(f)
1680
- w.writerow(['task_id', 'total_cost'])
1681
- for tn in sorted(costs_dict.keys()):
1682
- w.writerow([f'task{tn:03d}', costs_dict[tn]])
1683
-
1684
- # Estimate LB score: solved tasks get their score, unsolved get 1.0
1685
- unsolved_count = len(active_tasks) - len(results)
1686
- est_lb = total_score + unsolved_count * 1.0
1687
-
1688
- print(f"\n{n_files} ONNX files, {total_size/1024:.1f} KB uncompressed")
1689
- print(f"ZIP size: {zip_size/1024:.1f} KB / {MAX_FILESIZE/1024:.0f} KB limit {'OK' if zip_size <= MAX_FILESIZE else 'OVER!'}")
1690
- print(f"Estimated LB score: {est_lb:.1f} (solved: {total_score:.1f} + unsolved: {unsolved_count}×1.0)")
1691
- print(f"Written: {zip_path} | {csv_path}")
1692
-
1693
- if __name__ == '__main__':
1694
- main()
 
1
+ FILE_CONTENT_PLACEHOLDER