| |
| """Gravity solver — unrolled bubble-sort via Conv + Where. |
| |
| v5.2: Solves Task 78 (direction=up, bg=0, score 8.399). |
| Tries all 4 directions × 10 bg colors. Fixed-shape only. |
| """ |
|
|
| import numpy as np |
| from onnx import helper, numpy_helper |
| from ..onnx_helpers import mk, _make_int64_init, _build_pad_node, add_onehot_block |
| from ..data_loader import get_exs, fixed_shapes |
| from ..constants import GH, GW |
|
|
|
|
| def _gravity_np(grid, direction, bg_color=0): |
| """Apply gravity in numpy for verification.""" |
| r = np.full_like(grid, bg_color) |
| h, w = grid.shape |
| if direction == 'down': |
| for c in range(w): |
| nz = grid[:, c][grid[:, c] != bg_color] |
| r[h - len(nz):h, c] = nz |
| elif direction == 'up': |
| for c in range(w): |
| nz = grid[:, c][grid[:, c] != bg_color] |
| r[:len(nz), c] = nz |
| elif direction == 'right': |
| for rr in range(h): |
| nz = grid[rr, :][grid[rr, :] != bg_color] |
| r[rr, w - len(nz):w] = nz |
| elif direction == 'left': |
| for rr in range(h): |
| nz = grid[rr, :][grid[rr, :] != bg_color] |
| r[rr, :len(nz)] = nz |
| return r |
|
|
|
|
| def _build_gravity_model(IH, IW, direction, bg_color=0): |
| """Build ONNX model for gravity via unrolled bubble-sort. |
| |
| Each step compares adjacent cells and swaps if needed: |
| - If current cell is bg AND source neighbor is non-bg → fill with source |
| - If current cell is non-bg AND destination neighbor is bg → vacate to bg |
| After max(IH,IW) passes, all non-bg pixels settle in the gravity direction. |
| """ |
| pad_h, pad_w = GH - IH, GW - IW |
| n_steps = max(IH, IW) |
|
|
| pull_above = np.zeros((10, 10, 3, 3), dtype=np.float32) |
| pull_below = np.zeros((10, 10, 3, 3), dtype=np.float32) |
| for ch in range(10): |
| if direction == 'down': |
| pull_above[ch, ch, 0, 1] = 1.0 |
| pull_below[ch, ch, 2, 1] = 1.0 |
| elif direction == 'up': |
| pull_above[ch, ch, 2, 1] = 1.0 |
| pull_below[ch, ch, 0, 1] = 1.0 |
| elif direction == 'right': |
| pull_above[ch, ch, 1, 0] = 1.0 |
| pull_below[ch, ch, 1, 2] = 1.0 |
| elif direction == 'left': |
| pull_above[ch, ch, 1, 2] = 1.0 |
| pull_below[ch, ch, 1, 0] = 1.0 |
|
|
| bg_sel = np.zeros((1, 10, 1, 1), dtype=np.float32) |
| bg_sel[0, bg_color, 0, 0] = 1.0 |
| bg_oh = np.zeros((1, 10, 1, 1), dtype=np.float32) |
| bg_oh[0, bg_color, 0, 0] = 1.0 |
|
|
| inits = [ |
| _make_int64_init('sl_st', [0, 0, 0, 0]), |
| _make_int64_init('sl_en', [1, 10, IH, IW]), |
| numpy_helper.from_array(pull_above, 'pull_src'), |
| numpy_helper.from_array(pull_below, 'pull_dst'), |
| numpy_helper.from_array(bg_sel, 'bg_sel'), |
| numpy_helper.from_array(bg_oh, 'bg_oh'), |
| numpy_helper.from_array(np.float32(0.5), 'half'), |
| ] |
|
|
| nodes = [ |
| helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cur_0']), |
| ] |
|
|
| cur = 'cur_0' |
| for i in range(n_steps): |
| src = f'src_{i}' |
| nodes.append(helper.make_node('Conv', [cur, 'pull_src'], [src], |
| kernel_shape=[3, 3], pads=[1, 1, 1, 1])) |
|
|
| nodes.append(helper.make_node('Mul', [cur, 'bg_sel'], [f'cbg_{i}'])) |
| inits.append(_make_int64_init(f'ax1_{i}', [1])) |
| nodes.append(helper.make_node('ReduceSum', [f'cbg_{i}', f'ax1_{i}'], [f'cbgsum_{i}'], keepdims=1)) |
| nodes.append(helper.make_node('Greater', [f'cbgsum_{i}', 'half'], [f'cur_is_bg_{i}'])) |
|
|
| nodes.append(helper.make_node('Mul', [src, 'bg_sel'], [f'sbg_{i}'])) |
| inits.append(_make_int64_init(f'ax2_{i}', [1])) |
| nodes.append(helper.make_node('ReduceSum', [f'sbg_{i}', f'ax2_{i}'], [f'sbgsum_{i}'], keepdims=1)) |
| nodes.append(helper.make_node('Not', [f'cur_is_bg_{i}'], [f'cur_not_bg_{i}'])) |
|
|
| nodes.append(helper.make_node('Greater', [f'sbgsum_{i}', 'half'], [f'src_is_bg_{i}'])) |
| nodes.append(helper.make_node('Not', [f'src_is_bg_{i}'], [f'src_not_bg_{i}'])) |
| nodes.append(helper.make_node('And', [f'cur_is_bg_{i}', f'src_not_bg_{i}'], [f'fill_{i}'])) |
|
|
| dst = f'dst_{i}' |
| nodes.append(helper.make_node('Conv', [cur, 'pull_dst'], [dst], |
| kernel_shape=[3, 3], pads=[1, 1, 1, 1])) |
| nodes.append(helper.make_node('Mul', [dst, 'bg_sel'], [f'dbg_{i}'])) |
| inits.append(_make_int64_init(f'ax3_{i}', [1])) |
| nodes.append(helper.make_node('ReduceSum', [f'dbg_{i}', f'ax3_{i}'], [f'dbgsum_{i}'], keepdims=1)) |
| nodes.append(helper.make_node('Greater', [f'dbgsum_{i}', 'half'], [f'dst_is_bg_{i}'])) |
| nodes.append(helper.make_node('And', [f'cur_not_bg_{i}', f'dst_is_bg_{i}'], [f'vacate_{i}'])) |
|
|
| nxt = f'cur_{i+1}' |
| nodes.append(helper.make_node('Where', [f'fill_{i}', src, cur], [f'tmp_{i}'])) |
| nodes.append(helper.make_node('Where', [f'vacate_{i}', 'bg_oh', f'tmp_{i}'], [nxt])) |
| cur = nxt |
|
|
| nodes.append(helper.make_node('ArgMax', [cur], ['grav_am'], axis=1, keepdims=1)) |
| add_onehot_block(nodes, inits, 'grav_am', 'grav_oh') |
| nodes.append(_build_pad_node('grav_oh', 'output', pad_h, pad_w, inits)) |
| return mk(nodes, inits) |
|
|
|
|
| def s_gravity_unrolled(td): |
| """Gravity solver with unrolled Conv+Where steps. |
| Tries all 4 directions × bg colors 0-9.""" |
| exs = get_exs(td) |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| if (IH, IW) != (OH, OW): |
| return None |
|
|
| for bg_color in range(10): |
| for direction in ('down', 'up', 'left', 'right'): |
| if all(np.array_equal(_gravity_np(inp, direction, bg_color), out) |
| for inp, out in exs): |
| return _build_gravity_model(IH, IW, direction, bg_color) |
| return None |
|
|