rogermt commited on
Commit
f1edb04
·
verified ·
1 Parent(s): 32b1e49

Fix self_tile: use ConvTranspose for block-upscale (onnx_tool profiles ConvTranspose correctly)

Browse files
own-solver/neurogolf_solver/solvers/wave3.py CHANGED
@@ -1,16 +1,9 @@
1
  #!/usr/bin/env python3
2
  """Wave 3: Split-and-merge + Self-tile solvers.
3
 
4
- - s_split_and_merge: split grid at separator, merge halves via OR/AND/XOR
5
- - s_self_tile: output = input tiled by its own non-bg mask
6
- - s_self_tile_majority: output = input tiled by majority-color mask
7
 
8
- NOTE: Uses Reshape+Tile for block-upscaling instead of Resize (onnx_tool can't profile Resize).
9
-
10
- Validated:
11
- - e98196ab (OR, horizontal split) — 262/262
12
- - 007bbfb7 (self-tile non-bg mask) — 262/262
13
- - c3e719e8 (self-tile majority mask) — 262/262
14
  """
15
 
16
  import numpy as np
@@ -20,28 +13,17 @@ from ..data_loader import get_exs, fixed_shapes
20
  from ..constants import GH, GW
21
 
22
 
23
- # =============================================================================
24
- # SELF-TILE SOLVER (non-bg mask)
25
- # =============================================================================
26
-
27
  def s_self_tile(td):
28
- """Self-tiling: output = input tiled by its own non-bg mask.
29
-
30
- Each non-bg pixel in input becomes a full copy of the input.
31
- Each bg pixel becomes a block of zeros.
32
- Output size = IH*IH x IW*IW.
33
- """
34
  exs = get_exs(td)
35
  sp = fixed_shapes(td)
36
  if sp is None:
37
  return None
38
  (IH, IW), (OH, OW) = sp
39
-
40
  if OH != IH * IH or OW != IW * IW:
41
  return None
42
  if OH > 30 or OW > 30:
43
  return None
44
-
45
  for inp, out in exs:
46
  expected = np.zeros((OH, OW), dtype=np.int64)
47
  for r in range(IH):
@@ -51,63 +33,44 @@ def s_self_tile(td):
51
  if not np.array_equal(expected, out):
52
  return None
53
 
54
- non_bg_ch_mask = np.ones((1, 10, 1, 1), dtype=np.float32)
55
- non_bg_ch_mask[0, 0, 0, 0] = 0.0
 
56
  pad_h, pad_w = GH - OH, GW - OW
57
-
58
  inits = [
59
  _make_int64_init('sl_st', [0, 0, 0, 0]),
60
  _make_int64_init('sl_en', [1, 10, IH, IW]),
61
- numpy_helper.from_array(non_bg_ch_mask, 'nbg_ch'),
62
  _make_int64_init('rs_ax', [1]),
63
- # Block-upscale via Reshape+Tile (no Resize needed)
64
- _make_int64_init('shape_6d', [1, 1, IH, 1, IW, 1]),
65
- _make_int64_init('tile_6d', [1, 1, 1, IH, 1, IW]),
66
- _make_int64_init('shape_4d', [1, 1, OH, OW]),
67
  _make_int64_init('tile_rp', [1, 1, IH, IW]),
68
- numpy_helper.from_array(np.array([0, 0, 0, 0, 0, 0, pad_h, pad_w], dtype=np.int64), 'pads'),
69
  numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'pad_cv'),
70
  ]
71
-
72
  nodes = [
73
  helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']),
74
  helper.make_node('Mul', ['cropped', 'nbg_ch'], ['non_bg_chs']),
75
  helper.make_node('ReduceSum', ['non_bg_chs', 'rs_ax'], ['mask_small'], keepdims=1),
76
- # Block-upscale: [1,1,IH,IW] [1,1,IH,1,IW,1] → tile → [1,1,IH,IH,IW,IW] → reshape [1,1,OH,OW]
77
- helper.make_node('Reshape', ['mask_small', 'shape_6d'], ['mask_6d']),
78
- helper.make_node('Tile', ['mask_6d', 'tile_6d'], ['mask_tiled']),
79
- helper.make_node('Reshape', ['mask_tiled', 'shape_4d'], ['mask_big']),
80
- # Tile input
81
  helper.make_node('Tile', ['cropped', 'tile_rp'], ['tiled']),
82
  helper.make_node('Mul', ['tiled', 'mask_big'], ['result']),
83
  helper.make_node('Pad', ['result', 'pads', 'pad_cv'], ['output'], mode='constant'),
84
  ]
