rogermt's picture
Move own-solver/neurogolf_solver/solvers/analytical.py to own-solver/
bc87399 verified
#!/usr/bin/env python3
"""Basic analytical solvers: identity, constant, color_map, transpose."""
import numpy as np
from onnx import helper, numpy_helper, TensorProto
from ..onnx_helpers import mk, _make_int64_init
from ..data_loader import get_exs, fixed_shapes
def s_identity(td):
"""Identity solver."""
for ex in td['train'] + td['test']:
if ex['input'] != ex['output']:
return None
return mk([helper.make_node('Identity', ['input'], ['output'])])
def s_color_map(td):
"""Color mapping solver."""
cm = {}
for ex in td['train'] + td['test']:
inp, out = np.array(ex['input']), np.array(ex['output'])
if inp.shape != out.shape:
return None
for iv, ov in zip(inp.flat, out.flat):
iv, ov = int(iv), int(ov)
if iv in cm and cm[iv] != ov:
return None
cm[iv] = ov
is_permutation = (set(cm.keys()) == set(cm.values()))
if is_permutation:
gather_ch = np.arange(10, dtype=np.int32)
for src, dst in cm.items():
if 0 <= src < 10 and 0 <= dst < 10:
gather_ch[dst] = src
inits = [numpy_helper.from_array(gather_ch, 'gi')]
nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
return mk(nodes, inits)
else:
W = np.zeros((10, 10, 1, 1), dtype=np.float32)
for ic in range(10):
W[cm.get(ic, ic), ic, 0, 0] = 1.0
return mk([helper.make_node('Conv', ['input', 'W'], ['output'], kernel_shape=[1, 1])],
[numpy_helper.from_array(W, 'W')])
def s_transpose(td):
"""Transpose solver."""
for ex in td['train'] + td['test']:
if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T):
return None
return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0, 1, 3, 2])])
def s_constant(td):
"""Constant output solver using opset 17 ReduceSum."""
sp = fixed_shapes(td)
if sp is None:
return None
exs = get_exs(td)
outs = [out for _, out in exs]
if not all(np.array_equal(outs[0], o) for o in outs[1:]):
return None
const = np.zeros((1, 10, 30, 30), dtype=np.float32)
for r, row in enumerate(outs[0]):
for c, v in enumerate(row):
const[0, int(v), r, c] = 1.0
inits = [
numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
numpy_helper.from_array(const, 'c'),
_make_int64_init('rs_axes_cst', [1, 2, 3]),
]
nodes = [
helper.make_node('Mul', ['input', 'z'], ['zd']),
helper.make_node('ReduceSum', ['zd', 'rs_axes_cst'], ['s'], keepdims=1),
helper.make_node('Add', ['s', 'c'], ['output']),
]
return mk(nodes, inits)