| """Plot before/after evaluation curves and reward breakdowns.
|
|
|
| Reads two JSONL evaluation files (typically ``eval_pre_train.jsonl`` and
|
| ``eval_post_train.jsonl``) produced by ``training.evaluate`` and writes
|
| publication-ready PNGs (Portable Network Graphics) under ``--out_dir``.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import json
|
| from pathlib import Path
|
| from typing import Any, Dict, List
|
|
|
|
|
| def _load(path: str) -> List[Dict[str, Any]]:
|
| eps = []
|
| with open(path) as f:
|
| for line in f:
|
| line = line.strip()
|
| if line:
|
| eps.append(json.loads(line))
|
| return eps
|
|
|
|
|
| def _summarise(eps: List[Dict[str, Any]]) -> Dict[str, float]:
|
| if not eps:
|
| return {"mean": 0.0, "success_rate": 0.0, "mass_acc": 0.0, "channel_acc": 0.0}
|
| rewards = [float(e.get("cumulative_reward") or 0.0) for e in eps]
|
| return {
|
| "mean": sum(rewards) / len(rewards),
|
| "success_rate": sum(1 for e in eps if e.get("discovered")) / len(eps),
|
| "mass_acc": sum(1 for e in eps if e.get("correct_mass")) / len(eps),
|
| "channel_acc": sum(1 for e in eps if e.get("correct_channel")) / len(eps),
|
| }
|
|
|
|
|
| def main() -> None:
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--pre", required=True)
|
| parser.add_argument("--post", required=True)
|
| parser.add_argument("--out_dir", default="training/plots")
|
| args = parser.parse_args()
|
|
|
| import matplotlib
|
| matplotlib.use("Agg")
|
| import matplotlib.pyplot as plt
|
|
|
| pre = _load(args.pre)
|
| post = _load(args.post)
|
| pre_stats = _summarise(pre)
|
| post_stats = _summarise(post)
|
|
|
| out = Path(args.out_dir)
|
| out.mkdir(parents=True, exist_ok=True)
|
|
|
| pre_rewards = [float(e.get("cumulative_reward") or 0.0) for e in pre]
|
| post_rewards = [float(e.get("cumulative_reward") or 0.0) for e in post]
|
|
|
| fig, ax = plt.subplots(figsize=(7, 4))
|
| ax.hist(pre_rewards, bins=15, alpha=0.5, label=f"pre (μ={pre_stats['mean']:+.2f})")
|
| ax.hist(post_rewards, bins=15, alpha=0.5, label=f"post (μ={post_stats['mean']:+.2f})")
|
| ax.set_xlabel("episode cumulative reward")
|
| ax.set_ylabel("episode count")
|
| ax.set_title("CERNenv reward distribution: pre vs post training")
|
| ax.legend()
|
| fig.tight_layout()
|
| fig.savefig(out / "reward_distribution.png", dpi=140)
|
| plt.close(fig)
|
|
|
| metrics = ["mean", "success_rate", "mass_acc", "channel_acc"]
|
| pre_vals = [pre_stats[m] for m in metrics]
|
| post_vals = [post_stats[m] for m in metrics]
|
| x = list(range(len(metrics)))
|
| fig, ax = plt.subplots(figsize=(7, 4))
|
| ax.bar([i - 0.18 for i in x], pre_vals, width=0.36, label="pre")
|
| ax.bar([i + 0.18 for i in x], post_vals, width=0.36, label="post")
|
| ax.set_xticks(x)
|
| ax.set_xticklabels(metrics)
|
| ax.set_title("Mean reward & accuracy: pre vs post training")
|
| ax.legend()
|
| fig.tight_layout()
|
| fig.savefig(out / "metrics_summary.png", dpi=140)
|
| plt.close(fig)
|
|
|
| with open(out / "metrics_summary.json", "w") as f:
|
| json.dump({"pre": pre_stats, "post": post_stats}, f, indent=2)
|
|
|
| print("wrote:", list(out.glob("*")))
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|