File size: 5,811 Bytes
8e44042 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | #!/usr/bin/env python3
"""Gravity solver — unrolled bubble-sort via Conv + Where.
v5.2: Solves Task 78 (direction=up, bg=0, score 8.399).
Tries all 4 directions × 10 bg colors. Fixed-shape only.
"""
import numpy as np
from onnx import helper, numpy_helper
from ..onnx_helpers import mk, _make_int64_init, _build_pad_node, add_onehot_block
from ..data_loader import get_exs, fixed_shapes
from ..constants import GH, GW
def _gravity_np(grid, direction, bg_color=0):
"""Apply gravity in numpy for verification."""
r = np.full_like(grid, bg_color)
h, w = grid.shape
if direction == 'down':
for c in range(w):
nz = grid[:, c][grid[:, c] != bg_color]
r[h - len(nz):h, c] = nz
elif direction == 'up':
for c in range(w):
nz = grid[:, c][grid[:, c] != bg_color]
r[:len(nz), c] = nz
elif direction == 'right':
for rr in range(h):
nz = grid[rr, :][grid[rr, :] != bg_color]
r[rr, w - len(nz):w] = nz
elif direction == 'left':
for rr in range(h):
nz = grid[rr, :][grid[rr, :] != bg_color]
r[rr, :len(nz)] = nz
return r
def _build_gravity_model(IH, IW, direction, bg_color=0):
"""Build ONNX model for gravity via unrolled bubble-sort.
Each step compares adjacent cells and swaps if needed:
- If current cell is bg AND source neighbor is non-bg → fill with source
- If current cell is non-bg AND destination neighbor is bg → vacate to bg
After max(IH,IW) passes, all non-bg pixels settle in the gravity direction.
"""
pad_h, pad_w = GH - IH, GW - IW
n_steps = max(IH, IW)
pull_above = np.zeros((10, 10, 3, 3), dtype=np.float32)
pull_below = np.zeros((10, 10, 3, 3), dtype=np.float32)
for ch in range(10):
if direction == 'down':
pull_above[ch, ch, 0, 1] = 1.0
pull_below[ch, ch, 2, 1] = 1.0
elif direction == 'up':
pull_above[ch, ch, 2, 1] = 1.0
pull_below[ch, ch, 0, 1] = 1.0
elif direction == 'right':
pull_above[ch, ch, 1, 0] = 1.0
pull_below[ch, ch, 1, 2] = 1.0
elif direction == 'left':
pull_above[ch, ch, 1, 2] = 1.0
pull_below[ch, ch, 1, 0] = 1.0
bg_sel = np.zeros((1, 10, 1, 1), dtype=np.float32)
bg_sel[0, bg_color, 0, 0] = 1.0
bg_oh = np.zeros((1, 10, 1, 1), dtype=np.float32)
bg_oh[0, bg_color, 0, 0] = 1.0
inits = [
_make_int64_init('sl_st', [0, 0, 0, 0]),
_make_int64_init('sl_en', [1, 10, IH, IW]),
numpy_helper.from_array(pull_above, 'pull_src'),
numpy_helper.from_array(pull_below, 'pull_dst'),
numpy_helper.from_array(bg_sel, 'bg_sel'),
numpy_helper.from_array(bg_oh, 'bg_oh'),
numpy_helper.from_array(np.float32(0.5), 'half'),
]
nodes = [
helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cur_0']),
]
cur = 'cur_0'
for i in range(n_steps):
src = f'src_{i}'
nodes.append(helper.make_node('Conv', [cur, 'pull_src'], [src],
kernel_shape=[3, 3], pads=[1, 1, 1, 1]))
nodes.append(helper.make_node('Mul', [cur, 'bg_sel'], [f'cbg_{i}']))
inits.append(_make_int64_init(f'ax1_{i}', [1]))
nodes.append(helper.make_node('ReduceSum', [f'cbg_{i}', f'ax1_{i}'], [f'cbgsum_{i}'], keepdims=1))
nodes.append(helper.make_node('Greater', [f'cbgsum_{i}', 'half'], [f'cur_is_bg_{i}']))
nodes.append(helper.make_node('Mul', [src, 'bg_sel'], [f'sbg_{i}']))
inits.append(_make_int64_init(f'ax2_{i}', [1]))
nodes.append(helper.make_node('ReduceSum', [f'sbg_{i}', f'ax2_{i}'], [f'sbgsum_{i}'], keepdims=1))
nodes.append(helper.make_node('Not', [f'cur_is_bg_{i}'], [f'cur_not_bg_{i}']))
nodes.append(helper.make_node('Greater', [f'sbgsum_{i}', 'half'], [f'src_is_bg_{i}']))
nodes.append(helper.make_node('Not', [f'src_is_bg_{i}'], [f'src_not_bg_{i}']))
nodes.append(helper.make_node('And', [f'cur_is_bg_{i}', f'src_not_bg_{i}'], [f'fill_{i}']))
dst = f'dst_{i}'
nodes.append(helper.make_node('Conv', [cur, 'pull_dst'], [dst],
kernel_shape=[3, 3], pads=[1, 1, 1, 1]))
nodes.append(helper.make_node('Mul', [dst, 'bg_sel'], [f'dbg_{i}']))
inits.append(_make_int64_init(f'ax3_{i}', [1]))
nodes.append(helper.make_node('ReduceSum', [f'dbg_{i}', f'ax3_{i}'], [f'dbgsum_{i}'], keepdims=1))
nodes.append(helper.make_node('Greater', [f'dbgsum_{i}', 'half'], [f'dst_is_bg_{i}']))
nodes.append(helper.make_node('And', [f'cur_not_bg_{i}', f'dst_is_bg_{i}'], [f'vacate_{i}']))
nxt = f'cur_{i+1}'
nodes.append(helper.make_node('Where', [f'fill_{i}', src, cur], [f'tmp_{i}']))
nodes.append(helper.make_node('Where', [f'vacate_{i}', 'bg_oh', f'tmp_{i}'], [nxt]))
cur = nxt
nodes.append(helper.make_node('ArgMax', [cur], ['grav_am'], axis=1, keepdims=1))
add_onehot_block(nodes, inits, 'grav_am', 'grav_oh')
nodes.append(_build_pad_node('grav_oh', 'output', pad_h, pad_w, inits))
return mk(nodes, inits)
def s_gravity_unrolled(td):
"""Gravity solver with unrolled Conv+Where steps.
Tries all 4 directions × bg colors 0-9."""
exs = get_exs(td)
sp = fixed_shapes(td)
if sp is None:
return None
(IH, IW), (OH, OW) = sp
if (IH, IW) != (OH, OW):
return None
for bg_color in range(10):
for direction in ('down', 'up', 'left', 'right'):
if all(np.array_equal(_gravity_np(inp, direction, bg_color), out)
for inp, out in exs):
return _build_gravity_model(IH, IW, direction, bg_color)
return None
|