rogermt's picture
Add composition solvers (transform_then_recolor, crop_then_transform, recolor_then_tile)
e84bcc0 verified
#!/usr/bin/env python3
"""Composition solvers β€” chain two analytical solvers into one ONNX graph.
These handle tasks where a single transform cannot produce the output,
but two transforms in sequence can. E.g. rotate THEN color_map.
Each composition builds one combined ONNX graph (no intermediate I/O).
"""
import numpy as np
from onnx import helper, TensorProto, numpy_helper
from ..data_loader import get_exs
from ..validators import validate
from ..constants import DT, IR, GRID_SHAPE, GH, GW
from ..config import make_opset
from .analytical import s_color_map
from .geometric import s_flip, s_rotate, s_shift, s_fixed_crop
from .tiling import s_tile, s_upscale, s_mirror_h, s_mirror_v, s_quad_mirror
def _run_solver_on_data(solver_fn, td):
"""Try a solver, return the ONNX model if it passes validation on train pairs."""
try:
model = solver_fn(td)
return model
except Exception:
return None
def s_transform_then_recolor(td):
"""Try: spatial transform β†’ color_map. Chains two existing solver graphs.
For each (transform, color_map) pair:
1. Run transform solver β†’ get intermediate model
2. Check if transform output passes train pairs β†’ skip if not
3. Build combined graph: input β†’ transform β†’ color_map β†’ output
Returns ONNX model or None.
"""
from ..onnx_helpers import mk
transforms = [s_flip, s_rotate, s_shift, s_mirror_h, s_mirror_v, s_quad_mirror]
for tfn in transforms:
# Try the transform first
try:
tf_model = tfn(td)
if tf_model is None:
continue
except Exception:
continue
# Now try color_map on the transform's output
# We need to check: does color_map(rotate(input)) = output?
# The simplest approach: try all combinations and validate
try:
cm_model = s_color_map(td)
if cm_model is None:
continue
except Exception:
continue
# If both individually work, the composition might not be needed
# We want cases where NEITHER works alone but BOTH work together
# For now, try composition even if one works (composition may be cheaper)
pass
# The actual composition: we need to build a merged ONNX graph
# This is complex β€” for now, return None (composition building below)
return None
def s_transform_then_recolor_v2(td):
"""Build composed ONNX: spatial transform graph + color_map graph merged.
Strategy: try all pairs of (transform, color_map) on the data.
Validate the composition against all train + test pairs.
This version builds the combined ONNX graph by merging nodes from
both solver outputs into a single graph with a renamed intermediate.
"""
import onnx
from ..onnx_helpers import mk
from ..profiler import score_network
import tempfile
exs = get_exs(td)
if len(exs) < 2:
return None
# Check same shape (composition of same-shape transforms)
same_shape = all(inp.shape == out.shape for inp, out in exs)
if not same_shape:
return None
transforms = [
('flip_h', lambda td: s_flip(td, direction='horizontal')),
('flip_v', lambda td: s_flip(td, direction='vertical')),
('rotate90', lambda td: s_rotate(td, k=1)),
('rotate180', lambda td: s_rotate(td, k=2)),
('rotate270', lambda td: s_rotate(td, k=3)),
('transpose', lambda td: _s_transpose(td)),
]
# Try each transform + color_map combination
for tf_name, tf_fn in transforms:
try:
tf_model = tf_fn(td)
if tf_model is None:
continue
except Exception:
continue
# Check if transform alone solves it (no need for composition)
with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp:
onnx.save(tf_model, tmp.name)
if validate(tmp.name, td, ['CPUExecutionProvider']):
os.unlink(tmp.name)
continue # Transform alone works, no composition needed
# Try color_map after transform
# Build composed model by running transform, then checking color_map
cm_model = s_color_map(td)
if cm_model is None:
continue
# Merge the two ONNX graphs
composed = _merge_graphs(tf_model, cm_model, f"{tf_name}_then_recolor")
if composed is not None:
return composed
return None
def _merge_graphs(model_a, model_b, name="composed"):
"""Merge two ONNX models into a single graph.
model_a: input β†’ intermediate
model_b: intermediate β†’ output
The output name of model_a becomes the input name of model_b.
"""
import onnx
graph_a = model_a.graph
graph_b = model_b.graph
# Get output name of model_a
a_output = graph_a.output[0].name
# Get input name of model_b (should be "input")
b_input = graph_b.input[0].name
# Rename model_b's input to match model_a's output
nodes_b = []
for node in graph_b.node:
new_inputs = [a_output if inp == b_input else inp for inp in node.input]
nodes_b.append(helper.make_node(node.op_type, new_inputs, node.output, name=node.name))
# Combine initializers
inits = list(graph_a.initializer) + list(graph_b.initializer)
# Combine nodes
nodes = list(graph_a.node) + nodes_b
# Build merged graph
x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
g = helper.make_graph(nodes, name, [x], [y], initializer=inits)
try:
merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17))
onnx.checker.check_model(merged)
return merged
except Exception:
return None
def _s_transpose(td):
"""Transpose solver (from analytical.py pattern)."""
from ..onnx_helpers import mk
exs = get_exs(td)
for inp, out in exs:
if inp.shape[0] != out.shape[1] or inp.shape[1] != out.shape[0]:
return None
if not np.array_equal(inp.T, out):
return None
nodes = [helper.make_node('Transpose', ['input'], ['t_out'], perm=[0, 1, 3, 2])]
return mk(nodes)
def s_recolor_then_tile(td):
"""color_map β†’ tile/upscale composition."""
# Try color_map first, then check if tiling the result works
cm_model = s_color_map(td)
if cm_model is None:
return None
# Check if tile or upscale on the color-mapped result matches output
tile_model = s_tile(td)
if tile_model is not None:
composed = _merge_graphs(cm_model, tile_model, "recolor_then_tile")
if composed is not None:
return composed
return None
def s_crop_then_transform(td):
"""fixed_crop β†’ rotate/flip composition."""
crop_model = s_fixed_crop(td)
if crop_model is None:
return None
for tfn in [s_flip, s_rotate]:
try:
tf_model = tfn(td)
if tf_model is not None:
composed = _merge_graphs(crop_model, tf_model, "crop_then_transform")
if composed is not None:
return composed
except Exception:
continue
return None