#!/usr/bin/env python3 """Convolutional solvers with least squares fitting. v5.1: Refactored into composable primitives (_build_patch_matrix, _solve_weights, _extract_weights) + PCR (PCA regression) fallback via _solve_weights_pcr. PCR tested on 400 tasks: 0 new solves but no regressions. Code kept for future experiments (Lasso, Ridge can reuse the same _solve_weights interface). """ import time import numpy as np import onnx from onnx import helper, numpy_helper from ..onnx_helpers import mk, _make_int64_init, _build_pad_node, add_onehot_block from ..data_loader import get_exs, get_exs_for_fitting, get_exs_for_fitting_variable, fixed_shapes from ..validators import validate from ..constants import GH, GW # --------------------------------------------------------------------------- # Core fitting primitives (composable: mix _build_patch_matrix with any solver) # --------------------------------------------------------------------------- def _build_patch_matrix(exs_raw, ks, use_bias, use_full_30=False): """Build patch matrix P and target matrix T_oh from examples. Returns (P, T, T_oh) or None if infeasible.""" pad = ks // 2 feat = 10 * ks * ks + (1 if use_bias else 0) if feat > 20000: return None patches, targets = [], [] for inp_g, out_g in exs_raw: ih, iw = inp_g.shape if use_full_30: oh_full = np.zeros((10, GH, GW), dtype=np.float64) for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c) oh_pad = np.pad(oh_full, ((0, 0), (pad, pad), (pad, pad))) else: oh_enc = np.zeros((10, ih, iw), dtype=np.float64) for c in range(10): oh_enc[c] = (inp_g == c) oh_pad = np.pad(oh_enc, ((0, 0), (pad, pad), (pad, pad))) oh, ow = out_g.shape for r in range(oh): for c in range(ow): p = oh_pad[:, r:r + ks, c:c + ks].flatten() if use_bias: p = np.append(p, 1.0) patches.append(p) targets.append(int(out_g[r, c])) n_patches = len(patches) if feat > 5000 and n_patches > 2000: return None P = np.array(patches, dtype=np.float64) T = np.array(targets, dtype=np.int64) T_oh = np.zeros((len(T), 10), dtype=np.float64) for i, t in enumerate(T): T_oh[i, t] = 1.0 return P, T, T_oh def _solve_weights(P, T, T_oh): """Raw lstsq solve. Returns WT (p×10) or None.""" try: WT = np.linalg.lstsq(P, T_oh, rcond=None)[0] except (np.linalg.LinAlgError, ValueError): return None if not np.array_equal(np.argmax(P @ WT, axis=1), T): return None return WT def _solve_weights_pcr(P, T, T_oh, var_thresholds=(0.999, 0.99, 0.95)): """PCA/Truncated SVD regression. Try multiple variance thresholds. Returns WT (p×10) or None. Only attempted when p/n > 0.5 (potential overfitting zone). Tested 2026-04-26: improves arc-gen accuracy by 3-9% on 4/345 unsolved tasks but never reaches 100% required for validation. Kept as fallback for marginal cases and for future combination with more arc-gen data.""" n, p = P.shape if p / max(n, 1) <= 0.5: return None # lstsq is safe here, no need for PCR try: U, s, Vt = np.linalg.svd(P, full_matrices=False) except (np.linalg.LinAlgError, ValueError): return None cumvar = np.cumsum(s**2) / np.sum(s**2) for thresh in var_thresholds: k = int(np.searchsorted(cumvar, thresh)) + 1 k = max(k, 5) k = min(k, min(n, p)) P_red = U[:, :k] * s[:k] try: w_red = np.linalg.lstsq(P_red, T_oh, rcond=None)[0] except (np.linalg.LinAlgError, ValueError): continue if not np.array_equal(np.argmax(P_red @ w_red, axis=1), T): continue # Map back to full p-dimensional weights for ONNX conv WT = Vt[:k].T @ w_red # Verify full-space predictions match if np.array_equal(np.argmax(P @ WT, axis=1), T): return WT return None def _extract_weights(WT, ks, use_bias): """Extract Wconv and B from weight matrix WT.""" if use_bias: Wconv = WT[:-1].T.reshape(10, 10, ks, ks).astype(np.float32) B = WT[-1].astype(np.float32) else: Wconv = WT.T.reshape(10, 10, ks, ks).astype(np.float32) B = None return Wconv, B # --------------------------------------------------------------------------- # Convenience wrappers (combine primitives into single-call fitting) # --------------------------------------------------------------------------- def _lstsq_conv(exs_raw, ks, use_bias, use_full_30=False): """Least squares convolutional weight fitting. Returns (Wconv, B) or None.""" ptm = _build_patch_matrix(exs_raw, ks, use_bias, use_full_30) if ptm is None: return None P, T, T_oh = ptm WT = _solve_weights(P, T, T_oh) if WT is None: return None return _extract_weights(WT, ks, use_bias) def _lstsq_conv_pcr(exs_raw, ks, use_bias, use_full_30=False): """PCA regression convolutional weight fitting. Returns (Wconv, B) or None. Fallback when raw lstsq overfits.""" ptm = _build_patch_matrix(exs_raw, ks, use_bias, use_full_30) if ptm is None: return None P, T, T_oh = ptm WT = _solve_weights_pcr(P, T, T_oh) if WT is None: return None return _extract_weights(WT, ks, use_bias) # --------------------------------------------------------------------------- # Solver functions (called from solver_registry.py) # --------------------------------------------------------------------------- def _build_and_validate_conv_fixed(fit_fn, fit_exs, ks, use_bias, IH, IW, td, path, providers): """Build ONNX model with given fit function, validate it. Returns (tag, model) or None.""" result = fit_fn(fit_exs, ks, use_bias, use_full_30=False) if result is None: return None Wconv, B = result pad = ks // 2 pad_h, pad_w = GH - IH, GW - IW inits = [ _make_int64_init('sl_st', [0, 0, 0, 0]), _make_int64_init('sl_en', [1, 10, IH, IW]), numpy_helper.from_array(Wconv, 'W'), ] conv_inputs = ['grid', 'W'] if B is not None: inits.append(numpy_helper.from_array(B, 'B')) conv_inputs.append('B') nodes = [ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']), helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4), helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1), ] add_onehot_block(nodes, inits, 'am', 'oh_out') nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits)) model = mk(nodes, inits) onnx.save(model, path) if validate(path, td, providers): tag = 'conv_fixed' if fit_fn == _lstsq_conv else 'conv_fixed_pcr' return tag, model return None def solve_conv_fixed(td, path, providers, time_budget=30.0): """Fixed-shape convolutional solver. Tries lstsq first, PCR as second pass.""" exs = get_exs(td) for inp, out in exs: if inp.shape != out.shape: return None shapes = set(inp.shape for inp, _ in exs) if len(shapes) != 1: return None IH, IW = shapes.pop() fit_exs = get_exs_for_fitting(td) fit_exs = [(i, o) for i, o in fit_exs if i.shape == o.shape and i.shape == (IH, IW)] t_start = time.time() # Pass 1: raw lstsq (same as baseline) failed_ks = [] # (ks, use_bias) pairs where lstsq fit train but failed validation for use_bias in [False, True]: for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]: if time.time() - t_start > time_budget: return None result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=False) if result is None: continue Wconv, B = result pad = ks // 2 pad_h, pad_w = GH - IH, GW - IW inits = [ _make_int64_init('sl_st', [0, 0, 0, 0]), _make_int64_init('sl_en', [1, 10, IH, IW]), numpy_helper.from_array(Wconv, 'W'), ] conv_inputs = ['grid', 'W'] if B is not None: inits.append(numpy_helper.from_array(B, 'B')) conv_inputs.append('B') nodes = [ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']), helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4), helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1), ] add_onehot_block(nodes, inits, 'am', 'oh_out') nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits)) model = mk(nodes, inits) onnx.save(model, path) if validate(path, td, providers): return 'conv_fixed', model # lstsq fit train but failed validation — candidate for PCR failed_ks.append((ks, use_bias)) # Pass 2: PCR on failed ks values (only if time remains) for ks, use_bias in failed_ks: if time.time() - t_start > time_budget: return None r = _build_and_validate_conv_fixed(_lstsq_conv_pcr, fit_exs, ks, use_bias, IH, IW, td, path, providers) if r is not None: return r return None def _build_and_validate_conv_var(fit_fn, fit_exs, ks, use_bias, td, path, providers): """Build variable-shape ONNX model with given fit function. Returns (tag, model) or None.""" result = fit_fn(fit_exs, ks, use_bias, use_full_30=True) if result is None: return None Wconv, B = result pad = ks // 2 inits = [ numpy_helper.from_array(Wconv, 'W'), _make_int64_init('rs_axes_var', [1]), ] conv_inputs = ['input', 'W'] if B is not None: inits.append(numpy_helper.from_array(B, 'B')) conv_inputs.append('B') nodes = [ helper.make_node('ReduceSum', ['input', 'rs_axes_var'], ['mask'], keepdims=1), helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4), helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1), ] add_onehot_block(nodes, inits, 'am', 'oh_out') nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output'])) model = mk(nodes, inits) onnx.save(model, path) if validate(path, td, providers): tag = 'conv_var' if fit_fn == _lstsq_conv else 'conv_var_pcr' return tag, model return None def solve_conv_variable(td, path, providers, time_budget=30.0): """Variable-shape conv. Tries lstsq first, PCR as second pass.""" exs = get_exs(td) for inp, out in exs: if inp.shape != out.shape: return None fit_exs = get_exs_for_fitting_variable(td) fit_exs = [(i, o) for i, o in fit_exs if i.shape == o.shape] t_start = time.time() # Pass 1: raw lstsq failed_ks = [] for use_bias in [False, True]: for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]: if time.time() - t_start > time_budget: return None result = _lstsq_conv(fit_exs, ks, use_bias, use_full_30=True) if result is None: continue Wconv, B = result pad = ks // 2 inits = [ numpy_helper.from_array(Wconv, 'W'), _make_int64_init('rs_axes_var', [1]), ] conv_inputs = ['input', 'W'] if B is not None: inits.append(numpy_helper.from_array(B, 'B')) conv_inputs.append('B') nodes = [ helper.make_node('ReduceSum', ['input', 'rs_axes_var'], ['mask'], keepdims=1), helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4), helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1), ] add_onehot_block(nodes, inits, 'am', 'oh_out') nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output'])) model = mk(nodes, inits) onnx.save(model, path) if validate(path, td, providers): return 'conv_var', model failed_ks.append((ks, use_bias)) # Pass 2: PCR on failed ks values for ks, use_bias in failed_ks: if time.time() - t_start > time_budget: return None r = _build_and_validate_conv_var(_lstsq_conv_pcr, fit_exs, ks, use_bias, td, path, providers) if r is not None: return r return None def solve_conv_diffshape(td, path, providers, time_budget=30.0): """Different-shape convolutional solver. Tries lstsq first, PCR as second pass.""" sp = fixed_shapes(td) if sp is None: return None (IH, IW), (OH, OW) = sp if IH == OH and IW == OW: return None if OH > IH or OW > IW: return None if OH > 30 or OW > 30: return None exs = get_exs(td) t_start = time.time() failed_configs = [] # (P, T, T_oh, ks, use_bias, dr_off, dc_off) for PCR retry for dr_off, dc_off in [(0, 0), ((IH - OH) // 2, (IW - OW) // 2)]: for use_bias in [False, True]: for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21]: if time.time() - t_start > time_budget: break pad = ks // 2 feat = 10 * ks * ks + (1 if use_bias else 0) if feat > 10000: continue patches, targets = [], [] valid = True for inp_g, out_g in exs: oh_enc = np.zeros((10, IH, IW), dtype=np.float64) for c in range(10): oh_enc[c] = (inp_g == c) oh_pad = np.pad(oh_enc, ((0, 0), (pad, pad), (pad, pad))) for r in range(OH): for c in range(OW): sr, sc = r + dr_off, c + dc_off if sr < 0 or sr >= IH or sc < 0 or sc >= IW: valid = False break p = oh_pad[:, sr:sr + ks, sc:sc + ks].flatten() if use_bias: p = np.append(p, 1.0) patches.append(p) targets.append(int(out_g[r, c])) if not valid: break if not valid: break if not valid: continue n_patches = len(patches) if feat > 5000 and n_patches > 2000: continue P = np.array(patches, dtype=np.float64) T = np.array(targets, dtype=np.int64) T_oh = np.zeros((len(T), 10), dtype=np.float64) for i, t in enumerate(T): T_oh[i, t] = 1.0 # Pass 1: raw lstsq WT = _solve_weights(P, T, T_oh) if WT is None: continue Wconv, B = _extract_weights(WT, ks, use_bias) pad_h, pad_w = GH - OH, GW - OW inits = [ _make_int64_init('sl_st', [0, 0, 0, 0]), _make_int64_init('sl_en', [1, 10, IH, IW]), numpy_helper.from_array(Wconv, 'W'), _make_int64_init('cr_st', [0, 0, dr_off, dc_off]), _make_int64_init('cr_en', [1, 10, dr_off + OH, dc_off + OW]), ] conv_inputs = ['grid', 'W'] if B is not None: inits.append(numpy_helper.from_array(B, 'B')) conv_inputs.append('B') nodes = [ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']), helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4), helper.make_node('Slice', ['co', 'cr_st', 'cr_en'], ['co_crop']), helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1), ] add_onehot_block(nodes, inits, 'am', 'oh_out') nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits)) model = mk(nodes, inits) onnx.save(model, path) if validate(path, td, providers): return 'conv_diff', model # Failed validation — save for PCR retry failed_configs.append((P, T, T_oh, ks, use_bias, dr_off, dc_off)) # Pass 2: PCR on failed configs for P, T, T_oh, ks, use_bias, dr_off, dc_off in failed_configs: if time.time() - t_start > time_budget: return None WT = _solve_weights_pcr(P, T, T_oh) if WT is None: continue Wconv, B = _extract_weights(WT, ks, use_bias) pad_h, pad_w = GH - OH, GW - OW inits = [ _make_int64_init('sl_st', [0, 0, 0, 0]), _make_int64_init('sl_en', [1, 10, IH, IW]), numpy_helper.from_array(Wconv, 'W'), _make_int64_init('cr_st', [0, 0, dr_off, dc_off]), _make_int64_init('cr_en', [1, 10, dr_off + OH, dc_off + OW]), ] conv_inputs = ['grid', 'W'] if B is not None: inits.append(numpy_helper.from_array(B, 'B')) conv_inputs.append('B') nodes = [ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['grid']), helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4), helper.make_node('Slice', ['co', 'cr_st', 'cr_en'], ['co_crop']), helper.make_node('ArgMax', ['co_crop'], ['am'], axis=1, keepdims=1), ] add_onehot_block(nodes, inits, 'am', 'oh_out') nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits)) model = mk(nodes, inits) onnx.save(model, path) if validate(path, td, providers): return 'conv_diff_pcr', model return None def solve_conv_var_diff(td, path, providers, time_budget=30.0): """Variable diff-shape conv. Tries lstsq first, PCR as second pass.""" exs = get_exs(td) t_start = time.time() failed_configs = [] # (P, T, T_oh, ks, use_bias) for PCR retry for use_bias in [False, True]: for ks in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29]: if time.time() - t_start > time_budget: break pad = ks // 2 feat = 10 * ks * ks + (1 if use_bias else 0) if feat > 20000: continue patches, targets = [], [] for inp_g, out_g in exs: ih, iw = inp_g.shape oh, ow = out_g.shape oh_full = np.zeros((10, GH, GW), dtype=np.float64) for c in range(10): oh_full[c, :ih, :iw] = (inp_g == c) oh_pad = np.pad(oh_full, ((0, 0), (pad, pad), (pad, pad))) for r in range(oh): for c in range(ow): p = oh_pad[:, r:r + ks, c:c + ks].flatten() if use_bias: p = np.append(p, 1.0) patches.append(p) targets.append(int(out_g[r, c])) n_patches = len(patches) if feat > 5000 and n_patches > 2000: continue P = np.array(patches, dtype=np.float64) T = np.array(targets, dtype=np.int64) T_oh = np.zeros((len(T), 10), dtype=np.float64) for i, t in enumerate(T): T_oh[i, t] = 1.0 # Pass 1: raw lstsq WT = _solve_weights(P, T, T_oh) if WT is None: continue Wconv, B = _extract_weights(WT, ks, use_bias) all_output_within_input = all( out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1] for inp_g, out_g in exs ) if all_output_within_input: inits = [ numpy_helper.from_array(Wconv, 'W'), _make_int64_init('rs_axes_vd', [1]), ] conv_inputs = ['input', 'W'] if B is not None: inits.append(numpy_helper.from_array(B, 'B')) conv_inputs.append('B') nodes = [ helper.make_node('ReduceSum', ['input', 'rs_axes_vd'], ['mask'], keepdims=1), helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4), helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1), ] add_onehot_block(nodes, inits, 'am', 'oh_out') nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output'])) model = mk(nodes, inits) onnx.save(model, path) if validate(path, td, providers): return 'conv_var_diff', model # Failed validation — save for PCR failed_configs.append((P, T, T_oh, ks, use_bias)) # Pass 2: PCR on failed configs for P, T, T_oh, ks, use_bias in failed_configs: if time.time() - t_start > time_budget: return None WT = _solve_weights_pcr(P, T, T_oh) if WT is None: continue Wconv, B = _extract_weights(WT, ks, use_bias) all_output_within_input = all( out_g.shape[0] <= inp_g.shape[0] and out_g.shape[1] <= inp_g.shape[1] for inp_g, out_g in exs ) if all_output_within_input: inits = [ numpy_helper.from_array(Wconv, 'W'), _make_int64_init('rs_axes_vd', [1]), ] conv_inputs = ['input', 'W'] if B is not None: inits.append(numpy_helper.from_array(B, 'B')) conv_inputs.append('B') nodes = [ helper.make_node('ReduceSum', ['input', 'rs_axes_vd'], ['mask'], keepdims=1), helper.make_node('Conv', conv_inputs, ['co'], kernel_shape=[ks, ks], pads=[pad] * 4), helper.make_node('ArgMax', ['co'], ['am'], axis=1, keepdims=1), ] add_onehot_block(nodes, inits, 'am', 'oh_out') nodes.append(helper.make_node('Mul', ['oh_out', 'mask'], ['output'])) model = mk(nodes, inits) onnx.save(model, path) if validate(path, td, providers): return 'conv_var_diff_pcr', model return None