#!/usr/bin/env python3 """Composition solvers — chain two analytical solvers into one ONNX graph. These handle tasks where a single transform cannot produce the output, but two transforms in sequence can. E.g. rotate THEN color_map. Each composition builds one combined ONNX graph (no intermediate I/O). """ import numpy as np from onnx import helper, TensorProto, numpy_helper from ..data_loader import get_exs from ..validators import validate from ..constants import DT, IR, GRID_SHAPE, GH, GW from ..config import make_opset from .analytical import s_color_map from .geometric import s_flip, s_rotate, s_shift, s_fixed_crop from .tiling import s_tile, s_upscale, s_mirror_h, s_mirror_v, s_quad_mirror def _run_solver_on_data(solver_fn, td): """Try a solver, return the ONNX model if it passes validation on train pairs.""" try: model = solver_fn(td) return model except Exception: return None def s_transform_then_recolor(td): """Try: spatial transform → color_map. Chains two existing solver graphs. For each (transform, color_map) pair: 1. Run transform solver → get intermediate model 2. Check if transform output passes train pairs → skip if not 3. Build combined graph: input → transform → color_map → output Returns ONNX model or None. """ from ..onnx_helpers import mk transforms = [s_flip, s_rotate, s_shift, s_mirror_h, s_mirror_v, s_quad_mirror] for tfn in transforms: # Try the transform first try: tf_model = tfn(td) if tf_model is None: continue except Exception: continue # Now try color_map on the transform's output # We need to check: does color_map(rotate(input)) = output? # The simplest approach: try all combinations and validate try: cm_model = s_color_map(td) if cm_model is None: continue except Exception: continue # If both individually work, the composition might not be needed # We want cases where NEITHER works alone but BOTH work together # For now, try composition even if one works (composition may be cheaper) pass # The actual composition: we need to build a merged ONNX graph # This is complex — for now, return None (composition building below) return None def s_transform_then_recolor_v2(td): """Build composed ONNX: spatial transform graph + color_map graph merged. Strategy: try all pairs of (transform, color_map) on the data. Validate the composition against all train + test pairs. This version builds the combined ONNX graph by merging nodes from both solver outputs into a single graph with a renamed intermediate. """ import onnx from ..onnx_helpers import mk from ..profiler import score_network import tempfile exs = get_exs(td) if len(exs) < 2: return None # Check same shape (composition of same-shape transforms) same_shape = all(inp.shape == out.shape for inp, out in exs) if not same_shape: return None transforms = [ ('flip_h', lambda td: s_flip(td, direction='horizontal')), ('flip_v', lambda td: s_flip(td, direction='vertical')), ('rotate90', lambda td: s_rotate(td, k=1)), ('rotate180', lambda td: s_rotate(td, k=2)), ('rotate270', lambda td: s_rotate(td, k=3)), ('transpose', lambda td: _s_transpose(td)), ] # Try each transform + color_map combination for tf_name, tf_fn in transforms: try: tf_model = tf_fn(td) if tf_model is None: continue except Exception: continue # Check if transform alone solves it (no need for composition) with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp: onnx.save(tf_model, tmp.name) if validate(tmp.name, td, ['CPUExecutionProvider']): os.unlink(tmp.name) continue # Transform alone works, no composition needed # Try color_map after transform # Build composed model by running transform, then checking color_map cm_model = s_color_map(td) if cm_model is None: continue # Merge the two ONNX graphs composed = _merge_graphs(tf_model, cm_model, f"{tf_name}_then_recolor") if composed is not None: return composed return None def _merge_graphs(model_a, model_b, name="composed"): """Merge two ONNX models into a single graph. model_a: input → intermediate model_b: intermediate → output The output name of model_a becomes the input name of model_b. """ import onnx graph_a = model_a.graph graph_b = model_b.graph # Get output name of model_a a_output = graph_a.output[0].name # Get input name of model_b (should be "input") b_input = graph_b.input[0].name # Rename model_b's input to match model_a's output nodes_b = [] for node in graph_b.node: new_inputs = [a_output if inp == b_input else inp for inp in node.input] nodes_b.append(helper.make_node(node.op_type, new_inputs, node.output, name=node.name)) # Combine initializers inits = list(graph_a.initializer) + list(graph_b.initializer) # Combine nodes nodes = list(graph_a.node) + nodes_b # Build merged graph x = helper.make_tensor_value_info("input", DT, GRID_SHAPE) y = helper.make_tensor_value_info("output", DT, GRID_SHAPE) g = helper.make_graph(nodes, name, [x], [y], initializer=inits) try: merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17)) onnx.checker.check_model(merged) return merged except Exception: return None def _s_transpose(td): """Transpose solver (from analytical.py pattern).""" from ..onnx_helpers import mk exs = get_exs(td) for inp, out in exs: if inp.shape[0] != out.shape[1] or inp.shape[1] != out.shape[0]: return None if not np.array_equal(inp.T, out): return None nodes = [helper.make_node('Transpose', ['input'], ['t_out'], perm=[0, 1, 3, 2])] return mk(nodes) def s_recolor_then_tile(td): """color_map → tile/upscale composition.""" # Try color_map first, then check if tiling the result works cm_model = s_color_map(td) if cm_model is None: return None # Check if tile or upscale on the color-mapped result matches output tile_model = s_tile(td) if tile_model is not None: composed = _merge_graphs(cm_model, tile_model, "recolor_then_tile") if composed is not None: return composed return None def s_crop_then_transform(td): """fixed_crop → rotate/flip composition.""" crop_model = s_fixed_crop(td) if crop_model is None: return None for tfn in [s_flip, s_rotate]: try: tf_model = tfn(td) if tf_model is not None: composed = _merge_graphs(crop_model, tf_model, "crop_then_transform") if composed is not None: return composed except Exception: continue return None