File size: 5,811 Bytes
8e44042
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#!/usr/bin/env python3
"""Gravity solver — unrolled bubble-sort via Conv + Where.

v5.2: Solves Task 78 (direction=up, bg=0, score 8.399).
Tries all 4 directions × 10 bg colors. Fixed-shape only.
"""

import numpy as np
from onnx import helper, numpy_helper
from ..onnx_helpers import mk, _make_int64_init, _build_pad_node, add_onehot_block
from ..data_loader import get_exs, fixed_shapes
from ..constants import GH, GW


def _gravity_np(grid, direction, bg_color=0):
    """Apply gravity in numpy for verification."""
    r = np.full_like(grid, bg_color)
    h, w = grid.shape
    if direction == 'down':
        for c in range(w):
            nz = grid[:, c][grid[:, c] != bg_color]
            r[h - len(nz):h, c] = nz
    elif direction == 'up':
        for c in range(w):
            nz = grid[:, c][grid[:, c] != bg_color]
            r[:len(nz), c] = nz
    elif direction == 'right':
        for rr in range(h):
            nz = grid[rr, :][grid[rr, :] != bg_color]
            r[rr, w - len(nz):w] = nz
    elif direction == 'left':
        for rr in range(h):
            nz = grid[rr, :][grid[rr, :] != bg_color]
            r[rr, :len(nz)] = nz
    return r


def _build_gravity_model(IH, IW, direction, bg_color=0):
    """Build ONNX model for gravity via unrolled bubble-sort.

    Each step compares adjacent cells and swaps if needed:
    - If current cell is bg AND source neighbor is non-bg → fill with source
    - If current cell is non-bg AND destination neighbor is bg → vacate to bg
    After max(IH,IW) passes, all non-bg pixels settle in the gravity direction.
    """
    pad_h, pad_w = GH - IH, GW - IW
    n_steps = max(IH, IW)

    pull_above = np.zeros((10, 10, 3, 3), dtype=np.float32)
    pull_below = np.zeros((10, 10, 3, 3), dtype=np.float32)
    for ch in range(10):
        if direction == 'down':
            pull_above[ch, ch, 0, 1] = 1.0
            pull_below[ch, ch, 2, 1] = 1.0
        elif direction == 'up':
            pull_above[ch, ch, 2, 1] = 1.0
            pull_below[ch, ch, 0, 1] = 1.0
        elif direction == 'right':
            pull_above[ch, ch, 1, 0] = 1.0
            pull_below[ch, ch, 1, 2] = 1.0
        elif direction == 'left':
            pull_above[ch, ch, 1, 2] = 1.0
            pull_below[ch, ch, 1, 0] = 1.0

    bg_sel = np.zeros((1, 10, 1, 1), dtype=np.float32)
    bg_sel[0, bg_color, 0, 0] = 1.0
    bg_oh = np.zeros((1, 10, 1, 1), dtype=np.float32)
    bg_oh[0, bg_color, 0, 0] = 1.0

    inits = [
        _make_int64_init('sl_st', [0, 0, 0, 0]),
        _make_int64_init('sl_en', [1, 10, IH, IW]),
        numpy_helper.from_array(pull_above, 'pull_src'),
        numpy_helper.from_array(pull_below, 'pull_dst'),
        numpy_helper.from_array(bg_sel, 'bg_sel'),
        numpy_helper.from_array(bg_oh, 'bg_oh'),
        numpy_helper.from_array(np.float32(0.5), 'half'),
    ]

    nodes = [
        helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cur_0']),
    ]

    cur = 'cur_0'
    for i in range(n_steps):
        src = f'src_{i}'
        nodes.append(helper.make_node('Conv', [cur, 'pull_src'], [src],
                                       kernel_shape=[3, 3], pads=[1, 1, 1, 1]))

        nodes.append(helper.make_node('Mul', [cur, 'bg_sel'], [f'cbg_{i}']))
        inits.append(_make_int64_init(f'ax1_{i}', [1]))
        nodes.append(helper.make_node('ReduceSum', [f'cbg_{i}', f'ax1_{i}'], [f'cbgsum_{i}'], keepdims=1))
        nodes.append(helper.make_node('Greater', [f'cbgsum_{i}', 'half'], [f'cur_is_bg_{i}']))

        nodes.append(helper.make_node('Mul', [src, 'bg_sel'], [f'sbg_{i}']))
        inits.append(_make_int64_init(f'ax2_{i}', [1]))
        nodes.append(helper.make_node('ReduceSum', [f'sbg_{i}', f'ax2_{i}'], [f'sbgsum_{i}'], keepdims=1))
        nodes.append(helper.make_node('Not', [f'cur_is_bg_{i}'], [f'cur_not_bg_{i}']))

        nodes.append(helper.make_node('Greater', [f'sbgsum_{i}', 'half'], [f'src_is_bg_{i}']))
        nodes.append(helper.make_node('Not', [f'src_is_bg_{i}'], [f'src_not_bg_{i}']))
        nodes.append(helper.make_node('And', [f'cur_is_bg_{i}', f'src_not_bg_{i}'], [f'fill_{i}']))

        dst = f'dst_{i}'
        nodes.append(helper.make_node('Conv', [cur, 'pull_dst'], [dst],
                                       kernel_shape=[3, 3], pads=[1, 1, 1, 1]))
        nodes.append(helper.make_node('Mul', [dst, 'bg_sel'], [f'dbg_{i}']))
        inits.append(_make_int64_init(f'ax3_{i}', [1]))
        nodes.append(helper.make_node('ReduceSum', [f'dbg_{i}', f'ax3_{i}'], [f'dbgsum_{i}'], keepdims=1))
        nodes.append(helper.make_node('Greater', [f'dbgsum_{i}', 'half'], [f'dst_is_bg_{i}']))
        nodes.append(helper.make_node('And', [f'cur_not_bg_{i}', f'dst_is_bg_{i}'], [f'vacate_{i}']))

        nxt = f'cur_{i+1}'
        nodes.append(helper.make_node('Where', [f'fill_{i}', src, cur], [f'tmp_{i}']))
        nodes.append(helper.make_node('Where', [f'vacate_{i}', 'bg_oh', f'tmp_{i}'], [nxt]))
        cur = nxt

    nodes.append(helper.make_node('ArgMax', [cur], ['grav_am'], axis=1, keepdims=1))
    add_onehot_block(nodes, inits, 'grav_am', 'grav_oh')
    nodes.append(_build_pad_node('grav_oh', 'output', pad_h, pad_w, inits))
    return mk(nodes, inits)


def s_gravity_unrolled(td):
    """Gravity solver with unrolled Conv+Where steps.
    Tries all 4 directions × bg colors 0-9."""
    exs = get_exs(td)
    sp = fixed_shapes(td)
    if sp is None:
        return None
    (IH, IW), (OH, OW) = sp
    if (IH, IW) != (OH, OW):
        return None

    for bg_color in range(10):
        for direction in ('down', 'up', 'left', 'right'):
            if all(np.array_equal(_gravity_np(inp, direction, bg_color), out)
                   for inp, out in exs):
                return _build_gravity_model(IH, IW, direction, bg_color)
    return None