rogermt commited on
Commit
e0a217d
·
verified ·
1 Parent(s): e0d2eb1

Wire 14 new transforms (object + fill + connect + compress) into default_atomic_factory

Browse files
Files changed (1) hide show
  1. itt_solver/experiment_driver.py +25 -16
itt_solver/experiment_driver.py CHANGED
@@ -24,13 +24,7 @@ def param_grid(grid_dict):
24
  yield dict(zip(keys, combo))
25
 
26
  def run_single(task, atomic_library, params, out_dir):
27
- """
28
- Run one experiment and save artifacts to out_dir.
29
- This function does NOT call W&B itself to avoid recursion; call the external
30
- uploader (itt_solver.wandb_runner.run_and_log_wandb) after this returns if desired.
31
- """
32
  os.makedirs(out_dir, exist_ok=True)
33
-
34
  phi_in = initialize_potential(task['input'])
35
  phi_target = initialize_potential(task['target'])
36
  start = time.time()
@@ -54,7 +48,6 @@ def run_single(task, atomic_library, params, out_dir):
54
  'transform': repr(T_best),
55
  'states_count': len(states),
56
  }
57
- # save best field and logs
58
  ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
59
  base = f"{task.get('name','task')}_{ts}"
60
  np.save(os.path.join(out_dir, base + "_phi_best.npy"), phi_best)
@@ -62,7 +55,6 @@ def run_single(task, atomic_library, params, out_dir):
62
  json.dump(result, f, indent=2)
63
  with open(os.path.join(out_dir, base + "_logs.json"), "w") as f:
64
  json.dump(logs, f, default=str)
65
-
66
  return result
67
 
68
  def sweep(tasks, atomic_library_factory, grid, out_dir="experiments", max_runs=None):
@@ -97,13 +89,12 @@ def sweep(tasks, atomic_library_factory, grid, out_dir="experiments", max_runs=N
97
  def default_atomic_factory(params, task):
98
  """Build the default atomic library for a task.
99
 
100
- Includes the original transforms plus the new Kronecker / mirror / upscale
101
- family so the beam can express a much wider range of ARC patterns.
102
  """
103
  import itt_solver.transforms as tr
104
  from itt_solver.solver_core import tile_transform
105
 
106
- # Capture target_shape by value to avoid late-binding closure bug
107
  target_h, target_w = task['target_shape'][0], task['target_shape'][1]
108
 
109
  libs = []
@@ -115,7 +106,7 @@ def default_atomic_factory(params, task):
115
  libs.append(tr.tile_to_target_shifted(shift=(1, 1), tile_factor=3))
116
  libs.append(tr.FillEnclosedHarmonic())
117
 
118
- # --- Kronecker / self-similar (covers 007bbfb7 family) ---
119
  libs.append(tr.KroneckerSelfSimilar())
120
  libs.append(tr.KroneckerSelfSimilarInv())
121
 
@@ -124,7 +115,7 @@ def default_atomic_factory(params, task):
124
  libs.append(tr.MirrorTileV())
125
  libs.append(tr.MirrorTile4Way())
126
 
127
- # --- upscale (2x and 3x) ---
128
  libs.append(tr.Upscale(2))
129
  libs.append(tr.Upscale(3))
130
 
@@ -136,7 +127,25 @@ def default_atomic_factory(params, task):
136
  libs.append(tr.Transpose())
137
  libs.append(tr.CropToContent())
138
 
139
- # --- symmetry (optional) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  if params.get('use_symmetry', True):
141
  libs.append(tr.Rotate(1))
142
  libs.append(tr.Rotate(2))
@@ -144,12 +153,12 @@ def default_atomic_factory(params, task):
144
  libs.append(tr.Reflect('h'))
145
  libs.append(tr.Reflect('v'))
146
 
147
- # --- gravity (optional) ---
148
  if params.get('use_gravity', False):
149
  libs.append(tr.GravityDown())
150
  libs.append(tr.GravityUp())
151
 
152
- # --- color ops (optional, for multi-color tasks) ---
153
  if params.get('use_color_ops', False):
154
  libs.append(tr.InvertColors())
155
 
 
24
  yield dict(zip(keys, combo))
25
 
26
  def run_single(task, atomic_library, params, out_dir):
 
 
 
 
 
27
  os.makedirs(out_dir, exist_ok=True)
 
28
  phi_in = initialize_potential(task['input'])
29
  phi_target = initialize_potential(task['target'])
30
  start = time.time()
 
48
  'transform': repr(T_best),
49
  'states_count': len(states),
50
  }
 