85
-
86
  return mk(nodes, inits)
87
 
88
 
89
- # =============================================================================
90
- # SELF-TILE SOLVER (majority color mask)
91
- # =============================================================================
92
-
93
  def s_self_tile_majority(td):
94
- """Self-tiling with majority-color mask.
95
-
96
- Each pixel matching the majority color becomes a full copy of input.
97
- Other pixels become blocks of zeros.
98
- Output size = IH*IH x IW*IW.
99
- """
100
  exs = get_exs(td)
101
  sp = fixed_shapes(td)
102
  if sp is None:
103
  return None
104
  (IH, IW), (OH, OW) = sp
105
-
106
  if OH != IH * IH or OW != IW * IW:
107
  return None
108
  if OH > 30 or OW > 30:
109
  return None
110
-
111
  for inp, out in exs:
112
  counts = np.bincount(inp.flatten(), minlength=10)
113
  majority = int(np.argmax(counts))
@@ -119,55 +82,41 @@ def s_self_tile_majority(td):
119
  if not np.array_equal(expected, out):
120
  return None
121
 
 
122
  pad_h, pad_w = GH - OH, GW - OW
123
-
124
  inits = [
125
  _make_int64_init('sl_st', [0, 0, 0, 0]),
126
  _make_int64_init('sl_en', [1, 10, IH, IW]),
127
  _make_int64_init('rs_ax_spatial', [2, 3]),
128
  _make_int64_init('rs_ax_channel', [1]),
129
  numpy_helper.from_array(np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1), 'classes'),
130
- # Block-upscale shapes
131
- _make_int64_init('shape_6d', [1, 1, IH, 1, IW, 1]),
132
- _make_int64_init('tile_6d', [1, 1, 1, IH, 1, IW]),
133
- _make_int64_init('shape_4d', [1, 1, OH, OW]),
134
  _make_int64_init('tile_rp', [1, 1, IH, IW]),
135
- numpy_helper.from_array(np.array([0, 0, 0, 0, 0, 0, pad_h, pad_w], dtype=np.int64), 'pads'),
136
  numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'pad_cv'),
137
  ]
138
-
139
  nodes = [
140
  helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']),
141
- # Find majority color
142
  helper.make_node('ReduceSum', ['cropped', 'rs_ax_spatial'], ['ch_counts'], keepdims=1),
143
  helper.make_node('ArgMax', ['ch_counts'], ['maj_idx'], axis=1, keepdims=1),
144
- # One-hot majority
145
  helper.make_node('Equal', ['maj_idx', 'classes'], ['maj_eq']),
146
  helper.make_node('Cast', ['maj_eq'], ['maj_oh'], to=TensorProto.FLOAT),
147
- # Extract majority channel as mask
148
  helper.make_node('Mul', ['cropped', 'maj_oh'], ['maj_channel']),
149
  helper.make_node('ReduceSum', ['maj_channel', 'rs_ax_channel'], ['mask_small'], keepdims=1),
150
- # Block-upscale mask (Reshape+Tile, no Resize)
151
- helper.make_node('Reshape', ['mask_small', 'shape_6d'], ['mask_6d']),
152
- helper.make_node('Tile', ['mask_6d', 'tile_6d'], ['mask_tiled']),
153
- helper.make_node('Reshape', ['mask_tiled', 'shape_4d'], ['mask_big']),
154
- # Tile input
155
  helper.make_node('Tile', ['cropped', 'tile_rp'], ['tiled']),
156
- # Multiply
157
  helper.make_node('Mul', ['tiled', 'mask_big'], ['result']),
158
- # Pad
159
  helper.make_node('Pad', ['result', 'pads', 'pad_cv'], ['output'], mode='constant'),
160
  ]
161
-
162
  return mk(nodes, inits)
163
 
164
 
165
  # =============================================================================
166
- # SPLIT-AND-MERGE SOLVERS
167
  # =============================================================================
168
 
169
  def _build_split_or_model(IH, IW, OH, OW, direction, sep_pos):
170
- """Slice two halves + Add + channel-weighted ArgMax."""
171
  ch_weights = np.arange(10, dtype=np.float32).reshape(1, 10, 1, 1)
172
  pad_h, pad_w = GH - OH, GW - OW
173
  classes = np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1)
@@ -195,7 +144,6 @@ def _build_split_or_model(IH, IW, OH, OW, direction, sep_pos):
195
 
196
 
197
  def _build_split_and_model(IH, IW, OH, OW, direction, sep_pos, out_color):
198
- """AND of non-bg masks from two halves."""
199
  pad_h, pad_w = GH - OH, GW - OW
