rogermt commited on
Commit
bc87399
·
verified ·
1 Parent(s): 0611aba

Move own-solver/neurogolf_solver/solvers/analytical.py to own-solver/

Browse files
own-solver/neurogolf_solver/solvers/analytical.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Basic analytical solvers: identity, constant, color_map, transpose."""
3
+
4
+ import numpy as np
5
+ from onnx import helper, numpy_helper, TensorProto
6
+ from ..onnx_helpers import mk, _make_int64_init
7
+ from ..data_loader import get_exs, fixed_shapes
8
+
9
+
10
+ def s_identity(td):
11
+ """Identity solver."""
12
+ for ex in td['train'] + td['test']:
13
+ if ex['input'] != ex['output']:
14
+ return None
15
+ return mk([helper.make_node('Identity', ['input'], ['output'])])
16
+
17
+
18
+ def s_color_map(td):
19
+ """Color mapping solver."""
20
+ cm = {}
21
+ for ex in td['train'] + td['test']:
22
+ inp, out = np.array(ex['input']), np.array(ex['output'])
23
+ if inp.shape != out.shape:
24
+ return None
25
+ for iv, ov in zip(inp.flat, out.flat):
26
+ iv, ov = int(iv), int(ov)
27
+ if iv in cm and cm[iv] != ov:
28
+ return None
29
+ cm[iv] = ov
30
+ is_permutation = (set(cm.keys()) == set(cm.values()))
31
+ if is_permutation:
32
+ gather_ch = np.arange(10, dtype=np.int32)
33
+ for src, dst in cm.items():
34
+ if 0 <= src < 10 and 0 <= dst < 10:
35
+ gather_ch[dst] = src
36
+ inits = [numpy_helper.from_array(gather_ch, 'gi')]
37
+ nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
38
+ return mk(nodes, inits)
39
+ else:
40
+ W = np.zeros((10, 10, 1, 1), dtype=np.float32)
41
+ for ic in range(10):
42
+ W[cm.get(ic, ic), ic, 0, 0] = 1.0
43
+ return mk([helper.make_node('Conv', ['input', 'W'], ['output'], kernel_shape=[1, 1])],
44
+ [numpy_helper.from_array(W, 'W')])
45
+
46
+
47
+ def s_transpose(td):
48
+ """Transpose solver."""
49
+ for ex in td['train'] + td['test']:
50
+ if not np.array_equal(np.array(ex['output']), np.array(ex['input']).T):
51
+ return None
52
+ return mk([helper.make_node('Transpose', ['input'], ['output'], perm=[0, 1, 3, 2])])
53
+
54
+
55
+ def s_constant(td):
56
+ """Constant output solver using opset 17 ReduceSum."""
57
+ sp = fixed_shapes(td)
58
+ if sp is None:
59
+ return None
60
+ exs = get_exs(td)
61
+ outs = [out for _, out in exs]
62
+ if not all(np.array_equal(outs[0], o) for o in outs[1:]):
63
+ return None
64
+ const = np.zeros((1, 10, 30, 30), dtype=np.float32)
65
+ for r, row in enumerate(outs[0]):
66
+ for c, v in enumerate(row):
67
+ const[0, int(v), r, c] = 1.0
68
+ inits = [
69
+ numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'z'),
70
+ numpy_helper.from_array(const, 'c'),
71
+ _make_int64_init('rs_axes_cst', [1, 2, 3]),
72
+ ]
73
+ nodes = [
74
+ helper.make_node('Mul', ['input', 'z'], ['zd']),
75
+ helper.make_node('ReduceSum', ['zd', 'rs_axes_cst'], ['s'], keepdims=1),
76
+ helper.make_node('Add', ['s', 'c'], ['output']),
77
+ ]
78
+ return mk(nodes, inits)