51
  ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
52
  base = f"{task.get('name','task')}_{ts}"
53
  np.save(os.path.join(out_dir, base + "_phi_best.npy"), phi_best)
 
55
  json.dump(result, f, indent=2)
56
  with open(os.path.join(out_dir, base + "_logs.json"), "w") as f:
57
  json.dump(logs, f, default=str)
 
58
  return result
59
 
60
  def sweep(tasks, atomic_library_factory, grid, out_dir="experiments", max_runs=None):
 
89
  def default_atomic_factory(params, task):
90
  """Build the default atomic library for a task.
91
 
92
+ 33 transforms: tiling, Kronecker, mirror, upscale, stack, object extraction,
93
+ fill, connect, compress, proximity, border, symmetry, gravity, color ops.
94
  """
95
  import itt_solver.transforms as tr
96
  from itt_solver.solver_core import tile_transform
97
 
 
98
  target_h, target_w = task['target_shape'][0], task['target_shape'][1]
99
 
100
  libs = []
 
106
  libs.append(tr.tile_to_target_shifted(shift=(1, 1), tile_factor=3))
107
  libs.append(tr.FillEnclosedHarmonic())
108
 
109
+ # --- Kronecker / self-similar ---
110
  libs.append(tr.KroneckerSelfSimilar())
111
  libs.append(tr.KroneckerSelfSimilarInv())
112
 
 
115
  libs.append(tr.MirrorTileV())
116
  libs.append(tr.MirrorTile4Way())
117
 
118
+ # --- upscale ---
119
  libs.append(tr.Upscale(2))
120
  libs.append(tr.Upscale(3))
121
 
 
127
  libs.append(tr.Transpose())
128
  libs.append(tr.CropToContent())
129
 
130
+ # --- object extraction ---
131
+ libs.append(tr.ExtractLargestObject())
132
+ libs.append(tr.ExtractSmallestObject())
133
+ libs.append(tr.ExtractUniqueObject())
134
+ libs.append(tr.ExtractMostCommonObject())
135
+ libs.append(tr.KeepLargestObject())
136
+ libs.append(tr.KeepSmallestObject())
137
+ libs.append(tr.SortObjectsBySize())
138
+
139
+ # --- fill / connect / compress ---
140
+ libs.append(tr.FillInterior())
141
+ libs.append(tr.ConnectSameColorH())
142
+ libs.append(tr.ConnectSameColorV())
143
+ libs.append(tr.CompressGrid())
144
+ libs.append(tr.RemoveBlackLines())
145
+ libs.append(tr.ColorByProximity())
146
+ libs.append(tr.DrawBorder())
147
+
148
+ # --- symmetry ---
149
  if params.get('use_symmetry', True):
150
  libs.append(tr.Rotate(1))
151
  libs.append(tr.Rotate(2))
 
153
  libs.append(tr.Reflect('h'))
154
  libs.append(tr.Reflect('v'))
155
 
156
+ # --- gravity ---
157
  if params.get('use_gravity', False):
158
  libs.append(tr.GravityDown())
159
  libs.append(tr.GravityUp())
160
 
161
+ # --- color ops ---
162
  if params.get('use_color_ops', False):
163
  libs.append(tr.InvertColors())
164