rogermt commited on
Commit
0316872
·
verified ·
1 Parent(s): 260c341

Upload neurogolf_solver.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. neurogolf_solver.py +1919 -1
neurogolf_solver.py CHANGED
@@ -1 +1,1919 @@
1
- FILE_CONTENT_PLACEHOLDER
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ARC-AGI NeuroGolf Championship - Complete Solver v5
4
+ Format: [1,10,30,30] one-hot input/output, opset 17, IR version 10.
5
+
6
+ v5 CHANGES:
7
+ - Switched to opset 17 (Kaggle-compatible) for cheaper analytical solvers
8
+ - Slice-based analytical solvers: rotation, flip, transpose (near-zero cost)
9
+ - LOOCV Ridge tuning in _lstsq_conv with condition number check + SVD-based λ auto-tune
10
+ - stride_tricks speedup for patch extraction
11
+ - Composition detectors: rotation+color, flip+color, transpose+color
12
+ - Channel reduction wrapper for tasks with <8 colors
13
+ - ARC-GEN validation, EXCLUDED tasks skipped, submission.csv generation
14
+
15
+ Solvers:
16
+ - Analytical: identity, constant, color_map, transpose, flip, rotate, tile, upscale,
17
+ concat, concat_enhanced, spatial_gather, varshape_spatial_gather,
18
+ diagonal_tile, kronecker, shift, mirror_h, mirror_v, quad_mirror,
19
+ fixed_crop, nonuniform_scale
20
+ - Composition: rotate+color_map, flip+color_map, transpose+color_map
21
+ - Conv (fixed shape): Slice -> Conv -> ArgMax -> Equal+Cast -> Pad
22
+ - Conv (variable shape): Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)
23
+ - Conv (diff shape): Slice -> Conv -> Slice(crop) -> ArgMax -> Equal+Cast -> Pad
24
+ - Channel reduction: Conv1x1(10->N) -> transform -> Conv1x1(N->10)
25
+
26
+ Usage:
27
+ python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission
28
+ python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission --conv_budget 60 --arcgen_dir ARC-GEN-100K/
29
+ """
30
+
31
+ import json, os, sys, math, time, argparse, csv, io, zipfile, warnings
32
+ import numpy as np
33
+ import onnx
34
+ from onnx import helper, TensorProto, numpy_helper
35
+ import onnxruntime as ort
36
+ from collections import Counter
37
+
38
+ try:
39
+ from neurogolf_utils import score_network as _score_network_official
40
+ HAS_ONNX_TOOL = True
41
+ except ImportError:
42
+ HAS_ONNX_TOOL = False
43
+
44
+ try:
45
+ import wandb
46
+ except ImportError:
47
+ wandb = None
48
+
49
+ BATCH, CH, GH, GW = 1, 10, 30, 30
50
+ GRID_SHAPE = [BATCH, CH, GH, GW]
51
+ DT = TensorProto.FLOAT
52
+ IR = 10
53
+ # v5: opset 17 for cheaper Slice-based transforms
54
+ OPSET = [helper.make_opsetid("", 17)]
55
+
56
+ # Officially excluded tasks (score 0 regardless)
57
+ EXCLUDED_TASKS = {21, 55, 80, 184, 202, 366}
58
+
59
+ # Max ARC-GEN examples to use for validation (to keep runtime reasonable)
60
+ MAX_ARCGEN_VALIDATE = 30
61
+
62
+ # Max ARC-GEN examples for conv fitting
63
+ MAX_ARCGEN_FIT = 0
64
+
65
+ def get_providers():
66
+ return ['CPUExecutionProvider']
67
+
68
+ ORT_PROVIDERS = get_providers()
69
+
70
+ # ============================================================
71
+ # LOAD / VALIDATE
72
+ # ============================================================
73
+
74
+ def load_tasks_dir(data_dir, arcgen_dir=None):
75
+ """Load ARC-AGI tasks and optionally merge ARC-GEN data."""
76
+ files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
77
+ tasks = {}
78
+ for i, f in enumerate(files):
79
+ with open(os.path.join(data_dir, f)) as fh:
80
+ data = json.load(fh)
81
+ hex_id = f.replace('.json','')
82
+ if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)):
83
+ with open(os.path.join(arcgen_dir, f)) as fh:
84
+ arcgen_examples = json.load(fh)
85
+ if isinstance(arcgen_examples, list):
86
+ data['arc-gen'] = arcgen_examples
87
+ if 'arc-gen' not in data:
88
+ data['arc-gen'] = []
89
+ tasks[i+1] = {'hex': hex_id, 'data': data}
90
+ return tasks
91
+
92
+ def load_tasks_kaggle(data_dir):
93
+ """Load Kaggle format tasks (already have arc-gen embedded)."""
94
+ tasks = {}
95
+ for tn in range(1, 401):
96
+ path = os.path.join(data_dir, f"task{tn:03d}.json")
97
+ if os.path.exists(path):
98
+ with open(path) as f:
99
+ data = json.load(f)
100
+ if 'arc-gen' not in data:
101
+ data['arc-gen'] = []
102
+ tasks[tn] = {'hex': f'task{tn:03d}', 'data': data}
103
+ return tasks
104
+
105
+ def to_onehot(grid):
106
+ arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
107
+ for r, row in enumerate(grid):
108
+ for c, v in enumerate(row):
109
+ if r < GH and c < GW and 0 <= v < CH:
110
+ arr[0, v, r, c] = 1.0
111
+ return arr
112
+
113
+ def validate(path, td):
114
+ """Validate model against ALL examples: train + test + arc-gen."""
115
+ try:
116
+ opts = ort.SessionOptions()
117
+ opts.log_severity_level = 3
118
+ sess = ort.InferenceSession(path, sess_options=opts, providers=ORT_PROVIDERS)
119
+ except:
120
+ return False
121
+ examples = td['train'] + td['test']
122
+ if 'arc-gen' in td:
123
+ examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
124
+ for ex in examples:
125
+ inp = to_onehot(ex['input'])
126
+ exp = to_onehot(ex['output'])
127
+ try:
128
+ out = sess.run(['output'], {'input': inp})[0]
129
+ out = (out > 0.0).astype(np.float32)
130
+ except:
131
+ return False
132
+ if not np.array_equal(out, exp):
133
+ return False
134
+ return True
135
+
136
+ def validate_raw(raw_bytes, td):
137
+ """Validate model from raw bytes against ALL examples."""
138
+ try:
139
+ opts = ort.SessionOptions()
140
+ opts.log_severity_level = 3
141
+ sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=ORT_PROVIDERS)
142
+ except:
143
+ return False
144
+ examples = td['train'] + td['test']
145
+ if 'arc-gen' in td:
146
+ examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
147
+ for ex in examples:
148
+ inp = to_onehot(ex['input'])
149
+ exp = to_onehot(ex['output'])
150
+ try:
151
+ out = sess.run(['output'], {'input': inp})[0]
152
+ out = (out > 0.0).astype(np.float32)
153
+ except:
154
+ return False
155
+ if not np.array_equal(out, exp):
156
+ return False
157
+ return True
158
+
159
+ # ============================================================
160
+ # STATIC PROFILER (no onnx_tool dependency)
161
+ # ============================================================
162
+
163
+ BANNED_OPS = {'Loop', 'Scan', 'NonZero', 'Unique', 'If', 'Function'}
164
+ MAX_FILESIZE = int(1.44 * 1024 * 1024)
165
+
166
+ def score_network(path):
167
+ """Static profiler matching Kaggle scoring: cost = macs + memory + params."""
168
+ if HAS_ONNX_TOOL:
169
+ try:
170
+ return _score_network_official(path)
171
+ except:
172
+ pass
173
+ return _static_profile(path)
174
+
175
+ def _static_profile(path):
176
+ """Compute cost without onnx_tool: params + nbytes + macs."""
177
+ try:
178
+ model = onnx.load(path)
179
+ except:
180
+ return None, None, None
181
+
182
+ tensors = {}
183
+ params = 0
184
+ nbytes = 0
185
+ macs = 0
186
+
187
+ for init in model.graph.initializer:
188
+ a = numpy_helper.to_array(init)
189
+ tensors[init.name] = a
190
+ params += a.size
191
+ nbytes += a.nbytes
192
+
193
+ for nd in model.graph.node:
194
+ if nd.op_type == 'Constant':
195
+ for attr in nd.attribute:
196
+ if attr.t and attr.t.ByteSize() > 0:
197
+ try:
198
+ a = numpy_helper.to_array(attr.t)
199
+ if nd.output:
200
+ tensors[nd.output[0]] = a
201
+ params += a.size
202
+ nbytes += a.nbytes
203
+ except:
204
+ pass
205
+
206
+ if nd.op_type in BANNED_OPS:
207
+ return None, None, None
208
+
209
+ if nd.op_type == 'Conv' and len(nd.input) >= 2 and nd.input[1] in tensors:
210
+ w = tensors[nd.input[1]]
211
+ if w.ndim == 4:
212
+ co, ci, kh, kw = w.shape
213
+ macs += co * ci * kh * kw * GH * GW
214
+
215
+ return int(macs), int(nbytes), int(params)
216
+
217
+ def mk(nodes, inits=None):
218
+ x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
219
+ y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
220
+ g = helper.make_graph(nodes, "g", [x], [y], initializer=inits or [])
221
+ return helper.make_model(g, ir_version=IR, opset_imports=OPSET)
222
+
223
+ def get_exs(td):
224
+ """Get examples for analytical solvers (train+test only)."""
225
+ return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
226
+ for ex in td['train'] + td['test']]
227
+
228
+ def get_exs_for_fitting(td):
229
+ """Get examples for conv fitting. Uses train+test + arc-gen WHERE SIZES MATCH."""
230
+ base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
231
+ for ex in td['train'] + td['test']]
232
+
233
+ if not base_exs:
234
+ return base_exs
235
+
236
+ base_shapes = {inp.shape for inp, _ in base_exs}
237
+ if len(base_shapes) != 1:
238
+ return base_exs
239
+
240
+ base_shape = list(base_shapes)[0]
241
+
242
+ ag_exs = []
243
+ for ex in td.get('arc-gen', []):
244
+ inp = np.array(ex['input'], dtype=np.int64)
245
+ out = np.array(ex['output'], dtype=np.int64)
246
+ if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
247
+ ag_exs.append((inp, out))
248
+
249
+ return base_exs + ag_exs[:10]
250
+
251
+ def get_exs_for_fitting_variable(td):
252
+ """Get examples for variable-shape conv fitting."""
253
+ base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
254
+ for ex in td['train'] + td['test']]
255
+
256
+ ag_exs = []
257
+ for ex in td.get('arc-gen', []):
258
+ inp = np.array(ex['input'], dtype=np.int64)
259
+ out = np.array(ex['output'], dtype=np.int64)
260
+ if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
261
+ ag_exs.append((inp, out))
262
+
263
+ return base_exs + ag_exs[:20]
264
+
265
+ def fixed_shapes(td):
266
+ shapes = set()
267
+ for inp, out in get_exs(td):
268
+ shapes.add((inp.shape, out.shape))
269
+ return list(shapes)[0] if len(shapes) == 1 else None
270
+
271
+ # ============================================================
272
+ # GATHER HELPERS (opset 17 compatible)
273
+ # ============================================================
274
+
275
+ def _build_gather_model(OH, OW, idx):
276
+ """Build Gather-based spatial remapping model."""
277
+ flat_idx = np.zeros((GH*GW,), dtype=np.int64)
278
+ mask = np.zeros((1,1,GH,GW), dtype=np.float32)
279
+ for oi in range(OH):
280
+ for oj in range(OW):
281
+ flat_idx[oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
282
+ mask[0,0,oi,oj] = 1.0
283
+ inits = [
284
+ numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
285
+ numpy_helper.from_array(flat_idx, 'idx'),
286
+ numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
287
+ numpy_helper.from_array(mask, 'mask'),
288
+ ]
289
+ nodes = [
290
+ helper.make_node('Reshape', ['input','fs'], ['flat']),
291
+ helper.make_node('Gather', ['flat','idx'], ['g'], axis=2),
292
+ helper.make_node('Reshape', ['g','os'], ['raw']),
293
+ helper.make_node('Mul', ['raw','mask'], ['output']),
294
+ ]
295
+ return mk(nodes, inits)
296
+
297
+ def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
298
+ """Build Gather model with constant fill for unmapped positions."""
299
+ flat_idx = np.zeros((GH*GW,), dtype=np.int64)
300
+ gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
301
+ const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
302
+ for oi in range(OH):
303
+ for oj in range(OW):
304
+ if idx[oi,oj,0] >= 0:
305
+ flat_idx[oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
306
+ gather_mask[0,0,oi,oj] = 1.0
307
+ elif cst[oi,oj] >= 0:
308
+ const_oh[0, cst[oi,oj], oi, oj] = 1.0
309
+ has_const = np.any(const_oh > 0)
310
+ inits = [
311
+ numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
312
+ numpy_helper.from_array(flat_idx, 'idx'),
313
+ numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
314
+ numpy_helper.from_array(gather_mask, 'gmask'),
315
+ ]
316
+ nodes = [
317
+ helper.make_node('Reshape', ['input','fs'], ['flat']),
318
+ helper.make_node('Gather', ['flat','idx'], ['g'], axis=2),
319
+ helper.make_node('Reshape', ['g','os'], ['raw']),
320
+ helper.make_node('Mul', ['raw','gmask'], ['masked']),
321
+ ]
322
+ if has_const:
323
+ inits.append(numpy_helper.from_array(const_oh, 'cst'))
324
+ nodes.append(helper.make_node('Add', ['masked','cst'], ['output']))
325
+ else:
326
+ nodes[-1] = helper.make_node('Mul', ['raw','gmask'], ['output'])
327
+ return mk(nodes, inits)
328
+
329
+ # ============================================================
330
+ # SLICE-BASED ANALYTICAL SOLVERS (opset 17, ~0 cost)
331
+ # ============================================================
332
+
333
+ def _build_pad_nodes(input_name, IH, IW, output_name='output', pad_name='pads'):
334
+ """Build Pad nodes to pad spatial dims to 30x30 (opset 17 with tensor pads).
335
+ Returns (pad_inits, pad_node)."""
336
+ pad_h, pad_w = GH - IH, GW - IW
337
+ if pad_h > 0 or pad_w > 0:
338
+ pads_arr = np.array([0, 0, 0, 0, 0, 0, pad_h, pad_w], dtype=np.int64)
339
+ pad_inits = [numpy_helper.from_array(pads_arr, pad_name)]
340
+ pad_node = helper.make_node('Pad', [input_name, pad_name], [output_name], mode='constant')
341
+ return pad_inits, pad_node
342
+ else:
343
+ return [], helper.make_node('Identity', [input_name], [output_name])
344
+
345
+ def _build_slice_flip_model(axis, IH, IW):
346
+ """Build a Slice-based flip model using negative steps (opset 17).
347
+ Extracts content, applies flip, pads back to 30x30.
348
+ axis=0: vertical flip (reverse rows), axis=1: horizontal flip (reverse cols).
349
+ """
350
+ # Step 1: Extract content region [1,10,30,30] -> [1,10,IH,IW]
351
+ ex_st = np.array([0,0,0,0], dtype=np.int64)
352
+ ex_en = np.array([1,10,IH,IW], dtype=np.int64)
353
+
354
+ # Step 2: Flip with negative step Slice
355
+ if axis == 0:
356
+ starts = np.array([IH-1], dtype=np.int64)
357
+ ends = np.array([-IH-1], dtype=np.int64)
358
+ axes = np.array([2], dtype=np.int64)
359
+ steps = np.array([-1], dtype=np.int64)
360
+ else:
361
+ starts = np.array([IW-1], dtype=np.int64)
362
+ ends = np.array([-IW-1], dtype=np.int64)
363
+ axes = np.array([3], dtype=np.int64)
364
+ steps = np.array([-1], dtype=np.int64)
365
+
366
+ inits = [
367
+ numpy_helper.from_array(ex_st, 'ex_st'),
368
+ numpy_helper.from_array(ex_en, 'ex_en'),
369
+ numpy_helper.from_array(starts, 'sl_st'),
370
+ numpy_helper.from_array(ends, 'sl_en'),
371
+ numpy_helper.from_array(axes, 'sl_ax'),
372
+ numpy_helper.from_array(steps, 'sl_sp'),
373
+ ]
374
+ nodes = [
375
+ helper.make_node('Slice', ['input','ex_st','ex_en'], ['content']),
376
+ helper.make_node('Slice', ['content','sl_st','sl_en','sl_ax','sl_sp'], ['flipped']),
377
+ ]
378
+
379
+ # Step 3: Pad back to 30x30 if needed
380
+ pad_inits, pad_node = _build_pad_nodes('flipped', IH, IW)
381
+ inits.extend(pad_inits)
382
+ nodes.append(pad_node)
383
+
384
+ return mk(nodes, inits)
385
+
386
+ def _build_slice_transpose_model(IH, IW):
387
+ """Build a Transpose-based transpose model (perm=[0,1,3,2]).
388
+ Extracts content, transposes, pads back to 30x30."""
389
+ # Step 1: Extract content [1,10,30,30] -> [1,10,IH,IW]
390
+ ex_st = np.array([0,0,0,0], dtype=np.int64)
391
+ ex_en = np.array([1,10,IH,IW], dtype=np.int64)
392
+
393
+ inits = [
394
+ numpy_helper.from_array(ex_st, 'ex_st'),
395
+ numpy_helper.from_array(ex_en, 'ex_en'),
396
+ ]
397
+ nodes = [
398
+ helper.make_node('Slice', ['input','ex_st','ex_en'], ['content']),
399
+ helper.make_node('Transpose', ['content'], ['transposed'], perm=[0,1,3,2]),
400
+ ]
401
+ # After transpose, shape is [1,10,IW,IH]. Need to pad to [1,10,30,30].
402
+ pad_inits, pad_node = _build_pad_nodes('transposed', IW, IH)
403
+ nodes.append(pad_node)
404
+ return mk(nodes, inits + pad_inits)
405
+
406
+ def _build_slice_rotate_model(k, IH, IW):
407
+ """Build a rotation model using Transpose + Slice (opset 17).
408
+ Extracts content, applies rotation, pads back to 30x30.
409
+ Matches existing s_rotate behavior (np.rot90):
410
+ k=1: 90° CCW = Transpose then vflip (reverse rows)
411
+ k=2: 180° = hflip then vflip
412
+ k=3: 270° CCW = Transpose then hflip (reverse cols)
413
+ """
414
+ # Step 1: Extract content [1,10,30,30] -> [1,10,IH,IW]
415
+ ex_st = np.array([0,0,0,0], dtype=np.int64)
416
+ ex_en = np.array([1,10,IH,IW], dtype=np.int64)
417
+
418
+ inits = [
419
+ numpy_helper.from_array(ex_st, 'ex_st'),
420
+ numpy_helper.from_array(ex_en, 'ex_en'),
421
+ ]
422
+ nodes = [helper.make_node('Slice', ['input','ex_st','ex_en'], ['content'])]
423
+ current = 'content'
424
+
425
+ if k in (1, 3):
426
+ # Transpose: [1,10,IH,IW] -> [1,10,IW,IH]
427
+ nodes.append(helper.make_node('Transpose', [current], ['t'], perm=[0,1,3,2]))
428
+ current = 't'
429
+ new_IH, new_IW = IW, IH
430
+ else:
431
+ new_IH, new_IW = IH, IW
432
+
433
+ # Apply flips with negative step Slice
434
+ if k == 1:
435
+ # vflip (reverse rows, axis=2) after transpose
436
+ starts = np.array([new_IH-1], dtype=np.int64)
437
+ ends = np.array([-new_IH-1], dtype=np.int64)
438
+ axes = np.array([2], dtype=np.int64)
439
+ steps = np.array([-1], dtype=np.int64)
440
+ elif k == 2:
441
+ # 180° = hflip then vflip
442
+ starts_h = np.array([new_IW-1], dtype=np.int64)
443
+ ends_h = np.array([-new_IW-1], dtype=np.int64)
444
+ axes_h = np.array([3], dtype=np.int64)
445
+ steps_h = np.array([-1], dtype=np.int64)
446
+ inits.extend([
447
+ numpy_helper.from_array(starts_h, 'st_h'),
448
+ numpy_helper.from_array(ends_h, 'en_h'),
449
+ numpy_helper.from_array(axes_h, 'ax_h'),
450
+ numpy_helper.from_array(steps_h, 'sp_h'),
451
+ ])
452
+ nodes.append(helper.make_node('Slice', [current,'st_h','en_h','ax_h','sp_h'], ['fh']))
453
+ current = 'fh'
454
+ starts_v = np.array([new_IH-1], dtype=np.int64)
455
+ ends_v = np.array([-new_IH-1], dtype=np.int64)
456
+ axes_v = np.array([2], dtype=np.int64)
457
+ steps_v = np.array([-1], dtype=np.int64)
458
+ inits.extend([
459
+ numpy_helper.from_array(starts_v, 'st_v'),
460
+ numpy_helper.from_array(ends_v, 'en_v'),
461
+ numpy_helper.from_array(axes_v, 'ax_v'),
462
+ numpy_helper.from_array(steps_v, 'sp_v'),
463
+ ])
464
+ nodes.append(helper.make_node('Slice', [current,'st_v','en_v','ax_v','sp_v'], ['rot']))
465
+ current = 'rot'
466
+ pad_inits, pad_node = _build_pad_nodes(current, new_IH, new_IW)
467
+ nodes.append(pad_node)
468
+ return mk(nodes, inits + pad_inits)
469
+ elif k == 3:
470
+ # hflip (reverse cols, axis=3) after transpose
471
+ starts = np.array([new_IW-1], dtype=np.int64)
472
+ ends = np.array([-new_IW-1], dtype=np.int64)
473
+ axes = np.array([3], dtype=np.int64)
474
+ steps = np.array([-1], dtype=np.int64)
475
+
476
+ inits.extend([
477
+ numpy_helper.from_array(starts, 'sl_st'),
478
+ numpy_helper.from_array(ends, 'sl_en'),
479
+ numpy_helper.from_array(axes, 'sl_ax'),
480
+ numpy_helper.from_array(steps, 'sl_sp'),
481
+ ])
482
+ nodes.append(helper.make_node('Slice', [current,'sl_st','sl_en','sl_ax','sl_sp'], ['rot']))
483
+ current = 'rot'
484
+
485
+ # Pad back to 30x30
486
+ pad_inits, pad_node = _build_pad_nodes(current, new_IH, new_IW)
487
+ nodes.append(pad_node)
488
+
489
+ return mk(nodes, inits + pad_inits)
490
+
491
+ # ============================================================
492
+ # ANALYTICAL SOLVERS
493
+ # ============================================================
494
+
495
+ def s_identity(td):
496
+ for ex in td['train']+td['test']:
497
+ if ex['input'] != ex['output']: return None
498
+ return mk([helper.make_node('Identity', ['input'], ['output'])])
499
+
500
+ def _get_color_map(td):
501
+ """Extract color map if consistent across all examples, or None."""
502
+ cm = {}
503
+ for ex in td['train']+td['test']:
504
+ inp, out = np.array(ex['input']), np.array(ex['output'])
505
+ if inp.shape != out.shape: return None
506
+ for iv, ov in zip(inp.flat, out.flat):
507
+ iv, ov = int(iv), int(ov)
508
+ if iv in cm and cm[iv] != ov: return None
509
+ cm[iv] = ov
510
+ return cm
511
+
512
+ def _build_color_map_model(cm, is_permutation=None):
513
+ """Build ONNX model for a color map."""
514
+ if is_permutation is None:
515
+ is_permutation = (set(cm.keys()) == set(cm.values()))
516
+
517
+ if is_permutation:
518
+ gather_ch = np.arange(10, dtype=np.int32)
519
+ for src, dst in cm.items():
520
+ if 0 <= src < 10 and 0 <= dst < 10:
521
+ gather_ch[dst] = src
522
+ inits = [numpy_helper.from_array(gather_ch, 'gi')]
523
+ nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
524
+ return mk(nodes, inits)
525
+ else:
526
+ W = np.zeros((10,10,1,1), dtype=np.float32)
527
+ for ic in range(10):
528
+ W[cm.get(ic,ic), ic, 0, 0] = 1.0
529
+ return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
530
+ [numpy_helper.from_array(W, 'W')])
531
+
532
+ def s_color_map(td):
533
+ cm = _get_color_map(td)
534
+ if cm is None: return None
535
+ is_permutation = (set(cm.keys()) == set(cm.values()))
536
+ return _build_color_map_model(cm, is_permutation)
537
+
538
+ def s_transpose(td):
539
+ exs = get_exs(td)
540
+ sp = fixed_shapes(td)
541
+ if sp is None: return None
542
+ (IH,IW),(OH,OW) = sp
543
+ if not all(np.array_equal(out, inp.T) for inp, out in exs): return None
544
+ return _build_slice_transpose_model(IH, IW)
545
+
546
+ def s_flip(td):
547
+ exs = get_exs(td)
548
+ sp = fixed_shapes(td)
549
+ if sp is None: return None
550
+ (IH,IW),(OH,OW) = sp
551
+ if (IH,IW) != (OH,OW): return None
552
+ for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
553
+ if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
554
+ return _build_slice_flip_model(axis, IH, IW)
555
+ return None
556
+
557
+ def s_rotate(td):
558
+ exs = get_exs(td)
559
+ sp = fixed_shapes(td)
560
+ if sp is None: return None
561
+ (IH,IW),(OH,OW) = sp
562
+ for k in [1, 2, 3]:
563
+ if all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs):
564
+ return _build_slice_rotate_model(k, IH, IW)
565
+ return None
566
+
567
+ def s_spatial_gather(td):
568
+ sp = fixed_shapes(td)
569
+ if sp is None: return None
570
+ (IH,IW),(OH,OW) = sp
571
+ exs = get_exs(td)
572
+ idx = np.full((OH,OW,2), -1, dtype=np.int64)
573
+ cst = np.full((OH,OW), -1, dtype=np.int64)
574
+ for oi in range(OH):
575
+ for oj in range(OW):
576
+ vals = set(int(out[oi,oj]) for _,out in exs)
577
+ if len(vals) == 1: cst[oi,oj] = vals.pop()
578
+ found = False
579
+ for ri in range(IH):
580
+ for rj in range(IW):
581
+ if all(int(inp[ri,rj]) == int(out[oi,oj]) for inp,out in exs):
582
+ idx[oi,oj] = [ri, rj]; found = True; break
583
+ if found: break
584
+ if not found and cst[oi,oj] < 0: return None
585
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
586
+
587
+ def s_varshape_spatial_gather(td):
588
+ """Spatial gather that works for variable-shape tasks by embedding in 30x30."""
589
+ sp = fixed_shapes(td)
590
+ if sp is not None: return None
591
+ exs = get_exs(td)
592
+
593
+ exs_30 = []
594
+ for inp, out in exs:
595
+ ih, iw = inp.shape
596
+ oh, ow = out.shape
597
+ inp30 = np.zeros((30, 30), dtype=np.int64)
598
+ out30 = np.zeros((30, 30), dtype=np.int64)
599
+ inp30[:ih, :iw] = inp
600
+ out30[:oh, :ow] = out
601
+ exs_30.append((inp30, out30))
602
+
603
+ idx = np.full((30, 30, 2), -1, dtype=np.int64)
604
+ cst = np.full((30, 30), -1, dtype=np.int64)
605
+
606
+ for oi in range(30):
607
+ for oj in range(30):
608
+ vals = set(int(out30[oi, oj]) for _, out30 in exs_30)
609
+ if len(vals) == 1:
610
+ cst[oi, oj] = vals.pop()
611
+ found = False
612
+ for ri in range(30):
613
+ for rj in range(30):
614
+ if all(int(inp30[ri, rj]) == int(out30[oi, oj]) for inp30, out30 in exs_30):
615
+ idx[oi, oj] = [ri, rj]
616
+ found = True
617
+ break
618
+ if found: break
619
+ if not found and cst[oi, oj] < 0:
620
+ return None
621
+
622
+ return _build_gather_model_with_const(30, 30, 30, 30, idx, cst)
623
+
624
+ def s_tile(td):
625
+ exs = get_exs(td)
626
+ in_shapes = set(inp.shape for inp,_ in exs)
627
+ if len(in_shapes) != 1: return None
628
+ IH, IW = in_shapes.pop()
629
+ tiles = set()
630
+ for inp, out in exs:
631
+ OH, OW = out.shape
632
+ if OH % IH or OW % IW: return None
633
+ rH, rW = OH//IH, OW//IW
634
+ if rH < 1 or rW < 1 or (rH==1 and rW==1): return None
635
+ tiles.add((rH, rW))
636
+ if len(tiles) != 1: return None
637
+ rH, rW = tiles.pop()
638
+ OH, OW = IH*rH, IW*rW
639
+ if OH > 30 or OW > 30: return None
640
+ for inp, out in exs:
641
+ if not np.array_equal(out, np.tile(inp, (rH, rW))): return None
642
+ pad_h, pad_w = 30-OH, 30-OW
643
+ inits = [
644
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'st'),
645
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'en'),
646
+ numpy_helper.from_array(np.array([1,1,rH,rW], dtype=np.int64), 'rp'),
647
+ ]
648
+ pads_arr = np.array([0, 0, 0, 0, 0, 0, pad_h, pad_w], dtype=np.int64)
649
+ tile_pads = numpy_helper.from_array(pads_arr, 'tile_pads')
650
+ nodes = [
651
+ helper.make_node('Slice', ['input','st','en'], ['cr']),
652
+ helper.make_node('Tile', ['cr','rp'], ['tl']),
653
+ helper.make_node('Pad', ['tl', 'tile_pads'], ['output'], mode='constant'),
654
+ ]
655
+ inits.append(tile_pads)
656
+ return mk(nodes, inits)
657
+
658
+ def s_upscale(td):
659
+ exs = get_exs(td)
660
+ in_shapes = set(inp.shape for inp,_ in exs)
661
+ if len(in_shapes) != 1: return None
662
+ IH, IW = in_shapes.pop()
663
+ scales = set()
664
+ for inp, out in exs:
665
+ OH, OW = out.shape
666
+ if OH % IH or OW % IW: return None
667
+ sH, sW = OH//IH, OW//IW
668
+ if sH < 2 or sW < 2: return None
669
+ scales.add((sH, sW))
670
+ if len(scales) != 1: return None
671
+ sH, sW = scales.pop()
672
+ OH, OW = IH*sH, IW*sW
673
+ if OH > 30 or OW > 30: return None
674
+ for inp, out in exs:
675
+ if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)): return None
676
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
677
+ for r in range(OH):
678
+ for c in range(OW):
679
+ idx[r,c] = [r//sH, c//sW]
680
+ return _build_gather_model(OH, OW, idx)
681
+
682
+ def s_concat(td):
683
+ from itertools import product as iproduct
684
+ exs = get_exs(td)
685
+ sp = fixed_shapes(td)
686
+ if sp is None: return None
687
+ (IH,IW),(OH,OW) = sp
688
+ transforms = [
689
+ ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
690
+ ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
691
+ ]
692
+ if OH == IH and OW % IW == 0 and OW > IW:
693
+ n = OW // IW
694
+ if 2 <= n <= 4:
695
+ for combo in iproduct(range(4), repeat=n):
696
+ if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=1))
697
+ for inp, out in exs):
698
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
699
+ for oi in range(OH):
700
+ for oj in range(OW):
701
+ bj = oj // IW; lr, lc = oi, oj % IW
702
+ t = transforms[combo[bj]][0]
703
+ if t == 'id': sr, sc = lr, lc
704
+ elif t == 'fliplr': sr, sc = lr, IW-1-lc
705
+ elif t == 'flipud': sr, sc = IH-1-lr, lc
706
+ elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
707
+ idx[oi,oj] = [sr, sc]
708
+ return _build_gather_model(OH, OW, idx)
709
+ if OW == IW and OH % IH == 0 and OH > IH:
710
+ n = OH // IH
711
+ if 2 <= n <= 4:
712
+ for combo in iproduct(range(4), repeat=n):
713
+ if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=0))
714
+ for inp, out in exs):
715
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
716
+ for oi in range(OH):
717
+ for oj in range(OW):
718
+ bi = oi // IH; lr, lc = oi % IH, oj
719
+ t = transforms[combo[bi]][0]
720
+ if t == 'id': sr, sc = lr, lc
721
+ elif t == 'fliplr': sr, sc = lr, IW-1-lc
722
+ elif t == 'flipud': sr, sc = IH-1-lr, lc
723
+ elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
724
+ idx[oi,oj] = [sr, sc]
725
+ return _build_gather_model(OH, OW, idx)
726
+ return None
727
+
728
+ def s_concat_enhanced(td):
729
+ """Enhanced concat with all 8 dihedral group transforms."""
730
+ exs = get_exs(td)
731
+ sp = fixed_shapes(td)
732
+ if sp is None: return None
733
+ (IH,IW),(OH,OW) = sp
734
+ if IH == OH and IW == OW: return None
735
+ if OH % IH != 0 or OW % IW != 0: return None
736
+ rH, rW = OH // IH, OW // IW
737
+ if rH * rW > 16 or rH * rW < 2: return None
738
+ if OH > 30 or OW > 30: return None
739
+
740
+ transforms = [
741
+ ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
742
+ ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
743
+ ('rot90', lambda x: np.rot90(x, 1)), ('rot270', lambda x: np.rot90(x, 3)),
744
+ ('T', lambda x: x.T), ('T_fliplr', lambda x: np.fliplr(x.T)),
745
+ ]
746
+
747
+ block_transforms = {}
748
+ for bi in range(rH):
749
+ for bj in range(rW):
750
+ found = None
751
+ for tidx, (tname, tfn) in enumerate(transforms):
752
+ ok = True
753
+ for inp, out in exs:
754
+ block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
755
+ expected = tfn(inp)
756
+ if expected.shape != (IH, IW) or not np.array_equal(block, expected):
757
+ ok = False; break
758
+ if ok:
759
+ found = (tidx, tname)
760
+ break
761
+ if found is None: return None
762
+ block_transforms[(bi, bj)] = found
763
+
764
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
765
+ for bi in range(rH):
766
+ for bj in range(rW):
767
+ _, tname = block_transforms[(bi, bj)]
768
+ for lr in range(IH):
769
+ for lc in range(IW):
770
+ oi, oj = bi*IH + lr, bj*IW + lc
771
+ if tname == 'id': sr, sc = lr, lc
772
+ elif tname == 'fliplr': sr, sc = lr, IW-1-lc
773
+ elif tname == 'flipud': sr, sc = IH-1-lr, lc
774
+ elif tname == 'rot180': sr, sc = IH-1-lr, IW-1-lc
775
+ elif tname == 'rot90': sr, sc = IW-1-lc, lr
776
+ elif tname == 'rot270': sr, sc = lc, IH-1-lr
777
+ elif tname == 'T': sr, sc = lc, lr
778
+ elif tname == 'T_fliplr': sr, sc = IW-1-lc, lr
779
+ idx[oi, oj] = [sr, sc]
780
+
781
+ for inp, out in exs:
782
+ reconstructed = np.zeros_like(out)
783
+ for oi in range(OH):
784
+ for oj in range(OW):
785
+ reconstructed[oi,oj] = inp[idx[oi,oj,0], idx[oi,oj,1]]
786
+ if not np.array_equal(reconstructed, out): return None
787
+
788
+ return _build_gather_model(OH, OW, idx)
789
+
790
+ def s_kronecker(td):
791
+ exs = get_exs(td)
792
+ sp = fixed_shapes(td)
793
+ if sp is None: return None
794
+ (IH,IW),(OH,OW) = sp
795
+ if OH % IH != 0 or OW % IW != 0: return None
796
+ sH, sW = OH // IH, OW // IW
797
+ if sH < 2 or sW < 2: return None
798
+ if OH > 30 or OW > 30: return None
799
+ for inp, out in exs:
800
+ expected = np.kron(inp, np.ones((sH, sW), dtype=np.int64))
801
+ if not np.array_equal(out, expected): return None
802
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
803
+ for r in range(OH):
804
+ for c in range(OW):
805
+ idx[r,c] = [r//sH, c//sW]
806
+ return _build_gather_model(OH, OW, idx)
807
+
808
+ def s_diagonal_tile(td):
809
+ exs = get_exs(td)
810
+ sp = fixed_shapes(td)
811
+ if sp is None: return None
812
+ (IH,IW),(OH,OW) = sp
813
+ if OH % IH != 0 or OW % IW != 0: return None
814
+ rH, rW = OH // IH, OW // IW
815
+ if rH != rW or rH < 2: return None
816
+ if OH > 30 or OW > 30: return None
817
+ for inp, out in exs:
818
+ for bi in range(rH):
819
+ for bj in range(rW):
820
+ block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
821
+ if bi == bj:
822
+ if not np.array_equal(block, inp): return None
823
+ else:
824
+ if not np.all(block == 0): return None
825
+ idx = np.zeros((OH,OW,2), dtype=np.int64)
826
+ cst = np.full((OH,OW), -1, dtype=np.int64)
827
+ for bi in range(rH):
828
+ for bj in range(rW):
829
+ for lr in range(IH):
830
+ for lc in range(IW):
831
+ oi, oj = bi*IH + lr, bj*IW + lc
832
+ if bi == bj: idx[oi, oj] = [lr, lc]
833
+ else: idx[oi, oj] = [-1, -1]; cst[oi, oj] = 0
834
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
835
+
836
+ def s_shift(td):
837
+ exs = get_exs(td)
838
+ sp = fixed_shapes(td)
839
+ if sp is None: return None
840
+ (IH, IW), (OH, OW) = sp
841
+ if (IH, IW) != (OH, OW): return None
842
+ for dr in range(-5, 6):
843
+ for dc in range(-5, 6):
844
+ if dr == 0 and dc == 0: continue
845
+ ok = True
846
+ for inp, out in exs:
847
+ shifted = np.zeros_like(inp)
848
+ r0, r1 = max(0, dr), min(IH, IH + dr)
849
+ c0, c1 = max(0, dc), min(IW, IW + dc)
850
+ if r1 > r0 and c1 > c0:
851
+ sr0, sc0 = max(0, -dr), max(0, -dc)
852
+ shifted[r0:r1, c0:c1] = inp[sr0:sr0+(r1-r0), sc0:sc0+(c1-c0)]
853
+ if not np.array_equal(shifted, out):
854
+ ok = False; break
855
+ if not ok: continue
856
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
857
+ cst = np.full((OH, OW), 0, dtype=np.int64)
858
+ for r in range(OH):
859
+ for c in range(OW):
860
+ sr, sc = r - dr, c - dc
861
+ if 0 <= sr < IH and 0 <= sc < IW: idx[r, c] = [sr, sc]
862
+ else: idx[r, c] = [-1, -1]
863
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
864
+ return None
865
+
866
+ def s_mirror_h(td):
867
+ exs = get_exs(td)
868
+ sp = fixed_shapes(td)
869
+ if sp is None: return None
870
+ (IH, IW), (OH, OW) = sp
871
+ if OH != IH or OW != 2 * IW: return None
872
+ if OW > 30: return None
873
+ for inp, out in exs:
874
+ expected = np.concatenate([inp, np.flip(inp, 1)], 1)
875
+ if not np.array_equal(expected, out): return None
876
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
877
+ for r in range(OH):
878
+ for c in range(OW):
879
+ sc = c if c < IW else 2*IW - 1 - c
880
+ idx[r, c] = [r, sc]
881
+ return _build_gather_model(OH, OW, idx)
882
+
883
+ def s_mirror_v(td):
884
+ exs = get_exs(td)
885
+ sp = fixed_shapes(td)
886
+ if sp is None: return None
887
+ (IH, IW), (OH, OW) = sp
888
+ if OW != IW or OH != 2 * IH: return None
889
+ if OH > 30: return None
890
+ for inp, out in exs:
891
+ expected = np.concatenate([inp, np.flip(inp, 0)], 0)
892
+ if not np.array_equal(expected, out): return None
893
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
894
+ for r in range(OH):
895
+ for c in range(OW):
896
+ sr = r if r < IH else 2*IH - 1 - r
897
+ idx[r, c] = [sr, c]
898
+ return _build_gather_model(OH, OW, idx)
899
+
900
+ def s_quad_mirror(td):
901
+ exs = get_exs(td)
902
+ sp = fixed_shapes(td)
903
+ if sp is None: return None
904
+ (IH, IW), (OH, OW) = sp
905
+ if OH != 2 * IH or OW != 2 * IW: return None
906
+ if OH > 30 or OW > 30: return None
907
+ for inp, out in exs:
908
+ expected = np.block([
909
+ [inp, np.flip(inp, 1)],
910
+ [np.flip(inp, 0), np.flip(np.flip(inp, 0), 1)]
911
+ ])
912
+ if not np.array_equal(expected, out): return None
913
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
914
+ for r in range(OH):
915
+ for c in range(OW):
916
+ sr = r if r < IH else 2*IH - 1 - r
917
+ sc = c if c < IW else 2*IW - 1 - c
918
+ idx[r, c] = [sr, sc]
919
+ return _build_gather_model(OH, OW, idx)
920
+
921
+ def s_fixed_crop(td):
922
+ exs = get_exs(td)
923
+ sp = fixed_shapes(td)
924
+ if sp is None: return None
925
+ (IH, IW), (OH, OW) = sp
926
+ if OH > IH or OW > IW or (OH == IH and OW == IW): return None
927
+ for r0 in range(IH - OH + 1):
928
+ for c0 in range(IW - OW + 1):
929
+ if all(np.array_equal(inp[r0:r0+OH, c0:c0+OW], out) for inp, out in exs):
930
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
931
+ for r in range(OH):
932
+ for c in range(OW):
933
+ idx[r, c] = [r0 + r, c0 + c]
934
+ return _build_gather_model(OH, OW, idx)
935
+ return None
936
+
937
+ def s_nonuniform_scale(td):
938
+ exs = get_exs(td)
939
+ sp = fixed_shapes(td)
940
+ if sp is None: return None
941
+ (IH, IW), (OH, OW) = sp
942
+ for fh, fw in [(1,2),(2,1),(1,3),(3,1),(2,3),(3,2),(1,4),(4,1),(2,4),(4,2)]:
943
+ if OH != IH*fh or OW != IW*fw: continue
944
+ if OH > 30 or OW > 30: continue
945
+ if all(np.array_equal(np.repeat(np.repeat(inp, fh, 0), fw, 1), out) for inp, out in exs):
946
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
947
+ for r in range(OH):
948
+ for c in range(OW):
949
+ idx[r, c] = [r//fh, c//fw]
950
+ return _build_gather_model(OH, OW, idx)
951
+ return None
952
+
953
+ def s_constant(td):
954
+ sp = fixed_shapes(td)
955
+ if sp is None: return None
956
+ exs = get_exs(td)
957
+ outs = [out for _,out in exs]
958
+ if not all(np.array_equal(outs[0], o) for o in outs[1:]): return None
959
+ const = np.zeros((1,10,30,30), dtype=np.float32)
960
+ for r, row in enumerate(outs[0]):
961
+ for c, v in enumerate(row):
962
+ const[0, int(v), r, c] = 1.0
963
+ inits = [numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
964
+ numpy_helper.from_array(const, 'c')]
965
+ nodes = [helper.make_node('Mul', ['input','z'], ['zd']),
966
+ helper.make_node('ReduceSum', ['zd'], ['s'], axes=[1,2,3], keepdims=1),
967
+ helper.make_node('Add', ['s','c'], ['output'])]
968
+ return mk(nodes, inits)
969
+
970
+ def _attr_to_dict(attr_proto):
971
+ """Convert ONNX AttributeProto to Python native type."""
972
+ from onnx import AttributeProto
973
+ if attr_proto.type == AttributeProto.INT:
974
+ return attr_proto.i
975
+ elif attr_proto.type == AttributeProto.INTS:
976
+ return list(attr_proto.ints)
977
+ elif attr_proto.type == AttributeProto.FLOAT:
978
+ return attr_proto.f
979
+ elif attr_proto.type == AttributeProto.FLOATS:
980
+ return list(attr_proto.floats)
981
+ elif attr_proto.type == AttributeProto.STRING:
982
+ return attr_proto.s.decode('utf-8')
983
+ elif attr_proto.type == AttributeProto.STRINGS:
984
+ return [s.decode('utf-8') for s in attr_proto.strings]
985
+ elif attr_proto.type == AttributeProto.TENSOR:
986
+ return numpy_helper.to_array(attr_proto.t)
987
+ else:
988
+ return None
989
+
990
+ # ============================================================
991
+ # COMPOSITION DETECTORS (transform + color_map)
992
+ # ============================================================
993
+
994
+ def _apply_transform(inp, transform_name):
995
+ """Apply a named transform to a numpy array."""
996
+ if transform_name == 'id': return inp
997
+ elif transform_name == 'fliplr': return np.fliplr(inp)
998
+ elif transform_name == 'flipud': return np.flipud(inp)
999
+ elif transform_name == 'rot90': return np.rot90(inp, 1)
1000
+ elif transform_name == 'rot180': return np.rot90(inp, 2)
1001
+ elif transform_name == 'rot270': return np.rot90(inp, 3)
1002
+ elif transform_name == 'T': return inp.T
1003
+ else: return inp
1004
+
1005
+ def s_composition_rotate_color(td):
1006
+ """Detect rotation + color_map composition."""
1007
+ exs = get_exs(td)
1008
+ sp = fixed_shapes(td)
1009
+ if sp is None: return None
1010
+ (IH,IW),(OH,OW) = sp
1011
+ if (IH,IW) != (OH,OW): return None
1012
+
1013
+ for k in [1, 2, 3]:
1014
+ # Try each rotation, then check if consistent color_map remains
1015
+ cm = {}
1016
+ valid = True
1017
+ for inp, out in exs:
1018
+ rotated = np.rot90(inp, k)
1019
+ if rotated.shape != out.shape: valid = False; break
1020
+ for iv, ov in zip(rotated.flat, out.flat):
1021
+ iv, ov = int(iv), int(ov)
1022
+ if iv in cm and cm[iv] != ov: valid = False; break
1023
+ cm[iv] = ov
1024
+ if not valid: break
1025
+ if not valid: continue
1026
+
1027
+ # Build: rotate first (Slice-based), then color_map
1028
+ rot_model = _build_slice_rotate_model(k, IH, IW)
1029
+ # Extract nodes from rot_model, prepend to color_map
1030
+ cm_model = _build_color_map_model(cm)
1031
+
1032
+ # Combine: input -> rot_nodes -> color_map -> output
1033
+ # We need to chain the graphs
1034
+ combined_nodes = []
1035
+ combined_inits = []
1036
+
1037
+ # Add rotation nodes with renamed intermediates
1038
+ for node in rot_model.graph.node:
1039
+ if node.output[0] == 'output':
1040
+ # Last node of rotation feeds into color map
1041
+ new_node = helper.make_node(node.op_type, list(node.input), ['rot_out'],
1042
+ **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1043
+ else:
1044
+ new_node = node
1045
+ combined_nodes.append(new_node)
1046
+
1047
+ for init in rot_model.graph.initializer:
1048
+ combined_inits.append(init)
1049
+
1050
+ # Add color map nodes with input = rot_out
1051
+ for node in cm_model.graph.node:
1052
+ if node.input[0] == 'input':
1053
+ new_node = helper.make_node(node.op_type, ['rot_out'] + list(node.input[1:]), list(node.output),
1054
+ **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1055
+ else:
1056
+ new_node = node
1057
+ combined_nodes.append(new_node)
1058
+
1059
+ for init in cm_model.graph.initializer:
1060
+ combined_inits.append(init)
1061
+
1062
+ return mk(combined_nodes, combined_inits)
1063
+ return None
1064
+
1065
+ def s_composition_flip_color(td):
1066
+ """Detect flip + color_map composition."""
1067
+ exs = get_exs(td)
1068
+ sp = fixed_shapes(td)
1069
+ if sp is None: return None
1070
+ (IH,IW),(OH,OW) = sp
1071
+ if (IH,IW) != (OH,OW): return None
1072
+
1073
+ for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
1074
+ cm = {}
1075
+ valid = True
1076
+ for inp, out in exs:
1077
+ flipped = flip_fn(inp)
1078
+ if flipped.shape != out.shape: valid = False; break
1079
+ for iv, ov in zip(flipped.flat, out.flat):
1080
+ iv, ov = int(iv), int(ov)
1081
+ if iv in cm and cm[iv] != ov: valid = False; break
1082
+ cm[iv] = ov
1083
+ if not valid: break
1084
+ if not valid: continue
1085
+
1086
+ flip_model = _build_slice_flip_model(axis, IH, IW)
1087
+ cm_model = _build_color_map_model(cm)
1088
+
1089
+ combined_nodes = []
1090
+ combined_inits = []
1091
+
1092
+ for node in flip_model.graph.node:
1093
+ if node.output[0] == 'output':
1094
+ new_node = helper.make_node(node.op_type, list(node.input), ['flip_out'],
1095
+ **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1096
+ else:
1097
+ new_node = node
1098
+ combined_nodes.append(new_node)
1099
+
1100
+ for init in flip_model.graph.initializer:
1101
+ combined_inits.append(init)
1102
+
1103
+ for node in cm_model.graph.node:
1104
+ if node.input[0] == 'input':
1105
+ new_node = helper.make_node(node.op_type, ['flip_out'] + list(node.input[1:]), list(node.output),
1106
+ **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1107
+ else:
1108
+ new_node = node
1109
+ combined_nodes.append(new_node)
1110
+
1111
+ for init in cm_model.graph.initializer:
1112
+ combined_inits.append(init)
1113
+
1114
+ return mk(combined_nodes, combined_inits)
1115
+ return None
1116
+
1117
+ def s_composition_transpose_color(td):
1118
+ """Detect transpose + color_map composition."""
1119
+ exs = get_exs(td)
1120
+ sp = fixed_shapes(td)
1121
+ if sp is None: return None
1122
+ (IH,IW),(OH,OW) = sp
1123
+
1124
+ cm = {}
1125
+ valid = True
1126
+ for inp, out in exs:
1127
+ transposed = inp.T
1128
+ if transposed.shape != out.shape: valid = False; break
1129
+ for iv, ov in zip(transposed.flat, out.flat):
1130
+ iv, ov = int(iv), int(ov)
1131
+ if iv in cm and cm[iv] != ov: valid = False; break
1132
+ cm[iv] = ov
1133
+ if not valid: break
1134
+ if not valid: return None
1135
+
1136
+ trans_model = _build_slice_transpose_model(IH, IW)
1137
+ cm_model = _build_color_map_model(cm)
1138
+
1139
+ combined_nodes = []
1140
+ combined_inits = []
1141
+
1142
+ for node in trans_model.graph.node:
1143
+ if node.output[0] == 'output':
1144
+ new_node = helper.make_node(node.op_type, list(node.input), ['trans_out'],
1145
+ **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1146
+ else:
1147
+ new_node = node
1148
+ combined_nodes.append(new_node)
1149
+
1150
+ for init in trans_model.graph.initializer:
1151
+ combined_inits.append(init)
1152
+
1153
+ for node in cm_model.graph.node:
1154
+ if node.input[0] == 'input':
1155
+ new_node = helper.make_node(node.op_type, ['trans_out'] + list(node.input[1:]), list(node.output),
1156
+ **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1157
+ else:
1158
+ new_node = node
1159
+ combined_nodes.append(new_node)
1160
+
1161
+ for init in cm_model.graph.initializer:
1162
+ combined_inits.append(init)
1163
+
1164
+ return mk(combined_nodes, combined_inits)
1165
+
1166
+ # ============================================================
1167
+ # CHANNEL REDUCTION WRAPPER
1168
+ # ============================================================
1169
+
1170
+ def _get_active_colors(td):
1171
+ """Returns set of all colors appearing in inputs and outputs."""
1172
+ colors = set()
1173
+ for ex in td['train'] + td['test']:
1174
+ for row in ex['input']:
1175
+ colors.update(row)
1176
+ for row in ex['output']:
1177
+ colors.update(row)
1178
+ return colors
1179
+
1180
+ def _build_channel_reduced_model(inner_model, input_colors, output_colors):
1181
+ """Wrap a model with channel reduction: Conv1x1(10->N) -> inner -> Conv1x1(N->10).
1182
+ This saves MACs when N < 10."""
1183
+ n_in = len(input_colors)
1184
+ n_out = len(output_colors)
1185
+
1186
+ # Maps from full 10 channels to reduced set
1187
+ in_map = sorted(input_colors)
1188
+ out_map = sorted(output_colors)
1189
+
1190
+ # W_reduce: [n_in, 10, 1, 1] - maps 10 channels to n_in
1191
+ W_reduce = np.zeros((n_in, 10, 1, 1), dtype=np.float32)
1192
+ for i, c in enumerate(in_map):
1193
+ W_reduce[i, c, 0, 0] = 1.0
1194
+
1195
+ # W_expand: [10, n_out, 1, 1] - maps n_out channels back to 10
1196
+ W_expand = np.zeros((10, n_out, 1, 1), dtype=np.float32)
1197
+ for i, c in enumerate(out_map):
1198
+ W_expand[c, i, 0, 0] = 1.0
1199
+
1200
+ # Build the wrapped model
1201
+ nodes = [
1202
+ helper.make_node('Conv', ['input', 'W_reduce'], ['reduced'], kernel_shape=[1,1]),
1203
+ ]
1204
+ inits = [numpy_helper.from_array(W_reduce, 'W_reduce')]
1205
+
1206
+ # Add inner model nodes with input='reduced' and output renamed
1207
+ for node in inner_model.graph.node:
1208
+ if node.input[0] == 'input':
1209
+ new_inputs = ['reduced'] + list(node.input[1:])
1210
+ else:
1211
+ new_inputs = list(node.input)
1212
+
1213
+ if node.output[0] == 'output':
1214
+ new_outputs = ['inner_out']
1215
+ else:
1216
+ new_outputs = list(node.output)
1217
+
1218
+ new_node = helper.make_node(node.op_type, new_inputs, new_outputs,
1219
+ **{attr.name: _attr_to_dict(attr) for attr in node.attribute})
1220
+ nodes.append(new_node)
1221
+
1222
+ for init in inner_model.graph.initializer:
1223
+ if init.name != 'W_reduce': # avoid conflict
1224
+ inits.append(init)
1225
+
1226
+ nodes.append(helper.make_node('Conv', ['inner_out', 'W_expand'], ['output'], kernel_shape=[1,1]))
1227
+ inits.append(numpy_helper.from_array(W_expand, 'W_expand'))
1228
+
1229
+ return mk(nodes, inits)
1230
+
1231
+ def _try_channel_reduction(solver_fn, td):
1232
+ """Try a solver with channel reduction wrapper if it reduces cost.
1233
+ NOTE: Currently disabled for Gather-based models (spatial_gather, etc.)
1234
+ as they hardcode channel=10 in Reshape operations."""
1235
+ model = solver_fn(td)
1236
+ if model is None: return None
1237
+
1238
+ # DISABLED: Channel reduction breaks Gather-based models
1239
+ # that reshape to [1,10,900]. Only applies to Conv-based models.
1240
+ # colors = _get_active_colors(td)
1241
+ # if len(colors) >= 8:
1242
+ # return model
1243
+ # try:
1244
+ # wrapped = _build_channel_reduced_model(model, colors, colors)
1245
+ # return wrapped
1246
+ # except Exception:
1247
+ # return model
1248
+
1249
+ return model
1250
+
1251
+ # ============================================================
1252
+ # CONV SOLVERS WITH LOOCV RIDGE + STRIDE TRICKS
1253
+ # ============================================================
1254
+
1255
+ def add_onehot_block(nodes, inits, am_name, oh_name):
1256
+ """Equal + Cast one-hot encoding (replaces OneHot which lacks CUDA kernel)."""
1257
+ classes = np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1)
1258
+ inits.append(numpy_helper.from_array(classes, 'classes'))
1259
+ nodes.append(helper.make_node('Equal', [am_name, 'classes'], ['eq']))
1260
+ nodes.append(helper.make_node('Cast', ['eq'], [oh_name], to=TensorProto.FLOAT))
1261
+
1262
+ def _extract_patches_strided(oh_pad, ks, out_shape):
1263
+ """Extract patches using stride_tricks for speedup.
1264
+ oh_pad: [C, H+2p, W+2p] padded one-hot array
1265
+ ks: kernel size
1266
+ out_shape: (OH, OW) output shape
1267
+ Returns: patches array [OH*OW, C*ks*ks]
1268
+ """
1269
+ C, Hp, Wp = oh_pad.shape
1270
+ OH, OW = out_shape
1271
+
1272
+ # Use as_strided to create sliding window view over padded array
1273
+ stride_c = oh_pad.strides[0]
1274
+ stride_h = oh_pad.strides[1]
1275
+ stride_w = oh_pad.strides[2]
1276
+
1277
+ # Ensure base covers all needed elements: up to (OH-1+ks, OW-1+ks)
1278
+ needed_h = min(OH - 1 + ks, Hp)
1279
+ needed_w = min(OW - 1 + ks, Wp)
1280
+ base = oh_pad[:, :needed_h, :needed_w]
1281
+
1282
+ # Shape: [OH, OW, C, ks, ks]
1283
+ shape = (OH, OW, C, ks, ks)
1284
+ strides = (stride_h, stride_w, stride_c, stride_h, stride_w)
1285
+
1286
+ patches_view = np.lib.stride_tricks.as_strided(base, shape=shape, strides=strides)
1287
+ # Reshape to [OH*OW, C*ks*ks]
1288
+ return patches_view.reshape(OH * OW, C * ks * ks)
1289
+
1290
+ def _effective_rank(P):
1291
+ """Compute effective rank r(Σ) = Tr(Σ) / ‖Σ‖."""
1292
+ Sigma = np.cov(P, rowvar=False)
1293
+ evals = np.linalg.eigvalsh(Sigma)
1294
+ evals = evals[evals > 1e-12]
1295
+ if len(evals) == 0: return 0
1296
+ return np.sum(evals) / np.max(evals)
1297
+
1298
+ def _tune_ridge_loocv(P, T_oh, lambdas):
1299
+ """Find best λ using efficient LOOCV via Hat Matrix diagonal (SVD shortcut).
1300
+ Cawley & Talbot (2010), JMLR.
1301
+ """
1302
+ n, p = P.shape
1303
+ try:
1304
+ U, s, Vt = np.linalg.svd(P, full_matrices=False)
1305
+ except Exception:
1306
+ return None
1307
+
1308
+ best_lambda, min_err = None, float('inf')
1309
+
1310
+ for lam in lambdas:
1311
+ d = (s**2) / (s**2 + lam)
1312
+ y_hat = (U * d) @ (U.T @ T_oh)
1313
+ # Ridge hat matrix diagonal: h_ii = Σ_j U_ij^2 * s_j^2 / (s_j^2 + λ)
1314
+ h_ii = np.sum((U**2) * d[np.newaxis, :], axis=1)
1315
+
1316
+ # LOOCV shortcut: error_i = (y_i - ŷ_i) / (1 - h_ii)
1317
+ denom = 1 - h_ii
1318
+ denom = np.where(np.abs(denom) < 1e-10, 1e-10, denom)
1319
+ errors = (T_oh - y_hat) / denom[:, np.newaxis]
1320
+ mse = np.mean(errors**2)
1321
+
1322
+ if mse < min_err:
1323
+ min_err, best_lambda = mse, lam
1324
+
1325
+ return best_lambda
1326
+
1327
+ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False, use_ridge=True):
1328
+ """Shared lstsq conv fitting with optional LOOCV Ridge tuning.
1329
+ Returns (Wconv, B) or None."""
1330
+ pad = ks // 2
1331
+ feat = 10 * ks * ks + (1 if use_bias else 0)
1332
+ if feat > 20000: return None
1333
+
1334
+ patches_list, targets = [], []
1335
+ for inp_g, out_g in exs_raw:
1336
+ ih, iw = inp_g.shape
1337
+ if use_full_30:
1338
+ oh_full = np.zeros((10, GH, GW), dtype=np.float64)
1339
+ for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
1340
+ oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
1341
+ else:
1342
+ oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
1343
+ for c in range(10): oh_enc[c] = (inp_g == c)
1344
+ oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
1345
+
1346
+ oh, ow = out_g.shape
1347
+
1348
+ # Try stride_tricks for speedup
1349
+ try:
1350
+ patches = _extract_patches_strided(oh_pad, ks, (oh, ow))
1351
+ if use_bias:
1352
+ bias_col = np.ones((patches.shape[0], 1), dtype=np.float64)
1353
+ patches = np.concatenate([patches, bias_col], axis=1)
1354
+ patches_list.append(patches)
1355
+ targets.append(out_g.flatten())
1356
+ except Exception:
1357
+ # Fallback to loop-based extraction
1358
+ for r in range(oh):
1359
+ for c in range(ow):
1360
+ p = oh_pad[:, r:r+ks, c:c+ks].flatten()
1361
+ if use_bias: p = np.append(p, 1.0)
1362
+ patches_list.append(p)
1363
+ targets.append(int(out_g[r, c]))
1364
+
1365
+ if len(patches_list) > 0 and isinstance(patches_list[0], np.ndarray) and patches_list[0].ndim == 2:
1366
+ P = np.concatenate(patches_list, axis=0)
1367
+ T = np.concatenate(targets)
1368
+ else:
1369
+ P = np.array(patches_list, dtype=np.float64)
1370
+ T = np.array(targets, dtype=np.int64)
1371
+
1372
+ n_patches = P.shape[0]
1373
+ if feat > 5000 and n_patches > 2000: return None
1374
+
1375
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
1376
+ for i, t in enumerate(T): T_oh[i, t] = 1.0
1377
+
1378
+ # Quick condition number estimate using norm ratio (cheaper than full SVD)
1379
+ # Only skip if clearly pathological; otherwise try lstsq
1380
+ cond_estimate = None
1381
+ try:
1382
+ # Use 2-norm estimate: cond ≈ ||P||_2 * ||P^+||_2 ≈ max_singular / min_singular
1383
+ # We approximate with norm ratios for speed
1384
+ p_norm = np.linalg.norm(P, 2)
1385
+ if p_norm > 0:
1386
+ # Estimate using power method approximation or just try lstsq
1387
+ pass # Don't waste time on condition number - lstsq will handle it
1388
+ except Exception:
1389
+ pass
1390
+
1391
+ if use_ridge and n_patches <= feat * 1.5:
1392
+ # Use LOOCV Ridge tuning when system is underdetermined or near interpolation threshold
1393
+ lambdas = np.logspace(-4, 2, 10)
1394
+ best_lam = _tune_ridge_loocv(P, T_oh, lambdas)
1395
+ if best_lam is not None:
1396
+ # Ridge solve: (P^T P + λI)^-1 P^T T
1397
+ try:
1398
+ WT = np.linalg.solve(P.T @ P + best_lam * np.eye(P.shape[1]), P.T @ T_oh)
1399
+ except Exception:
1400
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1401
+ else:
1402
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1403
+ else:
1404
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1405
+
1406
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T): return None
1407
+
1408
+ if use_bias:
1409
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
1410
+ B = WT[-1].astype(np.float32)
1411
+ else:
1412
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1413
+ B = None
1414
+ return Wconv, B
1415
+
1416
+ # ============================================================
1417
+ # CONV SOLVER WRAPPERS
1418
+ # ============================================================
1419
+
1420
+ def _get_ks_for_budget(time_budget):
1421
+ """Return kernel sizes to try based on time budget."""
1422
+ if time_budget < 5:
1423
+ return [1, 3, 5]
1424
+ elif time_budget < 10:
1425
+ return [1, 3, 5, 7, 9]
1426
+ elif time_budget < 20:
1427
+ return [1, 3, 5, 7, 9, 11, 13, 15, 17]
1428
+ else:
1429
+ return [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]
1430
+
1431
+ def solve_conv_fixed(td, path, time_budget=30.0):
1432
+ """Fixed-shape conv: Slice -> Conv -> ArgMax -> Equal+Cast -> Pad."""
1433
+ exs = get_exs(td)
1434
+ for inp, out in exs:
1435
+ if inp.shape != out.shape: return None
1436
+ shapes = set(inp.shape for inp, _ in exs)
1437
+ if len(shapes) != 1: return None
1438
+ IH, IW = shapes.pop()
1439
+
1440
+ fit_exs = get_exs_for_fitting(td)
1441
+ fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
1442
+
1443
+ t_start = time.time()
1444
+ for use_bias in [False, True]:
1445
+ for ks in _get_ks_for_budget(time_budget):
1446
+ if time.time() - t_start > time_budget: return None
1447
+ result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=False)
1448
+ if result is None: continue
1449
+ Wconv, B = result
1450
+ pad = ks // 2
1451
+ pad_h, pad_w = GH - IH, GW - IW
1452
+
1453
+ inits = [
1454
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1455
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1456
+ numpy_helper.from_array(Wconv, 'W'),
1457
+ ]
1458
+ conv_inputs = ['grid', 'W']
1459
+ if B is not None:
1460
+ inits.append(numpy_helper.from_array(B, 'B'))
1461
+ conv_inputs.append('B')
1462
+
1463
+ nodes = [
1464
+ helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1465
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1466
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1467
+ ]
1468
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1469
+ cf_pads = numpy_helper.from_array(np.array([0,0,0,0,0,0,pad_h,pad_w], dtype=np.int64), 'cf_pads')
1470
+ inits.append(cf_pads)
1471
+ nodes.append(
1472
+ helper.make_node('Pad', ['oh_out', 'cf_pads'], ['output'], mode='constant')
1473
+ )
1474
+
1475
+ model = mk(nodes, inits)
1476
+ onnx.save(model, path)
1477
+ if validate(path, td): return 'conv_fixed', model
1478
+ return None
1479
+
1480
+ def solve_conv_variable(td, path, time_budget=30.0):
1481
+ """Variable-shape conv: Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)."""
1482
+ exs = get_exs(td)
1483
+ for inp, out in exs:
1484
+ if inp.shape != out.shape: return None
1485
+
1486
+ fit_exs = get_exs_for_fitting_variable(td)
1487
+ fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape]
1488
+
1489
+ t_start = time.time()
1490
+ for use_bias in [False, True]:
1491
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
1492
+ if time.time() - t_start > time_budget: return None
1493
+ result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=True)
1494
+ if result is None: continue
1495
+ Wconv, B = result
1496
+ pad = ks // 2
1497
+
1498
+ inits = [numpy_helper.from_array(Wconv, 'W')]
1499
+ conv_inputs = ['input', 'W']
1500
+ if B is not None:
1501
+ inits.append(numpy_helper.from_array(B, 'B'))
1502
+ conv_inputs.append('B')
1503
+
1504
+ nodes = [
1505
+ helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1506
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1507
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1508
+ ]
1509
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1510
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1511
+
1512
+ model = mk(nodes, inits)
1513
+ onnx.save(model, path)
1514
+ if validate(path, td): return 'conv_var', model
1515
+ return None
1516
+
1517
+ def solve_conv_diffshape(td, path, time_budget=30.0):
1518
+ """Diff-shape conv for fixed io shapes where output is smaller."""
1519
+ sp = fixed_shapes(td)
1520
+ if sp is None: return None
1521
+ (IH, IW), (OH, OW) = sp
1522
+ if IH == OH and IW == OW: return None
1523
+ if OH > IH or OW > IW: return None
1524
+ if OH > 30 or OW > 30: return None
1525
+
1526
+ exs = get_exs(td)
1527
+ t_start = time.time()
1528
+
1529
+ for dr_off, dc_off in [(0, 0), ((IH-OH)//2, (IW-OW)//2)]:
1530
+ for use_bias in [False, True]:
1531
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
1532
+ if time.time() - t_start > time_budget: return None
1533
+ pad = ks // 2
1534
+ feat = 10 * ks * ks + (1 if use_bias else 0)
1535
+ if feat > 10000: continue
1536
+
1537
+ patches, targets = [], []
1538
+ valid = True
1539
+ for inp_g, out_g in exs:
1540
+ oh_enc = np.zeros((10, IH, IW), dtype=np.float64)
1541
+ for c in range(10): oh_enc[c] = (inp_g == c)
1542
+ oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
1543
+ for r in range(OH):
1544
+ for c in range(OW):
1545
+ sr, sc = r + dr_off, c + dc_off
1546
+ if sr < 0 or sr >= IH or sc < 0 or sc >= IW:
1547
+ valid = False; break
1548
+ p = oh_pad[:, sr:sr+ks, sc:sc+ks].flatten()
1549
+ if use_bias: p = np.append(p, 1.0)
1550
+ patches.append(p)
1551
+ targets.append(int(out_g[r, c]))
1552
+ if not valid: break
1553
+ if not valid: break
1554
+ if not valid: continue
1555
+
1556
+ n_patches = len(patches)
1557
+ if feat > 5000 and n_patches > 2000: continue
1558
+
1559
+ P = np.array(patches, dtype=np.float64)
1560
+ T = np.array(targets, dtype=np.int64)
1561
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
1562
+ for i, t in enumerate(T): T_oh[i, t] = 1.0
1563
+
1564
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1565
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
1566
+
1567
+ if use_bias:
1568
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
1569
+ B = WT[-1].astype(np.float32)
1570
+ else:
1571
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1572
+ B = None
1573
+
1574
+ pad_h, pad_w = GH - OH, GW - OW
1575
+ inits = [
1576
+ numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
1577
+ numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
1578
+ numpy_helper.from_array(Wconv, 'W'),
1579
+ numpy_helper.from_array(np.array([0,0,dr_off,dc_off], dtype=np.int64), 'cr_st'),
1580
+ numpy_helper.from_array(np.array([1,10,dr_off+OH,dc_off+OW], dtype=np.int64), 'cr_en'),
1581
+ ]
1582
+ conv_inputs = ['grid', 'W']
1583
+ if B is not None:
1584
+ inits.append(numpy_helper.from_array(B, 'B'))
1585
+ conv_inputs.append('B')
1586
+
1587
+ nodes = [
1588
+ helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1589
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1590
+ helper.make_node('Slice', ['co','cr_st','cr_en'], ['co_crop']),
1591
+ helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
1592
+ ]
1593
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1594
+ diff_pads = numpy_helper.from_array(np.array([0,0,0,0,0,0,pad_h,pad_w], dtype=np.int64), 'diff_pads')
1595
+ inits.append(diff_pads)
1596
+ nodes.append(
1597
+ helper.make_node('Pad', ['oh_out', 'diff_pads'], ['output'], mode='constant')
1598
+ )
1599
+
1600
+ model = mk(nodes, inits)
1601
+ onnx.save(model, path)
1602
+ if validate(path, td): return 'conv_diff', model
1603
+ return None
1604
+
1605
+ def solve_conv_var_diff(td, path, time_budget=30.0):
1606
+ """Variable diff-shape conv."""
1607
+ exs = get_exs(td)
1608
+
1609
+ t_start = time.time()
1610
+ for use_bias in [False, True]:
1611
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
1612
+ if time.time() - t_start > time_budget: return None
1613
+
1614
+ pad = ks // 2
1615
+ feat = 10 * ks * ks + (1 if use_bias else 0)
1616
+ if feat > 20000: continue
1617
+
1618
+ patches, targets = [], []
1619
+ for inp_g, out_g in exs:
1620
+ ih, iw = inp_g.shape
1621
+ oh, ow = out_g.shape
1622
+ oh_full = np.zeros((10, GH, GW), dtype=np.float64)
1623
+ for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
1624
+ oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
1625
+
1626
+ for r in range(oh):
1627
+ for c in range(ow):
1628
+ p = oh_pad[:, r:r+ks, c:c+ks].flatten()
1629
+ if use_bias: p = np.append(p, 1.0)
1630
+ patches.append(p)
1631
+ targets.append(int(out_g[r, c]))
1632
+
1633
+ n_patches = len(patches)
1634
+ if feat > 5000 and n_patches > 2000: continue
1635
+
1636
+ P = np.array(patches, dtype=np.float64)
1637
+ T = np.array(targets, dtype=np.int64)
1638
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
1639
+ for i, t in enumerate(T): T_oh[i, t] = 1.0
1640
+
1641
+ try:
1642
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1643
+ except:
1644
+ continue
1645
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
1646
+
1647
+ if use_bias:
1648
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
1649
+ B = WT[-1].astype(np.float32)
1650
+ else:
1651
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1652
+ B = None
1653
+
1654
+ all_output_within_input = all(
1655
+ out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
1656
+ for inp_g, out_g in exs
1657
+ )
1658
+
1659
+ if not all_output_within_input:
1660
+ continue
1661
+
1662
+ inits = [numpy_helper.from_array(Wconv, 'W')]
1663
+ conv_inputs = ['input', 'W']
1664
+ if B is not None:
1665
+ inits.append(numpy_helper.from_array(B, 'B'))
1666
+ conv_inputs.append('B')
1667
+
1668
+ nodes = [
1669
+ helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
1670
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1671
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1672
+ ]
1673
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
1674
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1675
+
1676
+ model = mk(nodes, inits)
1677
+ onnx.save(model, path)
1678
+ if validate(path, td): return 'conv_var_diff', model
1679
+ return None
1680
+
1681
+ # ============================================================
1682
+ # MAIN SOLVER PIPELINE
1683
+ # ============================================================
1684
+
1685
+ ANALYTICAL_SOLVERS = [
1686
+ ('identity', s_identity),
1687
+ ('constant', s_constant),
1688
+ ('color_map', s_color_map),
1689
+ ('transpose', s_transpose),
1690
+ ('flip', s_flip),
1691
+ ('rotate', s_rotate),
1692
+ ('tile', s_tile),
1693
+ ('upscale', s_upscale),
1694
+ ('kronecker', s_kronecker),
1695
+ ('nonuniform_scale', s_nonuniform_scale),
1696
+ ('mirror_h', s_mirror_h),
1697
+ ('mirror_v', s_mirror_v),
1698
+ ('quad_mirror', s_quad_mirror),
1699
+ ('concat', s_concat),
1700
+ ('concat_enhanced', s_concat_enhanced),
1701
+ ('diagonal_tile', s_diagonal_tile),
1702
+ ('fixed_crop', s_fixed_crop),
1703
+ ('spatial_gather', s_spatial_gather),
1704
+ ('shift', s_shift),
1705
+ ('varshape_spatial_gather', s_varshape_spatial_gather),
1706
+ ]
1707
+
1708
+ COMPOSITION_SOLVERS = [
1709
+ ('rotate_color', s_composition_rotate_color),
1710
+ ('flip_color', s_composition_flip_color),
1711
+ ('transpose_color', s_composition_transpose_color),
1712
+ ]
1713
+
1714
+ def solve_task(tn, td, outdir, conv_budget=30.0, use_channel_reduction=True):
1715
+ t_start = time.time()
1716
+ os.makedirs(outdir, exist_ok=True)
1717
+ path = os.path.join(outdir, f"task{tn:03d}.onnx")
1718
+
1719
+ if tn in EXCLUDED_TASKS:
1720
+ return False, 'excluded', None, time.time() - t_start, path
1721
+
1722
+ # 1. Try analytical solvers (fast, tiny models)
1723
+ for sname, sfn in ANALYTICAL_SOLVERS:
1724
+ try:
1725
+ if use_channel_reduction and sname in ('transpose', 'flip', 'rotate', 'mirror_h', 'mirror_v', 'quad_mirror', 'shift', 'spatial_gather', 'varshape_spatial_gather'):
1726
+ model = _try_channel_reduction(sfn, td)
1727
+ else:
1728
+ model = sfn(td)
1729
+ if model is None: continue
1730
+ onnx.save(model, path)
1731
+ if validate(path, td):
1732
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1733
+ except Exception as e:
1734
+ pass
1735
+
1736
+ # 2. Try composition solvers
1737
+ for sname, sfn in COMPOSITION_SOLVERS:
1738
+ try:
1739
+ model = sfn(td)
1740
+ if model is None: continue
1741
+ onnx.save(model, path)
1742
+ if validate(path, td):
1743
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1744
+ except Exception:
1745
+ pass
1746
+
1747
+ # 3. Determine task shape category and try conv solvers
1748
+ exs = get_exs(td)
1749
+ same_shape = all(inp.shape == out.shape for inp, out in exs)
1750
+ shapes = set(inp.shape for inp, _ in exs)
1751
+ fixed_in = len(shapes) == 1
1752
+
1753
+ conv_time = conv_budget
1754
+
1755
+ if same_shape:
1756
+ if fixed_in:
1757
+ result = solve_conv_fixed(td, path, time_budget=conv_time/2)
1758
+ if result is not None:
1759
+ sname, model = result
1760
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1761
+ result = solve_conv_variable(td, path, time_budget=conv_time)
1762
+ if result is not None:
1763
+ sname, model = result
1764
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1765
+ else:
1766
+ sp = fixed_shapes(td)
1767
+ if sp is not None:
1768
+ (IH,IW),(OH,OW) = sp
1769
+ if OH <= IH and OW <= IW:
1770
+ result = solve_conv_diffshape(td, path, time_budget=conv_time)
1771
+ if result is not None:
1772
+ sname, model = result
1773
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1774
+
1775
+ result = solve_conv_var_diff(td, path, time_budget=conv_time)
1776
+ if result is not None:
1777
+ sname, model = result
1778
+ return True, sname, os.path.getsize(path), time.time() - t_start, path
1779
+
1780
+ return False, None, None, time.time() - t_start, path
1781
+
1782
+ def run_tasks(task_nums, tasks, output_dir, conv_budget, use_wandb, use_channel_reduction=True):
1783
+ results = {}
1784
+ costs_dict = {}
1785
+ total_score = 0
1786
+ for tn in task_nums:
1787
+ if tn not in tasks:
1788
+ continue
1789
+ if tn in EXCLUDED_TASKS:
1790
+ print(f"Task {tn:3d}: EXCLUDED (officially)")
1791
+ continue
1792
+
1793
+ td = tasks[tn]['data']
1794
+ ok, sname, sz, t_task, model_path = solve_task(tn, td, output_dir, conv_budget, use_channel_reduction)
1795
+
1796
+ if ok:
1797
+ macs, memory, params = score_network(model_path)
1798
+ if macs is None:
1799
+ macs, memory, params = 0, 0, 0
1800
+ cost = macs + memory + params
1801
+ score = max(1.0, 25.0 - math.log(max(1, cost)))
1802
+ total_score += score
1803
+
1804
+ results[tn] = (sname, t_task, sz)
1805
+ costs_dict[tn] = cost
1806
+ print(f"Task {tn:3d}: {sname:25s} {score:7.3f} {cost:>12} {t_task:7.3f}s ({sz:>8,} bytes)")
1807
+ else:
1808
+ print(f"Task {tn:3d}: UNSOLVED {t_task:7.3f}s")
1809
+ cost = 0
1810
+
1811
+ if use_wandb and wandb is not None:
1812
+ wandb.log({
1813
+ "task_id": tn,
1814
+ "solver": sname if ok else "unsolved",
1815
+ "onnx_bytes": sz if ok else 0,
1816
+ "task_time_sec": t_task,
1817
+ "cost": cost,
1818
+ "score": score if ok else 0,
1819
+ })
1820
+
1821
+ return results, costs_dict, total_score
1822
+
1823
+
1824
+ def main():
1825
+ parser = argparse.ArgumentParser()
1826
+ parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
1827
+ parser.add_argument('--arcgen_dir', default='', help='Path to ARC-GEN-100K/ directory')
1828
+ parser.add_argument('--output_dir', default='submission')
1829
+ parser.add_argument('--kaggle', action='store_true')
1830
+ parser.add_argument('--conv_budget', type=float, default=30.0)
1831
+ parser.add_argument('--tasks', type=str, default='')
1832
+ parser.add_argument('--device', type=str, default='auto', choices=['auto','cpu','cuda'])
1833
+ parser.add_argument('--use_wandb', action='store_true')
1834
+ parser.add_argument('--no_channel_reduction', action='store_true', help='Disable channel reduction wrapper')
1835
+ args = parser.parse_args()
1836
+ global ORT_PROVIDERS
1837
+ config = {
1838
+ "device": args.device,
1839
+ "conv_budget": args.conv_budget,
1840
+ "data_dir": args.data_dir,
1841
+ "arcgen_dir": args.arcgen_dir,
1842
+ "tasks": args.tasks,
1843
+ }
1844
+
1845
+ if args.device == 'cuda':
1846
+ ORT_PROVIDERS = ['CUDAExecutionProvider', 'CPUExecutionProvider']
1847
+ elif args.device == 'cpu':
1848
+ ORT_PROVIDERS = ['CPUExecutionProvider']
1849
+
1850
+ ort.set_default_logger_severity(3)
1851
+ print(f"Using providers: {ORT_PROVIDERS}")
1852
+ print(f"OPSET: 17 (v5)")
1853
+
1854
+ if args.kaggle:
1855
+ tasks = load_tasks_kaggle(args.data_dir)
1856
+ else:
1857
+ arcgen = args.arcgen_dir if args.arcgen_dir else None
1858
+ tasks = load_tasks_dir(args.data_dir, arcgen_dir=arcgen)
1859
+
1860
+ total_arcgen = sum(len(t['data'].get('arc-gen', [])) for t in tasks.values())
1861
+ print(f"Loaded {len(tasks)} tasks ({total_arcgen} ARC-GEN examples)")
1862
+ print(f"Excluded tasks: {sorted(EXCLUDED_TASKS)}")
1863
+
1864
+ task_nums = [int(t) for t in args.tasks.split(',')] if args.tasks else sorted(tasks.keys())
1865
+ active_tasks = [t for t in task_nums if t not in EXCLUDED_TASKS]
1866
+ print(f"Solving {len(active_tasks)} active tasks (skipping {len(task_nums) - len(active_tasks)} excluded)")
1867
+ print(f"Conv budget: {args.conv_budget}s per task")
1868
+ print(f"Channel reduction: {'enabled' if not args.no_channel_reduction else 'disabled'}")
1869
+ print("=" * 70)
1870
+ t0 = time.time()
1871
+
1872
+ use_ch_red = not args.no_channel_reduction
1873
+
1874
+ if args.use_wandb and wandb is not None:
1875
+ with wandb.init(project="neurogolf", name="solver_run", config=config):
1876
+ results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=True, use_channel_reduction=use_ch_red)
1877
+ else:
1878
+ results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=False, use_channel_reduction=use_ch_red)
1879
+
1880
+ elapsed = time.time() - t0
1881
+ print(f"\n{'='*70}")
1882
+ print(f"Solved: {len(results)}/{len(active_tasks)} active tasks in {elapsed:.0f}s")
1883
+ solver_names = [v[0] for v in results.values()]
1884
+ sc = Counter(solver_names)
1885
+ for s, c in sc.most_common(): print(f" {s}: {c}")
1886
+
1887
+ outdir = args.output_dir
1888
+ n_files = len([f for f in os.listdir(outdir) if f.endswith('.onnx')])
1889
+ total_size = sum(os.path.getsize(os.path.join(outdir, f))
1890
+ for f in os.listdir(outdir) if f.endswith('.onnx'))
1891
+
1892
+ zip_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.zip')
1893
+ buf = io.BytesIO()
1894
+ with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
1895
+ for f in sorted(os.listdir(outdir)):
1896
+ if f.endswith('.onnx'):
1897
+ zf.write(os.path.join(outdir, f), f)
1898
+ zip_bytes = buf.getvalue()
1899
+ with open(zip_path, 'wb') as f:
1900
+ f.write(zip_bytes)
1901
+ zip_size = len(zip_bytes)
1902
+
1903
+ csv_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.csv')
1904
+ with open(csv_path, 'w', newline='') as f:
1905
+ w = csv.writer(f)
1906
+ w.writerow(['task_id', 'total_cost'])
1907
+ for tn in sorted(costs_dict.keys()):
1908
+ w.writerow([f'task{tn:03d}', costs_dict[tn]])
1909
+
1910
+ unsolved_count = len(active_tasks) - len(results)
1911
+ est_lb = total_score + unsolved_count * 1.0
1912
+
1913
+ print(f"\n{n_files} ONNX files, {total_size/1024:.1f} KB uncompressed")
1914
+ print(f"ZIP size: {zip_size/1024:.1f} KB / {MAX_FILESIZE/1024:.0f} KB limit {'OK' if zip_size <= MAX_FILESIZE else 'OVER!'}")
1915
+ print(f"Estimated LB score: {est_lb:.1f} (solved: {total_score:.1f} + unsolved: {unsolved_count}×1.0)")
1916
+ print(f"Written: {zip_path} | {csv_path}")
1917
+
1918
+ if __name__ == '__main__':
1919
+ main()