v4.1: color_map via channel Gather for permutation maps (+15 pts), fix non-permutation fallback to Conv 1x1
Browse files- neurogolf_solver.py +20 -5
neurogolf_solver.py
CHANGED
|
@@ -355,11 +355,26 @@ def s_color_map(td):
|
|
| 355 |
iv, ov = int(iv), int(ov)
|
| 356 |
if iv in cm and cm[iv] != ov: return None
|
| 357 |
cm[iv] = ov
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
def s_transpose(td):
|
| 365 |
for ex in td['train']+td['test']:
|
|
|
|
| 355 |
iv, ov = int(iv), int(ov)
|
| 356 |
if iv in cm and cm[iv] != ov: return None
|
| 357 |
cm[iv] = ov
|
| 358 |
+
|
| 359 |
+
# Check if it's a permutation (bijective + all mapped colors form a closed set)
|
| 360 |
+
is_permutation = (set(cm.keys()) == set(cm.values()))
|
| 361 |
+
|
| 362 |
+
if is_permutation:
|
| 363 |
+
# Use channel Gather — zero MACs, much cheaper
|
| 364 |
+
gather_ch = np.arange(10, dtype=np.int32)
|
| 365 |
+
for src, dst in cm.items():
|
| 366 |
+
if 0 <= src < 10 and 0 <= dst < 10:
|
| 367 |
+
gather_ch[dst] = src
|
| 368 |
+
inits = [numpy_helper.from_array(gather_ch, 'gi')]
|
| 369 |
+
nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
|
| 370 |
+
return mk(nodes, inits)
|
| 371 |
+
else:
|
| 372 |
+
# Non-permutation: use Conv 1x1 (has MACs but handles any mapping)
|
| 373 |
+
W = np.zeros((10,10,1,1), dtype=np.float32)
|
| 374 |
+
for ic in range(10):
|
| 375 |
+
W[cm.get(ic,ic), ic, 0, 0] = 1.0
|
| 376 |
+
return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
|
| 377 |
+
[numpy_helper.from_array(W, 'W')])
|
| 378 |
|
| 379 |
def s_transpose(td):
|
| 380 |
for ex in td['train']+td['test']:
|