Spaces:
Runtime error
Runtime error
| """Aggregate pre/post evaluator JSONL into a before/after evidence pack. | |
| Reads ``evidence/pre_eval.jsonl`` + ``evidence/post_eval.jsonl`` (one | |
| JSONL row per episode, written by ``training/evaluate.py``) and emits | |
| the side-by-side artifacts the trainer dashboard surfaces: | |
| * ``evidence/before_after_metrics.json`` — ``{pre, post, delta}``, | |
| with the headline metrics every reviewer asks for first. | |
| * ``evidence/reward_components.csv`` — one row per (tag, episode) with | |
| all 8 component columns + total, ready to pivot in any spreadsheet. | |
| * ``evidence/before_after_summary.png`` — grouped bar chart of pre vs | |
| post for each headline metric. | |
| * ``evidence/reward_distribution.png`` — overlapping histograms of | |
| cumulative reward, pre vs post. | |
| * ``evidence/reward_components.png`` — grouped bar chart of mean reward | |
| components, pre vs post. | |
| * ``evidence/sample_trajectories.md`` — pretty-printed worst-3 pre + | |
| best-3 post episodes, with action sequences and per-step rewards. | |
| CLI:: | |
| python -m training.summarize \\ | |
| --evidence_dir evidence \\ | |
| --pre pre_eval.jsonl --post post_eval.jsonl | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Sequence | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") | |
| logger = logging.getLogger("training.summarize") | |
| # Mirrors ``training/evaluate.py::_REWARD_COMPONENTS`` (cannot import | |
| # directly because this module must run with only stdlib + matplotlib | |
| # in case the heavy ML deps are not installed at summarisation time). | |
| _REWARD_COMPONENTS = ( | |
| "evidence_coverage", | |
| "decision_accuracy", | |
| "credit_efficiency", | |
| "reasoning_coherence", | |
| "novelty", | |
| "penalty", | |
| "shaping", | |
| "terminal", | |
| ) | |
| _HEADLINE_METRICS = ( | |
| "n", | |
| "mean_reward", | |
| "median_reward", | |
| "success_rate", | |
| "decision_accuracy", | |
| "evidence_coverage", | |
| ) | |
| def _parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--evidence_dir", default="evidence") | |
| parser.add_argument("--pre", default="pre_eval.jsonl", | |
| help="Filename of the pre-train JSONL inside evidence_dir.") | |
| parser.add_argument("--post", default="post_eval.jsonl", | |
| help="Filename of the post-train JSONL inside evidence_dir.") | |
| parser.add_argument("--metrics_out", default="before_after_metrics.json") | |
| parser.add_argument("--components_csv", default="reward_components.csv") | |
| parser.add_argument("--summary_png", default="before_after_summary.png") | |
| parser.add_argument("--distribution_png", default="reward_distribution.png") | |
| parser.add_argument("--components_png", default="reward_components.png") | |
| parser.add_argument("--samples_md", default="sample_trajectories.md") | |
| parser.add_argument("--n_samples", type=int, default=3, | |
| help="Number of worst-pre + best-post episodes to dump.") | |
| return parser.parse_args(argv) | |
| def _load_jsonl(path: Path) -> List[Dict[str, Any]]: | |
| if not path.exists(): | |
| logger.warning("eval file missing: %s", path) | |
| return [] | |
| rows: List[Dict[str, Any]] = [] | |
| with open(path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| rows.append(json.loads(line)) | |
| except json.JSONDecodeError as exc: | |
| logger.warning("skipping unparseable JSONL line in %s: %s", path, exc) | |
| return rows | |
| def _success(row: Dict[str, Any]) -> bool: | |
| """Reviewer-facing 'success' (submitted AND right AND positive reward).""" | |
| return bool( | |
| row.get("submitted") | |
| and row.get("submitted_decision") == row.get("correct_decision") | |
| and float(row.get("cumulative_reward", 0.0)) > 0.0 | |
| ) | |
| def _compute_aggregate(rows: Sequence[Dict[str, Any]]) -> Dict[str, float]: | |
| """Compute the headline metrics for one tag.""" | |
| n = len(rows) | |
| if n == 0: | |
| return {k: 0.0 for k in _HEADLINE_METRICS} | {"n": 0} | |
| rewards = [float(r.get("cumulative_reward", 0.0)) for r in rows] | |
| rewards_sorted = sorted(rewards) | |
| median = rewards_sorted[n // 2] if n % 2 else 0.5 * (rewards_sorted[n // 2 - 1] + rewards_sorted[n // 2]) | |
| return { | |
| "n": n, | |
| "mean_reward": sum(rewards) / n, | |
| "median_reward": median, | |
| "success_rate": sum(1 for r in rows if _success(r)) / n, | |
| "decision_accuracy": sum(float(r.get("decision_accuracy", 0.0)) for r in rows) / n, | |
| "evidence_coverage": sum(float(r.get("evidence_coverage", 0.0)) for r in rows) / n, | |
| } | |
| def _delta(pre: Dict[str, float], post: Dict[str, float]) -> Dict[str, float]: | |
| out: Dict[str, float] = {} | |
| for k in _HEADLINE_METRICS: | |
| try: | |
| out[k] = float(post.get(k, 0.0)) - float(pre.get(k, 0.0)) | |
| except (TypeError, ValueError): | |
| out[k] = 0.0 | |
| return out | |
| def _write_metrics_json( | |
| *, | |
| pre: Dict[str, float], | |
| post: Dict[str, float], | |
| delta: Dict[str, float], | |
| path: Path, | |
| ) -> None: | |
| payload = {"pre": pre, "post": post, "delta": delta} | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| logger.info("wrote %s", path) | |
| def _write_components_csv( | |
| *, | |
| pre_rows: Sequence[Dict[str, Any]], | |
| post_rows: Sequence[Dict[str, Any]], | |
| path: Path, | |
| ) -> None: | |
| """One row per (tag, episode) with all 8 components + total.""" | |
| fields = ["tag", "episode", "scenario", "cumulative_reward", *_REWARD_COMPONENTS, "total"] | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| with open(path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=fields) | |
| writer.writeheader() | |
| for tag_rows, tag in ((pre_rows, "pre"), (post_rows, "post")): | |
| for r in tag_rows: | |
| comps = r.get("reward_components_total") or {} | |
| row = { | |
| "tag": tag, | |
| "episode": r.get("episode"), | |
| "scenario": r.get("scenario"), | |
| "cumulative_reward": round(float(r.get("cumulative_reward", 0.0)), 6), | |
| "total": round(sum(float(comps.get(c, 0.0)) for c in _REWARD_COMPONENTS), 6), | |
| } | |
| for c in _REWARD_COMPONENTS: | |
| row[c] = round(float(comps.get(c, 0.0)), 6) | |
| writer.writerow(row) | |
| logger.info("wrote %s", path) | |
| def _try_matplotlib(): | |
| """Returns (plt, np) or (None, None) when matplotlib is unavailable.""" | |
| try: | |
| import matplotlib # type: ignore | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt # type: ignore | |
| except Exception as exc: # pragma: no cover - plotting best-effort | |
| logger.warning("matplotlib unavailable (%s); skipping plots", exc) | |
| return None, None | |
| try: | |
| import numpy as np # type: ignore | |
| except Exception: | |
| np = None # type: ignore | |
| return plt, np | |
| def _write_summary_png( | |
| *, | |
| pre: Dict[str, float], | |
| post: Dict[str, float], | |
| path: Path, | |
| ) -> None: | |
| plt, np = _try_matplotlib() | |
| if plt is None: | |
| return | |
| metrics = [ | |
| ("mean_reward", "Mean reward"), | |
| ("median_reward", "Median reward"), | |
| ("success_rate", "Success rate"), | |
| ("decision_accuracy", "Decision accuracy"), | |
| ("evidence_coverage", "Evidence coverage"), | |
| ] | |
| labels = [lbl for _, lbl in metrics] | |
| pre_vals = [float(pre.get(k, 0.0)) for k, _ in metrics] | |
| post_vals = [float(post.get(k, 0.0)) for k, _ in metrics] | |
| n = len(metrics) | |
| x = list(range(n)) | |
| width = 0.36 | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| bar_pre = ax.bar([xi - width / 2 for xi in x], pre_vals, width=width, | |
| label=f"pre (n={int(pre.get('n', 0))})", color="#94a3b8") | |
| bar_post = ax.bar([xi + width / 2 for xi in x], post_vals, width=width, | |
| label=f"post (n={int(post.get('n', 0))})", color="#1d4ed8") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(labels, rotation=14, ha="right") | |
| ax.set_ylabel("metric value") | |
| ax.set_title("DrugEnv before vs after — headline metrics") | |
| ax.grid(alpha=0.25, axis="y") | |
| ax.axhline(0, color="black", lw=0.8) | |
| for bar_group in (bar_pre, bar_post): | |
| for rect in bar_group: | |
| h = rect.get_height() | |
| ax.text(rect.get_x() + rect.get_width() / 2, h, f"{h:.2f}", | |
| ha="center", va="bottom" if h >= 0 else "top", fontsize=8) | |
| ax.legend(loc="best") | |
| fig.tight_layout() | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(path, dpi=150) | |
| plt.close(fig) | |
| logger.info("wrote %s", path) | |
| def _write_distribution_png( | |
| *, | |
| pre_rows: Sequence[Dict[str, Any]], | |
| post_rows: Sequence[Dict[str, Any]], | |
| path: Path, | |
| ) -> None: | |
| plt, _np = _try_matplotlib() | |
| if plt is None: | |
| return | |
| pre_rewards = [float(r.get("cumulative_reward", 0.0)) for r in pre_rows] | |
| post_rewards = [float(r.get("cumulative_reward", 0.0)) for r in post_rows] | |
| if not pre_rewards and not post_rewards: | |
| logger.warning("no rewards to plot — skipping reward_distribution.png") | |
| return | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| bins = 12 | |
| if pre_rewards: | |
| ax.hist(pre_rewards, bins=bins, color="#94a3b8", alpha=0.55, | |
| label=f"pre (n={len(pre_rewards)}, μ={sum(pre_rewards) / len(pre_rewards):.2f})") | |
| if post_rewards: | |
| ax.hist(post_rewards, bins=bins, color="#1d4ed8", alpha=0.55, | |
| label=f"post (n={len(post_rewards)}, μ={sum(post_rewards) / len(post_rewards):.2f})") | |
| ax.set_xlabel("cumulative reward (per episode)") | |
| ax.set_ylabel("episode count") | |
| ax.set_title("DrugEnv per-episode reward distribution — pre vs post") | |
| ax.grid(alpha=0.25, axis="y") | |
| ax.legend(loc="best") | |
| fig.tight_layout() | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(path, dpi=150) | |
| plt.close(fig) | |
| logger.info("wrote %s", path) | |
| def _write_components_png( | |
| *, | |
| pre_rows: Sequence[Dict[str, Any]], | |
| post_rows: Sequence[Dict[str, Any]], | |
| path: Path, | |
| ) -> None: | |
| plt, _np = _try_matplotlib() | |
| if plt is None: | |
| return | |
| def _means(rows: Sequence[Dict[str, Any]]) -> List[float]: | |
| if not rows: | |
| return [0.0] * len(_REWARD_COMPONENTS) | |
| out: List[float] = [] | |
| for c in _REWARD_COMPONENTS: | |
| vals = [float((r.get("reward_components_total") or {}).get(c, 0.0)) for r in rows] | |
| out.append(sum(vals) / len(vals)) | |
| return out | |
| pre_means = _means(pre_rows) | |
| post_means = _means(post_rows) | |
| n = len(_REWARD_COMPONENTS) | |
| x = list(range(n)) | |
| width = 0.36 | |
| fig, ax = plt.subplots(figsize=(11, 5.5)) | |
| ax.bar([xi - width / 2 for xi in x], pre_means, width=width, | |
| color="#94a3b8", label=f"pre (n={len(pre_rows)})") | |
| ax.bar([xi + width / 2 for xi in x], post_means, width=width, | |
| color="#1d4ed8", label=f"post (n={len(post_rows)})") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels(list(_REWARD_COMPONENTS), rotation=18, ha="right") | |
| ax.set_ylabel("per-episode component sum (mean)") | |
| ax.set_title("DrugEnv reward component breakdown — pre vs post") | |
| ax.grid(alpha=0.25, axis="y") | |
| ax.axhline(0, color="black", lw=0.8) | |
| ax.legend(loc="best") | |
| fig.tight_layout() | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| fig.savefig(path, dpi=150) | |
| plt.close(fig) | |
| logger.info("wrote %s", path) | |
| def _format_episode_md(row: Dict[str, Any], rank_label: str) -> str: | |
| actions = row.get("action_sequence") or [] | |
| rewards = row.get("step_rewards") or [] | |
| lines = [ | |
| f"### {rank_label} — episode {row.get('episode')} " | |
| f"(scenario={row.get('scenario')}, target={row.get('target_gene')})", | |
| "", | |
| f"- tag: `{row.get('tag')}`", | |
| f"- seed: `{row.get('seed')}`", | |
| f"- cumulative_reward: **{float(row.get('cumulative_reward', 0.0)):+.3f}**", | |
| f"- n_steps: {row.get('n_steps')} (invalid: {row.get('invalid_actions', 0)})", | |
| f"- submitted: `{row.get('submitted')}` " | |
| f"submitted_decision: `{row.get('submitted_decision')}` " | |
| f"correct_decision: `{row.get('correct_decision')}`", | |
| f"- decision_accuracy: {float(row.get('decision_accuracy', 0.0)):.3f} " | |
| f"evidence_coverage: {float(row.get('evidence_coverage', 0.0)):.3f}", | |
| "", | |
| "| step | action | reward |", | |
| "|-----:|:-------|-------:|", | |
| ] | |
| pad = max(len(actions), len(rewards)) | |
| actions = list(actions) + ["(no-op)"] * (pad - len(actions)) | |
| rewards = list(rewards) + [0.0] * (pad - len(rewards)) | |
| for i, (a, r) in enumerate(zip(actions, rewards)): | |
| lines.append(f"| {i} | `{a}` | {float(r):+.3f} |") | |
| comps = row.get("reward_components_total") or {} | |
| if comps: | |
| lines.append("") | |
| lines.append("**Reward component totals**") | |
| lines.append("") | |
| lines.append("| component | total |") | |
| lines.append("|:----------|------:|") | |
| for c in _REWARD_COMPONENTS: | |
| lines.append(f"| `{c}` | {float(comps.get(c, 0.0)):+.3f} |") | |
| lines.append("") | |
| return "\n".join(lines) | |
| def _write_samples_md( | |
| *, | |
| pre_rows: Sequence[Dict[str, Any]], | |
| post_rows: Sequence[Dict[str, Any]], | |
| n_samples: int, | |
| path: Path, | |
| ) -> None: | |
| """Worst-N pre-train + best-N post-train, sorted by cumulative reward.""" | |
| pre_sorted = sorted(pre_rows, key=lambda r: float(r.get("cumulative_reward", 0.0))) | |
| post_sorted = sorted(post_rows, key=lambda r: float(r.get("cumulative_reward", 0.0)), reverse=True) | |
| worst_pre = pre_sorted[:n_samples] | |
| best_post = post_sorted[:n_samples] | |
| parts: List[str] = [ | |
| "# Sample trajectories — DrugEnv before/after", | |
| "", | |
| "Generated by `training/summarize.py`. Worst pre-train episodes show " | |
| "what the warm-started model failed at; best post-train episodes show " | |
| "the trajectories GRPO actually reinforced.", | |
| "", | |
| f"## Worst {len(worst_pre)} pre-train episodes (lowest cumulative reward)", | |
| "", | |
| ] | |
| if not worst_pre: | |
| parts.append("_(no pre-train episodes recorded)_\n") | |
| for i, row in enumerate(worst_pre): | |
| parts.append(_format_episode_md(row, f"Worst-pre #{i + 1}")) | |
| parts.append(f"## Best {len(best_post)} post-train episodes (highest cumulative reward)") | |
| parts.append("") | |
| if not best_post: | |
| parts.append("_(no post-train episodes recorded)_\n") | |
| for i, row in enumerate(best_post): | |
| parts.append(_format_episode_md(row, f"Best-post #{i + 1}")) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_text("\n".join(parts), encoding="utf-8") | |
| logger.info("wrote %s", path) | |
| def main(argv: Optional[List[str]] = None) -> int: | |
| args = _parse_args(argv) | |
| evidence_dir = Path(args.evidence_dir) | |
| evidence_dir.mkdir(parents=True, exist_ok=True) | |
| pre_rows = _load_jsonl(evidence_dir / args.pre) | |
| post_rows = _load_jsonl(evidence_dir / args.post) | |
| logger.info("loaded pre=%d post=%d episodes", len(pre_rows), len(post_rows)) | |
| pre_metrics = _compute_aggregate(pre_rows) | |
| post_metrics = _compute_aggregate(post_rows) | |
| delta_metrics = _delta(pre_metrics, post_metrics) | |
| _write_metrics_json( | |
| pre=pre_metrics, post=post_metrics, delta=delta_metrics, | |
| path=evidence_dir / args.metrics_out, | |
| ) | |
| _write_components_csv( | |
| pre_rows=pre_rows, post_rows=post_rows, | |
| path=evidence_dir / args.components_csv, | |
| ) | |
| _write_summary_png( | |
| pre=pre_metrics, post=post_metrics, | |
| path=evidence_dir / args.summary_png, | |
| ) | |
| _write_distribution_png( | |
| pre_rows=pre_rows, post_rows=post_rows, | |
| path=evidence_dir / args.distribution_png, | |
| ) | |
| _write_components_png( | |
| pre_rows=pre_rows, post_rows=post_rows, | |
| path=evidence_dir / args.components_png, | |
| ) | |
| _write_samples_md( | |
| pre_rows=pre_rows, post_rows=post_rows, | |
| n_samples=int(args.n_samples), | |
| path=evidence_dir / args.samples_md, | |
| ) | |
| logger.info( | |
| "summary: pre_mean=%.3f post_mean=%.3f Δmean=%+.3f " | |
| "pre_success=%.2f post_success=%.2f Δsuccess=%+.2f", | |
| pre_metrics.get("mean_reward", 0.0), | |
| post_metrics.get("mean_reward", 0.0), | |
| delta_metrics.get("mean_reward", 0.0), | |
| pre_metrics.get("success_rate", 0.0), | |
| post_metrics.get("success_rate", 0.0), | |
| delta_metrics.get("success_rate", 0.0), | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |