rogermt commited on
Commit
014e988
·
verified ·
1 Parent(s): bc87399

Move own-solver/neurogolf_solver/solvers/conv.py to own-solver/

Browse files
own-solver/neurogolf_solver/solvers/conv.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Convolutional solvers with least squares fitting.
3
+
4
+ v5.1: Refactored into composable primitives (_build_patch_matrix, _solve_weights,
5
+ _extract_weights) + PCR (PCA regression) fallback via _solve_weights_pcr.
6
+ PCR tested on 400 tasks: 0 new solves but no regressions. Code kept for
7
+ future experiments (Lasso, Ridge can reuse the same _solve_weights interface).
8
+ """
9
+
10
+ import time
11
+ import numpy as np
12
+ import onnx
13
+ from onnx import helper, numpy_helper
14
+ from ..onnx_helpers import mk, _make_int64_init, _build_pad_node, add_onehot_block
15
+ from ..data_loader import get_exs, get_exs_for_fitting, get_exs_for_fitting_variable, fixed_shapes
16
+ from ..validators import validate
17
+ from ..constants import GH, GW
18
+
19
+
20
+ # ---------------------------------------------------------------------------
21
+ # Core fitting primitives (composable: mix _build_patch_matrix with any solver)
22
+ # ---------------------------------------------------------------------------
23
+
24
+ def _build_patch_matrix(exs_raw, ks, use_bias, use_full_30=False):
25
+ """Build patch matrix P and target matrix T_oh from examples.
26
+ Returns (P, T, T_oh) or None if infeasible."""
27
+ pad = ks // 2
28
+ feat = 10 * ks * ks + (1 if use_bias else 0)
29
+ if feat > 20000:
30
+ return None
31
+ patches, targets = [], []
32
+ for inp_g, out_g in exs_raw:
33
+ ih, iw = inp_g.shape
34
+ if use_full_30:
35
+ oh_full = np.zeros((10, GH, GW), dtype=np.float64)
36
+ for c in range(10):
37
+ oh_full[c, :ih, :iw] = (inp_g == c)
38
+ oh_pad = np.pad(oh_full, ((0, 0), (pad, pad), (pad, pad)))
39
+ else:
40
+ oh_enc = np.zeros((10, ih, iw), dtype=np.float64)
41
+ for c in range(10):
42
+ oh_enc[c] = (inp_g == c)
43
+ oh_pad = np.pad(oh_enc, ((0, 0), (pad, pad), (pad, pad)))
44
+ oh, ow = out_g.shape
45
+ for r in range(oh):
46
+ for c in range(ow):
47
+ p = oh_pad[:, r:r + ks, c:c + ks].flatten()
48
+ if use_bias:
49
+ p = np.append(p, 1.0)
50
+ patches.append(p)
51
+ targets.append(int(out_g[r, c]))
52
+ n_patches = len(patches)
53
+ if feat > 5000 and n_patches > 2000:
54
+ return None
55
+ P = np.array(patches, dtype=np.float64)
56
+ T = np.array(targets, dtype=np.int64)
57
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
58
+ for i, t in enumerate(T):
59
+ T_oh[i, t] = 1.0
60
+ return P, T, T_oh
61
+
62
+
63
+ def _solve_weights(P, T, T_oh):
64
+ """Raw lstsq solve. Returns WT (p×10) or None."""
65
+ try:
66
+ WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
67
+ except (np.linalg.LinAlgError, ValueError):
68
+ return None
69
+ if not np.array_equal(np.argmax(P @ WT, axis=1), T):
70
+ return None
71
+ return WT
72
+
73
+
74
+ def _solve_weights_pcr(P, T, T_oh, var_thresholds=(0.999, 0.99, 0.95)):
75
+ """PCA/Truncated SVD regression. Try multiple variance thresholds.
76
+ Returns WT (p×10) or None.
77
+ Only attempted when p/n > 0.5 (potential overfitting zone).
78
+
79
+ Tested 2026-04-26: improves arc-gen accuracy by 3-9% on 4/345 unsolved
80
+ tasks but never reaches 100% required for validation. Kept as fallback
81
+ for marginal cases and for future combination with more arc-gen data."""
82
+ n, p = P.shape
83
+ if p / max(n, 1) <= 0.5:
84
+ return None # lstsq is safe here, no need for PCR
85
+ try:
86
+ U, s, Vt = np.linalg.svd(P, full_matrices=False)
87
+ except (np.linalg.LinAlgError, ValueError):
88
+ return None
89
+ cumvar = np.cumsum(s**2) / np.sum(s**2)
90
+ for thresh in var_thresholds:
91
+ k = int(np.searchsorted(cumvar, thresh)) + 1
92
+ k = max(k, 5)
93
+ k = min(k, min(n, p))
94
+ P_red = U[:, :k] * s[:k]
95
+ try:
96
+ w_red = np.linalg.lstsq(P_red, T_oh, rcond=None)[0]
97
+ except (np.linalg.LinAlgError, ValueError):
98
+ continue
99
+ if not np.array_equal(np.argmax(P_red @ w_red, axis=1), T):
100
+ continue
101
+ # Map back to full p-dimensional weights for ONNX conv
102
+ WT = Vt[:k].T @ w_red
103
+ # Verify full-space predictions match
104
+ if np.array_equal(np.argmax(P @ WT, axis=1), T):
105
+ return WT
106
+ return None
107
+
108
+
109
+ def _extract_weights(WT, ks, use_bias):
110
+ """Extract Wconv and B from weight matrix WT."""
111
+ if use_bias:
112
+ Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
113
+ B = WT[-1].astype(np.float32)
114
+ else:
115
+ Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
116
+ B = None
117
+ return Wconv, B
118
+
119
+
120
+ # ---------------------------------------------------------------------------
121
+ # Convenience wrappers (combine primitives into single-call fitting)
122
+ # ---------------------------------------------------------------------------
123
+
124
+ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
125
+ """Least squares convolutional weight fitting.
126
+ Returns (Wconv, B) or None."""
127
+ ptm = _build_patch_matrix(exs_raw, ks, use_bias, use_full_30)
128
+ if ptm is None:
129
+ return None
130
+ P, T, T_oh = ptm
131
+ WT = _solve_weights(P, T, T_oh)
132
+ if WT is None:
133
+ return None
134
+ return _extract_weights(WT, ks, use_bias)
135
+
136
+
137
+ def _lstsq_conv_pcr(exs_raw, ks, use_bias, use_full_30=False):
138
+ """PCA regression convolutional weight fitting.
139
+ Returns (Wconv, B) or None. Fallback when raw lstsq overfits."""
140
+ ptm = _build_patch_matrix(exs_raw, ks, use_bias, use_full_30)
141
+ if ptm is None:
142
+ return None
143
+ P, T, T_oh = ptm
144
+ WT = _solve_weights_pcr(P, T, T_oh)
145
+ if WT is None:
146
+ return None
147
+ return _extract_weights(WT, ks, use_bias)
148
+
149
+
150
+ # ---------------------------------------------------------------------------
151
+ # Solver functions (called from solver_registry.py)
152
+ # ---------------------------------------------------------------------------
153
+
154
+ def _build_and_validate_conv_fixed(fit_fn, fit_exs, ks, use_bias, IH, IW, td, path, providers):
155
+ """Build ONNX model with given fit function, validate it. Returns (tag, model) or None."""
156
+ result = fit_fn(fit_exs, ks, use_bias, use_full_30=False)
157
+ if result is None:
158
+ return None
159
+ Wconv, B = result
160
+ pad = ks // 2
161
+ pad_h, pad_w = GH - IH, GW - IW
162
+ inits = [
163
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
164
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
165
+ numpy_helper.from_array(Wconv, 'W'),
166
+ ]
167
+ conv_inputs = ['grid', 'W']
168
+ if B is not None:
169
+ inits.append(numpy_helper.from_array(B, 'B'))
170
+ conv_inputs.append('B')
171
+ nodes = [
172
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']),
173
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
174
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
175
+ ]
176
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
177
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
178
+ model = mk(nodes, inits)
179
+ onnx.save(model, path)
180
+ if validate(path, td, providers):
181
+ tag = 'conv_fixed' if fit_fn == _lstsq_conv else 'conv_fixed_pcr'
182
+ return tag, model
183
+ return None
184
+
185
+
186
+ def solve_conv_fixed(td, path, providers, time_budget=30.0):
187
+ """Fixed-shape convolutional solver. Tries lstsq first, PCR as second pass."""
188
+ exs = get_exs(td)
189
+ for inp, out in exs:
190
+ if inp.shape != out.shape:
191
+ return None
192
+ shapes = set(inp.shape for inp, _ in exs)
193
+ if len(shapes) != 1:
194
+ return None
195
+ IH, IW = shapes.pop()
196
+ fit_exs = get_exs_for_fitting(td)
197
+ fit_exs = [(i, o) for i, o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
198
+ t_start = time.time()
199
+ # Pass 1: raw lstsq (same as baseline)
200
+ failed_ks = [] # (ks, use_bias) pairs where lstsq fit train but failed validation
201
+ for use_bias in [False, True]:
202
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
203
+ if time.time() - t_start > time_budget:
204
+ return None
205
+ result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=False)
206
+ if result is None:
207
+ continue
208
+ Wconv, B = result
209
+ pad = ks // 2
210
+ pad_h, pad_w = GH - IH, GW - IW
211
+ inits = [
212
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
213
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
214
+ numpy_helper.from_array(Wconv, 'W'),
215
+ ]
216
+ conv_inputs = ['grid', 'W']
217
+ if B is not None:
218
+ inits.append(numpy_helper.from_array(B, 'B'))
219
+ conv_inputs.append('B')
220
+ nodes = [
221
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']),
222
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
223
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
224
+ ]
225
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
226
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
227
+ model = mk(nodes, inits)
228
+ onnx.save(model, path)
229
+ if validate(path, td, providers):
230
+ return 'conv_fixed', model
231
+ # lstsq fit train but failed validation — candidate for PCR
232
+ failed_ks.append((ks, use_bias))
233
+ # Pass 2: PCR on failed ks values (only if time remains)
234
+ for ks, use_bias in failed_ks:
235
+ if time.time() - t_start > time_budget:
236
+ return None
237
+ r = _build_and_validate_conv_fixed(_lstsq_conv_pcr, fit_exs, ks, use_bias, IH, IW, td, path, providers)
238
+ if r is not None:
239
+ return r
240
+ return None
241
+
242
+
243
+ def _build_and_validate_conv_var(fit_fn, fit_exs, ks, use_bias, td, path, providers):
244
+ """Build variable-shape ONNX model with given fit function. Returns (tag, model) or None."""
245
+ result = fit_fn(fit_exs, ks, use_bias, use_full_30=True)
246
+ if result is None:
247
+ return None
248
+ Wconv, B = result
249
+ pad = ks // 2
250
+ inits = [
251
+ numpy_helper.from_array(Wconv, 'W'),
252
+ _make_int64_init('rs_axes_var', [1]),
253
+ ]
254
+ conv_inputs = ['input', 'W']
255
+ if B is not None:
256
+ inits.append(numpy_helper.from_array(B, 'B'))
257
+ conv_inputs.append('B')
258
+ nodes = [
259
+ helper.make_node('ReduceSum', ['input', 'rs_axes_var'], ['mask'], keepdims=1),
260
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
261
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
262
+ ]
263
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
264
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
265
+ model = mk(nodes, inits)
266
+ onnx.save(model, path)
267
+ if validate(path, td, providers):
268
+ tag = 'conv_var' if fit_fn == _lstsq_conv else 'conv_var_pcr'
269
+ return tag, model
270
+ return None
271
+
272
+
273
+ def solve_conv_variable(td, path, providers, time_budget=30.0):
274
+ """Variable-shape conv. Tries lstsq first, PCR as second pass."""
275
+ exs = get_exs(td)
276
+ for inp, out in exs:
277
+ if inp.shape != out.shape:
278
+ return None
279
+ fit_exs = get_exs_for_fitting_variable(td)
280
+ fit_exs = [(i, o) for i, o in fit_exs if i.shape == o.shape]
281
+ t_start = time.time()
282
+ # Pass 1: raw lstsq
283
+ failed_ks = []
284
+ for use_bias in [False, True]:
285
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
286
+ if time.time() - t_start > time_budget:
287
+ return None
288
+ result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=True)
289
+ if result is None:
290
+ continue
291
+ Wconv, B = result
292
+ pad = ks // 2
293
+ inits = [
294
+ numpy_helper.from_array(Wconv, 'W'),
295
+ _make_int64_init('rs_axes_var', [1]),
296
+ ]
297
+ conv_inputs = ['input', 'W']
298
+ if B is not None:
299
+ inits.append(numpy_helper.from_array(B, 'B'))
300
+ conv_inputs.append('B')
301
+ nodes = [
302
+ helper.make_node('ReduceSum', ['input', 'rs_axes_var'], ['mask'], keepdims=1),
303
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
304
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
305
+ ]
306
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
307
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
308
+ model = mk(nodes, inits)
309
+ onnx.save(model, path)
310
+ if validate(path, td, providers):
311
+ return 'conv_var', model
312
+ failed_ks.append((ks, use_bias))
313
+ # Pass 2: PCR on failed ks values
314
+ for ks, use_bias in failed_ks:
315
+ if time.time() - t_start > time_budget:
316
+ return None
317
+ r = _build_and_validate_conv_var(_lstsq_conv_pcr, fit_exs, ks, use_bias, td, path, providers)
318
+ if r is not None:
319
+ return r
320
+ return None
321
+
322
+
323
+ def solve_conv_diffshape(td, path, providers, time_budget=30.0):
324
+ """Different-shape convolutional solver. Tries lstsq first, PCR as second pass."""
325
+ sp = fixed_shapes(td)
326
+ if sp is None:
327
+ return None
328
+ (IH, IW), (OH, OW) = sp
329
+ if IH == OH and IW == OW:
330
+ return None
331
+ if OH > IH or OW > IW:
332
+ return None
333
+ if OH > 30 or OW > 30:
334
+ return None
335
+ exs = get_exs(td)
336
+ t_start = time.time()
337
+ failed_configs = [] # (P, T, T_oh, ks, use_bias, dr_off, dc_off) for PCR retry
338
+ for dr_off, dc_off in [(0, 0), ((IH - OH) // 2, (IW - OW) // 2)]:
339
+ for use_bias in [False, True]:
340
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
341
+ if time.time() - t_start > time_budget:
342
+ break
343
+ pad = ks // 2
344
+ feat = 10 * ks * ks + (1 if use_bias else 0)
345
+ if feat > 10000:
346
+ continue
347
+ patches, targets = [], []
348
+ valid = True
349
+ for inp_g, out_g in exs:
350
+ oh_enc = np.zeros((10, IH, IW), dtype=np.float64)
351
+ for c in range(10):
352
+ oh_enc[c] = (inp_g == c)
353
+ oh_pad = np.pad(oh_enc, ((0, 0), (pad, pad), (pad, pad)))
354
+ for r in range(OH):
355
+ for c in range(OW):
356
+ sr, sc = r + dr_off, c + dc_off
357
+ if sr < 0 or sr >= IH or sc < 0 or sc >= IW:
358
+ valid = False
359
+ break
360
+ p = oh_pad[:, sr:sr + ks, sc:sc + ks].flatten()
361
+ if use_bias:
362
+ p = np.append(p, 1.0)
363
+ patches.append(p)
364
+ targets.append(int(out_g[r, c]))
365
+ if not valid:
366
+ break
367
+ if not valid:
368
+ break
369
+ if not valid:
370
+ continue
371
+ n_patches = len(patches)
372
+ if feat > 5000 and n_patches > 2000:
373
+ continue
374
+ P = np.array(patches, dtype=np.float64)
375
+ T = np.array(targets, dtype=np.int64)
376
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
377
+ for i, t in enumerate(T):
378
+ T_oh[i, t] = 1.0
379
+ # Pass 1: raw lstsq
380
+ WT = _solve_weights(P, T, T_oh)
381
+ if WT is None:
382
+ continue
383
+ Wconv, B = _extract_weights(WT, ks, use_bias)
384
+ pad_h, pad_w = GH - OH, GW - OW
385
+ inits = [
386
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
387
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
388
+ numpy_helper.from_array(Wconv, 'W'),
389
+ _make_int64_init('cr_st', [0, 0, dr_off, dc_off]),
390
+ _make_int64_init('cr_en', [1, 10, dr_off + OH, dc_off + OW]),
391
+ ]
392
+ conv_inputs = ['grid', 'W']
393
+ if B is not None:
394
+ inits.append(numpy_helper.from_array(B, 'B'))
395
+ conv_inputs.append('B')
396
+ nodes = [
397
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']),
398
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
399
+ helper.make_node('Slice', ['co', 'cr_st', 'cr_en'], ['co_crop']),
400
+ helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
401
+ ]
402
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
403
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
404
+ model = mk(nodes, inits)
405
+ onnx.save(model, path)
406
+ if validate(path, td, providers):
407
+ return 'conv_diff', model
408
+ # Failed validation — save for PCR retry
409
+ failed_configs.append((P, T, T_oh, ks, use_bias, dr_off, dc_off))
410
+ # Pass 2: PCR on failed configs
411
+ for P, T, T_oh, ks, use_bias, dr_off, dc_off in failed_configs:
412
+ if time.time() - t_start > time_budget:
413
+ return None
414
+ WT = _solve_weights_pcr(P, T, T_oh)
415
+ if WT is None:
416
+ continue
417
+ Wconv, B = _extract_weights(WT, ks, use_bias)
418
+ pad_h, pad_w = GH - OH, GW - OW
419
+ inits = [
420
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
421
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
422
+ numpy_helper.from_array(Wconv, 'W'),
423
+ _make_int64_init('cr_st', [0, 0, dr_off, dc_off]),
424
+ _make_int64_init('cr_en', [1, 10, dr_off + OH, dc_off + OW]),
425
+ ]
426
+ conv_inputs = ['grid', 'W']
427
+ if B is not None:
428
+ inits.append(numpy_helper.from_array(B, 'B'))
429
+ conv_inputs.append('B')
430
+ nodes = [
431
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']),
432
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
433
+ helper.make_node('Slice', ['co', 'cr_st', 'cr_en'], ['co_crop']),
434
+ helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
435
+ ]
436
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
437
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
438
+ model = mk(nodes, inits)
439
+ onnx.save(model, path)
440
+ if validate(path, td, providers):
441
+ return 'conv_diff_pcr', model
442
+ return None
443
+
444
+
445
+ def solve_conv_var_diff(td, path, providers, time_budget=30.0):
446
+ """Variable diff-shape conv. Tries lstsq first, PCR as second pass."""
447
+ exs = get_exs(td)
448
+ t_start = time.time()
449
+ failed_configs = [] # (P, T, T_oh, ks, use_bias) for PCR retry
450
+ for use_bias in [False, True]:
451
+ for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
452
+ if time.time() - t_start > time_budget:
453
+ break
454
+ pad = ks // 2
455
+ feat = 10 * ks * ks + (1 if use_bias else 0)
456
+ if feat > 20000:
457
+ continue
458
+ patches, targets = [], []
459
+ for inp_g, out_g in exs:
460
+ ih, iw = inp_g.shape
461
+ oh, ow = out_g.shape
462
+ oh_full = np.zeros((10, GH, GW), dtype=np.float64)
463
+ for c in range(10):
464
+ oh_full[c, :ih, :iw] = (inp_g == c)
465
+ oh_pad = np.pad(oh_full, ((0, 0), (pad, pad), (pad, pad)))
466
+ for r in range(oh):
467
+ for c in range(ow):
468
+ p = oh_pad[:, r:r + ks, c:c + ks].flatten()
469
+ if use_bias:
470
+ p = np.append(p, 1.0)
471
+ patches.append(p)
472
+ targets.append(int(out_g[r, c]))
473
+ n_patches = len(patches)
474
+ if feat > 5000 and n_patches > 2000:
475
+ continue
476
+ P = np.array(patches, dtype=np.float64)
477
+ T = np.array(targets, dtype=np.int64)
478
+ T_oh = np.zeros((len(T), 10), dtype=np.float64)
479
+ for i, t in enumerate(T):
480
+ T_oh[i, t] = 1.0
481
+ # Pass 1: raw lstsq
482
+ WT = _solve_weights(P, T, T_oh)
483
+ if WT is None:
484
+ continue
485
+ Wconv, B = _extract_weights(WT, ks, use_bias)
486
+ all_output_within_input = all(
487
+ out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
488
+ for inp_g, out_g in exs
489
+ )
490
+ if all_output_within_input:
491
+ inits = [
492
+ numpy_helper.from_array(Wconv, 'W'),
493
+ _make_int64_init('rs_axes_vd', [1]),
494
+ ]
495
+ conv_inputs = ['input', 'W']
496
+ if B is not None:
497
+ inits.append(numpy_helper.from_array(B, 'B'))
498
+ conv_inputs.append('B')
499
+ nodes = [
500
+ helper.make_node('ReduceSum', ['input', 'rs_axes_vd'], ['mask'], keepdims=1),
501
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
502
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
503
+ ]
504
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
505
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
506
+ model = mk(nodes, inits)
507
+ onnx.save(model, path)
508
+ if validate(path, td, providers):
509
+ return 'conv_var_diff', model
510
+ # Failed validation — save for PCR
511
+ failed_configs.append((P, T, T_oh, ks, use_bias))
512
+ # Pass 2: PCR on failed configs
513
+ for P, T, T_oh, ks, use_bias in failed_configs:
514
+ if time.time() - t_start > time_budget:
515
+ return None
516
+ WT = _solve_weights_pcr(P, T, T_oh)
517
+ if WT is None:
518
+ continue
519
+ Wconv, B = _extract_weights(WT, ks, use_bias)
520
+ all_output_within_input = all(
521
+ out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
522
+ for inp_g, out_g in exs
523
+ )
524
+ if all_output_within_input:
525
+ inits = [
526
+ numpy_helper.from_array(Wconv, 'W'),
527
+ _make_int64_init('rs_axes_vd', [1]),
528
+ ]
529
+ conv_inputs = ['input', 'W']
530
+ if B is not None:
531
+ inits.append(numpy_helper.from_array(B, 'B'))
532
+ conv_inputs.append('B')
533
+ nodes = [
534
+ helper.make_node('ReduceSum', ['input', 'rs_axes_vd'], ['mask'], keepdims=1),
535
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
536
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
537
+ ]
538
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
539
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
540
+ model = mk(nodes, inits)
541
+ onnx.save(model, path)
542
+ if validate(path, td, providers):
543
+ return 'conv_var_diff_pcr', model
544
+ return None