rogermt commited on
Commit
59533d2
·
verified ·
1 Parent(s): c228c49

v4: ARC-GEN validation, new analytical solvers, static profiler, s_flip opset fix

Browse files
Files changed (1) hide show
  1. neurogolf_solver.py +1 -1076
neurogolf_solver.py CHANGED
@@ -1,1076 +1 @@
1
- #!/usr/bin/env python3
2
- """
3
- ARC-AGI NeuroGolf Championship - Complete Solver v3
4
- Format: [1,10,30,30] one-hot input/output, opset 10, IR version 10.
5
- Solvers:
6
- - Analytical: identity, constant, color_map, transpose, flip, rotate, tile, upscale,
7
- concat, concat_enhanced, spatial_gather, varshape_spatial_gather,
8
- input_driven_tile, diagonal_tile, kronecker
9
- - Conv (fixed shape): Slice -> Conv -> ArgMax -> Equal+Cast -> Pad
10
- - Conv (variable shape): Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)
11
- - Conv (diff shape): Slice -> Conv -> Slice(crop) -> ArgMax -> Equal+Cast -> Pad
12
- Results: ~305+/400 tasks solved (was 294/400 in v2)
13
- Usage:
14
- python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission
15
- python neurogolf_solver.py --data_dir ARC-AGI/data/training/ --output_dir submission --conv_budget 60
16
- """
17
-
18
- import json, os, sys, math, time, argparse
19
- import numpy as np
20
- import onnx
21
- from onnx import helper, TensorProto, numpy_helper
22
- import onnxruntime as ort
23
- from collections import Counter
24
-
25
- try:
26
- from neurogolf_utils import score_network
27
- except ImportError:
28
- def score_network(path):
29
- return 0, 0, 0
30
-
31
- try:
32
- import wandb
33
- except ImportError:
34
- wandb = None
35
-
36
- BATCH, CH, GH, GW = 1, 10, 30, 30
37
- GRID_SHAPE = [BATCH, CH, GH, GW]
38
- DT = TensorProto.FLOAT
39
- IR = 10
40
- OPSET = [helper.make_opsetid("", 10)]
41
-
42
- def get_providers():
43
- return ['CPUExecutionProvider']
44
-
45
- ORT_PROVIDERS = get_providers()
46
-
47
- # ============================================================
48
- # LOAD / VALIDATE
49
- # ============================================================
50
-
51
- def load_tasks_dir(data_dir):
52
- files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
53
- tasks = {}
54
- for i, f in enumerate(files):
55
- with open(os.path.join(data_dir, f)) as fh:
56
- tasks[i+1] = {'hex': f.replace('.json',''), 'data': json.load(fh)}
57
- return tasks
58
-
59
- def load_tasks_kaggle(data_dir):
60
- tasks = {}
61
- for tn in range(1, 401):
62
- path = os.path.join(data_dir, f"task{tn:03d}.json")
63
- if os.path.exists(path):
64
- with open(path) as f:
65
- tasks[tn] = {'hex': f'task{tn:03d}', 'data': json.load(f)}
66
- return tasks
67
-
68
- def to_onehot(grid):
69
- arr = np.zeros((1, CH, GH, GW), dtype=np.float32)
70
- for r, row in enumerate(grid):
71
- for c, v in enumerate(row):
72
- arr[0, v, r, c] = 1.0
73
- return arr
74
-
75
- def validate(path, td):
76
- try:
77
- sess = ort.InferenceSession(path, providers=ORT_PROVIDERS)
78
- except:
79
- return False
80
- examples = td['train'] + td['test']
81
- if 'arc-gen' in td:
82
- examples = examples + td['arc-gen']
83
- for ex in examples:
84
- inp = to_onehot(ex['input'])
85
- exp = to_onehot(ex['output'])
86
- try:
87
- out = sess.run(['output'], {'input': inp})[0]
88
- out = (out > 0.0).astype(np.float32)
89
- except:
90
- return False
91
- if not np.array_equal(out, exp):
92
- return False
93
- return True
94
-
95
- def mk(nodes, inits=None):
96
- x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
97
- y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
98
- g = helper.make_graph(nodes, "g", [x], [y], initializer=inits or [])
99
- return helper.make_model(g, ir_version=IR, opset_imports=OPSET)
100
-
101
- def get_exs(td):
102
- return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
103
- for ex in td['train'] + td['test']]
104
-
105
- def fixed_shapes(td):
106
- shapes = set()
107
- for inp, out in get_exs(td):
108
- shapes.add((inp.shape, out.shape))
109
- return list(shapes)[0] if len(shapes) == 1 else None
110
-
111
- # ============================================================
112
- # GATHER HELPERS
113
- # ============================================================
114
-
115
- def _build_gather_model(OH, OW, idx):
116
- # Use Gather (opset 1) instead of GatherElements (opset 11)
117
- # Flatten spatial: [1,10,900] -> Gather(axis=2, indices=[900]) -> [1,10,900]
118
- flat_idx = np.zeros((GH*GW,), dtype=np.int64)
119
- mask = np.zeros((1,1,GH,GW), dtype=np.float32)
120
- for oi in range(OH):
121
- for oj in range(OW):
122
- flat_idx[oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
123
- mask[0,0,oi,oj] = 1.0
124
- inits = [
125
- numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
126
- numpy_helper.from_array(flat_idx, 'idx'),
127
- numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
128
- numpy_helper.from_array(mask, 'mask'),
129
- ]
130
- nodes = [
131
- helper.make_node('Reshape', ['input','fs'], ['flat']),
132
- helper.make_node('Gather', ['flat','idx'], ['g'], axis=2),
133
- helper.make_node('Reshape', ['g','os'], ['raw']),
134
- helper.make_node('Mul', ['raw','mask'], ['output']),
135
- ]
136
- return mk(nodes, inits)
137
-
138
- def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
139
- # Use Gather (opset 1) instead of GatherElements (opset 11)
140
- flat_idx = np.zeros((GH*GW,), dtype=np.int64)
141
- gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
142
- const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
143
- for oi in range(OH):
144
- for oj in range(OW):
145
- if idx[oi,oj,0] >= 0:
146
- flat_idx[oi*GW+oj] = idx[oi,oj,0]*GW + idx[oi,oj,1]
147
- gather_mask[0,0,oi,oj] = 1.0
148
- elif cst[oi,oj] >= 0:
149
- const_oh[0, cst[oi,oj], oi, oj] = 1.0
150
- has_const = np.any(const_oh > 0)
151
- inits = [
152
- numpy_helper.from_array(np.array([1,10,GH*GW], dtype=np.int64), 'fs'),
153
- numpy_helper.from_array(flat_idx, 'idx'),
154
- numpy_helper.from_array(np.array([1,10,GH,GW], dtype=np.int64), 'os'),
155
- numpy_helper.from_array(gather_mask, 'gmask'),
156
- ]
157
- nodes = [
158
- helper.make_node('Reshape', ['input','fs'], ['flat']),
159
- helper.make_node('Gather', ['flat','idx'], ['g'], axis=2),
160
- helper.make_node('Reshape', ['g','os'], ['raw']),
161
- helper.make_node('Mul', ['raw','gmask'], ['masked']),
162
- ]
163
- if has_const:
164
- inits.append(numpy_helper.from_array(const_oh, 'cst'))
165
- nodes.append(helper.make_node('Add', ['masked','cst'], ['output']))
166
- else:
167
- nodes[-1] = helper.make_node('Mul', ['raw','gmask'], ['output'])
168
- return mk(nodes, inits)
169
-
170
- # ============================================================
171
- # ANALYTICAL SOLVERS
172
- # ============================================================
173
-
174
- def s_identity(td):
175
- for ex in td['train']+td['test']:
176
- if ex['input'] != ex['output']: return None
177
- return mk([helper.make_node('Identity', ['input'], ['output'])])
178
-
179
- def s_color_map(td):
180
- cm = {}
181
- for ex in td['train']+td['test']:
182
- inp, out = np.array(ex['input']), np.array(ex['output'])
183
- if inp.shape != out.shape: return None
184
- for iv, ov in zip(inp.flat, out.flat):
185
- iv, ov = int(iv), int(ov)
186
- if iv in cm and cm[iv] != ov: return None
187
- cm[iv] = ov
188
- W = np.zeros((10,10,1,1), dtype=np.float32)
189
- for ic in range(10):
190
- W[cm.get(ic,ic), ic, 0, 0] = 1.0
191
- return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
192
- [numpy_helper.from_array(W, 'W')])
193
-
194
- def s_transpose(td):
195
- for ex in td['train']+td['test']:
196
- if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T): return None
197
- return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0,1,3,2])])
198
-
199
- def s_flip(td):
200
- exs = get_exs(td)
201
- sp = fixed_shapes(td)
202
- if sp is None: return None
203
- (IH,IW),(OH,OW) = sp
204
- if (IH,IW) != (OH,OW): return None
205
- for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
206
- if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
207
- if axis == 0:
208
- idx = np.arange(GH).reshape(1,1,GH,1).repeat(CH,1).repeat(GW,3)
209
- for r in range(IH): idx[0,:,r,:] = IH - 1 - r
210
- else:
211
- idx = np.arange(GW).reshape(1,1,1,GW).repeat(CH,1).repeat(GH,2)
212
- for c in range(IW): idx[0,:,:,c] = IW - 1 - c
213
- ax = 2 if axis == 0 else 3
214
- return mk([helper.make_node('GatherElements', ['input','idx'], ['output'], axis=ax)],
215
- [numpy_helper.from_array(idx.astype(np.int64), 'idx')])
216
- return None
217
-
218
- def s_rotate(td):
219
- exs = get_exs(td)
220
- sp = fixed_shapes(td)
221
- if sp is None: return None
222
- (IH,IW),(OH,OW) = sp
223
- for k in [1, 2, 3]:
224
- if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs): continue
225
- idx = np.zeros((OH,OW,2), dtype=np.int64)
226
- for r in range(OH):
227
- for c in range(OW):
228
- if k == 1: sr, sc = c, IH-1-r
229
- elif k == 2: sr, sc = IH-1-r, IW-1-c
230
- elif k == 3: sr, sc = IW-1-c, r
231
- idx[r,c] = [sr, sc]
232
- return _build_gather_model(OH, OW, idx)
233
- return None
234
-
235
- def s_spatial_gather(td):
236
- sp = fixed_shapes(td)
237
- if sp is None: return None
238
- (IH,IW),(OH,OW) = sp
239
- exs = get_exs(td)
240
- idx = np.full((OH,OW,2), -1, dtype=np.int64)
241
- cst = np.full((OH,OW), -1, dtype=np.int64)
242
- for oi in range(OH):
243
- for oj in range(OW):
244
- vals = set(int(out[oi,oj]) for _,out in exs)
245
- if len(vals) == 1: cst[oi,oj] = vals.pop()
246
- found = False
247
- for ri in range(IH):
248
- for rj in range(IW):
249
- if all(int(inp[ri,rj]) == int(out[oi,oj]) for inp,out in exs):
250
- idx[oi,oj] = [ri, rj]; found = True; break
251
- if found: break
252
- if not found and cst[oi,oj] < 0: return None
253
- return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
254
-
255
- def s_varshape_spatial_gather(td):
256
- """Spatial gather that works for variable-shape tasks by embedding in 30x30."""
257
- sp = fixed_shapes(td)
258
- if sp is not None: return None # fixed shapes handled by s_spatial_gather
259
- exs = get_exs(td)
260
-
261
- # Embed all examples in 30x30
262
- exs_30 = []
263
- for inp, out in exs:
264
- ih, iw = inp.shape
265
- oh, ow = out.shape
266
- inp30 = np.zeros((30, 30), dtype=np.int64)
267
- out30 = np.zeros((30, 30), dtype=np.int64)
268
- inp30[:ih, :iw] = inp
269
- out30[:oh, :ow] = out
270
- exs_30.append((inp30, out30))
271
-
272
- idx = np.full((30, 30, 2), -1, dtype=np.int64)
273
- cst = np.full((30, 30), -1, dtype=np.int64)
274
-
275
- for oi in range(30):
276
- for oj in range(30):
277
- vals = set(int(out30[oi, oj]) for _, out30 in exs_30)
278
- if len(vals) == 1:
279
- cst[oi, oj] = vals.pop()
280
- found = False
281
- for ri in range(30):
282
- for rj in range(30):
283
- if all(int(inp30[ri, rj]) == int(out30[oi, oj]) for inp30, out30 in exs_30):
284
- idx[oi, oj] = [ri, rj]
285
- found = True
286
- break
287
- if found: break
288
- if not found and cst[oi, oj] < 0:
289
- return None
290
-
291
- return _build_gather_model_with_const(30, 30, 30, 30, idx, cst)
292
-
293
- def s_tile(td):
294
- exs = get_exs(td)
295
- in_shapes = set(inp.shape for inp,_ in exs)
296
- if len(in_shapes) != 1: return None
297
- IH, IW = in_shapes.pop()
298
- tiles = set()
299
- for inp, out in exs:
300
- OH, OW = out.shape
301
- if OH % IH or OW % IW: return None
302
- rH, rW = OH//IH, OW//IW
303
- if rH < 1 or rW < 1 or (rH==1 and rW==1): return None
304
- tiles.add((rH, rW))
305
- if len(tiles) != 1: return None
306
- rH, rW = tiles.pop()
307
- OH, OW = IH*rH, IW*rW
308
- if OH > 30 or OW > 30: return None
309
- for inp, out in exs:
310
- if not np.array_equal(out, np.tile(inp, (rH, rW))): return None
311
- pad_h, pad_w = 30-OH, 30-OW
312
- inits = [
313
- numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'st'),
314
- numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'en'),
315
- numpy_helper.from_array(np.array([1,1,rH,rW], dtype=np.int64), 'rp'),
316
- ]
317
- nodes = [
318
- helper.make_node('Slice', ['input','st','en'], ['cr']),
319
- helper.make_node('Tile', ['cr','rp'], ['tl']),
320
- helper.make_node('Pad', ['tl'], ['output'], pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0),
321
- ]
322
- return mk(nodes, inits)
323
-
324
- def s_upscale(td):
325
- exs = get_exs(td)
326
- in_shapes = set(inp.shape for inp,_ in exs)
327
- if len(in_shapes) != 1: return None
328
- IH, IW = in_shapes.pop()
329
- scales = set()
330
- for inp, out in exs:
331
- OH, OW = out.shape
332
- if OH % IH or OW % IW: return None
333
- sH, sW = OH//IH, OW//IW
334
- if sH < 2 or sW < 2: return None
335
- scales.add((sH, sW))
336
- if len(scales) != 1: return None
337
- sH, sW = scales.pop()
338
- OH, OW = IH*sH, IW*sW
339
- if OH > 30 or OW > 30: return None
340
- for inp, out in exs:
341
- if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)): return None
342
- idx = np.zeros((OH,OW,2), dtype=np.int64)
343
- for r in range(OH):
344
- for c in range(OW):
345
- idx[r,c] = [r//sH, c//sW]
346
- return _build_gather_model(OH, OW, idx)
347
-
348
- def s_concat(td):
349
- from itertools import product as iproduct
350
- exs = get_exs(td)
351
- sp = fixed_shapes(td)
352
- if sp is None: return None
353
- (IH,IW),(OH,OW) = sp
354
- transforms = [
355
- ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
356
- ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
357
- ]
358
- if OH == IH and OW % IW == 0 and OW > IW:
359
- n = OW // IW
360
- if 2 <= n <= 4:
361
- for combo in iproduct(range(4), repeat=n):
362
- if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=1))
363
- for inp, out in exs):
364
- idx = np.zeros((OH,OW,2), dtype=np.int64)
365
- for oi in range(OH):
366
- for oj in range(OW):
367
- bj = oj // IW; lr, lc = oi, oj % IW
368
- t = transforms[combo[bj]][0]
369
- if t == 'id': sr, sc = lr, lc
370
- elif t == 'fliplr': sr, sc = lr, IW-1-lc
371
- elif t == 'flipud': sr, sc = IH-1-lr, lc
372
- elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
373
- idx[oi,oj] = [sr, sc]
374
- return _build_gather_model(OH, OW, idx)
375
- if OW == IW and OH % IH == 0 and OH > IH:
376
- n = OH // IH
377
- if 2 <= n <= 4:
378
- for combo in iproduct(range(4), repeat=n):
379
- if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=0))
380
- for inp, out in exs):
381
- idx = np.zeros((OH,OW,2), dtype=np.int64)
382
- for oi in range(OH):
383
- for oj in range(OW):
384
- bi = oi // IH; lr, lc = oi % IH, oj
385
- t = transforms[combo[bi]][0]
386
- if t == 'id': sr, sc = lr, lc
387
- elif t == 'fliplr': sr, sc = lr, IW-1-lc
388
- elif t == 'flipud': sr, sc = IH-1-lr, lc
389
- elif t == 'rot180': sr, sc = IH-1-lr, IW-1-lc
390
- idx[oi,oj] = [sr, sc]
391
- return _build_gather_model(OH, OW, idx)
392
- return None
393
-
394
- def s_concat_enhanced(td):
395
- """Enhanced concat with all 8 dihedral group transforms."""
396
- exs = get_exs(td)
397
- sp = fixed_shapes(td)
398
- if sp is None: return None
399
- (IH,IW),(OH,OW) = sp
400
- if IH == OH and IW == OW: return None
401
-
402
- # Need block decomposition
403
- if OH % IH != 0 or OW % IW != 0: return None
404
- rH, rW = OH // IH, OW // IW
405
- if rH * rW > 16 or rH * rW < 2: return None
406
- if OH > 30 or OW > 30: return None
407
-
408
- # All 8 symmetry transforms of the dihedral group
409
- transforms = [
410
- ('id', lambda x: x),
411
- ('fliplr', lambda x: np.fliplr(x)),
412
- ('flipud', lambda x: np.flipud(x)),
413
- ('rot180', lambda x: np.rot90(x, 2)),
414
- ('rot90', lambda x: np.rot90(x, 1)),
415
- ('rot270', lambda x: np.rot90(x, 3)),
416
- ('T', lambda x: x.T),
417
- ('T_fliplr', lambda x: np.fliplr(x.T)),
418
- ]
419
-
420
- # For each block, find which transform matches
421
- block_transforms = {}
422
- for bi in range(rH):
423
- for bj in range(rW):
424
- found = None
425
- for tidx, (tname, tfn) in enumerate(transforms):
426
- ok = True
427
- for inp, out in exs:
428
- block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
429
- expected = tfn(inp)
430
- if expected.shape != (IH, IW) or not np.array_equal(block, expected):
431
- ok = False
432
- break
433
- if ok:
434
- found = (tidx, tname)
435
- break
436
- if found is None:
437
- return None
438
- block_transforms[(bi, bj)] = found
439
-
440
- # Build index map
441
- idx = np.zeros((OH, OW, 2), dtype=np.int64)
442
- for bi in range(rH):
443
- for bj in range(rW):
444
- _, tname = block_transforms[(bi, bj)]
445
- for lr in range(IH):
446
- for lc in range(IW):
447
- oi, oj = bi*IH + lr, bj*IW + lc
448
- if tname == 'id': sr, sc = lr, lc
449
- elif tname == 'fliplr': sr, sc = lr, IW-1-lc
450
- elif tname == 'flipud': sr, sc = IH-1-lr, lc
451
- elif tname == 'rot180': sr, sc = IH-1-lr, IW-1-lc
452
- elif tname == 'rot90': sr, sc = IW-1-lc, lr
453
- elif tname == 'rot270': sr, sc = lc, IH-1-lr
454
- elif tname == 'T': sr, sc = lc, lr
455
- elif tname == 'T_fliplr': sr, sc = IW-1-lc, lr
456
- idx[oi, oj] = [sr, sc]
457
-
458
- # Verify
459
- for inp, out in exs:
460
- reconstructed = np.zeros_like(out)
461
- for oi in range(OH):
462
- for oj in range(OW):
463
- reconstructed[oi,oj] = inp[idx[oi,oj,0], idx[oi,oj,1]]
464
- if not np.array_equal(reconstructed, out):
465
- return None
466
-
467
- return _build_gather_model(OH, OW, idx)
468
-
469
- def s_input_driven_tile(td):
470
- """Each non-zero input pixel controls a block that's a copy of the input."""
471
- exs = get_exs(td)
472
- sp = fixed_shapes(td)
473
- if sp is None: return None
474
- (IH,IW),(OH,OW) = sp
475
- if OH % IH != 0 or OW % IW != 0: return None
476
- sH, sW = OH // IH, OW // IW
477
- if sH != IH or sW != IW: return None
478
- if OH > 30 or OW > 30: return None
479
-
480
- for inp, out in exs:
481
- for bi in range(IH):
482
- for bj in range(IW):
483
- block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
484
- if inp[bi, bj] != 0:
485
- if not np.array_equal(block, inp):
486
- return None
487
- else:
488
- if not np.all(block == 0):
489
- return None
490
-
491
- # Build gather model: each output pixel at (bi*IH+lr, bj*IW+lc) maps to
492
- # input[lr, lc] if input[bi, bj] != 0, else constant 0
493
- # Problem: whether block is active depends on input value, which varies.
494
- # This needs a different ONNX approach: can't use static gather.
495
- # But we CAN use: Tile input -> Mul by mask derived from input
496
- # Actually we need: for each (bi,bj) block position, multiply by inp[bi,bj] != 0
497
- # This is NOT static - it depends on input content.
498
- # Skip for now - spatial_gather can handle if block positions are fixed.
499
- return None
500
-
501
- def s_kronecker(td):
502
- """output = kron(input, ones(sH,sW)) — nearest-neighbor upscaling."""
503
- exs = get_exs(td)
504
- sp = fixed_shapes(td)
505
- if sp is None: return None
506
- (IH,IW),(OH,OW) = sp
507
- if OH % IH != 0 or OW % IW != 0: return None
508
- sH, sW = OH // IH, OW // IW
509
- if sH < 2 or sW < 2: return None
510
- if OH > 30 or OW > 30: return None
511
-
512
- for inp, out in exs:
513
- expected = np.kron(inp, np.ones((sH, sW), dtype=np.int64))
514
- if not np.array_equal(out, expected):
515
- return None
516
-
517
- # This is identical to upscale - build gather index
518
- idx = np.zeros((OH,OW,2), dtype=np.int64)
519
- for r in range(OH):
520
- for c in range(OW):
521
- idx[r,c] = [r//sH, c//sW]
522
- return _build_gather_model(OH, OW, idx)
523
-
524
- def s_diagonal_tile(td):
525
- """Input placed along diagonal: block[i,i] = input, rest = 0."""
526
- exs = get_exs(td)
527
- sp = fixed_shapes(td)
528
- if sp is None: return None
529
- (IH,IW),(OH,OW) = sp
530
- if OH % IH != 0 or OW % IW != 0: return None
531
- rH, rW = OH // IH, OW // IW
532
- if rH != rW or rH < 2: return None
533
- if OH > 30 or OW > 30: return None
534
-
535
- for inp, out in exs:
536
- for bi in range(rH):
537
- for bj in range(rW):
538
- block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
539
- if bi == bj:
540
- if not np.array_equal(block, inp):
541
- return None
542
- else:
543
- if not np.all(block == 0):
544
- return None
545
-
546
- # Build: diagonal blocks map to input, off-diagonal are constant 0
547
- idx = np.zeros((OH,OW,2), dtype=np.int64)
548
- cst = np.full((OH,OW), -1, dtype=np.int64)
549
- for bi in range(rH):
550
- for bj in range(rW):
551
- for lr in range(IH):
552
- for lc in range(IW):
553
- oi, oj = bi*IH + lr, bj*IW + lc
554
- if bi == bj:
555
- idx[oi, oj] = [lr, lc]
556
- else:
557
- idx[oi, oj] = [-1, -1]
558
- cst[oi, oj] = 0
559
-
560
- return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
561
-
562
- def s_constant(td):
563
- sp = fixed_shapes(td)
564
- if sp is None: return None
565
- exs = get_exs(td)
566
- outs = [out for _,out in exs]
567
- if not all(np.array_equal(outs[0], o) for o in outs[1:]): return None
568
- const = np.zeros((1,10,30,30), dtype=np.float32)
569
- for r, row in enumerate(outs[0]):
570
- for c, v in enumerate(row):
571
- const[0, int(v), r, c] = 1.0
572
- inits = [numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
573
- numpy_helper.from_array(const, 'c')]
574
- nodes = [helper.make_node('Mul', ['input','z'], ['zd']),
575
- helper.make_node('ReduceSum', ['zd'], ['s'], axes=[1,2,3], keepdims=1),
576
- helper.make_node('Add', ['s','c'], ['output'])]
577
- return mk(nodes, inits)
578
-
579
- # ============================================================
580
- # CONV SOLVERS
581
- # ============================================================
582
-
583
- def add_onehot_block(nodes, inits, am_name, oh_name):
584
- """Equal + Cast one-hot encoding (replaces OneHot which lacks CUDA kernel)."""
585
- classes = np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1)
586
- inits.append(numpy_helper.from_array(classes, 'classes'))
587
- nodes.append(helper.make_node('Equal', [am_name, 'classes'], ['eq']))
588
- nodes.append(helper.make_node('Cast', ['eq'], [oh_name], to=TensorProto.FLOAT))
589
-
590
- def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
591
- """Shared lstsq conv fitting. Returns (Wconv, B) or None."""
592
- pad = ks // 2
593
- feat = 10 * ks * ks + (1 if use_bias else 0)
594
- if feat > 20000: return None
595
-
596
- patches, targets = [], []
597
- for inp_g, out_g in exs_raw:
598
- ih, iw = inp_g.shape
599
- if use_full_30:
600
- oh_full = np.zeros((10, GH, GW), dtype=np.float64)
601
- for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
602
- oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
603
- else:
604
- oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
605
- for c in range(10): oh_enc[c] = (inp_g == c)
606
- oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
607
-
608
- oh, ow = out_g.shape
609
- for r in range(oh):
610
- for c in range(ow):
611
- p = oh_pad[:, r:r+ks, c:c+ks].flatten()
612
- if use_bias: p = np.append(p, 1.0)
613
- patches.append(p)
614
- targets.append(int(out_g[r, c]))
615
-
616
- n_patches = len(patches)
617
- if feat > 5000 and n_patches > 2000: return None
618
-
619
- P = np.array(patches, dtype=np.float64)
620
- T = np.array(targets, dtype=np.int64)
621
- T_oh = np.zeros((len(T), 10), dtype=np.float64)
622
- for i, t in enumerate(T): T_oh[i, t] = 1.0
623
-
624
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
625
- if not np.array_equal(np.argmax(P @ WT, axis=1), T): return None
626
-
627
- if use_bias:
628
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
629
- B = WT[-1].astype(np.float32)
630
- else:
631
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
632
- B = None
633
- return Wconv, B
634
-
635
- def solve_conv_fixed(td, path, time_budget=30.0):
636
- """Fixed-shape conv: Slice -> Conv -> ArgMax -> Equal+Cast -> Pad."""
637
- exs = get_exs(td)
638
- for inp, out in exs:
639
- if inp.shape != out.shape: return None
640
- shapes = set(inp.shape for inp, _ in exs)
641
- if len(shapes) != 1: return None
642
- IH, IW = shapes.pop()
643
-
644
- t_start = time.time()
645
- for use_bias in [False, True]:
646
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
647
- if time.time() - t_start > time_budget: return None
648
- result = _lstsq_conv(exs, ks, use_bias, use_full_30=False)
649
- if result is None: continue
650
- Wconv, B = result
651
- pad = ks // 2
652
- pad_h, pad_w = GH - IH, GW - IW
653
-
654
- inits = [
655
- numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
656
- numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
657
- numpy_helper.from_array(Wconv, 'W'),
658
- ]
659
- conv_inputs = ['grid', 'W']
660
- if B is not None:
661
- inits.append(numpy_helper.from_array(B, 'B'))
662
- conv_inputs.append('B')
663
-
664
- nodes = [
665
- helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
666
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
667
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
668
- ]
669
- add_onehot_block(nodes, inits, 'am', 'oh_out')
670
- nodes.append(
671
- helper.make_node('Pad', ['oh_out'], ['output'],
672
- pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
673
- )
674
-
675
- model = mk(nodes, inits)
676
- onnx.save(model, path)
677
- if validate(path, td): return 'conv_fixed', model
678
- return None
679
-
680
- def solve_conv_variable(td, path, time_budget=30.0):
681
- """Variable-shape conv: Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(mask)."""
682
- exs = get_exs(td)
683
- for inp, out in exs:
684
- if inp.shape != out.shape: return None
685
-
686
- t_start = time.time()
687
- for use_bias in [False, True]:
688
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
689
- if time.time() - t_start > time_budget: return None
690
- result = _lstsq_conv(exs, ks, use_bias, use_full_30=True)
691
- if result is None: continue
692
- Wconv, B = result
693
- pad = ks // 2
694
-
695
- inits = [numpy_helper.from_array(Wconv, 'W')]
696
- conv_inputs = ['input', 'W']
697
- if B is not None:
698
- inits.append(numpy_helper.from_array(B, 'B'))
699
- conv_inputs.append('B')
700
-
701
- nodes = [
702
- helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
703
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
704
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
705
- ]
706
- add_onehot_block(nodes, inits, 'am', 'oh_out')
707
- nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
708
-
709
- model = mk(nodes, inits)
710
- onnx.save(model, path)
711
- if validate(path, td): return 'conv_var', model
712
- return None
713
-
714
- def solve_conv_diffshape(td, path, time_budget=30.0):
715
- """Diff-shape conv for fixed io shapes where output is smaller."""
716
- sp = fixed_shapes(td)
717
- if sp is None: return None
718
- (IH, IW), (OH, OW) = sp
719
- if IH == OH and IW == OW: return None
720
- if OH > IH or OW > IW: return None
721
- if OH > 30 or OW > 30: return None
722
-
723
- exs = get_exs(td)
724
- t_start = time.time()
725
-
726
- for dr_off, dc_off in [(0, 0), ((IH-OH)//2, (IW-OW)//2)]:
727
- for use_bias in [False, True]:
728
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
729
- if time.time() - t_start > time_budget: return None
730
- pad = ks // 2
731
- feat = 10 * ks * ks + (1 if use_bias else 0)
732
- if feat > 10000: continue
733
-
734
- patches, targets = [], []
735
- valid = True
736
- for inp_g, out_g in exs:
737
- oh_enc = np.zeros((10, IH, IW), dtype=np.float64)
738
- for c in range(10): oh_enc[c] = (inp_g == c)
739
- oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
740
- for r in range(OH):
741
- for c in range(OW):
742
- sr, sc = r + dr_off, c + dc_off
743
- if sr < 0 or sr >= IH or sc < 0 or sc >= IW:
744
- valid = False; break
745
- p = oh_pad[:, sr:sr+ks, sc:sc+ks].flatten()
746
- if use_bias: p = np.append(p, 1.0)
747
- patches.append(p)
748
- targets.append(int(out_g[r, c]))
749
- if not valid: break
750
- if not valid: break
751
- if not valid: continue
752
-
753
- n_patches = len(patches)
754
- if feat > 5000 and n_patches > 2000: continue
755
-
756
- P = np.array(patches, dtype=np.float64)
757
- T = np.array(targets, dtype=np.int64)
758
- T_oh = np.zeros((len(T), 10), dtype=np.float64)
759
- for i, t in enumerate(T): T_oh[i, t] = 1.0
760
-
761
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
762
- if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
763
-
764
- if use_bias:
765
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
766
- B = WT[-1].astype(np.float32)
767
- else:
768
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
769
- B = None
770
-
771
- pad_h, pad_w = GH - OH, GW - OW
772
- inits = [
773
- numpy_helper.from_array(np.array([0,0,0,0], dtype=np.int64), 'sl_st'),
774
- numpy_helper.from_array(np.array([1,10,IH,IW], dtype=np.int64), 'sl_en'),
775
- numpy_helper.from_array(Wconv, 'W'),
776
- numpy_helper.from_array(np.array([0,0,dr_off,dc_off], dtype=np.int64), 'cr_st'),
777
- numpy_helper.from_array(np.array([1,10,dr_off+OH,dc_off+OW], dtype=np.int64), 'cr_en'),
778
- ]
779
- conv_inputs = ['grid', 'W']
780
- if B is not None:
781
- inits.append(numpy_helper.from_array(B, 'B'))
782
- conv_inputs.append('B')
783
-
784
- nodes = [
785
- helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
786
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
787
- helper.make_node('Slice', ['co','cr_st','cr_en'], ['co_crop']),
788
- helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
789
- ]
790
- add_onehot_block(nodes, inits, 'am', 'oh_out')
791
- nodes.append(
792
- helper.make_node('Pad', ['oh_out'], ['output'],
793
- pads=[0,0,0,0,0,0,pad_h,pad_w], value=0.0)
794
- )
795
-
796
- model = mk(nodes, inits)
797
- onnx.save(model, path)
798
- if validate(path, td): return 'conv_diff', model
799
- return None
800
-
801
- def solve_conv_var_diff(td, path, time_budget=30.0):
802
- """Variable diff-shape conv: Conv(30x30) -> ArgMax -> Equal+Cast -> Mul(output_mask).
803
- Works when output shape differs from input but mapping is convolutional on 30x30 grid."""
804
- exs = get_exs(td)
805
-
806
- t_start = time.time()
807
- for use_bias in [False, True]:
808
- for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
809
- if time.time() - t_start > time_budget: return None
810
-
811
- pad = ks // 2
812
- feat = 10 * ks * ks + (1 if use_bias else 0)
813
- if feat > 20000: continue
814
-
815
- patches, targets = [], []
816
- for inp_g, out_g in exs:
817
- ih, iw = inp_g.shape
818
- oh, ow = out_g.shape
819
- oh_full = np.zeros((10, GH, GW), dtype=np.float64)
820
- for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
821
- oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
822
-
823
- for r in range(oh):
824
- for c in range(ow):
825
- p = oh_pad[:, r:r+ks, c:c+ks].flatten()
826
- if use_bias: p = np.append(p, 1.0)
827
- patches.append(p)
828
- targets.append(int(out_g[r, c]))
829
-
830
- n_patches = len(patches)
831
- if feat > 5000 and n_patches > 2000: continue
832
-
833
- P = np.array(patches, dtype=np.float64)
834
- T = np.array(targets, dtype=np.int64)
835
- T_oh = np.zeros((len(T), 10), dtype=np.float64)
836
- for i, t in enumerate(T): T_oh[i, t] = 1.0
837
-
838
- try:
839
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
840
- except:
841
- continue
842
- if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
843
-
844
- if use_bias:
845
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
846
- B = WT[-1].astype(np.float32)
847
- else:
848
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
849
- B = None
850
-
851
- # Use ReduceSum of output channels as mask (sum across channels == 1 for valid pixels)
852
- # But we don't know the output mask at inference time from input alone...
853
- # We need a way to derive the output mask from the input.
854
- # For same-shape: mask = ReduceSum(input, axis=1) works
855
- # For diff-shape: we need to compute the output mask differently
856
- #
857
- # Approach: Conv output at valid positions should have max > threshold,
858
- # and at padding positions max ≈ 0. Use the ArgMax+OneHot and then
859
- # mask with ReduceSum(input) which is 1 at input positions but 0 at padding.
860
- # BUT output may be LARGER than input...
861
- #
862
- # Alternative: just use Conv -> ArgMax -> Equal+Cast -> Mul(input_mask_expanded)
863
- # where input_mask covers the output region too.
864
- # This won't work if output extends beyond input region.
865
- #
866
- # Simplest correct approach: let the conv produce valid one-hot everywhere,
867
- # then the padding region should naturally produce channel-0 output.
868
- # Since padding is all-zero input, conv output there = bias only.
869
- # If no bias, conv output = 0 for all channels -> argmax gives channel 0 -> onehot gives [1,0,...,0]
870
- # which equals the padding encoding (channel 0 = 1 in padding).
871
- # Wait - that's WRONG for the NeuroGolf format. In the padding region, ALL channels should be 0.
872
- # The one-hot encoding has channel[color]=1, but padding = ALL zeros.
873
- #
874
- # So we NEED a mask. But for diff-shape, what mask?
875
- # If output is always top-left aligned and we know max output size...
876
- # We can't statically determine the output mask from the input.
877
- #
878
- # However: we can try the ReduceSum approach anyway — if conv naturally
879
- # produces channel-0 dominant output in padding, then:
880
- # mask = ReduceSum(input, axis=1) gives 1 for input pixels, 0 for padding
881
- # If output region ⊆ input region, this works.
882
- # If output region > input region... we need the output's ReduceSum instead.
883
-
884
- # For tasks where output fits within input bounds, use input mask
885
- all_output_within_input = all(
886
- out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
887
- for inp_g, out_g in exs
888
- )
889
-
890
- if not all_output_within_input:
891
- continue # Skip tasks where output extends beyond input
892
-
893
- inits = [numpy_helper.from_array(Wconv, 'W')]
894
- conv_inputs = ['input', 'W']
895
- if B is not None:
896
- inits.append(numpy_helper.from_array(B, 'B'))
897
- conv_inputs.append('B')
898
-
899
- nodes = [
900
- helper.make_node('ReduceSum', ['input'], ['mask'], axes=[1], keepdims=1),
901
- helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
902
- helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
903
- ]
904
- add_onehot_block(nodes, inits, 'am', 'oh_out')
905
- nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
906
-
907
- model = mk(nodes, inits)
908
- onnx.save(model, path)
909
- if validate(path, td): return 'conv_var_diff', model
910
- return None
911
-
912
- # ============================================================
913
- # MAIN
914
- # ============================================================
915
-
916
- ANALYTICAL_SOLVERS = [
917
- ('identity', s_identity), ('constant', s_constant), ('color_map', s_color_map),
918
- ('transpose', s_transpose), ('flip', s_flip), ('rotate', s_rotate),
919
- ('tile', s_tile), ('upscale', s_upscale), ('kronecker', s_kronecker),
920
- ('concat', s_concat), ('concat_enhanced', s_concat_enhanced),
921
- ('diagonal_tile', s_diagonal_tile),
922
- ('spatial_gather', s_spatial_gather),
923
- ('varshape_spatial_gather', s_varshape_spatial_gather),
924
- ]
925
-
926
- def solve_task(tn, td, outdir, conv_budget=30.0):
927
- t_start = time.time()
928
- os.makedirs(outdir, exist_ok=True)
929
- path = os.path.join(outdir, f"task{tn:03d}.onnx")
930
-
931
- # 1. Try analytical solvers (fast, tiny models)
932
- for sname, sfn in ANALYTICAL_SOLVERS:
933
- try:
934
- model = sfn(td)
935
- if model is None: continue
936
- onnx.save(model, path)
937
- if validate(path, td):
938
- return True, sname, os.path.getsize(path), time.time() - t_start, path
939
- except: pass
940
-
941
- # 2. Determine task shape category and try conv solvers
942
- exs = get_exs(td)
943
- same_shape = all(inp.shape == out.shape for inp, out in exs)
944
- shapes = set(inp.shape for inp, _ in exs)
945
- fixed_in = len(shapes) == 1
946
-
947
- conv_time = conv_budget
948
-
949
- if same_shape:
950
- if fixed_in:
951
- result = solve_conv_fixed(td, path, time_budget=conv_time/2)
952
- if result is not None:
953
- sname, model = result
954
- return True, sname, os.path.getsize(path), time.time() - t_start, path
955
- result = solve_conv_variable(td, path, time_budget=conv_time)
956
- if result is not None:
957
- sname, model = result
958
- return True, sname, os.path.getsize(path), time.time() - t_start, path
959
- else:
960
- sp = fixed_shapes(td)
961
- if sp is not None:
962
- (IH,IW),(OH,OW) = sp
963
- if OH <= IH and OW <= IW:
964
- result = solve_conv_diffshape(td, path, time_budget=conv_time)
965
- if result is not None:
966
- sname, model = result
967
- return True, sname, os.path.getsize(path), time.time() - t_start, path
968
-
969
- # Try variable diff-shape conv (output within input bounds)
970
- result = solve_conv_var_diff(td, path, time_budget=conv_time)
971
- if result is not None:
972
- sname, model = result
973
- return True, sname, os.path.getsize(path), time.time() - t_start, path
974
-
975
- return False, None, None, time.time() - t_start, path
976
-
977
- def run_tasks(task_nums, tasks, output_dir, conv_budget, use_wandb):
978
- results = {}
979
- cost = 0
980
- score = 0
981
- total_score = 0
982
- for tn in task_nums:
983
- if tn not in tasks:
984
- continue
985
-
986
- td = tasks[tn]['data']
987
- ok, sname, sz, t_task, model_path = solve_task(tn, td, output_dir, conv_budget)
988
-
989
- if ok:
990
- try:
991
- macs, memory, params = score_network(model_path)
992
- if macs is None:
993
- macs, memory, params = 0, 0, 0
994
- except:
995
- macs, memory, params = 0, 0, 0
996
- cost = macs + memory + params
997
- score = max(1, 25 - math.log(max(1, cost)))
998
- total_score += score
999
-
1000
- results[tn] = (sname, t_task, sz)
1001
- print(f"Task {tn:3d}: {sname:25s} {score:7.3f} {cost:>12} {t_task:7.3f}s ({sz:>8,} bytes)")
1002
- else:
1003
- print(f"Task {tn:3d}: UNSOLVED {t_task:7.3f}s")
1004
- macs, memory, params, cost = 0, 0, 0, 0
1005
-
1006
- if use_wandb and wandb is not None:
1007
- wandb.log({
1008
- "task_id": tn,
1009
- "solver": sname if ok else "unsolved",
1010
- "onnx_bytes": sz if ok else 0,
1011
- "task_time_sec": t_task,
1012
- "macs": macs,
1013
- "memory": memory,
1014
- "params": params,
1015
- "cost": cost,
1016
- "score": score,
1017
- })
1018
-
1019
- return results, total_score
1020
-
1021
-
1022
- def main():
1023
- parser = argparse.ArgumentParser()
1024
- parser.add_argument('--data_dir', default='ARC-AGI/data/training/')
1025
- parser.add_argument('--output_dir', default='submission')
1026
- parser.add_argument('--kaggle', action='store_true')
1027
- parser.add_argument('--conv_budget', type=float, default=30.0)
1028
- parser.add_argument('--tasks', type=str, default='')
1029
- parser.add_argument('--device', type=str, default='auto', choices=['auto','cpu','cuda'])
1030
- parser.add_argument('--use_wandb', action='store_true')
1031
- args = parser.parse_args()
1032
- global ORT_PROVIDERS
1033
- config = {
1034
- "device": args.device,
1035
- "conv_budget": args.conv_budget,
1036
- "data_dir": args.data_dir,
1037
- "tasks": args.tasks,
1038
- }
1039
-
1040
- if args.device == 'cuda':
1041
- ORT_PROVIDERS = ['CUDAExecutionProvider', 'CPUExecutionProvider']
1042
- elif args.device == 'cpu':
1043
- ORT_PROVIDERS = ['CPUExecutionProvider']
1044
- print(f"Using providers: {ORT_PROVIDERS}")
1045
- if args.kaggle: tasks = load_tasks_kaggle(args.data_dir)
1046
- else: tasks = load_tasks_dir(args.data_dir)
1047
- task_nums = [int(t) for t in args.tasks.split(',')] if args.tasks else sorted(tasks.keys())
1048
- print(f"Loaded {len(tasks)} tasks, solving {len(task_nums)}")
1049
- print(f"Conv budget: {args.conv_budget}s per task")
1050
- print("=" * 70)
1051
- t0 = time.time()
1052
- results = {}
1053
-
1054
- if args.use_wandb and wandb is not None:
1055
- with wandb.init(
1056
- project="neurogolf",
1057
- name="solver_run",
1058
- config=config,
1059
- ):
1060
- results, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=True)
1061
- else:
1062
- results, total_score = run_tasks(task_nums, tasks, args.output_dir, args.conv_budget, use_wandb=False)
1063
-
1064
- elapsed = time.time() - t0
1065
- print(f"\n{'='*70}")
1066
- print(f"Solved: {len(results)}/{len(task_nums)} in {elapsed:.0f}s")
1067
- solver_names = [v[0] for v in results.values()]
1068
- sc = Counter(solver_names)
1069
- for s, c in sc.most_common(): print(f" {s}: {c}")
1070
- n_files = len([f for f in os.listdir(args.output_dir) if f.endswith('.onnx')])
1071
- total_size = sum(os.path.getsize(os.path.join(args.output_dir, f))
1072
- for f in os.listdir(args.output_dir) if f.endswith('.onnx'))
1073
- print(f"\n{n_files} ONNX files, Total local estimated score: {total_score:.1f} total {total_size/1024:.1f} KB")
1074
-
1075
- if __name__ == '__main__':
1076
- main()
 
1
+ placeholder