"""Aggregate pre/post evaluator JSONL into a before/after evidence pack. Reads ``evidence/pre_eval.jsonl`` + ``evidence/post_eval.jsonl`` (one JSONL row per episode, written by ``training/evaluate.py``) and emits the side-by-side artifacts the trainer dashboard surfaces: * ``evidence/before_after_metrics.json`` — ``{pre, post, delta}``, with the headline metrics every reviewer asks for first. * ``evidence/reward_components.csv`` — one row per (tag, episode) with all 8 component columns + total, ready to pivot in any spreadsheet. * ``evidence/before_after_summary.png`` — grouped bar chart of pre vs post for each headline metric. * ``evidence/reward_distribution.png`` — overlapping histograms of cumulative reward, pre vs post. * ``evidence/reward_components.png`` — grouped bar chart of mean reward components, pre vs post. * ``evidence/sample_trajectories.md`` — pretty-printed worst-3 pre + best-3 post episodes, with action sequences and per-step rewards. CLI:: python -m training.summarize \\ --evidence_dir evidence \\ --pre pre_eval.jsonl --post post_eval.jsonl """ from __future__ import annotations import argparse import csv import json import logging import sys from pathlib import Path from typing import Any, Dict, List, Optional, Sequence logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger("training.summarize") # Mirrors ``training/evaluate.py::_REWARD_COMPONENTS`` (cannot import # directly because this module must run with only stdlib + matplotlib # in case the heavy ML deps are not installed at summarisation time). _REWARD_COMPONENTS = ( "evidence_coverage", "decision_accuracy", "credit_efficiency", "reasoning_coherence", "novelty", "penalty", "shaping", "terminal", ) _HEADLINE_METRICS = ( "n", "mean_reward", "median_reward", "success_rate", "decision_accuracy", "evidence_coverage", ) def _parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--evidence_dir", default="evidence") parser.add_argument("--pre", default="pre_eval.jsonl", help="Filename of the pre-train JSONL inside evidence_dir.") parser.add_argument("--post", default="post_eval.jsonl", help="Filename of the post-train JSONL inside evidence_dir.") parser.add_argument("--metrics_out", default="before_after_metrics.json") parser.add_argument("--components_csv", default="reward_components.csv") parser.add_argument("--summary_png", default="before_after_summary.png") parser.add_argument("--distribution_png", default="reward_distribution.png") parser.add_argument("--components_png", default="reward_components.png") parser.add_argument("--samples_md", default="sample_trajectories.md") parser.add_argument("--n_samples", type=int, default=3, help="Number of worst-pre + best-post episodes to dump.") return parser.parse_args(argv) def _load_jsonl(path: Path) -> List[Dict[str, Any]]: if not path.exists(): logger.warning("eval file missing: %s", path) return [] rows: List[Dict[str, Any]] = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: rows.append(json.loads(line)) except json.JSONDecodeError as exc: logger.warning("skipping unparseable JSONL line in %s: %s", path, exc) return rows def _success(row: Dict[str, Any]) -> bool: """Reviewer-facing 'success' (submitted AND right AND positive reward).""" return bool( row.get("submitted") and row.get("submitted_decision") == row.get("correct_decision") and float(row.get("cumulative_reward", 0.0)) > 0.0 ) def _compute_aggregate(rows: Sequence[Dict[str, Any]]) -> Dict[str, float]: """Compute the headline metrics for one tag.""" n = len(rows) if n == 0: return {k: 0.0 for k in _HEADLINE_METRICS} | {"n": 0} rewards = [float(r.get("cumulative_reward", 0.0)) for r in rows] rewards_sorted = sorted(rewards) median = rewards_sorted[n // 2] if n % 2 else 0.5 * (rewards_sorted[n // 2 - 1] + rewards_sorted[n // 2]) return { "n": n, "mean_reward": sum(rewards) / n, "median_reward": median, "success_rate": sum(1 for r in rows if _success(r)) / n, "decision_accuracy": sum(float(r.get("decision_accuracy", 0.0)) for r in rows) / n, "evidence_coverage": sum(float(r.get("evidence_coverage", 0.0)) for r in rows) / n, } def _delta(pre: Dict[str, float], post: Dict[str, float]) -> Dict[str, float]: out: Dict[str, float] = {} for k in _HEADLINE_METRICS: try: out[k] = float(post.get(k, 0.0)) - float(pre.get(k, 0.0)) except (TypeError, ValueError): out[k] = 0.0 return out def _write_metrics_json( *, pre: Dict[str, float], post: Dict[str, float], delta: Dict[str, float], path: Path, ) -> None: payload = {"pre": pre, "post": post, "delta": delta} path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(payload, indent=2), encoding="utf-8") logger.info("wrote %s", path) def _write_components_csv( *, pre_rows: Sequence[Dict[str, Any]], post_rows: Sequence[Dict[str, Any]], path: Path, ) -> None: """One row per (tag, episode) with all 8 components + total.""" fields = ["tag", "episode", "scenario", "cumulative_reward", *_REWARD_COMPONENTS, "total"] path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fields) writer.writeheader() for tag_rows, tag in ((pre_rows, "pre"), (post_rows, "post")): for r in tag_rows: comps = r.get("reward_components_total") or {} row = { "tag": tag, "episode": r.get("episode"), "scenario": r.get("scenario"), "cumulative_reward": round(float(r.get("cumulative_reward", 0.0)), 6), "total": round(sum(float(comps.get(c, 0.0)) for c in _REWARD_COMPONENTS), 6), } for c in _REWARD_COMPONENTS: row[c] = round(float(comps.get(c, 0.0)), 6) writer.writerow(row) logger.info("wrote %s", path) def _try_matplotlib(): """Returns (plt, np) or (None, None) when matplotlib is unavailable.""" try: import matplotlib # type: ignore matplotlib.use("Agg") import matplotlib.pyplot as plt # type: ignore except Exception as exc: # pragma: no cover - plotting best-effort logger.warning("matplotlib unavailable (%s); skipping plots", exc) return None, None try: import numpy as np # type: ignore except Exception: np = None # type: ignore return plt, np def _write_summary_png( *, pre: Dict[str, float], post: Dict[str, float], path: Path, ) -> None: plt, np = _try_matplotlib() if plt is None: return metrics = [ ("mean_reward", "Mean reward"), ("median_reward", "Median reward"), ("success_rate", "Success rate"), ("decision_accuracy", "Decision accuracy"), ("evidence_coverage", "Evidence coverage"), ] labels = [lbl for _, lbl in metrics] pre_vals = [float(pre.get(k, 0.0)) for k, _ in metrics] post_vals = [float(post.get(k, 0.0)) for k, _ in metrics] n = len(metrics) x = list(range(n)) width = 0.36 fig, ax = plt.subplots(figsize=(9, 5)) bar_pre = ax.bar([xi - width / 2 for xi in x], pre_vals, width=width, label=f"pre (n={int(pre.get('n', 0))})", color="#94a3b8") bar_post = ax.bar([xi + width / 2 for xi in x], post_vals, width=width, label=f"post (n={int(post.get('n', 0))})", color="#1d4ed8") ax.set_xticks(x) ax.set_xticklabels(labels, rotation=14, ha="right") ax.set_ylabel("metric value") ax.set_title("DrugEnv before vs after — headline metrics") ax.grid(alpha=0.25, axis="y") ax.axhline(0, color="black", lw=0.8) for bar_group in (bar_pre, bar_post): for rect in bar_group: h = rect.get_height() ax.text(rect.get_x() + rect.get_width() / 2, h, f"{h:.2f}", ha="center", va="bottom" if h >= 0 else "top", fontsize=8) ax.legend(loc="best") fig.tight_layout() path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(path, dpi=150) plt.close(fig) logger.info("wrote %s", path) def _write_distribution_png( *, pre_rows: Sequence[Dict[str, Any]], post_rows: Sequence[Dict[str, Any]], path: Path, ) -> None: plt, _np = _try_matplotlib() if plt is None: return pre_rewards = [float(r.get("cumulative_reward", 0.0)) for r in pre_rows] post_rewards = [float(r.get("cumulative_reward", 0.0)) for r in post_rows] if not pre_rewards and not post_rewards: logger.warning("no rewards to plot — skipping reward_distribution.png") return fig, ax = plt.subplots(figsize=(9, 5)) bins = 12 if pre_rewards: ax.hist(pre_rewards, bins=bins, color="#94a3b8", alpha=0.55, label=f"pre (n={len(pre_rewards)}, μ={sum(pre_rewards) / len(pre_rewards):.2f})") if post_rewards: ax.hist(post_rewards, bins=bins, color="#1d4ed8", alpha=0.55, label=f"post (n={len(post_rewards)}, μ={sum(post_rewards) / len(post_rewards):.2f})") ax.set_xlabel("cumulative reward (per episode)") ax.set_ylabel("episode count") ax.set_title("DrugEnv per-episode reward distribution — pre vs post") ax.grid(alpha=0.25, axis="y") ax.legend(loc="best") fig.tight_layout() path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(path, dpi=150) plt.close(fig) logger.info("wrote %s", path) def _write_components_png( *, pre_rows: Sequence[Dict[str, Any]], post_rows: Sequence[Dict[str, Any]], path: Path, ) -> None: plt, _np = _try_matplotlib() if plt is None: return def _means(rows: Sequence[Dict[str, Any]]) -> List[float]: if not rows: return [0.0] * len(_REWARD_COMPONENTS) out: List[float] = [] for c in _REWARD_COMPONENTS: vals = [float((r.get("reward_components_total") or {}).get(c, 0.0)) for r in rows] out.append(sum(vals) / len(vals)) return out pre_means = _means(pre_rows) post_means = _means(post_rows) n = len(_REWARD_COMPONENTS) x = list(range(n)) width = 0.36 fig, ax = plt.subplots(figsize=(11, 5.5)) ax.bar([xi - width / 2 for xi in x], pre_means, width=width, color="#94a3b8", label=f"pre (n={len(pre_rows)})") ax.bar([xi + width / 2 for xi in x], post_means, width=width, color="#1d4ed8", label=f"post (n={len(post_rows)})") ax.set_xticks(x) ax.set_xticklabels(list(_REWARD_COMPONENTS), rotation=18, ha="right") ax.set_ylabel("per-episode component sum (mean)") ax.set_title("DrugEnv reward component breakdown — pre vs post") ax.grid(alpha=0.25, axis="y") ax.axhline(0, color="black", lw=0.8) ax.legend(loc="best") fig.tight_layout() path.parent.mkdir(parents=True, exist_ok=True) fig.savefig(path, dpi=150) plt.close(fig) logger.info("wrote %s", path) def _format_episode_md(row: Dict[str, Any], rank_label: str) -> str: actions = row.get("action_sequence") or [] rewards = row.get("step_rewards") or [] lines = [ f"### {rank_label} — episode {row.get('episode')} " f"(scenario={row.get('scenario')}, target={row.get('target_gene')})", "", f"- tag: `{row.get('tag')}`", f"- seed: `{row.get('seed')}`", f"- cumulative_reward: **{float(row.get('cumulative_reward', 0.0)):+.3f}**", f"- n_steps: {row.get('n_steps')} (invalid: {row.get('invalid_actions', 0)})", f"- submitted: `{row.get('submitted')}` " f"submitted_decision: `{row.get('submitted_decision')}` " f"correct_decision: `{row.get('correct_decision')}`", f"- decision_accuracy: {float(row.get('decision_accuracy', 0.0)):.3f} " f"evidence_coverage: {float(row.get('evidence_coverage', 0.0)):.3f}", "", "| step | action | reward |", "|-----:|:-------|-------:|", ] pad = max(len(actions), len(rewards)) actions = list(actions) + ["(no-op)"] * (pad - len(actions)) rewards = list(rewards) + [0.0] * (pad - len(rewards)) for i, (a, r) in enumerate(zip(actions, rewards)): lines.append(f"| {i} | `{a}` | {float(r):+.3f} |") comps = row.get("reward_components_total") or {} if comps: lines.append("") lines.append("**Reward component totals**") lines.append("") lines.append("| component | total |") lines.append("|:----------|------:|") for c in _REWARD_COMPONENTS: lines.append(f"| `{c}` | {float(comps.get(c, 0.0)):+.3f} |") lines.append("") return "\n".join(lines) def _write_samples_md( *, pre_rows: Sequence[Dict[str, Any]], post_rows: Sequence[Dict[str, Any]], n_samples: int, path: Path, ) -> None: """Worst-N pre-train + best-N post-train, sorted by cumulative reward.""" pre_sorted = sorted(pre_rows, key=lambda r: float(r.get("cumulative_reward", 0.0))) post_sorted = sorted(post_rows, key=lambda r: float(r.get("cumulative_reward", 0.0)), reverse=True) worst_pre = pre_sorted[:n_samples] best_post = post_sorted[:n_samples] parts: List[str] = [ "# Sample trajectories — DrugEnv before/after", "", "Generated by `training/summarize.py`. Worst pre-train episodes show " "what the warm-started model failed at; best post-train episodes show " "the trajectories GRPO actually reinforced.", "", f"## Worst {len(worst_pre)} pre-train episodes (lowest cumulative reward)", "", ] if not worst_pre: parts.append("_(no pre-train episodes recorded)_\n") for i, row in enumerate(worst_pre): parts.append(_format_episode_md(row, f"Worst-pre #{i + 1}")) parts.append(f"## Best {len(best_post)} post-train episodes (highest cumulative reward)") parts.append("") if not best_post: parts.append("_(no post-train episodes recorded)_\n") for i, row in enumerate(best_post): parts.append(_format_episode_md(row, f"Best-post #{i + 1}")) path.parent.mkdir(parents=True, exist_ok=True) path.write_text("\n".join(parts), encoding="utf-8") logger.info("wrote %s", path) def main(argv: Optional[List[str]] = None) -> int: args = _parse_args(argv) evidence_dir = Path(args.evidence_dir) evidence_dir.mkdir(parents=True, exist_ok=True) pre_rows = _load_jsonl(evidence_dir / args.pre) post_rows = _load_jsonl(evidence_dir / args.post) logger.info("loaded pre=%d post=%d episodes", len(pre_rows), len(post_rows)) pre_metrics = _compute_aggregate(pre_rows) post_metrics = _compute_aggregate(post_rows) delta_metrics = _delta(pre_metrics, post_metrics) _write_metrics_json( pre=pre_metrics, post=post_metrics, delta=delta_metrics, path=evidence_dir / args.metrics_out, ) _write_components_csv( pre_rows=pre_rows, post_rows=post_rows, path=evidence_dir / args.components_csv, ) _write_summary_png( pre=pre_metrics, post=post_metrics, path=evidence_dir / args.summary_png, ) _write_distribution_png( pre_rows=pre_rows, post_rows=post_rows, path=evidence_dir / args.distribution_png, ) _write_components_png( pre_rows=pre_rows, post_rows=post_rows, path=evidence_dir / args.components_png, ) _write_samples_md( pre_rows=pre_rows, post_rows=post_rows, n_samples=int(args.n_samples), path=evidence_dir / args.samples_md, ) logger.info( "summary: pre_mean=%.3f post_mean=%.3f Δmean=%+.3f " "pre_success=%.2f post_success=%.2f Δsuccess=%+.2f", pre_metrics.get("mean_reward", 0.0), post_metrics.get("mean_reward", 0.0), delta_metrics.get("mean_reward", 0.0), pre_metrics.get("success_rate", 0.0), post_metrics.get("success_rate", 0.0), delta_metrics.get("success_rate", 0.0), ) return 0 if __name__ == "__main__": sys.exit(main())