rogermt commited on
Commit
cdb8bf6
·
verified ·
1 Parent(s): 4b9174d

Exp 3: PCA/Truncated SVD before lstsq — implemented, tested, 0 new solves

Browse files

Refactored conv.py into composable primitives:
- _build_patch_matrix: builds P, T, T_oh from examples
- _solve_weights: raw lstsq (unchanged behavior)
- _solve_weights_pcr: PCA regression fallback (new)
- _extract_weights: WT -> (Wconv, B) for ONNX

All 4 conv solvers now use deferred 2-pass design:
Pass 1: raw lstsq (identical to baseline)
Pass 2: PCR on ks values where lstsq fit train but failed arc-gen validation

Results (400 tasks, budget=5s, full arc-gen validation):
- Baseline: 49 solved, 603.6 score
- With PCR: 50 solved, 681.6 score (Task 61 timing artifact, 0 actual PCR solves)
- No regressions on existing 25 conv tasks

Key findings from PCR diagnostic:
- 25 solved conv tasks: PCR can't improve — low p/n tasks don't need it, high p/n tasks need ALL dimensions
- 345 unsolved tasks: only 10 have lstsq train-fit, PCR improves 4 by 3-9% but none reach 100%
- Architecture mismatch confirmed as root cause, not regularization

Files changed (1) hide show
  1. neurogolf_solver/solvers/conv.py +260 -32
neurogolf_solver/solvers/conv.py CHANGED
@@ -1,5 +1,11 @@
1
  #!/usr/bin/env python3
2
- """Convolutional solvers with least squares fitting."""
 
 
 
 
 
 
3
 
4
  import time
5
  import numpy as np
@@ -11,9 +17,13 @@ from ..validators import validate
11
  from ..constants import GH, GW
12
 
13
 
14
- def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
15
- """Least squares convolutional weight fitting.
16
- Returns (Wconv, B) or None."""
 
 
 
 
17
  pad = ks // 2
18
  feat = 10 * ks * ks + (1 if use_bias else 0)
19
  if feat > 20000:
@@ -47,12 +57,57 @@ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
47
  T_oh = np.zeros((len(T), 10), dtype=np.float64)
48
  for i, t in enumerate(T):
49
  T_oh[i, t] = 1.0
 
 
 
 
 
50
  try:
51
  WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
52
  except (np.linalg.LinAlgError, ValueError):
53
  return None
54
  if not np.array_equal(np.argmax(P @ WT, axis=1), T):
55
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  if use_bias:
57
  Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
58
  B = WT[-1].astype(np.float32)
@@ -62,8 +117,74 @@ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
62
  return Wconv, B
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def solve_conv_fixed(td, path, providers, time_budget=30.0):
66
- """Fixed-shape convolutional solver."""
67
  exs = get_exs(td)
68
  for inp, out in exs:
69
  if inp.shape != out.shape:
@@ -75,6 +196,8 @@ def solve_conv_fixed(td, path, providers, time_budget=30.0):
75
  fit_exs = get_exs_for_fitting(td)
