#!/usr/bin/env python3 """Composition solvers - chain two analytical solvers into one ONNX graph. C1: transform_then_recolor - spatial transform THEN color_map C2: crop_then_transform - crop THEN apply spatial transform C3: recolor_then_tile - color_map THEN tile/upscale """ import numpy as np from onnx import helper, TensorProto, numpy_helper def s_transform_then_recolor(td, solvers): """Try: spatial transform -> color_map composition. Args: td: task data solvers: dict of solver_name -> solver_function Returns: ONNX model or None """ from onnx_helpers import mk, GRID_SHAPE, DT, IR, make_opset spatial_solvers = ['flip', 'rotate', 'shift', 'mirror_h', 'mirror_v', 'quad_mirror', 'transpose'] for sname in spatial_solvers: sfn = solvers.get(sname) if sfn is None: continue try: tf_model = sfn(td) if tf_model is None: continue except Exception: continue # Now try color_map on the transformed data cm_fn = solvers.get('color_map') if cm_fn is None: continue try: cm_model = cm_fn(td) if cm_model is None: continue except Exception: continue # Merge the two ONNX graphs composed = _merge_graphs(tf_model, cm_model, f"{sname}_then_recolor") if composed is not None: return composed return None def s_crop_then_transform(td, solvers): """Try: fixed_crop -> spatial transform composition.""" crop_fn = solvers.get('fixed_crop') if crop_fn is None: return None try: crop_model = crop_fn(td) if crop_model is None: return None except Exception: return None for sname in ['flip', 'rotate', 'mirror_h', 'mirror_v']: sfn = solvers.get(sname) if sfn is None: continue try: tf_model = sfn(td) if tf_model is None: continue except Exception: continue composed = _merge_graphs(crop_model, tf_model, f"crop_then_{sname}") if composed is not None: return composed return None def s_recolor_then_tile(td, solvers): """Try: color_map -> tile/upscale composition.""" cm_fn = solvers.get('color_map') if cm_fn is None: return None try: cm_model = cm_fn(td) if cm_model is None: return None except Exception: return None for sname in ['tile', 'upscale']: sfn = solvers.get(sname) if sfn is None: continue try: tile_model = sfn(td) if tile_model is None: continue except Exception: continue composed = _merge_graphs(cm_model, tile_model, f"recolor_then_{sname}") if composed is not None: return composed return None def _merge_graphs(model_a, model_b, name="composed"): """Merge two ONNX models into a single graph. model_a: input -> intermediate model_b: intermediate -> output """ import onnx from onnx_helpers import DT, IR, GRID_SHAPE, make_opset graph_a = model_a.graph graph_b = model_b.graph a_output = graph_a.output[0].name b_input = graph_b.input[0].name # Rename model_b's input to match model_a's output nodes_b = [] for node in graph_b.node: new_inputs = [a_output if inp == b_input else inp for inp in node.input] new_node = helper.make_node( node.op_type, new_inputs, node.output, name=f"{name}_{node.name}" if node.name else f"{name}_n{len(nodes_b)}" ) nodes_b.append(new_node) # Combine initializers (avoid name collisions) inits = list(graph_a.initializer) + list(graph_b.initializer) # Combine nodes nodes = list(graph_a.node) + nodes_b # Build merged graph x = helper.make_tensor_value_info("input", DT, GRID_SHAPE) y = helper.make_tensor_value_info("output", DT, GRID_SHAPE) g = helper.make_graph(nodes, name, [x], [y], initializer=inits) try: merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17)) onnx.checker.check_model(merged) return merged except Exception: return None