ARC-AGI / itt_solver /wandb_runner.py
rogermt's picture
Fix #4: Fix wandb generate_id crash — use string directly instead of int(id, 36)
056be36 verified
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import wandb
from typing import Any, Dict, Optional, Union
from itt_solver import experiment_driver as ed
def _save_phi_image(phi: np.ndarray, path: str):
"""Save a small visualization of phi to path (PNG)."""
fig, ax = plt.subplots(figsize=(4,4))
ax.imshow(phi, cmap='tab20')
ax.axis('off')
fig.savefig(path, bbox_inches='tight', dpi=150)
plt.close(fig)
def run_and_log_wandb(task: Dict[str, Any],
atomic_library,
params: Dict[str, Any],
out_dir: str = "experiments",
wandb_project: str = "itt_solver",
wandb_entity: Optional[str] = None,
resume: Union[bool, str] = "allow"):
"""
Run a single experiment via experiment_driver.run_single and log results to W&B
using a context manager (with wandb.init(...) as run: ...).
- task: dict with keys 'name','input','target','target_shape' (same format used by experiment_driver).
- atomic_library: list of Transform objects for the run.
- params: hyperparameter dict passed to run_single (will also be logged to W&B).
- out_dir: local directory where run_single writes artifacts (result.json, logs.json, phi_best.npy).
- wandb_project: W&B project name.
- wandb_entity: optional W&B entity (team/user).
- resume: "allow" to log anonymously if no API key is set, or False/None to require login.
"""
# Run the experiment locally first (this writes files under out_dir)
result = ed.run_single(task, atomic_library, params, out_dir)
# Compose artifact base name used by run_single (it uses timestamped base)
# We expect run_single saved files: <base>_phi_best.npy, <base>_result.json, <base>_logs.json
# result dict contains 'transform' and 'final_sigma' and 'sigma_trace' etc.
# Find the most recent files for this task in out_dir
base_prefix = task.get('name', 'task')
# find candidate files saved by run_single (match base prefix and timestamp)
files = sorted([f for f in os.listdir(out_dir) if f.startswith(base_prefix) and ("_result.json" in f or "_phi_best.npy" in f or "_logs.json" in f)])
# group by base (strip suffix)
bases = sorted({f.rsplit("_result",1)[0].rsplit("_phi_best",1)[0].rsplit("_logs",1)[0] for f in files})
# choose last base if multiple
base = bases[-1] if bases else None
# Prepare artifact file paths if available
phi_path = os.path.join(out_dir, base + "_phi_best.npy") if base else None
result_path = os.path.join(out_dir, base + "_result.json") if base else None
logs_path = os.path.join(out_dir, base + "_logs.json") if base else None
# Generate a safe run name using wandb's built-in ID generator
run_id = wandb.util.generate_id()
# Start W&B run using context manager
with wandb.init(project=wandb_project,
entity=wandb_entity,
config=params,
name=f"{task.get('name','task')}_{run_id}",
reinit=True,
resume=resume) as run:
# Log scalar metrics
try:
run.log({
"final_sigma": result.get("final_sigma"),
"time_s": result.get("time_s"),
"states_count": result.get("states_count")
})
except Exception:
pass
# Log sigma trace as a series
try:
run.log({"sigma_trace": result.get("sigma_trace", [])})
except Exception:
pass
# Attach artifacts (phi_best, result.json, logs.json) if present
try:
art = wandb.Artifact(f"{task.get('name','task')}_run", type="itt_run")
if phi_path and os.path.exists(phi_path):
art.add_file(phi_path, name="phi_best.npy")
if result_path and os.path.exists(result_path):
art.add_file(result_path, name="result.json")
if logs_path and os.path.exists(logs_path):
art.add_file(logs_path, name="logs.json")
run.log_artifact(art)
except Exception:
pass
# Log a small image preview of phi_best
try:
if phi_path and os.path.exists(phi_path):
phi = np.load(phi_path)
tmp_png = os.path.join(out_dir, base + "_phi_preview.png")
_save_phi_image(phi, tmp_png)
run.log({"phi_best_image": wandb.Image(tmp_png)})
# optionally remove tmp_png
try:
os.remove(tmp_png)
except Exception:
pass
except Exception:
pass
# Return the local result dict for further local analysis
return result
# Example usage (call from a notebook cell):
# from itt_solver.wandb_runner import run_and_log_wandb
# res = run_and_log_wandb(task, atomic_library, params, out_dir="experiments",
# wandb_project="itt_solver", anonymous="allow")
# print("W&B logged run, final sigma:", res.get("final_sigma"))