File size: 4,523 Bytes
981ef11 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | #!/usr/bin/env python3
"""Composition solvers - chain two analytical solvers into one ONNX graph.
C1: transform_then_recolor - spatial transform THEN color_map
C2: crop_then_transform - crop THEN apply spatial transform
C3: recolor_then_tile - color_map THEN tile/upscale
"""
import numpy as np
from onnx import helper, TensorProto, numpy_helper
def s_transform_then_recolor(td, solvers):
"""Try: spatial transform -> color_map composition.
Args:
td: task data
solvers: dict of solver_name -> solver_function
Returns: ONNX model or None
"""
from onnx_helpers import mk, GRID_SHAPE, DT, IR, make_opset
spatial_solvers = ['flip', 'rotate', 'shift', 'mirror_h', 'mirror_v',
'quad_mirror', 'transpose']
for sname in spatial_solvers:
sfn = solvers.get(sname)
if sfn is None:
continue
try:
tf_model = sfn(td)
if tf_model is None:
continue
except Exception:
continue
# Now try color_map on the transformed data
cm_fn = solvers.get('color_map')
if cm_fn is None:
continue
try:
cm_model = cm_fn(td)
if cm_model is None:
continue
except Exception:
continue
# Merge the two ONNX graphs
composed = _merge_graphs(tf_model, cm_model, f"{sname}_then_recolor")
if composed is not None:
return composed
return None
def s_crop_then_transform(td, solvers):
"""Try: fixed_crop -> spatial transform composition."""
crop_fn = solvers.get('fixed_crop')
if crop_fn is None:
return None
try:
crop_model = crop_fn(td)
if crop_model is None:
return None
except Exception:
return None
for sname in ['flip', 'rotate', 'mirror_h', 'mirror_v']:
sfn = solvers.get(sname)
if sfn is None:
continue
try:
tf_model = sfn(td)
if tf_model is None:
continue
except Exception:
continue
composed = _merge_graphs(crop_model, tf_model, f"crop_then_{sname}")
if composed is not None:
return composed
return None
def s_recolor_then_tile(td, solvers):
"""Try: color_map -> tile/upscale composition."""
cm_fn = solvers.get('color_map')
if cm_fn is None:
return None
try:
cm_model = cm_fn(td)
if cm_model is None:
return None
except Exception:
return None
for sname in ['tile', 'upscale']:
sfn = solvers.get(sname)
if sfn is None:
continue
try:
tile_model = sfn(td)
if tile_model is None:
continue
except Exception:
continue
composed = _merge_graphs(cm_model, tile_model, f"recolor_then_{sname}")
if composed is not None:
return composed
return None
def _merge_graphs(model_a, model_b, name="composed"):
"""Merge two ONNX models into a single graph.
model_a: input -> intermediate
model_b: intermediate -> output
"""
import onnx
from onnx_helpers import DT, IR, GRID_SHAPE, make_opset
graph_a = model_a.graph
graph_b = model_b.graph
a_output = graph_a.output[0].name
b_input = graph_b.input[0].name
# Rename model_b's input to match model_a's output
nodes_b = []
for node in graph_b.node:
new_inputs = [a_output if inp == b_input else inp for inp in node.input]
new_node = helper.make_node(
node.op_type, new_inputs, node.output,
name=f"{name}_{node.name}" if node.name else f"{name}_n{len(nodes_b)}"
)
nodes_b.append(new_node)
# Combine initializers (avoid name collisions)
inits = list(graph_a.initializer) + list(graph_b.initializer)
# Combine nodes
nodes = list(graph_a.node) + nodes_b
# Build merged graph
x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
g = helper.make_graph(nodes, name, [x], [y], initializer=inits)
try:
merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17))
onnx.checker.check_model(merged)
return merged
except Exception:
return None
|