rogermt commited on
Commit
3849c9c
·
verified ·
1 Parent(s): ac61883

Add Wave 2 solvers: 17 new analytical solvers (overlay, bbox_crop, row/col_mode_fill, fill_bg, pad_align, multi_stamp, diagonal_flip, invert_colors, majority_fill, border_extract, interior_fill, repeat_row/col, swap_colors, max_pool, crop_paste)

Browse files
own-solver/neurogolf_solver/solvers/wave2.py ADDED
@@ -0,0 +1,863 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Wave 2 analytical solvers — new patterns not covered by existing solvers.
3
+
4
+ New solvers:
5
+ - s_overlay_constant: input + fixed constant overlay
6
+ - s_bbox_crop: crop to bounding box of non-bg pixels
7
+ - s_row_mode_fill: each row filled with its mode color
8
+ - s_col_mode_fill: each column filled with its mode color
9
+ - s_fill_bg_with_mode: bg pixels → dominant non-bg color
10
+ - s_pad_align: input pasted into larger canvas at fixed offset
11
+ - s_multi_stamp: union of shifted copies
12
+ - s_diagonal_flip: anti-diagonal transpose
13
+ - s_invert_colors: color inversion (max_val - c)
14
+ - s_majority_color_fill: solid fill with least common input color
15
+ - s_border_extract: keep only border pixels
16
+ - s_interior_fill: keep only interior pixels
17
+ - s_repeat_row: repeat one row vertically
18
+ - s_repeat_col: repeat one column horizontally
19
+ - s_swap_two_colors: swap exactly two colors
20
+ - s_max_pool_downsample: max-pool style downsampling
21
+ """
22
+
23
+ import numpy as np
24
+ from onnx import helper, numpy_helper, TensorProto
25
+ from ..onnx_helpers import mk, _make_int64_init, _build_pad_node, add_onehot_block
26
+ from ..data_loader import get_exs, fixed_shapes
27
+ from ..gather_helpers import _build_gather_model, _build_gather_model_with_const
28
+ from ..constants import GH, GW
29
+
30
+
31
+ # =============================================================================
32
+ # SOLVER: overlay_constant
33
+ # =============================================================================
34
+
35
+ def s_overlay_constant(td):
36
+ """Output = input with specific fixed positions overwritten by constant colors."""
37
+ exs = get_exs(td)
38
+ sp = fixed_shapes(td)
39
+ if sp is None:
40
+ return None
41
+ (IH, IW), (OH, OW) = sp
42
+ if (IH, IW) != (OH, OW):
43
+ return None
44
+
45
+ diff_positions = None
46
+ const_values = {}
47
+
48
+ for inp, out in exs:
49
+ current_diffs = set()
50
+ for r in range(IH):
51
+ for c in range(IW):
52
+ if inp[r, c] != out[r, c]:
53
+ current_diffs.add((r, c))
54
+ key = (r, c)
55
+ if key in const_values:
56
+ if const_values[key] != int(out[r, c]):
57
+ return None
58
+ else:
59
+ const_values[key] = int(out[r, c])
60
+
61
+ if diff_positions is None:
62
+ diff_positions = current_diffs
63
+ elif diff_positions != current_diffs:
64
+ return None
65
+
66
+ if not diff_positions or len(diff_positions) == 0:
67
+ return None
68
+ if len(diff_positions) > IH * IW * 0.8:
69
+ return None
70
+
71
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
72
+ cst = np.full((OH, OW), -1, dtype=np.int64)
73
+
74
+ for r in range(OH):
75
+ for c in range(OW):
76
+ if (r, c) in const_values:
77
+ idx[r, c] = [-1, -1]
78
+ cst[r, c] = const_values[(r, c)]
79
+ else:
80
+ idx[r, c] = [r, c]
81
+
82
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
83
+
84
+
85
+ # =============================================================================
86
+ # SOLVER: bbox_crop
87
+ # =============================================================================
88
+
89
+ def s_bbox_crop(td):
90
+ """Crop input to the bounding box of non-background pixels."""
91
+ exs = get_exs(td)
92
+ sp = fixed_shapes(td)
93
+ if sp is None:
94
+ return None
95
+ (IH, IW), (OH, OW) = sp
96
+ if OH >= IH and OW >= IW:
97
+ return None
98
+
99
+ for bg_color in range(10):
100
+ ok = True
101
+ for inp, out in exs:
102
+ rows = np.any(inp != bg_color, axis=1)
103
+ cols = np.any(inp != bg_color, axis=0)
104
+ if not np.any(rows) or not np.any(cols):
105
+ ok = False
106
+ break
107
+ r_min, r_max = np.where(rows)[0][[0, -1]]
108
+ c_min, c_max = np.where(cols)[0][[0, -1]]
109
+ cropped = inp[r_min:r_max+1, c_min:c_max+1]
110
+ if cropped.shape != out.shape or not np.array_equal(cropped, out):
111
+ ok = False
112
+ break
113
+ if ok:
114
+ inp0 = exs[0][0]
115
+ rows = np.any(inp0 != bg_color, axis=1)
116
+ cols = np.any(inp0 != bg_color, axis=0)
117
+ r_min = int(np.where(rows)[0][0])
118
+ c_min = int(np.where(cols)[0][0])
119
+
120
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
121
+ for r in range(OH):
122
+ for c in range(OW):
123
+ idx[r, c] = [r + r_min, c + c_min]
124
+ return _build_gather_model(OH, OW, idx)
125
+
126
+ return None
127
+
128
+
129
+ # =============================================================================
130
+ # SOLVER: row_mode_fill
131
+ # =============================================================================
132
+
133
+ def s_row_mode_fill(td):
134
+ """Each row of output = solid fill of the mode color of that input row."""
135
+ exs = get_exs(td)
136
+ sp = fixed_shapes(td)
137
+ if sp is None:
138
+ return None
139
+ (IH, IW), (OH, OW) = sp
140
+ if (IH, IW) != (OH, OW):
141
+ return None
142
+
143
+ for inp, out in exs:
144
+ for r in range(IH):
145
+ counts = np.bincount(inp[r, :], minlength=10)
146
+ mode_color = int(np.argmax(counts))
147
+ if not np.all(out[r, :] == mode_color):
148
+ return None
149
+
150
+ pad_h, pad_w = GH - IH, GW - IW
151
+ inits = [
152
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
153
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
154
+ _make_int64_init('rs_ax', [3]),
155
+ _make_int64_init('tile_rp', [1, 1, 1, IW]),
156
+ ]
157
+ nodes = [
158
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']),
159
+ helper.make_node('ReduceSum', ['cropped', 'rs_ax'], ['row_sums'], keepdims=1),
160
+ helper.make_node('ArgMax', ['row_sums'], ['row_modes'], axis=1, keepdims=1),
161
+ helper.make_node('Tile', ['row_modes', 'tile_rp'], ['tiled']),
162
+ ]
163
+ add_onehot_block(nodes, inits, 'tiled', 'oh_out')
164
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
165
+ return mk(nodes, inits)
166
+
167
+
168
+ # =============================================================================
169
+ # SOLVER: col_mode_fill
170
+ # =============================================================================
171
+
172
+ def s_col_mode_fill(td):
173
+ """Each column of output = solid fill of the mode color of that input column."""
174
+ exs = get_exs(td)
175
+ sp = fixed_shapes(td)
176
+ if sp is None:
177
+ return None
178
+ (IH, IW), (OH, OW) = sp
179
+ if (IH, IW) != (OH, OW):
180
+ return None
181
+
182
+ for inp, out in exs:
183
+ for c in range(IW):
184
+ counts = np.bincount(inp[:, c], minlength=10)
185
+ mode_color = int(np.argmax(counts))
186
+ if not np.all(out[:, c] == mode_color):
187
+ return None
188
+
189
+ pad_h, pad_w = GH - IH, GW - IW
190
+ inits = [
191
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
192
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
193
+ _make_int64_init('rs_ax', [2]),
194
+ _make_int64_init('tile_rp', [1, 1, IH, 1]),
195
+ ]
196
+ nodes = [
197
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']),
198
+ helper.make_node('ReduceSum', ['cropped', 'rs_ax'], ['col_sums'], keepdims=1),
199
+ helper.make_node('ArgMax', ['col_sums'], ['col_modes'], axis=1, keepdims=1),
200
+ helper.make_node('Tile', ['col_modes', 'tile_rp'], ['tiled']),
201
+ ]
202
+ add_onehot_block(nodes, inits, 'tiled', 'oh_out')
203
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
204
+ return mk(nodes, inits)
205
+
206
+
207
+ # =============================================================================
208
+ # SOLVER: fill_bg_with_mode
209
+ # =============================================================================
210
+
211
+ def s_fill_bg_with_mode(td):
212
+ """Background pixels replaced by the most common non-background color."""
213
+ exs = get_exs(td)
214
+ sp = fixed_shapes(td)
215
+ if sp is None:
216
+ return None
217
+ (IH, IW), (OH, OW) = sp
218
+ if (IH, IW) != (OH, OW):
219
+ return None
220
+
221
+ for bg_color in range(10):
222
+ ok = True
223
+ for inp, out in exs:
224
+ non_bg = inp[inp != bg_color]
225
+ if len(non_bg) == 0:
226
+ ok = False
227
+ break
228
+ counts = np.bincount(non_bg, minlength=10)
229
+ mode_color = int(np.argmax(counts))
230
+ expected = inp.copy()
231
+ expected[inp == bg_color] = mode_color
232
+ if not np.array_equal(expected, out):
233
+ ok = False
234
+ break
235
+
236
+ if ok:
237
+ pad_h, pad_w = GH - IH, GW - IW
238
+ non_bg_mask = np.ones((1, 10, 1, 1), dtype=np.float32)
239
+ non_bg_mask[0, bg_color, 0, 0] = 0.0
240
+
241
+ inits = [
242
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
243
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
244
+ numpy_helper.from_array(non_bg_mask, 'non_bg_mask'),
245
+ _make_int64_init('rs_ax', [2, 3]),
246
+ _make_int64_init('tile_rp', [1, 1, IH, IW]),
247
+ numpy_helper.from_array(np.float32(0.5), 'half'),
248
+ _make_int64_init('bg_st', [0, bg_color, 0, 0]),
249
+ _make_int64_init('bg_en', [1, bg_color + 1, IH, IW]),
250
+ ]
251
+ nodes = [
252
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']),
253
+ helper.make_node('Mul', ['cropped', 'non_bg_mask'], ['non_bg_only']),
254
+ helper.make_node('ReduceSum', ['non_bg_only', 'rs_ax'], ['color_counts'], keepdims=1),
255
+ helper.make_node('ArgMax', ['color_counts'], ['mode_idx'], axis=1, keepdims=1),
256
+ helper.make_node('Tile', ['mode_idx', 'tile_rp'], ['mode_tiled']),
257
+ ]
258
+ add_onehot_block(nodes, inits, 'mode_tiled', 'mode_oh')
259
+ nodes.extend([
260
+ helper.make_node('Slice', ['cropped', 'bg_st', 'bg_en'], ['bg_ch']),
261
+ helper.make_node('Greater', ['bg_ch', 'half'], ['is_bg']),
262
+ helper.make_node('Where', ['is_bg', 'mode_oh', 'cropped'], ['filled']),
263
+ ])
264
+ nodes.append(_build_pad_node('filled', 'output', pad_h, pad_w, inits))
265
+ return mk(nodes, inits)
266
+
267
+ return None
268
+
269
+
270
+ # =============================================================================
271
+ # SOLVER: pad_align
272
+ # =============================================================================
273
+
274
+ def s_pad_align(td):
275
+ """Input pasted into larger canvas at fixed offset, rest = bg_color."""
276
+ exs = get_exs(td)
277
+ sp = fixed_shapes(td)
278
+ if sp is None:
279
+ return None
280
+ (IH, IW), (OH, OW) = sp
281
+ if OH <= IH and OW <= IW:
282
+ return None
283
+ if OH > 30 or OW > 30:
284
+ return None
285
+
286
+ for bg_color in range(10):
287
+ for r0 in range(OH - IH + 1):
288
+ for c0 in range(OW - IW + 1):
289
+ ok = True
290
+ for inp, out in exs:
291
+ expected = np.full((OH, OW), bg_color, dtype=np.int64)
292
+ expected[r0:r0+IH, c0:c0+IW] = inp
293
+ if not np.array_equal(expected, out):
294
+ ok = False
295
+ break
296
+ if ok:
297
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
298
+ cst = np.full((OH, OW), -1, dtype=np.int64)
299
+ for r in range(OH):
300
+ for c in range(OW):
301
+ if r0 <= r < r0 + IH and c0 <= c < c0 + IW:
302
+ idx[r, c] = [r - r0, c - c0]
303
+ else:
304
+ idx[r, c] = [-1, -1]
305
+ cst[r, c] = bg_color
306
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
307
+ return None
308
+
309
+
310
+ # =============================================================================
311
+ # SOLVER: multi_stamp
312
+ # =============================================================================
313
+
314
+ def s_multi_stamp(td):
315
+ """Output = union of multiple shifted copies of input at fixed offsets."""
316
+ exs = get_exs(td)
317
+ sp = fixed_shapes(td)
318
+ if sp is None:
319
+ return None
320
+ (IH, IW), (OH, OW) = sp
321
+ if (IH, IW) != (OH, OW):
322
+ return None
323
+
324
+ def apply_stamps(inp, offsets, bg=0):
325
+ result = np.full((IH, IW), bg, dtype=np.int64)
326
+ for dr, dc in offsets:
327
+ for r in range(IH):
328
+ for c in range(IW):
329
+ sr, sc = r - dr, c - dc
330
+ if 0 <= sr < IH and 0 <= sc < IW and inp[sr, sc] != bg:
331
+ result[r, c] = inp[sr, sc]
332
+ return result
333
+
334
+ for bg in range(2):
335
+ small_range = range(-min(IH, 8), min(IH, 8) + 1)
336
+ for dr1 in small_range:
337
+ for dc1 in small_range:
338
+ if dr1 == 0 and dc1 == 0:
339
+ continue
340
+ offsets = [(0, 0), (dr1, dc1)]
341
+ if all(np.array_equal(apply_stamps(inp, offsets, bg), out)
342
+ for inp, out in exs):
343
+ return _build_stamp_model(IH, IW, offsets, bg)
344
+ return None
345
+
346
+
347
+ def _build_stamp_model(IH, IW, offsets, bg):
348
+ """Build ONNX model for multi-stamp."""
349
+ OH, OW = IH, IW
350
+ pad_h, pad_w = GH - IH, GW - IW
351
+ n_stamps = len(offsets)
352
+
353
+ inits = [
354
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
355
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
356
+ ]
357
+ nodes = [
358
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']),
359
+ ]
360
+
361
+ stamp_names = []
362
+ for i, (dr, dc) in enumerate(offsets):
363
+ flat_idx = np.zeros((IH * IW,), dtype=np.int64)
364
+ mask = np.zeros((1, 1, IH, IW), dtype=np.float32)
365
+
366
+ for r in range(IH):
367
+ for c in range(IW):
368
+ sr, sc = r - dr, c - dc
369
+ if 0 <= sr < IH and 0 <= sc < IW:
370
+ flat_idx[r * IW + c] = sr * IW + sc
371
+ mask[0, 0, r, c] = 1.0
372
+
373
+ idx_name = f'sidx_{i}'
374
+ mask_name = f'smask_{i}'
375
+ inits.append(numpy_helper.from_array(flat_idx, idx_name))
376
+ inits.append(numpy_helper.from_array(mask, mask_name))
377
+ inits.append(_make_int64_init(f'sfs_{i}', [1, 10, IH * IW]))
378
+ inits.append(_make_int64_init(f'sos_{i}', [1, 10, IH, IW]))
379
+
380
+ nodes.extend([
381
+ helper.make_node('Reshape', ['cropped', f'sfs_{i}'], [f'sflat_{i}']),
382
+ helper.make_node('Gather', [f'sflat_{i}', idx_name], [f'sg_{i}'], axis=2),
383
+ helper.make_node('Reshape', [f'sg_{i}', f'sos_{i}'], [f'sraw_{i}']),
384
+ helper.make_node('Mul', [f'sraw_{i}', mask_name], [f'stamp_{i}']),
385
+ ])
386
+ stamp_names.append(f'stamp_{i}')
387
+
388
+ # Sum all stamps
389
+ if len(stamp_names) == 2:
390
+ nodes.append(helper.make_node('Add', stamp_names, ['sum_stamps']))
391
+ else:
392
+ nodes.append(helper.make_node('Add', [stamp_names[0], stamp_names[1]], ['sum_01']))
393
+ cur = 'sum_01'
394
+ for i in range(2, len(stamp_names)):
395
+ nxt = f'sum_0{i}' if i < len(stamp_names) - 1 else 'sum_stamps'
396
+ nodes.append(helper.make_node('Add', [cur, stamp_names[i]], [nxt]))
397
+ cur = nxt
398
+
399
+ # Weight channels so non-bg colors always win over bg
400
+ ch_weights = np.arange(10, dtype=np.float32).reshape(1, 10, 1, 1)
401
+ inits.append(numpy_helper.from_array(ch_weights, 'ch_w'))
402
+ nodes.append(helper.make_node('Mul', ['sum_stamps', 'ch_w'], ['weighted']))
403
+ nodes.append(helper.make_node('ArgMax', ['weighted'], ['am'], axis=1, keepdims=1))
404
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
405
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
406
+ return mk(nodes, inits)
407
+
408
+
409
+ # =============================================================================
410
+ # SOLVER: diagonal_flip
411
+ # =============================================================================
412
+
413
+ def s_diagonal_flip(td):
414
+ """Flip along anti-diagonal: out[r,c] = inp[IW-1-c, IH-1-r]."""
415
+ exs = get_exs(td)
416
+ sp = fixed_shapes(td)
417
+ if sp is None:
418
+ return None
419
+ (IH, IW), (OH, OW) = sp
420
+ if OH != IW or OW != IH:
421
+ return None
422
+
423
+ ok = True
424
+ for inp, out in exs:
425
+ expected = np.zeros((OH, OW), dtype=np.int64)
426
+ for r in range(OH):
427
+ for c in range(OW):
428
+ expected[r, c] = inp[IW - 1 - c, IH - 1 - r]
429
+ if not np.array_equal(expected, out):
430
+ ok = False
431
+ break
432
+
433
+ if ok:
434
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
435
+ for r in range(OH):
436
+ for c in range(OW):
437
+ idx[r, c] = [IW - 1 - c, IH - 1 - r]
438
+ return _build_gather_model(OH, OW, idx)
439
+
440
+ return None
441
+
442
+
443
+ # =============================================================================
444
+ # SOLVER: invert_colors
445
+ # =============================================================================
446
+
447
+ def s_invert_colors(td):
448
+ """Color inversion: out = max_val - inp."""
449
+ exs = get_exs(td)
450
+ sp = fixed_shapes(td)
451
+ if sp is None:
452
+ return None
453
+ (IH, IW), (OH, OW) = sp
454
+ if (IH, IW) != (OH, OW):
455
+ return None
456
+
457
+ for max_val in range(1, 10):
458
+ ok = True
459
+ for inp, out in exs:
460
+ expected = max_val - inp
461
+ if np.any(expected < 0) or np.any(expected >= 10):
462
+ ok = False
463
+ break
464
+ if not np.array_equal(out, expected):
465
+ ok = False
466
+ break
467
+ if ok:
468
+ W = np.zeros((10, 10, 1, 1), dtype=np.float32)
469
+ for c in range(10):
470
+ dst = max_val - c
471
+ if 0 <= dst < 10:
472
+ W[dst, c, 0, 0] = 1.0
473
+ inits = [numpy_helper.from_array(W, 'W')]
474
+ nodes = [helper.make_node('Conv', ['input', 'W'], ['output'], kernel_shape=[1, 1])]
475
+ return mk(nodes, inits)
476
+
477
+ return None
478
+
479
+
480
+ # =============================================================================
481
+ # SOLVER: majority_color_fill (solid fill with least common color)
482
+ # =============================================================================
483
+
484
+ def s_majority_color_fill(td):
485
+ """Output = solid grid filled with the least common present color of input."""
486
+ exs = get_exs(td)
487
+ sp = fixed_shapes(td)
488
+ if sp is None:
489
+ return None
490
+ (IH, IW), (OH, OW) = sp
491
+
492
+ # Check if output is always a solid single color
493
+ for _, out in exs:
494
+ if not np.all(out == out[0, 0]):
495
+ return None
496
+
497
+ # Check if color is the LEAST common present color
498
+ ok = True
499
+ for inp, out in exs:
500
+ out_color = int(out[0, 0])
501
+ counts = np.bincount(inp.flatten(), minlength=10)
502
+ present = np.where(counts > 0)[0]
503
+ if len(present) < 2:
504
+ ok = False
505
+ break
506
+ min_count_color = present[np.argmin(counts[present])]
507
+ if min_count_color != out_color:
508
+ ok = False
509
+ break
510
+
511
+ if ok and len(exs) >= 2:
512
+ pad_h, pad_w = GH - OH, GW - OW
513
+ inits = [
514
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
515
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
516
+ _make_int64_init('rs_ax', [2, 3]),
517
+ _make_int64_init('tile_rp', [1, 1, OH, OW]),
518
+ numpy_helper.from_array(np.float32(-1.0), 'neg_one'),
519
+ numpy_helper.from_array(np.float32(0.5), 'half'),
520
+ ]
521
+ nodes = [
522
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']),
523
+ helper.make_node('ReduceSum', ['cropped', 'rs_ax'], ['counts'], keepdims=1),
524
+ # Negate so argmax gives argmin
525
+ helper.make_node('Mul', ['counts', 'neg_one'], ['neg_counts']),
526
+ # Mask out zero-count channels (so they don't win)
527
+ helper.make_node('Greater', ['counts', 'half'], ['present']),
528
+ helper.make_node('Cast', ['present'], ['present_f'], to=TensorProto.FLOAT),
529
+ helper.make_node('Mul', ['neg_counts', 'present_f'], ['masked_neg']),
530
+ helper.make_node('ArgMax', ['masked_neg'], ['min_idx'], axis=1, keepdims=1),
531
+ helper.make_node('Tile', ['min_idx', 'tile_rp'], ['tiled']),
532
+ ]
533
+ add_onehot_block(nodes, inits, 'tiled', 'oh_out')
534
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
535
+ return mk(nodes, inits)
536
+
537
+ return None
538
+
539
+
540
+ # =============================================================================
541
+ # SOLVER: border_extract
542
+ # =============================================================================
543
+
544
+ def s_border_extract(td):
545
+ """Output = only the border pixels of input, interior → bg."""
546
+ exs = get_exs(td)
547
+ sp = fixed_shapes(td)
548
+ if sp is None:
549
+ return None
550
+ (IH, IW), (OH, OW) = sp
551
+ if (IH, IW) != (OH, OW):
552
+ return None
553
+ if IH < 3 or IW < 3:
554
+ return None
555
+
556
+ for bg in range(10):
557
+ ok = True
558
+ for inp, out in exs:
559
+ expected = np.full((IH, IW), bg, dtype=np.int64)
560
+ expected[0, :] = inp[0, :]
561
+ expected[-1, :] = inp[-1, :]
562
+ expected[:, 0] = inp[:, 0]
563
+ expected[:, -1] = inp[:, -1]
564
+ if not np.array_equal(expected, out):
565
+ ok = False
566
+ break
567
+ if ok:
568
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
569
+ cst = np.full((OH, OW), -1, dtype=np.int64)
570
+ for r in range(OH):
571
+ for c in range(OW):
572
+ if r == 0 or r == IH - 1 or c == 0 or c == IW - 1:
573
+ idx[r, c] = [r, c]
574
+ else:
575
+ idx[r, c] = [-1, -1]
576
+ cst[r, c] = bg
577
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
578
+ return None
579
+
580
+
581
+ # =============================================================================
582
+ # SOLVER: interior_fill
583
+ # =============================================================================
584
+
585
+ def s_interior_fill(td):
586
+ """Keep only interior pixels, border → bg. Same size output."""
587
+ exs = get_exs(td)
588
+ sp = fixed_shapes(td)
589
+ if sp is None:
590
+ return None
591
+ (IH, IW), (OH, OW) = sp
592
+ if (IH, IW) != (OH, OW):
593
+ return None
594
+ if IH < 3 or IW < 3:
595
+ return None
596
+
597
+ for b in range(1, min(IH, IW) // 2):
598
+ for bg in range(10):
599
+ ok = True
600
+ for inp, out in exs:
601
+ expected = np.full((IH, IW), bg, dtype=np.int64)
602
+ expected[b:IH-b, b:IW-b] = inp[b:IH-b, b:IW-b]
603
+ if not np.array_equal(expected, out):
604
+ ok = False
605
+ break
606
+ if ok:
607
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
608
+ cst = np.full((OH, OW), -1, dtype=np.int64)
609
+ for r in range(OH):
610
+ for c in range(OW):
611
+ if b <= r < IH - b and b <= c < IW - b:
612
+ idx[r, c] = [r, c]
613
+ else:
614
+ idx[r, c] = [-1, -1]
615
+ cst[r, c] = bg
616
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
617
+ return None
618
+
619
+
620
+ # =============================================================================
621
+ # SOLVER: repeat_row
622
+ # =============================================================================
623
+
624
+ def s_repeat_row(td):
625
+ """Output = one specific row of input repeated vertically to fill grid."""
626
+ exs = get_exs(td)
627
+ sp = fixed_shapes(td)
628
+ if sp is None:
629
+ return None
630
+ (IH, IW), (OH, OW) = sp
631
+ if IW != OW:
632
+ return None
633
+
634
+ for src_row in range(IH):
635
+ ok = True
636
+ for inp, out in exs:
637
+ expected = np.tile(inp[src_row:src_row+1, :], (OH, 1))
638
+ if not np.array_equal(expected, out):
639
+ ok = False
640
+ break
641
+ if ok:
642
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
643
+ for r in range(OH):
644
+ for c in range(OW):
645
+ idx[r, c] = [src_row, c]
646
+ return _build_gather_model(OH, OW, idx)
647
+ return None
648
+
649
+
650
+ # =============================================================================
651
+ # SOLVER: repeat_col
652
+ # =============================================================================
653
+
654
+ def s_repeat_col(td):
655
+ """Output = one specific column of input repeated horizontally to fill grid."""
656
+ exs = get_exs(td)
657
+ sp = fixed_shapes(td)
658
+ if sp is None:
659
+ return None
660
+ (IH, IW), (OH, OW) = sp
661
+ if IH != OH:
662
+ return None
663
+
664
+ for src_col in range(IW):
665
+ ok = True
666
+ for inp, out in exs:
667
+ expected = np.tile(inp[:, src_col:src_col+1], (1, OW))
668
+ if not np.array_equal(expected, out):
669
+ ok = False
670
+ break
671
+ if ok:
672
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
673
+ for r in range(OH):
674
+ for c in range(OW):
675
+ idx[r, c] = [r, src_col]
676
+ return _build_gather_model(OH, OW, idx)
677
+ return None
678
+
679
+
680
+ # =============================================================================
681
+ # SOLVER: swap_two_colors
682
+ # =============================================================================
683
+
684
+ def s_swap_two_colors(td):
685
+ """Swap exactly two colors A and B, all other colors unchanged."""
686
+ exs = get_exs(td)
687
+ sp = fixed_shapes(td)
688
+ if sp is None:
689
+ return None
690
+ (IH, IW), (OH, OW) = sp
691
+ if (IH, IW) != (OH, OW):
692
+ return None
693
+
694
+ swaps = {}
695
+ for inp, out in exs:
696
+ for iv, ov in zip(inp.flat, out.flat):
697
+ iv, ov = int(iv), int(ov)
698
+ if iv != ov:
699
+ if iv in swaps:
700
+ if swaps[iv] != ov:
701
+ return None
702
+ else:
703
+ swaps[iv] = ov
704
+
705
+ if len(swaps) != 2:
706
+ return None
707
+
708
+ items = list(swaps.items())
709
+ if items[0][1] != items[1][0] or items[1][1] != items[0][0]:
710
+ return None
711
+
712
+ a, b = items[0][0], items[0][1]
713
+ for inp, out in exs:
714
+ expected = inp.copy()
715
+ expected[inp == a] = b
716
+ expected[inp == b] = a
717
+ if not np.array_equal(expected, out):
718
+ return None
719
+
720
+ gather_ch = np.arange(10, dtype=np.int32)
721
+ gather_ch[a] = b
722
+ gather_ch[b] = a
723
+ inits = [numpy_helper.from_array(gather_ch, 'gi')]
724
+ nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
725
+ return mk(nodes, inits)
726
+
727
+
728
+ # =============================================================================
729
+ # SOLVER: max_pool_downsample
730
+ # =============================================================================
731
+
732
+ def s_max_pool_downsample(td):
733
+ """Non-overlapping max-pool: output = max color value in each block."""
734
+ exs = get_exs(td)
735
+ sp = fixed_shapes(td)
736
+ if sp is None:
737
+ return None
738
+ (IH, IW), (OH, OW) = sp
739
+ if OH >= IH or OW >= IW:
740
+ return None
741
+ if IH % OH != 0 or IW % OW != 0:
742
+ return None
743
+
744
+ bh, bw = IH // OH, IW // OW
745
+ if bh < 2 and bw < 2:
746
+ return None
747
+
748
+ # Check if simple max works
749
+ ok = True
750
+ for inp, out in exs:
751
+ for r in range(OH):
752
+ for c in range(OW):
753
+ block = inp[r*bh:(r+1)*bh, c*bw:(c+1)*bw]
754
+ if out[r, c] != int(np.max(block)):
755
+ ok = False
756
+ break
757
+ if not ok:
758
+ break
759
+ if not ok:
760
+ break
761
+
762
+ if ok:
763
+ pad_h, pad_w = GH - OH, GW - OW
764
+ ch_weights = np.arange(10, dtype=np.float32).reshape(1, 10, 1, 1)
765
+ inits = [
766
+ _make_int64_init('sl_st', [0, 0, 0, 0]),
767
+ _make_int64_init('sl_en', [1, 10, IH, IW]),
768
+ numpy_helper.from_array(ch_weights, 'ch_w'),
769
+ ]
770
+ nodes = [
771
+ helper.make_node('Slice', ['input', 'sl_st', 'sl_en'], ['cropped']),
772
+ helper.make_node('MaxPool', ['cropped'], ['pooled'],
773
+ kernel_shape=[bh, bw], strides=[bh, bw]),
774
+ helper.make_node('Mul', ['pooled', 'ch_w'], ['weighted']),
775
+ helper.make_node('ArgMax', ['weighted'], ['am'], axis=1, keepdims=1),
776
+ ]
777
+ add_onehot_block(nodes, inits, 'am', 'oh_out')
778
+ nodes.append(_build_pad_node('oh_out', 'output', pad_h, pad_w, inits))
779
+ return mk(nodes, inits)
780
+
781
+ return None
782
+
783
+
784
+ # =============================================================================
785
+ # SOLVER: crop_paste (limited search for speed)
786
+ # =============================================================================
787
+
788
+ def s_crop_paste(td):
789
+ """Crop from one fixed position, paste at another. Limited search."""
790
+ exs = get_exs(td)
791
+ sp = fixed_shapes(td)
792
+ if sp is None:
793
+ return None
794
+ (IH, IW), (OH, OW) = sp
795
+ if (IH, IW) != (OH, OW):
796
+ return None
797
+ if IH < 3 or IW < 3:
798
+ return None
799
+
800
+ for bg in range(2):
801
+ out0 = exs[0][1]
802
+ rows = np.any(out0 != bg, axis=1)
803
+ cols = np.any(out0 != bg, axis=0)
804
+ if not np.any(rows) or not np.any(cols):
805
+ continue
806
+ dr = int(np.where(rows)[0][0])
807
+ dc = int(np.where(cols)[0][0])
808
+ dr_end = int(np.where(rows)[0][-1]) + 1
809
+ dc_end = int(np.where(cols)[0][-1]) + 1
810
+ ch = dr_end - dr
811
+ cw = dc_end - dc
812
+
813
+ if ch < 2 or cw < 2 or ch >= IH or cw >= IW:
814
+ continue
815
+
816
+ for sr in range(IH - ch + 1):
817
+ for sc in range(IW - cw + 1):
818
+ if sr == dr and sc == dc:
819
+ continue
820
+ ok = True
821
+ for inp, out in exs:
822
+ expected = np.full((OH, OW), bg, dtype=np.int64)
823
+ expected[dr:dr+ch, dc:dc+cw] = inp[sr:sr+ch, sc:sc+cw]
824
+ if not np.array_equal(expected, out):
825
+ ok = False
826
+ break
827
+ if ok:
828
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
829
+ cst = np.full((OH, OW), -1, dtype=np.int64)
830
+ for r in range(OH):
831
+ for c in range(OW):
832
+ if dr <= r < dr + ch and dc <= c < dc + cw:
833
+ idx[r, c] = [sr + (r - dr), sc + (c - dc)]
834
+ else:
835
+ idx[r, c] = [-1, -1]
836
+ cst[r, c] = bg
837
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
838
+ return None
839
+
840
+
841
+ # =============================================================================
842
+ # Collect all Wave 2 solvers for registration
843
+ # =============================================================================
844
+
845
+ WAVE2_SOLVERS = [
846
+ ('overlay_constant', s_overlay_constant),
847
+ ('bbox_crop', s_bbox_crop),
848
+ ('row_mode_fill', s_row_mode_fill),
849
+ ('col_mode_fill', s_col_mode_fill),
850
+ ('fill_bg_with_mode', s_fill_bg_with_mode),
851
+ ('pad_align', s_pad_align),
852
+ ('multi_stamp', s_multi_stamp),
853
+ ('diagonal_flip', s_diagonal_flip),
854
+ ('invert_colors', s_invert_colors),
855
+ ('majority_color_fill', s_majority_color_fill),
856
+ ('border_extract', s_border_extract),
857
+ ('interior_fill', s_interior_fill),
858
+ ('repeat_row', s_repeat_row),
859
+ ('repeat_col', s_repeat_col),
860
+ ('swap_two_colors', s_swap_two_colors),
861
+ ('max_pool_downsample', s_max_pool_downsample),
862
+ ('crop_paste', s_crop_paste),
863
+ ]