ARC-AGI / trm_solver /composition.py
rogermt's picture
Add composition solvers for chained transforms
981ef11 verified
#!/usr/bin/env python3
"""Composition solvers - chain two analytical solvers into one ONNX graph.
C1: transform_then_recolor - spatial transform THEN color_map
C2: crop_then_transform - crop THEN apply spatial transform
C3: recolor_then_tile - color_map THEN tile/upscale
"""
import numpy as np
from onnx import helper, TensorProto, numpy_helper
def s_transform_then_recolor(td, solvers):
"""Try: spatial transform -> color_map composition.
Args:
td: task data
solvers: dict of solver_name -> solver_function
Returns: ONNX model or None
"""
from onnx_helpers import mk, GRID_SHAPE, DT, IR, make_opset
spatial_solvers = ['flip', 'rotate', 'shift', 'mirror_h', 'mirror_v',
'quad_mirror', 'transpose']
for sname in spatial_solvers:
sfn = solvers.get(sname)
if sfn is None:
continue
try:
tf_model = sfn(td)
if tf_model is None:
continue
except Exception:
continue
# Now try color_map on the transformed data
cm_fn = solvers.get('color_map')
if cm_fn is None:
continue
try:
cm_model = cm_fn(td)
if cm_model is None:
continue
except Exception:
continue
# Merge the two ONNX graphs
composed = _merge_graphs(tf_model, cm_model, f"{sname}_then_recolor")
if composed is not None:
return composed
return None
def s_crop_then_transform(td, solvers):
"""Try: fixed_crop -> spatial transform composition."""
crop_fn = solvers.get('fixed_crop')
if crop_fn is None:
return None
try:
crop_model = crop_fn(td)
if crop_model is None:
return None
except Exception:
return None
for sname in ['flip', 'rotate', 'mirror_h', 'mirror_v']:
sfn = solvers.get(sname)
if sfn is None:
continue
try:
tf_model = sfn(td)
if tf_model is None:
continue
except Exception:
continue
composed = _merge_graphs(crop_model, tf_model, f"crop_then_{sname}")
if composed is not None:
return composed
return None
def s_recolor_then_tile(td, solvers):
"""Try: color_map -> tile/upscale composition."""
cm_fn = solvers.get('color_map')
if cm_fn is None:
return None
try:
cm_model = cm_fn(td)
if cm_model is None:
return None
except Exception:
return None
for sname in ['tile', 'upscale']:
sfn = solvers.get(sname)
if sfn is None:
continue
try:
tile_model = sfn(td)
if tile_model is None:
continue
except Exception:
continue
composed = _merge_graphs(cm_model, tile_model, f"recolor_then_{sname}")
if composed is not None:
return composed
return None
def _merge_graphs(model_a, model_b, name="composed"):
"""Merge two ONNX models into a single graph.
model_a: input -> intermediate
model_b: intermediate -> output
"""
import onnx
from onnx_helpers import DT, IR, GRID_SHAPE, make_opset
graph_a = model_a.graph
graph_b = model_b.graph
a_output = graph_a.output[0].name
b_input = graph_b.input[0].name
# Rename model_b's input to match model_a's output
nodes_b = []
for node in graph_b.node:
new_inputs = [a_output if inp == b_input else inp for inp in node.input]
new_node = helper.make_node(
node.op_type, new_inputs, node.output,
name=f"{name}_{node.name}" if node.name else f"{name}_n{len(nodes_b)}"
)
nodes_b.append(new_node)
# Combine initializers (avoid name collisions)
inits = list(graph_a.initializer) + list(graph_b.initializer)
# Combine nodes
nodes = list(graph_a.node) + nodes_b
# Build merged graph
x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
g = helper.make_graph(nodes, name, [x], [y], initializer=inits)
try:
merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17))
onnx.checker.check_model(merged)
return merged
except Exception:
return None