File size: 3,300 Bytes
2b0bffa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""Plot before/after evaluation curves and reward breakdowns.



Reads two JSONL evaluation files (typically ``eval_pre_train.jsonl`` and

``eval_post_train.jsonl``) produced by ``training.evaluate`` and writes

publication-ready PNGs (Portable Network Graphics) under ``--out_dir``.

"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Any, Dict, List


def _load(path: str) -> List[Dict[str, Any]]:
    eps = []
    with open(path) as f:
        for line in f:
            line = line.strip()
            if line:
                eps.append(json.loads(line))
    return eps


def _summarise(eps: List[Dict[str, Any]]) -> Dict[str, float]:
    if not eps:
        return {"mean": 0.0, "success_rate": 0.0, "mass_acc": 0.0, "channel_acc": 0.0}
    rewards = [float(e.get("cumulative_reward") or 0.0) for e in eps]
    return {
        "mean": sum(rewards) / len(rewards),
        "success_rate": sum(1 for e in eps if e.get("discovered")) / len(eps),
        "mass_acc": sum(1 for e in eps if e.get("correct_mass")) / len(eps),
        "channel_acc": sum(1 for e in eps if e.get("correct_channel")) / len(eps),
    }


def main() -> None:  # pragma: no cover
    parser = argparse.ArgumentParser()
    parser.add_argument("--pre", required=True)
    parser.add_argument("--post", required=True)
    parser.add_argument("--out_dir", default="training/plots")
    args = parser.parse_args()

    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    pre = _load(args.pre)
    post = _load(args.post)
    pre_stats = _summarise(pre)
    post_stats = _summarise(post)

    out = Path(args.out_dir)
    out.mkdir(parents=True, exist_ok=True)

    pre_rewards = [float(e.get("cumulative_reward") or 0.0) for e in pre]
    post_rewards = [float(e.get("cumulative_reward") or 0.0) for e in post]

    fig, ax = plt.subplots(figsize=(7, 4))
    ax.hist(pre_rewards, bins=15, alpha=0.5, label=f"pre (μ={pre_stats['mean']:+.2f})")
    ax.hist(post_rewards, bins=15, alpha=0.5, label=f"post (μ={post_stats['mean']:+.2f})")
    ax.set_xlabel("episode cumulative reward")
    ax.set_ylabel("episode count")
    ax.set_title("CERNenv reward distribution: pre vs post training")
    ax.legend()
    fig.tight_layout()
    fig.savefig(out / "reward_distribution.png", dpi=140)
    plt.close(fig)

    metrics = ["mean", "success_rate", "mass_acc", "channel_acc"]
    pre_vals = [pre_stats[m] for m in metrics]
    post_vals = [post_stats[m] for m in metrics]
    x = list(range(len(metrics)))
    fig, ax = plt.subplots(figsize=(7, 4))
    ax.bar([i - 0.18 for i in x], pre_vals, width=0.36, label="pre")
    ax.bar([i + 0.18 for i in x], post_vals, width=0.36, label="post")
    ax.set_xticks(x)
    ax.set_xticklabels(metrics)
    ax.set_title("Mean reward & accuracy: pre vs post training")
    ax.legend()
    fig.tight_layout()
    fig.savefig(out / "metrics_summary.png", dpi=140)
    plt.close(fig)

    with open(out / "metrics_summary.json", "w") as f:
        json.dump({"pre": pre_stats, "post": post_stats}, f, indent=2)

    print("wrote:", list(out.glob("*")))


if __name__ == "__main__":  # pragma: no cover
    main()