Move own-solver/neurogolf_solver/solvers/analytical.py to own-solver/
Browse files
own-solver/neurogolf_solver/solvers/analytical.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Basic analytical solvers: identity, constant, color_map, transpose."""
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from onnx import helper, numpy_helper, TensorProto
|
| 6 |
+
from ..onnx_helpers import mk, _make_int64_init
|
| 7 |
+
from ..data_loader import get_exs, fixed_shapes
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def s_identity(td):
|
| 11 |
+
"""Identity solver."""
|
| 12 |
+
for ex in td['train'] + td['test']:
|
| 13 |
+
if ex['input'] != ex['output']:
|
| 14 |
+
return None
|
| 15 |
+
return mk([helper.make_node('Identity', ['input'], ['output'])])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def s_color_map(td):
|
| 19 |
+
"""Color mapping solver."""
|
| 20 |
+
cm = {}
|
| 21 |
+
for ex in td['train'] + td['test']:
|
| 22 |
+
inp, out = np.array(ex['input']), np.array(ex['output'])
|
| 23 |
+
if inp.shape != out.shape:
|
| 24 |
+
return None
|
| 25 |
+
for iv, ov in zip(inp.flat, out.flat):
|
| 26 |
+
iv, ov = int(iv), int(ov)
|
| 27 |
+
if iv in cm and cm[iv] != ov:
|
| 28 |
+
return None
|
| 29 |
+
cm[iv] = ov
|
| 30 |
+
is_permutation = (set(cm.keys()) == set(cm.values()))
|
| 31 |
+
if is_permutation:
|
| 32 |
+
gather_ch = np.arange(10, dtype=np.int32)
|
| 33 |
+
for src, dst in cm.items():
|
| 34 |
+
if 0 <= src < 10 and 0 <= dst < 10:
|
| 35 |
+
gather_ch[dst] = src
|
| 36 |
+
inits = [numpy_helper.from_array(gather_ch, 'gi')]
|
| 37 |
+
nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
|
| 38 |
+
return mk(nodes, inits)
|
| 39 |
+
else:
|
| 40 |
+
W = np.zeros((10, 10, 1, 1), dtype=np.float32)
|
| 41 |
+
for ic in range(10):
|
| 42 |
+
W[cm.get(ic, ic), ic, 0, 0] = 1.0
|
| 43 |
+
return mk([helper.make_node('Conv', ['input', 'W'], ['output'], kernel_shape=[1, 1])],
|
| 44 |
+
[numpy_helper.from_array(W, 'W')])
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def s_transpose(td):
|
| 48 |
+
"""Transpose solver."""
|
| 49 |
+
for ex in td['train'] + td['test']:
|
| 50 |
+
if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T):
|
| 51 |
+
return None
|
| 52 |
+
return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0, 1, 3, 2])])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def s_constant(td):
|
| 56 |
+
"""Constant output solver using opset 17 ReduceSum."""
|
| 57 |
+
sp = fixed_shapes(td)
|
| 58 |
+
if sp is None:
|
| 59 |
+
return None
|
| 60 |
+
exs = get_exs(td)
|
| 61 |
+
outs = [out for _, out in exs]
|
| 62 |
+
if not all(np.array_equal(outs[0], o) for o in outs[1:]):
|
| 63 |
+
return None
|
| 64 |
+
const = np.zeros((1, 10, 30, 30), dtype=np.float32)
|
| 65 |
+
for r, row in enumerate(outs[0]):
|
| 66 |
+
for c, v in enumerate(row):
|
| 67 |
+
const[0, int(v), r, c] = 1.0
|
| 68 |
+
inits = [
|
| 69 |
+
numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
|
| 70 |
+
numpy_helper.from_array(const, 'c'),
|
| 71 |
+
_make_int64_init('rs_axes_cst', [1, 2, 3]),
|
| 72 |
+
]
|
| 73 |
+
nodes = [
|
| 74 |
+
helper.make_node('Mul', ['input', 'z'], ['zd']),
|
| 75 |
+
helper.make_node('ReduceSum', ['zd', 'rs_axes_cst'], ['s'], keepdims=1),
|
| 76 |
+
helper.make_node('Add', ['s', 'c'], ['output']),
|
| 77 |
+
]
|
| 78 |
+
return mk(nodes, inits)
|