rogermt commited on
Commit
cd4624b
·
verified ·
1 Parent(s): 1b065e6

v5 refactor: add solvers/conv.py (with lstsq crash fix)

Browse files
Files changed (1) hide show
  1. neurogolf_solver/solvers/conv.py +316 -0
neurogolf_solver/solvers/conv.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convolutional solvers with least squares fitting."""
3
+
4
+ import time
5
+ import numpy as np
6
+ import onnx
7
+ from onnx import helper, numpy_helper
8
+ from ..onnx_helpers import mk, _make_int64_init, _build_pad_node, add_onehot_block
9
+ from ..data_loader import get_exs, get_exs_for_fitting, get_exs_for_fitting_variable, fixed_shapes
10
+ from ..validators import validate
11
+ from ..constants import GH, GW
12
+
13
+
14
+ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
15
+ """Least squares convolutional weight fitting.
16
+ Returns (Wconv, B) or None."""
17
+ pad = ks // 2
18
+ feat = 10 * ks * ks + (1 if use_bias else 0)
19
+ if feat > 20000:
20
+ return None
21
+ patches, targets = [], []
22
+ for inp_g, out_g in exs_raw:
23
+ ih, iw = inp_g.shape
24
+ if use_full_30:
25
+ oh_full = np.zeros((10, GH, GW), dtype=np.float64)
26
+ for c in range(10):
27
+ oh_full[c, :ih, :iw] = (inp_g == c)
28
+ oh_pad = np.pad(oh_full, ((0, 0), (pad, pad), (pad, pad)))
29
+ else:
30
+ oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
31
+ for c in range(10):
32
+ oh_enc[c] = (inp_g == c)
33
+ oh_pad = np.pad(oh_enc, ((0, 0), (pad, pad), (pad, pad)))
34
+ oh, ow = out_g.shape
35
+ for r in range(oh):
36
+ for c in range(ow):
37
+ p = oh_pad[:, r:r + ks, c:c + ks].flatten()
38
+ if use_bias:
39
+ p = np.append(p, 1.0)
40
+ patches.append(p)
41
+ targets.append(int(out_g[r, c]))
42
+ n_patches = len(patches)
43
+ if feat > 5000 and n_patches > 2000:
44
+ return None
45
+ P = np.array(patches, dtype=np.float64)
46
+ T = np.array(targets, dtype=np.int64)
47
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
48
+ for i, t in enumerate(T):
49
+ T_oh[i, t] = 1.0
50
+ try:
51
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
52
+ except (np.linalg.LinAlgError, ValueError):
53
+ return None
54
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T):
55
+ return None
56
+ if use_bias:
57
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
58
+ B = WT[-1].astype(np.float32)
59
+ else:
60
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
61
+ B = None
62
+ return Wconv, B
63
+
64
+
65
+ def solve_conv_fixed(td, path, providers, time_budget=30.0):
66
+ """Fixed-shape convolutional solver."""
67
+ exs = get_exs(td)
68
+ for inp, out in exs:
69
+ if inp.shape != out.shape:
70
+ return None
71
+ shapes = set(inp.shape for inp, _ in exs)
72
+ if len(shapes) != 1:
73
+ return None
74
+ IH, IW = shapes.pop()
75
+ fit_exs = get_exs_for_fitting(td)
76
+ fit_exs = [(i, o) for i, o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
77
+ t_start = time.time()
78
+ for use_bias in [False, True]:
79
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
80
+ if time.time() - t_start > time_budget:
81
+ return None
82
+ result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=False)
83
+ if result is None:
84
+ continue
85
+ Wconv, B = result
86
+ pad = ks // 2
87
+ pad_h, pad_w = GH - IH, GW - IW
88
+ inits = [
89
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
90
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
91
+ numpy_helper.from_array(Wconv, 'W'),
92
+ ]
93
+ conv_inputs = ['grid', 'W']
94
+ if B is not None:
95
+ inits.append(numpy_helper.from_array(B, 'B'))
96
+ conv_inputs.append('B')
97
+ nodes = [
98
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']),
99
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
100
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
101
+ ]
102
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
103
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
104
+ model = mk(nodes, inits)
105
+ onnx.save(model, path)
106
+ if validate(path, td, providers):
107
+ return 'conv_fixed', model
108
+ return None
109
+
110
+
111
+ def solve_conv_variable(td, path, providers, time_budget=30.0):
112
+ """Variable-shape conv with opset 17 ReduceSum."""
113
+ exs = get_exs(td)
114
+ for inp, out in exs:
115
+ if inp.shape != out.shape:
116
+ return None
117
+ fit_exs = get_exs_for_fitting_variable(td)
118
+ fit_exs = [(i, o) for i, o in fit_exs if i.shape == o.shape]
119
+ t_start = time.time()
120
+ for use_bias in [False, True]:
121
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
122
+ if time.time() - t_start > time_budget:
123
+ return None
124
+ result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=True)
125
+ if result is None:
126
+ continue
127
+ Wconv, B = result
128
+ pad = ks // 2
129
+ inits = [
130
+ numpy_helper.from_array(Wconv, 'W'),
131
+ _make_int64_init('rs_axes_var', [1]),
132
+ ]
133
+ conv_inputs = ['input', 'W']
134
+ if B is not None:
135
+ inits.append(numpy_helper.from_array(B, 'B'))
136
+ conv_inputs.append('B')
137
+ nodes = [
138
+ helper.make_node('ReduceSum', ['input', 'rs_axes_var'], ['mask'], keepdims=1),
139
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
140
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
141
+ ]
142
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
143
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
144
+ model = mk(nodes, inits)
145
+ onnx.save(model, path)
146
+ if validate(path, td, providers):
147
+ return 'conv_var', model
148
+ return None
149
+
150
+
151
+ def solve_conv_diffshape(td, path, providers, time_budget=30.0):
152
+ """Different-shape convolutional solver."""
153
+ sp = fixed_shapes(td)
154
+ if sp is None:
155
+ return None
156
+ (IH, IW), (OH, OW) = sp
157
+ if IH == OH and IW == OW:
158
+ return None
159
+ if OH > IH or OW > IW:
160
+ return None
161
+ if OH > 30 or OW > 30:
162
+ return None
163
+ exs = get_exs(td)
164
+ t_start = time.time()
165
+ for dr_off, dc_off in [(0, 0), ((IH - OH) // 2, (IW - OW) // 2)]:
166
+ for use_bias in [False, True]:
167
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
168
+ if time.time() - t_start > time_budget:
169
+ return None
170
+ pad = ks // 2
171
+ feat = 10 * ks * ks + (1 if use_bias else 0)
172
+ if feat > 10000:
173
+ continue
174
+ patches, targets = [], []
175
+ valid = True
176
+ for inp_g, out_g in exs:
177
+ oh_enc = np.zeros((10, IH, IW), dtype=np.float64)
178
+ for c in range(10):
179
+ oh_enc[c] = (inp_g == c)
180
+ oh_pad = np.pad(oh_enc, ((0, 0), (pad, pad), (pad, pad)))
181
+ for r in range(OH):
182
+ for c in range(OW):
183
+ sr, sc = r + dr_off, c + dc_off
184
+ if sr < 0 or sr >= IH or sc < 0 or sc >= IW:
185
+ valid = False
186
+ break
187
+ p = oh_pad[:, sr:sr + ks, sc:sc + ks].flatten()
188
+ if use_bias:
189
+ p = np.append(p, 1.0)
190
+ patches.append(p)
191
+ targets.append(int(out_g[r, c]))
192
+ if not valid:
193
+ break
194
+ if not valid:
195
+ break
196
+ if not valid:
197
+ continue
198
+ n_patches = len(patches)
199
+ if feat > 5000 and n_patches > 2000:
200
+ continue
201
+ P = np.array(patches, dtype=np.float64)
202
+ T = np.array(targets, dtype=np.int64)
203
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
204
+ for i, t in enumerate(T):
205
+ T_oh[i, t] = 1.0
206
+ try:
207
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
208
+ except (np.linalg.LinAlgError, ValueError):
209
+ continue
210
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T):
211
+ continue
212
+ if use_bias:
213
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
214
+ B = WT[-1].astype(np.float32)
215
+ else:
216
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
217
+ B = None
218
+ pad_h, pad_w = GH - OH, GW - OW
219
+ inits = [
220
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
221
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
222
+ numpy_helper.from_array(Wconv, 'W'),
223
+ _make_int64_init('cr_st', [0, 0, dr_off, dc_off]),
224
+ _make_int64_init('cr_en', [1, 10, dr_off + OH, dc_off + OW]),
225
+ ]
226
+ conv_inputs = ['grid', 'W']
227
+ if B is not None:
228
+ inits.append(numpy_helper.from_array(B, 'B'))
229
+ conv_inputs.append('B')
230
+ nodes = [
231
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']),
232
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
233
+ helper.make_node('Slice', ['co', 'cr_st', 'cr_en'], ['co_crop']),
234
+ helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
235
+ ]
236
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
237
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
238
+ model = mk(nodes, inits)
239
+ onnx.save(model, path)
240
+ if validate(path, td, providers):
241
+ return 'conv_diff', model
242
+ return None
243
+
244
+
245
+ def solve_conv_var_diff(td, path, providers, time_budget=30.0):
246
+ """Variable diff-shape conv with opset 17 ReduceSum."""
247
+ exs = get_exs(td)
248
+ t_start = time.time()
249
+ for use_bias in [False, True]:
250
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
251
+ if time.time() - t_start > time_budget:
252
+ return None
253
+ pad = ks // 2
254
+ feat = 10 * ks * ks + (1 if use_bias else 0)
255
+ if feat > 20000:
256
+ continue
257
+ patches, targets = [], []
258
+ for inp_g, out_g in exs:
259
+ ih, iw = inp_g.shape
260
+ oh, ow = out_g.shape
261
+ oh_full = np.zeros((10, GH, GW), dtype=np.float64)
262
+ for c in range(10):
263
+ oh_full[c, :ih, :iw] = (inp_g == c)
264
+ oh_pad = np.pad(oh_full, ((0, 0), (pad, pad), (pad, pad)))
265
+ for r in range(oh):
266
+ for c in range(ow):
267
+ p = oh_pad[:, r:r + ks, c:c + ks].flatten()
268
+ if use_bias:
269
+ p = np.append(p, 1.0)
270
+ patches.append(p)
271
+ targets.append(int(out_g[r, c]))
272
+ n_patches = len(patches)
273
+ if feat > 5000 and n_patches > 2000:
274
+ continue
275
+ P = np.array(patches, dtype=np.float64)
276
+ T = np.array(targets, dtype=np.int64)
277
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
278
+ for i, t in enumerate(T):
279
+ T_oh[i, t] = 1.0
280
+ try:
281
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
282
+ except (np.linalg.LinAlgError, ValueError):
283
+ continue
284
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T):
285
+ continue
286
+ if use_bias:
287
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
288
+ B = WT[-1].astype(np.float32)
289
+ else:
290
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
291
+ B = None
292
+ all_output_within_input = all(
293
+ out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
294
+ for inp_g, out_g in exs
295
+ )
296
+ if all_output_within_input:
297
+ inits = [
298
+ numpy_helper.from_array(Wconv, 'W'),
299
+ _make_int64_init('rs_axes_vd', [1]),
300
+ ]
301
+ conv_inputs = ['input', 'W']
302
+ if B is not None:
303
+ inits.append(numpy_helper.from_array(B, 'B'))
304
+ conv_inputs.append('B')
305
+ nodes = [
306
+ helper.make_node('ReduceSum', ['input', 'rs_axes_vd'], ['mask'], keepdims=1),
307
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
308
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
309
+ ]
310
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
311
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
312
+ model = mk(nodes, inits)
313
+ onnx.save(model, path)
314
+ if validate(path, td, providers):
315
+ return 'conv_var_diff', model
316
+ return None