rogermt commited on
Commit
ae0427a
·
verified ·
1 Parent(s): 0422a6a

v4.1: color_map via channel Gather for permutation maps (+15 pts), fix non-permutation fallback to Conv 1x1

Browse files
Files changed (1) hide show
  1. neurogolf_solver.py +20 -5
neurogolf_solver.py CHANGED
@@ -355,11 +355,26 @@ def s_color_map(td):
355
  iv, ov = int(iv), int(ov)
356
  if iv in cm and cm[iv] != ov: return None
357
  cm[iv] = ov
358
- W = np.zeros((10,10,1,1), dtype=np.float32)
359
- for ic in range(10):
360
- W[cm.get(ic,ic), ic, 0, 0] = 1.0
361
- return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
362
- [numpy_helper.from_array(W, 'W')])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
  def s_transpose(td):
365
  for ex in td['train']+td['test']:
 
355
  iv, ov = int(iv), int(ov)
356
  if iv in cm and cm[iv] != ov: return None
357
  cm[iv] = ov
358
+
359
+ # Check if it's a permutation (bijective + all mapped colors form a closed set)
360
+ is_permutation = (set(cm.keys()) == set(cm.values()))
361
+
362
+ if is_permutation:
363
+ # Use channel Gather — zero MACs, much cheaper
364
+ gather_ch = np.arange(10, dtype=np.int32)
365
+ for src, dst in cm.items():
366
+ if 0 <= src < 10 and 0 <= dst < 10:
367
+ gather_ch[dst] = src
368
+ inits = [numpy_helper.from_array(gather_ch, 'gi')]
369
+ nodes = [helper.make_node('Gather', ['input', 'gi'], ['output'], axis=1)]
370
+ return mk(nodes, inits)
371
+ else:
372
+ # Non-permutation: use Conv 1x1 (has MACs but handles any mapping)
373
+ W = np.zeros((10,10,1,1), dtype=np.float32)
374
+ for ic in range(10):
375
+ W[cm.get(ic,ic), ic, 0, 0] = 1.0
376
+ return mk([helper.make_node('Conv', ['input','W'], ['output'], kernel_shape=[1,1])],
377
+ [numpy_helper.from_array(W, 'W')])
378
 
379
  def s_transpose(td):
380
  for ex in td['train']+td['test']: