| |
| """Mode fill solver — output = solid fill of most common input color. |
| |
| v5.2: Solves Task 129 (score 19.451). |
| Uses runtime ReduceSum→ArgMax→Expand for variable mode across inputs. |
| Falls through to s_constant when mode is fixed across all examples. |
| """ |
|
|
| import numpy as np |
| from onnx import helper, numpy_helper, TensorProto |
| from ..onnx_helpers import mk, _make_int64_init, _build_pad_node |
| from ..data_loader import get_exs, fixed_shapes |
| from ..constants import GH, GW |
|
|
|
|
| def s_mode_fill(td): |
| """Mode fill: output is entirely the most common color from input. |
| Uses runtime ArgMax to handle variable mode across inputs.""" |
| exs = get_exs(td) |
|
|
| for inp, out in exs: |
| if inp.shape != out.shape: |
| return None |
| vals, counts = np.unique(inp, return_counts=True) |
| mode = vals[np.argmax(counts)] |
| if not np.all(out == mode): |
| return None |
|
|
| |
| modes = set() |
| for inp, out in exs: |
| vals, counts = np.unique(inp, return_counts=True) |
| modes.add(vals[np.argmax(counts)]) |
|
|
| if len(modes) == 1: |
| return None |
|
|
| sp = fixed_shapes(td) |
| if sp is None: |
| return None |
| (IH, IW), (OH, OW) = sp |
| if (IH, IW) != (OH, OW): |
| return None |
|
|
| pad_h, pad_w = GH - IH, GW - IW |
|
|
| inits = [ |
| _make_int64_init('sl_st', [0, 0, 0, 0]), |
| _make_int64_init('sl_en', [1, 10, IH, IW]), |
| _make_int64_init('rs_axes_mode', [2, 3]), |
| numpy_helper.from_array(np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1), 'classes'), |
| ] |
|
|
| nodes = [ |
| helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']), |
| helper.make_node('ReduceSum', ['cropped', 'rs_axes_mode'], ['hist'], keepdims=1), |
| helper.make_node('ArgMax', ['hist'], ['mode_idx'], axis=1, keepdims=1), |
| helper.make_node('Equal', ['mode_idx', 'classes'], ['eq']), |
| helper.make_node('Cast', ['eq'], ['mode_oh'], to=TensorProto.FLOAT), |
| helper.make_node('Expand', ['mode_oh', 'sl_en'], ['expanded']), |
| ] |
| nodes.append(_build_pad_node('expanded', 'output', pad_h, pad_w, inits)) |
| return mk(nodes, inits) |
|
|