| |
| """Tiling, upscaling, mirror, concat, and spatial gather solvers.""" |
|
|
| import numpy as np |
| from onnx import helper |
| from itertools import product as iproduct |
| from ..onnx_helpers import mk, _make_int64_init, _build_pad_node |
| from ..data_loader import get_exs, fixed_shapes |
| from ..gather_helpers import _build_gather_model, _build_gather_model_with_const |
|
|
|
|
| def s_tile(td): |
| """Tiling solver.""" |
| exs = get_exs(td) |
| in_shapes = set(inp.shape for inp, _ in exs) |
| if len(in_shapes) != 1: |
| return None |
| IH, IW = in_shapes.pop() |
| tiles = set() |
| for inp, out in exs: |
| OH, OW = out.shape |
| if OH % IH or OW % IW: |
| return None |
| rH, rW = OH // IH, OW // IW |
| if rH < 1 or rW < 1 or (rH == 1 and rW == 1): |
| return None |
| tiles.add((rH, rW)) |
| if len(tiles) != 1: |
| return None |
| rH, rW = tiles.pop() |
| OH, OW = IH * rH, IW * rW |
| if OH > 30 or OW > 30: |
| return None |
| for inp, out in exs: |
| if not np.array_equal(out, np.tile(inp, (rH, rW))): |
| return None |
| pad_h, pad_w = 30 - OH, 30 - OW |
| inits = [ |
| _make_int64_init('st', [0, 0, 0, 0]), |
| _make_int64_init('en', [1, 10, IH, IW]), |
| _make_int64_init('rp', [1, 1, rH, rW]), |
| ] |
| nodes = [ |
| helper.make_node('Slice', ['input', 'st', 'en'], ['cr']), |
| helper.make_node('Tile', ['cr', 'rp'], ['tl']), |
| ] |
| nodes.append(_build_pad_node('tl', 'output', pad_h, pad_w, inits)) |
| return mk(nodes, inits) |
|
|
|
|
| def s_upscale(td): |
| """Upscaling solver.""" |
| exs = get_exs(td) |
| in_shapes = set(inp.shape for inp, _ in exs) |
| if len(in_shapes) != 1: |
| return None |
| IH, IW = in_shapes.pop() |
| scales = set() |
| for inp, out in exs: |
| OH, OW = out.shape |
| if OH % IH or OW % IW: |
| return None |
| sH, sW = OH // IH, OW // IW |
| if sH < 2 or sW < 2: |
| return None |
| scales.add((sH, sW)) |
| if len(scales) != 1: |
| return None |
| sH, sW = scales.pop() |
| OH, OW = IH * sH, IW * sW |
| if OH > 30 or OW > 30: |
| return None |
| for inp, out in exs: |
| if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)): |
| return None |
| idx = np.zeros((OH, OW, 2), dtype=np.int64) |
| for r in range(OH): |
| for c in range(OW): |
| idx[r, c] = [r // sH, c // sW] |
| return _build_gather_model(OH, OW, idx) |
|
|
|
|
| def s_kronecker(td): |
| """Kronecker product solver.""" |
| exs = get_exs(td) |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| if OH % IH != 0 or OW % IW != 0: |
| return None |
| sH, sW = OH // IH, OW // IW |
| if sH < 2 or sW < 2: |
| return None |
| if OH > 30 or OW > 30: |
| return None |
| for inp, out in exs: |
| if not np.array_equal(out, np.kron(inp, np.ones((sH, sW), dtype=np.int64))): |
| return None |
| idx = np.zeros((OH, OW, 2), dtype=np.int64) |
| for r in range(OH): |
| for c in range(OW): |
| idx[r, c] = [r // sH, c // sW] |
| return _build_gather_model(OH, OW, idx) |
|
|
|
|
| def s_nonuniform_scale(td): |
| """Non-uniform scaling solver.""" |
| exs = get_exs(td) |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| for fh, fw in [(1, 2), (2, 1), (1, 3), (3, 1), (2, 3), (3, 2), (1, 4), (4, 1), (2, 4), (4, 2)]: |
| if OH != IH * fh or OW != IW * fw: |
| continue |
| if OH > 30 or OW > 30: |
| continue |
| if all(np.array_equal(np.repeat(np.repeat(inp, fh, 0), fw, 1), out) for inp, out in exs): |
| idx = np.zeros((OH, OW, 2), dtype=np.int64) |
| for r in range(OH): |
| for c in range(OW): |
| idx[r, c] = [r // fh, c // fw] |
| return _build_gather_model(OH, OW, idx) |
| return None |
|
|
|
|
| def s_diagonal_tile(td): |
| """Diagonal tiling solver.""" |
| exs = get_exs(td) |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| if OH % IH != 0 or OW % IW != 0: |
| return None |
| rH, rW = OH // IH, OW // IW |
| if rH != rW or rH < 2: |
| return None |
| if OH > 30 or OW > 30: |
| return None |
| for inp, out in exs: |
| for bi in range(rH): |
| for bj in range(rW): |
| block = out[bi * IH:(bi + 1) * IH, bj * IW:(bj + 1) * IW] |
| if bi == bj: |
| if not np.array_equal(block, inp): |
| return None |
| else: |
| if not np.all(block == 0): |
| return None |
| idx = np.zeros((OH, OW, 2), dtype=np.int64) |
| cst = np.full((OH, OW), -1, dtype=np.int64) |
| for bi in range(rH): |
| for bj in range(rW): |
| for lr in range(IH): |
| for lc in range(IW): |
| oi, oj = bi * IH + lr, bj * IW + lc |
| if bi == bj: |
| idx[oi, oj] = [lr, lc] |
| else: |
| idx[oi, oj] = [-1, -1] |
| cst[oi, oj] = 0 |
| return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst) |
|
|
|
|
| def s_mirror_h(td): |
| """Horizontal mirror solver.""" |
| exs = get_exs(td) |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| if OH != IH or OW != 2 * IW: |
| return None |
| if OW > 30: |
| return None |
| for inp, out in exs: |
| if not np.array_equal(np.concatenate([inp, np.flip(inp, 1)], 1), out): |
| return None |
| idx = np.zeros((OH, OW, 2), dtype=np.int64) |
| for r in range(OH): |
| for c in range(OW): |
| sc = c if c < IW else 2 * IW - 1 - c |
| idx[r, c] = [r, sc] |
| return _build_gather_model(OH, OW, idx) |
|
|
|
|
| def s_mirror_v(td): |
| """Vertical mirror solver.""" |
| exs = get_exs(td) |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| if OW != IW or OH != 2 * IH: |
| return None |
| if OH > 30: |
| return None |
| for inp, out in exs: |
| if not np.array_equal(np.concatenate([inp, np.flip(inp, 0)], 0), out): |
| return None |
| idx = np.zeros((OH, OW, 2), dtype=np.int64) |
| for r in range(OH): |
| for c in range(OW): |
| sr = r if r < IH else 2 * IH - 1 - r |
| idx[r, c] = [sr, c] |
| return _build_gather_model(OH, OW, idx) |
|
|
|
|
| def s_quad_mirror(td): |
| """Quad mirror solver.""" |
| exs = get_exs(td) |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| if OH != 2 * IH or OW != 2 * IW: |
| return None |
| if OH > 30 or OW > 30: |
| return None |
| for inp, out in exs: |
| expected = np.block([[inp, np.flip(inp, 1)], |
| [np.flip(inp, 0), np.flip(np.flip(inp, 0), 1)]]) |
| if not np.array_equal(expected, out): |
| return None |
| idx = np.zeros((OH, OW, 2), dtype=np.int64) |
| for r in range(OH): |
| for c in range(OW): |
| sr = r if r < IH else 2 * IH - 1 - r |
| sc = c if c < IW else 2 * IW - 1 - c |
| idx[r, c] = [sr, sc] |
| return _build_gather_model(OH, OW, idx) |
|
|
|
|
| def s_concat(td): |
| """Concatenation solver with transformations.""" |
| exs = get_exs(td) |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| transforms = [ |
| ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)), |
| ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)), |
| ] |
| if OH == IH and OW % IW == 0 and OW > IW: |
| n = OW // IW |
| if 2 <= n <= 4: |
| for combo in iproduct(range(4), repeat=n): |
| if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=1)) |
| for inp, out in exs): |
| idx = np.zeros((OH, OW, 2), dtype=np.int64) |
| for oi in range(OH): |
| for oj in range(OW): |
| bj = oj // IW |
| lr, lc = oi, oj % IW |
| t = transforms[combo[bj]][0] |
| if t == 'id': |
| sr, sc = lr, lc |
| elif t == 'fliplr': |
| sr, sc = lr, IW - 1 - lc |
| elif t == 'flipud': |
| sr, sc = IH - 1 - lr, lc |
| elif t == 'rot180': |
| sr, sc = IH - 1 - lr, IW - 1 - lc |
| idx[oi, oj] = [sr, sc] |
| return _build_gather_model(OH, OW, idx) |
| if OW == IW and OH % IH == 0 and OH > IH: |
| n = OH // IH |
| if 2 <= n <= 4: |
| for combo in iproduct(range(4), repeat=n): |
| if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=0)) |
| for inp, out in exs): |
| idx = np.zeros((OH, OW, 2), dtype=np.int64) |
| for oi in range(OH): |
| for oj in range(OW): |
| bi = oi // IH |
| lr, lc = oi % IH, oj |
| t = transforms[combo[bi]][0] |
| if t == 'id': |
| sr, sc = lr, lc |
| elif t == 'fliplr': |
| sr, sc = lr, IW - 1 - lc |
| elif t == 'flipud': |
| sr, sc = IH - 1 - lr, lc |
| elif t == 'rot180': |
| sr, sc = IH - 1 - lr, IW - 1 - lc |
| idx[oi, oj] = [sr, sc] |
| return _build_gather_model(OH, OW, idx) |
| return None |
|
|
|
|
| def s_concat_enhanced(td): |
| """Enhanced concatenation with all 8 dihedral transforms.""" |
| exs = get_exs(td) |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| if IH == OH and IW == OW: |
| return None |
| if OH % IH != 0 or OW % IW != 0: |
| return None |
| rH, rW = OH // IH, OW // IW |
| if rH * rW > 16 or rH * rW < 2: |
| return None |
| if OH > 30 or OW > 30: |
| return None |
| transforms = [ |
| ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)), |
| ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)), |
| ('rot90', lambda x: np.rot90(x, 1)), ('rot270', lambda x: np.rot90(x, 3)), |
| ('T', lambda x: x.T), ('T_fliplr', lambda x: np.fliplr(x.T)), |
| ] |
| block_transforms = {} |
| for bi in range(rH): |
| for bj in range(rW): |
| found = None |
| for tidx, (tname, tfn) in enumerate(transforms): |
| ok = True |
| for inp, out in exs: |
| block = out[bi * IH:(bi + 1) * IH, bj * IW:(bj + 1) * IW] |
| expected = tfn(inp) |
| if expected.shape != (IH, IW) or not np.array_equal(block, expected): |
| ok = False |
| break |
| if ok: |
| found = (tidx, tname) |
| break |
| if found is None: |
| return None |
| block_transforms[(bi, bj)] = found |
| idx = np.zeros((OH, OW, 2), dtype=np.int64) |
| for bi in range(rH): |
| for bj in range(rW): |
| _, tname = block_transforms[(bi, bj)] |
| for lr in range(IH): |
| for lc in range(IW): |
| oi, oj = bi * IH + lr, bj * IW + lc |
| if tname == 'id': |
| sr, sc = lr, lc |
| elif tname == 'fliplr': |
| sr, sc = lr, IW - 1 - lc |
| elif tname == 'flipud': |
| sr, sc = IH - 1 - lr, lc |
| elif tname == 'rot180': |
| sr, sc = IH - 1 - lr, IW - 1 - lc |
| elif tname == 'rot90': |
| sr, sc = IW - 1 - lc, lr |
| elif tname == 'rot270': |
| sr, sc = lc, IH - 1 - lr |
| elif tname == 'T': |
| sr, sc = lc, lr |
| elif tname == 'T_fliplr': |
| sr, sc = IW - 1 - lc, lr |
| idx[oi, oj] = [sr, sc] |
| for inp, out in exs: |
| reconstructed = np.zeros_like(out) |
| for oi in range(OH): |
| for oj in range(OW): |
| reconstructed[oi, oj] = inp[idx[oi, oj, 0], idx[oi, oj, 1]] |
| if not np.array_equal(reconstructed, out): |
| return None |
| return _build_gather_model(OH, OW, idx) |
|
|
|
|
| def s_spatial_gather(td): |
| """Spatial gather solver.""" |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| exs = get_exs(td) |
| idx = np.full((OH, OW, 2), -1, dtype=np.int64) |
| cst = np.full((OH, OW), -1, dtype=np.int64) |
| for oi in range(OH): |
| for oj in range(OW): |
| vals = set(int(out[oi, oj]) for _, out in exs) |
| if len(vals) == 1: |
| cst[oi, oj] = vals.pop() |
| found = False |
| for ri in range(IH): |
| for rj in range(IW): |
| if all(int(inp[ri, rj]) == int(out[oi, oj]) for inp, out in exs): |
| idx[oi, oj] = [ri, rj] |
| found = True |
| break |
| if found: |
| break |
| if not found and cst[oi, oj] < 0: |
| return None |
| return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst) |
|
|
|
|
| def s_varshape_spatial_gather(td): |
| """Variable shape spatial gather solver.""" |
| sp = fixed_shapes(td) |
| if sp is not None: |
| return None |
| exs = get_exs(td) |
| exs_30 = [] |
| for inp, out in exs: |
| ih, iw = inp.shape |
| oh, ow = out.shape |
| inp30 = np.zeros((30, 30), dtype=np.int64) |
| out30 = np.zeros((30, 30), dtype=np.int64) |
| inp30[:ih, :iw] = inp |
| out30[:oh, :ow] = out |
| exs_30.append((inp30, out30)) |
| idx = np.full((30, 30, 2), -1, dtype=np.int64) |
| cst = np.full((30, 30), -1, dtype=np.int64) |
| for oi in range(30): |
| for oj in range(30): |
| vals = set(int(out30[oi, oj]) for _, out30 in exs_30) |
| if len(vals) == 1: |
| cst[oi, oj] = vals.pop() |
| found = False |
| for ri in range(30): |
| for rj in range(30): |
| if all(int(inp30[ri, rj]) == int(out30[oi, oj]) for inp30, out30 in exs_30): |
| idx[oi, oj] = [ri, rj] |
| found = True |
| break |
| if found: |
| break |
| if not found and cst[oi, oj] < 0: |
| return None |
| return _build_gather_model_with_const(30, 30, 30, 30, idx, cst) |
|
|