| |
| """Basic analytical solvers: identity, constant, color_map, transpose.""" |
|
|
| import numpy as np |
| from onnx import helper, numpy_helper, TensorProto |
| from ..onnx_helpers import mk, _make_int64_init |
| from ..data_loader import get_exs, fixed_shapes |
|
|
|
|
| def s_identity(td): |
| """Identity solver.""" |
| for ex in td['train'] + td['test']: |
| if ex['input'] != ex['output']: |
| return None |
| return mk([helper.make_node('Identity', ['input'], ['output'])]) |
|
|
|
|
| def s_color_map(td): |
| """Color mapping solver.""" |
| cm = {} |
| for ex in td['train'] + td['test']: |
| inp, out = np.array(ex['input']), np.array(ex['output']) |
| if inp.shape != out.shape: |
| return None |
| for iv, ov in zip(inp.flat, out.flat): |
| iv, ov = int(iv), int(ov) |
| if iv in cm and cm[iv] != ov: |
| return None |
| cm[iv] = ov |
| is_permutation = (set(cm.keys()) == set(cm.values())) |
| if is_permutation: |
| gather_ch = np.arange(10, dtype=np.int32) |
| for src, dst in cm.items(): |
| if 0 <= src < 10 and 0 <= dst < 10: |
| gather_ch[dst] = src |
| inits = [numpy_helper.from_array(gather_ch, 'gi')] |
| nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)] |
| return mk(nodes, inits) |
| else: |
| W = np.zeros((10, 10, 1, 1), dtype=np.float32) |
| for ic in range(10): |
| W[cm.get(ic, ic), ic, 0, 0] = 1.0 |
| return mk([helper.make_node('Conv', ['input', 'W'], ['output'], kernel_shape=[1, 1])], |
| [numpy_helper.from_array(W, 'W')]) |
|
|
|
|
| def s_transpose(td): |
| """Transpose solver.""" |
| for ex in td['train'] + td['test']: |
| if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T): |
| return None |
| return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0, 1, 3, 2])]) |
|
|
|
|
| def s_constant(td): |
| """Constant output solver using opset 17 ReduceSum.""" |
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| exs = get_exs(td) |
| outs = [out for _, out in exs] |
| if not all(np.array_equal(outs[0], o) for o in outs[1:]): |
| return None |
| const = np.zeros((1, 10, 30, 30), dtype=np.float32) |
| for r, row in enumerate(outs[0]): |
| for c, v in enumerate(row): |
| const[0, int(v), r, c] = 1.0 |
| inits = [ |
| numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'), |
| numpy_helper.from_array(const, 'c'), |
| _make_int64_init('rs_axes_cst', [1, 2, 3]), |
| ] |
| nodes = [ |
| helper.make_node('Mul', ['input', 'z'], ['zd']), |
| helper.make_node('ReduceSum', ['zd', 'rs_axes_cst'], ['s'], keepdims=1), |
| helper.make_node('Add', ['s', 'c'], ['output']), |
| ] |
| return mk(nodes, inits) |
|
|