rogermt commited on
Commit
981ef11
·
verified ·
1 Parent(s): 9caa30d

Add composition solvers for chained transforms

Browse files
Files changed (1) hide show
  1. trm_solver/composition.py +160 -0
trm_solver/composition.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Composition solvers - chain two analytical solvers into one ONNX graph.
3
+
4
+ C1: transform_then_recolor - spatial transform THEN color_map
5
+ C2: crop_then_transform - crop THEN apply spatial transform
6
+ C3: recolor_then_tile - color_map THEN tile/upscale
7
+ """
8
+
9
+ import numpy as np
10
+ from onnx import helper, TensorProto, numpy_helper
11
+
12
+
13
+ def s_transform_then_recolor(td, solvers):
14
+ """Try: spatial transform -> color_map composition.
15
+
16
+ Args:
17
+ td: task data
18
+ solvers: dict of solver_name -> solver_function
19
+
20
+ Returns: ONNX model or None
21
+ """
22
+ from onnx_helpers import mk, GRID_SHAPE, DT, IR, make_opset
23
+
24
+ spatial_solvers = ['flip', 'rotate', 'shift', 'mirror_h', 'mirror_v',
25
+ 'quad_mirror', 'transpose']
26
+
27
+ for sname in spatial_solvers:
28
+ sfn = solvers.get(sname)
29
+ if sfn is None:
30
+ continue
31
+ try:
32
+ tf_model = sfn(td)
33
+ if tf_model is None:
34
+ continue
35
+ except Exception:
36
+ continue
37
+
38
+ # Now try color_map on the transformed data
39
+ cm_fn = solvers.get('color_map')
40
+ if cm_fn is None:
41
+ continue
42
+ try:
43
+ cm_model = cm_fn(td)
44
+ if cm_model is None:
45
+ continue
46
+ except Exception:
47
+ continue
48
+
49
+ # Merge the two ONNX graphs
50
+ composed = _merge_graphs(tf_model, cm_model, f"{sname}_then_recolor")
51
+ if composed is not None:
52
+ return composed
53
+
54
+ return None
55
+
56
+
57
+ def s_crop_then_transform(td, solvers):
58
+ """Try: fixed_crop -> spatial transform composition."""
59
+ crop_fn = solvers.get('fixed_crop')
60
+ if crop_fn is None:
61
+ return None
62
+
63
+ try:
64
+ crop_model = crop_fn(td)
65
+ if crop_model is None:
66
+ return None
67
+ except Exception:
68
+ return None
69
+
70
+ for sname in ['flip', 'rotate', 'mirror_h', 'mirror_v']:
71
+ sfn = solvers.get(sname)
72
+ if sfn is None:
73
+ continue
74
+ try:
75
+ tf_model = sfn(td)
76
+ if tf_model is None:
77
+ continue
78
+ except Exception:
79
+ continue
80
+
81
+ composed = _merge_graphs(crop_model, tf_model, f"crop_then_{sname}")
82
+ if composed is not None:
83
+ return composed
84
+
85
+ return None
86
+
87
+
88
+ def s_recolor_then_tile(td, solvers):
89
+ """Try: color_map -> tile/upscale composition."""
90
+ cm_fn = solvers.get('color_map')
91
+ if cm_fn is None:
92
+ return None
93
+
94
+ try:
95
+ cm_model = cm_fn(td)
96
+ if cm_model is None:
97
+ return None
98
+ except Exception:
99
+ return None
100
+
101
+ for sname in ['tile', 'upscale']:
102
+ sfn = solvers.get(sname)
103
+ if sfn is None:
104
+ continue
105
+ try:
106
+ tile_model = sfn(td)
107
+ if tile_model is None:
108
+ continue
109
+ except Exception:
110
+ continue
111
+
112
+ composed = _merge_graphs(cm_model, tile_model, f"recolor_then_{sname}")
113
+ if composed is not None:
114
+ return composed
115
+
116
+ return None
117
+
118
+
119
+ def _merge_graphs(model_a, model_b, name="composed"):
120
+ """Merge two ONNX models into a single graph.
121
+
122
+ model_a: input -> intermediate
123
+ model_b: intermediate -> output
124
+ """
125
+ import onnx
126
+ from onnx_helpers import DT, IR, GRID_SHAPE, make_opset
127
+
128
+ graph_a = model_a.graph
129
+ graph_b = model_b.graph
130
+
131
+ a_output = graph_a.output[0].name
132
+ b_input = graph_b.input[0].name
133
+
134
+ # Rename model_b's input to match model_a's output
135
+ nodes_b = []
136
+ for node in graph_b.node:
137
+ new_inputs = [a_output if inp == b_input else inp for inp in node.input]
138
+ new_node = helper.make_node(
139
+ node.op_type, new_inputs, node.output,
140
+ name=f"{name}_{node.name}" if node.name else f"{name}_n{len(nodes_b)}"
141
+ )
142
+ nodes_b.append(new_node)
143
+
144
+ # Combine initializers (avoid name collisions)
145
+ inits = list(graph_a.initializer) + list(graph_b.initializer)
146
+
147
+ # Combine nodes
148
+ nodes = list(graph_a.node) + nodes_b
149
+
150
+ # Build merged graph
151
+ x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
152
+ y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
153
+ g = helper.make_graph(nodes, name, [x], [y], initializer=inits)
154
+
155
+ try:
156
+ merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17))
157
+ onnx.checker.check_model(merged)
158
+ return merged
159
+ except Exception:
160
+ return None