anugrah55's picture
Update CERNenv Space
5f78183 verified
"""Plot before/after evaluation curves and reward breakdowns.
Reads two JSONL evaluation files (typically ``eval_pre_train.jsonl`` and
``eval_post_train.jsonl``) produced by ``training.evaluate`` and writes
publication-ready PNGs (Portable Network Graphics) under ``--out_dir``.
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List
def _load(path: str) -> List[Dict[str, Any]]:
eps = []
with open(path) as f:
for line in f:
line = line.strip()
if line:
eps.append(json.loads(line))
return eps
def _summarise(eps: List[Dict[str, Any]]) -> Dict[str, float]:
if not eps:
return {"mean": 0.0, "success_rate": 0.0, "mass_acc": 0.0, "channel_acc": 0.0}
rewards = [float(e.get("cumulative_reward") or 0.0) for e in eps]
return {
"mean": sum(rewards) / len(rewards),
"success_rate": sum(1 for e in eps if e.get("discovered")) / len(eps),
"mass_acc": sum(1 for e in eps if e.get("correct_mass")) / len(eps),
"channel_acc": sum(1 for e in eps if e.get("correct_channel")) / len(eps),
}
def main() -> None: # pragma: no cover
parser = argparse.ArgumentParser()
parser.add_argument("--pre", required=True)
parser.add_argument("--post", required=True)
parser.add_argument("--out_dir", default="training/plots")
args = parser.parse_args()
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
pre = _load(args.pre)
post = _load(args.post)
pre_stats = _summarise(pre)
post_stats = _summarise(post)
out = Path(args.out_dir)
out.mkdir(parents=True, exist_ok=True)
pre_rewards = [float(e.get("cumulative_reward") or 0.0) for e in pre]
post_rewards = [float(e.get("cumulative_reward") or 0.0) for e in post]
fig, ax = plt.subplots(figsize=(7, 4))
ax.hist(pre_rewards, bins=15, alpha=0.5, label=f"pre (μ={pre_stats['mean']:+.2f})")
ax.hist(post_rewards, bins=15, alpha=0.5, label=f"post (μ={post_stats['mean']:+.2f})")
ax.set_xlabel("episode cumulative reward")
ax.set_ylabel("episode count")
ax.set_title("CERNenv reward distribution: pre vs post training")
ax.legend()
fig.tight_layout()
fig.savefig(out / "reward_distribution.png", dpi=140)
plt.close(fig)
metrics = ["mean", "success_rate", "mass_acc", "channel_acc"]
pre_vals = [pre_stats[m] for m in metrics]
post_vals = [post_stats[m] for m in metrics]
x = list(range(len(metrics)))
fig, ax = plt.subplots(figsize=(7, 4))
ax.bar([i - 0.18 for i in x], pre_vals, width=0.36, label="pre")
ax.bar([i + 0.18 for i in x], post_vals, width=0.36, label="post")
ax.set_xticks(x)
ax.set_xticklabels(metrics)
ax.set_title("Mean reward & accuracy: pre vs post training")
ax.legend()
fig.tight_layout()
fig.savefig(out / "metrics_summary.png", dpi=140)
plt.close(fig)
with open(out / "metrics_summary.json", "w") as f:
json.dump({"pre": pre_stats, "post": post_stats}, f, indent=2)
print("wrote:", list(out.glob("*")))
if __name__ == "__main__": # pragma: no cover
main()