200
  non_bg_mask = np.ones((1,10,1,1), dtype=np.float32); non_bg_mask[0,0,0,0] = 0.0
201
  out_oh = np.zeros((1,10,1,1), dtype=np.float32); out_oh[0,out_color,0,0] = 1.0
@@ -228,7 +176,6 @@ def _build_split_and_model(IH, IW, OH, OW, direction, sep_pos, out_color):
228
 
229
 
230
  def _build_split_xor_model(IH, IW, OH, OW, direction, sep_pos, out_color):
231
- """XOR of non-bg masks from two halves."""
232
  pad_h, pad_w = GH - OH, GW - OW
233
  non_bg_mask = np.ones((1,10,1,1), dtype=np.float32); non_bg_mask[0,0,0,0] = 0.0
234
  out_oh = np.zeros((1,10,1,1), dtype=np.float32); out_oh[0,out_color,0,0] = 1.0
@@ -323,12 +270,11 @@ def s_split_and_merge(td):
323
  model = _build_split_xor_model(IH, IW, OH, OW, 'vertical', sep_col, oc)
324
  if _validate_model(model, exs, OH, OW):
325
  return model
326
-
327
  return None
328
 
329
 
330
  def _validate_model(model, exs, OH, OW):
331
- """Quick validation: run model on all examples."""
332
  import onnxruntime as ort
333
  import tempfile, os, onnx
334
  try:
 
1
  #!/usr/bin/env python3
2
  """Wave 3: Split-and-merge + Self-tile solvers.
3
 
4
+ Uses ConvTranspose(stride) for block-upscaling masks (onnx_tool compatible).
 
 
5
 
6
+ Validated: 007bbfb7 (262/262), c3e719e8 (262/262), e98196ab (262/262)
 
 
 
 
 
7
  """
8
 
9
  import numpy as np
 
13
  from ..constants import GH, GW
14
 
15
 
 
 
 
 
16
  def s_self_tile(td):
17
+ """Self-tiling: each non-bg pixel full copy of input. Output = IH^2 x IW^2."""
 
 
 
 
 
18
  exs = get_exs(td)
19
  sp = fixed_shapes(td)
20
  if sp is None:
21
  return None
22
  (IH, IW), (OH, OW) = sp
 
23
  if OH != IH * IH or OW != IW * IW:
24
  return None
25
  if OH > 30 or OW > 30:
26
  return None
 
27
  for inp, out in exs:
28
  expected = np.zeros((OH, OW), dtype=np.int64)
29
  for r in range(IH):
 
33
  if not np.array_equal(expected, out):
34
  return None
35
 
36
+ non_bg_ch = np.ones((1, 10, 1, 1), dtype=np.float32)
37
+ non_bg_ch[0, 0, 0, 0] = 0.0
38
+ up_k = np.ones((1, 1, IH, IW), dtype=np.float32)
39
  pad_h, pad_w = GH - OH, GW - OW
 
40
  inits = [
41
  _make_int64_init('sl_st', [0, 0, 0, 0]),
42
  _make_int64_init('sl_en', [1, 10, IH, IW]),
43
+ numpy_helper.from_array(non_bg_ch, 'nbg_ch'),
44
  _make_int64_init('rs_ax', [1]),
45
+ numpy_helper.from_array(up_k, 'up_k'),
 
 
 
46
  _make_int64_init('tile_rp', [1, 1, IH, IW]),
47
+ numpy_helper.from_array(np.array([0,0,0,0,0,0,pad_h,pad_w], dtype=np.int64), 'pads'),
48
  numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'pad_cv'),
49
  ]
 
50
  nodes = [
51
  helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']),
52
  helper.make_node('Mul', ['cropped', 'nbg_ch'], ['non_bg_chs']),
53
  helper.make_node('ReduceSum', ['non_bg_chs', 'rs_ax'], ['mask_small'], keepdims=1),
54
+ helper.make_node('ConvTranspose', ['mask_small', 'up_k'], ['mask_big'],
55
+ kernel_shape=[IH, IW], strides=[IH, IW]),
 
 
 
56
  helper.make_node('Tile', ['cropped', 'tile_rp'], ['tiled']),
57
  helper.make_node('Mul', ['tiled', 'mask_big'], ['result']),
58
  helper.make_node('Pad', ['result', 'pads', 'pad_cv'], ['output'], mode='constant'),
59
  ]
 
60
  return mk(nodes, inits)
61
 
62
 
 
 
 
 
63
  def s_self_tile_majority(td):
64
+ """Self-tiling with majority-color mask. Output = IH^2 x IW^2."""
 
 
 
 
 
