rogermt commited on
Commit
b06267e
·
verified ·
1 Parent(s): ff5c300

Delete neurogolf_solver.py

Browse files
Files changed (1) hide show
  1. neurogolf_solver.py +0 -1364
neurogolf_solver.py DELETED
@@ -1,1364 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- ARC-AGI NeuroGolf Championship - Complete Solver v5
4
- Format: [1,10,30,30] one-hot input/output, opset 17, IR version 8.
5
- v5 CHANGES (from v4):
6
- - Opset 10 → 17, IR 10 → 8
7
- - s_flip: Slice(step=-1) replaces Gather — 0 MACs (was ~165K)
8
- - s_rotate k=2: double Slice(step=-1) — 0 MACs (was ~165K)
9
- - s_rotate k=1,3: Slice+Transpose for square grids (0 MACs), Gather fallback for non-square
10
- - All Pad nodes: tensor-based pads input (opset 17 requirement)
11
- - All ReduceSum nodes: axes as tensor input (opset 13+ requirement)
12
- - All other solvers unchanged from v4
13
- Solvers:
14
- - Analytical: identity, constant, color_map, transpose, flip, rotate, tile, upscale,
15
- concat, concat_enhanced, spatial_gather, varshape_spatial_gather,
16
- diagonal_tile, kronecker
17
- - Conv (fixed shape): Slice -> Conv -> ArgMax -> Equal+Cast -> Pad
18
- - Conv (variable shape): Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)
19
- - Conv (diff shape): Slice -> Conv -> Slice(crop) -> ArgMax -> Equal+Cast -> Pad
20
- Usage:
21
- python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission
22
- python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission --conv_budget 60 --arcgen_dir ARC-GEN-100K/
23
- """
24
-
25
- import json, os, sys, math, time, argparse, csv, io, zipfile
26
- import numpy as np
27
- import onnx
28
- from onnx import helper, TensorProto, numpy_helper
29
- import onnxruntime as ort
30
- from collections import Counter
31
-
32
- try:
33
- from neurogolf_utils import score_network as _score_network_official
34
- HAS_ONNX_TOOL = True
35
- except ImportError:
36
- HAS_ONNX_TOOL = False
37
-
38
- try:
39
- import wandb
40
- except ImportError:
41
- wandb = None
42
-
43
- BATCH, CH, GH, GW = 1, 10, 30, 30
44
- GRID_SHAPE = [BATCH, CH, GH, GW]
45
- DT = TensorProto.FLOAT
46
- IR = 8
47
- OPSET = [helper.make_opsetid("", 17)]
48
-
49
- INT64_MIN = int(np.iinfo(np.int64).min)
50
-
51
- EXCLUDED_TASKS = {21, 55, 80, 184, 202, 366}
52
- MAX_ARCGEN_VALIDATE = 30
53
- MAX_ARCGEN_FIT = 0
54
-
55
- def get_providers():
56
- return ['CPUExecutionProvider']
57
-
58
- ORT_PROVIDERS = get_providers()
59
-
60
- # ============================================================
61
- # LOAD / VALIDATE
62
- # ============================================================
63
-
64
- def load_tasks_dir(data_dir, arcgen_dir=None):
65
- files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
66
- tasks = {}
67
- for i, f in enumerate(files):
68
- with open(os.path.join(data_dir, f)) as fh:
69
- data = json.load(fh)
70
- hex_id = f.replace('.json','')
71
- if arcgen_dir and os.path.exists(os.path.join(arcgen_dir, f)):
72
- with open(os.path.join(arcgen_dir, f)) as fh:
73
- arcgen_examples = json.load(fh)
74
- if isinstance(arcgen_examples, list):
75
- data['arc-gen'] = arcgen_examples
76
- if 'arc-gen' not in data:
77
- data['arc-gen'] = []
78
- tasks[i+1] = {'hex': hex_id, 'data': data}
79
- return tasks
80
-
81
- def load_tasks_kaggle(data_dir):
82
- tasks = {}
83
- for tn in range(1, 401):
84
- path = os.path.join(data_dir, f"task{tn:03d}.json")
85
- if os.path.exists(path):
86
- with open(path) as f:
87
- data = json.load(f)
88
- if 'arc-gen' not in data:
89
- data['arc-gen'] = []
90
- tasks[tn] = {'hex': f'task{tn:03d}', 'data': data}
91
- return tasks
92
-
93
- def to_onehot(grid):
94
- arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
95
- for r, row in enumerate(grid):
96
- for c, v in enumerate(row):
97
- if r < GH and c < GW and 0 <= v < CH:
98
- arr[0, v, r, c] = 1.0
99
- return arr
100
-
101
- def validate(path, td):
102
- try:
103
- opts = ort.SessionOptions()
104
- opts.log_severity_level = 3
105
- sess = ort.InferenceSession(path, sess_options=opts, providers=ORT_PROVIDERS)
106
- except:
107
- return False
108
- examples = td['train'] + td['test']
109
- if 'arc-gen' in td:
110
- examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
111
- for ex in examples:
112
- inp = to_onehot(ex['input'])
113
- exp = to_onehot(ex['output'])
114
- try:
115
- out = sess.run(['output'], {'input': inp})[0]
116
- out = (out > 0.0).astype(np.float32)
117
- except:
118
- return False
119
- if not np.array_equal(out, exp):
120
- return False
121
- return True
122
-
123
- def validate_raw(raw_bytes, td):
124
- try:
125
- opts = ort.SessionOptions()
126
- opts.log_severity_level = 3
127
- sess = ort.InferenceSession(raw_bytes, sess_options=opts, providers=ORT_PROVIDERS)
128
- except:
129
- return False
130
- examples = td['train'] + td['test']
131
- if 'arc-gen' in td:
132
- examples = examples + td['arc-gen'][:MAX_ARCGEN_VALIDATE]
133
- for ex in examples:
134
- inp = to_onehot(ex['input'])
135
- exp = to_onehot(ex['output'])
136
- try:
137
- out = sess.run(['output'], {'input': inp})[0]
138
- out = (out > 0.0).astype(np.float32)
139
- except:
140
- return False
141
- if not np.array_equal(out, exp):
142
- return False
143
- return True
144
-
145
- # ============================================================
146
- # STATIC PROFILER
147
- # ============================================================
148
-
149
- BANNED_OPS = {'Loop', 'Scan', 'NonZero', 'Unique', 'If', 'Function'}
150
- MAX_FILESIZE = int(1.44 * 1024 * 1024)
151
-
152
- def score_network(path):
153
- if HAS_ONNX_TOOL:
154
- try:
155
- return _score_network_official(path)
156
- except:
157
- pass
158
- return _static_profile(path)
159
-
160
- def _static_profile(path):
161
- try:
162
- model = onnx.load(path)
163
- except:
164
- return None, None, None
165
- tensors = {}
166
- params = 0
167
- nbytes = 0
168
- macs = 0
169
- for init in model.graph.initializer:
170
- a = numpy_helper.to_array(init)
171
- tensors[init.name] = a
172
- params += a.size
173
- nbytes += a.nbytes
174
- for nd in model.graph.node:
175
- if nd.op_type == 'Constant':
176
- for attr in nd.attribute:
177
- if attr.t and attr.t.ByteSize() > 0:
178
- try:
179
- a = numpy_helper.to_array(attr.t)
180
- if nd.output:
181
- tensors[nd.output[0]] = a
182
- params += a.size
183
- nbytes += a.nbytes
184
- except:
185
- pass
186
- if nd.op_type in BANNED_OPS:
187
- return None, None, None
188
- if nd.op_type == 'Conv' and len(nd.input) >= 2 and nd.input[1] in tensors:
189
- w = tensors[nd.input[1]]
190
- if w.ndim == 4:
191
- co, ci, kh, kw = w.shape
192
- macs += co * ci * kh * kw * GH * GW
193
- return int(macs), int(nbytes), int(params)
194
-
195
- # ============================================================
196
- # OPSET 17 HELPERS
197
- # ============================================================
198
-
199
- def _make_int64_init(name, values):
200
- return numpy_helper.from_array(np.array(values, dtype=np.int64), name)
201
-
202
- def _build_pad_node(input_name, output_name, pad_h, pad_w, inits, suffix=''):
203
- """Pad with tensor-based pads input (opset 11+)."""
204
- pads_name = f'pads{suffix}'
205
- cv_name = f'pad_cv{suffix}'
206
- pads_arr = np.array([0, 0, 0, 0, 0, 0, pad_h, pad_w], dtype=np.int64)
207
- inits.append(numpy_helper.from_array(pads_arr, pads_name))
208
- inits.append(numpy_helper.from_array(np.array(0.0, dtype=np.float32), cv_name))
209
- return helper.make_node('Pad', [input_name, pads_name, cv_name], [output_name], mode='constant')
210
-
211
- def _build_slice_crop(input_name, output_name, IH, IW, inits, suffix=''):
212
- """Slice to crop [1,10,30,30] to [1,10,IH,IW]."""
213
- st_name = f'crop_st{suffix}'
214
- en_name = f'crop_en{suffix}'
215
- inits.append(_make_int64_init(st_name, [0, 0, 0, 0]))
216
- inits.append(_make_int64_init(en_name, [1, 10, IH, IW]))
217
- return helper.make_node('Slice', [input_name, st_name, en_name], [output_name])
218
-
219
- def _build_slice_reverse(input_name, output_name, axis, dim_size, inits, suffix=''):
220
- """Slice(step=-1) to reverse one axis. Zero MACs."""
221
- st_name = f'rev_st{suffix}'
222
- en_name = f'rev_en{suffix}'
223
- ax_name = f'rev_ax{suffix}'
224
- sp_name = f'rev_sp{suffix}'
225
- inits.append(_make_int64_init(st_name, [dim_size - 1]))
226
- inits.append(_make_int64_init(en_name, [INT64_MIN]))
227
- inits.append(_make_int64_init(ax_name, [axis]))
228
- inits.append(_make_int64_init(sp_name, [-1]))
229
- return helper.make_node('Slice', [input_name, st_name, en_name, ax_name, sp_name], [output_name])
230
-
231
- def _build_reducesum(input_name, output_name, axes_list, inits, suffix=''):
232
- """ReduceSum with axes as tensor input (opset 13+). keepdims=1."""
233
- axes_name = f'rs_axes{suffix}'
234
- inits.append(_make_int64_init(axes_name, axes_list))
235
- return helper.make_node('ReduceSum', [input_name, axes_name], [output_name], keepdims=1)
236
-
237
- def mk(nodes, inits=None):
238
- x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
239
- y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
240
- g = helper.make_graph(nodes, "g", [x], [y], initializer=inits or [])
241
- return helper.make_model(g, ir_version=IR, opset_imports=OPSET)
242
-
243
- def get_exs(td):
244
- return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
245
- for ex in td['train'] + td['test']]
246
-
247
- def get_exs_for_fitting(td):
248
- base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
249
- for ex in td['train'] + td['test']]
250
- if not base_exs:
251
- return base_exs
252
- base_shapes = {inp.shape for inp, _ in base_exs}
253
- if len(base_shapes) != 1:
254
- return base_exs
255
- base_shape = list(base_shapes)[0]
256
- ag_exs = []
257
- for ex in td.get('arc-gen', []):
258
- inp = np.array(ex['input'], dtype=np.int64)
259
- out = np.array(ex['output'], dtype=np.int64)
260
- if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
261
- ag_exs.append((inp, out))
262
- return base_exs + ag_exs[:10]
263
-
264
- def get_exs_for_fitting_variable(td):
265
- base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
266
- for ex in td['train'] + td['test']]
267
- ag_exs = []
268
- for ex in td.get('arc-gen', []):
269
- inp = np.array(ex['input'], dtype=np.int64)
270
- out = np.array(ex['output'], dtype=np.int64)
271
- if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
272
- ag_exs.append((inp, out))
273
- return base_exs + ag_exs[:20]
274
-
275
- def fixed_shapes(td):
276
- shapes = set()
277
- for inp, out in get_exs(td):
278
- shapes.add((inp.shape, out.shape))
279
- return list(shapes)[0] if len(shapes) == 1 else None
280
-
281
- # ============================================================
282
- # GATHER HELPERS
283
- # ============================================================
284
-
285
- def _build_gather_model(OH, OW, idx):
286
- flat_idx = np.zeros((GH*GW,), dtype=np.int64)
287
- mask = np.zeros((1,1,GH,GW), dtype=np.float32)
288
- for oi in range(OH):
289
- for oj in range(OW):
290
- flat_idx[oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
291
- mask[0,0,oi,oj] = 1.0
292
- inits = [
293
- numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
294
- numpy_helper.from_array(flat_idx, 'idx'),
295
- numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
296
- numpy_helper.from_array(mask, 'mask'),
297
- ]
298
- nodes = [
299
- helper.make_node('Reshape', ['input','fs'], ['flat']),
300
- helper.make_node('Gather', ['flat','idx'], ['g'], axis=2),
301
- helper.make_node('Reshape', ['g','os'], ['raw']),
302
- helper.make_node('Mul', ['raw','mask'], ['output']),
303
- ]
304
- return mk(nodes, inits)
305
-
306
- def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
307
- flat_idx = np.zeros((GH*GW,), dtype=np.int64)
308
- gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
309
- const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
310
- for oi in range(OH):
311
- for oj in range(OW):
312
- if idx[oi,oj,0] >= 0:
313
- flat_idx[oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
314
- gather_mask[0,0,oi,oj] = 1.0
315
- elif cst[oi,oj] >= 0:
316
- const_oh[0, cst[oi,oj], oi, oj] = 1.0
317
- has_const = np.any(const_oh > 0)
318
- inits = [
319
- numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
320
- numpy_helper.from_array(flat_idx, 'idx'),
321
- numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
322
- numpy_helper.from_array(gather_mask, 'gmask'),
323
- ]
324
- nodes = [
325
- helper.make_node('Reshape', ['input','fs'], ['flat']),
326
- helper.make_node('Gather', ['flat','idx'], ['g'], axis=2),
327
- helper.make_node('Reshape', ['g','os'], ['raw']),
328
- helper.make_node('Mul', ['raw','gmask'], ['masked']),
329
- ]
330
- if has_const:
331
- inits.append(numpy_helper.from_array(const_oh, 'cst'))
332
- nodes.append(helper.make_node('Add', ['masked','cst'], ['output']))
333
- else:
334
- nodes[-1] = helper.make_node('Mul', ['raw','gmask'], ['output'])
335
- return mk(nodes, inits)
336
-
337
- # ============================================================
338
- # ANALYTICAL SOLVERS
339
- # ============================================================
340
-
341
- def s_identity(td):
342
- for ex in td['train']+td['test']:
343
- if ex['input'] != ex['output']: return None
344
- return mk([helper.make_node('Identity', ['input'], ['output'])])
345
-
346
- def s_color_map(td):
347
- cm = {}
348
- for ex in td['train']+td['test']:
349
- inp, out = np.array(ex['input']), np.array(ex['output'])
350
- if inp.shape != out.shape: return None
351
- for iv, ov in zip(inp.flat, out.flat):
352
- iv, ov = int(iv), int(ov)
353
- if iv in cm and cm[iv] != ov: return None
354
- cm[iv] = ov
355
- is_permutation = (set(cm.keys()) == set(cm.values()))
356
- if is_permutation:
357
- gather_ch = np.arange(10, dtype=np.int32)
358
- for src, dst in cm.items():
359
- if 0 <= src < 10 and 0 <= dst < 10:
360
- gather_ch[dst] = src
361
- inits = [numpy_helper.from_array(gather_ch, 'gi')]
362
- nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
363
- return mk(nodes, inits)
364
- else:
365
- W = np.zeros((10,10,1,1), dtype=np.float32)
366
- for ic in range(10):
367
- W[cm.get(ic,ic), ic, 0, 0] = 1.0
368
- return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
369
- [numpy_helper.from_array(W, 'W')])
370
-
371
- def s_transpose(td):
372
- for ex in td['train']+td['test']:
373
- if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T): return None
374
- return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0,1,3,2])])
375
-
376
- def s_flip(td):
377
- """Flip using Slice(step=-1) — zero MACs."""
378
- exs = get_exs(td)
379
- sp = fixed_shapes(td)
380
- if sp is None: return None
381
- (IH,IW),(OH,OW) = sp
382
- if (IH,IW) != (OH,OW): return None
383
- for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
384
- if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
385
- onnx_axis = 2 if axis == 0 else 3
386
- dim_size = IH if axis == 0 else IW
387
- pad_h, pad_w = GH - IH, GW - IW
388
- inits = []
389
- nodes = []
390
- nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
391
- nodes.append(_build_slice_reverse('cropped', 'flipped', onnx_axis, dim_size, inits))
392
- nodes.append(_build_pad_node('flipped', 'output', pad_h, pad_w, inits))
393
- return mk(nodes, inits)
394
- return None
395
-
396
- def s_rotate(td):
397
- """Rotate using Slice+Transpose — zero MACs for square grids and k=2.
398
- Gather fallback for non-square k=1,3."""
399
- exs = get_exs(td)
400
- sp = fixed_shapes(td)
401
- if sp is None: return None
402
- (IH,IW),(OH,OW) = sp
403
- for k in [1, 2, 3]:
404
- if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs):
405
- continue
406
- if k == 2:
407
- pad_h, pad_w = GH - OH, GW - OW
408
- inits = []
409
- nodes = []
410
- nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
411
- nodes.append(_build_slice_reverse('cropped', 'flip_h', 2, IH, inits, suffix='_h'))
412
- nodes.append(_build_slice_reverse('flip_h', 'rotated', 3, IW, inits, suffix='_w'))
413
- nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
414
- return mk(nodes, inits)
415
- elif k == 1 and IH == IW:
416
- # rot90 CCW square: Transpose then flip axis 2
417
- pad_h, pad_w = GH - IH, GW - IW
418
- inits = []
419
- nodes = []
420
- nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
421
- nodes.append(helper.make_node('Transpose', ['cropped'], ['transposed'], perm=[0,1,3,2]))
422
- nodes.append(_build_slice_reverse('transposed', 'rotated', 2, IH, inits))
423
- nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
424
- return mk(nodes, inits)
425
- elif k == 3 and IH == IW:
426
- # rot270 CCW square: flip axis 2 then Transpose
427
- pad_h, pad_w = GH - IH, GW - IW
428
- inits = []
429
- nodes = []
430
- nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
431
- nodes.append(_build_slice_reverse('cropped', 'flipped', 2, IH, inits))
432
- nodes.append(helper.make_node('Transpose', ['flipped'], ['rotated'], perm=[0,1,3,2]))
433
- nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
434
- return mk(nodes, inits)
435
- else:
436
- # Non-square k=1 or k=3: Gather fallback
437
- idx = np.zeros((OH,OW,2), dtype=np.int64)
438
- for r in range(OH):
439
- for c in range(OW):
440
- if k == 1: sr, sc = c, IH-1-r
441
- elif k == 3: sr, sc = IW-1-c, r
442
- idx[r,c] = [sr, sc]
443
- return _build_gather_model(OH, OW, idx)
444
- return None
445
-
446
- def s_spatial_gather(td):
447
- sp = fixed_shapes(td)
448
- if sp is None: return None
449
- (IH,IW),(OH,OW) = sp
450
- exs = get_exs(td)
451
- idx = np.full((OH,OW,2), -1, dtype=np.int64)
452
- cst = np.full((OH,OW), -1, dtype=np.int64)
453
- for oi in range(OH):
454
- for oj in range(OW):
455
- vals = set(int(out[oi,oj]) for _,out in exs)
456
- if len(vals) == 1: cst[oi,oj] = vals.pop()
457
- found = False
458
- for ri in range(IH):
459
- for rj in range(IW):
460
- if all(int(inp[ri,rj]) == int(out[oi,oj]) for inp,out in exs):
461
- idx[oi,oj] = [ri, rj]; found = True; break
462
- if found: break
463
- if not found and cst[oi,oj] < 0: return None
464
- return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
465
-
466
- def s_varshape_spatial_gather(td):
467
- sp = fixed_shapes(td)
468
- if sp is not None: return None
469
- exs = get_exs(td)
470
- exs_30 = []
471
- for inp, out in exs:
472
- ih, iw = inp.shape
473
- oh, ow = out.shape
474
- inp30 = np.zeros((30, 30), dtype=np.int64)
475
- out30 = np.zeros((30, 30), dtype=np.int64)
476
- inp30[:ih, :iw] = inp
477
- out30[:oh, :ow] = out
478
- exs_30.append((inp30, out30))
479
- idx = np.full((30, 30, 2), -1, dtype=np.int64)
480
- cst = np.full((30, 30), -1, dtype=np.int64)
481
- for oi in range(30):
482
- for oj in range(30):
483
- vals = set(int(out30[oi, oj]) for _, out30 in exs_30)
484
- if len(vals) == 1:
485
- cst[oi, oj] = vals.pop()
486
- found = False
487
- for ri in range(30):
488
- for rj in range(30):
489
- if all(int(inp30[ri, rj]) == int(out30[oi, oj]) for inp30, out30 in exs_30):
490
- idx[oi, oj] = [ri, rj]
491
- found = True
492
- break
493
- if found: break
494
- if not found and cst[oi, oj] < 0:
495
- return None
496
- return _build_gather_model_with_const(30, 30, 30, 30, idx, cst)
497
-
498
- def s_tile(td):
499
- exs = get_exs(td)
500
- in_shapes = set(inp.shape for inp,_ in exs)
501
- if len(in_shapes) != 1: return None
502
- IH, IW = in_shapes.pop()
503
- tiles = set()
504
- for inp, out in exs:
505
- OH, OW = out.shape
506
- if OH % IH or OW % IW: return None
507
- rH, rW = OH//IH, OW//IW
508
- if rH < 1 or rW < 1 or (rH==1 and rW==1): return None
509
- tiles.add((rH, rW))
510
- if len(tiles) != 1: return None
511
- rH, rW = tiles.pop()
512
- OH, OW = IH*rH, IW*rW
513
- if OH > 30 or OW > 30: return None
514
- for inp, out in exs:
515
- if not np.array_equal(out, np.tile(inp, (rH, rW))): return None
516
- pad_h, pad_w = 30-OH, 30-OW
517
- inits = [
518
- _make_int64_init('st', [0,0,0,0]),
519
- _make_int64_init('en', [1,10,IH,IW]),
520
- _make_int64_init('rp', [1,1,rH,rW]),
521
- ]
522
- nodes = [
523
- helper.make_node('Slice', ['input','st','en'], ['cr']),
524
- helper.make_node('Tile', ['cr','rp'], ['tl']),
525
- ]
526
- nodes.append(_build_pad_node('tl', 'output', pad_h, pad_w, inits))
527
- return mk(nodes, inits)
528
-
529
- def s_upscale(td):
530
- exs = get_exs(td)
531
- in_shapes = set(inp.shape for inp,_ in exs)
532
- if len(in_shapes) != 1: return None
533
- IH, IW = in_shapes.pop()
534
- scales = set()
535
- for inp, out in exs:
536
- OH, OW = out.shape
537
- if OH % IH or OW % IW: return None
538
- sH, sW = OH//IH, OW//IW
539
- if sH < 2 or sW < 2: return None
540
- scales.add((sH, sW))
541
- if len(scales) != 1: return None
542
- sH, sW = scales.pop()
543
- OH, OW = IH*sH, IW*sW
544
- if OH > 30 or OW > 30: return None
545
- for inp, out in exs:
546
- if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)): return None
547
- idx = np.zeros((OH,OW,2), dtype=np.int64)
548
- for r in range(OH):
549
- for c in range(OW):
550
- idx[r,c] = [r//sH, c//sW]
551
- return _build_gather_model(OH, OW, idx)
552
-
553
- def s_concat(td):
554
- from itertools import product as iproduct
555
- exs = get_exs(td)
556
- sp = fixed_shapes(td)
557
- if sp is None: return None
558
- (IH,IW),(OH,OW) = sp
559
- transforms = [
560
- ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
561
- ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
562
- ]
563
- if OH == IH and OW % IW == 0 and OW > IW:
564
- n = OW // IW
565
- if 2 <= n <= 4:
566
- for combo in iproduct(range(4), repeat=n):
567
- if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=1))
568
- for inp, out in exs):
569
- idx = np.zeros((OH,OW,2), dtype=np.int64)
570
- for oi in range(OH):
571
- for oj in range(OW):
572
- bj = oj // IW; lr, lc = oi, oj % IW
573
- t = transforms[combo[bj]][0]
574
- if t == 'id': sr, sc = lr, lc
575
- elif t == 'fliplr': sr, sc = lr, IW-1-lc
576
- elif t == 'flipud': sr, sc = IH-1-lr, lc
577
- elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
578
- idx[oi,oj] = [sr, sc]
579
- return _build_gather_model(OH, OW, idx)
580
- if OW == IW and OH % IH == 0 and OH > IH:
581
- n = OH // IH
582
- if 2 <= n <= 4:
583
- for combo in iproduct(range(4), repeat=n):
584
- if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=0))
585
- for inp, out in exs):
586
- idx = np.zeros((OH,OW,2), dtype=np.int64)
587
- for oi in range(OH):
588
- for oj in range(OW):
589
- bi = oi // IH; lr, lc = oi % IH, oj
590
- t = transforms[combo[bi]][0]
591
- if t == 'id': sr, sc = lr, lc
592
- elif t == 'fliplr': sr, sc = lr, IW-1-lc
593
- elif t == 'flipud': sr, sc = IH-1-lr, lc
594
- elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
595
- idx[oi,oj] = [sr, sc]
596
- return _build_gather_model(OH, OW, idx)
597
- return None
598
-
599
- def s_concat_enhanced(td):
600
- exs = get_exs(td)
601
- sp = fixed_shapes(td)
602
- if sp is None: return None
603
- (IH,IW),(OH,OW) = sp
604
- if IH == OH and IW == OW: return None
605
- if OH % IH != 0 or OW % IW != 0: return None
606
- rH, rW = OH // IH, OW // IW
607
- if rH * rW > 16 or rH * rW < 2: return None
608
- if OH > 30 or OW > 30: return None
609
- transforms = [
610
- ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
611
- ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
612
- ('rot90', lambda x: np.rot90(x, 1)), ('rot270', lambda x: np.rot90(x, 3)),
613
- ('T', lambda x: x.T), ('T_fliplr', lambda x: np.fliplr(x.T)),
614
- ]
615
- block_transforms = {}
616
- for bi in range(rH):
617
- for bj in range(rW):
618
- found = None
619
- for tidx, (tname, tfn) in enumerate(transforms):
620
- ok = True
621
- for inp, out in exs:
622
- block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
623
- expected = tfn(inp)
624
- if expected.shape != (IH, IW) or not np.array_equal(block, expected):
625
- ok = False; break
626
- if ok:
627
- found = (tidx, tname); break
628
- if found is None: return None
629
- block_transforms[(bi, bj)] = found
630
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
631
- for bi in range(rH):
632
- for bj in range(rW):
633
- _, tname = block_transforms[(bi, bj)]
634
- for lr in range(IH):
635
- for lc in range(IW):
636
- oi, oj = bi*IH + lr, bj*IW + lc
637
- if tname == 'id': sr, sc = lr, lc
638
- elif tname == 'fliplr': sr, sc = lr, IW-1-lc
639
- elif tname == 'flipud': sr, sc = IH-1-lr, lc
640
- elif tname == 'rot180': sr, sc = IH-1-lr, IW-1-lc
641
- elif tname == 'rot90': sr, sc = IW-1-lc, lr
642
- elif tname == 'rot270': sr, sc = lc, IH-1-lr
643
- elif tname == 'T': sr, sc = lc, lr
644
- elif tname == 'T_fliplr': sr, sc = IW-1-lc, lr
645
- idx[oi, oj] = [sr, sc]
646
- for inp, out in exs:
647
- reconstructed = np.zeros_like(out)
648
- for oi in range(OH):
649
- for oj in range(OW):
650
- reconstructed[oi,oj] = inp[idx[oi,oj,0], idx[oi,oj,1]]
651
- if not np.array_equal(reconstructed, out): return None
652
- return _build_gather_model(OH, OW, idx)
653
-
654
- def s_input_driven_tile(td):
655
- exs = get_exs(td)
656
- sp = fixed_shapes(td)
657
- if sp is None: return None
658
- (IH,IW),(OH,OW) = sp
659
- if OH % IH != 0 or OW % IW != 0: return None
660
- sH, sW = OH // IH, OW // IW
661
- if sH != IH or sW != IW: return None
662
- if OH > 30 or OW > 30: return None
663
- for inp, out in exs:
664
- for bi in range(IH):
665
- for bj in range(IW):
666
- block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
667
- if inp[bi, bj] != 0:
668
- if not np.array_equal(block, inp): return None
669
- else:
670
- if not np.all(block == 0): return None
671
- return None
672
-
673
- def s_kronecker(td):
674
- exs = get_exs(td)
675
- sp = fixed_shapes(td)
676
- if sp is None: return None
677
- (IH,IW),(OH,OW) = sp
678
- if OH % IH != 0 or OW % IW != 0: return None
679
- sH, sW = OH // IH, OW // IW
680
- if sH < 2 or sW < 2: return None
681
- if OH > 30 or OW > 30: return None
682
- for inp, out in exs:
683
- if not np.array_equal(out, np.kron(inp, np.ones((sH, sW), dtype=np.int64))): return None
684
- idx = np.zeros((OH,OW,2), dtype=np.int64)
685
- for r in range(OH):
686
- for c in range(OW):
687
- idx[r,c] = [r//sH, c//sW]
688
- return _build_gather_model(OH, OW, idx)
689
-
690
- def s_diagonal_tile(td):
691
- exs = get_exs(td)
692
- sp = fixed_shapes(td)
693
- if sp is None: return None
694
- (IH,IW),(OH,OW) = sp
695
- if OH % IH != 0 or OW % IW != 0: return None
696
- rH, rW = OH // IH, OW // IW
697
- if rH != rW or rH < 2: return None
698
- if OH > 30 or OW > 30: return None
699
- for inp, out in exs:
700
- for bi in range(rH):
701
- for bj in range(rW):
702
- block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
703
- if bi == bj:
704
- if not np.array_equal(block, inp): return None
705
- else:
706
- if not np.all(block == 0): return None
707
- idx = np.zeros((OH,OW,2), dtype=np.int64)
708
- cst = np.full((OH,OW), -1, dtype=np.int64)
709
- for bi in range(rH):
710
- for bj in range(rW):
711
- for lr in range(IH):
712
- for lc in range(IW):
713
- oi, oj = bi*IH + lr, bj*IW + lc
714
- if bi == bj:
715
- idx[oi, oj] = [lr, lc]
716
- else:
717
- idx[oi, oj] = [-1, -1]
718
- cst[oi, oj] = 0
719
- return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
720
-
721
- def s_shift(td):
722
- exs = get_exs(td)
723
- sp = fixed_shapes(td)
724
- if sp is None: return None
725
- (IH, IW), (OH, OW) = sp
726
- if (IH, IW) != (OH, OW): return None
727
- for dr in range(-5, 6):
728
- for dc in range(-5, 6):
729
- if dr == 0 and dc == 0: continue
730
- ok = True
731
- for inp, out in exs:
732
- shifted = np.zeros_like(inp)
733
- r0, r1 = max(0, dr), min(IH, IH + dr)
734
- c0, c1 = max(0, dc), min(IW, IW + dc)
735
- if r1 > r0 and c1 > c0:
736
- sr0, sc0 = max(0, -dr), max(0, -dc)
737
- shifted[r0:r1, c0:c1] = inp[sr0:sr0+(r1-r0), sc0:sc0+(c1-c0)]
738
- if not np.array_equal(shifted, out):
739
- ok = False; break
740
- if not ok: continue
741
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
742
- cst = np.full((OH, OW), 0, dtype=np.int64)
743
- for r in range(OH):
744
- for c in range(OW):
745
- sr, sc = r - dr, c - dc
746
- if 0 <= sr < IH and 0 <= sc < IW:
747
- idx[r, c] = [sr, sc]
748
- else:
749
- idx[r, c] = [-1, -1]
750
- return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
751
- return None
752
-
753
- def s_gravity(td):
754
- exs = get_exs(td)
755
- sp = fixed_shapes(td)
756
- if sp is None: return None
757
- (IH, IW), (OH, OW) = sp
758
- if (IH, IW) != (OH, OW): return None
759
- def _gravity(grid, direction):
760
- r = np.zeros_like(grid); h, w = grid.shape
761
- if direction in ('down', 'up'):
762
- for c in range(w):
763
- nz = grid[:, c][grid[:, c] != 0]
764
- if direction == 'down': r[h-len(nz):h, c] = nz
765
- else: r[:len(nz), c] = nz
766
- else:
767
- for rr in range(h):
768
- nz = grid[rr, :][grid[rr, :] != 0]
769
- if direction == 'right': r[rr, w-len(nz):w] = nz
770
- else: r[rr, :len(nz)] = nz
771
- return r
772
- for d in ('down', 'up', 'left', 'right'):
773
- if all(np.array_equal(_gravity(inp, d), out) for inp, out in exs):
774
- return None
775
- return None
776
-
777
- def s_mirror_h(td):
778
- exs = get_exs(td)
779
- sp = fixed_shapes(td)
780
- if sp is None: return None
781
- (IH, IW), (OH, OW) = sp
782
- if OH != IH or OW != 2 * IW: return None
783
- if OW > 30: return None
784
- for inp, out in exs:
785
- if not np.array_equal(np.concatenate([inp, np.flip(inp, 1)], 1), out): return None
786
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
787
- for r in range(OH):
788
- for c in range(OW):
789
- sc = c if c < IW else 2*IW - 1 - c
790
- idx[r, c] = [r, sc]
791
- return _build_gather_model(OH, OW, idx)
792
-
793
- def s_mirror_v(td):
794
- exs = get_exs(td)
795
- sp = fixed_shapes(td)
796
- if sp is None: return None
797
- (IH, IW), (OH, OW) = sp
798
- if OW != IW or OH != 2 * IH: return None
799
- if OH > 30: return None
800
- for inp, out in exs:
801
- if not np.array_equal(np.concatenate([inp, np.flip(inp, 0)], 0), out): return None
802
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
803
- for r in range(OH):
804
- for c in range(OW):
805
- sr = r if r < IH else 2*IH - 1 - r
806
- idx[r, c] = [sr, c]
807
- return _build_gather_model(OH, OW, idx)
808
-
809
- def s_quad_mirror(td):
810
- exs = get_exs(td)
811
- sp = fixed_shapes(td)
812
- if sp is None: return None
813
- (IH, IW), (OH, OW) = sp
814
- if OH != 2 * IH or OW != 2 * IW: return None
815
- if OH > 30 or OW > 30: return None
816
- for inp, out in exs:
817
- expected = np.block([[inp, np.flip(inp, 1)],
818
- [np.flip(inp, 0), np.flip(np.flip(inp, 0), 1)]])
819
- if not np.array_equal(expected, out): return None
820
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
821
- for r in range(OH):
822
- for c in range(OW):
823
- sr = r if r < IH else 2*IH - 1 - r
824
- sc = c if c < IW else 2*IW - 1 - c
825
- idx[r, c] = [sr, sc]
826
- return _build_gather_model(OH, OW, idx)
827
-
828
- def s_fixed_crop(td):
829
- exs = get_exs(td)
830
- sp = fixed_shapes(td)
831
- if sp is None: return None
832
- (IH, IW), (OH, OW) = sp
833
- if OH > IH or OW > IW or (OH == IH and OW == IW): return None
834
- for r0 in range(IH - OH + 1):
835
- for c0 in range(IW - OW + 1):
836
- if all(np.array_equal(inp[r0:r0+OH, c0:c0+OW], out) for inp, out in exs):
837
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
838
- for r in range(OH):
839
- for c in range(OW):
840
- idx[r, c] = [r0 + r, c0 + c]
841
- return _build_gather_model(OH, OW, idx)
842
- return None
843
-
844
- def s_nonuniform_scale(td):
845
- exs = get_exs(td)
846
- sp = fixed_shapes(td)
847
- if sp is None: return None
848
- (IH, IW), (OH, OW) = sp
849
- for fh, fw in [(1,2),(2,1),(1,3),(3,1),(2,3),(3,2),(1,4),(4,1),(2,4),(4,2)]:
850
- if OH != IH*fh or OW != IW*fw: continue
851
- if OH > 30 or OW > 30: continue
852
- if all(np.array_equal(np.repeat(np.repeat(inp, fh, 0), fw, 1), out) for inp, out in exs):
853
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
854
- for r in range(OH):
855
- for c in range(OW):
856
- idx[r, c] = [r//fh, c//fw]
857
- return _build_gather_model(OH, OW, idx)
858
- return None
859
-
860
- def s_constant(td):
861
- """Constant output. Uses opset 17 ReduceSum with tensor axes input."""
862
- sp = fixed_shapes(td)
863
- if sp is None: return None
864
- exs = get_exs(td)
865
- outs = [out for _,out in exs]
866
- if not all(np.array_equal(outs[0], o) for o in outs[1:]): return None
867
- const = np.zeros((1,10,30,30), dtype=np.float32)
868
- for r, row in enumerate(outs[0]):
869
- for c, v in enumerate(row):
870
- const[0, int(v), r, c] = 1.0
871
- inits = [
872
- numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
873
- numpy_helper.from_array(const, 'c'),
874
- _make_int64_init('rs_axes_cst', [1, 2, 3]),
875
- ]
876
- nodes = [
877
- helper.make_node('Mul', ['input','z'], ['zd']),
878
- helper.make_node('ReduceSum', ['zd', 'rs_axes_cst'], ['s'], keepdims=1),
879
- helper.make_node('Add', ['s','c'], ['output']),
880
- ]
881
- return mk(nodes, inits)
882
-
883
- # ============================================================
884
- # CONV SOLVERS
885
- # ============================================================
886
-
887
- def add_onehot_block(nodes, inits, am_name, oh_name):
888
- classes = np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1)
889
- inits.append(numpy_helper.from_array(classes, 'classes'))
890
- nodes.append(helper.make_node('Equal', [am_name, 'classes'], ['eq']))
891
- nodes.append(helper.make_node('Cast', ['eq'], [oh_name], to=TensorProto.FLOAT))
892
-
893
- def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
894
- pad = ks // 2
895
- feat = 10 * ks * ks + (1 if use_bias else 0)
896
- if feat > 20000: return None
897
- patches, targets = [], []
898
- for inp_g, out_g in exs_raw:
899
- ih, iw = inp_g.shape
900
- if use_full_30:
901
- oh_full = np.zeros((10, GH, GW), dtype=np.float64)
902
- for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
903
- oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
904
- else:
905
- oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
906
- for c in range(10): oh_enc[c] = (inp_g == c)
907
- oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
908
- oh, ow = out_g.shape
909
- for r in range(oh):
910
- for c in range(ow):
911
- p = oh_pad[:, r:r+ks, c:c+ks].flatten()
912
- if use_bias: p = np.append(p, 1.0)
913
- patches.append(p)
914
- targets.append(int(out_g[r, c]))
915
- n_patches = len(patches)
916
- if feat > 5000 and n_patches > 2000: return None
917
- P = np.array(patches, dtype=np.float64)
918
- T = np.array(targets, dtype=np.int64)
919
- T_oh = np.zeros((len(T), 10), dtype=np.float64)
920
- for i, t in enumerate(T): T_oh[i, t] = 1.0
921
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
922
- if not np.array_equal(np.argmax(P @ WT, axis=1), T): return None
923
- if use_bias:
924
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
925
- B = WT[-1].astype(np.float32)
926
- else:
927
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
928
- B = None
929
- return Wconv, B
930
-
931
- def solve_conv_fixed(td, path, time_budget=30.0):
932
- exs = get_exs(td)
933
- for inp, out in exs:
934
- if inp.shape != out.shape: return None
935
- shapes = set(inp.shape for inp, _ in exs)
936
- if len(shapes) != 1: return None
937
- IH, IW = shapes.pop()
938
- fit_exs = get_exs_for_fitting(td)
939
- fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
940
- t_start = time.time()
941
- for use_bias in [False, True]:
942
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
943
- if time.time() - t_start > time_budget: return None
944
- result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=False)
945
- if result is None: continue
946
- Wconv, B = result
947
- pad = ks // 2
948
- pad_h, pad_w = GH - IH, GW - IW
949
- inits = [
950
- _make_int64_init('sl_st', [0,0,0,0]),
951
- _make_int64_init('sl_en', [1,10,IH,IW]),
952
- numpy_helper.from_array(Wconv, 'W'),
953
- ]
954
- conv_inputs = ['grid', 'W']
955
- if B is not None:
956
- inits.append(numpy_helper.from_array(B, 'B'))
957
- conv_inputs.append('B')
958
- nodes = [
959
- helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
960
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
961
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
962
- ]
963
- add_onehot_block(nodes, inits, 'am', 'oh_out')
964
- nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
965
- model = mk(nodes, inits)
966
- onnx.save(model, path)
967
- if validate(path, td): return 'conv_fixed', model
968
- return None
969
-
970
- def solve_conv_variable(td, path, time_budget=30.0):
971
- """Variable-shape conv with opset 17 ReduceSum (axes as tensor input)."""
972
- exs = get_exs(td)
973
- for inp, out in exs:
974
- if inp.shape != out.shape: return None
975
- fit_exs = get_exs_for_fitting_variable(td)
976
- fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape]
977
- t_start = time.time()
978
- for use_bias in [False, True]:
979
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
980
- if time.time() - t_start > time_budget: return None
981
- result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=True)
982
- if result is None: continue
983
- Wconv, B = result
984
- pad = ks // 2
985
- inits = [
986
- numpy_helper.from_array(Wconv, 'W'),
987
- _make_int64_init('rs_axes_var', [1]),
988
- ]
989
- conv_inputs = ['input', 'W']
990
- if B is not None:
991
- inits.append(numpy_helper.from_array(B, 'B'))
992
- conv_inputs.append('B')
993
- nodes = [
994
- helper.make_node('ReduceSum', ['input', 'rs_axes_var'], ['mask'], keepdims=1),
995
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
996
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
997
- ]
998
- add_onehot_block(nodes, inits, 'am', 'oh_out')
999
- nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1000
- model = mk(nodes, inits)
1001
- onnx.save(model, path)
1002
- if validate(path, td): return 'conv_var', model
1003
- return None
1004
-
1005
- def solve_conv_diffshape(td, path, time_budget=30.0):
1006
- sp = fixed_shapes(td)
1007
- if sp is None: return None
1008
- (IH, IW), (OH, OW) = sp
1009
- if IH == OH and IW == OW: return None
1010
- if OH > IH or OW > IW: return None
1011
- if OH > 30 or OW > 30: return None
1012
- exs = get_exs(td)
1013
- t_start = time.time()
1014
- for dr_off, dc_off in [(0, 0), ((IH-OH)//2, (IW-OW)//2)]:
1015
- for use_bias in [False, True]:
1016
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
1017
- if time.time() - t_start > time_budget: return None
1018
- pad = ks // 2
1019
- feat = 10 * ks * ks + (1 if use_bias else 0)
1020
- if feat > 10000: continue
1021
- patches, targets = [], []
1022
- valid = True
1023
- for inp_g, out_g in exs:
1024
- oh_enc = np.zeros((10, IH, IW), dtype=np.float64)
1025
- for c in range(10): oh_enc[c] = (inp_g == c)
1026
- oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
1027
- for r in range(OH):
1028
- for c in range(OW):
1029
- sr, sc = r + dr_off, c + dc_off
1030
- if sr < 0 or sr >= IH or sc < 0 or sc >= IW:
1031
- valid = False; break
1032
- p = oh_pad[:, sr:sr+ks, sc:sc+ks].flatten()
1033
- if use_bias: p = np.append(p, 1.0)
1034
- patches.append(p)
1035
- targets.append(int(out_g[r, c]))
1036
- if not valid: break
1037
- if not valid: break
1038
- if not valid: continue
1039
- n_patches = len(patches)
1040
- if feat > 5000 and n_patches > 2000: continue
1041
- P = np.array(patches, dtype=np.float64)
1042
- T = np.array(targets, dtype=np.int64)
1043
- T_oh = np.zeros((len(T), 10), dtype=np.float64)
1044
- for i, t in enumerate(T): T_oh[i, t] = 1.0
1045
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1046
- if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
1047
- if use_bias:
1048
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
1049
- B = WT[-1].astype(np.float32)
1050
- else:
1051
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1052
- B = None
1053
- pad_h, pad_w = GH - OH, GW - OW
1054
- inits = [
1055
- _make_int64_init('sl_st', [0,0,0,0]),
1056
- _make_int64_init('sl_en', [1,10,IH,IW]),
1057
- numpy_helper.from_array(Wconv, 'W'),
1058
- _make_int64_init('cr_st', [0,0,dr_off,dc_off]),
1059
- _make_int64_init('cr_en', [1,10,dr_off+OH,dc_off+OW]),
1060
- ]
1061
- conv_inputs = ['grid', 'W']
1062
- if B is not None:
1063
- inits.append(numpy_helper.from_array(B, 'B'))
1064
- conv_inputs.append('B')
1065
- nodes = [
1066
- helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
1067
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1068
- helper.make_node('Slice', ['co','cr_st','cr_en'], ['co_crop']),
1069
- helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
1070
- ]
1071
- add_onehot_block(nodes, inits, 'am', 'oh_out')
1072
- nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
1073
- model = mk(nodes, inits)
1074
- onnx.save(model, path)
1075
- if validate(path, td): return 'conv_diff', model
1076
- return None
1077
-
1078
- def solve_conv_var_diff(td, path, time_budget=30.0):
1079
- """Variable diff-shape conv with opset 17 ReduceSum."""
1080
- exs = get_exs(td)
1081
- t_start = time.time()
1082
- for use_bias in [False, True]:
1083
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
1084
- if time.time() - t_start > time_budget: return None
1085
- pad = ks // 2
1086
- feat = 10 * ks * ks + (1 if use_bias else 0)
1087
- if feat > 20000: continue
1088
- patches, targets = [], []
1089
- for inp_g, out_g in exs:
1090
- ih, iw = inp_g.shape
1091
- oh, ow = out_g.shape
1092
- oh_full = np.zeros((10, GH, GW), dtype=np.float64)
1093
- for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
1094
- oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
1095
- for r in range(oh):
1096
- for c in range(ow):
1097
- p = oh_pad[:, r:r+ks, c:c+ks].flatten()
1098
- if use_bias: p = np.append(p, 1.0)
1099
- patches.append(p)
1100
- targets.append(int(out_g[r, c]))
1101
- n_patches = len(patches)
1102
- if feat > 5000 and n_patches > 2000: continue
1103
- P = np.array(patches, dtype=np.float64)
1104
- T = np.array(targets, dtype=np.int64)
1105
- T_oh = np.zeros((len(T), 10), dtype=np.float64)
1106
- for i, t in enumerate(T): T_oh[i, t] = 1.0
1107
- try:
1108
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
1109
- except:
1110
- continue
1111
- if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
1112
- if use_bias:
1113
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
1114
- B = WT[-1].astype(np.float32)
1115
- else:
1116
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
1117
- B = None
1118
- all_output_within_input = all(
1119
- out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
1120
- for inp_g, out_g in exs
1121
- )
1122
- if all_output_within_input:
1123
- inits = [
1124
- numpy_helper.from_array(Wconv, 'W'),
1125
- _make_int64_init('rs_axes_vd', [1]),
1126
- ]
1127
- conv_inputs = ['input', 'W']
1128
- if B is not None:
1129
- inits.append(numpy_helper.from_array(B, 'B'))
1130
- conv_inputs.append('B')
1131
- nodes = [
1132
- helper.make_node('ReduceSum', ['input', 'rs_axes_vd'], ['mask'], keepdims=1),
1133
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
1134
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
1135
- ]
1136
- add_onehot_block(nodes, inits, 'am', 'oh_out')
1137
- nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
1138
- model = mk(nodes, inits)
1139
- onnx.save(model, path)
1140
- if validate(path, td): return 'conv_var_diff', model
1141
- return None
1142
-
1143
- # ============================================================
1144
- # MAIN SOLVER
1145
- # ============================================================
1146
-
1147
- ANALYTICAL_SOLVERS = [
1148
- ('identity', s_identity),
1149
- ('constant', s_constant),
1150
- ('color_map', s_color_map),
1151
- ('transpose', s_transpose),
1152
- ('flip', s_flip),
1153
- ('rotate', s_rotate),
1154
- ('shift', s_shift),
1155
- ('tile', s_tile),
1156
- ('upscale', s_upscale),
1157
- ('kronecker', s_kronecker),
1158
- ('nonuniform_scale', s_nonuniform_scale),
1159
- ('mirror_h', s_mirror_h),
1160
- ('mirror_v', s_mirror_v),
1161
- ('quad_mirror', s_quad_mirror),
1162
- ('concat', s_concat),
1163
- ('concat_enhanced', s_concat_enhanced),
1164
- ('diagonal_tile', s_diagonal_tile),
1165
- ('fixed_crop', s_fixed_crop),
1166
- ('spatial_gather', s_spatial_gather),
1167
- ('varshape_spatial_gather', s_varshape_spatial_gather),
1168
- ]
1169
-
1170
-
1171
- def solve_task(tn, td, outdir, conv_budget=30.0):
1172
- t_start = time.time()
1173
- os.makedirs(outdir, exist_ok=True)
1174
- path = os.path.join(outdir, f"task{tn:03d}.onnx")
1175
-
1176
- # Skip excluded tasks
1177
- if tn in EXCLUDED_TASKS:
1178
- return False, 'excluded', None, time.time() - t_start, path
1179
-
1180
- # 1. Try analytical solvers (fast, tiny models)
1181
- for sname, sfn in ANALYTICAL_SOLVERS:
1182
- try:
1183
- model = sfn(td)
1184
- if model is None: continue
1185
- onnx.save(model, path)
1186
- if validate(path, td):
1187
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1188
- except: pass
1189
-
1190
- # 2. Determine task shape category and try conv solvers
1191
- exs = get_exs(td)
1192
- same_shape = all(inp.shape == out.shape for inp, out in exs)
1193
- shapes = set(inp.shape for inp, _ in exs)
1194
- fixed_in = len(shapes) == 1
1195
-
1196
- conv_time = conv_budget
1197
-
1198
- if same_shape:
1199
- if fixed_in:
1200
- result = solve_conv_fixed(td, path, time_budget=conv_time/2)
1201
- if result is not None:
1202
- sname, model = result
1203
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1204
- result = solve_conv_variable(td, path, time_budget=conv_time)
1205
- if result is not None:
1206
- sname, model = result
1207
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1208
- else:
1209
- sp = fixed_shapes(td)
1210
- if sp is not None:
1211
- (IH,IW),(OH,OW) = sp
1212
- if OH <= IH and OW <= IW:
1213
- result = solve_conv_diffshape(td, path, time_budget=conv_time)
1214
- if result is not None:
1215
- sname, model = result
1216
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1217
-
1218
- # Try variable diff-shape conv (output within input bounds)
1219
- result = solve_conv_var_diff(td, path, time_budget=conv_time)
1220
- if result is not None:
1221
- sname, model = result
1222
- return True, sname, os.path.getsize(path), time.time() - t_start, path
1223
-
1224
- return False, None, None, time.time() - t_start, path
1225
-
1226
- def run_tasks(task_nums, tasks, output_dir, conv_budget, use_wandb):
1227
- results = {}
1228
- costs_dict = {}
1229
- total_score = 0
1230
- for tn in task_nums:
1231
- if tn not in tasks:
1232
- continue
1233
- if tn in EXCLUDED_TASKS:
1234
- print(f"Task {tn:3d}: EXCLUDED (officially)")
1235
- continue
1236
-
1237
- td = tasks[tn]['data']
1238
- ok, sname, sz, t_task, model_path = solve_task(tn, td, output_dir, conv_budget)
1239
-
1240
- if ok:
1241
- macs, memory, params = score_network(model_path)
1242
- if macs is None:
1243
- macs, memory, params = 0, 0, 0
1244
- cost = macs + memory + params
1245
- score = max(1.0, 25.0 - math.log(max(1, cost)))
1246
- total_score += score
1247
-
1248
- results[tn] = (sname, t_task, sz)
1249
- costs_dict[tn] = cost
1250
- print(f"Task {tn:3d}: {sname:25s} {score:7.3f} {cost:>12} {t_task:7.3f}s ({sz:>8,} bytes)")
1251
- else:
1252
- print(f"Task {tn:3d}: UNSOLVED {t_task:7.3f}s")
1253
- cost = 0
1254
-
1255
- if use_wandb and wandb is not None:
1256
- wandb.log({
1257
- "task_id": tn,
1258
- "solver": sname if ok else "unsolved",
1259
- "onnx_bytes": sz if ok else 0,
1260
- "task_time_sec": t_task,
1261
- "cost": cost,
1262
- "score": score if ok else 0,
1263
- })
1264
-
1265
- return results, costs_dict, total_score
1266
-
1267
-
1268
- def main():
1269
- parser = argparse.ArgumentParser()
1270
- parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
1271
- parser.add_argument('--arcgen_dir', default='', help='Path to ARC-GEN-100K/ directory')
1272
- parser.add_argument('--output_dir', default='/kaggle/working/submission')
1273
- parser.add_argument('--kaggle', action='store_true')
1274
- parser.add_argument('--conv_budget', type=float, default=30.0)
1275
- parser.add_argument('--tasks', type=str, default='')
1276
- parser.add_argument('--device', type=str, default='auto', choices=['auto','cpu','cuda'])
1277
- parser.add_argument('--use_wandb', action='store_true')
1278
- args = parser.parse_args()
1279
- global ORT_PROVIDERS
1280
- config = {
1281
- "device": args.device,
1282
- "conv_budget": args.conv_budget,
1283
- "data_dir": args.data_dir,
1284
- "arcgen_dir": args.arcgen_dir,
1285
- "tasks": args.tasks,
1286
- }
1287
-
1288
- if args.device == 'cuda':
1289
- ORT_PROVIDERS = ['CUDAExecutionProvider', 'CPUExecutionProvider']
1290
- elif args.device == 'cpu':
1291
- ORT_PROVIDERS = ['CPUExecutionProvider']
1292
-
1293
- ort.set_default_logger_severity(3)
1294
- print(f"Using providers: {ORT_PROVIDERS}")
1295
-
1296
- if args.kaggle:
1297
- tasks = load_tasks_kaggle(args.data_dir)
1298
- else:
1299
- arcgen = args.arcgen_dir if args.arcgen_dir else None
1300
- tasks = load_tasks_dir(args.data_dir, arcgen_dir=arcgen)
1301
-
1302
- # Count arc-gen examples
1303
- total_arcgen = sum(len(t['data'].get('arc-gen', [])) for t in tasks.values())
1304
- print(f"Loaded {len(tasks)} tasks ({total_arcgen} ARC-GEN examples)")
1305
- print(f"Excluded tasks: {sorted(EXCLUDED_TASKS)}")
1306
-
1307
- task_nums = [int(t) for t in args.tasks.split(',')] if args.tasks else sorted(tasks.keys())
1308
- active_tasks = [t for t in task_nums if t not in EXCLUDED_TASKS]
1309
- print(f"Solving {len(active_tasks)} active tasks (skipping {len(task_nums) - len(active_tasks)} excluded)")
1310
- print(f"Conv budget: {args.conv_budget}s per task")
1311
- print("=" * 70)
1312
- t0 = time.time()
1313
-
1314
- if args.use_wandb and wandb is not None:
1315
- with wandb.init(project="neurogolf", name="solver_run", config=config):
1316
- results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=True)
1317
- else:
1318
- results, costs_dict, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=False)
1319
-
1320
- elapsed = time.time() - t0
1321
- print(f"\n{'='*70}")
1322
- print(f"Solved: {len(results)}/{len(active_tasks)} active tasks in {elapsed:.0f}s")
1323
- solver_names = [v[0] for v in results.values()]
1324
- sc = Counter(solver_names)
1325
- for s, c in sc.most_common(): print(f" {s}: {c}")
1326
-
1327
- # Generate submission
1328
- outdir = args.output_dir
1329
- n_files = len([f for f in os.listdir(outdir) if f.endswith('.onnx')])
1330
- total_size = sum(os.path.getsize(os.path.join(outdir, f))
1331
- for f in os.listdir(outdir) if f.endswith('.onnx'))
1332
-
1333
- # Create submission.zip
1334
- zip_path = os.path.join(os.path.dirname(outdir) or '/kaggle/working/', 'submission.zip')
1335
- buf = io.BytesIO()
1336
- with zipfile.ZipFile(buf, 'w', zipfile.ZIP_DEFLATED) as zf:
1337
- for f in sorted(os.listdir(outdir)):
1338
- if f.endswith('.onnx'):
1339
- zf.write(os.path.join(outdir, f), f)
1340
- zip_bytes = buf.getvalue()
1341
- with open(zip_path, 'wb') as f:
1342
- f.write(zip_bytes)
1343
- zip_size = len(zip_bytes)
1344
-
1345
- # Create submission.csv
1346
- csv_path = os.path.join(os.path.dirname(outdir) or '.', 'submission.csv')
1347
- with open(csv_path, 'w', newline='') as f:
1348
- w = csv.writer(f)
1349
- w.writerow(['task_id', 'total_cost'])
1350
- for tn in sorted(costs_dict.keys()):
1351
- w.writerow([f'task{tn:03d}', costs_dict[tn]])
1352
-
1353
- # Estimate LB score: solved tasks get their score, unsolved get 1.0
1354
- unsolved_count = len(active_tasks) - len(results)
1355
- est_lb = total_score + unsolved_count * 1.0
1356
-
1357
- print(f"\n{n_files} ONNX files, {total_size/1024:.1f} KB uncompressed")
1358
- print(f"ZIP size: {zip_size/1024:.1f} KB / {MAX_FILESIZE/1024:.0f} KB limit {'OK' if zip_size <= MAX_FILESIZE else 'OVER!'}")
1359
- print(f"Estimated LB score: {est_lb:.1f} (solved: {total_score:.1f} + unsolved: {unsolved_count}×1.0)")
1360
- print(f"Written: {zip_path} | {csv_path}")
1361
-
1362
- if __name__ == '__main__':
1363
- main()
1364
-