import os import json import numpy as np import matplotlib.pyplot as plt import wandb from typing import Any, Dict, Optional, Union from itt_solver import experiment_driver as ed def _save_phi_image(phi: np.ndarray, path: str): """Save a small visualization of phi to path (PNG).""" fig, ax = plt.subplots(figsize=(4,4)) ax.imshow(phi, cmap='tab20') ax.axis('off') fig.savefig(path, bbox_inches='tight', dpi=150) plt.close(fig) def run_and_log_wandb(task: Dict[str, Any], atomic_library, params: Dict[str, Any], out_dir: str = "experiments", wandb_project: str = "itt_solver", wandb_entity: Optional[str] = None, resume: Union[bool, str] = "allow"): """ Run a single experiment via experiment_driver.run_single and log results to W&B using a context manager (with wandb.init(...) as run: ...). - task: dict with keys 'name','input','target','target_shape' (same format used by experiment_driver). - atomic_library: list of Transform objects for the run. - params: hyperparameter dict passed to run_single (will also be logged to W&B). - out_dir: local directory where run_single writes artifacts (result.json, logs.json, phi_best.npy). - wandb_project: W&B project name. - wandb_entity: optional W&B entity (team/user). - resume: "allow" to log anonymously if no API key is set, or False/None to require login. """ # Run the experiment locally first (this writes files under out_dir) result = ed.run_single(task, atomic_library, params, out_dir) # Compose artifact base name used by run_single (it uses timestamped base) # We expect run_single saved files: _phi_best.npy, _result.json, _logs.json # result dict contains 'transform' and 'final_sigma' and 'sigma_trace' etc. # Find the most recent files for this task in out_dir base_prefix = task.get('name', 'task') # find candidate files saved by run_single (match base prefix and timestamp) files = sorted([f for f in os.listdir(out_dir) if f.startswith(base_prefix) and ("_result.json" in f or "_phi_best.npy" in f or "_logs.json" in f)]) # group by base (strip suffix) bases = sorted({f.rsplit("_result",1)[0].rsplit("_phi_best",1)[0].rsplit("_logs",1)[0] for f in files}) # choose last base if multiple base = bases[-1] if bases else None # Prepare artifact file paths if available phi_path = os.path.join(out_dir, base + "_phi_best.npy") if base else None result_path = os.path.join(out_dir, base + "_result.json") if base else None logs_path = os.path.join(out_dir, base + "_logs.json") if base else None # Generate a safe run name using wandb's built-in ID generator run_id = wandb.util.generate_id() # Start W&B run using context manager with wandb.init(project=wandb_project, entity=wandb_entity, config=params, name=f"{task.get('name','task')}_{run_id}", reinit=True, resume=resume) as run: # Log scalar metrics try: run.log({ "final_sigma": result.get("final_sigma"), "time_s": result.get("time_s"), "states_count": result.get("states_count") }) except Exception: pass # Log sigma trace as a series try: run.log({"sigma_trace": result.get("sigma_trace", [])}) except Exception: pass # Attach artifacts (phi_best, result.json, logs.json) if present try: art = wandb.Artifact(f"{task.get('name','task')}_run", type="itt_run") if phi_path and os.path.exists(phi_path): art.add_file(phi_path, name="phi_best.npy") if result_path and os.path.exists(result_path): art.add_file(result_path, name="result.json") if logs_path and os.path.exists(logs_path): art.add_file(logs_path, name="logs.json") run.log_artifact(art) except Exception: pass # Log a small image preview of phi_best try: if phi_path and os.path.exists(phi_path): phi = np.load(phi_path) tmp_png = os.path.join(out_dir, base + "_phi_preview.png") _save_phi_image(phi, tmp_png) run.log({"phi_best_image": wandb.Image(tmp_png)}) # optionally remove tmp_png try: os.remove(tmp_png) except Exception: pass except Exception: pass # Return the local result dict for further local analysis return result # Example usage (call from a notebook cell): # from itt_solver.wandb_runner import run_and_log_wandb # res = run_and_log_wandb(task, atomic_library, params, out_dir="experiments", # wandb_project="itt_solver", anonymous="allow") # print("W&B logged run, final sigma:", res.get("final_sigma"))