Fix ReduceSum axes for opset 17 (axes must be tensor input, not attribute)
Browse filesThree locations fixed:
- s_constant: ReduceSum axes=[1,2,3] → tensor input
- solve_conv_variable: ReduceSum axes=[1] → tensor input
- solve_conv_var_diff: ReduceSum axes=[1] → tensor input
Also fixes solve_conv_var_diff which was truncated in previous upload.
- neurogolf_solver.py +65 -235
neurogolf_solver.py
CHANGED
|
@@ -9,6 +9,7 @@ v5 CHANGES (from v4):
|
|
| 9 |
- s_rotate k=2: double Slice(step=-1) — 0 MACs (was ~165K)
|
| 10 |
- s_rotate k=1,3: Slice+Transpose for square grids (0 MACs), Gather fallback for non-square
|
| 11 |
- All Pad nodes: tensor-based pads input (opset 17 requirement)
|
|
|
|
| 12 |
- All other solvers unchanged from v4
|
| 13 |
|
| 14 |
Solvers:
|
|
@@ -50,12 +51,8 @@ OPSET = [helper.make_opsetid("", 17)]
|
|
| 50 |
|
| 51 |
INT64_MIN = int(np.iinfo(np.int64).min)
|
| 52 |
|
| 53 |
-
# Officially excluded tasks (score 0 regardless)
|
| 54 |
EXCLUDED_TASKS = {21, 55, 80, 184, 202, 366}
|
| 55 |
-
|
| 56 |
-
# Max ARC-GEN examples to use for validation (to keep runtime reasonable)
|
| 57 |
MAX_ARCGEN_VALIDATE = 30
|
| 58 |
-
# Max ARC-GEN examples for conv fitting (keep separate from validation!)
|
| 59 |
MAX_ARCGEN_FIT = 0
|
| 60 |
|
| 61 |
def get_providers():
|
|
@@ -68,7 +65,6 @@ ORT_PROVIDERS = get_providers()
|
|
| 68 |
# ============================================================
|
| 69 |
|
| 70 |
def load_tasks_dir(data_dir, arcgen_dir=None):
|
| 71 |
-
"""Load ARC-AGI tasks and optionally merge ARC-GEN data."""
|
| 72 |
files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
|
| 73 |
tasks = {}
|
| 74 |
for i, f in enumerate(files):
|
|
@@ -86,7 +82,6 @@ def load_tasks_dir(data_dir, arcgen_dir=None):
|
|
| 86 |
return tasks
|
| 87 |
|
| 88 |
def load_tasks_kaggle(data_dir):
|
| 89 |
-
"""Load Kaggle format tasks (already have arc-gen embedded)."""
|
| 90 |
tasks = {}
|
| 91 |
for tn in range(1, 401):
|
| 92 |
path = os.path.join(data_dir, f"task{tn:03d}.json")
|
|
@@ -107,7 +102,6 @@ def to_onehot(grid):
|
|
| 107 |
return arr
|
| 108 |
|
| 109 |
def validate(path, td):
|
| 110 |
-
"""Validate model against ALL examples: train + test + arc-gen."""
|
| 111 |
try:
|
| 112 |
opts = ort.SessionOptions()
|
| 113 |
opts.log_severity_level = 3
|
|
@@ -130,7 +124,6 @@ def validate(path, td):
|
|
| 130 |
return True
|
| 131 |
|
| 132 |
def validate_raw(raw_bytes, td):
|
| 133 |
-
"""Validate model from raw bytes against ALL examples."""
|
| 134 |
try:
|
| 135 |
opts = ort.SessionOptions()
|
| 136 |
opts.log_severity_level = 3
|
|
@@ -153,14 +146,13 @@ def validate_raw(raw_bytes, td):
|
|
| 153 |
return True
|
| 154 |
|
| 155 |
# ============================================================
|
| 156 |
-
# STATIC PROFILER
|
| 157 |
# ============================================================
|
| 158 |
|
| 159 |
BANNED_OPS = {'Loop', 'Scan', 'NonZero', 'Unique', 'If', 'Function'}
|
| 160 |
MAX_FILESIZE = int(1.44 * 1024 * 1024)
|
| 161 |
|
| 162 |
def score_network(path):
|
| 163 |
-
"""Static profiler matching Kaggle scoring: cost = macs + memory + params."""
|
| 164 |
if HAS_ONNX_TOOL:
|
| 165 |
try:
|
| 166 |
return _score_network_official(path)
|
|
@@ -169,23 +161,19 @@ def score_network(path):
|
|
| 169 |
return _static_profile(path)
|
| 170 |
|
| 171 |
def _static_profile(path):
|
| 172 |
-
"""Compute cost without onnx_tool: params + nbytes + macs."""
|
| 173 |
try:
|
| 174 |
model = onnx.load(path)
|
| 175 |
except:
|
| 176 |
return None, None, None
|
| 177 |
-
|
| 178 |
tensors = {}
|
| 179 |
params = 0
|
| 180 |
nbytes = 0
|
| 181 |
macs = 0
|
| 182 |
-
|
| 183 |
for init in model.graph.initializer:
|
| 184 |
a = numpy_helper.to_array(init)
|
| 185 |
tensors[init.name] = a
|
| 186 |
params += a.size
|
| 187 |
nbytes += a.nbytes
|
| 188 |
-
|
| 189 |
for nd in model.graph.node:
|
| 190 |
if nd.op_type == 'Constant':
|
| 191 |
for attr in nd.attribute:
|
|
@@ -198,16 +186,13 @@ def _static_profile(path):
|
|
| 198 |
nbytes += a.nbytes
|
| 199 |
except:
|
| 200 |
pass
|
| 201 |
-
|
| 202 |
if nd.op_type in BANNED_OPS:
|
| 203 |
return None, None, None
|
| 204 |
-
|
| 205 |
if nd.op_type == 'Conv' and len(nd.input) >= 2 and nd.input[1] in tensors:
|
| 206 |
w = tensors[nd.input[1]]
|
| 207 |
if w.ndim == 4:
|
| 208 |
co, ci, kh, kw = w.shape
|
| 209 |
macs += co * ci * kh * kw * GH * GW
|
| 210 |
-
|
| 211 |
return int(macs), int(nbytes), int(params)
|
| 212 |
|
| 213 |
# ============================================================
|
|
@@ -215,12 +200,10 @@ def _static_profile(path):
|
|
| 215 |
# ============================================================
|
| 216 |
|
| 217 |
def _make_int64_init(name, values):
|
| 218 |
-
"""Create an int64 tensor initializer from a list of values."""
|
| 219 |
return numpy_helper.from_array(np.array(values, dtype=np.int64), name)
|
| 220 |
|
| 221 |
def _build_pad_node(input_name, output_name, pad_h, pad_w, inits, suffix=''):
|
| 222 |
-
"""
|
| 223 |
-
Pads [0,0,0,0, 0,0,pad_h,pad_w] — only spatial end-padding."""
|
| 224 |
pads_name = f'pads{suffix}'
|
| 225 |
cv_name = f'pad_cv{suffix}'
|
| 226 |
pads_arr = np.array([0, 0, 0, 0, 0, 0, pad_h, pad_w], dtype=np.int64)
|
|
@@ -229,7 +212,7 @@ def _build_pad_node(input_name, output_name, pad_h, pad_w, inits, suffix=''):
|
|
| 229 |
return helper.make_node('Pad', [input_name, pads_name, cv_name], [output_name], mode='constant')
|
| 230 |
|
| 231 |
def _build_slice_crop(input_name, output_name, IH, IW, inits, suffix=''):
|
| 232 |
-
"""
|
| 233 |
st_name = f'crop_st{suffix}'
|
| 234 |
en_name = f'crop_en{suffix}'
|
| 235 |
inits.append(_make_int64_init(st_name, [0, 0, 0, 0]))
|
|
@@ -237,7 +220,7 @@ def _build_slice_crop(input_name, output_name, IH, IW, inits, suffix=''):
|
|
| 237 |
return helper.make_node('Slice', [input_name, st_name, en_name], [output_name])
|
| 238 |
|
| 239 |
def _build_slice_reverse(input_name, output_name, axis, dim_size, inits, suffix=''):
|
| 240 |
-
"""
|
| 241 |
st_name = f'rev_st{suffix}'
|
| 242 |
en_name = f'rev_en{suffix}'
|
| 243 |
ax_name = f'rev_ax{suffix}'
|
|
@@ -248,6 +231,12 @@ def _build_slice_reverse(input_name, output_name, axis, dim_size, inits, suffix=
|
|
| 248 |
inits.append(_make_int64_init(sp_name, [-1]))
|
| 249 |
return helper.make_node('Slice', [input_name, st_name, en_name, ax_name, sp_name], [output_name])
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
def mk(nodes, inits=None):
|
| 252 |
x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
|
| 253 |
y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
|
|
@@ -255,45 +244,35 @@ def mk(nodes, inits=None):
|
|
| 255 |
return helper.make_model(g, ir_version=IR, opset_imports=OPSET)
|
| 256 |
|
| 257 |
def get_exs(td):
|
| 258 |
-
"""Get examples for analytical solvers (train+test only)."""
|
| 259 |
return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
|
| 260 |
for ex in td['train'] + td['test']]
|
| 261 |
|
| 262 |
def get_exs_for_fitting(td):
|
| 263 |
-
"""Get examples for conv fitting. Uses train+test + arc-gen WHERE SIZES MATCH."""
|
| 264 |
base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
|
| 265 |
for ex in td['train'] + td['test']]
|
| 266 |
-
|
| 267 |
if not base_exs:
|
| 268 |
return base_exs
|
| 269 |
-
|
| 270 |
base_shapes = {inp.shape for inp, _ in base_exs}
|
| 271 |
if len(base_shapes) != 1:
|
| 272 |
return base_exs
|
| 273 |
-
|
| 274 |
base_shape = list(base_shapes)[0]
|
| 275 |
-
|
| 276 |
ag_exs = []
|
| 277 |
for ex in td.get('arc-gen', []):
|
| 278 |
inp = np.array(ex['input'], dtype=np.int64)
|
| 279 |
out = np.array(ex['output'], dtype=np.int64)
|
| 280 |
if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
|
| 281 |
ag_exs.append((inp, out))
|
| 282 |
-
|
| 283 |
return base_exs + ag_exs[:10]
|
| 284 |
|
| 285 |
def get_exs_for_fitting_variable(td):
|
| 286 |
-
"""Get examples for variable-shape conv fitting."""
|
| 287 |
base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
|
| 288 |
for ex in td['train'] + td['test']]
|
| 289 |
-
|
| 290 |
ag_exs = []
|
| 291 |
for ex in td.get('arc-gen', []):
|
| 292 |
inp = np.array(ex['input'], dtype=np.int64)
|
| 293 |
out = np.array(ex['output'], dtype=np.int64)
|
| 294 |
if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
|
| 295 |
ag_exs.append((inp, out))
|
| 296 |
-
|
| 297 |
return base_exs + ag_exs[:20]
|
| 298 |
|
| 299 |
def fixed_shapes(td):
|
|
@@ -303,11 +282,10 @@ def fixed_shapes(td):
|
|
| 303 |
return list(shapes)[0] if len(shapes) == 1 else None
|
| 304 |
|
| 305 |
# ============================================================
|
| 306 |
-
# GATHER HELPERS
|
| 307 |
# ============================================================
|
| 308 |
|
| 309 |
def _build_gather_model(OH, OW, idx):
|
| 310 |
-
"""Gather-based spatial remapping. Used for concat, spatial_gather, etc."""
|
| 311 |
flat_idx = np.zeros((GH*GW,), dtype=np.int64)
|
| 312 |
mask = np.zeros((1,1,GH,GW), dtype=np.float32)
|
| 313 |
for oi in range(OH):
|
|
@@ -329,7 +307,6 @@ def _build_gather_model(OH, OW, idx):
|
|
| 329 |
return mk(nodes, inits)
|
| 330 |
|
| 331 |
def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
|
| 332 |
-
"""Gather-based spatial remapping with constant pixels."""
|
| 333 |
flat_idx = np.zeros((GH*GW,), dtype=np.int64)
|
| 334 |
gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
|
| 335 |
const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
|
|
@@ -378,9 +355,7 @@ def s_color_map(td):
|
|
| 378 |
iv, ov = int(iv), int(ov)
|
| 379 |
if iv in cm and cm[iv] != ov: return None
|
| 380 |
cm[iv] = ov
|
| 381 |
-
|
| 382 |
is_permutation = (set(cm.keys()) == set(cm.values()))
|
| 383 |
-
|
| 384 |
if is_permutation:
|
| 385 |
gather_ch = np.arange(10, dtype=np.int32)
|
| 386 |
for src, dst in cm.items():
|
|
@@ -397,106 +372,71 @@ def s_color_map(td):
|
|
| 397 |
[numpy_helper.from_array(W, 'W')])
|
| 398 |
|
| 399 |
def s_transpose(td):
|
| 400 |
-
"""Transpose spatial dimensions. Already near-zero cost with Transpose node."""
|
| 401 |
for ex in td['train']+td['test']:
|
| 402 |
if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T): return None
|
| 403 |
return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0,1,3,2])])
|
| 404 |
|
| 405 |
def s_flip(td):
|
| 406 |
-
"""Flip using Slice(step=-1) — zero MACs
|
| 407 |
exs = get_exs(td)
|
| 408 |
sp = fixed_shapes(td)
|
| 409 |
if sp is None: return None
|
| 410 |
(IH,IW),(OH,OW) = sp
|
| 411 |
if (IH,IW) != (OH,OW): return None
|
| 412 |
-
|
| 413 |
for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
|
| 414 |
if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
|
| 415 |
-
# axis 0 = flipud = reverse dim 2 (H)
|
| 416 |
-
# axis 1 = fliplr = reverse dim 3 (W)
|
| 417 |
onnx_axis = 2 if axis == 0 else 3
|
| 418 |
dim_size = IH if axis == 0 else IW
|
| 419 |
pad_h, pad_w = GH - IH, GW - IW
|
| 420 |
-
|
| 421 |
inits = []
|
| 422 |
nodes = []
|
| 423 |
-
|
| 424 |
-
# Step 1: Crop input to [1,10,IH,IW]
|
| 425 |
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
| 426 |
-
|
| 427 |
-
# Step 2: Reverse the target axis
|
| 428 |
nodes.append(_build_slice_reverse('cropped', 'flipped', onnx_axis, dim_size, inits))
|
| 429 |
-
|
| 430 |
-
# Step 3: Pad back to [1,10,30,30]
|
| 431 |
nodes.append(_build_pad_node('flipped', 'output', pad_h, pad_w, inits))
|
| 432 |
-
|
| 433 |
return mk(nodes, inits)
|
| 434 |
return None
|
| 435 |
|
| 436 |
def s_rotate(td):
|
| 437 |
-
"""Rotate using Slice+Transpose
|
| 438 |
-
|
| 439 |
exs = get_exs(td)
|
| 440 |
sp = fixed_shapes(td)
|
| 441 |
if sp is None: return None
|
| 442 |
(IH,IW),(OH,OW) = sp
|
| 443 |
-
|
| 444 |
for k in [1, 2, 3]:
|
| 445 |
if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs):
|
| 446 |
continue
|
| 447 |
-
|
| 448 |
if k == 2:
|
| 449 |
-
# 180° = flipud + fliplr — works for any shape
|
| 450 |
-
# output[r,c] = input[IH-1-r, IW-1-c]
|
| 451 |
pad_h, pad_w = GH - OH, GW - OW
|
| 452 |
inits = []
|
| 453 |
nodes = []
|
| 454 |
-
|
| 455 |
-
# Crop to [1,10,IH,IW]
|
| 456 |
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
| 457 |
-
# Reverse axis 2 (H)
|
| 458 |
nodes.append(_build_slice_reverse('cropped', 'flip_h', 2, IH, inits, suffix='_h'))
|
| 459 |
-
# Reverse axis 3 (W)
|
| 460 |
nodes.append(_build_slice_reverse('flip_h', 'rotated', 3, IW, inits, suffix='_w'))
|
| 461 |
-
# Pad back
|
| 462 |
nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
|
| 463 |
-
|
| 464 |
return mk(nodes, inits)
|
| 465 |
-
|
| 466 |
elif k == 1 and IH == IW:
|
| 467 |
-
# rot90 CCW
|
| 468 |
-
# output[r,c] = input[c, IH-1-r]
|
| 469 |
-
# Step 1: Transpose [0,1,3,2]: temp[r,c] = input[c,r]
|
| 470 |
-
# Step 2: Reverse axis 2: out[r,c] = temp[IH-1-r,c] = input[c,IH-1-r] ✓
|
| 471 |
pad_h, pad_w = GH - IH, GW - IW
|
| 472 |
inits = []
|
| 473 |
nodes = []
|
| 474 |
-
|
| 475 |
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
| 476 |
nodes.append(helper.make_node('Transpose', ['cropped'], ['transposed'], perm=[0,1,3,2]))
|
| 477 |
nodes.append(_build_slice_reverse('transposed', 'rotated', 2, IH, inits))
|
| 478 |
nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
|
| 479 |
-
|
| 480 |
return mk(nodes, inits)
|
| 481 |
-
|
| 482 |
elif k == 3 and IH == IW:
|
| 483 |
-
# rot270 CCW
|
| 484 |
-
# output[r,c] = input[IW-1-c, r]
|
| 485 |
-
# Step 1: Reverse axis 2: temp[r,c] = input[IH-1-r,c]
|
| 486 |
-
# Step 2: Transpose [0,1,3,2]: out[r,c] = temp[c,r] = input[IH-1-c,r] ✓ (IH=IW)
|
| 487 |
pad_h, pad_w = GH - IH, GW - IW
|
| 488 |
inits = []
|
| 489 |
nodes = []
|
| 490 |
-
|
| 491 |
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
| 492 |
nodes.append(_build_slice_reverse('cropped', 'flipped', 2, IH, inits))
|
| 493 |
nodes.append(helper.make_node('Transpose', ['flipped'], ['rotated'], perm=[0,1,3,2]))
|
| 494 |
nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
|
| 495 |
-
|
| 496 |
return mk(nodes, inits)
|
| 497 |
-
|
| 498 |
else:
|
| 499 |
-
# Non-square k=1 or k=3:
|
| 500 |
idx = np.zeros((OH,OW,2), dtype=np.int64)
|
| 501 |
for r in range(OH):
|
| 502 |
for c in range(OW):
|
|
@@ -527,11 +467,9 @@ def s_spatial_gather(td):
|
|
| 527 |
return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
|
| 528 |
|
| 529 |
def s_varshape_spatial_gather(td):
|
| 530 |
-
"""Spatial gather that works for variable-shape tasks by embedding in 30x30."""
|
| 531 |
sp = fixed_shapes(td)
|
| 532 |
if sp is not None: return None
|
| 533 |
exs = get_exs(td)
|
| 534 |
-
|
| 535 |
exs_30 = []
|
| 536 |
for inp, out in exs:
|
| 537 |
ih, iw = inp.shape
|
|
@@ -541,10 +479,8 @@ def s_varshape_spatial_gather(td):
|
|
| 541 |
inp30[:ih, :iw] = inp
|
| 542 |
out30[:oh, :ow] = out
|
| 543 |
exs_30.append((inp30, out30))
|
| 544 |
-
|
| 545 |
idx = np.full((30, 30, 2), -1, dtype=np.int64)
|
| 546 |
cst = np.full((30, 30), -1, dtype=np.int64)
|
| 547 |
-
|
| 548 |
for oi in range(30):
|
| 549 |
for oj in range(30):
|
| 550 |
vals = set(int(out30[oi, oj]) for _, out30 in exs_30)
|
|
@@ -560,7 +496,6 @@ def s_varshape_spatial_gather(td):
|
|
| 560 |
if found: break
|
| 561 |
if not found and cst[oi, oj] < 0:
|
| 562 |
return None
|
| 563 |
-
|
| 564 |
return _build_gather_model_with_const(30, 30, 30, 30, idx, cst)
|
| 565 |
|
| 566 |
def s_tile(td):
|
|
@@ -665,29 +600,21 @@ def s_concat(td):
|
|
| 665 |
return None
|
| 666 |
|
| 667 |
def s_concat_enhanced(td):
|
| 668 |
-
"""Enhanced concat with all 8 dihedral group transforms."""
|
| 669 |
exs = get_exs(td)
|
| 670 |
sp = fixed_shapes(td)
|
| 671 |
if sp is None: return None
|
| 672 |
(IH,IW),(OH,OW) = sp
|
| 673 |
if IH == OH and IW == OW: return None
|
| 674 |
-
|
| 675 |
if OH % IH != 0 or OW % IW != 0: return None
|
| 676 |
rH, rW = OH // IH, OW // IW
|
| 677 |
if rH * rW > 16 or rH * rW < 2: return None
|
| 678 |
if OH > 30 or OW > 30: return None
|
| 679 |
-
|
| 680 |
transforms = [
|
| 681 |
-
('id', lambda x: x),
|
| 682 |
-
('
|
| 683 |
-
('
|
| 684 |
-
('
|
| 685 |
-
('rot90', lambda x: np.rot90(x, 1)),
|
| 686 |
-
('rot270', lambda x: np.rot90(x, 3)),
|
| 687 |
-
('T', lambda x: x.T),
|
| 688 |
-
('T_fliplr', lambda x: np.fliplr(x.T)),
|
| 689 |
]
|
| 690 |
-
|
| 691 |
block_transforms = {}
|
| 692 |
for bi in range(rH):
|
| 693 |
for bj in range(rW):
|
|
@@ -698,15 +625,11 @@ def s_concat_enhanced(td):
|
|
| 698 |
block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
|
| 699 |
expected = tfn(inp)
|
| 700 |
if expected.shape != (IH, IW) or not np.array_equal(block, expected):
|
| 701 |
-
ok = False
|
| 702 |
-
break
|
| 703 |
if ok:
|
| 704 |
-
found = (tidx, tname)
|
| 705 |
-
|
| 706 |
-
if found is None:
|
| 707 |
-
return None
|
| 708 |
block_transforms[(bi, bj)] = found
|
| 709 |
-
|
| 710 |
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 711 |
for bi in range(rH):
|
| 712 |
for bj in range(rW):
|
|
@@ -723,19 +646,15 @@ def s_concat_enhanced(td):
|
|
| 723 |
elif tname == 'T': sr, sc = lc, lr
|
| 724 |
elif tname == 'T_fliplr': sr, sc = IW-1-lc, lr
|
| 725 |
idx[oi, oj] = [sr, sc]
|
| 726 |
-
|
| 727 |
for inp, out in exs:
|
| 728 |
reconstructed = np.zeros_like(out)
|
| 729 |
for oi in range(OH):
|
| 730 |
for oj in range(OW):
|
| 731 |
reconstructed[oi,oj] = inp[idx[oi,oj,0], idx[oi,oj,1]]
|
| 732 |
-
if not np.array_equal(reconstructed, out):
|
| 733 |
-
return None
|
| 734 |
-
|
| 735 |
return _build_gather_model(OH, OW, idx)
|
| 736 |
|
| 737 |
def s_input_driven_tile(td):
|
| 738 |
-
"""Each non-zero input pixel controls a block that's a copy of the input."""
|
| 739 |
exs = get_exs(td)
|
| 740 |
sp = fixed_shapes(td)
|
| 741 |
if sp is None: return None
|
|
@@ -744,21 +663,17 @@ def s_input_driven_tile(td):
|
|
| 744 |
sH, sW = OH // IH, OW // IW
|
| 745 |
if sH != IH or sW != IW: return None
|
| 746 |
if OH > 30 or OW > 30: return None
|
| 747 |
-
|
| 748 |
for inp, out in exs:
|
| 749 |
for bi in range(IH):
|
| 750 |
for bj in range(IW):
|
| 751 |
block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
|
| 752 |
if inp[bi, bj] != 0:
|
| 753 |
-
if not np.array_equal(block, inp):
|
| 754 |
-
return None
|
| 755 |
else:
|
| 756 |
-
if not np.all(block == 0):
|
| 757 |
-
return None
|
| 758 |
return None
|
| 759 |
|
| 760 |
def s_kronecker(td):
|
| 761 |
-
"""output = kron(input, ones(sH,sW)) — nearest-neighbor upscaling."""
|
| 762 |
exs = get_exs(td)
|
| 763 |
sp = fixed_shapes(td)
|
| 764 |
if sp is None: return None
|
|
@@ -767,12 +682,8 @@ def s_kronecker(td):
|
|
| 767 |
sH, sW = OH // IH, OW // IW
|
| 768 |
if sH < 2 or sW < 2: return None
|
| 769 |
if OH > 30 or OW > 30: return None
|
| 770 |
-
|
| 771 |
for inp, out in exs:
|
| 772 |
-
|
| 773 |
-
if not np.array_equal(out, expected):
|
| 774 |
-
return None
|
| 775 |
-
|
| 776 |
idx = np.zeros((OH,OW,2), dtype=np.int64)
|
| 777 |
for r in range(OH):
|
| 778 |
for c in range(OW):
|
|
@@ -780,7 +691,6 @@ def s_kronecker(td):
|
|
| 780 |
return _build_gather_model(OH, OW, idx)
|
| 781 |
|
| 782 |
def s_diagonal_tile(td):
|
| 783 |
-
"""Input placed along diagonal: block[i,i] = input, rest = 0."""
|
| 784 |
exs = get_exs(td)
|
| 785 |
sp = fixed_shapes(td)
|
| 786 |
if sp is None: return None
|
|
@@ -789,18 +699,14 @@ def s_diagonal_tile(td):
|
|
| 789 |
rH, rW = OH // IH, OW // IW
|
| 790 |
if rH != rW or rH < 2: return None
|
| 791 |
if OH > 30 or OW > 30: return None
|
| 792 |
-
|
| 793 |
for inp, out in exs:
|
| 794 |
for bi in range(rH):
|
| 795 |
for bj in range(rW):
|
| 796 |
block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
|
| 797 |
if bi == bj:
|
| 798 |
-
if not np.array_equal(block, inp):
|
| 799 |
-
return None
|
| 800 |
else:
|
| 801 |
-
if not np.all(block == 0):
|
| 802 |
-
return None
|
| 803 |
-
|
| 804 |
idx = np.zeros((OH,OW,2), dtype=np.int64)
|
| 805 |
cst = np.full((OH,OW), -1, dtype=np.int64)
|
| 806 |
for bi in range(rH):
|
|
@@ -813,11 +719,9 @@ def s_diagonal_tile(td):
|
|
| 813 |
else:
|
| 814 |
idx[oi, oj] = [-1, -1]
|
| 815 |
cst[oi, oj] = 0
|
| 816 |
-
|
| 817 |
return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
|
| 818 |
|
| 819 |
def s_shift(td):
|
| 820 |
-
"""Detect constant spatial shift of the grid."""
|
| 821 |
exs = get_exs(td)
|
| 822 |
sp = fixed_shapes(td)
|
| 823 |
if sp is None: return None
|
|
@@ -850,13 +754,11 @@ def s_shift(td):
|
|
| 850 |
return None
|
| 851 |
|
| 852 |
def s_gravity(td):
|
| 853 |
-
"""Detect gravity-like compaction in one direction."""
|
| 854 |
exs = get_exs(td)
|
| 855 |
sp = fixed_shapes(td)
|
| 856 |
if sp is None: return None
|
| 857 |
(IH, IW), (OH, OW) = sp
|
| 858 |
if (IH, IW) != (OH, OW): return None
|
| 859 |
-
|
| 860 |
def _gravity(grid, direction):
|
| 861 |
r = np.zeros_like(grid); h, w = grid.shape
|
| 862 |
if direction in ('down', 'up'):
|
|
@@ -870,14 +772,12 @@ def s_gravity(td):
|
|
| 870 |
if direction == 'right': r[rr, w-len(nz):w] = nz
|
| 871 |
else: r[rr, :len(nz)] = nz
|
| 872 |
return r
|
| 873 |
-
|
| 874 |
for d in ('down', 'up', 'left', 'right'):
|
| 875 |
if all(np.array_equal(_gravity(inp, d), out) for inp, out in exs):
|
| 876 |
return None
|
| 877 |
return None
|
| 878 |
|
| 879 |
def s_mirror_h(td):
|
| 880 |
-
"""Output = input | flip(input, horizontal), doubling width."""
|
| 881 |
exs = get_exs(td)
|
| 882 |
sp = fixed_shapes(td)
|
| 883 |
if sp is None: return None
|
|
@@ -885,8 +785,7 @@ def s_mirror_h(td):
|
|
| 885 |
if OH != IH or OW != 2 * IW: return None
|
| 886 |
if OW > 30: return None
|
| 887 |
for inp, out in exs:
|
| 888 |
-
|
| 889 |
-
if not np.array_equal(expected, out): return None
|
| 890 |
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 891 |
for r in range(OH):
|
| 892 |
for c in range(OW):
|
|
@@ -895,7 +794,6 @@ def s_mirror_h(td):
|
|
| 895 |
return _build_gather_model(OH, OW, idx)
|
| 896 |
|
| 897 |
def s_mirror_v(td):
|
| 898 |
-
"""Output = input over flip(input, vertical), doubling height."""
|
| 899 |
exs = get_exs(td)
|
| 900 |
sp = fixed_shapes(td)
|
| 901 |
if sp is None: return None
|
|
@@ -903,8 +801,7 @@ def s_mirror_v(td):
|
|
| 903 |
if OW != IW or OH != 2 * IH: return None
|
| 904 |
if OH > 30: return None
|
| 905 |
for inp, out in exs:
|
| 906 |
-
|
| 907 |
-
if not np.array_equal(expected, out): return None
|
| 908 |
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 909 |
for r in range(OH):
|
| 910 |
for c in range(OW):
|
|
@@ -913,7 +810,6 @@ def s_mirror_v(td):
|
|
| 913 |
return _build_gather_model(OH, OW, idx)
|
| 914 |
|
| 915 |
def s_quad_mirror(td):
|
| 916 |
-
"""Output = 2x2 block of input with h/v flips."""
|
| 917 |
exs = get_exs(td)
|
| 918 |
sp = fixed_shapes(td)
|
| 919 |
if sp is None: return None
|
|
@@ -921,10 +817,8 @@ def s_quad_mirror(td):
|
|
| 921 |
if OH != 2 * IH or OW != 2 * IW: return None
|
| 922 |
if OH > 30 or OW > 30: return None
|
| 923 |
for inp, out in exs:
|
| 924 |
-
expected = np.block([
|
| 925 |
-
|
| 926 |
-
[np.flip(inp, 0), np.flip(np.flip(inp, 0), 1)]
|
| 927 |
-
])
|
| 928 |
if not np.array_equal(expected, out): return None
|
| 929 |
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 930 |
for r in range(OH):
|
|
@@ -935,7 +829,6 @@ def s_quad_mirror(td):
|
|
| 935 |
return _build_gather_model(OH, OW, idx)
|
| 936 |
|
| 937 |
def s_fixed_crop(td):
|
| 938 |
-
"""Output = fixed subregion of input."""
|
| 939 |
exs = get_exs(td)
|
| 940 |
sp = fixed_shapes(td)
|
| 941 |
if sp is None: return None
|
|
@@ -952,7 +845,6 @@ def s_fixed_crop(td):
|
|
| 952 |
return None
|
| 953 |
|
| 954 |
def s_nonuniform_scale(td):
|
| 955 |
-
"""Output = input scaled by different factors in h and w."""
|
| 956 |
exs = get_exs(td)
|
| 957 |
sp = fixed_shapes(td)
|
| 958 |
if sp is None: return None
|
|
@@ -969,6 +861,7 @@ def s_nonuniform_scale(td):
|
|
| 969 |
return None
|
| 970 |
|
| 971 |
def s_constant(td):
|
|
|
|
| 972 |
sp = fixed_shapes(td)
|
| 973 |
if sp is None: return None
|
| 974 |
exs = get_exs(td)
|
|
@@ -978,11 +871,16 @@ def s_constant(td):
|
|
| 978 |
for r, row in enumerate(outs[0]):
|
| 979 |
for c, v in enumerate(row):
|
| 980 |
const[0, int(v), r, c] = 1.0
|
| 981 |
-
inits = [
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 986 |
return mk(nodes, inits)
|
| 987 |
|
| 988 |
# ============================================================
|
|
@@ -990,18 +888,15 @@ def s_constant(td):
|
|
| 990 |
# ============================================================
|
| 991 |
|
| 992 |
def add_onehot_block(nodes, inits, am_name, oh_name):
|
| 993 |
-
"""Equal + Cast one-hot encoding (replaces OneHot which lacks CUDA kernel)."""
|
| 994 |
classes = np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1)
|
| 995 |
inits.append(numpy_helper.from_array(classes, 'classes'))
|
| 996 |
nodes.append(helper.make_node('Equal', [am_name, 'classes'], ['eq']))
|
| 997 |
nodes.append(helper.make_node('Cast', ['eq'], [oh_name], to=TensorProto.FLOAT))
|
| 998 |
|
| 999 |
def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
|
| 1000 |
-
"""Shared lstsq conv fitting. Returns (Wconv, B) or None."""
|
| 1001 |
pad = ks // 2
|
| 1002 |
feat = 10 * ks * ks + (1 if use_bias else 0)
|
| 1003 |
if feat > 20000: return None
|
| 1004 |
-
|
| 1005 |
patches, targets = [], []
|
| 1006 |
for inp_g, out_g in exs_raw:
|
| 1007 |
ih, iw = inp_g.shape
|
|
@@ -1013,7 +908,6 @@ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
|
|
| 1013 |
oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
|
| 1014 |
for c in range(10): oh_enc[c] = (inp_g == c)
|
| 1015 |
oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
|
| 1016 |
-
|
| 1017 |
oh, ow = out_g.shape
|
| 1018 |
for r in range(oh):
|
| 1019 |
for c in range(ow):
|
|
@@ -1021,18 +915,14 @@ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
|
|
| 1021 |
if use_bias: p = np.append(p, 1.0)
|
| 1022 |
patches.append(p)
|
| 1023 |
targets.append(int(out_g[r, c]))
|
| 1024 |
-
|
| 1025 |
n_patches = len(patches)
|
| 1026 |
if feat > 5000 and n_patches > 2000: return None
|
| 1027 |
-
|
| 1028 |
P = np.array(patches, dtype=np.float64)
|
| 1029 |
T = np.array(targets, dtype=np.int64)
|
| 1030 |
T_oh = np.zeros((len(T), 10), dtype=np.float64)
|
| 1031 |
for i, t in enumerate(T): T_oh[i, t] = 1.0
|
| 1032 |
-
|
| 1033 |
WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
|
| 1034 |
if not np.array_equal(np.argmax(P @ WT, axis=1), T): return None
|
| 1035 |
-
|
| 1036 |
if use_bias:
|
| 1037 |
Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
|
| 1038 |
B = WT[-1].astype(np.float32)
|
|
@@ -1042,17 +932,14 @@ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
|
|
| 1042 |
return Wconv, B
|
| 1043 |
|
| 1044 |
def solve_conv_fixed(td, path, time_budget=30.0):
|
| 1045 |
-
"""Fixed-shape conv: Slice -> Conv -> ArgMax -> Equal+Cast -> Pad."""
|
| 1046 |
exs = get_exs(td)
|
| 1047 |
for inp, out in exs:
|
| 1048 |
if inp.shape != out.shape: return None
|
| 1049 |
shapes = set(inp.shape for inp, _ in exs)
|
| 1050 |
if len(shapes) != 1: return None
|
| 1051 |
IH, IW = shapes.pop()
|
| 1052 |
-
|
| 1053 |
fit_exs = get_exs_for_fitting(td)
|
| 1054 |
fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
|
| 1055 |
-
|
| 1056 |
t_start = time.time()
|
| 1057 |
for use_bias in [False, True]:
|
| 1058 |
for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
|
|
@@ -1062,7 +949,6 @@ def solve_conv_fixed(td, path, time_budget=30.0):
|
|
| 1062 |
Wconv, B = result
|
| 1063 |
pad = ks // 2
|
| 1064 |
pad_h, pad_w = GH - IH, GW - IW
|
| 1065 |
-
|
| 1066 |
inits = [
|
| 1067 |
_make_int64_init('sl_st', [0,0,0,0]),
|
| 1068 |
_make_int64_init('sl_en', [1,10,IH,IW]),
|
|
@@ -1072,7 +958,6 @@ def solve_conv_fixed(td, path, time_budget=30.0):
|
|
| 1072 |
if B is not None:
|
| 1073 |
inits.append(numpy_helper.from_array(B, 'B'))
|
| 1074 |
conv_inputs.append('B')
|
| 1075 |
-
|
| 1076 |
nodes = [
|
| 1077 |
helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
|
| 1078 |
helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
|
|
@@ -1080,21 +965,18 @@ def solve_conv_fixed(td, path, time_budget=30.0):
|
|
| 1080 |
]
|
| 1081 |
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1082 |
nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
|
| 1083 |
-
|
| 1084 |
model = mk(nodes, inits)
|
| 1085 |
onnx.save(model, path)
|
| 1086 |
if validate(path, td): return 'conv_fixed', model
|
| 1087 |
return None
|
| 1088 |
|
| 1089 |
def solve_conv_variable(td, path, time_budget=30.0):
|
| 1090 |
-
"""Variable-shape conv
|
| 1091 |
exs = get_exs(td)
|
| 1092 |
for inp, out in exs:
|
| 1093 |
if inp.shape != out.shape: return None
|
| 1094 |
-
|
| 1095 |
fit_exs = get_exs_for_fitting_variable(td)
|
| 1096 |
fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape]
|
| 1097 |
-
|
| 1098 |
t_start = time.time()
|
| 1099 |
for use_bias in [False, True]:
|
| 1100 |
for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
|
|
@@ -1103,38 +985,35 @@ def solve_conv_variable(td, path, time_budget=30.0):
|
|
| 1103 |
if result is None: continue
|
| 1104 |
Wconv, B = result
|
| 1105 |
pad = ks // 2
|
| 1106 |
-
|
| 1107 |
-
|
|
|
|
|
|
|
| 1108 |
conv_inputs = ['input', 'W']
|
| 1109 |
if B is not None:
|
| 1110 |
inits.append(numpy_helper.from_array(B, 'B'))
|
| 1111 |
conv_inputs.append('B')
|
| 1112 |
-
|
| 1113 |
nodes = [
|
| 1114 |
-
helper.make_node('ReduceSum', ['input'
|
| 1115 |
helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
|
| 1116 |
helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
|
| 1117 |
]
|
| 1118 |
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1119 |
nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
|
| 1120 |
-
|
| 1121 |
model = mk(nodes, inits)
|
| 1122 |
onnx.save(model, path)
|
| 1123 |
if validate(path, td): return 'conv_var', model
|
| 1124 |
return None
|
| 1125 |
|
| 1126 |
def solve_conv_diffshape(td, path, time_budget=30.0):
|
| 1127 |
-
"""Diff-shape conv for fixed io shapes where output is smaller."""
|
| 1128 |
sp = fixed_shapes(td)
|
| 1129 |
if sp is None: return None
|
| 1130 |
(IH, IW), (OH, OW) = sp
|
| 1131 |
if IH == OH and IW == OW: return None
|
| 1132 |
if OH > IH or OW > IW: return None
|
| 1133 |
if OH > 30 or OW > 30: return None
|
| 1134 |
-
|
| 1135 |
exs = get_exs(td)
|
| 1136 |
t_start = time.time()
|
| 1137 |
-
|
| 1138 |
for dr_off, dc_off in [(0, 0), ((IH-OH)//2, (IW-OW)//2)]:
|
| 1139 |
for use_bias in [False, True]:
|
| 1140 |
for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
|
|
@@ -1142,7 +1021,6 @@ def solve_conv_diffshape(td, path, time_budget=30.0):
|
|
| 1142 |
pad = ks // 2
|
| 1143 |
feat = 10 * ks * ks + (1 if use_bias else 0)
|
| 1144 |
if feat > 10000: continue
|
| 1145 |
-
|
| 1146 |
patches, targets = [], []
|
| 1147 |
valid = True
|
| 1148 |
for inp_g, out_g in exs:
|
|
@@ -1161,25 +1039,20 @@ def solve_conv_diffshape(td, path, time_budget=30.0):
|
|
| 1161 |
if not valid: break
|
| 1162 |
if not valid: break
|
| 1163 |
if not valid: continue
|
| 1164 |
-
|
| 1165 |
n_patches = len(patches)
|
| 1166 |
if feat > 5000 and n_patches > 2000: continue
|
| 1167 |
-
|
| 1168 |
P = np.array(patches, dtype=np.float64)
|
| 1169 |
T = np.array(targets, dtype=np.int64)
|
| 1170 |
T_oh = np.zeros((len(T), 10), dtype=np.float64)
|
| 1171 |
for i, t in enumerate(T): T_oh[i, t] = 1.0
|
| 1172 |
-
|
| 1173 |
WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
|
| 1174 |
if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
|
| 1175 |
-
|
| 1176 |
if use_bias:
|
| 1177 |
Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
|
| 1178 |
B = WT[-1].astype(np.float32)
|
| 1179 |
else:
|
| 1180 |
Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
|
| 1181 |
B = None
|
| 1182 |
-
|
| 1183 |
pad_h, pad_w = GH - OH, GW - OW
|
| 1184 |
inits = [
|
| 1185 |
_make_int64_init('sl_st', [0,0,0,0]),
|
|
@@ -1192,7 +1065,6 @@ def solve_conv_diffshape(td, path, time_budget=30.0):
|
|
| 1192 |
if B is not None:
|
| 1193 |
inits.append(numpy_helper.from_array(B, 'B'))
|
| 1194 |
conv_inputs.append('B')
|
| 1195 |
-
|
| 1196 |
nodes = [
|
| 1197 |
helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
|
| 1198 |
helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
|
|
@@ -1201,25 +1073,21 @@ def solve_conv_diffshape(td, path, time_budget=30.0):
|
|
| 1201 |
]
|
| 1202 |
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1203 |
nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
|
| 1204 |
-
|
| 1205 |
model = mk(nodes, inits)
|
| 1206 |
onnx.save(model, path)
|
| 1207 |
if validate(path, td): return 'conv_diff', model
|
| 1208 |
return None
|
| 1209 |
|
| 1210 |
def solve_conv_var_diff(td, path, time_budget=30.0):
|
| 1211 |
-
"""Variable diff-shape conv
|
| 1212 |
exs = get_exs(td)
|
| 1213 |
-
|
| 1214 |
t_start = time.time()
|
| 1215 |
for use_bias in [False, True]:
|
| 1216 |
for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
|
| 1217 |
if time.time() - t_start > time_budget: return None
|
| 1218 |
-
|
| 1219 |
pad = ks // 2
|
| 1220 |
feat = 10 * ks * ks + (1 if use_bias else 0)
|
| 1221 |
if feat > 20000: continue
|
| 1222 |
-
|
| 1223 |
patches, targets = [], []
|
| 1224 |
for inp_g, out_g in exs:
|
| 1225 |
ih, iw = inp_g.shape
|
|
@@ -1227,56 +1095,49 @@ def solve_conv_var_diff(td, path, time_budget=30.0):
|
|
| 1227 |
oh_full = np.zeros((10, GH, GW), dtype=np.float64)
|
| 1228 |
for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
|
| 1229 |
oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
|
| 1230 |
-
|
| 1231 |
for r in range(oh):
|
| 1232 |
for c in range(ow):
|
| 1233 |
p = oh_pad[:, r:r+ks, c:c+ks].flatten()
|
| 1234 |
if use_bias: p = np.append(p, 1.0)
|
| 1235 |
patches.append(p)
|
| 1236 |
targets.append(int(out_g[r, c]))
|
| 1237 |
-
|
| 1238 |
n_patches = len(patches)
|
| 1239 |
if feat > 5000 and n_patches > 2000: continue
|
| 1240 |
-
|
| 1241 |
P = np.array(patches, dtype=np.float64)
|
| 1242 |
T = np.array(targets, dtype=np.int64)
|
| 1243 |
T_oh = np.zeros((len(T), 10), dtype=np.float64)
|
| 1244 |
for i, t in enumerate(T): T_oh[i, t] = 1.0
|
| 1245 |
-
|
| 1246 |
try:
|
| 1247 |
WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
|
| 1248 |
except:
|
| 1249 |
continue
|
| 1250 |
if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
|
| 1251 |
-
|
| 1252 |
if use_bias:
|
| 1253 |
Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
|
| 1254 |
B = WT[-1].astype(np.float32)
|
| 1255 |
else:
|
| 1256 |
Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
|
| 1257 |
B = None
|
| 1258 |
-
|
| 1259 |
-
# For tasks where output fits within input bounds, use input mask
|
| 1260 |
all_output_within_input = all(
|
| 1261 |
out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
|
| 1262 |
for inp_g, out_g in exs
|
| 1263 |
)
|
| 1264 |
-
|
| 1265 |
if all_output_within_input:
|
| 1266 |
-
inits = [
|
|
|
|
|
|
|
|
|
|
| 1267 |
conv_inputs = ['input', 'W']
|
| 1268 |
if B is not None:
|
| 1269 |
inits.append(numpy_helper.from_array(B, 'B'))
|
| 1270 |
conv_inputs.append('B')
|
| 1271 |
-
|
| 1272 |
nodes = [
|
| 1273 |
-
helper.make_node('ReduceSum', ['input'
|
| 1274 |
helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
|
| 1275 |
helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
|
| 1276 |
]
|
| 1277 |
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1278 |
nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
|
| 1279 |
-
|
| 1280 |
model = mk(nodes, inits)
|
| 1281 |
onnx.save(model, path)
|
| 1282 |
if validate(path, td): return 'conv_var_diff', model
|
|
@@ -1310,10 +1171,7 @@ ANALYTICAL_SOLVERS = [
|
|
| 1310 |
]
|
| 1311 |
|
| 1312 |
def solve_task(tn, td, output_dir, conv_budget=30.0, verbose=True):
|
| 1313 |
-
"""Try all solvers on a task. Returns (solver_name, score) or None."""
|
| 1314 |
path = os.path.join(output_dir, f"task{tn:03d}.onnx")
|
| 1315 |
-
|
| 1316 |
-
# Try analytical solvers first (instant, arc-gen safe)
|
| 1317 |
for name, solver in ANALYTICAL_SOLVERS:
|
| 1318 |
try:
|
| 1319 |
model = solver(td)
|
|
@@ -1331,8 +1189,6 @@ def solve_task(tn, td, output_dir, conv_budget=30.0, verbose=True):
|
|
| 1331 |
return name, score
|
| 1332 |
else:
|
| 1333 |
if verbose: print(f" {name}: model built but FAILED validation")
|
| 1334 |
-
|
| 1335 |
-
# Try conv solvers
|
| 1336 |
conv_solvers = [
|
| 1337 |
('conv_fixed', solve_conv_fixed),
|
| 1338 |
('conv_variable', solve_conv_variable),
|
|
@@ -1354,35 +1210,28 @@ def solve_task(tn, td, output_dir, conv_budget=30.0, verbose=True):
|
|
| 1354 |
score = max(1.0, 25.0 - math.log(cost)) if cost > 0 else 25.0
|
| 1355 |
if verbose: print(f" {solver_type}: PASS cost={cost} score={score:.2f}")
|
| 1356 |
return solver_type, score
|
| 1357 |
-
|
| 1358 |
return None
|
| 1359 |
|
| 1360 |
def main():
|
| 1361 |
parser = argparse.ArgumentParser(description='NeuroGolf Solver v5')
|
| 1362 |
-
parser.add_argument('--data_dir', type=str, default=None
|
| 1363 |
-
parser.add_argument('--kaggle_dir', type=str, default=None
|
| 1364 |
-
parser.add_argument('--arcgen_dir', type=str, default=None
|
| 1365 |
-
parser.add_argument('--output_dir', type=str, default='submission'
|
| 1366 |
-
parser.add_argument('--conv_budget', type=float, default=30.0
|
| 1367 |
-
parser.add_argument('--task', type=int, default=None
|
| 1368 |
parser.add_argument('--verbose', action='store_true', default=True)
|
| 1369 |
parser.add_argument('--quiet', action='store_true', default=False)
|
| 1370 |
args = parser.parse_args()
|
| 1371 |
-
|
| 1372 |
if args.quiet:
|
| 1373 |
args.verbose = False
|
| 1374 |
-
|
| 1375 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 1376 |
-
|
| 1377 |
-
# Load tasks
|
| 1378 |
if args.kaggle_dir:
|
| 1379 |
tasks = load_tasks_kaggle(args.kaggle_dir)
|
| 1380 |
elif args.data_dir:
|
| 1381 |
tasks = load_tasks_dir(args.data_dir, args.arcgen_dir)
|
| 1382 |
else:
|
| 1383 |
-
|
| 1384 |
-
for p in ['/kaggle/input/competitions/neurogolf-2026/',
|
| 1385 |
-
'ARC-AGI/data/training/']:
|
| 1386 |
if os.path.exists(p):
|
| 1387 |
if 'kaggle' in p:
|
| 1388 |
tasks = load_tasks_kaggle(p)
|
|
@@ -1392,15 +1241,11 @@ def main():
|
|
| 1392 |
else:
|
| 1393 |
print("ERROR: No data directory found. Use --data_dir or --kaggle_dir")
|
| 1394 |
sys.exit(1)
|
| 1395 |
-
|
| 1396 |
-
# Solve tasks
|
| 1397 |
results = {}
|
| 1398 |
total_score = 0.0
|
| 1399 |
solved = 0
|
| 1400 |
t_total = time.time()
|
| 1401 |
-
|
| 1402 |
task_nums = [args.task] if args.task else sorted(tasks.keys())
|
| 1403 |
-
|
| 1404 |
for tn in task_nums:
|
| 1405 |
if tn in EXCLUDED_TASKS:
|
| 1406 |
if args.verbose: print(f"Task {tn:3d}: EXCLUDED")
|
|
@@ -1408,44 +1253,32 @@ def main():
|
|
| 1408 |
if tn not in tasks:
|
| 1409 |
if args.verbose: print(f"Task {tn:3d}: NOT FOUND")
|
| 1410 |
continue
|
| 1411 |
-
|
| 1412 |
td = tasks[tn]['data']
|
| 1413 |
hex_id = tasks[tn]['hex']
|
| 1414 |
-
|
| 1415 |
if args.verbose: print(f"\nTask {tn:3d} ({hex_id}):")
|
| 1416 |
-
|
| 1417 |
result = solve_task(tn, td, args.output_dir, args.conv_budget, args.verbose)
|
| 1418 |
-
|
| 1419 |
if result is not None:
|
| 1420 |
solver_type, score = result
|
| 1421 |
results[tn] = {'solver': solver_type, 'score': score, 'hex': hex_id}
|
| 1422 |
total_score += score
|
| 1423 |
solved += 1
|
| 1424 |
else:
|
| 1425 |
-
# Unsolved tasks score 1.0 (minimum)
|
| 1426 |
total_score += 1.0
|
| 1427 |
if args.verbose: print(f" UNSOLVED")
|
| 1428 |
-
|
| 1429 |
-
# Summary
|
| 1430 |
elapsed = time.time() - t_total
|
| 1431 |
print(f"\n{'='*60}")
|
| 1432 |
print(f"RESULTS: {solved}/{len(task_nums)} tasks solved")
|
| 1433 |
print(f"Total score: {total_score:.1f}")
|
| 1434 |
print(f"Time: {elapsed:.1f}s")
|
| 1435 |
print(f"{'='*60}")
|
| 1436 |
-
|
| 1437 |
-
# Breakdown by solver type
|
| 1438 |
solver_counts = Counter(r['solver'] for r in results.values())
|
| 1439 |
solver_scores = {}
|
| 1440 |
for tn, r in results.items():
|
| 1441 |
st = r['solver']
|
| 1442 |
solver_scores[st] = solver_scores.get(st, 0) + r['score']
|
| 1443 |
-
|
| 1444 |
print("\nSolver breakdown:")
|
| 1445 |
for st in sorted(solver_counts.keys()):
|
| 1446 |
print(f" {st}: {solver_counts[st]} tasks, total score {solver_scores[st]:.1f}, avg {solver_scores[st]/solver_counts[st]:.2f}")
|
| 1447 |
-
|
| 1448 |
-
# Generate submission.csv
|
| 1449 |
csv_path = os.path.join(args.output_dir, 'submission.csv')
|
| 1450 |
with open(csv_path, 'w', newline='') as f:
|
| 1451 |
w = csv.writer(f)
|
|
@@ -1453,15 +1286,12 @@ def main():
|
|
| 1453 |
for tn in sorted(results.keys()):
|
| 1454 |
r = results[tn]
|
| 1455 |
w.writerow([tn, r['hex'], r['solver'], f"{r['score']:.3f}", f"task{tn:03d}.onnx"])
|
| 1456 |
-
|
| 1457 |
-
# Generate submission.zip
|
| 1458 |
zip_path = os.path.join(args.output_dir, 'submission.zip')
|
| 1459 |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
|
| 1460 |
for tn in sorted(results.keys()):
|
| 1461 |
onnx_path = os.path.join(args.output_dir, f"task{tn:03d}.onnx")
|
| 1462 |
if os.path.exists(onnx_path):
|
| 1463 |
zf.write(onnx_path, f"task{tn:03d}.onnx")
|
| 1464 |
-
|
| 1465 |
print(f"\nSubmission files: {csv_path}, {zip_path}")
|
| 1466 |
print(f"Models in zip: {len(results)}")
|
| 1467 |
|
|
|
|
| 9 |
- s_rotate k=2: double Slice(step=-1) — 0 MACs (was ~165K)
|
| 10 |
- s_rotate k=1,3: Slice+Transpose for square grids (0 MACs), Gather fallback for non-square
|
| 11 |
- All Pad nodes: tensor-based pads input (opset 17 requirement)
|
| 12 |
+
- All ReduceSum nodes: axes as tensor input (opset 13+ requirement)
|
| 13 |
- All other solvers unchanged from v4
|
| 14 |
|
| 15 |
Solvers:
|
|
|
|
| 51 |
|
| 52 |
INT64_MIN = int(np.iinfo(np.int64).min)
|
| 53 |
|
|
|
|
| 54 |
EXCLUDED_TASKS = {21, 55, 80, 184, 202, 366}
|
|
|
|
|
|
|
| 55 |
MAX_ARCGEN_VALIDATE = 30
|
|
|
|
| 56 |
MAX_ARCGEN_FIT = 0
|
| 57 |
|
| 58 |
def get_providers():
|
|
|
|
| 65 |
# ============================================================
|
| 66 |
|
| 67 |
def load_tasks_dir(data_dir, arcgen_dir=None):
|
|
|
|
| 68 |
files = sorted(f for f in os.listdir(data_dir) if f.endswith('.json'))
|
| 69 |
tasks = {}
|
| 70 |
for i, f in enumerate(files):
|
|
|
|
| 82 |
return tasks
|
| 83 |
|
| 84 |
def load_tasks_kaggle(data_dir):
|
|
|
|
| 85 |
tasks = {}
|
| 86 |
for tn in range(1, 401):
|
| 87 |
path = os.path.join(data_dir, f"task{tn:03d}.json")
|
|
|
|
| 102 |
return arr
|
| 103 |
|
| 104 |
def validate(path, td):
|
|
|
|
| 105 |
try:
|
| 106 |
opts = ort.SessionOptions()
|
| 107 |
opts.log_severity_level = 3
|
|
|
|
| 124 |
return True
|
| 125 |
|
| 126 |
def validate_raw(raw_bytes, td):
|
|
|
|
| 127 |
try:
|
| 128 |
opts = ort.SessionOptions()
|
| 129 |
opts.log_severity_level = 3
|
|
|
|
| 146 |
return True
|
| 147 |
|
| 148 |
# ============================================================
|
| 149 |
+
# STATIC PROFILER
|
| 150 |
# ============================================================
|
| 151 |
|
| 152 |
BANNED_OPS = {'Loop', 'Scan', 'NonZero', 'Unique', 'If', 'Function'}
|
| 153 |
MAX_FILESIZE = int(1.44 * 1024 * 1024)
|
| 154 |
|
| 155 |
def score_network(path):
|
|
|
|
| 156 |
if HAS_ONNX_TOOL:
|
| 157 |
try:
|
| 158 |
return _score_network_official(path)
|
|
|
|
| 161 |
return _static_profile(path)
|
| 162 |
|
| 163 |
def _static_profile(path):
|
|
|
|
| 164 |
try:
|
| 165 |
model = onnx.load(path)
|
| 166 |
except:
|
| 167 |
return None, None, None
|
|
|
|
| 168 |
tensors = {}
|
| 169 |
params = 0
|
| 170 |
nbytes = 0
|
| 171 |
macs = 0
|
|
|
|
| 172 |
for init in model.graph.initializer:
|
| 173 |
a = numpy_helper.to_array(init)
|
| 174 |
tensors[init.name] = a
|
| 175 |
params += a.size
|
| 176 |
nbytes += a.nbytes
|
|
|
|
| 177 |
for nd in model.graph.node:
|
| 178 |
if nd.op_type == 'Constant':
|
| 179 |
for attr in nd.attribute:
|
|
|
|
| 186 |
nbytes += a.nbytes
|
| 187 |
except:
|
| 188 |
pass
|
|
|
|
| 189 |
if nd.op_type in BANNED_OPS:
|
| 190 |
return None, None, None
|
|
|
|
| 191 |
if nd.op_type == 'Conv' and len(nd.input) >= 2 and nd.input[1] in tensors:
|
| 192 |
w = tensors[nd.input[1]]
|
| 193 |
if w.ndim == 4:
|
| 194 |
co, ci, kh, kw = w.shape
|
| 195 |
macs += co * ci * kh * kw * GH * GW
|
|
|
|
| 196 |
return int(macs), int(nbytes), int(params)
|
| 197 |
|
| 198 |
# ============================================================
|
|
|
|
| 200 |
# ============================================================
|
| 201 |
|
| 202 |
def _make_int64_init(name, values):
|
|
|
|
| 203 |
return numpy_helper.from_array(np.array(values, dtype=np.int64), name)
|
| 204 |
|
| 205 |
def _build_pad_node(input_name, output_name, pad_h, pad_w, inits, suffix=''):
|
| 206 |
+
"""Pad with tensor-based pads input (opset 11+)."""
|
|
|
|
| 207 |
pads_name = f'pads{suffix}'
|
| 208 |
cv_name = f'pad_cv{suffix}'
|
| 209 |
pads_arr = np.array([0, 0, 0, 0, 0, 0, pad_h, pad_w], dtype=np.int64)
|
|
|
|
| 212 |
return helper.make_node('Pad', [input_name, pads_name, cv_name], [output_name], mode='constant')
|
| 213 |
|
| 214 |
def _build_slice_crop(input_name, output_name, IH, IW, inits, suffix=''):
|
| 215 |
+
"""Slice to crop [1,10,30,30] to [1,10,IH,IW]."""
|
| 216 |
st_name = f'crop_st{suffix}'
|
| 217 |
en_name = f'crop_en{suffix}'
|
| 218 |
inits.append(_make_int64_init(st_name, [0, 0, 0, 0]))
|
|
|
|
| 220 |
return helper.make_node('Slice', [input_name, st_name, en_name], [output_name])
|
| 221 |
|
| 222 |
def _build_slice_reverse(input_name, output_name, axis, dim_size, inits, suffix=''):
|
| 223 |
+
"""Slice(step=-1) to reverse one axis. Zero MACs."""
|
| 224 |
st_name = f'rev_st{suffix}'
|
| 225 |
en_name = f'rev_en{suffix}'
|
| 226 |
ax_name = f'rev_ax{suffix}'
|
|
|
|
| 231 |
inits.append(_make_int64_init(sp_name, [-1]))
|
| 232 |
return helper.make_node('Slice', [input_name, st_name, en_name, ax_name, sp_name], [output_name])
|
| 233 |
|
| 234 |
+
def _build_reducesum(input_name, output_name, axes_list, inits, suffix=''):
|
| 235 |
+
"""ReduceSum with axes as tensor input (opset 13+). keepdims=1."""
|
| 236 |
+
axes_name = f'rs_axes{suffix}'
|
| 237 |
+
inits.append(_make_int64_init(axes_name, axes_list))
|
| 238 |
+
return helper.make_node('ReduceSum', [input_name, axes_name], [output_name], keepdims=1)
|
| 239 |
+
|
| 240 |
def mk(nodes, inits=None):
|
| 241 |
x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
|
| 242 |
y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
|
|
|
|
| 244 |
return helper.make_model(g, ir_version=IR, opset_imports=OPSET)
|
| 245 |
|
| 246 |
def get_exs(td):
|
|
|
|
| 247 |
return [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
|
| 248 |
for ex in td['train'] + td['test']]
|
| 249 |
|
| 250 |
def get_exs_for_fitting(td):
|
|
|
|
| 251 |
base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
|
| 252 |
for ex in td['train'] + td['test']]
|
|
|
|
| 253 |
if not base_exs:
|
| 254 |
return base_exs
|
|
|
|
| 255 |
base_shapes = {inp.shape for inp, _ in base_exs}
|
| 256 |
if len(base_shapes) != 1:
|
| 257 |
return base_exs
|
|
|
|
| 258 |
base_shape = list(base_shapes)[0]
|
|
|
|
| 259 |
ag_exs = []
|
| 260 |
for ex in td.get('arc-gen', []):
|
| 261 |
inp = np.array(ex['input'], dtype=np.int64)
|
| 262 |
out = np.array(ex['output'], dtype=np.int64)
|
| 263 |
if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
|
| 264 |
ag_exs.append((inp, out))
|
|
|
|
| 265 |
return base_exs + ag_exs[:10]
|
| 266 |
|
| 267 |
def get_exs_for_fitting_variable(td):
|
|
|
|
| 268 |
base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
|
| 269 |
for ex in td['train'] + td['test']]
|
|
|
|
| 270 |
ag_exs = []
|
| 271 |
for ex in td.get('arc-gen', []):
|
| 272 |
inp = np.array(ex['input'], dtype=np.int64)
|
| 273 |
out = np.array(ex['output'], dtype=np.int64)
|
| 274 |
if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
|
| 275 |
ag_exs.append((inp, out))
|
|
|
|
| 276 |
return base_exs + ag_exs[:20]
|
| 277 |
|
| 278 |
def fixed_shapes(td):
|
|
|
|
| 282 |
return list(shapes)[0] if len(shapes) == 1 else None
|
| 283 |
|
| 284 |
# ============================================================
|
| 285 |
+
# GATHER HELPERS
|
| 286 |
# ============================================================
|
| 287 |
|
| 288 |
def _build_gather_model(OH, OW, idx):
|
|
|
|
| 289 |
flat_idx = np.zeros((GH*GW,), dtype=np.int64)
|
| 290 |
mask = np.zeros((1,1,GH,GW), dtype=np.float32)
|
| 291 |
for oi in range(OH):
|
|
|
|
| 307 |
return mk(nodes, inits)
|
| 308 |
|
| 309 |
def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
|
|
|
|
| 310 |
flat_idx = np.zeros((GH*GW,), dtype=np.int64)
|
| 311 |
gather_mask = np.zeros((1,1,GH,GW), dtype=np.float32)
|
| 312 |
const_oh = np.zeros((1,10,GH,GW), dtype=np.float32)
|
|
|
|
| 355 |
iv, ov = int(iv), int(ov)
|
| 356 |
if iv in cm and cm[iv] != ov: return None
|
| 357 |
cm[iv] = ov
|
|
|
|
| 358 |
is_permutation = (set(cm.keys()) == set(cm.values()))
|
|
|
|
| 359 |
if is_permutation:
|
| 360 |
gather_ch = np.arange(10, dtype=np.int32)
|
| 361 |
for src, dst in cm.items():
|
|
|
|
| 372 |
[numpy_helper.from_array(W, 'W')])
|
| 373 |
|
| 374 |
def s_transpose(td):
|
|
|
|
| 375 |
for ex in td['train']+td['test']:
|
| 376 |
if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T): return None
|
| 377 |
return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0,1,3,2])])
|
| 378 |
|
| 379 |
def s_flip(td):
|
| 380 |
+
"""Flip using Slice(step=-1) — zero MACs."""
|
| 381 |
exs = get_exs(td)
|
| 382 |
sp = fixed_shapes(td)
|
| 383 |
if sp is None: return None
|
| 384 |
(IH,IW),(OH,OW) = sp
|
| 385 |
if (IH,IW) != (OH,OW): return None
|
|
|
|
| 386 |
for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
|
| 387 |
if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
|
|
|
|
|
|
|
| 388 |
onnx_axis = 2 if axis == 0 else 3
|
| 389 |
dim_size = IH if axis == 0 else IW
|
| 390 |
pad_h, pad_w = GH - IH, GW - IW
|
|
|
|
| 391 |
inits = []
|
| 392 |
nodes = []
|
|
|
|
|
|
|
| 393 |
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
|
|
|
|
|
|
| 394 |
nodes.append(_build_slice_reverse('cropped', 'flipped', onnx_axis, dim_size, inits))
|
|
|
|
|
|
|
| 395 |
nodes.append(_build_pad_node('flipped', 'output', pad_h, pad_w, inits))
|
|
|
|
| 396 |
return mk(nodes, inits)
|
| 397 |
return None
|
| 398 |
|
| 399 |
def s_rotate(td):
|
| 400 |
+
"""Rotate using Slice+Transpose — zero MACs for square grids and k=2.
|
| 401 |
+
Gather fallback for non-square k=1,3."""
|
| 402 |
exs = get_exs(td)
|
| 403 |
sp = fixed_shapes(td)
|
| 404 |
if sp is None: return None
|
| 405 |
(IH,IW),(OH,OW) = sp
|
|
|
|
| 406 |
for k in [1, 2, 3]:
|
| 407 |
if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs):
|
| 408 |
continue
|
|
|
|
| 409 |
if k == 2:
|
|
|
|
|
|
|
| 410 |
pad_h, pad_w = GH - OH, GW - OW
|
| 411 |
inits = []
|
| 412 |
nodes = []
|
|
|
|
|
|
|
| 413 |
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
|
|
|
| 414 |
nodes.append(_build_slice_reverse('cropped', 'flip_h', 2, IH, inits, suffix='_h'))
|
|
|
|
| 415 |
nodes.append(_build_slice_reverse('flip_h', 'rotated', 3, IW, inits, suffix='_w'))
|
|
|
|
| 416 |
nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
|
|
|
|
| 417 |
return mk(nodes, inits)
|
|
|
|
| 418 |
elif k == 1 and IH == IW:
|
| 419 |
+
# rot90 CCW square: Transpose then flip axis 2
|
|
|
|
|
|
|
|
|
|
| 420 |
pad_h, pad_w = GH - IH, GW - IW
|
| 421 |
inits = []
|
| 422 |
nodes = []
|
|
|
|
| 423 |
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
| 424 |
nodes.append(helper.make_node('Transpose', ['cropped'], ['transposed'], perm=[0,1,3,2]))
|
| 425 |
nodes.append(_build_slice_reverse('transposed', 'rotated', 2, IH, inits))
|
| 426 |
nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
|
|
|
|
| 427 |
return mk(nodes, inits)
|
|
|
|
| 428 |
elif k == 3 and IH == IW:
|
| 429 |
+
# rot270 CCW square: flip axis 2 then Transpose
|
|
|
|
|
|
|
|
|
|
| 430 |
pad_h, pad_w = GH - IH, GW - IW
|
| 431 |
inits = []
|
| 432 |
nodes = []
|
|
|
|
| 433 |
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
| 434 |
nodes.append(_build_slice_reverse('cropped', 'flipped', 2, IH, inits))
|
| 435 |
nodes.append(helper.make_node('Transpose', ['flipped'], ['rotated'], perm=[0,1,3,2]))
|
| 436 |
nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
|
|
|
|
| 437 |
return mk(nodes, inits)
|
|
|
|
| 438 |
else:
|
| 439 |
+
# Non-square k=1 or k=3: Gather fallback
|
| 440 |
idx = np.zeros((OH,OW,2), dtype=np.int64)
|
| 441 |
for r in range(OH):
|
| 442 |
for c in range(OW):
|
|
|
|
| 467 |
return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
|
| 468 |
|
| 469 |
def s_varshape_spatial_gather(td):
|
|
|
|
| 470 |
sp = fixed_shapes(td)
|
| 471 |
if sp is not None: return None
|
| 472 |
exs = get_exs(td)
|
|
|
|
| 473 |
exs_30 = []
|
| 474 |
for inp, out in exs:
|
| 475 |
ih, iw = inp.shape
|
|
|
|
| 479 |
inp30[:ih, :iw] = inp
|
| 480 |
out30[:oh, :ow] = out
|
| 481 |
exs_30.append((inp30, out30))
|
|
|
|
| 482 |
idx = np.full((30, 30, 2), -1, dtype=np.int64)
|
| 483 |
cst = np.full((30, 30), -1, dtype=np.int64)
|
|
|
|
| 484 |
for oi in range(30):
|
| 485 |
for oj in range(30):
|
| 486 |
vals = set(int(out30[oi, oj]) for _, out30 in exs_30)
|
|
|
|
| 496 |
if found: break
|
| 497 |
if not found and cst[oi, oj] < 0:
|
| 498 |
return None
|
|
|
|
| 499 |
return _build_gather_model_with_const(30, 30, 30, 30, idx, cst)
|
| 500 |
|
| 501 |
def s_tile(td):
|
|
|
|
| 600 |
return None
|
| 601 |
|
| 602 |
def s_concat_enhanced(td):
|
|
|
|
| 603 |
exs = get_exs(td)
|
| 604 |
sp = fixed_shapes(td)
|
| 605 |
if sp is None: return None
|
| 606 |
(IH,IW),(OH,OW) = sp
|
| 607 |
if IH == OH and IW == OW: return None
|
|
|
|
| 608 |
if OH % IH != 0 or OW % IW != 0: return None
|
| 609 |
rH, rW = OH // IH, OW // IW
|
| 610 |
if rH * rW > 16 or rH * rW < 2: return None
|
| 611 |
if OH > 30 or OW > 30: return None
|
|
|
|
| 612 |
transforms = [
|
| 613 |
+
('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
|
| 614 |
+
('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
|
| 615 |
+
('rot90', lambda x: np.rot90(x, 1)), ('rot270', lambda x: np.rot90(x, 3)),
|
| 616 |
+
('T', lambda x: x.T), ('T_fliplr', lambda x: np.fliplr(x.T)),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 617 |
]
|
|
|
|
| 618 |
block_transforms = {}
|
| 619 |
for bi in range(rH):
|
| 620 |
for bj in range(rW):
|
|
|
|
| 625 |
block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
|
| 626 |
expected = tfn(inp)
|
| 627 |
if expected.shape != (IH, IW) or not np.array_equal(block, expected):
|
| 628 |
+
ok = False; break
|
|
|
|
| 629 |
if ok:
|
| 630 |
+
found = (tidx, tname); break
|
| 631 |
+
if found is None: return None
|
|
|
|
|
|
|
| 632 |
block_transforms[(bi, bj)] = found
|
|
|
|
| 633 |
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 634 |
for bi in range(rH):
|
| 635 |
for bj in range(rW):
|
|
|
|
| 646 |
elif tname == 'T': sr, sc = lc, lr
|
| 647 |
elif tname == 'T_fliplr': sr, sc = IW-1-lc, lr
|
| 648 |
idx[oi, oj] = [sr, sc]
|
|
|
|
| 649 |
for inp, out in exs:
|
| 650 |
reconstructed = np.zeros_like(out)
|
| 651 |
for oi in range(OH):
|
| 652 |
for oj in range(OW):
|
| 653 |
reconstructed[oi,oj] = inp[idx[oi,oj,0], idx[oi,oj,1]]
|
| 654 |
+
if not np.array_equal(reconstructed, out): return None
|
|
|
|
|
|
|
| 655 |
return _build_gather_model(OH, OW, idx)
|
| 656 |
|
| 657 |
def s_input_driven_tile(td):
|
|
|
|
| 658 |
exs = get_exs(td)
|
| 659 |
sp = fixed_shapes(td)
|
| 660 |
if sp is None: return None
|
|
|
|
| 663 |
sH, sW = OH // IH, OW // IW
|
| 664 |
if sH != IH or sW != IW: return None
|
| 665 |
if OH > 30 or OW > 30: return None
|
|
|
|
| 666 |
for inp, out in exs:
|
| 667 |
for bi in range(IH):
|
| 668 |
for bj in range(IW):
|
| 669 |
block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
|
| 670 |
if inp[bi, bj] != 0:
|
| 671 |
+
if not np.array_equal(block, inp): return None
|
|
|
|
| 672 |
else:
|
| 673 |
+
if not np.all(block == 0): return None
|
|
|
|
| 674 |
return None
|
| 675 |
|
| 676 |
def s_kronecker(td):
|
|
|
|
| 677 |
exs = get_exs(td)
|
| 678 |
sp = fixed_shapes(td)
|
| 679 |
if sp is None: return None
|
|
|
|
| 682 |
sH, sW = OH // IH, OW // IW
|
| 683 |
if sH < 2 or sW < 2: return None
|
| 684 |
if OH > 30 or OW > 30: return None
|
|
|
|
| 685 |
for inp, out in exs:
|
| 686 |
+
if not np.array_equal(out, np.kron(inp, np.ones((sH, sW), dtype=np.int64))): return None
|
|
|
|
|
|
|
|
|
|
| 687 |
idx = np.zeros((OH,OW,2), dtype=np.int64)
|
| 688 |
for r in range(OH):
|
| 689 |
for c in range(OW):
|
|
|
|
| 691 |
return _build_gather_model(OH, OW, idx)
|
| 692 |
|
| 693 |
def s_diagonal_tile(td):
|
|
|
|
| 694 |
exs = get_exs(td)
|
| 695 |
sp = fixed_shapes(td)
|
| 696 |
if sp is None: return None
|
|
|
|
| 699 |
rH, rW = OH // IH, OW // IW
|
| 700 |
if rH != rW or rH < 2: return None
|
| 701 |
if OH > 30 or OW > 30: return None
|
|
|
|
| 702 |
for inp, out in exs:
|
| 703 |
for bi in range(rH):
|
| 704 |
for bj in range(rW):
|
| 705 |
block = out[bi*IH:(bi+1)*IH, bj*IW:(bj+1)*IW]
|
| 706 |
if bi == bj:
|
| 707 |
+
if not np.array_equal(block, inp): return None
|
|
|
|
| 708 |
else:
|
| 709 |
+
if not np.all(block == 0): return None
|
|
|
|
|
|
|
| 710 |
idx = np.zeros((OH,OW,2), dtype=np.int64)
|
| 711 |
cst = np.full((OH,OW), -1, dtype=np.int64)
|
| 712 |
for bi in range(rH):
|
|
|
|
| 719 |
else:
|
| 720 |
idx[oi, oj] = [-1, -1]
|
| 721 |
cst[oi, oj] = 0
|
|
|
|
| 722 |
return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
|
| 723 |
|
| 724 |
def s_shift(td):
|
|
|
|
| 725 |
exs = get_exs(td)
|
| 726 |
sp = fixed_shapes(td)
|
| 727 |
if sp is None: return None
|
|
|
|
| 754 |
return None
|
| 755 |
|
| 756 |
def s_gravity(td):
|
|
|
|
| 757 |
exs = get_exs(td)
|
| 758 |
sp = fixed_shapes(td)
|
| 759 |
if sp is None: return None
|
| 760 |
(IH, IW), (OH, OW) = sp
|
| 761 |
if (IH, IW) != (OH, OW): return None
|
|
|
|
| 762 |
def _gravity(grid, direction):
|
| 763 |
r = np.zeros_like(grid); h, w = grid.shape
|
| 764 |
if direction in ('down', 'up'):
|
|
|
|
| 772 |
if direction == 'right': r[rr, w-len(nz):w] = nz
|
| 773 |
else: r[rr, :len(nz)] = nz
|
| 774 |
return r
|
|
|
|
| 775 |
for d in ('down', 'up', 'left', 'right'):
|
| 776 |
if all(np.array_equal(_gravity(inp, d), out) for inp, out in exs):
|
| 777 |
return None
|
| 778 |
return None
|
| 779 |
|
| 780 |
def s_mirror_h(td):
|
|
|
|
| 781 |
exs = get_exs(td)
|
| 782 |
sp = fixed_shapes(td)
|
| 783 |
if sp is None: return None
|
|
|
|
| 785 |
if OH != IH or OW != 2 * IW: return None
|
| 786 |
if OW > 30: return None
|
| 787 |
for inp, out in exs:
|
| 788 |
+
if not np.array_equal(np.concatenate([inp, np.flip(inp, 1)], 1), out): return None
|
|
|
|
| 789 |
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 790 |
for r in range(OH):
|
| 791 |
for c in range(OW):
|
|
|
|
| 794 |
return _build_gather_model(OH, OW, idx)
|
| 795 |
|
| 796 |
def s_mirror_v(td):
|
|
|
|
| 797 |
exs = get_exs(td)
|
| 798 |
sp = fixed_shapes(td)
|
| 799 |
if sp is None: return None
|
|
|
|
| 801 |
if OW != IW or OH != 2 * IH: return None
|
| 802 |
if OH > 30: return None
|
| 803 |
for inp, out in exs:
|
| 804 |
+
if not np.array_equal(np.concatenate([inp, np.flip(inp, 0)], 0), out): return None
|
|
|
|
| 805 |
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 806 |
for r in range(OH):
|
| 807 |
for c in range(OW):
|
|
|
|
| 810 |
return _build_gather_model(OH, OW, idx)
|
| 811 |
|
| 812 |
def s_quad_mirror(td):
|
|
|
|
| 813 |
exs = get_exs(td)
|
| 814 |
sp = fixed_shapes(td)
|
| 815 |
if sp is None: return None
|
|
|
|
| 817 |
if OH != 2 * IH or OW != 2 * IW: return None
|
| 818 |
if OH > 30 or OW > 30: return None
|
| 819 |
for inp, out in exs:
|
| 820 |
+
expected = np.block([[inp, np.flip(inp, 1)],
|
| 821 |
+
[np.flip(inp, 0), np.flip(np.flip(inp, 0), 1)]])
|
|
|
|
|
|
|
| 822 |
if not np.array_equal(expected, out): return None
|
| 823 |
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 824 |
for r in range(OH):
|
|
|
|
| 829 |
return _build_gather_model(OH, OW, idx)
|
| 830 |
|
| 831 |
def s_fixed_crop(td):
|
|
|
|
| 832 |
exs = get_exs(td)
|
| 833 |
sp = fixed_shapes(td)
|
| 834 |
if sp is None: return None
|
|
|
|
| 845 |
return None
|
| 846 |
|
| 847 |
def s_nonuniform_scale(td):
|
|
|
|
| 848 |
exs = get_exs(td)
|
| 849 |
sp = fixed_shapes(td)
|
| 850 |
if sp is None: return None
|
|
|
|
| 861 |
return None
|
| 862 |
|
| 863 |
def s_constant(td):
|
| 864 |
+
"""Constant output. Uses opset 17 ReduceSum with tensor axes input."""
|
| 865 |
sp = fixed_shapes(td)
|
| 866 |
if sp is None: return None
|
| 867 |
exs = get_exs(td)
|
|
|
|
| 871 |
for r, row in enumerate(outs[0]):
|
| 872 |
for c, v in enumerate(row):
|
| 873 |
const[0, int(v), r, c] = 1.0
|
| 874 |
+
inits = [
|
| 875 |
+
numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
|
| 876 |
+
numpy_helper.from_array(const, 'c'),
|
| 877 |
+
_make_int64_init('rs_axes_cst', [1, 2, 3]),
|
| 878 |
+
]
|
| 879 |
+
nodes = [
|
| 880 |
+
helper.make_node('Mul', ['input','z'], ['zd']),
|
| 881 |
+
helper.make_node('ReduceSum', ['zd', 'rs_axes_cst'], ['s'], keepdims=1),
|
| 882 |
+
helper.make_node('Add', ['s','c'], ['output']),
|
| 883 |
+
]
|
| 884 |
return mk(nodes, inits)
|
| 885 |
|
| 886 |
# ============================================================
|
|
|
|
| 888 |
# ============================================================
|
| 889 |
|
| 890 |
def add_onehot_block(nodes, inits, am_name, oh_name):
|
|
|
|
| 891 |
classes = np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1)
|
| 892 |
inits.append(numpy_helper.from_array(classes, 'classes'))
|
| 893 |
nodes.append(helper.make_node('Equal', [am_name, 'classes'], ['eq']))
|
| 894 |
nodes.append(helper.make_node('Cast', ['eq'], [oh_name], to=TensorProto.FLOAT))
|
| 895 |
|
| 896 |
def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
|
|
|
|
| 897 |
pad = ks // 2
|
| 898 |
feat = 10 * ks * ks + (1 if use_bias else 0)
|
| 899 |
if feat > 20000: return None
|
|
|
|
| 900 |
patches, targets = [], []
|
| 901 |
for inp_g, out_g in exs_raw:
|
| 902 |
ih, iw = inp_g.shape
|
|
|
|
| 908 |
oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
|
| 909 |
for c in range(10): oh_enc[c] = (inp_g == c)
|
| 910 |
oh_pad = np.pad(oh_enc, ((0,0),(pad,pad),(pad,pad)))
|
|
|
|
| 911 |
oh, ow = out_g.shape
|
| 912 |
for r in range(oh):
|
| 913 |
for c in range(ow):
|
|
|
|
| 915 |
if use_bias: p = np.append(p, 1.0)
|
| 916 |
patches.append(p)
|
| 917 |
targets.append(int(out_g[r, c]))
|
|
|
|
| 918 |
n_patches = len(patches)
|
| 919 |
if feat > 5000 and n_patches > 2000: return None
|
|
|
|
| 920 |
P = np.array(patches, dtype=np.float64)
|
| 921 |
T = np.array(targets, dtype=np.int64)
|
| 922 |
T_oh = np.zeros((len(T), 10), dtype=np.float64)
|
| 923 |
for i, t in enumerate(T): T_oh[i, t] = 1.0
|
|
|
|
| 924 |
WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
|
| 925 |
if not np.array_equal(np.argmax(P @ WT, axis=1), T): return None
|
|
|
|
| 926 |
if use_bias:
|
| 927 |
Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
|
| 928 |
B = WT[-1].astype(np.float32)
|
|
|
|
| 932 |
return Wconv, B
|
| 933 |
|
| 934 |
def solve_conv_fixed(td, path, time_budget=30.0):
|
|
|
|
| 935 |
exs = get_exs(td)
|
| 936 |
for inp, out in exs:
|
| 937 |
if inp.shape != out.shape: return None
|
| 938 |
shapes = set(inp.shape for inp, _ in exs)
|
| 939 |
if len(shapes) != 1: return None
|
| 940 |
IH, IW = shapes.pop()
|
|
|
|
| 941 |
fit_exs = get_exs_for_fitting(td)
|
| 942 |
fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
|
|
|
|
| 943 |
t_start = time.time()
|
| 944 |
for use_bias in [False, True]:
|
| 945 |
for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
|
|
|
|
| 949 |
Wconv, B = result
|
| 950 |
pad = ks // 2
|
| 951 |
pad_h, pad_w = GH - IH, GW - IW
|
|
|
|
| 952 |
inits = [
|
| 953 |
_make_int64_init('sl_st', [0,0,0,0]),
|
| 954 |
_make_int64_init('sl_en', [1,10,IH,IW]),
|
|
|
|
| 958 |
if B is not None:
|
| 959 |
inits.append(numpy_helper.from_array(B, 'B'))
|
| 960 |
conv_inputs.append('B')
|
|
|
|
| 961 |
nodes = [
|
| 962 |
helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
|
| 963 |
helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
|
|
|
|
| 965 |
]
|
| 966 |
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 967 |
nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
|
|
|
|
| 968 |
model = mk(nodes, inits)
|
| 969 |
onnx.save(model, path)
|
| 970 |
if validate(path, td): return 'conv_fixed', model
|
| 971 |
return None
|
| 972 |
|
| 973 |
def solve_conv_variable(td, path, time_budget=30.0):
|
| 974 |
+
"""Variable-shape conv with opset 17 ReduceSum (axes as tensor input)."""
|
| 975 |
exs = get_exs(td)
|
| 976 |
for inp, out in exs:
|
| 977 |
if inp.shape != out.shape: return None
|
|
|
|
| 978 |
fit_exs = get_exs_for_fitting_variable(td)
|
| 979 |
fit_exs = [(i,o) for i,o in fit_exs if i.shape == o.shape]
|
|
|
|
| 980 |
t_start = time.time()
|
| 981 |
for use_bias in [False, True]:
|
| 982 |
for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
|
|
|
|
| 985 |
if result is None: continue
|
| 986 |
Wconv, B = result
|
| 987 |
pad = ks // 2
|
| 988 |
+
inits = [
|
| 989 |
+
numpy_helper.from_array(Wconv, 'W'),
|
| 990 |
+
_make_int64_init('rs_axes_var', [1]),
|
| 991 |
+
]
|
| 992 |
conv_inputs = ['input', 'W']
|
| 993 |
if B is not None:
|
| 994 |
inits.append(numpy_helper.from_array(B, 'B'))
|
| 995 |
conv_inputs.append('B')
|
|
|
|
| 996 |
nodes = [
|
| 997 |
+
helper.make_node('ReduceSum', ['input', 'rs_axes_var'], ['mask'], keepdims=1),
|
| 998 |
helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
|
| 999 |
helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
|
| 1000 |
]
|
| 1001 |
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1002 |
nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
|
|
|
|
| 1003 |
model = mk(nodes, inits)
|
| 1004 |
onnx.save(model, path)
|
| 1005 |
if validate(path, td): return 'conv_var', model
|
| 1006 |
return None
|
| 1007 |
|
| 1008 |
def solve_conv_diffshape(td, path, time_budget=30.0):
|
|
|
|
| 1009 |
sp = fixed_shapes(td)
|
| 1010 |
if sp is None: return None
|
| 1011 |
(IH, IW), (OH, OW) = sp
|
| 1012 |
if IH == OH and IW == OW: return None
|
| 1013 |
if OH > IH or OW > IW: return None
|
| 1014 |
if OH > 30 or OW > 30: return None
|
|
|
|
| 1015 |
exs = get_exs(td)
|
| 1016 |
t_start = time.time()
|
|
|
|
| 1017 |
for dr_off, dc_off in [(0, 0), ((IH-OH)//2, (IW-OW)//2)]:
|
| 1018 |
for use_bias in [False, True]:
|
| 1019 |
for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
|
|
|
|
| 1021 |
pad = ks // 2
|
| 1022 |
feat = 10 * ks * ks + (1 if use_bias else 0)
|
| 1023 |
if feat > 10000: continue
|
|
|
|
| 1024 |
patches, targets = [], []
|
| 1025 |
valid = True
|
| 1026 |
for inp_g, out_g in exs:
|
|
|
|
| 1039 |
if not valid: break
|
| 1040 |
if not valid: break
|
| 1041 |
if not valid: continue
|
|
|
|
| 1042 |
n_patches = len(patches)
|
| 1043 |
if feat > 5000 and n_patches > 2000: continue
|
|
|
|
| 1044 |
P = np.array(patches, dtype=np.float64)
|
| 1045 |
T = np.array(targets, dtype=np.int64)
|
| 1046 |
T_oh = np.zeros((len(T), 10), dtype=np.float64)
|
| 1047 |
for i, t in enumerate(T): T_oh[i, t] = 1.0
|
|
|
|
| 1048 |
WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
|
| 1049 |
if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
|
|
|
|
| 1050 |
if use_bias:
|
| 1051 |
Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
|
| 1052 |
B = WT[-1].astype(np.float32)
|
| 1053 |
else:
|
| 1054 |
Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
|
| 1055 |
B = None
|
|
|
|
| 1056 |
pad_h, pad_w = GH - OH, GW - OW
|
| 1057 |
inits = [
|
| 1058 |
_make_int64_init('sl_st', [0,0,0,0]),
|
|
|
|
| 1065 |
if B is not None:
|
| 1066 |
inits.append(numpy_helper.from_array(B, 'B'))
|
| 1067 |
conv_inputs.append('B')
|
|
|
|
| 1068 |
nodes = [
|
| 1069 |
helper.make_node('Slice', ['input','sl_st','sl_en'], ['grid']),
|
| 1070 |
helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
|
|
|
|
| 1073 |
]
|
| 1074 |
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1075 |
nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
|
|
|
|
| 1076 |
model = mk(nodes, inits)
|
| 1077 |
onnx.save(model, path)
|
| 1078 |
if validate(path, td): return 'conv_diff', model
|
| 1079 |
return None
|
| 1080 |
|
| 1081 |
def solve_conv_var_diff(td, path, time_budget=30.0):
|
| 1082 |
+
"""Variable diff-shape conv with opset 17 ReduceSum."""
|
| 1083 |
exs = get_exs(td)
|
|
|
|
| 1084 |
t_start = time.time()
|
| 1085 |
for use_bias in [False, True]:
|
| 1086 |
for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
|
| 1087 |
if time.time() - t_start > time_budget: return None
|
|
|
|
| 1088 |
pad = ks // 2
|
| 1089 |
feat = 10 * ks * ks + (1 if use_bias else 0)
|
| 1090 |
if feat > 20000: continue
|
|
|
|
| 1091 |
patches, targets = [], []
|
| 1092 |
for inp_g, out_g in exs:
|
| 1093 |
ih, iw = inp_g.shape
|
|
|
|
| 1095 |
oh_full = np.zeros((10, GH, GW), dtype=np.float64)
|
| 1096 |
for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c)
|
| 1097 |
oh_pad = np.pad(oh_full, ((0,0),(pad,pad),(pad,pad)))
|
|
|
|
| 1098 |
for r in range(oh):
|
| 1099 |
for c in range(ow):
|
| 1100 |
p = oh_pad[:, r:r+ks, c:c+ks].flatten()
|
| 1101 |
if use_bias: p = np.append(p, 1.0)
|
| 1102 |
patches.append(p)
|
| 1103 |
targets.append(int(out_g[r, c]))
|
|
|
|
| 1104 |
n_patches = len(patches)
|
| 1105 |
if feat > 5000 and n_patches > 2000: continue
|
|
|
|
| 1106 |
P = np.array(patches, dtype=np.float64)
|
| 1107 |
T = np.array(targets, dtype=np.int64)
|
| 1108 |
T_oh = np.zeros((len(T), 10), dtype=np.float64)
|
| 1109 |
for i, t in enumerate(T): T_oh[i, t] = 1.0
|
|
|
|
| 1110 |
try:
|
| 1111 |
WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
|
| 1112 |
except:
|
| 1113 |
continue
|
| 1114 |
if not np.array_equal(np.argmax(P @ WT, axis=1), T): continue
|
|
|
|
| 1115 |
if use_bias:
|
| 1116 |
Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
|
| 1117 |
B = WT[-1].astype(np.float32)
|
| 1118 |
else:
|
| 1119 |
Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
|
| 1120 |
B = None
|
|
|
|
|
|
|
| 1121 |
all_output_within_input = all(
|
| 1122 |
out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
|
| 1123 |
for inp_g, out_g in exs
|
| 1124 |
)
|
|
|
|
| 1125 |
if all_output_within_input:
|
| 1126 |
+
inits = [
|
| 1127 |
+
numpy_helper.from_array(Wconv, 'W'),
|
| 1128 |
+
_make_int64_init('rs_axes_vd', [1]),
|
| 1129 |
+
]
|
| 1130 |
conv_inputs = ['input', 'W']
|
| 1131 |
if B is not None:
|
| 1132 |
inits.append(numpy_helper.from_array(B, 'B'))
|
| 1133 |
conv_inputs.append('B')
|
|
|
|
| 1134 |
nodes = [
|
| 1135 |
+
helper.make_node('ReduceSum', ['input', 'rs_axes_vd'], ['mask'], keepdims=1),
|
| 1136 |
helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks,ks], pads=[pad]*4),
|
| 1137 |
helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
|
| 1138 |
]
|
| 1139 |
add_onehot_block(nodes, inits, 'am', 'oh_out')
|
| 1140 |
nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
|
|
|
|
| 1141 |
model = mk(nodes, inits)
|
| 1142 |
onnx.save(model, path)
|
| 1143 |
if validate(path, td): return 'conv_var_diff', model
|
|
|
|
| 1171 |
]
|
| 1172 |
|
| 1173 |
def solve_task(tn, td, output_dir, conv_budget=30.0, verbose=True):
|
|
|
|
| 1174 |
path = os.path.join(output_dir, f"task{tn:03d}.onnx")
|
|
|
|
|
|
|
| 1175 |
for name, solver in ANALYTICAL_SOLVERS:
|
| 1176 |
try:
|
| 1177 |
model = solver(td)
|
|
|
|
| 1189 |
return name, score
|
| 1190 |
else:
|
| 1191 |
if verbose: print(f" {name}: model built but FAILED validation")
|
|
|
|
|
|
|
| 1192 |
conv_solvers = [
|
| 1193 |
('conv_fixed', solve_conv_fixed),
|
| 1194 |
('conv_variable', solve_conv_variable),
|
|
|
|
| 1210 |
score = max(1.0, 25.0 - math.log(cost)) if cost > 0 else 25.0
|
| 1211 |
if verbose: print(f" {solver_type}: PASS cost={cost} score={score:.2f}")
|
| 1212 |
return solver_type, score
|
|
|
|
| 1213 |
return None
|
| 1214 |
|
| 1215 |
def main():
|
| 1216 |
parser = argparse.ArgumentParser(description='NeuroGolf Solver v5')
|
| 1217 |
+
parser.add_argument('--data_dir', type=str, default=None)
|
| 1218 |
+
parser.add_argument('--kaggle_dir', type=str, default=None)
|
| 1219 |
+
parser.add_argument('--arcgen_dir', type=str, default=None)
|
| 1220 |
+
parser.add_argument('--output_dir', type=str, default='submission')
|
| 1221 |
+
parser.add_argument('--conv_budget', type=float, default=30.0)
|
| 1222 |
+
parser.add_argument('--task', type=int, default=None)
|
| 1223 |
parser.add_argument('--verbose', action='store_true', default=True)
|
| 1224 |
parser.add_argument('--quiet', action='store_true', default=False)
|
| 1225 |
args = parser.parse_args()
|
|
|
|
| 1226 |
if args.quiet:
|
| 1227 |
args.verbose = False
|
|
|
|
| 1228 |
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
|
|
|
| 1229 |
if args.kaggle_dir:
|
| 1230 |
tasks = load_tasks_kaggle(args.kaggle_dir)
|
| 1231 |
elif args.data_dir:
|
| 1232 |
tasks = load_tasks_dir(args.data_dir, args.arcgen_dir)
|
| 1233 |
else:
|
| 1234 |
+
for p in ['/kaggle/input/competitions/neurogolf-2026/', 'ARC-AGI/data/training/']:
|
|
|
|
|
|
|
| 1235 |
if os.path.exists(p):
|
| 1236 |
if 'kaggle' in p:
|
| 1237 |
tasks = load_tasks_kaggle(p)
|
|
|
|
| 1241 |
else:
|
| 1242 |
print("ERROR: No data directory found. Use --data_dir or --kaggle_dir")
|
| 1243 |
sys.exit(1)
|
|
|
|
|
|
|
| 1244 |
results = {}
|
| 1245 |
total_score = 0.0
|
| 1246 |
solved = 0
|
| 1247 |
t_total = time.time()
|
|
|
|
| 1248 |
task_nums = [args.task] if args.task else sorted(tasks.keys())
|
|
|
|
| 1249 |
for tn in task_nums:
|
| 1250 |
if tn in EXCLUDED_TASKS:
|
| 1251 |
if args.verbose: print(f"Task {tn:3d}: EXCLUDED")
|
|
|
|
| 1253 |
if tn not in tasks:
|
| 1254 |
if args.verbose: print(f"Task {tn:3d}: NOT FOUND")
|
| 1255 |
continue
|
|
|
|
| 1256 |
td = tasks[tn]['data']
|
| 1257 |
hex_id = tasks[tn]['hex']
|
|
|
|
| 1258 |
if args.verbose: print(f"\nTask {tn:3d} ({hex_id}):")
|
|
|
|
| 1259 |
result = solve_task(tn, td, args.output_dir, args.conv_budget, args.verbose)
|
|
|
|
| 1260 |
if result is not None:
|
| 1261 |
solver_type, score = result
|
| 1262 |
results[tn] = {'solver': solver_type, 'score': score, 'hex': hex_id}
|
| 1263 |
total_score += score
|
| 1264 |
solved += 1
|
| 1265 |
else:
|
|
|
|
| 1266 |
total_score += 1.0
|
| 1267 |
if args.verbose: print(f" UNSOLVED")
|
|
|
|
|
|
|
| 1268 |
elapsed = time.time() - t_total
|
| 1269 |
print(f"\n{'='*60}")
|
| 1270 |
print(f"RESULTS: {solved}/{len(task_nums)} tasks solved")
|
| 1271 |
print(f"Total score: {total_score:.1f}")
|
| 1272 |
print(f"Time: {elapsed:.1f}s")
|
| 1273 |
print(f"{'='*60}")
|
|
|
|
|
|
|
| 1274 |
solver_counts = Counter(r['solver'] for r in results.values())
|
| 1275 |
solver_scores = {}
|
| 1276 |
for tn, r in results.items():
|
| 1277 |
st = r['solver']
|
| 1278 |
solver_scores[st] = solver_scores.get(st, 0) + r['score']
|
|
|
|
| 1279 |
print("\nSolver breakdown:")
|
| 1280 |
for st in sorted(solver_counts.keys()):
|
| 1281 |
print(f" {st}: {solver_counts[st]} tasks, total score {solver_scores[st]:.1f}, avg {solver_scores[st]/solver_counts[st]:.2f}")
|
|
|
|
|
|
|
| 1282 |
csv_path = os.path.join(args.output_dir, 'submission.csv')
|
| 1283 |
with open(csv_path, 'w', newline='') as f:
|
| 1284 |
w = csv.writer(f)
|
|
|
|
| 1286 |
for tn in sorted(results.keys()):
|
| 1287 |
r = results[tn]
|
| 1288 |
w.writerow([tn, r['hex'], r['solver'], f"{r['score']:.3f}", f"task{tn:03d}.onnx"])
|
|
|
|
|
|
|
| 1289 |
zip_path = os.path.join(args.output_dir, 'submission.zip')
|
| 1290 |
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
|
| 1291 |
for tn in sorted(results.keys()):
|
| 1292 |
onnx_path = os.path.join(args.output_dir, f"task{tn:03d}.onnx")
|
| 1293 |
if os.path.exists(onnx_path):
|
| 1294 |
zf.write(onnx_path, f"task{tn:03d}.onnx")
|
|
|
|
| 1295 |
print(f"\nSubmission files: {csv_path}, {zip_path}")
|
| 1296 |
print(f"Models in zip: {len(results)}")
|
| 1297 |
|