Move own-solver/neurogolf_solver/solvers/gravity.py to own-solver/
Browse files
own-solver/neurogolf_solver/solvers/gravity.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Gravity solver — unrolled bubble-sort via Conv + Where.
|
| 3 |
+
|
| 4 |
+
v5.2: Solves Task 78 (direction=up, bg=0, score 8.399).
|
| 5 |
+
Tries all 4 directions × 10 bg colors. Fixed-shape only.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from onnx import helper, numpy_helper
|
| 10 |
+
from ..onnx_helpers import mk, _make_int64_init, _build_pad_node, add_onehot_block
|
| 11 |
+
from ..data_loader import get_exs, fixed_shapes
|
| 12 |
+
from ..constants import GH, GW
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _gravity_np(grid, direction, bg_color=0):
|
| 16 |
+
"""Apply gravity in numpy for verification."""
|
| 17 |
+
r = np.full_like(grid, bg_color)
|
| 18 |
+
h, w = grid.shape
|
| 19 |
+
if direction == 'down':
|
| 20 |
+
for c in range(w):
|
| 21 |
+
nz = grid[:, c][grid[:, c] != bg_color]
|
| 22 |
+
r[h - len(nz):h, c] = nz
|
| 23 |
+
elif direction == 'up':
|
| 24 |
+
for c in range(w):
|
| 25 |
+
nz = grid[:, c][grid[:, c] != bg_color]
|
| 26 |
+
r[:len(nz), c] = nz
|
| 27 |
+
elif direction == 'right':
|
| 28 |
+
for rr in range(h):
|
| 29 |
+
nz = grid[rr, :][grid[rr, :] != bg_color]
|
| 30 |
+
r[rr, w - len(nz):w] = nz
|
| 31 |
+
elif direction == 'left':
|
| 32 |
+
for rr in range(h):
|
| 33 |
+
nz = grid[rr, :][grid[rr, :] != bg_color]
|
| 34 |
+
r[rr, :len(nz)] = nz
|
| 35 |
+
return r
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _build_gravity_model(IH, IW, direction, bg_color=0):
|
| 39 |
+
"""Build ONNX model for gravity via unrolled bubble-sort.
|
| 40 |
+
|
| 41 |
+
Each step compares adjacent cells and swaps if needed:
|
| 42 |
+
- If current cell is bg AND source neighbor is non-bg → fill with source
|
| 43 |
+
- If current cell is non-bg AND destination neighbor is bg → vacate to bg
|
| 44 |
+
After max(IH,IW) passes, all non-bg pixels settle in the gravity direction.
|
| 45 |
+
"""
|
| 46 |
+
pad_h, pad_w = GH - IH, GW - IW
|
| 47 |
+
n_steps = max(IH, IW)
|
| 48 |
+
|
| 49 |
+
pull_above = np.zeros((10, 10, 3, 3), dtype=np.float32)
|
| 50 |
+
pull_below = np.zeros((10, 10, 3, 3), dtype=np.float32)
|
| 51 |
+
for ch in range(10):
|
| 52 |
+
if direction == 'down':
|
| 53 |
+
pull_above[ch, ch, 0, 1] = 1.0
|
| 54 |
+
pull_below[ch, ch, 2, 1] = 1.0
|
| 55 |
+
elif direction == 'up':
|
| 56 |
+
pull_above[ch, ch, 2, 1] = 1.0
|
| 57 |
+
pull_below[ch, ch, 0, 1] = 1.0
|
| 58 |
+
elif direction == 'right':
|
| 59 |
+
pull_above[ch, ch, 1, 0] = 1.0
|
| 60 |
+
pull_below[ch, ch, 1, 2] = 1.0
|
| 61 |
+
elif direction == 'left':
|
| 62 |
+
pull_above[ch, ch, 1, 2] = 1.0
|
| 63 |
+
pull_below[ch, ch, 1, 0] = 1.0
|
| 64 |
+
|
| 65 |
+
bg_sel = np.zeros((1, 10, 1, 1), dtype=np.float32)
|
| 66 |
+
bg_sel[0, bg_color, 0, 0] = 1.0
|
| 67 |
+
bg_oh = np.zeros((1, 10, 1, 1), dtype=np.float32)
|
| 68 |
+
bg_oh[0, bg_color, 0, 0] = 1.0
|
| 69 |
+
|
| 70 |
+
inits = [
|
| 71 |
+
_make_int64_init('sl_st', [0, 0, 0, 0]),
|
| 72 |
+
_make_int64_init('sl_en', [1, 10, IH, IW]),
|
| 73 |
+
numpy_helper.from_array(pull_above, 'pull_src'),
|
| 74 |
+
numpy_helper.from_array(pull_below, 'pull_dst'),
|
| 75 |
+
numpy_helper.from_array(bg_sel, 'bg_sel'),
|
| 76 |
+
numpy_helper.from_array(bg_oh, 'bg_oh'),
|
| 77 |
+
numpy_helper.from_array(np.float32(0.5), 'half'),
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
nodes = [
|
| 81 |
+
helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cur_0']),
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
cur = 'cur_0'
|
| 85 |
+
for i in range(n_steps):
|
| 86 |
+
src = f'src_{i}'
|
| 87 |
+
nodes.append(helper.make_node('Conv', [cur, 'pull_src'], [src],
|
| 88 |
+
kernel_shape=[3, 3], pads=[1, 1, 1, 1]))
|
| 89 |
+
|
| 90 |
+
nodes.append(helper.make_node('Mul', [cur, 'bg_sel'], [f'cbg_{i}']))
|
| 91 |
+
inits.append(_make_int64_init(f'ax1_{i}', [1]))
|
| 92 |
+
nodes.append(helper.make_node('ReduceSum', [f'cbg_{i}', f'ax1_{i}'], [f'cbgsum_{i}'], keepdims=1))
|
| 93 |
+
nodes.append(helper.make_node('Greater', [f'cbgsum_{i}', 'half'], [f'cur_is_bg_{i}']))
|
| 94 |
+
|
| 95 |
+
nodes.append(helper.make_node('Mul', [src, 'bg_sel'], [f'sbg_{i}']))
|
| 96 |
+
inits.append(_make_int64_init(f'ax2_{i}', [1]))
|
| 97 |
+
nodes.append(helper.make_node('ReduceSum', [f'sbg_{i}', f'ax2_{i}'], [f'sbgsum_{i}'], keepdims=1))
|
| 98 |
+
nodes.append(helper.make_node('Not', [f'cur_is_bg_{i}'], [f'cur_not_bg_{i}']))
|
| 99 |
+
|
| 100 |
+
nodes.append(helper.make_node('Greater', [f'sbgsum_{i}', 'half'], [f'src_is_bg_{i}']))
|
| 101 |
+
nodes.append(helper.make_node('Not', [f'src_is_bg_{i}'], [f'src_not_bg_{i}']))
|
| 102 |
+
nodes.append(helper.make_node('And', [f'cur_is_bg_{i}', f'src_not_bg_{i}'], [f'fill_{i}']))
|
| 103 |
+
|
| 104 |
+
dst = f'dst_{i}'
|
| 105 |
+
nodes.append(helper.make_node('Conv', [cur, 'pull_dst'], [dst],
|
| 106 |
+
kernel_shape=[3, 3], pads=[1, 1, 1, 1]))
|
| 107 |
+
nodes.append(helper.make_node('Mul', [dst, 'bg_sel'], [f'dbg_{i}']))
|
| 108 |
+
inits.append(_make_int64_init(f'ax3_{i}', [1]))
|
| 109 |
+
nodes.append(helper.make_node('ReduceSum', [f'dbg_{i}', f'ax3_{i}'], [f'dbgsum_{i}'], keepdims=1))
|
| 110 |
+
nodes.append(helper.make_node('Greater', [f'dbgsum_{i}', 'half'], [f'dst_is_bg_{i}']))
|
| 111 |
+
nodes.append(helper.make_node('And', [f'cur_not_bg_{i}', f'dst_is_bg_{i}'], [f'vacate_{i}']))
|
| 112 |
+
|
| 113 |
+
nxt = f'cur_{i+1}'
|
| 114 |
+
nodes.append(helper.make_node('Where', [f'fill_{i}', src, cur], [f'tmp_{i}']))
|
| 115 |
+
nodes.append(helper.make_node('Where', [f'vacate_{i}', 'bg_oh', f'tmp_{i}'], [nxt]))
|
| 116 |
+
cur = nxt
|
| 117 |
+
|
| 118 |
+
nodes.append(helper.make_node('ArgMax', [cur], ['grav_am'], axis=1, keepdims=1))
|
| 119 |
+
add_onehot_block(nodes, inits, 'grav_am', 'grav_oh')
|
| 120 |
+
nodes.append(_build_pad_node('grav_oh', 'output', pad_h, pad_w, inits))
|
| 121 |
+
return mk(nodes, inits)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def s_gravity_unrolled(td):
|
| 125 |
+
"""Gravity solver with unrolled Conv+Where steps.
|
| 126 |
+
Tries all 4 directions × bg colors 0-9."""
|
| 127 |
+
exs = get_exs(td)
|
| 128 |
+
sp = fixed_shapes(td)
|
| 129 |
+
if sp is None:
|
| 130 |
+
return None
|
| 131 |
+
(IH, IW), (OH, OW) = sp
|
| 132 |
+
if (IH, IW) != (OH, OW):
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
for bg_color in range(10):
|
| 136 |
+
for direction in ('down', 'up', 'left', 'right'):
|
| 137 |
+
if all(np.array_equal(_gravity_np(inp, direction, bg_color), out)
|
| 138 |
+
for inp, out in exs):
|
| 139 |
+
return _build_gravity_model(IH, IW, direction, bg_color)
|
| 140 |
+
return None
|