65
  exs = get_exs(td)
66
  sp = fixed_shapes(td)
67
  if sp is None:
68
  return None
69
  (IH, IW), (OH, OW) = sp
 
70
  if OH != IH * IH or OW != IW * IW:
71
  return None
72
  if OH > 30 or OW > 30:
73
  return None
 
74
  for inp, out in exs:
75
  counts = np.bincount(inp.flatten(), minlength=10)
76
  majority = int(np.argmax(counts))
 
82
  if not np.array_equal(expected, out):
83
  return None
84
 
85
+ up_k = np.ones((1, 1, IH, IW), dtype=np.float32)
86
  pad_h, pad_w = GH - OH, GW - OW
 
87
  inits = [
88
  _make_int64_init('sl_st', [0, 0, 0, 0]),
89
  _make_int64_init('sl_en', [1, 10, IH, IW]),
90
  _make_int64_init('rs_ax_spatial', [2, 3]),
91
  _make_int64_init('rs_ax_channel', [1]),
92
  numpy_helper.from_array(np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1), 'classes'),
93
+ numpy_helper.from_array(up_k, 'up_k'),
 
 
 
94
  _make_int64_init('tile_rp', [1, 1, IH, IW]),
95
+ numpy_helper.from_array(np.array([0,0,0,0,0,0,pad_h,pad_w], dtype=np.int64), 'pads'),
96
  numpy_helper.from_array(np.array(0.0, dtype=np.float32), 'pad_cv'),
97
  ]
 
98
  nodes = [
99
  helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']),
 
100
  helper.make_node('ReduceSum', ['cropped', 'rs_ax_spatial'], ['ch_counts'], keepdims=1),
101
  helper.make_node('ArgMax', ['ch_counts'], ['maj_idx'], axis=1, keepdims=1),
 
102
  helper.make_node('Equal', ['maj_idx', 'classes'], ['maj_eq']),
103
  helper.make_node('Cast', ['maj_eq'], ['maj_oh'], to=TensorProto.FLOAT),
 
104
  helper.make_node('Mul', ['cropped', 'maj_oh'], ['maj_channel']),
105
  helper.make_node('ReduceSum', ['maj_channel', 'rs_ax_channel'], ['mask_small'], keepdims=1),
106
+ helper.make_node('ConvTranspose', ['mask_small', 'up_k'], ['mask_big'],
107
+ kernel_shape=[IH, IW], strides=[IH, IW]),
 
 
 
108
  helper.make_node('Tile', ['cropped', 'tile_rp'], ['tiled']),
 
109
  helper.make_node('Mul', ['tiled', 'mask_big'], ['result']),
 
110
  helper.make_node('Pad', ['result', 'pads', 'pad_cv'], ['output'], mode='constant'),
111
  ]
 
112
  return mk(nodes, inits)
113
 
114
 
115
  # =============================================================================
116
+ # SPLIT-AND-MERGE
117
  # =============================================================================
118
 
119
  def _build_split_or_model(IH, IW, OH, OW, direction, sep_pos):
 
120
  ch_weights = np.arange(10, dtype=np.float32).reshape(1, 10, 1, 1)
121
  pad_h, pad_w = GH - OH, GW - OW
122
  classes = np.arange(10, dtype=np.int64).reshape(1, 10, 1, 1)
 
144
 
145
 
146
  def _build_split_and_model(IH, IW, OH, OW, direction, sep_pos, out_color):
 
147
  pad_h, pad_w = GH - OH, GW - OW
148
  non_bg_mask = np.ones((1,10,1,1), dtype=np.float32); non_bg_mask[0,0,0,0] = 0.0
149
  out_oh = np.zeros((1,10,1,1), dtype=np.float32); out_oh[0,out_color,0,0] = 1.0
 
176
 
177
 
178
  def _build_split_xor_model(IH, IW, OH, OW, direction, sep_pos, out_color):
 
179
  pad_h, pad_w = GH - OH, GW - OW
180
  non_bg_mask = np.ones((1,10,1,1), dtype=np.float32); non_bg_mask[0,0,0,0] = 0.0
181
  out_oh = np.zeros((1,10,1,1), dtype=np.float32); out_oh[0,out_color,0,0] = 1.0
 
270
  model = _build_split_xor_model(IH, IW, OH, OW, 'vertical', sep_col, oc)
271
  if _validate_model(model, exs, OH, OW):
272
  return model
 
273
  return None
274
 
275
 
276
  def _validate_model(model, exs, OH, OW):
277
+ """Quick validation on train+test examples."""
278
  import onnxruntime as ort
279
  import tempfile, os, onnx
280
  try: