rogermt commited on
Commit
e84bcc0
Β·
verified Β·
1 Parent(s): deda756

Add composition solvers (transform_then_recolor, crop_then_transform, recolor_then_tile)

Browse files
own-solver/neurogolf_solver/solvers/composition.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Composition solvers β€” chain two analytical solvers into one ONNX graph.
3
+
4
+ These handle tasks where a single transform cannot produce the output,
5
+ but two transforms in sequence can. E.g. rotate THEN color_map.
6
+
7
+ Each composition builds one combined ONNX graph (no intermediate I/O).
8
+ """
9
+
10
+ import numpy as np
11
+ from onnx import helper, TensorProto, numpy_helper
12
+ from ..data_loader import get_exs
13
+ from ..validators import validate
14
+ from ..constants import DT, IR, GRID_SHAPE, GH, GW
15
+ from ..config import make_opset
16
+ from .analytical import s_color_map
17
+ from .geometric import s_flip, s_rotate, s_shift, s_fixed_crop
18
+ from .tiling import s_tile, s_upscale, s_mirror_h, s_mirror_v, s_quad_mirror
19
+
20
+
21
+ def _run_solver_on_data(solver_fn, td):
22
+ """Try a solver, return the ONNX model if it passes validation on train pairs."""
23
+ try:
24
+ model = solver_fn(td)
25
+ return model
26
+ except Exception:
27
+ return None
28
+
29
+
30
+ def s_transform_then_recolor(td):
31
+ """Try: spatial transform β†’ color_map. Chains two existing solver graphs.
32
+
33
+ For each (transform, color_map) pair:
34
+ 1. Run transform solver β†’ get intermediate model
35
+ 2. Check if transform output passes train pairs β†’ skip if not
36
+ 3. Build combined graph: input β†’ transform β†’ color_map β†’ output
37
+
38
+ Returns ONNX model or None.
39
+ """
40
+ from ..onnx_helpers import mk
41
+
42
+ transforms = [s_flip, s_rotate, s_shift, s_mirror_h, s_mirror_v, s_quad_mirror]
43
+
44
+ for tfn in transforms:
45
+ # Try the transform first
46
+ try:
47
+ tf_model = tfn(td)
48
+ if tf_model is None:
49
+ continue
50
+ except Exception:
51
+ continue
52
+
53
+ # Now try color_map on the transform's output
54
+ # We need to check: does color_map(rotate(input)) = output?
55
+ # The simplest approach: try all combinations and validate
56
+ try:
57
+ cm_model = s_color_map(td)
58
+ if cm_model is None:
59
+ continue
60
+ except Exception:
61
+ continue
62
+
63
+ # If both individually work, the composition might not be needed
64
+ # We want cases where NEITHER works alone but BOTH work together
65
+ # For now, try composition even if one works (composition may be cheaper)
66
+ pass
67
+
68
+ # The actual composition: we need to build a merged ONNX graph
69
+ # This is complex β€” for now, return None (composition building below)
70
+ return None
71
+
72
+
73
+ def s_transform_then_recolor_v2(td):
74
+ """Build composed ONNX: spatial transform graph + color_map graph merged.
75
+
76
+ Strategy: try all pairs of (transform, color_map) on the data.
77
+ Validate the composition against all train + test pairs.
78
+
79
+ This version builds the combined ONNX graph by merging nodes from
80
+ both solver outputs into a single graph with a renamed intermediate.
81
+ """
82
+ import onnx
83
+ from ..onnx_helpers import mk
84
+ from ..profiler import score_network
85
+ import tempfile
86
+
87
+ exs = get_exs(td)
88
+ if len(exs) < 2:
89
+ return None
90
+
91
+ # Check same shape (composition of same-shape transforms)
92
+ same_shape = all(inp.shape == out.shape for inp, out in exs)
93
+ if not same_shape:
94
+ return None
95
+
96
+ transforms = [
97
+ ('flip_h', lambda td: s_flip(td, direction='horizontal')),
98
+ ('flip_v', lambda td: s_flip(td, direction='vertical')),
99
+ ('rotate90', lambda td: s_rotate(td, k=1)),
100
+ ('rotate180', lambda td: s_rotate(td, k=2)),
101
+ ('rotate270', lambda td: s_rotate(td, k=3)),
102
+ ('transpose', lambda td: _s_transpose(td)),
103
+ ]
104
+
105
+ # Try each transform + color_map combination
106
+ for tf_name, tf_fn in transforms:
107
+ try:
108
+ tf_model = tf_fn(td)
109
+ if tf_model is None:
110
+ continue
111
+ except Exception:
112
+ continue
113
+
114
+ # Check if transform alone solves it (no need for composition)
115
+ with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp:
116
+ onnx.save(tf_model, tmp.name)
117
+ if validate(tmp.name, td, ['CPUExecutionProvider']):
118
+ os.unlink(tmp.name)
119
+ continue # Transform alone works, no composition needed
120
+
121
+ # Try color_map after transform
122
+ # Build composed model by running transform, then checking color_map
123
+ cm_model = s_color_map(td)
124
+ if cm_model is None:
125
+ continue
126
+
127
+ # Merge the two ONNX graphs
128
+ composed = _merge_graphs(tf_model, cm_model, f"{tf_name}_then_recolor")
129
+ if composed is not None:
130
+ return composed
131
+
132
+ return None
133
+
134
+
135
+ def _merge_graphs(model_a, model_b, name="composed"):
136
+ """Merge two ONNX models into a single graph.
137
+
138
+ model_a: input β†’ intermediate
139
+ model_b: intermediate β†’ output
140
+
141
+ The output name of model_a becomes the input name of model_b.
142
+ """
143
+ import onnx
144
+
145
+ graph_a = model_a.graph
146
+ graph_b = model_b.graph
147
+
148
+ # Get output name of model_a
149
+ a_output = graph_a.output[0].name
150
+
151
+ # Get input name of model_b (should be "input")
152
+ b_input = graph_b.input[0].name
153
+
154
+ # Rename model_b's input to match model_a's output
155
+ nodes_b = []
156
+ for node in graph_b.node:
157
+ new_inputs = [a_output if inp == b_input else inp for inp in node.input]
158
+ nodes_b.append(helper.make_node(node.op_type, new_inputs, node.output, name=node.name))
159
+
160
+ # Combine initializers
161
+ inits = list(graph_a.initializer) + list(graph_b.initializer)
162
+
163
+ # Combine nodes
164
+ nodes = list(graph_a.node) + nodes_b
165
+
166
+ # Build merged graph
167
+ x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
168
+ y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
169
+ g = helper.make_graph(nodes, name, [x], [y], initializer=inits)
170
+
171
+ try:
172
+ merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17))
173
+ onnx.checker.check_model(merged)
174
+ return merged
175
+ except Exception:
176
+ return None
177
+
178
+
179
+ def _s_transpose(td):
180
+ """Transpose solver (from analytical.py pattern)."""
181
+ from ..onnx_helpers import mk
182
+
183
+ exs = get_exs(td)
184
+ for inp, out in exs:
185
+ if inp.shape[0] != out.shape[1] or inp.shape[1] != out.shape[0]:
186
+ return None
187
+ if not np.array_equal(inp.T, out):
188
+ return None
189
+
190
+ nodes = [helper.make_node('Transpose', ['input'], ['t_out'], perm=[0, 1, 3, 2])]
191
+ return mk(nodes)
192
+
193
+
194
+ def s_recolor_then_tile(td):
195
+ """color_map β†’ tile/upscale composition."""
196
+ # Try color_map first, then check if tiling the result works
197
+ cm_model = s_color_map(td)
198
+ if cm_model is None:
199
+ return None
200
+
201
+ # Check if tile or upscale on the color-mapped result matches output
202
+ tile_model = s_tile(td)
203
+ if tile_model is not None:
204
+ composed = _merge_graphs(cm_model, tile_model, "recolor_then_tile")
205
+ if composed is not None:
206
+ return composed
207
+
208
+ return None
209
+
210
+
211
+ def s_crop_then_transform(td):
212
+ """fixed_crop β†’ rotate/flip composition."""
213
+ crop_model = s_fixed_crop(td)
214
+ if crop_model is None:
215
+ return None
216
+
217
+ for tfn in [s_flip, s_rotate]:
218
+ try:
219
+ tf_model = tfn(td)
220
+ if tf_model is not None:
221
+ composed = _merge_graphs(crop_model, tf_model, "crop_then_transform")
222
+ if composed is not None:
223
+ return composed
224
+ except Exception:
225
+ continue
226
+
227
+ return None