rogermt commited on
Commit
0422a6a
·
verified ·
1 Parent(s): 59533d2

v4: ARC-GEN validation + fitting, new analytical solvers (shift/mirror/crop/quad), s_flip opset fix, static profiler, submission.csv

Browse files
Files changed (1) hide show
  1. neurogolf_solver.py +1459 -1
neurogolf_solver.py CHANGED
@@ -1 +1,1459 @@
1
- placeholder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ARC-AGI NeuroGolf Championship - Complete Solver v4
4
+ Format: [1,10,30,30] one-hot input/output, opset 10, IR version 10.
5
+
6
+ v4 CRITICAL FIXES:
7
+ - ARC-GEN data loaded and used for conv fitting (more data = better lstsq)
8
+ - ARC-GEN validation: models validated against train+test+arc-gen
9
+ - EXCLUDED tasks: {21, 55, 80, 184, 202, 366} skipped
10
+ - submission.csv generation for Kaggle
11
+ - s_flip fixed: GatherElements -> Gather (opset 10 compat)
12
+ - Static profiler: no onnx_tool dependency for cost estimation
13
+ - get_exs_for_fitting(): uses train+test+arc-gen for conv fitting
14
+
15
+ Solvers:
16
+ - Analytical: identity, constant, color_map, transpose, flip, rotate, tile, upscale,
17
+ concat, concat_enhanced, spatial_gather, varshape_spatial_gather,
18
+ diagonal_tile, kronecker
19
+ - Conv (fixed shape): Slice -> Conv -> ArgMax -> Equal+Cast -> Pad
20
+ - Conv (variable shape): Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)
21
+ - Conv (diff shape): Slice -> Conv -> Slice(crop) -> ArgMax -> Equal+Cast -> Pad
22
+
23
+ Usage:
24
+ python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission
25
+ python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission --conv_budget 60 --arcgen_dir ARC-GEN-100K/
26
+ """
27
+
28
+ import json, os, sys, math, time, argparse, csv, io, zipfile
29
+ import numpy as np
30
+ import onnx
31
+ from onnx import helper, TensorProto, numpy_helper
32
+ import onnxruntime as ort
33
+ from collections import Counter
34
+
35
+ try:
36
+ from neurogolf_utils import score_network as _score_network_official
37
+ HAS_ONNX_TOOL = True
38
+ except ImportError:
39
+ HAS_ONNX_TOOL = False
40
+
41
+ try:
42
+ import wandb
43
+ except ImportError:
44
+ wandb = None
45
+
46
+ BATCH, CH, GH, GW = 1, 10, 30, 30
47
+ GRID_SHAPE = [BATCH, CH, GH, GW]
48
+ DT = TensorProto.FLOAT
49
+ IR = 10
50
+ OPSET = [helper.make_opsetid("", 10)]
51
+
52
+ # Officially excluded tasks (score 0 regardless)
53
+ EXCLUDED_TASKS = {21, 55, 80, 184, 202, 366}
54
+
55
+ # Max ARC-GEN examples to use for validation (to keep runtime reasonable)
56
+ MAX_ARCGEN_VALIDATE = 30
57
+ # Max ARC-GEN examples for conv fitting (keep separate from validation!)
58
+ # NOTE: Conv fitting uses train+test only. ARC-GEN is for VALIDATION only.
59
+ # lstsq underdetermines with too many variable-size arc-gen examples.
60
+ MAX_ARCGEN_FIT = 0 # Don't use arc-gen for fitting — use for validation only
61
+
62
+ def get_providers():
63
+ return ['CPUExecutionProvider']
64
+
65
+ ORT_PROVIDERS = get_providers()
66
+
67
+ # ============================================================
68
+ # LOAD / VALIDATE
69
+ # ============================================================
70
+
71
+ def load_tasks_dir(data_dir, arcgen_dir=None):
72
+ """Load ARC-AGI tasks and optionally merge ARC-GEN data."""
73
+ files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
74
+ tasks = {}
75
+ for i, f in enumerate(files):
76
+ with open(os.path.join(data_dir, f)) as fh:
77
+ data = json.load(fh)
78
+ hex_id = f.replace('.json','')
79
+ # Load ARC-GEN data if available
80
+ if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)):
81
+ with open(os.path.join(arcgen_dir, f)) as fh:
82
+ arcgen_examples = json.load(fh)
83
+ if isinstance(arcgen_examples, list):
84
+ data['arc-gen'] = arcgen_examples
85
+ if 'arc-gen' not in data:
86
+ data['arc-gen'] = []
87
+ tasks[i+1] = {'hex': hex_id, 'data': data}
88
+ return tasks
89
+
90
+ def load_tasks_kaggle(data_dir):
91
+ """Load Kaggle format tasks (already have arc-gen embedded)."""
92
+ tasks = {}
93
+ for tn in range(1, 401):
94
+ path = os.path.join(data_dir, f"task{tn:03d}.json")
95
+ if os.path.exists(path):
96
+ with open(path) as f:
97
+ data = json.load(f)
98
+ if 'arc-gen' not in data:
99
+ data['arc-gen'] = []
100
+ tasks[tn] = {'hex': f'task{tn:03d}', 'data': data}
101
+ return tasks
102
+
103
+ def to_onehot(grid):
104
+ arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
105
+ for r, row in enumerate(grid):
106
+ for c, v in enumerate(row):
107
+ if r < GH and c < GW and 0 <= v < CH:
108
+ arr[0, v, r, c] = 1.0
109
+ return arr
110
+
111
+ def validate(path, td):
112
+ """Validate model against ALL examples: train + test + arc-gen.
113
+ This matches what Kaggle does for scoring."""
114
+ try:
115
+ opts = ort.SessionOptions()
116
+ opts.log_severity_level = 3
117
+ sess = ort.InferenceSession(path, sess_options=opts, providers=ORT_PROVIDERS)
118
+ except:
119
+ return False
120
+ examples = td['train'] + td['test']
121
+ # Include arc-gen examples (capped for speed)
122
+ if 'arc-gen' in td:
123
+ examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
124
+ for ex in examples:
125
+ inp = to_onehot(ex['input'])
126
+ exp = to_onehot(ex['output'])
127
+ try:
128
+ out = sess.run(['output'], {'input': inp})[0]
129
+ out = (out > 0.0).astype(np.float32)
130
+ except:
131
+ return False
132
+ if not np.array_equal(out, exp):
133
+ return False
134
+ return True
135
+
136
+ def validate_raw(raw_bytes, td):
137
+ """Validate model from raw bytes against ALL examples."""
138
+ try:
139
+ opts = ort.SessionOptions()
140
+ opts.log_severity_level = 3
141
+ sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=ORT_PROVIDERS)
142
+ except:
143
+ return False
144
+ examples = td['train'] + td['test']
145
+ if 'arc-gen' in td:
146
+ examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
147
+ for ex in examples:
148
+ inp = to_onehot(ex['input'])
149
+ exp = to_onehot(ex['output'])
150
+ try:
151
+ out = sess.run(['output'], {'input': inp})[0]
152
+ out = (out > 0.0).astype(np.float32)
153
+ except:
154
+ return False
155
+ if not np.array_equal(out, exp):
156
+ return False
157
+ return True
158
+
159
+ # ============================================================
160
+ # STATIC PROFILER (no onnx_tool dependency)
161
+ # ============================================================
162
+
163
+ BANNED_OPS = {'Loop', 'Scan', 'NonZero', 'Unique', 'If', 'Function'}
164
+ MAX_FILESIZE = int(1.44 * 1024 * 1024)
165
+
166
+ def score_network(path):
167
+ """Static profiler matching Kaggle scoring: cost = macs + memory + params.
168
+ Falls back to official neurogolf_utils if available."""
169
+ if HAS_ONNX_TOOL:
170
+ try:
171
+ return _score_network_official(path)
172
+ except:
173
+ pass
174
+ return _static_profile(path)
175
+
176
+ def _static_profile(path):
177
+ """Compute cost without onnx_tool: params + nbytes + macs."""
178
+ try:
179
+ model = onnx.load(path)
180
+ except:
181
+ return None, None, None
182
+
183
+ tensors = {}
184
+ params = 0
185
+ nbytes = 0
186
+ macs = 0
187
+
188
+ for init in model.graph.initializer:
189
+ a = numpy_helper.to_array(init)
190
+ tensors[init.name] = a
191
+ params += a.size
192
+ nbytes += a.nbytes
193
+
194
+ for nd in model.graph.node:
195
+ if nd.op_type == 'Constant':
196
+ for attr in nd.attribute:
197
+ if attr.t and attr.t.ByteSize() > 0:
198
+ try:
199
+ a = numpy_helper.to_array(attr.t)
200
+ if nd.output:
201
+ tensors[nd.output[0]] = a
202
+ params += a.size
203
+ nbytes += a.nbytes
204
+ except:
205
+ pass
206
+
207
+ if nd.op_type in BANNED_OPS:
208
+ return None, None, None
209
+
210
+ if nd.op_type == 'Conv' and len(nd.input) >= 2 and nd.input[1] in tensors:
211
+ w = tensors[nd.input[1]]
212
+ if w.ndim == 4:
213
+ co, ci, kh, kw = w.shape
214
+ macs += co * ci * kh * kw * GH * GW
215
+
216
+ return int(macs), int(nbytes), int(params)
217
+
218
+ def mk(nodes, inits=None):
219
+ x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
220
+ y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
221
+ g = helper.make_graph(nodes, "g", [x], [y], initializer=inits or [])
222
+ return helper.make_model(g, ir_version=IR, opset_imports=OPSET)
223
+
224
+ def get_exs(td):
225
+ """Get examples for analytical solvers (train+test only)."""
226
+ return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
227
+ for ex in td['train'] + td['test']]
228
+
229
+ def get_exs_for_fitting(td):
230
+ """Get examples for conv fitting. Uses train+test + arc-gen WHERE SIZES MATCH.
231
+ For fixed-size tasks, arc-gen examples have the same grid size,
232
+ so they provide more data points for lstsq without changing the feature dimension.
233
+ For variable-size tasks, only use train+test (arc-gen varies too much)."""
234
+ base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
235
+ for ex in td['train'] + td['test']]
236
+
237
+ if not base_exs:
238
+ return base_exs
239
+
240
+ # Check if all base examples have same input shape
241
+ base_shapes = {inp.shape for inp, _ in base_exs}
242
+ if len(base_shapes) != 1:
243
+ return base_exs # Variable sizes — don't add arc-gen
244
+
245
+ base_shape = list(base_shapes)[0]
246
+
247
+ # Add arc-gen examples that match the base shape
248
+ ag_exs = []
249
+ for ex in td.get('arc-gen', []):
250
+ inp = np.array(ex['input'], dtype=np.int64)
251
+ out = np.array(ex['output'], dtype=np.int64)
252
+ if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
253
+ ag_exs.append((inp, out))
254
+
255
+ # Cap to avoid massive lstsq (diminishing returns after ~10)
256
+ return base_exs + ag_exs[:10]
257
+
258
+ def get_exs_for_fitting_variable(td):
259
+ """Get examples for variable-shape conv fitting.
260
+ For variable-shape tasks, arc-gen examples may have different sizes per example
261
+ but since we embed in 30x30 anyway, we can safely include them."""
262
+ base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
263
+ for ex in td['train'] + td['test']]
264
+
265
+ # For variable shape, include arc-gen examples (they get embedded in 30x30)
266
+ ag_exs = []
267
+ for ex in td.get('arc-gen', []):
268
+ inp = np.array(ex['input'], dtype=np.int64)
269
+ out = np.array(ex['output'], dtype=np.int64)
270
+ if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
271
+ ag_exs.append((inp, out))
272
+
273
+ return base_exs + ag_exs[:20]
274
+
275
+ def fixed_shapes(td):
276
+ shapes = set()
277
+ for inp, out in get_exs(td):
278
+ shapes.add((inp.shape, out.shape))
279
+ return list(shapes)[0] if len(shapes) == 1 else None
280
+
281
+ # ============================================================
282
+ # GATHER HELPERS
283
+ # ============================================================
284
+
285
+ def _build_gather_model(OH, OW, idx):
286
+ # Use Gather (opset 1) instead of GatherElements (opset 11)
287
+ # Flatten spatial: [1,10,900] -> Gather(axis=2, indices=[900]) -> [1,10,900]
288
+ flat_idx = np.zeros((GH*GW,), dtype=np.int64)
289
+ mask = np.zeros((1,1,GH,GW), dtype=np.float32)
290
+ for oi in range(OH):
291
+ for oj in range(OW):
292
+ flat_idx[oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
293
+ mask[0,0,oi,oj] = 1.0
294
+ inits = [
295
+ numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
296
+ numpy_helper.from_array(flat_idx, 'idx'),
297
+ numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
298
+ numpy_helper.from_array(mask, 'mask'),
299
+ ]
300
+ nodes = [
301
+ helper.make_node('Reshape', ['input','fs'], ['flat']),
302
+ helper.make_node('Gather', ['flat','idx'], ['g'], axis=2),
303
+ helper.make_node('Reshape', ['g','os'], ['raw']),
304
+ helper.make_node('Mul', ['raw','mask'], ['output']),
305
+ ]
306
+ return mk(nodes, inits)
307
+
308
+ def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
309
+ # Use Gather (opset 1) instead of GatherElements (opset 11)
310
+ flat_idx = np.zeros((GH*GW,), dtype=np.int64)
311
+ gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
312
+ const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
313
+ for oi in range(OH):
314
+ for oj in range(OW):
315
+ if idx[oi,oj,0] >= 0:
316
+ flat_idx[oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
317
+ gather_mask[0,0,oi,oj] = 1.0
318
+ elif cst[oi,oj] >= 0:
319
+ const_oh[0, cst[oi,oj], oi, oj] = 1.0
320
+ has_const = np.any(const_oh > 0)
321
+ inits = [
322
+ numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
323
+ numpy_helper.from_array(flat_idx, 'idx'),
324
+ numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
325
+ numpy_helper.from_array(gather_mask, 'gmask'),
326
+ ]
327
+ nodes = [
328
+ helper.make_node('Reshape', ['input','fs'], ['flat']),
329
+ helper.make_node('Gather', ['flat','idx'], ['g'], axis=2),
330
+ helper.make_node('Reshape', ['g','os'], ['raw']),
331
+ helper.make_node('Mul', ['raw','gmask'], ['masked']),
332
+ ]
333
+ if has_const:
334
+ inits.append(numpy_helper.from_array(const_oh, 'cst'))
335
+ nodes.append(helper.make_node('Add', ['masked','cst'], ['output']))
336
+ else:
337
+ nodes[-1] = helper.make_node('Mul', ['raw','gmask'], ['output'])
338
+ return mk(nodes, inits)
339
+
340
+ # ============================================================
341
+ # ANALYTICAL SOLVERS
342
+ # ============================================================
343
+
344
+ def s_identity(td):
345
+ for ex in td['train']+td['test']:
346
+ if ex['input'] != ex['output']: return None
347
+ return mk([helper.make_node('Identity', ['input'], ['output'])])
348
+
349
+ def s_color_map(td):
350
+ cm = {}
351
+ for ex in td['train']+td['test']:
352
+ inp, out = np.array(ex['input']), np.array(ex['output'])
353
+ if inp.shape != out.shape: return None
354
+ for iv, ov in zip(inp.flat, out.flat):
355
+ iv, ov = int(iv), int(ov)
356
+ if iv in cm and cm[iv] != ov: return None
357
+ cm[iv] = ov
358
+ W = np.zeros((10,10,1,1), dtype=np.float32)
359
+ for ic in range(10):
360
+ W[cm.get(ic,ic), ic, 0, 0] = 1.0
361
+ return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
362
+ [numpy_helper.from_array(W, 'W')])
363
+
364
+ def s_transpose(td):
365
+ for ex in td['train']+td['test']:
366
+ if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T): return None
367
+ return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0,1,3,2])])
368
+
369
+ def s_flip(td):
370
+ exs = get_exs(td)
371
+ sp = fixed_shapes(td)
372
+ if sp is None: return None
373
+ (IH,IW),(OH,OW) = sp
374
+ if (IH,IW) != (OH,OW): return None
375
+ for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
376
+ if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
377
+ # Build gather index map (using Gather, opset 1 compatible)
378
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
379
+ for r in range(OH):
380
+ for c in range(OW):
381
+ if axis == 0:
382
+ idx[r,c] = [IH-1-r, c]
383
+ else:
384
+ idx[r,c] = [r, IW-1-c]
385
+ return _build_gather_model(OH, OW, idx)
386
+ return None
387
+
388
+ def s_rotate(td):
389
+ exs = get_exs(td)
390
+ sp = fixed_shapes(td)
391
+ if sp is None: return None
392
+ (IH,IW),(OH,OW) = sp
393
+ for k in [1, 2, 3]:
394
+ if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs): continue
395
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
396
+ for r in range(OH):
397
+ for c in range(OW):
398
+ if k == 1: sr, sc = c, IH-1-r
399
+ elif k == 2: sr, sc = IH-1-r, IW-1-c
400
+ elif k == 3: sr, sc = IW-1-c, r
401
+ idx[r,c] = [sr, sc]
402
+ return _build_gather_model(OH, OW, idx)
403
+ return None
404
+
405
+ def s_spatial_gather(td):
406
+ sp = fixed_shapes(td)
407
+ if sp is None: return None
408
+ (IH,IW),(OH,OW) = sp
409
+ exs = get_exs(td)
410
+ idx = np.full((OH,OW,2), -1, dtype=np.int64)
411
+ cst = np.full((OH,OW), -1, dtype=np.int64)
412
+ for oi in range(OH):
413
+ for oj in range(OW):
414
+ vals = set(int(out[oi,oj]) for _,out in exs)
415
+ if len(vals) == 1: cst[oi,oj] = vals.pop()
416
+ found = False
417
+ for ri in range(IH):
418
+ for rj in range(IW):
419
+ if all(int(inp[ri,rj]) == int(out[oi,oj]) for inp,out in exs):
420
+ idx[oi,oj] = [ri, rj]; found = True; break
421
+ if found: break
422
+ if not found and cst[oi,oj] < 0: return None
423
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
424
+
425
+ def s_varshape_spatial_gather(td):
426
+ """Spatial gather that works for variable-shape tasks by embedding in 30x30."""
427
+ sp = fixed_shapes(td)
428
+ if sp is not None: return None # fixed shapes handled by s_spatial_gather
429
+ exs = get_exs(td)
430
+
431
+ # Embed all examples in 30x30
432
+ exs_30 = []
433
+ for inp, out in exs:
434
+ ih, iw = inp.shape
435
+ oh, ow = out.shape
436
+ inp30 = np.zeros((30, 30), dtype=np.int64)
437
+ out30 = np.zeros((30, 30), dtype=np.int64)
438
+ inp30[:ih, :iw] = inp
439
+ out30[:oh, :ow] = out
440
+ exs_30.append((inp30, out30))
441
+
442
+ idx = np.full((30, 30, 2), -1, dtype=np.int64)
443
+ cst = np.full((30, 30), -1, dtype=np.int64)
444
+
445
+ for oi in range(30):
446
+ for oj in range(30):
447
+ vals = set(int(out30[oi, oj]) for _, out30 in exs_30)
448
+ if len(vals) == 1:
449
+ cst[oi, oj] = vals.pop()
450
+ found = False
451
+ for ri in range(30):
452
+ for rj in range(30):
453
+ if all(int(inp30[ri, rj]) == int(out30[oi, oj]) for inp30, out30 in exs_30):
454
+ idx[oi, oj] = [ri, rj]
455
+ found = True
456
+ break
457
+ if found: break
458
+ if not found and cst[oi, oj] < 0:
459
+ return None
460
+
461
+ return _build_gather_model_with_const(30, 30, 30, 30, idx, cst)
462
+
463
+ def s_tile(td):
464
+ exs = get_exs(td)
465
+ in_shapes = set(inp.shape for inp,_ in exs)
466
+ if len(in_shapes) != 1: return None
467
+ IH, IW = in_shapes.pop()
468
+ tiles = set()
469
+ for inp, out in exs:
470
+ OH, OW = out.shape
471
+ if OH % IH or OW % IW: return None
472
+ rH, rW = OH//IH, OW//IW
473
+ if rH < 1 or rW < 1 or (rH==1 and rW==1): return None
474
+ tiles.add((rH, rW))
475
+ if len(tiles) != 1: return None
476
+ rH, rW = tiles.pop()
477
+ OH, OW = IH*rH, IW*rW
478
+ if OH > 30 or OW > 30: return None
479
+ for inp, out in exs:
480
+ if not np.array_equal(out, np.tile(inp, (rH, rW))): return None
481
+ pad_h, pad_w = 30-OH, 30-OW
482
+ inits = [
483
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'st'),
484
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'en'),
485
+ numpy_helper.from_array(np.array([1,1,rH,rW], dtype=np.int64), 'rp'),
486
+ ]
487
+ nodes = [
488
+ helper.make_node('Slice', ['input','st','en'], ['cr']),
489
+ helper.make_node('Tile', ['cr','rp'], ['tl']),
490
+ helper.make_node('Pad', ['tl'], ['output'], pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0),
491
+ ]
492
+ return mk(nodes, inits)
493
+
494
+ def s_upscale(td):
495
+ exs = get_exs(td)
496
+ in_shapes = set(inp.shape for inp,_ in exs)
497
+ if len(in_shapes) != 1: return None
498
+ IH, IW = in_shapes.pop()
499
+ scales = set()
500
+ for inp, out in exs:
501
+ OH, OW = out.shape
502
+ if OH % IH or OW % IW: return None
503
+ sH, sW = OH//IH, OW//IW
504
+ if sH < 2 or sW < 2: return None
505
+ scales.add((sH, sW))
506
+ if len(scales) != 1: return None
507
+ sH, sW = scales.pop()
508
+ OH, OW = IH*sH, IW*sW
509
+ if OH > 30 or OW > 30: return None
510
+ for inp, out in exs:
511
+ if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)): return None
512
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
513
+ for r in range(OH):
514
+ for c in range(OW):
515
+ idx[r,c] = [r//sH, c//sW]
516
+ return _build_gather_model(OH, OW, idx)
517
+
518
+ def s_concat(td):
519
+ from itertools import product as iproduct
520
+ exs = get_exs(td)
521
+ sp = fixed_shapes(td)
522
+ if sp is None: return None
523
+ (IH,IW),(OH,OW) = sp
524
+ transforms = [
525
+ ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
526
+ ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
527
+ ]
528
+ if OH == IH and OW % IW == 0 and OW > IW:
529
+ n = OW // IW
530
+ if 2 <= n <= 4:
531
+ for combo in iproduct(range(4), repeat=n):
532
+ if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=1))
533
+ for inp, out in exs):
534
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
535
+ for oi in range(OH):
536
+ for oj in range(OW):
537
+ bj = oj // IW; lr, lc = oi, oj % IW
538
+ t = transforms[combo[bj]][0]
539
+ if t == 'id': sr, sc = lr, lc
540
+ elif t == 'fliplr': sr, sc = lr, IW-1-lc
541
+ elif t == 'flipud': sr, sc = IH-1-lr, lc
542
+ elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
543
+ idx[oi,oj] = [sr, sc]
544
+ return _build_gather_model(OH, OW, idx)
545
+ if OW == IW and OH % IH == 0 and OH > IH:
546
+ n = OH // IH
547
+ if 2 <= n <= 4:
548
+ for combo in iproduct(range(4), repeat=n):
549
+ if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=0))
550
+ for inp, out in exs):
551
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
552
+ for oi in range(OH):
553
+ for oj in range(OW):
554
+ bi = oi // IH; lr, lc = oi % IH, oj
555
+ t = transforms[combo[bi]][0]
556
+ if t == 'id': sr, sc = lr, lc
557
+ elif t == 'fliplr': sr, sc = lr, IW-1-lc
558
+ elif t == 'flipud': sr, sc = IH-1-lr, lc
559
+ elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
560
+ idx[oi,oj] = [sr, sc]
561
+ return _build_gather_model(OH, OW, idx)
562
+ return None
563
+
564
+ def s_concat_enhanced(td):
565
+ """Enhanced concat with all 8 dihedral group transforms."""
566
+ exs = get_exs(td)
567
+ sp = fixed_shapes(td)
568
+ if sp is None: return None
569
+ (IH,IW),(OH,OW) = sp
570
+ if IH == OH and IW == OW: return None
571
+
572
+ # Need block decomposition
573
+ if OH % IH != 0 or OW % IW != 0: return None
574
+ rH, rW = OH // IH, OW // IW
575
+ if rH * rW > 16 or rH * rW < 2: return None
576
+ if OH > 30 or OW > 30: return None
577
+
578
+ # All 8 symmetry transforms of the dihedral group
579
+ transforms = [
580
+ ('id', lambda x: x),
581
+ ('fliplr', lambda x: np.fliplr(x)),
582
+ ('flipud', lambda x: np.flipud(x)),
583
+ ('rot180', lambda x: np.rot90(x, 2)),
584
+ ('rot90', lambda x: np.rot90(x, 1)),
585
+ ('rot270', lambda x: np.rot90(x, 3)),
586
+ ('T', lambda x: x.T),
587
+ ('T_fliplr', lambda x: np.fliplr(x.T)),
588
+ ]
589
+
590
+ # For each block, find which transform matches
591
+ block_transforms = {}
592
+ for bi in range(rH):
593
+ for bj in range(rW):
594
+ found = None
595
+ for tidx, (tname, tfn) in enumerate(transforms):
596
+ ok = True
597
+ for inp, out in exs:
598
+ block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
599
+ expected = tfn(inp)
600
+ if expected.shape != (IH, IW) or not np.array_equal(block, expected):
601
+ ok = False
602
+ break
603
+ if ok:
604
+ found = (tidx, tname)
605
+ break
606
+ if found is None:
607
+ return None
608
+ block_transforms[(bi, bj)] = found
609
+
610
+ # Build index map
611
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
612
+ for bi in range(rH):
613
+ for bj in range(rW):
614
+ _, tname = block_transforms[(bi, bj)]
615
+ for lr in range(IH):
616
+ for lc in range(IW):
617
+ oi, oj = bi*IH + lr, bj*IW + lc
618
+ if tname == 'id': sr, sc = lr, lc
619
+ elif tname == 'fliplr': sr, sc = lr, IW-1-lc
620
+ elif tname == 'flipud': sr, sc = IH-1-lr, lc
621
+ elif tname == 'rot180': sr, sc = IH-1-lr, IW-1-lc
622
+ elif tname == 'rot90': sr, sc = IW-1-lc, lr
623
+ elif tname == 'rot270': sr, sc = lc, IH-1-lr
624
+ elif tname == 'T': sr, sc = lc, lr
625
+ elif tname == 'T_fliplr': sr, sc = IW-1-lc, lr
626
+ idx[oi, oj] = [sr, sc]
627
+
628
+ # Verify
629
+ for inp, out in exs:
630
+ reconstructed = np.zeros_like(out)
631
+ for oi in range(OH):
632
+ for oj in range(OW):
633
+ reconstructed[oi,oj] = inp[idx[oi,oj,0], idx[oi,oj,1]]
634
+ if not np.array_equal(reconstructed, out):
635
+ return None
636
+
637
+ return _build_gather_model(OH, OW, idx)
638
+
639
+ def s_input_driven_tile(td):
640
+ """Each non-zero input pixel controls a block that's a copy of the input."""
641
+ exs = get_exs(td)
642
+ sp = fixed_shapes(td)
643
+ if sp is None: return None
644
+ (IH,IW),(OH,OW) = sp
645
+ if OH % IH != 0 or OW % IW != 0: return None
646
+ sH, sW = OH // IH, OW // IW
647
+ if sH != IH or sW != IW: return None
648
+ if OH > 30 or OW > 30: return None
649
+
650
+ for inp, out in exs:
651
+ for bi in range(IH):
652
+ for bj in range(IW):
653
+ block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
654
+ if inp[bi, bj] != 0:
655
+ if not np.array_equal(block, inp):
656
+ return None
657
+ else:
658
+ if not np.all(block == 0):
659
+ return None
660
+
661
+ # Build gather model: each output pixel at (bi*IH+lr, bj*IW+lc) maps to
662
+ # input[lr, lc] if input[bi, bj] != 0, else constant 0
663
+ # Problem: whether block is active depends on input value, which varies.
664
+ # This needs a different ONNX approach: can't use static gather.
665
+ # But we CAN use: Tile input -> Mul by mask derived from input
666
+ # Actually we need: for each (bi,bj) block position, multiply by inp[bi,bj] != 0
667
+ # This is NOT static - it depends on input content.
668
+ # Skip for now - spatial_gather can handle if block positions are fixed.
669
+ return None
670
+
671
+ def s_kronecker(td):
672
+ """output = kron(input, ones(sH,sW)) — nearest-neighbor upscaling."""
673
+ exs = get_exs(td)
674
+ sp = fixed_shapes(td)
675
+ if sp is None: return None
676
+ (IH,IW),(OH,OW) = sp
677
+ if OH % IH != 0 or OW % IW != 0: return None
678
+ sH, sW = OH // IH, OW // IW
679
+ if sH < 2 or sW < 2: return None
680
+ if OH > 30 or OW > 30: return None
681
+
682
+ for inp, out in exs:
683
+ expected = np.kron(inp, np.ones((sH, sW), dtype=np.int64))
684
+ if not np.array_equal(out, expected):
685
+ return None
686
+
687
+ # This is identical to upscale - build gather index
688
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
689
+ for r in range(OH):
690
+ for c in range(OW):
691
+ idx[r,c] = [r//sH, c//sW]
692
+ return _build_gather_model(OH, OW, idx)
693
+
694
+ def s_diagonal_tile(td):
695
+ """Input placed along diagonal: block[i,i] = input, rest = 0."""
696
+ exs = get_exs(td)
697
+ sp = fixed_shapes(td)
698
+ if sp is None: return None
699
+ (IH,IW),(OH,OW) = sp
700
+ if OH % IH != 0 or OW % IW != 0: return None
701
+ rH, rW = OH // IH, OW // IW
702
+ if rH != rW or rH < 2: return None
703
+ if OH > 30 or OW > 30: return None
704
+
705
+ for inp, out in exs:
706
+ for bi in range(rH):
707
+ for bj in range(rW):
708
+ block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
709
+ if bi == bj:
710
+ if not np.array_equal(block, inp):
711
+ return None
712
+ else:
713
+ if not np.all(block == 0):
714
+ return None
715
+
716
+ # Build: diagonal blocks map to input, off-diagonal are constant 0
717
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
718
+ cst = np.full((OH,OW), -1, dtype=np.int64)
719
+ for bi in range(rH):
720
+ for bj in range(rW):
721
+ for lr in range(IH):
722
+ for lc in range(IW):
723
+ oi, oj = bi*IH + lr, bj*IW + lc
724
+ if bi == bj:
725
+ idx[oi, oj] = [lr, lc]
726
+ else:
727
+ idx[oi, oj] = [-1, -1]
728
+ cst[oi, oj] = 0
729
+
730
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
731
+
732
+ def s_shift(td):
733
+ """Detect constant spatial shift of the grid."""
734
+ exs = get_exs(td)
735
+ sp = fixed_shapes(td)
736
+ if sp is None: return None
737
+ (IH, IW), (OH, OW) = sp
738
+ if (IH, IW) != (OH, OW): return None
739
+ for dr in range(-5, 6):
740
+ for dc in range(-5, 6):
741
+ if dr == 0 and dc == 0: continue
742
+ ok = True
743
+ for inp, out in exs:
744
+ shifted = np.zeros_like(inp)
745
+ r0, r1 = max(0, dr), min(IH, IH + dr)
746
+ c0, c1 = max(0, dc), min(IW, IW + dc)
747
+ if r1 > r0 and c1 > c0:
748
+ sr0, sc0 = max(0, -dr), max(0, -dc)
749
+ shifted[r0:r1, c0:c1] = inp[sr0:sr0+(r1-r0), sc0:sc0+(c1-c0)]
750
+ if not np.array_equal(shifted, out):
751
+ ok = False; break
752
+ if not ok: continue
753
+ # Build gather index
754
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
755
+ cst = np.full((OH, OW), 0, dtype=np.int64) # zeros for out-of-bounds
756
+ for r in range(OH):
757
+ for c in range(OW):
758
+ sr, sc = r - dr, c - dc
759
+ if 0 <= sr < IH and 0 <= sc < IW:
760
+ idx[r, c] = [sr, sc]
761
+ else:
762
+ idx[r, c] = [-1, -1]
763
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
764
+ return None
765
+
766
+ def s_gravity(td):
767
+ """Detect gravity-like compaction in one direction."""
768
+ exs = get_exs(td)
769
+ sp = fixed_shapes(td)
770
+ if sp is None: return None
771
+ (IH, IW), (OH, OW) = sp
772
+ if (IH, IW) != (OH, OW): return None
773
+
774
+ def _gravity(grid, direction):
775
+ r = np.zeros_like(grid); h, w = grid.shape
776
+ if direction in ('down', 'up'):
777
+ for c in range(w):
778
+ nz = grid[:, c][grid[:, c] != 0]
779
+ if direction == 'down': r[h-len(nz):h, c] = nz
780
+ else: r[:len(nz), c] = nz
781
+ else:
782
+ for rr in range(h):
783
+ nz = grid[rr, :][grid[rr, :] != 0]
784
+ if direction == 'right': r[rr, w-len(nz):w] = nz
785
+ else: r[rr, :len(nz)] = nz
786
+ return r
787
+
788
+ for d in ('down', 'up', 'left', 'right'):
789
+ if all(np.array_equal(_gravity(inp, d), out) for inp, out in exs):
790
+ # Gravity is input-dependent (positions depend on content)
791
+ # Can't use static Gather — need Conv to learn it
792
+ # But conv also can't learn arbitrary sorting...
793
+ # Skip for now — this needs a specialized ONNX graph
794
+ return None
795
+ return None
796
+
797
+ def s_mirror_h(td):
798
+ """Output = input | flip(input, horizontal), doubling width."""
799
+ exs = get_exs(td)
800
+ sp = fixed_shapes(td)
801
+ if sp is None: return None
802
+ (IH, IW), (OH, OW) = sp
803
+ if OH != IH or OW != 2 * IW: return None
804
+ if OW > 30: return None
805
+ for inp, out in exs:
806
+ expected = np.concatenate([inp, np.flip(inp, 1)], 1)
807
+ if not np.array_equal(expected, out): return None
808
+ # Build gather index
809
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
810
+ for r in range(OH):
811
+ for c in range(OW):
812
+ sc = c if c < IW else 2*IW - 1 - c
813
+ idx[r, c] = [r, sc]
814
+ return _build_gather_model(OH, OW, idx)
815
+
816
+ def s_mirror_v(td):
817
+ """Output = input over flip(input, vertical), doubling height."""
818
+ exs = get_exs(td)
819
+ sp = fixed_shapes(td)
820
+ if sp is None: return None
821
+ (IH, IW), (OH, OW) = sp
822
+ if OW != IW or OH != 2 * IH: return None
823
+ if OH > 30: return None
824
+ for inp, out in exs:
825
+ expected = np.concatenate([inp, np.flip(inp, 0)], 0)
826
+ if not np.array_equal(expected, out): return None
827
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
828
+ for r in range(OH):
829
+ for c in range(OW):
830
+ sr = r if r < IH else 2*IH - 1 - r
831
+ idx[r, c] = [sr, c]
832
+ return _build_gather_model(OH, OW, idx)
833
+
834
+ def s_quad_mirror(td):
835
+ """Output = 2x2 block of input with h/v flips."""
836
+ exs = get_exs(td)
837
+ sp = fixed_shapes(td)
838
+ if sp is None: return None
839
+ (IH, IW), (OH, OW) = sp
840
+ if OH != 2 * IH or OW != 2 * IW: return None
841
+ if OH > 30 or OW > 30: return None
842
+ for inp, out in exs:
843
+ expected = np.block([
844
+ [inp, np.flip(inp, 1)],
845
+ [np.flip(inp, 0), np.flip(np.flip(inp, 0), 1)]
846
+ ])
847
+ if not np.array_equal(expected, out): return None
848
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
849
+ for r in range(OH):
850
+ for c in range(OW):
851
+ sr = r if r < IH else 2*IH - 1 - r
852
+ sc = c if c < IW else 2*IW - 1 - c
853
+ idx[r, c] = [sr, sc]
854
+ return _build_gather_model(OH, OW, idx)
855
+
856
+ def s_fixed_crop(td):
857
+ """Output = fixed subregion of input."""
858
+ exs = get_exs(td)
859
+ sp = fixed_shapes(td)
860
+ if sp is None: return None
861
+ (IH, IW), (OH, OW) = sp
862
+ if OH > IH or OW > IW or (OH == IH and OW == IW): return None
863
+ for r0 in range(IH - OH + 1):
864
+ for c0 in range(IW - OW + 1):
865
+ if all(np.array_equal(inp[r0:r0+OH, c0:c0+OW], out) for inp, out in exs):
866
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
867
+ for r in range(OH):
868
+ for c in range(OW):
869
+ idx[r, c] = [r0 + r, c0 + c]
870
+ return _build_gather_model(OH, OW, idx)
871
+ return None
872
+
873
+ def s_nonuniform_scale(td):
874
+ """Output = input scaled by different factors in h and w."""
875
+ exs = get_exs(td)
876
+ sp = fixed_shapes(td)
877
+ if sp is None: return None
878
+ (IH, IW), (OH, OW) = sp
879
+ for fh, fw in [(1,2),(2,1),(1,3),(3,1),(2,3),(3,2),(1,4),(4,1),(2,4),(4,2)]:
880
+ if OH != IH*fh or OW != IW*fw: continue
881
+ if OH > 30 or OW > 30: continue
882
+ if all(np.array_equal(np.repeat(np.repeat(inp, fh, 0), fw, 1), out) for inp, out in exs):
883
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
884
+ for r in range(OH):
885
+ for c in range(OW):
886
+ idx[r, c] = [r//fh, c//fw]
887
+ return _build_gather_model(OH, OW, idx)
888
+ return None
889
+
890
+ def s_constant(td):
891
+ sp = fixed_shapes(td)
892
+ if sp is None: return None
893
+ exs = get_exs(td)
894
+ outs = [out for _,out in exs]
895
+ if not all(np.array_equal(outs[0], o) for o in outs[1:]): return None
896
+ const = np.zeros((1,10,30,30), dtype=np.float32)
897
+ for r, row in enumerate(outs[0]):
898
+ for c, v in enumerate(row):
899
+ const[0, int(v), r, c] = 1.0
900
+ inits = [numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
901
+ numpy_helper.from_array(const, 'c')]
902
+ nodes = [helper.make_node('Mul', ['input','z'], ['zd']),
903
+ helper.make_node('ReduceSum', ['zd'], ['s'], axes=[1,2,3], keepdims=1),
904
+ helper.make_node('Add', ['s','c'], ['output'])]
905
+ return mk(nodes, inits)
906
+
907
+ # ============================================================
908
+ # CONV SOLVERS
909
+ # ============================================================
910
+
911
+ def add_onehot_block(nodes, inits, am_name, oh_name):
912
+ """Equal + Cast one-hot encoding (replaces OneHot which lacks CUDA kernel)."""
913
+ classes = np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1)
914
+ inits.append(numpy_helper.from_array(classes, 'classes'))
915
+ nodes.append(helper.make_node('Equal', [am_name, 'classes'], ['eq']))
916
+ nodes.append(helper.make_node('Cast', ['eq'], [oh_name], to=TensorProto.FLOAT))
917
+
918
+ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
919
+ """Shared lstsq conv fitting. Returns (Wconv, B) or None."""
920
+ pad = ks // 2
921
+ feat = 10 * ks * ks + (1 if use_bias else 0)
922
+ if feat > 20000: return None
923
+
924
+ patches, targets = [], []
925
+ for inp_g, out_g in exs_raw:
926
+ ih, iw = inp_g.shape
927
+ if use_full_30:
928
+ oh_full = np.zeros((10, GH, GW), dtype=np.float64)
929
+ for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
930
+ oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
931
+ else:
932
+ oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
933
+ for c in range(10): oh_enc[c] = (inp_g == c)
934
+ oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
935
+
936
+ oh, ow = out_g.shape
937
+ for r in range(oh):
938
+ for c in range(ow):
939
+ p = oh_pad[:, r:r+ks, c:c+ks].flatten()
940
+ if use_bias: p = np.append(p, 1.0)
941
+ patches.append(p)
942
+ targets.append(int(out_g[r, c]))
943
+
944
+ n_patches = len(patches)
945
+ if feat > 5000 and n_patches > 2000: return None
946
+
947
+ P = np.array(patches, dtype=np.float64)
948
+ T = np.array(targets, dtype=np.int64)
949
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
950
+ for i, t in enumerate(T): T_oh[i, t] = 1.0
951
+
952
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
953
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T): return None
954
+
955
+ if use_bias:
956
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
957
+ B = WT[-1].astype(np.float32)
958
+ else:
959
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
960
+ B = None
961
+ return Wconv, B
962
+
963
+ def solve_conv_fixed(td, path, time_budget=30.0):
964
+ """Fixed-shape conv: Slice -> Conv -> ArgMax -> Equal+Cast -> Pad."""
965
+ exs = get_exs(td)
966
+ for inp, out in exs:
967
+ if inp.shape != out.shape: return None
968
+ shapes = set(inp.shape for inp, _ in exs)
969
+ if len(shapes) != 1: return None
970
+ IH, IW = shapes.pop()
971
+
972
+ # Use ARC-GEN data for better fitting
973
+ fit_exs = get_exs_for_fitting(td)
974
+ # Filter to same-shape, same IH/IW
975
+ fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
976
+
977
+ t_start = time.time()
978
+ for use_bias in [False, True]:
979
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
980
+ if time.time() - t_start > time_budget: return None
981
+ result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=False)
982
+ if result is None: continue
983
+ Wconv, B = result
984
+ pad = ks // 2
985
+ pad_h, pad_w = GH - IH, GW - IW
986
+
987
+ inits = [
988
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
989
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
990
+ numpy_helper.from_array(Wconv, 'W'),
991
+ ]
992
+ conv_inputs = ['grid', 'W']
993
+ if B is not None:
994
+ inits.append(numpy_helper.from_array(B, 'B'))
995
+ conv_inputs.append('B')
996
+
997
+ nodes = [
998
+ helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
999
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1000
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1001
+ ]
1002
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1003
+ nodes.append(
1004
+ helper.make_node('Pad', ['oh_out'], ['output'],
1005
+ pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1006
+ )
1007
+
1008
+ model = mk(nodes, inits)
1009
+ onnx.save(model, path)
1010
+ if validate(path, td): return 'conv_fixed', model
1011
+ return None
1012
+
1013
+ def solve_conv_variable(td, path, time_budget=30.0):
1014
+ """Variable-shape conv: Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)."""
1015
+ exs = get_exs(td)
1016
+ for inp, out in exs:
1017
+ if inp.shape != out.shape: return None
1018
+
1019
+ # Use ARC-GEN data for better fitting (variable shape, embedded in 30x30)
1020
+ fit_exs = get_exs_for_fitting_variable(td)
1021
+ fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape]
1022
+
1023
+ t_start = time.time()
1024
+ for use_bias in [False, True]:
1025
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
1026
+ if time.time() - t_start > time_budget: return None
1027
+ result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=True)
1028
+ if result is None: continue
1029
+ Wconv, B = result
1030
+ pad = ks // 2
1031
+
1032
+ inits = [numpy_helper.from_array(Wconv, 'W')]
1033
+ conv_inputs = ['input', 'W']
1034
+ if B is not None:
1035
+ inits.append(numpy_helper.from_array(B, 'B'))
1036
+ conv_inputs.append('B')
1037
+
1038
+ nodes = [
1039
+ helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1040
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1041
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1042
+ ]
1043
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1044
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1045
+
1046
+ model = mk(nodes, inits)
1047
+ onnx.save(model, path)
1048
+ if validate(path, td): return 'conv_var', model
1049
+ return None
1050
+
1051
+ def solve_conv_diffshape(td, path, time_budget=30.0):
1052
+ """Diff-shape conv for fixed io shapes where output is smaller."""
1053
+ sp = fixed_shapes(td)
1054
+ if sp is None: return None
1055
+ (IH, IW), (OH, OW) = sp
1056
+ if IH == OH and IW == OW: return None
1057
+ if OH > IH or OW > IW: return None
1058
+ if OH > 30 or OW > 30: return None
1059
+
1060
+ exs = get_exs(td)
1061
+ t_start = time.time()
1062
+
1063
+ for dr_off, dc_off in [(0, 0), ((IH-OH)//2, (IW-OW)//2)]:
1064
+ for use_bias in [False, True]:
1065
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
1066
+ if time.time() - t_start > time_budget: return None
1067
+ pad = ks // 2
1068
+ feat = 10 * ks * ks + (1 if use_bias else 0)
1069
+ if feat > 10000: continue
1070
+
1071
+ patches, targets = [], []
1072
+ valid = True
1073
+ for inp_g, out_g in exs:
1074
+ oh_enc = np.zeros((10, IH, IW), dtype=np.float64)
1075
+ for c in range(10): oh_enc[c] = (inp_g == c)
1076
+ oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
1077
+ for r in range(OH):
1078
+ for c in range(OW):
1079
+ sr, sc = r + dr_off, c + dc_off
1080
+ if sr < 0 or sr >= IH or sc < 0 or sc >= IW:
1081
+ valid = False; break
1082
+ p = oh_pad[:, sr:sr+ks, sc:sc+ks].flatten()
1083
+ if use_bias: p = np.append(p, 1.0)
1084
+ patches.append(p)
1085
+ targets.append(int(out_g[r, c]))
1086
+ if not valid: break
1087
+ if not valid: break
1088
+ if not valid: continue
1089
+
1090
+ n_patches = len(patches)
1091
+ if feat > 5000 and n_patches > 2000: continue
1092
+
1093
+ P = np.array(patches, dtype=np.float64)
1094
+ T = np.array(targets, dtype=np.int64)
1095
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
1096
+ for i, t in enumerate(T): T_oh[i, t] = 1.0
1097
+
1098
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1099
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
1100
+
1101
+ if use_bias:
1102
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
1103
+ B = WT[-1].astype(np.float32)
1104
+ else:
1105
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1106
+ B = None
1107
+
1108
+ pad_h, pad_w = GH - OH, GW - OW
1109
+ inits = [
1110
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1111
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1112
+ numpy_helper.from_array(Wconv, 'W'),
1113
+ numpy_helper.from_array(np.array([0,0,dr_off,dc_off], dtype=np.int64), 'cr_st'),
1114
+ numpy_helper.from_array(np.array([1,10,dr_off+OH,dc_off+OW], dtype=np.int64), 'cr_en'),
1115
+ ]
1116
+ conv_inputs = ['grid', 'W']
1117
+ if B is not None:
1118
+ inits.append(numpy_helper.from_array(B, 'B'))
1119
+ conv_inputs.append('B')
1120
+
1121
+ nodes = [
1122
+ helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1123
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1124
+ helper.make_node('Slice', ['co','cr_st','cr_en'], ['co_crop']),
1125
+ helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
1126
+ ]
1127
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1128
+ nodes.append(
1129
+ helper.make_node('Pad', ['oh_out'], ['output'],
1130
+ pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
1131
+ )
1132
+
1133
+ model = mk(nodes, inits)
1134
+ onnx.save(model, path)
1135
+ if validate(path, td): return 'conv_diff', model
1136
+ return None
1137
+
1138
+ def solve_conv_var_diff(td, path, time_budget=30.0):
1139
+ """Variable diff-shape conv: Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(output_mask).
1140
+ Works when output shape differs from input but mapping is convolutional on 30x30 grid."""
1141
+ exs = get_exs(td)
1142
+
1143
+ t_start = time.time()
1144
+ for use_bias in [False, True]:
1145
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
1146
+ if time.time() - t_start > time_budget: return None
1147
+
1148
+ pad = ks // 2
1149
+ feat = 10 * ks * ks + (1 if use_bias else 0)
1150
+ if feat > 20000: continue
1151
+
1152
+ patches, targets = [], []
1153
+ for inp_g, out_g in exs:
1154
+ ih, iw = inp_g.shape
1155
+ oh, ow = out_g.shape
1156
+ oh_full = np.zeros((10, GH, GW), dtype=np.float64)
1157
+ for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
1158
+ oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
1159
+
1160
+ for r in range(oh):
1161
+ for c in range(ow):
1162
+ p = oh_pad[:, r:r+ks, c:c+ks].flatten()
1163
+ if use_bias: p = np.append(p, 1.0)
1164
+ patches.append(p)
1165
+ targets.append(int(out_g[r, c]))
1166
+
1167
+ n_patches = len(patches)
1168
+ if feat > 5000 and n_patches > 2000: continue
1169
+
1170
+ P = np.array(patches, dtype=np.float64)
1171
+ T = np.array(targets, dtype=np.int64)
1172
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
1173
+ for i, t in enumerate(T): T_oh[i, t] = 1.0
1174
+
1175
+ try:
1176
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1177
+ except:
1178
+ continue
1179
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
1180
+
1181
+ if use_bias:
1182
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
1183
+ B = WT[-1].astype(np.float32)
1184
+ else:
1185
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1186
+ B = None
1187
+
1188
+ # Use ReduceSum of output channels as mask (sum across channels == 1 for valid pixels)
1189
+ # But we don't know the output mask at inference time from input alone...
1190
+ # We need a way to derive the output mask from the input.
1191
+ # For same-shape: mask = ReduceSum(input, axis=1) works
1192
+ # For diff-shape: we need to compute the output mask differently
1193
+ #
1194
+ # Approach: Conv output at valid positions should have max > threshold,
1195
+ # and at padding positions max ≈ 0. Use the ArgMax+OneHot and then
1196
+ # mask with ReduceSum(input) which is 1 at input positions but 0 at padding.
1197
+ # BUT output may be LARGER than input...
1198
+ #
1199
+ # Alternative: just use Conv -> ArgMax -> Equal+Cast -> Mul(input_mask_expanded)
1200
+ # where input_mask covers the output region too.
1201
+ # This won't work if output extends beyond input region.
1202
+ #
1203
+ # Simplest correct approach: let the conv produce valid one-hot everywhere,
1204
+ # then the padding region should naturally produce channel-0 output.
1205
+ # Since padding is all-zero input, conv output there = bias only.
1206
+ # If no bias, conv output = 0 for all channels -> argmax gives channel 0 -> onehot gives [1,0,...,0]
1207
+ # which equals the padding encoding (channel 0 = 1 in padding).
1208
+ # Wait - that's WRONG for the NeuroGolf format. In the padding region, ALL channels should be 0.
1209
+ # The one-hot encoding has channel[color]=1, but padding = ALL zeros.
1210
+ #
1211
+ # So we NEED a mask. But for diff-shape, what mask?
1212
+ # If output is always top-left aligned and we know max output size...
1213
+ # We can't statically determine the output mask from the input.
1214
+ #
1215
+ # However: we can try the ReduceSum approach anyway — if conv naturally
1216
+ # produces channel-0 dominant output in padding, then:
1217
+ # mask = ReduceSum(input, axis=1) gives 1 for input pixels, 0 for padding
1218
+ # If output region ⊆ input region, this works.
1219
+ # If output region > input region... we need the output's ReduceSum instead.
1220
+
1221
+ # For tasks where output fits within input bounds, use input mask
1222
+ all_output_within_input = all(
1223
+ out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
1224
+ for inp_g, out_g in exs
1225
+ )
1226
+
1227
+ if not all_output_within_input:
1228
+ continue # Skip tasks where output extends beyond input
1229
+
1230
+ inits = [numpy_helper.from_array(Wconv, 'W')]
1231
+ conv_inputs = ['input', 'W']
1232
+ if B is not None:
1233
+ inits.append(numpy_helper.from_array(B, 'B'))
1234
+ conv_inputs.append('B')
1235
+
1236
+ nodes = [
1237
+ helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1238
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1239
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1240
+ ]
1241
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1242
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1243
+
1244
+ model = mk(nodes, inits)
1245
+ onnx.save(model, path)
1246
+ if validate(path, td): return 'conv_var_diff', model
1247
+ return None
1248
+
1249
+ # ============================================================
1250
+ # MAIN
1251
+ # ============================================================
1252
+
1253
+ ANALYTICAL_SOLVERS = [
1254
+ ('identity', s_identity), ('constant', s_constant), ('color_map', s_color_map),
1255
+ ('transpose', s_transpose), ('flip', s_flip), ('rotate', s_rotate),
1256
+ ('tile', s_tile), ('upscale', s_upscale), ('kronecker', s_kronecker),
1257
+ ('nonuniform_scale', s_nonuniform_scale),
1258
+ ('mirror_h', s_mirror_h), ('mirror_v', s_mirror_v), ('quad_mirror', s_quad_mirror),
1259
+ ('concat', s_concat), ('concat_enhanced', s_concat_enhanced),
1260
+ ('diagonal_tile', s_diagonal_tile),
1261
+ ('fixed_crop', s_fixed_crop),
1262
+ ('spatial_gather', s_spatial_gather),
1263
+ ('shift', s_shift),
1264
+ ('varshape_spatial_gather', s_varshape_spatial_gather),
1265
+ ]
1266
+
1267
+ def solve_task(tn, td, outdir, conv_budget=30.0):
1268
+ t_start = time.time()
1269
+ os.makedirs(outdir, exist_ok=True)
1270
+ path = os.path.join(outdir, f"task{tn:03d}.onnx")
1271
+
1272
+ # Skip excluded tasks
1273
+ if tn in EXCLUDED_TASKS:
1274
+ return False, 'excluded', None, time.time() - t_start, path
1275
+
1276
+ # 1. Try analytical solvers (fast, tiny models)
1277
+ for sname, sfn in ANALYTICAL_SOLVERS:
1278
+ try:
1279
+ model = sfn(td)
1280
+ if model is None: continue
1281
+ onnx.save(model, path)
1282
+ if validate(path, td):
1283
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1284
+ except: pass
1285
+
1286
+ # 2. Determine task shape category and try conv solvers
1287
+ exs = get_exs(td)
1288
+ same_shape = all(inp.shape == out.shape for inp, out in exs)
1289
+ shapes = set(inp.shape for inp, _ in exs)
1290
+ fixed_in = len(shapes) == 1
1291
+
1292
+ conv_time = conv_budget
1293
+
1294
+ if same_shape:
1295
+ if fixed_in:
1296
+ result = solve_conv_fixed(td, path, time_budget=conv_time/2)
1297
+ if result is not None:
1298
+ sname, model = result
1299
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1300
+ result = solve_conv_variable(td, path, time_budget=conv_time)
1301
+ if result is not None:
1302
+ sname, model = result
1303
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1304
+ else:
1305
+ sp = fixed_shapes(td)
1306
+ if sp is not None:
1307
+ (IH,IW),(OH,OW) = sp
1308
+ if OH <= IH and OW <= IW:
1309
+ result = solve_conv_diffshape(td, path, time_budget=conv_time)
1310
+ if result is not None:
1311
+ sname, model = result
1312
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1313
+
1314
+ # Try variable diff-shape conv (output within input bounds)
1315
+ result = solve_conv_var_diff(td, path, time_budget=conv_time)
1316
+ if result is not None:
1317
+ sname, model = result
1318
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1319
+
1320
+ return False, None, None, time.time() - t_start, path
1321
+
1322
+ def run_tasks(task_nums, tasks, output_dir, conv_budget, use_wandb):
1323
+ results = {}
1324
+ costs_dict = {}
1325
+ total_score = 0
1326
+ for tn in task_nums:
1327
+ if tn not in tasks:
1328
+ continue
1329
+ if tn in EXCLUDED_TASKS:
1330
+ print(f"Task {tn:3d}: EXCLUDED (officially)")
1331
+ continue
1332
+
1333
+ td = tasks[tn]['data']
1334
+ ok, sname, sz, t_task, model_path = solve_task(tn, td, output_dir, conv_budget)
1335
+
1336
+ if ok:
1337
+ macs, memory, params = score_network(model_path)
1338
+ if macs is None:
1339
+ macs, memory, params = 0, 0, 0
1340
+ cost = macs + memory + params
1341
+ score = max(1.0, 25.0 - math.log(max(1, cost)))
1342
+ total_score += score
1343
+
1344
+ results[tn] = (sname, t_task, sz)
1345
+ costs_dict[tn] = cost
1346
+ print(f"Task {tn:3d}: {sname:25s} {score:7.3f} {cost:>12} {t_task:7.3f}s ({sz:>8,} bytes)")
1347
+ else:
1348
+ print(f"Task {tn:3d}: UNSOLVED {t_task:7.3f}s")
1349
+ cost = 0
1350
+
1351
+ if use_wandb and wandb is not None:
1352
+ wandb.log({
1353
+ "task_id": tn,
1354
+ "solver": sname if ok else "unsolved",
1355
+ "onnx_bytes": sz if ok else 0,
1356
+ "task_time_sec": t_task,
1357
+ "cost": cost,
1358
+ "score": score if ok else 0,
1359
+ })
1360
+
1361
+ return results, costs_dict, total_score
1362
+
1363
+
1364
+ def main():
1365
+ parser = argparse.ArgumentParser()
1366
+ parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
1367
+ parser.add_argument('--arcgen_dir', default='', help='Path to ARC-GEN-100K/ directory')
1368
+ parser.add_argument('--output_dir', default='submission')
1369
+ parser.add_argument('--kaggle', action='store_true')
1370
+ parser.add_argument('--conv_budget', type=float, default=30.0)
1371
+ parser.add_argument('--tasks', type=str, default='')
1372
+ parser.add_argument('--device', type=str, default='auto', choices=['auto','cpu','cuda'])
1373
+ parser.add_argument('--use_wandb', action='store_true')
1374
+ args = parser.parse_args()
1375
+ global ORT_PROVIDERS
1376
+ config = {
1377
+ "device": args.device,
1378
+ "conv_budget": args.conv_budget,
1379
+ "data_dir": args.data_dir,
1380
+ "arcgen_dir": args.arcgen_dir,
1381
+ "tasks": args.tasks,
1382
+ }
1383
+
1384
+ if args.device == 'cuda':
1385
+ ORT_PROVIDERS = ['CUDAExecutionProvider', 'CPUExecutionProvider']
1386
+ elif args.device == 'cpu':
1387
+ ORT_PROVIDERS = ['CPUExecutionProvider']
1388
+
1389
+ ort.set_default_logger_severity(3)
1390
+ print(f"Using providers: {ORT_PROVIDERS}")
1391
+
1392
+ if args.kaggle:
1393
+ tasks = load_tasks_kaggle(args.data_dir)
1394
+ else:
1395
+ arcgen = args.arcgen_dir if args.arcgen_dir else None
1396
+ tasks = load_tasks_dir(args.data_dir, arcgen_dir=arcgen)
1397
+
1398
+ # Count arc-gen examples
1399
+ total_arcgen = sum(len(t['data'].get('arc-gen', [])) for t in tasks.values())
1400
+ print(f"Loaded {len(tasks)} tasks ({total_arcgen} ARC-GEN examples)")
1401
+ print(f"Excluded tasks: {sorted(EXCLUDED_TASKS)}")
1402
+
1403
+ task_nums = [int(t) for t in args.tasks.split(',')] if args.tasks else sorted(tasks.keys())
1404
+ active_tasks = [t for t in task_nums if t not in EXCLUDED_TASKS]
1405
+ print(f"Solving {len(active_tasks)} active tasks (skipping {len(task_nums) - len(active_tasks)} excluded)")
1406
+ print(f"Conv budget: {args.conv_budget}s per task")
1407
+ print("=" * 70)
1408
+ t0 = time.time()
1409
+
1410
+ if args.use_wandb and wandb is not None:
1411
+ with wandb.init(project="neurogolf", name="solver_run", config=config):
1412
+ results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=True)
1413
+ else:
1414
+ results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=False)
1415
+
1416
+ elapsed = time.time() - t0
1417
+ print(f"\n{'='*70}")
1418
+ print(f"Solved: {len(results)}/{len(active_tasks)} active tasks in {elapsed:.0f}s")
1419
+ solver_names = [v[0] for v in results.values()]
1420
+ sc = Counter(solver_names)
1421
+ for s, c in sc.most_common(): print(f" {s}: {c}")
1422
+
1423
+ # Generate submission
1424
+ outdir = args.output_dir
1425
+ n_files = len([f for f in os.listdir(outdir) if f.endswith('.onnx')])
1426
+ total_size = sum(os.path.getsize(os.path.join(outdir, f))
1427
+ for f in os.listdir(outdir) if f.endswith('.onnx'))
1428
+
1429
+ # Create submission.zip
1430
+ zip_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.zip')
1431
+ buf = io.BytesIO()
1432
+ with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
1433
+ for f in sorted(os.listdir(outdir)):
1434
+ if f.endswith('.onnx'):
1435
+ zf.write(os.path.join(outdir, f), f)
1436
+ zip_bytes = buf.getvalue()
1437
+ with open(zip_path, 'wb') as f:
1438
+ f.write(zip_bytes)
1439
+ zip_size = len(zip_bytes)
1440
+
1441
+ # Create submission.csv
1442
+ csv_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.csv')
1443
+ with open(csv_path, 'w', newline='') as f:
1444
+ w = csv.writer(f)
1445
+ w.writerow(['task_id', 'total_cost'])
1446
+ for tn in sorted(costs_dict.keys()):
1447
+ w.writerow([f'task{tn:03d}', costs_dict[tn]])
1448
+
1449
+ # Estimate LB score: solved tasks get their score, unsolved get 1.0
1450
+ unsolved_count = len(active_tasks) - len(results)
1451
+ est_lb = total_score + unsolved_count * 1.0
1452
+
1453
+ print(f"\n{n_files} ONNX files, {total_size/1024:.1f} KB uncompressed")
1454
+ print(f"ZIP size: {zip_size/1024:.1f} KB / {MAX_FILESIZE/1024:.0f} KB limit {'OK' if zip_size <= MAX_FILESIZE else 'OVER!'}")
1455
+ print(f"Estimated LB score: {est_lb:.1f} (solved: {total_score:.1f} + unsolved: {unsolved_count}×1.0)")
1456
+ print(f"Written: {zip_path} | {csv_path}")
1457
+
1458
+ if __name__ == '__main__':
1459
+ main()