#!/usr/bin/env python3 """Tiling, upscaling, mirror, concat, and spatial gather solvers.""" import numpy as np from onnx import helper from itertools import product as iproduct from ..onnx_helpers import mk, _make_int64_init, _build_pad_node from ..data_loader import get_exs, fixed_shapes from ..gather_helpers import _build_gather_model, _build_gather_model_with_const def s_tile(td): """Tiling solver.""" exs = get_exs(td) in_shapes = set(inp.shape for inp, _ in exs) if len(in_shapes) != 1: return None IH, IW = in_shapes.pop() tiles = set() for inp, out in exs: OH, OW = out.shape if OH % IH or OW % IW: return None rH, rW = OH // IH, OW // IW if rH < 1 or rW < 1 or (rH == 1 and rW == 1): return None tiles.add((rH, rW)) if len(tiles) != 1: return None rH, rW = tiles.pop() OH, OW = IH * rH, IW * rW if OH > 30 or OW > 30: return None for inp, out in exs: if not np.array_equal(out, np.tile(inp, (rH, rW))): return None pad_h, pad_w = 30 - OH, 30 - OW inits = [ _make_int64_init('st', [0, 0, 0, 0]), _make_int64_init('en', [1, 10, IH, IW]), _make_int64_init('rp', [1, 1, rH, rW]), ] nodes = [ helper.make_node('Slice', ['input', 'st', 'en'], ['cr']), helper.make_node('Tile', ['cr', 'rp'], ['tl']), ] nodes.append(_build_pad_node('tl', 'output', pad_h, pad_w, inits)) return mk(nodes, inits) def s_upscale(td): """Upscaling solver.""" exs = get_exs(td) in_shapes = set(inp.shape for inp, _ in exs) if len(in_shapes) != 1: return None IH, IW = in_shapes.pop() scales = set() for inp, out in exs: OH, OW = out.shape if OH % IH or OW % IW: return None sH, sW = OH // IH, OW // IW if sH < 2 or sW < 2: return None scales.add((sH, sW)) if len(scales) != 1: return None sH, sW = scales.pop() OH, OW = IH * sH, IW * sW if OH > 30 or OW > 30: return None for inp, out in exs: if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)): return None idx = np.zeros((OH, OW, 2), dtype=np.int64) for r in range(OH): for c in range(OW): idx[r, c] = [r // sH, c // sW] return _build_gather_model(OH, OW, idx) def s_kronecker(td): """Kronecker product solver.""" exs = get_exs(td) sp = fixed_shapes(td) if sp is None: return None (IH, IW), (OH, OW) = sp if OH % IH != 0 or OW % IW != 0: return None sH, sW = OH // IH, OW // IW if sH < 2 or sW < 2: return None if OH > 30 or OW > 30: return None for inp, out in exs: if not np.array_equal(out, np.kron(inp, np.ones((sH, sW), dtype=np.int64))): return None idx = np.zeros((OH, OW, 2), dtype=np.int64) for r in range(OH): for c in range(OW): idx[r, c] = [r // sH, c // sW] return _build_gather_model(OH, OW, idx) def s_nonuniform_scale(td): """Non-uniform scaling solver.""" exs = get_exs(td) sp = fixed_shapes(td) if sp is None: return None (IH, IW), (OH, OW) = sp for fh, fw in [(1, 2), (2, 1), (1, 3), (3, 1), (2, 3), (3, 2), (1, 4), (4, 1), (2, 4), (4, 2)]: if OH != IH * fh or OW != IW * fw: continue if OH > 30 or OW > 30: continue if all(np.array_equal(np.repeat(np.repeat(inp, fh, 0), fw, 1), out) for inp, out in exs): idx = np.zeros((OH, OW, 2), dtype=np.int64) for r in range(OH): for c in range(OW): idx[r, c] = [r // fh, c // fw] return _build_gather_model(OH, OW, idx) return None def s_diagonal_tile(td): """Diagonal tiling solver.""" exs = get_exs(td) sp = fixed_shapes(td) if sp is None: return None (IH, IW), (OH, OW) = sp if OH % IH != 0 or OW % IW != 0: return None rH, rW = OH // IH, OW // IW if rH != rW or rH < 2: return None if OH > 30 or OW > 30: return None for inp, out in exs: for bi in range(rH): for bj in range(rW): block = out[bi * IH:(bi + 1) * IH, bj * IW:(bj + 1) * IW] if bi == bj: if not np.array_equal(block, inp): return None else: if not np.all(block == 0): return None idx = np.zeros((OH, OW, 2), dtype=np.int64) cst = np.full((OH, OW), -1, dtype=np.int64) for bi in range(rH): for bj in range(rW): for lr in range(IH): for lc in range(IW): oi, oj = bi * IH + lr, bj * IW + lc if bi == bj: idx[oi, oj] = [lr, lc] else: idx[oi, oj] = [-1, -1] cst[oi, oj] = 0 return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst) def s_mirror_h(td): """Horizontal mirror solver.""" exs = get_exs(td) sp = fixed_shapes(td) if sp is None: return None (IH, IW), (OH, OW) = sp if OH != IH or OW != 2 * IW: return None if OW > 30: return None for inp, out in exs: if not np.array_equal(np.concatenate([inp, np.flip(inp, 1)], 1), out): return None idx = np.zeros((OH, OW, 2), dtype=np.int64) for r in range(OH): for c in range(OW): sc = c if c < IW else 2 * IW - 1 - c idx[r, c] = [r, sc] return _build_gather_model(OH, OW, idx) def s_mirror_v(td): """Vertical mirror solver.""" exs = get_exs(td) sp = fixed_shapes(td) if sp is None: return None (IH, IW), (OH, OW) = sp if OW != IW or OH != 2 * IH: return None if OH > 30: return None for inp, out in exs: if not np.array_equal(np.concatenate([inp, np.flip(inp, 0)], 0), out): return None idx = np.zeros((OH, OW, 2), dtype=np.int64) for r in range(OH): for c in range(OW): sr = r if r < IH else 2 * IH - 1 - r idx[r, c] = [sr, c] return _build_gather_model(OH, OW, idx) def s_quad_mirror(td): """Quad mirror solver.""" exs = get_exs(td) sp = fixed_shapes(td) if sp is None: return None (IH, IW), (OH, OW) = sp if OH != 2 * IH or OW != 2 * IW: return None if OH > 30 or OW > 30: return None for inp, out in exs: expected = np.block([[inp, np.flip(inp, 1)], [np.flip(inp, 0), np.flip(np.flip(inp, 0), 1)]]) if not np.array_equal(expected, out): return None idx = np.zeros((OH, OW, 2), dtype=np.int64) for r in range(OH): for c in range(OW): sr = r if r < IH else 2 * IH - 1 - r sc = c if c < IW else 2 * IW - 1 - c idx[r, c] = [sr, sc] return _build_gather_model(OH, OW, idx) def s_concat(td): """Concatenation solver with transformations.""" exs = get_exs(td) sp = fixed_shapes(td) if sp is None: return None (IH, IW), (OH, OW) = sp transforms = [ ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)), ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)), ] if OH == IH and OW % IW == 0 and OW > IW: n = OW // IW if 2 <= n <= 4: for combo in iproduct(range(4), repeat=n): if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=1)) for inp, out in exs): idx = np.zeros((OH, OW, 2), dtype=np.int64) for oi in range(OH): for oj in range(OW): bj = oj // IW lr, lc = oi, oj % IW t = transforms[combo[bj]][0] if t == 'id': sr, sc = lr, lc elif t == 'fliplr': sr, sc = lr, IW - 1 - lc elif t == 'flipud': sr, sc = IH - 1 - lr, lc elif t == 'rot180': sr, sc = IH - 1 - lr, IW - 1 - lc idx[oi, oj] = [sr, sc] return _build_gather_model(OH, OW, idx) if OW == IW and OH % IH == 0 and OH > IH: n = OH // IH if 2 <= n <= 4: for combo in iproduct(range(4), repeat=n): if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=0)) for inp, out in exs): idx = np.zeros((OH, OW, 2), dtype=np.int64) for oi in range(OH): for oj in range(OW): bi = oi // IH lr, lc = oi % IH, oj t = transforms[combo[bi]][0] if t == 'id': sr, sc = lr, lc elif t == 'fliplr': sr, sc = lr, IW - 1 - lc elif t == 'flipud': sr, sc = IH - 1 - lr, lc elif t == 'rot180': sr, sc = IH - 1 - lr, IW - 1 - lc idx[oi, oj] = [sr, sc] return _build_gather_model(OH, OW, idx) return None def s_concat_enhanced(td): """Enhanced concatenation with all 8 dihedral transforms.""" exs = get_exs(td) sp = fixed_shapes(td) if sp is None: return None (IH, IW), (OH, OW) = sp if IH == OH and IW == OW: return None if OH % IH != 0 or OW % IW != 0: return None rH, rW = OH // IH, OW // IW if rH * rW > 16 or rH * rW < 2: return None if OH > 30 or OW > 30: return None transforms = [ ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)), ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)), ('rot90', lambda x: np.rot90(x, 1)), ('rot270', lambda x: np.rot90(x, 3)), ('T', lambda x: x.T), ('T_fliplr', lambda x: np.fliplr(x.T)), ] block_transforms = {} for bi in range(rH): for bj in range(rW): found = None for tidx, (tname, tfn) in enumerate(transforms): ok = True for inp, out in exs: block = out[bi * IH:(bi + 1) * IH, bj * IW:(bj + 1) * IW] expected = tfn(inp) if expected.shape != (IH, IW) or not np.array_equal(block, expected): ok = False break if ok: found = (tidx, tname) break if found is None: return None block_transforms[(bi, bj)] = found idx = np.zeros((OH, OW, 2), dtype=np.int64) for bi in range(rH): for bj in range(rW): _, tname = block_transforms[(bi, bj)] for lr in range(IH): for lc in range(IW): oi, oj = bi * IH + lr, bj * IW + lc if tname == 'id': sr, sc = lr, lc elif tname == 'fliplr': sr, sc = lr, IW - 1 - lc elif tname == 'flipud': sr, sc = IH - 1 - lr, lc elif tname == 'rot180': sr, sc = IH - 1 - lr, IW - 1 - lc elif tname == 'rot90': sr, sc = IW - 1 - lc, lr elif tname == 'rot270': sr, sc = lc, IH - 1 - lr elif tname == 'T': sr, sc = lc, lr elif tname == 'T_fliplr': sr, sc = IW - 1 - lc, lr idx[oi, oj] = [sr, sc] for inp, out in exs: reconstructed = np.zeros_like(out) for oi in range(OH): for oj in range(OW): reconstructed[oi, oj] = inp[idx[oi, oj, 0], idx[oi, oj, 1]] if not np.array_equal(reconstructed, out): return None return _build_gather_model(OH, OW, idx) def s_spatial_gather(td): """Spatial gather solver.""" sp = fixed_shapes(td) if sp is None: return None (IH, IW), (OH, OW) = sp exs = get_exs(td) idx = np.full((OH, OW, 2), -1, dtype=np.int64) cst = np.full((OH, OW), -1, dtype=np.int64) for oi in range(OH): for oj in range(OW): vals = set(int(out[oi, oj]) for _, out in exs) if len(vals) == 1: cst[oi, oj] = vals.pop() found = False for ri in range(IH): for rj in range(IW): if all(int(inp[ri, rj]) == int(out[oi, oj]) for inp, out in exs): idx[oi, oj] = [ri, rj] found = True break if found: break if not found and cst[oi, oj] < 0: return None return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst) def s_varshape_spatial_gather(td): """Variable shape spatial gather solver.""" sp = fixed_shapes(td) if sp is not None: return None exs = get_exs(td) exs_30 = [] for inp, out in exs: ih, iw = inp.shape oh, ow = out.shape inp30 = np.zeros((30, 30), dtype=np.int64) out30 = np.zeros((30, 30), dtype=np.int64) inp30[:ih, :iw] = inp out30[:oh, :ow] = out exs_30.append((inp30, out30)) idx = np.full((30, 30, 2), -1, dtype=np.int64) cst = np.full((30, 30), -1, dtype=np.int64) for oi in range(30): for oj in range(30): vals = set(int(out30[oi, oj]) for _, out30 in exs_30) if len(vals) == 1: cst[oi, oj] = vals.pop() found = False for ri in range(30): for rj in range(30): if all(int(inp30[ri, rj]) == int(out30[oi, oj]) for inp30, out30 in exs_30): idx[oi, oj] = [ri, rj] found = True break if found: break if not found and cst[oi, oj] < 0: return None return _build_gather_model_with_const(30, 30, 30, 30, idx, cst)