File size: 7,479 Bytes
e84bcc0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 | #!/usr/bin/env python3
"""Composition solvers β chain two analytical solvers into one ONNX graph.
These handle tasks where a single transform cannot produce the output,
but two transforms in sequence can. E.g. rotate THEN color_map.
Each composition builds one combined ONNX graph (no intermediate I/O).
"""
import numpy as np
from onnx import helper, TensorProto, numpy_helper
from ..data_loader import get_exs
from ..validators import validate
from ..constants import DT, IR, GRID_SHAPE, GH, GW
from ..config import make_opset
from .analytical import s_color_map
from .geometric import s_flip, s_rotate, s_shift, s_fixed_crop
from .tiling import s_tile, s_upscale, s_mirror_h, s_mirror_v, s_quad_mirror
def _run_solver_on_data(solver_fn, td):
"""Try a solver, return the ONNX model if it passes validation on train pairs."""
try:
model = solver_fn(td)
return model
except Exception:
return None
def s_transform_then_recolor(td):
"""Try: spatial transform β color_map. Chains two existing solver graphs.
For each (transform, color_map) pair:
1. Run transform solver β get intermediate model
2. Check if transform output passes train pairs β skip if not
3. Build combined graph: input β transform β color_map β output
Returns ONNX model or None.
"""
from ..onnx_helpers import mk
transforms = [s_flip, s_rotate, s_shift, s_mirror_h, s_mirror_v, s_quad_mirror]
for tfn in transforms:
# Try the transform first
try:
tf_model = tfn(td)
if tf_model is None:
continue
except Exception:
continue
# Now try color_map on the transform's output
# We need to check: does color_map(rotate(input)) = output?
# The simplest approach: try all combinations and validate
try:
cm_model = s_color_map(td)
if cm_model is None:
continue
except Exception:
continue
# If both individually work, the composition might not be needed
# We want cases where NEITHER works alone but BOTH work together
# For now, try composition even if one works (composition may be cheaper)
pass
# The actual composition: we need to build a merged ONNX graph
# This is complex β for now, return None (composition building below)
return None
def s_transform_then_recolor_v2(td):
"""Build composed ONNX: spatial transform graph + color_map graph merged.
Strategy: try all pairs of (transform, color_map) on the data.
Validate the composition against all train + test pairs.
This version builds the combined ONNX graph by merging nodes from
both solver outputs into a single graph with a renamed intermediate.
"""
import onnx
from ..onnx_helpers import mk
from ..profiler import score_network
import tempfile
exs = get_exs(td)
if len(exs) < 2:
return None
# Check same shape (composition of same-shape transforms)
same_shape = all(inp.shape == out.shape for inp, out in exs)
if not same_shape:
return None
transforms = [
('flip_h', lambda td: s_flip(td, direction='horizontal')),
('flip_v', lambda td: s_flip(td, direction='vertical')),
('rotate90', lambda td: s_rotate(td, k=1)),
('rotate180', lambda td: s_rotate(td, k=2)),
('rotate270', lambda td: s_rotate(td, k=3)),
('transpose', lambda td: _s_transpose(td)),
]
# Try each transform + color_map combination
for tf_name, tf_fn in transforms:
try:
tf_model = tf_fn(td)
if tf_model is None:
continue
except Exception:
continue
# Check if transform alone solves it (no need for composition)
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp:
onnx.save(tf_model, tmp.name)
if validate(tmp.name, td, ['CPUExecutionProvider']):
os.unlink(tmp.name)
continue # Transform alone works, no composition needed
# Try color_map after transform
# Build composed model by running transform, then checking color_map
cm_model = s_color_map(td)
if cm_model is None:
continue
# Merge the two ONNX graphs
composed = _merge_graphs(tf_model, cm_model, f"{tf_name}_then_recolor")
if composed is not None:
return composed
return None
def _merge_graphs(model_a, model_b, name="composed"):
"""Merge two ONNX models into a single graph.
model_a: input β intermediate
model_b: intermediate β output
The output name of model_a becomes the input name of model_b.
"""
import onnx
graph_a = model_a.graph
graph_b = model_b.graph
# Get output name of model_a
a_output = graph_a.output[0].name
# Get input name of model_b (should be "input")
b_input = graph_b.input[0].name
# Rename model_b's input to match model_a's output
nodes_b = []
for node in graph_b.node:
new_inputs = [a_output if inp == b_input else inp for inp in node.input]
nodes_b.append(helper.make_node(node.op_type, new_inputs, node.output, name=node.name))
# Combine initializers
inits = list(graph_a.initializer) + list(graph_b.initializer)
# Combine nodes
nodes = list(graph_a.node) + nodes_b
# Build merged graph
x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
g = helper.make_graph(nodes, name, [x], [y], initializer=inits)
try:
merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17))
onnx.checker.check_model(merged)
return merged
except Exception:
return None
def _s_transpose(td):
"""Transpose solver (from analytical.py pattern)."""
from ..onnx_helpers import mk
exs = get_exs(td)
for inp, out in exs:
if inp.shape[0] != out.shape[1] or inp.shape[1] != out.shape[0]:
return None
if not np.array_equal(inp.T, out):
return None
nodes = [helper.make_node('Transpose', ['input'], ['t_out'], perm=[0, 1, 3, 2])]
return mk(nodes)
def s_recolor_then_tile(td):
"""color_map β tile/upscale composition."""
# Try color_map first, then check if tiling the result works
cm_model = s_color_map(td)
if cm_model is None:
return None
# Check if tile or upscale on the color-mapped result matches output
tile_model = s_tile(td)
if tile_model is not None:
composed = _merge_graphs(cm_model, tile_model, "recolor_then_tile")
if composed is not None:
return composed
return None
def s_crop_then_transform(td):
"""fixed_crop β rotate/flip composition."""
crop_model = s_fixed_crop(td)
if crop_model is None:
return None
for tfn in [s_flip, s_rotate]:
try:
tf_model = tfn(td)
if tf_model is not None:
composed = _merge_graphs(crop_model, tf_model, "crop_then_transform")
if composed is not None:
return composed
except Exception:
continue
return None
|