import os
import json
import numpy as np
import matplotlib.pyplot as plt
import wandb
from typing import Any, Dict, Optional, Union
from itt_solver import experiment_driver as ed
def _save_phi_image(phi: np.ndarray, path: str):
"""Save a small visualization of phi to path (PNG)."""
fig, ax = plt.subplots(figsize=(4,4))
ax.imshow(phi, cmap='tab20')
ax.axis('off')
fig.savefig(path, bbox_inches='tight', dpi=150)
plt.close(fig)
def run_and_log_wandb(task: Dict[str, Any],
atomic_library,
params: Dict[str, Any],
out_dir: str = "experiments",
wandb_project: str = "itt_solver",
wandb_entity: Optional[str] = None,
resume: Union[bool, str] = "allow"):
"""
Run a single experiment via experiment_driver.run_single and log results to W&B
using a context manager (with wandb.init(...) as run: ...).
- task: dict with keys 'name','input','target','target_shape' (same format used by experiment_driver).
- atomic_library: list of Transform objects for the run.
- params: hyperparameter dict passed to run_single (will also be logged to W&B).
- out_dir: local directory where run_single writes artifacts (result.json, logs.json, phi_best.npy).
- wandb_project: W&B project name.
- wandb_entity: optional W&B entity (team/user).
- resume: "allow" to log anonymously if no API key is set, or False/None to require login.
"""
# Run the experiment locally first (this writes files under out_dir)
result = ed.run_single(task, atomic_library, params, out_dir)
# Compose artifact base name used by run_single (it uses timestamped base)
# We expect run_single saved files: _phi_best.npy, _result.json, _logs.json
# result dict contains 'transform' and 'final_sigma' and 'sigma_trace' etc.
# Find the most recent files for this task in out_dir
base_prefix = task.get('name', 'task')
# find candidate files saved by run_single (match base prefix and timestamp)
files = sorted([f for f in os.listdir(out_dir) if f.startswith(base_prefix) and ("_result.json" in f or "_phi_best.npy" in f or "_logs.json" in f)])
# group by base (strip suffix)
bases = sorted({f.rsplit("_result",1)[0].rsplit("_phi_best",1)[0].rsplit("_logs",1)[0] for f in files})
# choose last base if multiple
base = bases[-1] if bases else None
# Prepare artifact file paths if available
phi_path = os.path.join(out_dir, base + "_phi_best.npy") if base else None
result_path = os.path.join(out_dir, base + "_result.json") if base else None
logs_path = os.path.join(out_dir, base + "_logs.json") if base else None
# Generate a safe run name using wandb's built-in ID generator
run_id = wandb.util.generate_id()
# Start W&B run using context manager
with wandb.init(project=wandb_project,
entity=wandb_entity,
config=params,
name=f"{task.get('name','task')}_{run_id}",
reinit=True,
resume=resume) as run:
# Log scalar metrics
try:
run.log({
"final_sigma": result.get("final_sigma"),
"time_s": result.get("time_s"),
"states_count": result.get("states_count")
})
except Exception:
pass
# Log sigma trace as a series
try:
run.log({"sigma_trace": result.get("sigma_trace", [])})
except Exception:
pass
# Attach artifacts (phi_best, result.json, logs.json) if present
try:
art = wandb.Artifact(f"{task.get('name','task')}_run", type="itt_run")
if phi_path and os.path.exists(phi_path):
art.add_file(phi_path, name="phi_best.npy")
if result_path and os.path.exists(result_path):
art.add_file(result_path, name="result.json")
if logs_path and os.path.exists(logs_path):
art.add_file(logs_path, name="logs.json")
run.log_artifact(art)
except Exception:
pass
# Log a small image preview of phi_best
try:
if phi_path and os.path.exists(phi_path):
phi = np.load(phi_path)
tmp_png = os.path.join(out_dir, base + "_phi_preview.png")
_save_phi_image(phi, tmp_png)
run.log({"phi_best_image": wandb.Image(tmp_png)})
# optionally remove tmp_png
try:
os.remove(tmp_png)
except Exception:
pass
except Exception:
pass
# Return the local result dict for further local analysis
return result
# Example usage (call from a notebook cell):
# from itt_solver.wandb_runner import run_and_log_wandb
# res = run_and_log_wandb(task, atomic_library, params, out_dir="experiments",
# wandb_project="itt_solver", anonymous="allow")
# print("W&B logged run, final sigma:", res.get("final_sigma"))