| |
| """Composition solvers β chain two analytical solvers into one ONNX graph. |
| |
| These handle tasks where a single transform cannot produce the output, |
| but two transforms in sequence can. E.g. rotate THEN color_map. |
| |
| Each composition builds one combined ONNX graph (no intermediate I/O). |
| """ |
|
|
| import numpy as np |
| from onnx import helper, TensorProto, numpy_helper |
| from ..data_loader import get_exs |
| from ..validators import validate |
| from ..constants import DT, IR, GRID_SHAPE, GH, GW |
| from ..config import make_opset |
| from .analytical import s_color_map |
| from .geometric import s_flip, s_rotate, s_shift, s_fixed_crop |
| from .tiling import s_tile, s_upscale, s_mirror_h, s_mirror_v, s_quad_mirror |
|
|
|
|
| def _run_solver_on_data(solver_fn, td): |
| """Try a solver, return the ONNX model if it passes validation on train pairs.""" |
| try: |
| model = solver_fn(td) |
| return model |
| except Exception: |
| return None |
|
|
|
|
| def s_transform_then_recolor(td): |
| """Try: spatial transform β color_map. Chains two existing solver graphs. |
| |
| For each (transform, color_map) pair: |
| 1. Run transform solver β get intermediate model |
| 2. Check if transform output passes train pairs β skip if not |
| 3. Build combined graph: input β transform β color_map β output |
| |
| Returns ONNX model or None. |
| """ |
| from ..onnx_helpers import mk |
| |
| transforms = [s_flip, s_rotate, s_shift, s_mirror_h, s_mirror_v, s_quad_mirror] |
| |
| for tfn in transforms: |
| |
| try: |
| tf_model = tfn(td) |
| if tf_model is None: |
| continue |
| except Exception: |
| continue |
| |
| |
| |
| |
| try: |
| cm_model = s_color_map(td) |
| if cm_model is None: |
| continue |
| except Exception: |
| continue |
| |
| |
| |
| |
| pass |
| |
| |
| |
| return None |
|
|
|
|
| def s_transform_then_recolor_v2(td): |
| """Build composed ONNX: spatial transform graph + color_map graph merged. |
| |
| Strategy: try all pairs of (transform, color_map) on the data. |
| Validate the composition against all train + test pairs. |
| |
| This version builds the combined ONNX graph by merging nodes from |
| both solver outputs into a single graph with a renamed intermediate. |
| """ |
| import onnx |
| from ..onnx_helpers import mk |
| from ..profiler import score_network |
| import tempfile |
| |
| exs = get_exs(td) |
| if len(exs) < 2: |
| return None |
| |
| |
| same_shape = all(inp.shape == out.shape for inp, out in exs) |
| if not same_shape: |
| return None |
| |
| transforms = [ |
| ('flip_h', lambda td: s_flip(td, direction='horizontal')), |
| ('flip_v', lambda td: s_flip(td, direction='vertical')), |
| ('rotate90', lambda td: s_rotate(td, k=1)), |
| ('rotate180', lambda td: s_rotate(td, k=2)), |
| ('rotate270', lambda td: s_rotate(td, k=3)), |
| ('transpose', lambda td: _s_transpose(td)), |
| ] |
| |
| |
| for tf_name, tf_fn in transforms: |
| try: |
| tf_model = tf_fn(td) |
| if tf_model is None: |
| continue |
| except Exception: |
| continue |
| |
| |
| with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp: |
| onnx.save(tf_model, tmp.name) |
| if validate(tmp.name, td, ['CPUExecutionProvider']): |
| os.unlink(tmp.name) |
| continue |
| |
| |
| |
| cm_model = s_color_map(td) |
| if cm_model is None: |
| continue |
| |
| |
| composed = _merge_graphs(tf_model, cm_model, f"{tf_name}_then_recolor") |
| if composed is not None: |
| return composed |
| |
| return None |
|
|
|
|
| def _merge_graphs(model_a, model_b, name="composed"): |
| """Merge two ONNX models into a single graph. |
| |
| model_a: input β intermediate |
| model_b: intermediate β output |
| |
| The output name of model_a becomes the input name of model_b. |
| """ |
| import onnx |
| |
| graph_a = model_a.graph |
| graph_b = model_b.graph |
| |
| |
| a_output = graph_a.output[0].name |
| |
| |
| b_input = graph_b.input[0].name |
| |
| |
| nodes_b = [] |
| for node in graph_b.node: |
| new_inputs = [a_output if inp == b_input else inp for inp in node.input] |
| nodes_b.append(helper.make_node(node.op_type, new_inputs, node.output, name=node.name)) |
| |
| |
| inits = list(graph_a.initializer) + list(graph_b.initializer) |
| |
| |
| nodes = list(graph_a.node) + nodes_b |
| |
| |
| x = helper.make_tensor_value_info("input", DT, GRID_SHAPE) |
| y = helper.make_tensor_value_info("output", DT, GRID_SHAPE) |
| g = helper.make_graph(nodes, name, [x], [y], initializer=inits) |
| |
| try: |
| merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17)) |
| onnx.checker.check_model(merged) |
| return merged |
| except Exception: |
| return None |
|
|
|
|
| def _s_transpose(td): |
| """Transpose solver (from analytical.py pattern).""" |
| from ..onnx_helpers import mk |
| |
| exs = get_exs(td) |
| for inp, out in exs: |
| if inp.shape[0] != out.shape[1] or inp.shape[1] != out.shape[0]: |
| return None |
| if not np.array_equal(inp.T, out): |
| return None |
| |
| nodes = [helper.make_node('Transpose', ['input'], ['t_out'], perm=[0, 1, 3, 2])] |
| return mk(nodes) |
|
|
|
|
| def s_recolor_then_tile(td): |
| """color_map β tile/upscale composition.""" |
| |
| cm_model = s_color_map(td) |
| if cm_model is None: |
| return None |
| |
| |
| tile_model = s_tile(td) |
| if tile_model is not None: |
| composed = _merge_graphs(cm_model, tile_model, "recolor_then_tile") |
| if composed is not None: |
| return composed |
| |
| return None |
|
|
|
|
| def s_crop_then_transform(td): |
| """fixed_crop β rotate/flip composition.""" |
| crop_model = s_fixed_crop(td) |
| if crop_model is None: |
| return None |
| |
| for tfn in [s_flip, s_rotate]: |
| try: |
| tf_model = tfn(td) |
| if tf_model is not None: |
| composed = _merge_graphs(crop_model, tf_model, "crop_then_transform") |
| if composed is not None: |
| return composed |
| except Exception: |
| continue |
| |
| return None |
|
|