"""Plot before/after evaluation curves and reward breakdowns. Reads two JSONL evaluation files (typically ``eval_pre_train.jsonl`` and ``eval_post_train.jsonl``) produced by ``training.evaluate`` and writes publication-ready PNGs (Portable Network Graphics) under ``--out_dir``. """ from __future__ import annotations import argparse import json from pathlib import Path from typing import Any, Dict, List def _load(path: str) -> List[Dict[str, Any]]: eps = [] with open(path) as f: for line in f: line = line.strip() if line: eps.append(json.loads(line)) return eps def _summarise(eps: List[Dict[str, Any]]) -> Dict[str, float]: if not eps: return {"mean": 0.0, "success_rate": 0.0, "mass_acc": 0.0, "channel_acc": 0.0} rewards = [float(e.get("cumulative_reward") or 0.0) for e in eps] return { "mean": sum(rewards) / len(rewards), "success_rate": sum(1 for e in eps if e.get("discovered")) / len(eps), "mass_acc": sum(1 for e in eps if e.get("correct_mass")) / len(eps), "channel_acc": sum(1 for e in eps if e.get("correct_channel")) / len(eps), } def main() -> None: # pragma: no cover parser = argparse.ArgumentParser() parser.add_argument("--pre", required=True) parser.add_argument("--post", required=True) parser.add_argument("--out_dir", default="training/plots") args = parser.parse_args() import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt pre = _load(args.pre) post = _load(args.post) pre_stats = _summarise(pre) post_stats = _summarise(post) out = Path(args.out_dir) out.mkdir(parents=True, exist_ok=True) pre_rewards = [float(e.get("cumulative_reward") or 0.0) for e in pre] post_rewards = [float(e.get("cumulative_reward") or 0.0) for e in post] fig, ax = plt.subplots(figsize=(7, 4)) ax.hist(pre_rewards, bins=15, alpha=0.5, label=f"pre (μ={pre_stats['mean']:+.2f})") ax.hist(post_rewards, bins=15, alpha=0.5, label=f"post (μ={post_stats['mean']:+.2f})") ax.set_xlabel("episode cumulative reward") ax.set_ylabel("episode count") ax.set_title("CERNenv reward distribution: pre vs post training") ax.legend() fig.tight_layout() fig.savefig(out / "reward_distribution.png", dpi=140) plt.close(fig) metrics = ["mean", "success_rate", "mass_acc", "channel_acc"] pre_vals = [pre_stats[m] for m in metrics] post_vals = [post_stats[m] for m in metrics] x = list(range(len(metrics))) fig, ax = plt.subplots(figsize=(7, 4)) ax.bar([i - 0.18 for i in x], pre_vals, width=0.36, label="pre") ax.bar([i + 0.18 for i in x], post_vals, width=0.36, label="post") ax.set_xticks(x) ax.set_xticklabels(metrics) ax.set_title("Mean reward & accuracy: pre vs post training") ax.legend() fig.tight_layout() fig.savefig(out / "metrics_summary.png", dpi=140) plt.close(fig) with open(out / "metrics_summary.json", "w") as f: json.dump({"pre": pre_stats, "post": post_stats}, f, indent=2) print("wrote:", list(out.glob("*"))) if __name__ == "__main__": # pragma: no cover main()