76
  fit_exs = [(i, o) for i, o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
77
  t_start = time.time()
 
 
78
  for use_bias in [False, True]:
79
  for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
80
  if time.time() - t_start > time_budget:
@@ -105,11 +228,50 @@ def solve_conv_fixed(td, path, providers, time_budget=30.0):
105
  onnx.save(model, path)
106
  if validate(path, td, providers):
107
  return 'conv_fixed', model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  return None
109
 
110
 
111
  def solve_conv_variable(td, path, providers, time_budget=30.0):
112
- """Variable-shape conv with opset 17 ReduceSum."""
113
  exs = get_exs(td)
114
  for inp, out in exs:
115
  if inp.shape != out.shape:
@@ -117,6 +279,8 @@ def solve_conv_variable(td, path, providers, time_budget=30.0):
117
  fit_exs = get_exs_for_fitting_variable(td)
118
  fit_exs = [(i, o) for i, o in fit_exs if i.shape == o.shape]
119
  t_start = time.time()
 
 
120
  for use_bias in [False, True]:
121
  for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
122
  if time.time() - t_start > time_budget:
@@ -145,11 +309,19 @@ def solve_conv_variable(td, path, providers, time_budget=30.0):
145
  onnx.save(model, path)
146
  if validate(path, td, providers):
147
  return 'conv_var', model
 
 
 
 
 
 
 
 
148
  return None
149
 
150
 
151
  def solve_conv_diffshape(td, path, providers, time_budget=30.0):
152
- """Different-shape convolutional solver."""
153
  sp = fixed_shapes(td)
154
  if sp is None:
155
  return None
@@ -162,11 +334,12 @@ def solve_conv_diffshape(td, path, providers, time_budget=30.0):
162
  return None
163
  exs = get_exs(td)
164
  t_start = time.time()
 
165
  for dr_off, dc_off in [(0, 0), ((IH - OH) // 2, (IW - OW) // 2)]:
166
  for use_bias in [False, True]:
167
  for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
168
  if time.time() - t_start > time_budget:
169
- return None
170
  pad = ks // 2
171
  feat = 10 * ks * ks + (1 if use_bias else 0)
172
  if feat > 10000:
@@ -203,18 +376,11 @@ def solve_conv_diffshape(td, path, providers, time_budget=30.0):
203
  T_oh = np.zeros((len(T), 10), dtype=np.float64)
204
  for i, t in enumerate(T):
205
  T_oh[i, t] = 1.0
206
- try:
207
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
208
- except (np.linalg.LinAlgError, ValueError):
209
- continue
210
- if not np.array_equal(np.argmax(P @ WT, axis=1), T):
211
  continue
212
- if use_bias:
213
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
214
- B = WT[-1].astype(np.float32)
215
- else:
216
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
217
- B = None
218
  pad_h, pad_w = GH - OH, GW - OW
219
  inits = [
220
  _make_int64_init('sl_st', [0, 0, 0, 0]),
@@ -239,17 +405,52 @@ def solve_conv_diffshape(td, path, providers, time_budget=30.0):
239
  onnx.save(model, path)
240
  if validate(path, td, providers):
241
  return 'conv_diff', model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  return None
243
 
244
 
245
  def solve_conv_var_diff(td, path, providers, time_budget=30.0):
246
- """Variable diff-shape conv with opset 17 ReduceSum."""
247
  exs = get_exs(td)
248
  t_start = time.time()
 
249
  for use_bias in [False, True]:
250
  for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
251
  if time.time() - t_start > time_budget:
252
- return None
253
  pad = ks // 2
254
  feat = 10 * ks * ks + (1 if use_bias else 0)
255
  if feat > 20000:
@@ -277,18 +478,11 @@ def solve_conv_var_diff(td, path, providers, time_budget=30.0):
277
  T_oh = np.zeros((len(T), 10), dtype=np.float64)
278
  for i, t in enumerate(T):
279
  T_oh[i, t] = 1.0
280
- try:
281
- WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
282
- except (np.linalg.LinAlgError, ValueError):
283
- continue
284
- if not np.array_equal(np.argmax(P @ WT, axis=1), T):
285
  continue
286
- if use_bias:
287
- Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
288
- B = WT[-1].astype(np.float32)
289
- else:
290
- Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32)
291
- B = None
292
  all_output_within_input = all(
293
  out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
294
  for inp_g, out_g in exs
@@ -313,4 +507,38 @@ def solve_conv_var_diff(td, path, providers, time_budget=30.0):
313
  onnx.save(model, path)
314
  if validate(path, td, providers):
315
  return 'conv_var_diff', model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  return None
 
1
  #!/usr/bin/env python3
2
+ """Convolutional solvers with least squares fitting.
3
+
4
+ v5.1: Refactored into composable primitives (_build_patch_matrix, _solve_weights,
5
+ _extract_weights) + PCR (PCA regression) fallback via _solve_weights_pcr.
6
+ PCR tested on 400 tasks: 0 new solves but no regressions. Code kept for
7
+ future experiments (Lasso, Ridge can reuse the same _solve_weights interface).
8
+ """
9
 
10
  import time
11
  import numpy as np
 
17
  from ..constants import GH, GW
18
 
19
 
20
+ # ---------------------------------------------------------------------------
21
+ # Core fitting primitives (composable: mix _build_patch_matrix with any solver)
22
+ # ---------------------------------------------------------------------------
23
+
24
+ def _build_patch_matrix(exs_raw, ks, use_bias, use_full_30=False):
25
+ """Build patch matrix P and target matrix T_oh from examples.
26
+ Returns (P, T, T_oh) or None if infeasible."""
27
  pad = ks // 2
28
  feat = 10 * ks * ks + (1 if use_bias else 0)
29
  if feat > 20000:
 
57
  T_oh = np.zeros((len(T), 10), dtype=np.float64)
58
  for i, t in enumerate(T):
59
  T_oh[i, t] = 1.0
60
+ return P, T, T_oh
61
+
62
+
63
+ def _solve_weights(P, T, T_oh):
64
+ """Raw lstsq solve. Returns WT (p×10) or None."""
65
  try:
66
  WT = np.linalg.lstsq(P, T_oh, rcond=None)[0]
67
  except (np.linalg.LinAlgError, ValueError):
68
  return None
69
  if not np.array_equal(np.argmax(P @ WT, axis=1), T):
70
  return None
71
+ return WT
72
+
73
+
74
+ def _solve_weights_pcr(P, T, T_oh, var_thresholds=(0.999, 0.99, 0.95)):
75
+ """PCA/Truncated SVD regression. Try multiple variance thresholds.
76
+ Returns WT (p×10) or None.
77
+ Only attempted when p/n > 0.5 (potential overfitting zone).
78
+
79
+ Tested 2026-04-26: improves arc-gen accuracy by 3-9% on 4/345 unsolved
80
+ tasks but never reaches 100% required for validation. Kept as fallback
81
+ for marginal cases and for future combination with more arc-gen data."""
82
+ n, p = P.shape
83
+ if p / max(n, 1) <= 0.5:
84
+ return None # lstsq is safe here, no need for PCR
85
+ try:
86
+ U, s, Vt = np.linalg.svd(P, full_matrices=False)
87
+ except (np.linalg.LinAlgError, ValueError):
88
+ return None
89
+ cumvar = np.cumsum(s**2) / np.sum(s**2)
90
+ for thresh in var_thresholds:
91
+ k = int(np.searchsorted(cumvar, thresh)) + 1
92
+ k = max(k, 5)
93
+ k = min(k, min(n, p))
94
+ P_red = U[:, :k] * s[:k]
95
+ try:
96
+ w_red = np.linalg.lstsq(P_red, T_oh, rcond=None)[0]
97
+ except (np.linalg.LinAlgError, ValueError):
98
+ continue
99
+ if not np.array_equal(np.argmax(P_red @ w_red, axis=1), T):
100
+ continue
101
+ # Map back to full p-dimensional weights for ONNX conv
102
+ WT = Vt[:k].T @ w_red
103
+ # Verify full-space predictions match
104
+ if np.array_equal(np.argmax(P @ WT, axis=1), T):
105
+ return WT
106
+ return None
107
+
108
+
109
+ def _extract_weights(WT, ks, use_bias):
110
+ """Extract Wconv and B from weight matrix WT."""
111
  if use_bias:
112
  Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32)
113
  B = WT[-1].astype(np.float32)
 
117
  return Wconv, B
118
 
119
 
120
+ # ---------------------------------------------------------------------------
121
+ # Convenience wrappers (combine primitives into single-call fitting)
122
+ # ---------------------------------------------------------------------------
123
+
124
+ def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False):
125
+ """Least squares convolutional weight fitting.
126
+ Returns (Wconv, B) or None."""
127
+ ptm = _build_patch_matrix(exs_raw, ks, use_bias, use_full_30)
128
+ if ptm is None:
129
+ return None
130
+ P, T, T_oh = ptm
131
+ WT = _solve_weights(P, T, T_oh)
132
+ if WT is None:
133
+ return None
134
+ return _extract_weights(WT, ks, use_bias)
135
+
136
+
137
+ def _lstsq_conv_pcr(exs_raw, ks, use_bias, use_full_30=False):
138
+ """PCA regression convolutional weight fitting.
139
+ Returns (Wconv, B) or None. Fallback when raw lstsq overfits."""
140
+ ptm = _build_patch_matrix(exs_raw, ks, use_bias, use_full_30)
141
+ if ptm is None:
142
+ return None
143
+ P, T, T_oh = ptm
144
+ WT = _solve_weights_pcr(P, T, T_oh)
145
+ if WT is None:
146
+ return None
147
+ return _extract_weights(WT, ks, use_bias)
148
+
149
+
150
+ # ---------------------------------------------------------------------------
151
+ # Solver functions (called from solver_registry.py)
152
+ # ---------------------------------------------------------------------------
153
+
154
+ def _build_and_validate_conv_fixed(fit_fn, fit_exs, ks, use_bias, IH, IW, td, path, providers):
155
+ """Build ONNX model with given fit function, validate it. Returns (tag, model) or None."""
156
+ result = fit_fn(fit_exs, ks, use_bias, use_full_30=False)
157
+ if result is None:
158
+ return None
159
+ Wconv, B = result
160
+ pad = ks // 2
161
+ pad_h, pad_w = GH - IH, GW - IW
162
+ inits = [
163
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
164
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
165
+ numpy_helper.from_array(Wconv, 'W'),
166
+ ]
167
+ conv_inputs = ['grid', 'W']
168
+ if B is not None:
169
+ inits.append(numpy_helper.from_array(B, 'B'))
170
+ conv_inputs.append('B')
171
+ nodes = [
172
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']),
173
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
174
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
175
+ ]
176
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
177
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
178
+ model = mk(nodes, inits)
179
+ onnx.save(model, path)
180
+ if validate(path, td, providers):
181
+ tag = 'conv_fixed' if fit_fn == _lstsq_conv else 'conv_fixed_pcr'
182
+ return tag, model
183
+ return None
184
+
185
+
186
  def solve_conv_fixed(td, path, providers, time_budget=30.0):
187
+ """Fixed-shape convolutional solver. Tries lstsq first, PCR as second pass."""
188
  exs = get_exs(td)
189
  for inp, out in exs:
190
  if inp.shape != out.shape:
 
196
  fit_exs = get_exs_for_fitting(td)
197
  fit_exs = [(i, o) for i, o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)]
198
  t_start = time.time()
199
+ # Pass 1: raw lstsq (same as baseline)
200
+ failed_ks = [] # (ks, use_bias) pairs where lstsq fit train but failed validation
201
  for use_bias in [False, True]:
202
  for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
203
  if time.time() - t_start > time_budget:
 
228
  onnx.save(model, path)
229
  if validate(path, td, providers):
230
  return 'conv_fixed', model
231
+ # lstsq fit train but failed validation — candidate for PCR
232
+ failed_ks.append((ks, use_bias))
233
+ # Pass 2: PCR on failed ks values (only if time remains)
234
+ for ks, use_bias in failed_ks:
235
+ if time.time() - t_start > time_budget:
236
+ return None
237
+ r = _build_and_validate_conv_fixed(_lstsq_conv_pcr, fit_exs, ks, use_bias, IH, IW, td, path, providers)
238
+ if r is not None:
239
+ return r
240
+ return None
241
+
242
+
243
+ def _build_and_validate_conv_var(fit_fn, fit_exs, ks, use_bias, td, path, providers):
244
+ """Build variable-shape ONNX model with given fit function. Returns (tag, model) or None."""
245
+ result = fit_fn(fit_exs, ks, use_bias, use_full_30=True)
246
+ if result is None:
247
+ return None
248
+ Wconv, B = result
249
+ pad = ks // 2
250
+ inits = [
251
+ numpy_helper.from_array(Wconv, 'W'),
252
+ _make_int64_init('rs_axes_var', [1]),
253
+ ]
254
+ conv_inputs = ['input', 'W']
255
+ if B is not None:
256
+ inits.append(numpy_helper.from_array(B, 'B'))
257
+ conv_inputs.append('B')
258
+ nodes = [
259
+ helper.make_node('ReduceSum', ['input', 'rs_axes_var'], ['mask'], keepdims=1),
260
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
261
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
262
+ ]
263
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
264
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
265
+ model = mk(nodes, inits)
266
+ onnx.save(model, path)
267
+ if validate(path, td, providers):
268
+ tag = 'conv_var' if fit_fn == _lstsq_conv else 'conv_var_pcr'
269
+ return tag, model
270
  return None
271
 
272
 
273
  def solve_conv_variable(td, path, providers, time_budget=30.0):
274
+ """Variable-shape conv. Tries lstsq first, PCR as second pass."""
275
  exs = get_exs(td)
276
  for inp, out in exs:
277
  if inp.shape != out.shape:
 
279
  fit_exs = get_exs_for_fitting_variable(td)
280
  fit_exs = [(i, o) for i, o in fit_exs if i.shape == o.shape]
281
  t_start = time.time()
282
+ # Pass 1: raw lstsq
283
+ failed_ks = []
284
  for use_bias in [False, True]:
285
  for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
286
  if time.time() - t_start > time_budget:
 
309
  onnx.save(model, path)
310
  if validate(path, td, providers):
311
  return 'conv_var', model
312
+ failed_ks.append((ks, use_bias))
313
+ # Pass 2: PCR on failed ks values
314
+ for ks, use_bias in failed_ks:
315
+ if time.time() - t_start > time_budget:
316
+ return None
317
+ r = _build_and_validate_conv_var(_lstsq_conv_pcr, fit_exs, ks, use_bias, td, path, providers)
318
+ if r is not None:
319
+ return r
320
  return None
321
 
322
 
323
  def solve_conv_diffshape(td, path, providers, time_budget=30.0):
324
+ """Different-shape convolutional solver. Tries lstsq first, PCR as second pass."""
325
  sp = fixed_shapes(td)
326
  if sp is None:
327
  return None
 
334
  return None
335
  exs = get_exs(td)
336
  t_start = time.time()
337
+ failed_configs = [] # (P, T, T_oh, ks, use_bias, dr_off, dc_off) for PCR retry
338
  for dr_off, dc_off in [(0, 0), ((IH - OH) // 2, (IW - OW) // 2)]:
339
  for use_bias in [False, True]:
340
  for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]:
341
  if time.time() - t_start > time_budget:
342
+ break
343
  pad = ks // 2
344
  feat = 10 * ks * ks + (1 if use_bias else 0)
345
  if feat > 10000:
 
376
  T_oh = np.zeros((len(T), 10), dtype=np.float64)
377
  for i, t in enumerate(T):
378
  T_oh[i, t] = 1.0
379
+ # Pass 1: raw lstsq
380
+ WT = _solve_weights(P, T, T_oh)
381
+ if WT is None:
 
 
382
  continue
383
+ Wconv, B = _extract_weights(WT, ks, use_bias)
 
 
 
 
 
384
  pad_h, pad_w = GH - OH, GW - OW
385
  inits = [
386
  _make_int64_init('sl_st', [0, 0, 0, 0]),
 
405
  onnx.save(model, path)
406
  if validate(path, td, providers):
407
  return 'conv_diff', model
408
+ # Failed validation — save for PCR retry
409
+ failed_configs.append((P, T, T_oh, ks, use_bias, dr_off, dc_off))
410
+ # Pass 2: PCR on failed configs
411
+ for P, T, T_oh, ks, use_bias, dr_off, dc_off in failed_configs:
412
+ if time.time() - t_start > time_budget:
413
+ return None
414
+ WT = _solve_weights_pcr(P, T, T_oh)
415
+ if WT is None:
416
+ continue
417
+ Wconv, B = _extract_weights(WT, ks, use_bias)
418
+ pad_h, pad_w = GH - OH, GW - OW
419
+ inits = [
420
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
421
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
422
+ numpy_helper.from_array(Wconv, 'W'),
423
+ _make_int64_init('cr_st', [0, 0, dr_off, dc_off]),
424
+ _make_int64_init('cr_en', [1, 10, dr_off + OH, dc_off + OW]),
425
+ ]
426
+ conv_inputs = ['grid', 'W']
427
+ if B is not None:
428
+ inits.append(numpy_helper.from_array(B, 'B'))
429
+ conv_inputs.append('B')
430
+ nodes = [
431
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']),
432
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
433
+ helper.make_node('Slice', ['co', 'cr_st', 'cr_en'], ['co_crop']),
434
+ helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1),
435
+ ]
436
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
437
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
438
+ model = mk(nodes, inits)
439
+ onnx.save(model, path)
440
+ if validate(path, td, providers):
441
+ return 'conv_diff_pcr', model
442
  return None
443
 
444
 
445
  def solve_conv_var_diff(td, path, providers, time_budget=30.0):
446
+ """Variable diff-shape conv. Tries lstsq first, PCR as second pass."""
447
  exs = get_exs(td)
448
  t_start = time.time()
449
+ failed_configs = [] # (P, T, T_oh, ks, use_bias) for PCR retry
450
  for use_bias in [False, True]:
451
  for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]:
452
  if time.time() - t_start > time_budget:
453
+ break
454
  pad = ks // 2
455
  feat = 10 * ks * ks + (1 if use_bias else 0)
456
  if feat > 20000:
 
478
  T_oh = np.zeros((len(T), 10), dtype=np.float64)
479
  for i, t in enumerate(T):
480
  T_oh[i, t] = 1.0
481
+ # Pass 1: raw lstsq
482
+ WT = _solve_weights(P, T, T_oh)
483
+ if WT is None:
 
 
484
  continue
485
+ Wconv, B = _extract_weights(WT, ks, use_bias)
 
 
 
 
 
486
  all_output_within_input = all(
487
  out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
488
  for inp_g, out_g in exs
 
507
  onnx.save(model, path)
508
  if validate(path, td, providers):
509
  return 'conv_var_diff', model
510
+ # Failed validation — save for PCR
511
+ failed_configs.append((P, T, T_oh, ks, use_bias))
512
+ # Pass 2: PCR on failed configs
513
+ for P, T, T_oh, ks, use_bias in failed_configs:
514
+ if time.time() - t_start > time_budget:
515
+ return None
516
+ WT = _solve_weights_pcr(P, T, T_oh)
517
+ if WT is None:
518
+ continue
519
+ Wconv, B = _extract_weights(WT, ks, use_bias)
520
+ all_output_within_input = all(
521
+ out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1]
522
+ for inp_g, out_g in exs
523
+ )
524
+ if all_output_within_input:
525
+ inits = [
526
+ numpy_helper.from_array(Wconv, 'W'),
527
+ _make_int64_init('rs_axes_vd', [1]),
528
+ ]
529
+ conv_inputs = ['input', 'W']
530
+ if B is not None:
531
+ inits.append(numpy_helper.from_array(B, 'B'))
532
+ conv_inputs.append('B')
533
+ nodes = [
534
+ helper.make_node('ReduceSum', ['input', 'rs_axes_vd'], ['mask'], keepdims=1),
535
+ helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4),
536
+ helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1),
537
+ ]
538
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
539
+ nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output']))
540
+ model = mk(nodes, inits)
541
+ onnx.save(model, path)
542
+ if validate(path, td, providers):
543
+ return 'conv_var_diff_pcr', model
544
  return None