rogermt's picture
Move own-solver/neurogolf_solver/solvers/tiling.py to own-solver/
1f0695e verified
#!/usr/bin/env python3
"""Tiling, upscaling, mirror, concat, and spatial gather solvers."""
import numpy as np
from onnx import helper
from itertools import product as iproduct
from ..onnx_helpers import mk, _make_int64_init, _build_pad_node
from ..data_loader import get_exs, fixed_shapes
from ..gather_helpers import _build_gather_model, _build_gather_model_with_const
def s_tile(td):
"""Tiling solver."""
exs = get_exs(td)
in_shapes = set(inp.shape for inp, _ in exs)
if len(in_shapes) != 1:
return None
IH, IW = in_shapes.pop()
tiles = set()
for inp, out in exs:
OH, OW = out.shape
if OH % IH or OW % IW:
return None
rH, rW = OH // IH, OW // IW
if rH < 1 or rW < 1 or (rH == 1 and rW == 1):
return None
tiles.add((rH, rW))
if len(tiles) != 1:
return None
rH, rW = tiles.pop()
OH, OW = IH * rH, IW * rW
if OH > 30 or OW > 30:
return None
for inp, out in exs:
if not np.array_equal(out, np.tile(inp, (rH, rW))):
return None
pad_h, pad_w = 30 - OH, 30 - OW
inits = [
_make_int64_init('st', [0, 0, 0, 0]),
_make_int64_init('en', [1, 10, IH, IW]),
_make_int64_init('rp', [1, 1, rH, rW]),
]
nodes = [
helper.make_node('Slice', ['input', 'st', 'en'], ['cr']),
helper.make_node('Tile', ['cr', 'rp'], ['tl']),
]
nodes.append(_build_pad_node('tl', 'output', pad_h, pad_w, inits))
return mk(nodes, inits)
def s_upscale(td):
"""Upscaling solver."""
exs = get_exs(td)
in_shapes = set(inp.shape for inp, _ in exs)
if len(in_shapes) != 1:
return None
IH, IW = in_shapes.pop()
scales = set()
for inp, out in exs:
OH, OW = out.shape
if OH % IH or OW % IW:
return None
sH, sW = OH // IH, OW // IW
if sH < 2 or sW < 2:
return None
scales.add((sH, sW))
if len(scales) != 1:
return None
sH, sW = scales.pop()
OH, OW = IH * sH, IW * sW
if OH > 30 or OW > 30:
return None
for inp, out in exs:
if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)):
return None
idx = np.zeros((OH, OW, 2), dtype=np.int64)
for r in range(OH):
for c in range(OW):
idx[r, c] = [r // sH, c // sW]
return _build_gather_model(OH, OW, idx)
def s_kronecker(td):
"""Kronecker product solver."""
exs = get_exs(td)
sp = fixed_shapes(td)
if sp is None:
return None
(IH, IW), (OH, OW) = sp
if OH % IH != 0 or OW % IW != 0:
return None
sH, sW = OH // IH, OW // IW
if sH < 2 or sW < 2:
return None
if OH > 30 or OW > 30:
return None
for inp, out in exs:
if not np.array_equal(out, np.kron(inp, np.ones((sH, sW), dtype=np.int64))):
return None
idx = np.zeros((OH, OW, 2), dtype=np.int64)
for r in range(OH):
for c in range(OW):
idx[r, c] = [r // sH, c // sW]
return _build_gather_model(OH, OW, idx)
def s_nonuniform_scale(td):
"""Non-uniform scaling solver."""
exs = get_exs(td)
sp = fixed_shapes(td)
if sp is None:
return None
(IH, IW), (OH, OW) = sp
for fh, fw in [(1, 2), (2, 1), (1, 3), (3, 1), (2, 3), (3, 2), (1, 4), (4, 1), (2, 4), (4, 2)]:
if OH != IH * fh or OW != IW * fw:
continue
if OH > 30 or OW > 30:
continue
if all(np.array_equal(np.repeat(np.repeat(inp, fh, 0), fw, 1), out) for inp, out in exs):
idx = np.zeros((OH, OW, 2), dtype=np.int64)
for r in range(OH):
for c in range(OW):
idx[r, c] = [r // fh, c // fw]
return _build_gather_model(OH, OW, idx)
return None
def s_diagonal_tile(td):
"""Diagonal tiling solver."""
exs = get_exs(td)
sp = fixed_shapes(td)
if sp is None:
return None
(IH, IW), (OH, OW) = sp
if OH % IH != 0 or OW % IW != 0:
return None
rH, rW = OH // IH, OW // IW
if rH != rW or rH < 2:
return None
if OH > 30 or OW > 30:
return None
for inp, out in exs:
for bi in range(rH):
for bj in range(rW):
block = out[bi * IH:(bi + 1) * IH, bj * IW:(bj + 1) * IW]
if bi == bj:
if not np.array_equal(block, inp):
return None
else:
if not np.all(block == 0):
return None
idx = np.zeros((OH, OW, 2), dtype=np.int64)
cst = np.full((OH, OW), -1, dtype=np.int64)
for bi in range(rH):
for bj in range(rW):
for lr in range(IH):
for lc in range(IW):
oi, oj = bi * IH + lr, bj * IW + lc
if bi == bj:
idx[oi, oj] = [lr, lc]
else:
idx[oi, oj] = [-1, -1]
cst[oi, oj] = 0
return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
def s_mirror_h(td):
"""Horizontal mirror solver."""
exs = get_exs(td)
sp = fixed_shapes(td)
if sp is None:
return None
(IH, IW), (OH, OW) = sp
if OH != IH or OW != 2 * IW:
return None
if OW > 30:
return None
for inp, out in exs:
if not np.array_equal(np.concatenate([inp, np.flip(inp, 1)], 1), out):
return None
idx = np.zeros((OH, OW, 2), dtype=np.int64)
for r in range(OH):
for c in range(OW):
sc = c if c < IW else 2 * IW - 1 - c
idx[r, c] = [r, sc]
return _build_gather_model(OH, OW, idx)
def s_mirror_v(td):
"""Vertical mirror solver."""
exs = get_exs(td)
sp = fixed_shapes(td)
if sp is None:
return None
(IH, IW), (OH, OW) = sp
if OW != IW or OH != 2 * IH:
return None
if OH > 30:
return None
for inp, out in exs:
if not np.array_equal(np.concatenate([inp, np.flip(inp, 0)], 0), out):
return None
idx = np.zeros((OH, OW, 2), dtype=np.int64)
for r in range(OH):
for c in range(OW):
sr = r if r < IH else 2 * IH - 1 - r
idx[r, c] = [sr, c]
return _build_gather_model(OH, OW, idx)
def s_quad_mirror(td):
"""Quad mirror solver."""
exs = get_exs(td)
sp = fixed_shapes(td)
if sp is None:
return None
(IH, IW), (OH, OW) = sp
if OH != 2 * IH or OW != 2 * IW:
return None
if OH > 30 or OW > 30:
return None
for inp, out in exs:
expected = np.block([[inp, np.flip(inp, 1)],
[np.flip(inp, 0), np.flip(np.flip(inp, 0), 1)]])
if not np.array_equal(expected, out):
return None
idx = np.zeros((OH, OW, 2), dtype=np.int64)
for r in range(OH):
for c in range(OW):
sr = r if r < IH else 2 * IH - 1 - r
sc = c if c < IW else 2 * IW - 1 - c
idx[r, c] = [sr, sc]
return _build_gather_model(OH, OW, idx)
def s_concat(td):
"""Concatenation solver with transformations."""
exs = get_exs(td)
sp = fixed_shapes(td)
if sp is None:
return None
(IH, IW), (OH, OW) = sp
transforms = [
('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
]
if OH == IH and OW % IW == 0 and OW > IW:
n = OW // IW
if 2 <= n <= 4:
for combo in iproduct(range(4), repeat=n):
if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=1))
for inp, out in exs):
idx = np.zeros((OH, OW, 2), dtype=np.int64)
for oi in range(OH):
for oj in range(OW):
bj = oj // IW
lr, lc = oi, oj % IW
t = transforms[combo[bj]][0]
if t == 'id':
sr, sc = lr, lc
elif t == 'fliplr':
sr, sc = lr, IW - 1 - lc
elif t == 'flipud':
sr, sc = IH - 1 - lr, lc
elif t == 'rot180':
sr, sc = IH - 1 - lr, IW - 1 - lc
idx[oi, oj] = [sr, sc]
return _build_gather_model(OH, OW, idx)
if OW == IW and OH % IH == 0 and OH > IH:
n = OH // IH
if 2 <= n <= 4:
for combo in iproduct(range(4), repeat=n):
if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=0))
for inp, out in exs):
idx = np.zeros((OH, OW, 2), dtype=np.int64)
for oi in range(OH):
for oj in range(OW):
bi = oi // IH
lr, lc = oi % IH, oj
t = transforms[combo[bi]][0]
if t == 'id':
sr, sc = lr, lc
elif t == 'fliplr':
sr, sc = lr, IW - 1 - lc
elif t == 'flipud':
sr, sc = IH - 1 - lr, lc
elif t == 'rot180':
sr, sc = IH - 1 - lr, IW - 1 - lc
idx[oi, oj] = [sr, sc]
return _build_gather_model(OH, OW, idx)
return None
def s_concat_enhanced(td):
"""Enhanced concatenation with all 8 dihedral transforms."""
exs = get_exs(td)
sp = fixed_shapes(td)
if sp is None:
return None
(IH, IW), (OH, OW) = sp
if IH == OH and IW == OW:
return None
if OH % IH != 0 or OW % IW != 0:
return None
rH, rW = OH // IH, OW // IW
if rH * rW > 16 or rH * rW < 2:
return None
if OH > 30 or OW > 30:
return None
transforms = [
('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
('rot90', lambda x: np.rot90(x, 1)), ('rot270', lambda x: np.rot90(x, 3)),
('T', lambda x: x.T), ('T_fliplr', lambda x: np.fliplr(x.T)),
]
block_transforms = {}
for bi in range(rH):
for bj in range(rW):
found = None
for tidx, (tname, tfn) in enumerate(transforms):
ok = True
for inp, out in exs:
block = out[bi * IH:(bi + 1) * IH, bj * IW:(bj + 1) * IW]
expected = tfn(inp)
if expected.shape != (IH, IW) or not np.array_equal(block, expected):
ok = False
break
if ok:
found = (tidx, tname)
break
if found is None:
return None
block_transforms[(bi, bj)] = found
idx = np.zeros((OH, OW, 2), dtype=np.int64)
for bi in range(rH):
for bj in range(rW):
_, tname = block_transforms[(bi, bj)]
for lr in range(IH):
for lc in range(IW):
oi, oj = bi * IH + lr, bj * IW + lc
if tname == 'id':
sr, sc = lr, lc
elif tname == 'fliplr':
sr, sc = lr, IW - 1 - lc
elif tname == 'flipud':
sr, sc = IH - 1 - lr, lc
elif tname == 'rot180':
sr, sc = IH - 1 - lr, IW - 1 - lc
elif tname == 'rot90':
sr, sc = IW - 1 - lc, lr
elif tname == 'rot270':
sr, sc = lc, IH - 1 - lr
elif tname == 'T':
sr, sc = lc, lr
elif tname == 'T_fliplr':
sr, sc = IW - 1 - lc, lr
idx[oi, oj] = [sr, sc]
for inp, out in exs:
reconstructed = np.zeros_like(out)
for oi in range(OH):
for oj in range(OW):
reconstructed[oi, oj] = inp[idx[oi, oj, 0], idx[oi, oj, 1]]
if not np.array_equal(reconstructed, out):
return None
return _build_gather_model(OH, OW, idx)
def s_spatial_gather(td):
"""Spatial gather solver."""
sp = fixed_shapes(td)
if sp is None:
return None
(IH, IW), (OH, OW) = sp
exs = get_exs(td)
idx = np.full((OH, OW, 2), -1, dtype=np.int64)
cst = np.full((OH, OW), -1, dtype=np.int64)
for oi in range(OH):
for oj in range(OW):
vals = set(int(out[oi, oj]) for _, out in exs)
if len(vals) == 1:
cst[oi, oj] = vals.pop()
found = False
for ri in range(IH):
for rj in range(IW):
if all(int(inp[ri, rj]) == int(out[oi, oj]) for inp, out in exs):
idx[oi, oj] = [ri, rj]
found = True
break
if found:
break
if not found and cst[oi, oj] < 0:
return None
return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
def s_varshape_spatial_gather(td):
"""Variable shape spatial gather solver."""
sp = fixed_shapes(td)
if sp is not None:
return None
exs = get_exs(td)
exs_30 = []
for inp, out in exs:
ih, iw = inp.shape
oh, ow = out.shape
inp30 = np.zeros((30, 30), dtype=np.int64)
out30 = np.zeros((30, 30), dtype=np.int64)
inp30[:ih, :iw] = inp
out30[:oh, :ow] = out
exs_30.append((inp30, out30))
idx = np.full((30, 30, 2), -1, dtype=np.int64)
cst = np.full((30, 30), -1, dtype=np.int64)
for oi in range(30):
for oj in range(30):
vals = set(int(out30[oi, oj]) for _, out30 in exs_30)
if len(vals) == 1:
cst[oi, oj] = vals.pop()
found = False
for ri in range(30):
for rj in range(30):
if all(int(inp30[ri, rj]) == int(out30[oi, oj]) for inp30, out30 in exs_30):
idx[oi, oj] = [ri, rj]
found = True
break
if found:
break
if not found and cst[oi, oj] < 0:
return None
return _build_gather_model_with_const(30, 30, 30, 30, idx, cst)