rogermt commited on
Commit
f9a1abc
·
verified ·
1 Parent(s): f1edb04

Increase arc-gen fitting data: 10→50 (fixed), 20→50 (variable) for better conv generalization

Browse files
own-solver/neurogolf_solver/data_loader.py CHANGED
@@ -57,7 +57,7 @@ def get_exs(td):
57
 
58
 
59
  def get_exs_for_fitting(td):
60
- """Get examples for fitting with ARC-GEN augmentation."""
61
  base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
62
  for ex in td['train'] + td['test']]
63
  if not base_exs:
@@ -72,7 +72,7 @@ def get_exs_for_fitting(td):
72
  out = np.array(ex['output'], dtype=np.int64)
73
  if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
74
  ag_exs.append((inp, out))
75
- return base_exs + ag_exs[:10]
76
 
77
 
78
  def get_exs_for_fitting_variable(td):
@@ -85,7 +85,7 @@ def get_exs_for_fitting_variable(td):
85
  out = np.array(ex['output'], dtype=np.int64)
86
  if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
87
  ag_exs.append((inp, out))
88
- return base_exs + ag_exs[:20]
89
 
90
 
91
  def fixed_shapes(td):
 
57
 
58
 
59
  def get_exs_for_fitting(td):
60
+ """Get examples for fitting with ARC-GEN augmentation (fixed shape)."""
61
  base_exs = [(np.array(ex['input'], dtype=np.int64), np.array(ex['output'], dtype=np.int64))
62
  for ex in td['train'] + td['test']]
63
  if not base_exs:
 
72
  out = np.array(ex['output'], dtype=np.int64)
73
  if inp.shape == base_shape and out.shape == base_exs[0][1].shape:
74
  ag_exs.append((inp, out))
75
+ return base_exs + ag_exs[:50]
76
 
77
 
78
  def get_exs_for_fitting_variable(td):
 
85
  out = np.array(ex['output'], dtype=np.int64)
86
  if inp.shape == out.shape and inp.shape[0] <= 30 and inp.shape[1] <= 30:
87
  ag_exs.append((inp, out))
88
+ return base_exs + ag_exs[:50]
89
 
90
 
91
  def fixed_shapes(td):