Move own-solver/neurogolf_solver/solvers/geometric.py to own-solver/
Browse files
own-solver/neurogolf_solver/solvers/geometric.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Geometric transformation solvers: flip, rotate, shift, crop."""
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from onnx import helper
|
| 6 |
+
from ..onnx_helpers import mk, _build_slice_crop, _build_slice_reverse, _build_pad_node
|
| 7 |
+
from ..data_loader import get_exs, fixed_shapes
|
| 8 |
+
from ..gather_helpers import _build_gather_model, _build_gather_model_with_const
|
| 9 |
+
from ..constants import GH, GW
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def s_flip(td):
|
| 13 |
+
"""Flip using Slice(step=-1) — zero MACs."""
|
| 14 |
+
exs = get_exs(td)
|
| 15 |
+
sp = fixed_shapes(td)
|
| 16 |
+
if sp is None:
|
| 17 |
+
return None
|
| 18 |
+
(IH, IW), (OH, OW) = sp
|
| 19 |
+
if (IH, IW) != (OH, OW):
|
| 20 |
+
return None
|
| 21 |
+
for axis, flip_fn in [(0, np.flipud), (1, np.fliplr)]:
|
| 22 |
+
if all(np.array_equal(out, flip_fn(inp)) for inp, out in exs):
|
| 23 |
+
onnx_axis = 2 if axis == 0 else 3
|
| 24 |
+
dim_size = IH if axis == 0 else IW
|
| 25 |
+
pad_h, pad_w = GH - IH, GW - IW
|
| 26 |
+
inits = []
|
| 27 |
+
nodes = []
|
| 28 |
+
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
| 29 |
+
nodes.append(_build_slice_reverse('cropped', 'flipped', onnx_axis, dim_size, inits))
|
| 30 |
+
nodes.append(_build_pad_node('flipped', 'output', pad_h, pad_w, inits))
|
| 31 |
+
return mk(nodes, inits)
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def s_rotate(td):
|
| 36 |
+
"""Rotate using Slice+Transpose — zero MACs for square grids and k=2.
|
| 37 |
+
Gather fallback for non-square k=1,3."""
|
| 38 |
+
exs = get_exs(td)
|
| 39 |
+
sp = fixed_shapes(td)
|
| 40 |
+
if sp is None:
|
| 41 |
+
return None
|
| 42 |
+
(IH, IW), (OH, OW) = sp
|
| 43 |
+
for k in [1, 2, 3]:
|
| 44 |
+
if not all(np.array_equal(out, np.rot90(inp, k)) for inp, out in exs):
|
| 45 |
+
continue
|
| 46 |
+
if k == 2:
|
| 47 |
+
pad_h, pad_w = GH - OH, GW - OW
|
| 48 |
+
inits = []
|
| 49 |
+
nodes = []
|
| 50 |
+
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
| 51 |
+
nodes.append(_build_slice_reverse('cropped', 'flip_h', 2, IH, inits, suffix='_h'))
|
| 52 |
+
nodes.append(_build_slice_reverse('flip_h', 'rotated', 3, IW, inits, suffix='_w'))
|
| 53 |
+
nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
|
| 54 |
+
return mk(nodes, inits)
|
| 55 |
+
elif k == 1 and IH == IW:
|
| 56 |
+
pad_h, pad_w = GH - IH, GW - IW
|
| 57 |
+
inits = []
|
| 58 |
+
nodes = []
|
| 59 |
+
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
| 60 |
+
nodes.append(helper.make_node('Transpose', ['cropped'], ['transposed'], perm=[0, 1, 3, 2]))
|
| 61 |
+
nodes.append(_build_slice_reverse('transposed', 'rotated', 2, IH, inits))
|
| 62 |
+
nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
|
| 63 |
+
return mk(nodes, inits)
|
| 64 |
+
elif k == 3 and IH == IW:
|
| 65 |
+
pad_h, pad_w = GH - IH, GW - IW
|
| 66 |
+
inits = []
|
| 67 |
+
nodes = []
|
| 68 |
+
nodes.append(_build_slice_crop('input', 'cropped', IH, IW, inits))
|
| 69 |
+
nodes.append(_build_slice_reverse('cropped', 'flipped', 2, IH, inits))
|
| 70 |
+
nodes.append(helper.make_node('Transpose', ['flipped'], ['rotated'], perm=[0, 1, 3, 2]))
|
| 71 |
+
nodes.append(_build_pad_node('rotated', 'output', pad_h, pad_w, inits))
|
| 72 |
+
return mk(nodes, inits)
|
| 73 |
+
else:
|
| 74 |
+
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 75 |
+
for r in range(OH):
|
| 76 |
+
for c in range(OW):
|
| 77 |
+
if k == 1:
|
| 78 |
+
sr, sc = c, IH - 1 - r
|
| 79 |
+
elif k == 3:
|
| 80 |
+
sr, sc = IW - 1 - c, r
|
| 81 |
+
idx[r, c] = [sr, sc]
|
| 82 |
+
return _build_gather_model(OH, OW, idx)
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def s_shift(td):
|
| 87 |
+
"""Shift transformation solver."""
|
| 88 |
+
exs = get_exs(td)
|
| 89 |
+
sp = fixed_shapes(td)
|
| 90 |
+
if sp is None:
|
| 91 |
+
return None
|
| 92 |
+
(IH, IW), (OH, OW) = sp
|
| 93 |
+
if (IH, IW) != (OH, OW):
|
| 94 |
+
return None
|
| 95 |
+
for dr in range(-5, 6):
|
| 96 |
+
for dc in range(-5, 6):
|
| 97 |
+
if dr == 0 and dc == 0:
|
| 98 |
+
continue
|
| 99 |
+
ok = True
|
| 100 |
+
for inp, out in exs:
|
| 101 |
+
shifted = np.zeros_like(inp)
|
| 102 |
+
r0, r1 = max(0, dr), min(IH, IH + dr)
|
| 103 |
+
c0, c1 = max(0, dc), min(IW, IW + dc)
|
| 104 |
+
if r1 > r0 and c1 > c0:
|
| 105 |
+
sr0, sc0 = max(0, -dr), max(0, -dc)
|
| 106 |
+
shifted[r0:r1, c0:c1] = inp[sr0:sr0 + (r1 - r0), sc0:sc0 + (c1 - c0)]
|
| 107 |
+
if not np.array_equal(shifted, out):
|
| 108 |
+
ok = False
|
| 109 |
+
break
|
| 110 |
+
if not ok:
|
| 111 |
+
continue
|
| 112 |
+
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 113 |
+
cst = np.full((OH, OW), 0, dtype=np.int64)
|
| 114 |
+
for r in range(OH):
|
| 115 |
+
for c in range(OW):
|
| 116 |
+
sr, sc = r - dr, c - dc
|
| 117 |
+
if 0 <= sr < IH and 0 <= sc < IW:
|
| 118 |
+
idx[r, c] = [sr, sc]
|
| 119 |
+
else:
|
| 120 |
+
idx[r, c] = [-1, -1]
|
| 121 |
+
return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
|
| 122 |
+
return None
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def s_fixed_crop(td):
|
| 126 |
+
"""Fixed crop solver."""
|
| 127 |
+
exs = get_exs(td)
|
| 128 |
+
sp = fixed_shapes(td)
|
| 129 |
+
if sp is None:
|
| 130 |
+
return None
|
| 131 |
+
(IH, IW), (OH, OW) = sp
|
| 132 |
+
if OH > IH or OW > IW or (OH == IH and OW == IW):
|
| 133 |
+
return None
|
| 134 |
+
for r0 in range(IH - OH + 1):
|
| 135 |
+
for c0 in range(IW - OW + 1):
|
| 136 |
+
if all(np.array_equal(inp[r0:r0 + OH, c0:c0 + OW], out) for inp, out in exs):
|
| 137 |
+
idx = np.zeros((OH, OW, 2), dtype=np.int64)
|
| 138 |
+
for r in range(OH):
|
| 139 |
+
for c in range(OW):
|
| 140 |
+
idx[r, c] = [r0 + r, c0 + c]
|
| 141 |
+
return _build_gather_model(OH, OW, idx)
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def s_gravity(td):
|
| 146 |
+
"""Detect gravity-like compaction (detection only, no ONNX model built)."""
|
| 147 |
+
exs = get_exs(td)
|
| 148 |
+
sp = fixed_shapes(td)
|
| 149 |
+
if sp is None:
|
| 150 |
+
return None
|
| 151 |
+
(IH, IW), (OH, OW) = sp
|
| 152 |
+
if (IH, IW) != (OH, OW):
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
def _gravity(grid, direction):
|
| 156 |
+
r = np.zeros_like(grid)
|
| 157 |
+
h, w = grid.shape
|
| 158 |
+
if direction in ('down', 'up'):
|
| 159 |
+
for c in range(w):
|
| 160 |
+
nz = grid[:, c][grid[:, c] != 0]
|
| 161 |
+
if direction == 'down':
|
| 162 |
+
r[h - len(nz):h, c] = nz
|
| 163 |
+
else:
|
| 164 |
+
r[:len(nz), c] = nz
|
| 165 |
+
else:
|
| 166 |
+
for rr in range(h):
|
| 167 |
+
nz = grid[rr, :][grid[rr, :] != 0]
|
| 168 |
+
if direction == 'right':
|
| 169 |
+
r[rr, w - len(nz):w] = nz
|
| 170 |
+
else:
|
| 171 |
+
r[rr, :len(nz)] = nz
|
| 172 |
+
return r
|
| 173 |
+
|
| 174 |
+
for d in ('down', 'up', 'left', 'right'):
|
| 175 |
+
if all(np.array_equal(_gravity(inp, d), out) for inp, out in exs):
|
| 176 |
+
return None
|
| 177 |
+
return None
|