#!/usr/bin/env python3 """Basic analytical solvers: identity, constant, color_map, transpose.""" import numpy as np from onnx import helper, numpy_helper, TensorProto from ..onnx_helpers import mk, _make_int64_init from ..data_loader import get_exs, fixed_shapes def s_identity(td): """Identity solver.""" for ex in td['train'] + td['test']: if ex['input'] != ex['output']: return None return mk([helper.make_node('Identity', ['input'], ['output'])]) def s_color_map(td): """Color mapping solver.""" cm = {} for ex in td['train'] + td['test']: inp, out = np.array(ex['input']), np.array(ex['output']) if inp.shape != out.shape: return None for iv, ov in zip(inp.flat, out.flat): iv, ov = int(iv), int(ov) if iv in cm and cm[iv] != ov: return None cm[iv] = ov is_permutation = (set(cm.keys()) == set(cm.values())) if is_permutation: gather_ch = np.arange(10, dtype=np.int32) for src, dst in cm.items(): if 0 <= src < 10 and 0 <= dst < 10: gather_ch[dst] = src inits = [numpy_helper.from_array(gather_ch, 'gi')] nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)] return mk(nodes, inits) else: W = np.zeros((10, 10, 1, 1), dtype=np.float32) for ic in range(10): W[cm.get(ic, ic), ic, 0, 0] = 1.0 return mk([helper.make_node('Conv', ['input', 'W'], ['output'], kernel_shape=[1, 1])], [numpy_helper.from_array(W, 'W')]) def s_transpose(td): """Transpose solver.""" for ex in td['train'] + td['test']: if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T): return None return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0, 1, 3, 2])]) def s_constant(td): """Constant output solver using opset 17 ReduceSum.""" sp = fixed_shapes(td) if sp is None: return None exs = get_exs(td) outs = [out for _, out in exs] if not all(np.array_equal(outs[0], o) for o in outs[1:]): return None const = np.zeros((1, 10, 30, 30), dtype=np.float32) for r, row in enumerate(outs[0]): for c, v in enumerate(row): const[0, int(v), r, c] = 1.0 inits = [ numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'), numpy_helper.from_array(const, 'c'), _make_int64_init('rs_axes_cst', [1, 2, 3]), ] nodes = [ helper.make_node('Mul', ['input', 'z'], ['zd']), helper.make_node('ReduceSum', ['zd', 'rs_axes_cst'], ['s'], keepdims=1), helper.make_node('Add', ['s', 'c'], ['output']), ] return mk(nodes, inits)