File size: 7,479 Bytes
e84bcc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
#!/usr/bin/env python3
"""Composition solvers β€” chain two analytical solvers into one ONNX graph.

These handle tasks where a single transform cannot produce the output,
but two transforms in sequence can. E.g. rotate THEN color_map.

Each composition builds one combined ONNX graph (no intermediate I/O).
"""

import numpy as np
from onnx import helper, TensorProto, numpy_helper
from ..data_loader import get_exs
from ..validators import validate
from ..constants import DT, IR, GRID_SHAPE, GH, GW
from ..config import make_opset
from .analytical import s_color_map
from .geometric import s_flip, s_rotate, s_shift, s_fixed_crop
from .tiling import s_tile, s_upscale, s_mirror_h, s_mirror_v, s_quad_mirror


def _run_solver_on_data(solver_fn, td):
    """Try a solver, return the ONNX model if it passes validation on train pairs."""
    try:
        model = solver_fn(td)
        return model
    except Exception:
        return None


def s_transform_then_recolor(td):
    """Try: spatial transform β†’ color_map. Chains two existing solver graphs.
    
    For each (transform, color_map) pair:
    1. Run transform solver β†’ get intermediate model
    2. Check if transform output passes train pairs β†’ skip if not
    3. Build combined graph: input β†’ transform β†’ color_map β†’ output
    
    Returns ONNX model or None.
    """
    from ..onnx_helpers import mk
    
    transforms = [s_flip, s_rotate, s_shift, s_mirror_h, s_mirror_v, s_quad_mirror]
    
    for tfn in transforms:
        # Try the transform first
        try:
            tf_model = tfn(td)
            if tf_model is None:
                continue
        except Exception:
            continue
        
        # Now try color_map on the transform's output
        # We need to check: does color_map(rotate(input)) = output?
        # The simplest approach: try all combinations and validate
        try:
            cm_model = s_color_map(td)
            if cm_model is None:
                continue
        except Exception:
            continue
        
        # If both individually work, the composition might not be needed
        # We want cases where NEITHER works alone but BOTH work together
        # For now, try composition even if one works (composition may be cheaper)
        pass
    
    # The actual composition: we need to build a merged ONNX graph
    # This is complex β€” for now, return None (composition building below)
    return None


def s_transform_then_recolor_v2(td):
    """Build composed ONNX: spatial transform graph + color_map graph merged.
    
    Strategy: try all pairs of (transform, color_map) on the data.
    Validate the composition against all train + test pairs.
    
    This version builds the combined ONNX graph by merging nodes from
    both solver outputs into a single graph with a renamed intermediate.
    """
    import onnx
    from ..onnx_helpers import mk
    from ..profiler import score_network
    import tempfile
    
    exs = get_exs(td)
    if len(exs) < 2:
        return None
    
    # Check same shape (composition of same-shape transforms)
    same_shape = all(inp.shape == out.shape for inp, out in exs)
    if not same_shape:
        return None
    
    transforms = [
        ('flip_h', lambda td: s_flip(td, direction='horizontal')),
        ('flip_v', lambda td: s_flip(td, direction='vertical')),
        ('rotate90', lambda td: s_rotate(td, k=1)),
        ('rotate180', lambda td: s_rotate(td, k=2)),
        ('rotate270', lambda td: s_rotate(td, k=3)),
        ('transpose', lambda td: _s_transpose(td)),
    ]
    
    # Try each transform + color_map combination
    for tf_name, tf_fn in transforms:
        try:
            tf_model = tf_fn(td)
            if tf_model is None:
                continue
        except Exception:
            continue
        
        # Check if transform alone solves it (no need for composition)
        with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as tmp:
            onnx.save(tf_model, tmp.name)
            if validate(tmp.name, td, ['CPUExecutionProvider']):
                os.unlink(tmp.name)
                continue  # Transform alone works, no composition needed
        
        # Try color_map after transform
        # Build composed model by running transform, then checking color_map
        cm_model = s_color_map(td)
        if cm_model is None:
            continue
        
        # Merge the two ONNX graphs
        composed = _merge_graphs(tf_model, cm_model, f"{tf_name}_then_recolor")
        if composed is not None:
            return composed
    
    return None


def _merge_graphs(model_a, model_b, name="composed"):
    """Merge two ONNX models into a single graph.
    
    model_a: input β†’ intermediate
    model_b: intermediate β†’ output
    
    The output name of model_a becomes the input name of model_b.
    """
    import onnx
    
    graph_a = model_a.graph
    graph_b = model_b.graph
    
    # Get output name of model_a
    a_output = graph_a.output[0].name
    
    # Get input name of model_b (should be "input")
    b_input = graph_b.input[0].name
    
    # Rename model_b's input to match model_a's output
    nodes_b = []
    for node in graph_b.node:
        new_inputs = [a_output if inp == b_input else inp for inp in node.input]
        nodes_b.append(helper.make_node(node.op_type, new_inputs, node.output, name=node.name))
    
    # Combine initializers
    inits = list(graph_a.initializer) + list(graph_b.initializer)
    
    # Combine nodes
    nodes = list(graph_a.node) + nodes_b
    
    # Build merged graph
    x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
    y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
    g = helper.make_graph(nodes, name, [x], [y], initializer=inits)
    
    try:
        merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17))
        onnx.checker.check_model(merged)
        return merged
    except Exception:
        return None


def _s_transpose(td):
    """Transpose solver (from analytical.py pattern)."""
    from ..onnx_helpers import mk
    
    exs = get_exs(td)
    for inp, out in exs:
        if inp.shape[0] != out.shape[1] or inp.shape[1] != out.shape[0]:
            return None
        if not np.array_equal(inp.T, out):
            return None
    
    nodes = [helper.make_node('Transpose', ['input'], ['t_out'], perm=[0, 1, 3, 2])]
    return mk(nodes)


def s_recolor_then_tile(td):
    """color_map β†’ tile/upscale composition."""
    # Try color_map first, then check if tiling the result works
    cm_model = s_color_map(td)
    if cm_model is None:
        return None
    
    # Check if tile or upscale on the color-mapped result matches output
    tile_model = s_tile(td)
    if tile_model is not None:
        composed = _merge_graphs(cm_model, tile_model, "recolor_then_tile")
        if composed is not None:
            return composed
    
    return None


def s_crop_then_transform(td):
    """fixed_crop β†’ rotate/flip composition."""
    crop_model = s_fixed_crop(td)
    if crop_model is None:
        return None
    
    for tfn in [s_flip, s_rotate]:
        try:
            tf_model = tfn(td)
            if tf_model is not None:
                composed = _merge_graphs(crop_model, tf_model, "crop_then_transform")
                if composed is not None:
                    return composed
        except Exception:
            continue
    
    return None