rogermt commited on
Commit
ebfc1c9
·
verified ·
1 Parent(s): 771c280

Upload own-solver/neurogolf_solver/gather_helpers.py

Browse files
own-solver/neurogolf_solver/gather_helpers.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Gather-based model building utilities."""
3
+
4
+ import numpy as np
5
+ from onnx import numpy_helper, helper
6
+ from .onnx_helpers import mk
7
+ from .constants import GH, GW
8
+
9
+
10
+ def _build_gather_model(OH, OW, idx):
11
+ """Build gather model from index mapping."""
12
+ flat_idx = np.zeros((GH * GW,), dtype=np.int64)
13
+ mask = np.zeros((1, 1, GH, GW), dtype=np.float32)
14
+ for oi in range(OH):
15
+ for oj in range(OW):
16
+ flat_idx[oi * GW + oj] = idx[oi, oj, 0] * GW + idx[oi, oj, 1]
17
+ mask[0, 0, oi, oj] = 1.0
18
+ inits = [
19
+ numpy_helper.from_array(np.array([1, 10, GH * GW], dtype=np.int64), 'fs'),
20
+ numpy_helper.from_array(flat_idx, 'idx'),
21
+ numpy_helper.from_array(np.array([1, 10, GH, GW], dtype=np.int64), 'os'),
22
+ numpy_helper.from_array(mask, 'mask'),
23
+ ]
24
+ nodes = [
25
+ helper.make_node('Reshape', ['input', 'fs'], ['flat']),
26
+ helper.make_node('Gather', ['flat', 'idx'], ['g'], axis=2),
27
+ helper.make_node('Reshape', ['g', 'os'], ['raw']),
28
+ helper.make_node('Mul', ['raw', 'mask'], ['output']),
29
+ ]
30
+ return mk(nodes, inits)
31
+
32
+
33
+ def _build_gather_model_with_const(IH, IW, OH, OW, idx, cst):
34
+ """Build gather model with constant values."""
35
+ flat_idx = np.zeros((GH * GW,), dtype=np.int64)
36
+ gather_mask = np.zeros((1, 1, GH, GW), dtype=np.float32)
37
+ const_oh = np.zeros((1, 10, GH, GW), dtype=np.float32)
38
+ for oi in range(OH):
39
+ for oj in range(OW):
40
+ if idx[oi, oj, 0] >= 0:
41
+ flat_idx[oi * GW + oj] = idx[oi, oj, 0] * GW + idx[oi, oj, 1]
42
+ gather_mask[0, 0, oi, oj] = 1.0
43
+ elif cst[oi, oj] >= 0:
44
+ const_oh[0, cst[oi, oj], oi, oj] = 1.0
45
+ has_const = np.any(const_oh > 0)
46
+ inits = [
47
+ numpy_helper.from_array(np.array([1, 10, GH * GW], dtype=np.int64), 'fs'),
48
+ numpy_helper.from_array(flat_idx, 'idx'),
49
+ numpy_helper.from_array(np.array([1, 10, GH, GW], dtype=np.int64), 'os'),
50
+ numpy_helper.from_array(gather_mask, 'gmask'),
51
+ ]
52
+ nodes = [
53
+ helper.make_node('Reshape', ['input', 'fs'], ['flat']),
54
+ helper.make_node('Gather', ['flat', 'idx'], ['g'], axis=2),
55
+ helper.make_node('Reshape', ['g', 'os'], ['raw']),
56
+ helper.make_node('Mul', ['raw', 'gmask'], ['masked']),
57
+ ]
58
+ if has_const:
59
+ inits.append(numpy_helper.from_array(const_oh, 'cst'))
60
+ nodes.append(helper.make_node('Add', ['masked', 'cst'], ['output']))
61
+ else:
62
+ nodes[-1] = helper.make_node('Mul', ['raw', 'gmask'], ['output'])
63
+ return mk(nodes, inits)