rogermt commited on
Commit
36815b6
·
verified ·
1 Parent(s): 30356d0

v5 refactor: add solvers/tiling.py

Browse files
Files changed (1) hide show
  1. neurogolf_solver/solvers/tiling.py +429 -0
neurogolf_solver/solvers/tiling.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Tiling, upscaling, mirror, concat, and spatial gather solvers."""
3
+
4
+ import numpy as np
5
+ from onnx import helper
6
+ from itertools import product as iproduct
7
+ from ..onnx_helpers import mk, _make_int64_init, _build_pad_node
8
+ from ..data_loader import get_exs, fixed_shapes
9
+ from ..gather_helpers import _build_gather_model, _build_gather_model_with_const
10
+
11
+
12
+ def s_tile(td):
13
+ """Tiling solver."""
14
+ exs = get_exs(td)
15
+ in_shapes = set(inp.shape for inp, _ in exs)
16
+ if len(in_shapes) != 1:
17
+ return None
18
+ IH, IW = in_shapes.pop()
19
+ tiles = set()
20
+ for inp, out in exs:
21
+ OH, OW = out.shape
22
+ if OH % IH or OW % IW:
23
+ return None
24
+ rH, rW = OH // IH, OW // IW
25
+ if rH < 1 or rW < 1 or (rH == 1 and rW == 1):
26
+ return None
27
+ tiles.add((rH, rW))
28
+ if len(tiles) != 1:
29
+ return None
30
+ rH, rW = tiles.pop()
31
+ OH, OW = IH * rH, IW * rW
32
+ if OH > 30 or OW > 30:
33
+ return None
34
+ for inp, out in exs:
35
+ if not np.array_equal(out, np.tile(inp, (rH, rW))):
36
+ return None
37
+ pad_h, pad_w = 30 - OH, 30 - OW
38
+ inits = [
39
+ _make_int64_init('st', [0, 0, 0, 0]),
40
+ _make_int64_init('en', [1, 10, IH, IW]),
41
+ _make_int64_init('rp', [1, 1, rH, rW]),
42
+ ]
43
+ nodes = [
44
+ helper.make_node('Slice', ['input', 'st', 'en'], ['cr']),
45
+ helper.make_node('Tile', ['cr', 'rp'], ['tl']),
46
+ ]
47
+ nodes.append(_build_pad_node('tl', 'output', pad_h, pad_w, inits))
48
+ return mk(nodes, inits)
49
+
50
+
51
+ def s_upscale(td):
52
+ """Upscaling solver."""
53
+ exs = get_exs(td)
54
+ in_shapes = set(inp.shape for inp, _ in exs)
55
+ if len(in_shapes) != 1:
56
+ return None
57
+ IH, IW = in_shapes.pop()
58
+ scales = set()
59
+ for inp, out in exs:
60
+ OH, OW = out.shape
61
+ if OH % IH or OW % IW:
62
+ return None
63
+ sH, sW = OH // IH, OW // IW
64
+ if sH < 2 or sW < 2:
65
+ return None
66
+ scales.add((sH, sW))
67
+ if len(scales) != 1:
68
+ return None
69
+ sH, sW = scales.pop()
70
+ OH, OW = IH * sH, IW * sW
71
+ if OH > 30 or OW > 30:
72
+ return None
73
+ for inp, out in exs:
74
+ if not np.array_equal(out, np.repeat(np.repeat(inp, sH, 0), sW, 1)):
75
+ return None
76
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
77
+ for r in range(OH):
78
+ for c in range(OW):
79
+ idx[r, c] = [r // sH, c // sW]
80
+ return _build_gather_model(OH, OW, idx)
81
+
82
+
83
+ def s_kronecker(td):
84
+ """Kronecker product solver."""
85
+ exs = get_exs(td)
86
+ sp = fixed_shapes(td)
87
+ if sp is None:
88
+ return None
89
+ (IH, IW), (OH, OW) = sp
90
+ if OH % IH != 0 or OW % IW != 0:
91
+ return None
92
+ sH, sW = OH // IH, OW // IW
93
+ if sH < 2 or sW < 2:
94
+ return None
95
+ if OH > 30 or OW > 30:
96
+ return None
97
+ for inp, out in exs:
98
+ if not np.array_equal(out, np.kron(inp, np.ones((sH, sW), dtype=np.int64))):
99
+ return None
100
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
101
+ for r in range(OH):
102
+ for c in range(OW):
103
+ idx[r, c] = [r // sH, c // sW]
104
+ return _build_gather_model(OH, OW, idx)
105
+
106
+
107
+ def s_nonuniform_scale(td):
108
+ """Non-uniform scaling solver."""
109
+ exs = get_exs(td)
110
+ sp = fixed_shapes(td)
111
+ if sp is None:
112
+ return None
113
+ (IH, IW), (OH, OW) = sp
114
+ for fh, fw in [(1, 2), (2, 1), (1, 3), (3, 1), (2, 3), (3, 2), (1, 4), (4, 1), (2, 4), (4, 2)]:
115
+ if OH != IH * fh or OW != IW * fw:
116
+ continue
117
+ if OH > 30 or OW > 30:
118
+ continue
119
+ if all(np.array_equal(np.repeat(np.repeat(inp, fh, 0), fw, 1), out) for inp, out in exs):
120
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
121
+ for r in range(OH):
122
+ for c in range(OW):
123
+ idx[r, c] = [r // fh, c // fw]
124
+ return _build_gather_model(OH, OW, idx)
125
+ return None
126
+
127
+
128
+ def s_diagonal_tile(td):
129
+ """Diagonal tiling solver."""
130
+ exs = get_exs(td)
131
+ sp = fixed_shapes(td)
132
+ if sp is None:
133
+ return None
134
+ (IH, IW), (OH, OW) = sp
135
+ if OH % IH != 0 or OW % IW != 0:
136
+ return None
137
+ rH, rW = OH // IH, OW // IW
138
+ if rH != rW or rH < 2:
139
+ return None
140
+ if OH > 30 or OW > 30:
141
+ return None
142
+ for inp, out in exs:
143
+ for bi in range(rH):
144
+ for bj in range(rW):
145
+ block = out[bi * IH:(bi + 1) * IH, bj * IW:(bj + 1) * IW]
146
+ if bi == bj:
147
+ if not np.array_equal(block, inp):
148
+ return None
149
+ else:
150
+ if not np.all(block == 0):
151
+ return None
152
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
153
+ cst = np.full((OH, OW), -1, dtype=np.int64)
154
+ for bi in range(rH):
155
+ for bj in range(rW):
156
+ for lr in range(IH):
157
+ for lc in range(IW):
158
+ oi, oj = bi * IH + lr, bj * IW + lc
159
+ if bi == bj:
160
+ idx[oi, oj] = [lr, lc]
161
+ else:
162
+ idx[oi, oj] = [-1, -1]
163
+ cst[oi, oj] = 0
164
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
165
+
166
+
167
+ def s_mirror_h(td):
168
+ """Horizontal mirror solver."""
169
+ exs = get_exs(td)
170
+ sp = fixed_shapes(td)
171
+ if sp is None:
172
+ return None
173
+ (IH, IW), (OH, OW) = sp
174
+ if OH != IH or OW != 2 * IW:
175
+ return None
176
+ if OW > 30:
177
+ return None
178
+ for inp, out in exs:
179
+ if not np.array_equal(np.concatenate([inp, np.flip(inp, 1)], 1), out):
180
+ return None
181
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
182
+ for r in range(OH):
183
+ for c in range(OW):
184
+ sc = c if c < IW else 2 * IW - 1 - c
185
+ idx[r, c] = [r, sc]
186
+ return _build_gather_model(OH, OW, idx)
187
+
188
+
189
+ def s_mirror_v(td):
190
+ """Vertical mirror solver."""
191
+ exs = get_exs(td)
192
+ sp = fixed_shapes(td)
193
+ if sp is None:
194
+ return None
195
+ (IH, IW), (OH, OW) = sp
196
+ if OW != IW or OH != 2 * IH:
197
+ return None
198
+ if OH > 30:
199
+ return None
200
+ for inp, out in exs:
201
+ if not np.array_equal(np.concatenate([inp, np.flip(inp, 0)], 0), out):
202
+ return None
203
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
204
+ for r in range(OH):
205
+ for c in range(OW):
206
+ sr = r if r < IH else 2 * IH - 1 - r
207
+ idx[r, c] = [sr, c]
208
+ return _build_gather_model(OH, OW, idx)
209
+
210
+
211
+ def s_quad_mirror(td):
212
+ """Quad mirror solver."""
213
+ exs = get_exs(td)
214
+ sp = fixed_shapes(td)
215
+ if sp is None:
216
+ return None
217
+ (IH, IW), (OH, OW) = sp
218
+ if OH != 2 * IH or OW != 2 * IW:
219
+ return None
220
+ if OH > 30 or OW > 30:
221
+ return None
222
+ for inp, out in exs:
223
+ expected = np.block([[inp, np.flip(inp, 1)],
224
+ [np.flip(inp, 0), np.flip(np.flip(inp, 0), 1)]])
225
+ if not np.array_equal(expected, out):
226
+ return None
227
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
228
+ for r in range(OH):
229
+ for c in range(OW):
230
+ sr = r if r < IH else 2 * IH - 1 - r
231
+ sc = c if c < IW else 2 * IW - 1 - c
232
+ idx[r, c] = [sr, sc]
233
+ return _build_gather_model(OH, OW, idx)
234
+
235
+
236
+ def s_concat(td):
237
+ """Concatenation solver with transformations."""
238
+ exs = get_exs(td)
239
+ sp = fixed_shapes(td)
240
+ if sp is None:
241
+ return None
242
+ (IH, IW), (OH, OW) = sp
243
+ transforms = [
244
+ ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
245
+ ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
246
+ ]
247
+ if OH == IH and OW % IW == 0 and OW > IW:
248
+ n = OW // IW
249
+ if 2 <= n <= 4:
250
+ for combo in iproduct(range(4), repeat=n):
251
+ if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=1))
252
+ for inp, out in exs):
253
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
254
+ for oi in range(OH):
255
+ for oj in range(OW):
256
+ bj = oj // IW
257
+ lr, lc = oi, oj % IW
258
+ t = transforms[combo[bj]][0]
259
+ if t == 'id':
260
+ sr, sc = lr, lc
261
+ elif t == 'fliplr':
262
+ sr, sc = lr, IW - 1 - lc
263
+ elif t == 'flipud':
264
+ sr, sc = IH - 1 - lr, lc
265
+ elif t == 'rot180':
266
+ sr, sc = IH - 1 - lr, IW - 1 - lc
267
+ idx[oi, oj] = [sr, sc]
268
+ return _build_gather_model(OH, OW, idx)
269
+ if OW == IW and OH % IH == 0 and OH > IH:
270
+ n = OH // IH
271
+ if 2 <= n <= 4:
272
+ for combo in iproduct(range(4), repeat=n):
273
+ if all(np.array_equal(out, np.concatenate([transforms[t][1](inp) for t in combo], axis=0))
274
+ for inp, out in exs):
275
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
276
+ for oi in range(OH):
277
+ for oj in range(OW):
278
+ bi = oi // IH
279
+ lr, lc = oi % IH, oj
280
+ t = transforms[combo[bi]][0]
281
+ if t == 'id':
282
+ sr, sc = lr, lc
283
+ elif t == 'fliplr':
284
+ sr, sc = lr, IW - 1 - lc
285
+ elif t == 'flipud':
286
+ sr, sc = IH - 1 - lr, lc
287
+ elif t == 'rot180':
288
+ sr, sc = IH - 1 - lr, IW - 1 - lc
289
+ idx[oi, oj] = [sr, sc]
290
+ return _build_gather_model(OH, OW, idx)
291
+ return None
292
+
293
+
294
+ def s_concat_enhanced(td):
295
+ """Enhanced concatenation with all 8 dihedral transforms."""
296
+ exs = get_exs(td)
297
+ sp = fixed_shapes(td)
298
+ if sp is None:
299
+ return None
300
+ (IH, IW), (OH, OW) = sp
301
+ if IH == OH and IW == OW:
302
+ return None
303
+ if OH % IH != 0 or OW % IW != 0:
304
+ return None
305
+ rH, rW = OH // IH, OW // IW
306
+ if rH * rW > 16 or rH * rW < 2:
307
+ return None
308
+ if OH > 30 or OW > 30:
309
+ return None
310
+ transforms = [
311
+ ('id', lambda x: x), ('fliplr', lambda x: np.fliplr(x)),
312
+ ('flipud', lambda x: np.flipud(x)), ('rot180', lambda x: np.rot90(x, 2)),
313
+ ('rot90', lambda x: np.rot90(x, 1)), ('rot270', lambda x: np.rot90(x, 3)),
314
+ ('T', lambda x: x.T), ('T_fliplr', lambda x: np.fliplr(x.T)),
315
+ ]
316
+ block_transforms = {}
317
+ for bi in range(rH):
318
+ for bj in range(rW):
319
+ found = None
320
+ for tidx, (tname, tfn) in enumerate(transforms):
321
+ ok = True
322
+ for inp, out in exs:
323
+ block = out[bi * IH:(bi + 1) * IH, bj * IW:(bj + 1) * IW]
324
+ expected = tfn(inp)
325
+ if expected.shape != (IH, IW) or not np.array_equal(block, expected):
326
+ ok = False
327
+ break
328
+ if ok:
329
+ found = (tidx, tname)
330
+ break
331
+ if found is None:
332
+ return None
333
+ block_transforms[(bi, bj)] = found
334
+ idx = np.zeros((OH, OW, 2), dtype=np.int64)
335
+ for bi in range(rH):
336
+ for bj in range(rW):
337
+ _, tname = block_transforms[(bi, bj)]
338
+ for lr in range(IH):
339
+ for lc in range(IW):
340
+ oi, oj = bi * IH + lr, bj * IW + lc
341
+ if tname == 'id':
342
+ sr, sc = lr, lc
343
+ elif tname == 'fliplr':
344
+ sr, sc = lr, IW - 1 - lc
345
+ elif tname == 'flipud':
346
+ sr, sc = IH - 1 - lr, lc
347
+ elif tname == 'rot180':
348
+ sr, sc = IH - 1 - lr, IW - 1 - lc
349
+ elif tname == 'rot90':
350
+ sr, sc = IW - 1 - lc, lr
351
+ elif tname == 'rot270':
352
+ sr, sc = lc, IH - 1 - lr
353
+ elif tname == 'T':
354
+ sr, sc = lc, lr
355
+ elif tname == 'T_fliplr':
356
+ sr, sc = IW - 1 - lc, lr
357
+ idx[oi, oj] = [sr, sc]
358
+ for inp, out in exs:
359
+ reconstructed = np.zeros_like(out)
360
+ for oi in range(OH):
361
+ for oj in range(OW):
362
+ reconstructed[oi, oj] = inp[idx[oi, oj, 0], idx[oi, oj, 1]]
363
+ if not np.array_equal(reconstructed, out):
364
+ return None
365
+ return _build_gather_model(OH, OW, idx)
366
+
367
+
368
+ def s_spatial_gather(td):
369
+ """Spatial gather solver."""
370
+ sp = fixed_shapes(td)
371
+ if sp is None:
372
+ return None
373
+ (IH, IW), (OH, OW) = sp
374
+ exs = get_exs(td)
375
+ idx = np.full((OH, OW, 2), -1, dtype=np.int64)
376
+ cst = np.full((OH, OW), -1, dtype=np.int64)
377
+ for oi in range(OH):
378
+ for oj in range(OW):
379
+ vals = set(int(out[oi, oj]) for _, out in exs)
380
+ if len(vals) == 1:
381
+ cst[oi, oj] = vals.pop()
382
+ found = False
383
+ for ri in range(IH):
384
+ for rj in range(IW):
385
+ if all(int(inp[ri, rj]) == int(out[oi, oj]) for inp, out in exs):
386
+ idx[oi, oj] = [ri, rj]
387
+ found = True
388
+ break
389
+ if found:
390
+ break
391
+ if not found and cst[oi, oj] < 0:
392
+ return None
393
+ return _build_gather_model_with_const(IH, IW, OH, OW, idx, cst)
394
+
395
+
396
+ def s_varshape_spatial_gather(td):
397
+ """Variable shape spatial gather solver."""
398
+ sp = fixed_shapes(td)
399
+ if sp is not None:
400
+ return None
401
+ exs = get_exs(td)
402
+ exs_30 = []
403
+ for inp, out in exs:
404
+ ih, iw = inp.shape
405
+ oh, ow = out.shape
406
+ inp30 = np.zeros((30, 30), dtype=np.int64)
407
+ out30 = np.zeros((30, 30), dtype=np.int64)
408
+ inp30[:ih, :iw] = inp
409
+ out30[:oh, :ow] = out
410
+ exs_30.append((inp30, out30))
411
+ idx = np.full((30, 30, 2), -1, dtype=np.int64)
412
+ cst = np.full((30, 30), -1, dtype=np.int64)
413
+ for oi in range(30):
414
+ for oj in range(30):
415
+ vals = set(int(out30[oi, oj]) for _, out30 in exs_30)
416
+ if len(vals) == 1:
417
+ cst[oi, oj] = vals.pop()
418
+ found = False
419
+ for ri in range(30):
420
+ for rj in range(30):
421
+ if all(int(inp30[ri, rj]) == int(out30[oi, oj]) for inp30, out30 in exs_30):
422
+ idx[oi, oj] = [ri, rj]
423
+ found = True
424
+ break
425
+ if found:
426
+ break
427
+ if not found and cst[oi, oj] < 0:
428
+ return None
429
+ return _build_gather_model_with_const(30, 30, 30, 30, idx, cst)