ARC-AGI / itt_solver /experiment_driver.py
rogermt's picture
Wire 14 new transforms (object + fill + connect + compress) into default_atomic_factory
e0a217d verified
import csv
import os
import time
import json
import importlib
import numpy as np
from itertools import product
from datetime import datetime
from .solver_core import initialize_potential
from .beam_logging import beam_minimize_with_log
def reload_modules():
import itt_solver.solver_core as sc
import itt_solver.beam_logging as bl
import itt_solver.transforms as tr
import itt_solver.gates as gates
import itt_solver.layer_minus_one as l1
importlib.reload(sc); importlib.reload(bl); importlib.reload(tr); importlib.reload(gates); importlib.reload(l1)
def param_grid(grid_dict):
keys = list(grid_dict.keys())
vals = [grid_dict[k] for k in keys]
for combo in product(*vals):
yield dict(zip(keys, combo))
def run_single(task, atomic_library, params, out_dir):
os.makedirs(out_dir, exist_ok=True)
phi_in = initialize_potential(task['input'])
phi_target = initialize_potential(task['target'])
start = time.time()
T_best, phi_best, states, sigmas, logs = beam_minimize_with_log(
phi_in, phi_target, atomic_library,
beam_width=params.get('beam_width',4),
max_depth=params.get('max_depth',3),
lock_coeff=params.get('lock_coeff',0.01),
max_fraction=params.get('max_fraction',0.5),
allowed_symbols=params.get('allowed_symbols', list(range(10))),
enable_layer_minus_one=params.get('enable_layer_minus_one', False),
boundary_source=params.get('boundary_source','target'),
)
elapsed = time.time() - start
result = {
'task_name': task.get('name','task'),
'params': params,
'final_sigma': float(sigmas[-1]) if sigmas else None,
'sigma_trace': [float(s) for s in sigmas],
'time_s': elapsed,
'transform': repr(T_best),
'states_count': len(states),
}
ts = datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
base = f"{task.get('name','task')}_{ts}"
np.save(os.path.join(out_dir, base + "_phi_best.npy"), phi_best)
with open(os.path.join(out_dir, base + "_result.json"), "w") as f:
json.dump(result, f, indent=2)
with open(os.path.join(out_dir, base + "_logs.json"), "w") as f:
json.dump(logs, f, default=str)
return result
def sweep(tasks, atomic_library_factory, grid, out_dir="experiments", max_runs=None):
os.makedirs(out_dir, exist_ok=True)
reload_modules()
csv_path = os.path.join(out_dir, "results.csv")
header_written = os.path.exists(csv_path)
runs = 0
with open(csv_path, "a", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=["task_name","params","final_sigma","time_s","transform","sigma_trace"])
if not header_written:
writer.writeheader()
for params in param_grid(grid):
for task in tasks:
if max_runs and runs >= max_runs:
return
atomic_library = atomic_library_factory(params, task)
res = run_single(task, atomic_library, params, out_dir)
writer.writerow({
"task_name": res['task_name'],
"params": json.dumps(res['params']),
"final_sigma": res['final_sigma'],
"time_s": res['time_s'],
"transform": res['transform'],
"sigma_trace": json.dumps(res['sigma_trace']),
})
csvfile.flush()
runs += 1
return
def default_atomic_factory(params, task):
"""Build the default atomic library for a task.
33 transforms: tiling, Kronecker, mirror, upscale, stack, object extraction,
fill, connect, compress, proximity, border, symmetry, gravity, color ops.
"""
import itt_solver.transforms as tr
from itt_solver.solver_core import tile_transform
target_h, target_w = task['target_shape'][0], task['target_shape'][1]
libs = []
# --- core tiling ---
libs.append(tr.Transform(
lambda p, _h=target_h, _w=target_w: tile_transform(p, (_h, _w)),
"tile_to_target"))
libs.append(tr.tile_to_target_shifted(shift=(1, 1), tile_factor=3))
libs.append(tr.FillEnclosedHarmonic())
# --- Kronecker / self-similar ---
libs.append(tr.KroneckerSelfSimilar())
libs.append(tr.KroneckerSelfSimilarInv())
# --- mirror / kaleidoscope ---
libs.append(tr.MirrorTileH())
libs.append(tr.MirrorTileV())
libs.append(tr.MirrorTile4Way())
# --- upscale ---
libs.append(tr.Upscale(2))
libs.append(tr.Upscale(3))
# --- stacking ---
libs.append(tr.StackH(3))
libs.append(tr.StackV(3))
# --- structural ---
libs.append(tr.Transpose())
libs.append(tr.CropToContent())
# --- object extraction ---
libs.append(tr.ExtractLargestObject())
libs.append(tr.ExtractSmallestObject())
libs.append(tr.ExtractUniqueObject())
libs.append(tr.ExtractMostCommonObject())
libs.append(tr.KeepLargestObject())
libs.append(tr.KeepSmallestObject())
libs.append(tr.SortObjectsBySize())
# --- fill / connect / compress ---
libs.append(tr.FillInterior())
libs.append(tr.ConnectSameColorH())
libs.append(tr.ConnectSameColorV())
libs.append(tr.CompressGrid())
libs.append(tr.RemoveBlackLines())
libs.append(tr.ColorByProximity())
libs.append(tr.DrawBorder())
# --- symmetry ---
if params.get('use_symmetry', True):
libs.append(tr.Rotate(1))
libs.append(tr.Rotate(2))
libs.append(tr.Rotate(3))
libs.append(tr.Reflect('h'))
libs.append(tr.Reflect('v'))
# --- gravity ---
if params.get('use_gravity', False):
libs.append(tr.GravityDown())
libs.append(tr.GravityUp())
# --- color ops ---
if params.get('use_color_ops', False):
libs.append(tr.InvertColors())
return libs