File size: 4,523 Bytes
981ef11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#!/usr/bin/env python3
"""Composition solvers - chain two analytical solvers into one ONNX graph.

C1: transform_then_recolor - spatial transform THEN color_map
C2: crop_then_transform - crop THEN apply spatial transform
C3: recolor_then_tile - color_map THEN tile/upscale
"""

import numpy as np
from onnx import helper, TensorProto, numpy_helper


def s_transform_then_recolor(td, solvers):
    """Try: spatial transform -> color_map composition.
    
    Args:
        td: task data
        solvers: dict of solver_name -> solver_function
    
    Returns: ONNX model or None
    """
    from onnx_helpers import mk, GRID_SHAPE, DT, IR, make_opset
    
    spatial_solvers = ['flip', 'rotate', 'shift', 'mirror_h', 'mirror_v', 
                       'quad_mirror', 'transpose']
    
    for sname in spatial_solvers:
        sfn = solvers.get(sname)
        if sfn is None:
            continue
        try:
            tf_model = sfn(td)
            if tf_model is None:
                continue
        except Exception:
            continue
        
        # Now try color_map on the transformed data
        cm_fn = solvers.get('color_map')
        if cm_fn is None:
            continue
        try:
            cm_model = cm_fn(td)
            if cm_model is None:
                continue
        except Exception:
            continue
        
        # Merge the two ONNX graphs
        composed = _merge_graphs(tf_model, cm_model, f"{sname}_then_recolor")
        if composed is not None:
            return composed
    
    return None


def s_crop_then_transform(td, solvers):
    """Try: fixed_crop -> spatial transform composition."""
    crop_fn = solvers.get('fixed_crop')
    if crop_fn is None:
        return None
    
    try:
        crop_model = crop_fn(td)
        if crop_model is None:
            return None
    except Exception:
        return None
    
    for sname in ['flip', 'rotate', 'mirror_h', 'mirror_v']:
        sfn = solvers.get(sname)
        if sfn is None:
            continue
        try:
            tf_model = sfn(td)
            if tf_model is None:
                continue
        except Exception:
            continue
        
        composed = _merge_graphs(crop_model, tf_model, f"crop_then_{sname}")
        if composed is not None:
            return composed
    
    return None


def s_recolor_then_tile(td, solvers):
    """Try: color_map -> tile/upscale composition."""
    cm_fn = solvers.get('color_map')
    if cm_fn is None:
        return None
    
    try:
        cm_model = cm_fn(td)
        if cm_model is None:
            return None
    except Exception:
        return None
    
    for sname in ['tile', 'upscale']:
        sfn = solvers.get(sname)
        if sfn is None:
            continue
        try:
            tile_model = sfn(td)
            if tile_model is None:
                continue
        except Exception:
            continue
        
        composed = _merge_graphs(cm_model, tile_model, f"recolor_then_{sname}")
        if composed is not None:
            return composed
    
    return None


def _merge_graphs(model_a, model_b, name="composed"):
    """Merge two ONNX models into a single graph.
    
    model_a: input -> intermediate
    model_b: intermediate -> output
    """
    import onnx
    from onnx_helpers import DT, IR, GRID_SHAPE, make_opset
    
    graph_a = model_a.graph
    graph_b = model_b.graph
    
    a_output = graph_a.output[0].name
    b_input = graph_b.input[0].name
    
    # Rename model_b's input to match model_a's output
    nodes_b = []
    for node in graph_b.node:
        new_inputs = [a_output if inp == b_input else inp for inp in node.input]
        new_node = helper.make_node(
            node.op_type, new_inputs, node.output,
            name=f"{name}_{node.name}" if node.name else f"{name}_n{len(nodes_b)}"
        )
        nodes_b.append(new_node)
    
    # Combine initializers (avoid name collisions)
    inits = list(graph_a.initializer) + list(graph_b.initializer)
    
    # Combine nodes
    nodes = list(graph_a.node) + nodes_b
    
    # Build merged graph
    x = helper.make_tensor_value_info("input", DT, GRID_SHAPE)
    y = helper.make_tensor_value_info("output", DT, GRID_SHAPE)
    g = helper.make_graph(nodes, name, [x], [y], initializer=inits)
    
    try:
        merged = helper.make_model(g, ir_version=IR, opset_imports=make_opset(17))
        onnx.checker.check_model(merged)
        return merged
    except Exception:
        return None