rogermt commited on
Commit
b8e8fc8
·
verified ·
1 Parent(s): 03612fa

Revert to 50 arc-gen, remove color shuffling (200+shuffling caused regression 67→38)

Browse files
own-solver/neurogolf_solver/data_loader.py CHANGED
@@ -56,22 +56,8 @@ def get_exs(td):
56
  for ex in td['train'] + td['test']]
57
 
58
 
59
- def _color_shuffle(inp, out, rng):
60
- """Apply a random color permutation to both input and output.
61
- Color 0 (background) is always preserved."""
62
- perm = np.arange(10, dtype=np.int64)
63
- # Shuffle colors 1-9 (keep 0 fixed)
64
- non_bg = perm[1:].copy()
65
- rng.shuffle(non_bg)
66
- perm[1:] = non_bg
67
- inp_shuffled = perm[inp]
68
- out_shuffled = perm[out]
69
- return inp_shuffled, out_shuffled
70
-
71
-
72
  def get_exs_for_fitting(td):
73
- """Get examples for fitting with ARC-GEN augmentation (fixed shape).
74
- Uses up to 200 arc-gen examples + color-shuffled augmentations."""
75
  base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
76
  for ex in td['train'] + td['test']]
77
  if not base_exs:
@@ -86,20 +72,11 @@ def get_exs_for_fitting(td):
86
  out = np.array(ex['output'], dtype=np.int64)
87
  if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
88
  ag_exs.append((inp, out))
89
- # Use up to 200 arc-gen examples
90
- fitting_exs = base_exs + ag_exs[:200]
91
- # Add color-shuffled augmentations of base examples (helps learn color-invariant rules)
92
- rng = np.random.RandomState(42)
93
- for inp, out in base_exs:
94
- for _ in range(3):
95
- inp_s, out_s = _color_shuffle(inp, out, rng)
96
- fitting_exs.append((inp_s, out_s))
97
- return fitting_exs
98
 
99
 
100
  def get_exs_for_fitting_variable(td):
101
- """Get examples for variable-shape fitting.
102
- Uses up to 200 arc-gen examples + color-shuffled augmentations."""
103
  base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
104
  for ex in td['train'] + td['test']]
105
  ag_exs = []
@@ -108,16 +85,7 @@ def get_exs_for_fitting_variable(td):
108
  out = np.array(ex['output'], dtype=np.int64)
109
  if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
110
  ag_exs.append((inp, out))
111
- # Use up to 200 arc-gen examples
112
- fitting_exs = base_exs + ag_exs[:200]
113
- # Add color-shuffled augmentations of base examples
114
- rng = np.random.RandomState(42)
115
- for inp, out in base_exs:
116
- if inp.shape == out.shape:
117
- for _ in range(3):
118
- inp_s, out_s = _color_shuffle(inp, out, rng)
119
- fitting_exs.append((inp_s, out_s))
120
- return fitting_exs
121
 
122
 
123
  def fixed_shapes(td):
 
56
  for ex in td['train'] + td['test']]
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def get_exs_for_fitting(td):
60
+ """Get examples for fitting with ARC-GEN augmentation (fixed shape)."""
 
61
  base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
62
  for ex in td['train'] + td['test']]
63
  if not base_exs:
 
72
  out = np.array(ex['output'], dtype=np.int64)
73
  if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
74
  ag_exs.append((inp, out))
75
+ return base_exs + ag_exs[:50]
 
 
 
 
 
 
 
 
76
 
77
 
78
  def get_exs_for_fitting_variable(td):
79
+ """Get examples for variable-shape fitting."""
 
80
  base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
81
  for ex in td['train'] + td['test']]
82
  ag_exs = []
 
85
  out = np.array(ex['output'], dtype=np.int64)
86
  if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
87
  ag_exs.append((inp, out))
88
+ return base_exs + ag_exs[:50]
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  def fixed_shapes(td):