rogermt commited on
Commit
6b72e40
·
verified ·
1 Parent(s): bb1e82c

Move own-solver/neurogolf_solver/solvers/geometric.py to own-solver/

Browse files
own-solver/neurogolf_solver/solvers/geometric.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Geometric transformation solvers: flip, rotate, shift, crop."""
3
+
4
+ import numpy as np
5
+ from onnx import helper
6
+ from ..onnx_helpers import mk, _build_slice_crop, _build_slice_reverse, _build_pad_node
7
+ from ..data_loader import get_exs, fixed_shapes
8
+ from ..gather_helpers import _build_gather_model, _build_gather_model_with_const
9
+ from ..constants import GH, GW
10
+
11
+
12
+ def s_flip(td):
13
+ """Flip using Slice(step=-1) — zero MACs."""
14
+ exs = get_exs(td)
15
+ sp = fixed_shapes(td)
16
+ if sp is None:
17
+ return None
18
+ (IH, IW), (OH, OW) = sp
19
+ if (IH, IW) != (OH, OW):
20
+ return None
21
+ for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
22
+ if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
23
+ onnx_axis = 2 if axis == 0 else 3
24
+ dim_size = IH if axis == 0 else IW
25
+ pad_h, pad_w = GH - IH, GW - IW
26
+ inits = []
27
+ nodes = []
28
+ nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
29
+ nodes.append(_build_slice_reverse('cropped', 'flipped', onnx_axis, dim_size, inits))
30
+ nodes.append(_build_pad_node('flipped', 'output', pad_h, pad_w, inits))
31
+ return mk(nodes, inits)
32
+ return None
33
+
34
+
35
+ def s_rotate(td):
36
+ """Rotate using Slice+Transpose — zero MACs for square grids and k=2.
37
+ Gather fallback for non-square k=1,3."""
38
+ exs = get_exs(td)
39
+ sp = fixed_shapes(td)
40
+ if sp is None:
41
+ return None
42
+ (IH, IW), (OH, OW) = sp
43
+ for k in [1, 2, 3]:
44
+ if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs):
45
+ continue
46
+ if k == 2:
47
+ pad_h, pad_w = GH - OH, GW - OW
48
+ inits = []
49
+ nodes = []
50
+ nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
51
+ nodes.append(_build_slice_reverse('cropped', 'flip_h', 2, IH, inits, suffix='_h'))
52
+ nodes.append(_build_slice_reverse('flip_h', 'rotated', 3, IW, inits, suffix='_w'))
53
+ nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
54
+ return mk(nodes, inits)
55
+ elif k == 1 and IH == IW:
56
+ pad_h, pad_w = GH - IH, GW - IW
57
+ inits = []
58
+ nodes = []
59
+ nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
60
+ nodes.append(helper.make_node('Transpose', ['cropped'], ['transposed'], perm=[0, 1, 3, 2]))
61
+ nodes.append(_build_slice_reverse('transposed', 'rotated', 2, IH, inits))
62
+ nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
63
+ return mk(nodes, inits)
64
+ elif k == 3 and IH == IW:
65
+ pad_h, pad_w = GH - IH, GW - IW
66
+ inits = []
67
+ nodes = []
68
+ nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
69
+ nodes.append(_build_slice_reverse('cropped', 'flipped', 2, IH, inits))
70
+ nodes.append(helper.make_node('Transpose', ['flipped'], ['rotated'], perm=[0, 1, 3, 2]))
71
+ nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
72
+ return mk(nodes, inits)
73
+ else:
74
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
75
+ for r in range(OH):
76
+ for c in range(OW):
77
+ if k == 1:
78
+ sr, sc = c, IH - 1 - r
79
+ elif k == 3:
80
+ sr, sc = IW - 1 - c, r
81
+ idx[r, c] = [sr, sc]
82
+ return _build_gather_model(OH, OW, idx)
83
+ return None
84
+
85
+
86
+ def s_shift(td):
87
+ """Shift transformation solver."""
88
+ exs = get_exs(td)
89
+ sp = fixed_shapes(td)
90
+ if sp is None:
91
+ return None
92
+ (IH, IW), (OH, OW) = sp
93
+ if (IH, IW) != (OH, OW):
94
+ return None
95
+ for dr in range(-5, 6):
96
+ for dc in range(-5, 6):
97
+ if dr == 0 and dc == 0:
98
+ continue
99
+ ok = True
100
+ for inp, out in exs:
101
+ shifted = np.zeros_like(inp)
102
+ r0, r1 = max(0, dr), min(IH, IH + dr)
103
+ c0, c1 = max(0, dc), min(IW, IW + dc)
104
+ if r1 > r0 and c1 > c0:
105
+ sr0, sc0 = max(0, -dr), max(0, -dc)
106
+ shifted[r0:r1, c0:c1] = inp[sr0:sr0 + (r1 - r0), sc0:sc0 + (c1 - c0)]
107
+ if not np.array_equal(shifted, out):
108
+ ok = False
109
+ break
110
+ if not ok:
111
+ continue
112
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
113
+ cst = np.full((OH, OW), 0, dtype=np.int64)
114
+ for r in range(OH):
115
+ for c in range(OW):
116
+ sr, sc = r - dr, c - dc
117
+ if 0 <= sr < IH and 0 <= sc < IW:
118
+ idx[r, c] = [sr, sc]
119
+ else:
120
+ idx[r, c] = [-1, -1]
121
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
122
+ return None
123
+
124
+
125
+ def s_fixed_crop(td):
126
+ """Fixed crop solver."""
127
+ exs = get_exs(td)
128
+ sp = fixed_shapes(td)
129
+ if sp is None:
130
+ return None
131
+ (IH, IW), (OH, OW) = sp
132
+ if OH > IH or OW > IW or (OH == IH and OW == IW):
133
+ return None
134
+ for r0 in range(IH - OH + 1):
135
+ for c0 in range(IW - OW + 1):
136
+ if all(np.array_equal(inp[r0:r0 + OH, c0:c0 + OW], out) for inp, out in exs):
137
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
138
+ for r in range(OH):
139
+ for c in range(OW):
140
+ idx[r, c] = [r0 + r, c0 + c]
141
+ return _build_gather_model(OH, OW, idx)
142
+ return None
143
+
144
+
145
+ def s_gravity(td):
146
+ """Detect gravity-like compaction (detection only, no ONNX model built)."""
147
+ exs = get_exs(td)
148
+ sp = fixed_shapes(td)
149
+ if sp is None:
150
+ return None
151
+ (IH, IW), (OH, OW) = sp
152
+ if (IH, IW) != (OH, OW):
153
+ return None
154
+
155
+ def _gravity(grid, direction):
156
+ r = np.zeros_like(grid)
157
+ h, w = grid.shape
158
+ if direction in ('down', 'up'):
159
+ for c in range(w):
160
+ nz = grid[:, c][grid[:, c] != 0]
161
+ if direction == 'down':
162
+ r[h - len(nz):h, c] = nz
163
+ else:
164
+ r[:len(nz), c] = nz
165
+ else:
166
+ for rr in range(h):
167
+ nz = grid[rr, :][grid[rr, :] != 0]
168
+ if direction == 'right':
169
+ r[rr, w - len(nz):w] = nz
170
+ else:
171
+ r[rr, :len(nz)] = nz
172
+ return r
173
+
174
+ for d in ('down', 'up', 'left', 'right'):
175
+ if all(np.array_equal(_gravity(inp, d), out) for inp, out in exs):
176
+ return None
177
+ return None