Spaces:
Sleeping
Sleeping
| """Plot before/after evaluation curves and reward breakdowns. | |
| Reads two JSONL evaluation files (typically ``eval_pre_train.jsonl`` and | |
| ``eval_post_train.jsonl``) produced by ``training.evaluate`` and writes | |
| publication-ready PNGs (Portable Network Graphics) under ``--out_dir``. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| def _load(path: str) -> List[Dict[str, Any]]: | |
| eps = [] | |
| with open(path) as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| eps.append(json.loads(line)) | |
| return eps | |
| def _summarise(eps: List[Dict[str, Any]]) -> Dict[str, float]: | |
| if not eps: | |
| return {"mean": 0.0, "success_rate": 0.0, "mass_acc": 0.0, "channel_acc": 0.0} | |
| rewards = [float(e.get("cumulative_reward") or 0.0) for e in eps] | |
| return { | |
| "mean": sum(rewards) / len(rewards), | |
| "success_rate": sum(1 for e in eps if e.get("discovered")) / len(eps), | |
| "mass_acc": sum(1 for e in eps if e.get("correct_mass")) / len(eps), | |
| "channel_acc": sum(1 for e in eps if e.get("correct_channel")) / len(eps), | |
| } | |
| def main() -> None: # pragma: no cover | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--pre", required=True) | |
| parser.add_argument("--post", required=True) | |
| parser.add_argument("--out_dir", default="training/plots") | |
| args = parser.parse_args() | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| pre = _load(args.pre) | |
| post = _load(args.post) | |
| pre_stats = _summarise(pre) | |
| post_stats = _summarise(post) | |
| out = Path(args.out_dir) | |
| out.mkdir(parents=True, exist_ok=True) | |
| pre_rewards = [float(e.get("cumulative_reward") or 0.0) for e in pre] | |
| post_rewards = [float(e.get("cumulative_reward") or 0.0) for e in post] | |
| fig, ax = plt.subplots(figsize=(7, 4)) | |
| ax.hist(pre_rewards, bins=15, alpha=0.5, label=f"pre (μ={pre_stats['mean']:+.2f})") | |
| ax.hist(post_rewards, bins=15, alpha=0.5, label=f"post (μ={post_stats['mean']:+.2f})") | |
| ax.set_xlabel("episode cumulative reward") | |
| ax.set_ylabel("episode count") | |
| ax.set_title("CERNenv reward distribution: pre vs post training") | |
| ax.legend() | |
| fig.tight_layout() | |
| fig.savefig(out / "reward_distribution.png", dpi=140) | |
| plt.close(fig) | |
| metrics = ["mean", "success_rate", "mass_acc", "channel_acc"] | |
| pre_vals = [pre_stats[m] for m in metrics] | |
| post_vals = [post_stats[m] for m in metrics] | |
| x = list(range(len(metrics))) | |
| fig, ax = plt.subplots(figsize=(7, 4)) | |
| ax.bar([i - 0.18 for i in x], pre_vals, width=0.36, label="pre") | |
| ax.bar([i + 0.18 for i in x], post_vals, width=0.36, label="post") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(metrics) | |
| ax.set_title("Mean reward & accuracy: pre vs post training") | |
| ax.legend() | |
| fig.tight_layout() | |
| fig.savefig(out / "metrics_summary.png", dpi=140) | |
| plt.close(fig) | |
| with open(out / "metrics_summary.json", "w") as f: | |
| json.dump({"pre": pre_stats, "post": post_stats}, f, indent=2) | |
| print("wrote:", list(out.glob("*"))) | |
| if __name__ == "__main__": # pragma: no cover | |
| main() | |