| |
| """Composition solvers - chain two analytical solvers into one ONNX graph. |
| |
| C1: transform_then_recolor - spatial transform THEN color_map |
| C2: crop_then_transform - crop THEN apply spatial transform |
| C3: recolor_then_tile - color_map THEN tile/upscale |
| """ |
|
|
| import numpy as np |
| from onnx import helper, TensorProto, numpy_helper |
|
|
|
|
| def s_transform_then_recolor(td, solvers): |
| """Try: spatial transform -> color_map composition. |
| |
| Args: |
| td: task data |
| solvers: dict of solver_name -> solver_function |
| |
| Returns: ONNX model or None |
| """ |
| from onnx_helpers import mk, GRID_SHAPE, DT, IR, make_opset |
| |
| spatial_solvers = ['flip', 'rotate', 'shift', 'mirror_h', 'mirror_v', |
| 'quad_mirror', 'transpose'] |
| |
| for sname in spatial_solvers: |
| sfn = solvers.get(sname) |
| if sfn is None: |
| continue |
| try: |
| tf_model = sfn(td) |
| if tf_model is None: |
| continue |
| except Exception: |
| continue |
| |
| |
| cm_fn = solvers.get('color_map') |
| if cm_fn is None: |
| continue |
| try: |
| cm_model = cm_fn(td) |
| if cm_model is None: |
| continue |
| except Exception: |
| continue |
| |
| |
| composed = _merge_graphs(tf_model, cm_model, f"{sname}_then_recolor") |
| if composed is not None: |
| return composed |
| |
| return None |
|
|
|
|
| def s_crop_then_transform(td, solvers): |
| """Try: fixed_crop -> spatial transform composition.""" |
| crop_fn = solvers.get('fixed_crop') |
| if crop_fn is None: |
| return None |
| |
| try: |
| crop_model = crop_fn(td) |
| if crop_model is None: |
| return None |
| except Exception: |
| return None |
| |
| for sname in ['flip', 'rotate', 'mirror_h', 'mirror_v']: |
| sfn = solvers.get(sname) |
| if sfn is None: |
| continue |
| try: |
| tf_model = sfn(td) |
| if tf_model is None: |
| continue |
| except Exception: |
| continue |
| |
| composed = _merge_graphs(crop_model, tf_model, f"crop_then_{sname}") |
| if composed is not None: |
| return composed |
| |
| return None |
|
|
|
|
| def s_recolor_then_tile(td, solvers): |
| """Try: color_map -> tile/upscale composition.""" |
| cm_fn = solvers.get('color_map') |
| if cm_fn is None: |
| return None |
| |
| try: |
| cm_model = cm_fn(td) |
| if cm_model is None: |
| return None |
| except Exception: |
| return None |
| |
| for sname in ['tile', 'upscale']: |
| sfn = solvers.get(sname) |
| if sfn is None: |
| continue |
| try: |
| tile_model = sfn(td) |
| if tile_model is None: |
| continue |
| except Exception: |
| continue |
| |
| composed = _merge_graphs(cm_model, tile_model, f"recolor_then_{sname}") |
| if composed is not None: |
| return composed |
| |
| return None |
|
|
|
|
| def _merge_graphs(model_a, model_b, name="composed"): |
| """Merge two ONNX models into a single graph. |
| |
| model_a: input -> intermediate |
| model_b: intermediate -> output |
| """ |
| import onnx |
| from onnx_helpers import DT, IR, GRID_SHAPE, make_opset |
| |
| graph_a = model_a.graph |
| graph_b = model_b.graph |
| |
| a_output = graph_a.output[0].name |
| b_input = graph_b.input[0].name |
| |
| |
| nodes_b = [] |
| for node in graph_b.node: |
| new_inputs = [a_output if inp == b_input else inp for inp in node.input] |
| new_node = helper.make_node( |
| node.op_type, new_inputs, node.output, |
| name=f"{name}_{node.name}" if node.name else f"{name}_n{len(nodes_b)}" |
| ) |
| nodes_b.append(new_node) |
| |
| |
| inits = list(graph_a.initializer) + list(graph_b.initializer) |
| |
| |
| nodes = list(graph_a.node) + nodes_b |
| |
| |
| x = helper.make_tensor_value_info("input", DT, GRID_SHAPE) |
| y = helper.make_tensor_value_info("output", DT, GRID_SHAPE) |
| g = helper.make_graph(nodes, name, [x], [y], initializer=inits) |
| |
| try: |
| merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17)) |
| onnx.checker.check_model(merged) |
| return merged |
| except Exception: |
| return None |
|
|