rogermt commited on
Commit
8e44042
·
verified ·
1 Parent(s): 6b72e40

Move own-solver/neurogolf_solver/solvers/gravity.py to own-solver/

Browse files
own-solver/neurogolf_solver/solvers/gravity.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Gravity solver — unrolled bubble-sort via Conv + Where.
3
+
4
+ v5.2: Solves Task 78 (direction=up, bg=0, score 8.399).
5
+ Tries all 4 directions × 10 bg colors. Fixed-shape only.
6
+ """
7
+
8
+ import numpy as np
9
+ from onnx import helper, numpy_helper
10
+ from ..onnx_helpers import mk, _make_int64_init, _build_pad_node, add_onehot_block
11
+ from ..data_loader import get_exs, fixed_shapes
12
+ from ..constants import GH, GW
13
+
14
+
15
+ def _gravity_np(grid, direction, bg_color=0):
16
+ """Apply gravity in numpy for verification."""
17
+ r = np.full_like(grid, bg_color)
18
+ h, w = grid.shape
19
+ if direction == 'down':
20
+ for c in range(w):
21
+ nz = grid[:, c][grid[:, c] != bg_color]
22
+ r[h - len(nz):h, c] = nz
23
+ elif direction == 'up':
24
+ for c in range(w):
25
+ nz = grid[:, c][grid[:, c] != bg_color]
26
+ r[:len(nz), c] = nz
27
+ elif direction == 'right':
28
+ for rr in range(h):
29
+ nz = grid[rr, :][grid[rr, :] != bg_color]
30
+ r[rr, w - len(nz):w] = nz
31
+ elif direction == 'left':
32
+ for rr in range(h):
33
+ nz = grid[rr, :][grid[rr, :] != bg_color]
34
+ r[rr, :len(nz)] = nz
35
+ return r
36
+
37
+
38
+ def _build_gravity_model(IH, IW, direction, bg_color=0):
39
+ """Build ONNX model for gravity via unrolled bubble-sort.
40
+
41
+ Each step compares adjacent cells and swaps if needed:
42
+ - If current cell is bg AND source neighbor is non-bg → fill with source
43
+ - If current cell is non-bg AND destination neighbor is bg → vacate to bg
44
+ After max(IH,IW) passes, all non-bg pixels settle in the gravity direction.
45
+ """
46
+ pad_h, pad_w = GH - IH, GW - IW
47
+ n_steps = max(IH, IW)
48
+
49
+ pull_above = np.zeros((10, 10, 3, 3), dtype=np.float32)
50
+ pull_below = np.zeros((10, 10, 3, 3), dtype=np.float32)
51
+ for ch in range(10):
52
+ if direction == 'down':
53
+ pull_above[ch, ch, 0, 1] = 1.0
54
+ pull_below[ch, ch, 2, 1] = 1.0
55
+ elif direction == 'up':
56
+ pull_above[ch, ch, 2, 1] = 1.0
57
+ pull_below[ch, ch, 0, 1] = 1.0
58
+ elif direction == 'right':
59
+ pull_above[ch, ch, 1, 0] = 1.0
60
+ pull_below[ch, ch, 1, 2] = 1.0
61
+ elif direction == 'left':
62
+ pull_above[ch, ch, 1, 2] = 1.0
63
+ pull_below[ch, ch, 1, 0] = 1.0
64
+
65
+ bg_sel = np.zeros((1, 10, 1, 1), dtype=np.float32)
66
+ bg_sel[0, bg_color, 0, 0] = 1.0
67
+ bg_oh = np.zeros((1, 10, 1, 1), dtype=np.float32)
68
+ bg_oh[0, bg_color, 0, 0] = 1.0
69
+
70
+ inits = [
71
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
72
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
73
+ numpy_helper.from_array(pull_above, 'pull_src'),
74
+ numpy_helper.from_array(pull_below, 'pull_dst'),
75
+ numpy_helper.from_array(bg_sel, 'bg_sel'),
76
+ numpy_helper.from_array(bg_oh, 'bg_oh'),
77
+ numpy_helper.from_array(np.float32(0.5), 'half'),
78
+ ]
79
+
80
+ nodes = [
81
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cur_0']),
82
+ ]
83
+
84
+ cur = 'cur_0'
85
+ for i in range(n_steps):
86
+ src = f'src_{i}'
87
+ nodes.append(helper.make_node('Conv', [cur, 'pull_src'], [src],
88
+ kernel_shape=[3, 3], pads=[1, 1, 1, 1]))
89
+
90
+ nodes.append(helper.make_node('Mul', [cur, 'bg_sel'], [f'cbg_{i}']))
91
+ inits.append(_make_int64_init(f'ax1_{i}', [1]))
92
+ nodes.append(helper.make_node('ReduceSum', [f'cbg_{i}', f'ax1_{i}'], [f'cbgsum_{i}'], keepdims=1))
93
+ nodes.append(helper.make_node('Greater', [f'cbgsum_{i}', 'half'], [f'cur_is_bg_{i}']))
94
+
95
+ nodes.append(helper.make_node('Mul', [src, 'bg_sel'], [f'sbg_{i}']))
96
+ inits.append(_make_int64_init(f'ax2_{i}', [1]))
97
+ nodes.append(helper.make_node('ReduceSum', [f'sbg_{i}', f'ax2_{i}'], [f'sbgsum_{i}'], keepdims=1))
98
+ nodes.append(helper.make_node('Not', [f'cur_is_bg_{i}'], [f'cur_not_bg_{i}']))
99
+
100
+ nodes.append(helper.make_node('Greater', [f'sbgsum_{i}', 'half'], [f'src_is_bg_{i}']))
101
+ nodes.append(helper.make_node('Not', [f'src_is_bg_{i}'], [f'src_not_bg_{i}']))
102
+ nodes.append(helper.make_node('And', [f'cur_is_bg_{i}', f'src_not_bg_{i}'], [f'fill_{i}']))
103
+
104
+ dst = f'dst_{i}'
105
+ nodes.append(helper.make_node('Conv', [cur, 'pull_dst'], [dst],
106
+ kernel_shape=[3, 3], pads=[1, 1, 1, 1]))
107
+ nodes.append(helper.make_node('Mul', [dst, 'bg_sel'], [f'dbg_{i}']))
108
+ inits.append(_make_int64_init(f'ax3_{i}', [1]))
109
+ nodes.append(helper.make_node('ReduceSum', [f'dbg_{i}', f'ax3_{i}'], [f'dbgsum_{i}'], keepdims=1))
110
+ nodes.append(helper.make_node('Greater', [f'dbgsum_{i}', 'half'], [f'dst_is_bg_{i}']))
111
+ nodes.append(helper.make_node('And', [f'cur_not_bg_{i}', f'dst_is_bg_{i}'], [f'vacate_{i}']))
112
+
113
+ nxt = f'cur_{i+1}'
114
+ nodes.append(helper.make_node('Where', [f'fill_{i}', src, cur], [f'tmp_{i}']))
115
+ nodes.append(helper.make_node('Where', [f'vacate_{i}', 'bg_oh', f'tmp_{i}'], [nxt]))
116
+ cur = nxt
117
+
118
+ nodes.append(helper.make_node('ArgMax', [cur], ['grav_am'], axis=1, keepdims=1))
119
+ add_onehot_block(nodes, inits, 'grav_am', 'grav_oh')
120
+ nodes.append(_build_pad_node('grav_oh', 'output', pad_h, pad_w, inits))
121
+ return mk(nodes, inits)
122
+
123
+
124
+ def s_gravity_unrolled(td):
125
+ """Gravity solver with unrolled Conv+Where steps.
126
+ Tries all 4 directions × bg colors 0-9."""
127
+ exs = get_exs(td)
128
+ sp = fixed_shapes(td)
129
+ if sp is None:
130
+ return None
131
+ (IH, IW), (OH, OW) = sp
132
+ if (IH, IW) != (OH, OW):
133
+ return None
134
+
135
+ for bg_color in range(10):
136
+ for direction in ('down', 'up', 'left', 'right'):
137
+ if all(np.array_equal(_gravity_np(inp, direction, bg_color), out)
138
+ for inp, out in exs):
139
+ return _build_gravity_model(IH, IW, direction, bg_color)
140
+ return None