File size: 5,159 Bytes
b48dd06
 
 
 
 
 
056be36
b48dd06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
056be36
 
b48dd06
 
 
 
 
 
 
 
 
 
056be36
b48dd06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
056be36
 
 
b48dd06
 
 
 
056be36
b48dd06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import json
import numpy as np
import matplotlib.pyplot as plt
import wandb

from typing import Any, Dict, Optional, Union
from itt_solver import experiment_driver as ed

def _save_phi_image(phi: np.ndarray, path: str):
    """Save a small visualization of phi to path (PNG)."""
    fig, ax = plt.subplots(figsize=(4,4))
    ax.imshow(phi, cmap='tab20')
    ax.axis('off')
    fig.savefig(path, bbox_inches='tight', dpi=150)
    plt.close(fig)

def run_and_log_wandb(task: Dict[str, Any],
                      atomic_library,
                      params: Dict[str, Any],
                      out_dir: str = "experiments",
                      wandb_project: str = "itt_solver",
                      wandb_entity: Optional[str] = None,
                      resume: Union[bool, str] = "allow"):
    """
    Run a single experiment via experiment_driver.run_single and log results to W&B
    using a context manager (with wandb.init(...) as run: ...).

    - task: dict with keys 'name','input','target','target_shape' (same format used by experiment_driver).
    - atomic_library: list of Transform objects for the run.
    - params: hyperparameter dict passed to run_single (will also be logged to W&B).
    - out_dir: local directory where run_single writes artifacts (result.json, logs.json, phi_best.npy).
    - wandb_project: W&B project name.
    - wandb_entity: optional W&B entity (team/user).
    - resume: "allow" to log anonymously if no API key is set, or False/None to require login.
    """
    # Run the experiment locally first (this writes files under out_dir)
    result = ed.run_single(task, atomic_library, params, out_dir)

    # Compose artifact base name used by run_single (it uses timestamped base)
    # We expect run_single saved files: <base>_phi_best.npy, <base>_result.json, <base>_logs.json
    # result dict contains 'transform' and 'final_sigma' and 'sigma_trace' etc.
    # Find the most recent files for this task in out_dir
    base_prefix = task.get('name', 'task')
    # find candidate files saved by run_single (match base prefix and timestamp)
    files = sorted([f for f in os.listdir(out_dir) if f.startswith(base_prefix) and ("_result.json" in f or "_phi_best.npy" in f or "_logs.json" in f)])
    # group by base (strip suffix)
    bases = sorted({f.rsplit("_result",1)[0].rsplit("_phi_best",1)[0].rsplit("_logs",1)[0] for f in files})
    # choose last base if multiple
    base = bases[-1] if bases else None

    # Prepare artifact file paths if available
    phi_path = os.path.join(out_dir, base + "_phi_best.npy") if base else None
    result_path = os.path.join(out_dir, base + "_result.json") if base else None
    logs_path = os.path.join(out_dir, base + "_logs.json") if base else None

    # Generate a safe run name using wandb's built-in ID generator
    run_id = wandb.util.generate_id()

    # Start W&B run using context manager
    with wandb.init(project=wandb_project,
                    entity=wandb_entity,
                    config=params,
                    name=f"{task.get('name','task')}_{run_id}",
                    reinit=True,
                    resume=resume) as run:

        # Log scalar metrics
        try:
            run.log({
                "final_sigma": result.get("final_sigma"),
                "time_s": result.get("time_s"),
                "states_count": result.get("states_count")
            })
        except Exception:
            pass

        # Log sigma trace as a series
        try:
            run.log({"sigma_trace": result.get("sigma_trace", [])})
        except Exception:
            pass

        # Attach artifacts (phi_best, result.json, logs.json) if present
        try:
            art = wandb.Artifact(f"{task.get('name','task')}_run", type="itt_run")
            if phi_path and os.path.exists(phi_path):
                art.add_file(phi_path, name="phi_best.npy")
            if result_path and os.path.exists(result_path):
                art.add_file(result_path, name="result.json")
            if logs_path and os.path.exists(logs_path):
                art.add_file(logs_path, name="logs.json")
            run.log_artifact(art)
        except Exception:
            pass

        # Log a small image preview of phi_best
        try:
            if phi_path and os.path.exists(phi_path):
                phi = np.load(phi_path)
                tmp_png = os.path.join(out_dir, base + "_phi_preview.png")
                _save_phi_image(phi, tmp_png)
                run.log({"phi_best_image": wandb.Image(tmp_png)})
                # optionally remove tmp_png
                try:
                    os.remove(tmp_png)
                except Exception:
                    pass
        except Exception:
            pass

    # Return the local result dict for further local analysis
    return result


# Example usage (call from a notebook cell):
# from itt_solver.wandb_runner import run_and_log_wandb
# res = run_and_log_wandb(task, atomic_library, params, out_dir="experiments",
#                         wandb_project="itt_solver", anonymous="allow")
# print("W&B logged run, final sigma:", res.get("final_sigma"))