drugenv-trainer / training /summarize.py
anugrahteesdollar's picture
add: pre/post eval + summarize + bumped GRPO config (training/summarize.py)
18adb35 verified
"""Aggregate pre/post evaluator JSONL into a before/after evidence pack.
Reads ``evidence/pre_eval.jsonl`` + ``evidence/post_eval.jsonl`` (one
JSONL row per episode, written by ``training/evaluate.py``) and emits
the side-by-side artifacts the trainer dashboard surfaces:
* ``evidence/before_after_metrics.json`` — ``{pre, post, delta}``,
with the headline metrics every reviewer asks for first.
* ``evidence/reward_components.csv`` — one row per (tag, episode) with
all 8 component columns + total, ready to pivot in any spreadsheet.
* ``evidence/before_after_summary.png`` — grouped bar chart of pre vs
post for each headline metric.
* ``evidence/reward_distribution.png`` — overlapping histograms of
cumulative reward, pre vs post.
* ``evidence/reward_components.png`` — grouped bar chart of mean reward
components, pre vs post.
* ``evidence/sample_trajectories.md`` — pretty-printed worst-3 pre +
best-3 post episodes, with action sequences and per-step rewards.
CLI::
python -m training.summarize \\
--evidence_dir evidence \\
--pre pre_eval.jsonl --post post_eval.jsonl
"""
from __future__ import annotations
import argparse
import csv
import json
import logging
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("training.summarize")
# Mirrors ``training/evaluate.py::_REWARD_COMPONENTS`` (cannot import
# directly because this module must run with only stdlib + matplotlib
# in case the heavy ML deps are not installed at summarisation time).
_REWARD_COMPONENTS = (
"evidence_coverage",
"decision_accuracy",
"credit_efficiency",
"reasoning_coherence",
"novelty",
"penalty",
"shaping",
"terminal",
)
_HEADLINE_METRICS = (
"n",
"mean_reward",
"median_reward",
"success_rate",
"decision_accuracy",
"evidence_coverage",
)
def _parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--evidence_dir", default="evidence")
parser.add_argument("--pre", default="pre_eval.jsonl",
help="Filename of the pre-train JSONL inside evidence_dir.")
parser.add_argument("--post", default="post_eval.jsonl",
help="Filename of the post-train JSONL inside evidence_dir.")
parser.add_argument("--metrics_out", default="before_after_metrics.json")
parser.add_argument("--components_csv", default="reward_components.csv")
parser.add_argument("--summary_png", default="before_after_summary.png")
parser.add_argument("--distribution_png", default="reward_distribution.png")
parser.add_argument("--components_png", default="reward_components.png")
parser.add_argument("--samples_md", default="sample_trajectories.md")
parser.add_argument("--n_samples", type=int, default=3,
help="Number of worst-pre + best-post episodes to dump.")
return parser.parse_args(argv)
def _load_jsonl(path: Path) -> List[Dict[str, Any]]:
if not path.exists():
logger.warning("eval file missing: %s", path)
return []
rows: List[Dict[str, Any]] = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
rows.append(json.loads(line))
except json.JSONDecodeError as exc:
logger.warning("skipping unparseable JSONL line in %s: %s", path, exc)
return rows
def _success(row: Dict[str, Any]) -> bool:
"""Reviewer-facing 'success' (submitted AND right AND positive reward)."""
return bool(
row.get("submitted")
and row.get("submitted_decision") == row.get("correct_decision")
and float(row.get("cumulative_reward", 0.0)) > 0.0
)
def _compute_aggregate(rows: Sequence[Dict[str, Any]]) -> Dict[str, float]:
"""Compute the headline metrics for one tag."""
n = len(rows)
if n == 0:
return {k: 0.0 for k in _HEADLINE_METRICS} | {"n": 0}
rewards = [float(r.get("cumulative_reward", 0.0)) for r in rows]
rewards_sorted = sorted(rewards)
median = rewards_sorted[n // 2] if n % 2 else 0.5 * (rewards_sorted[n // 2 - 1] + rewards_sorted[n // 2])
return {
"n": n,
"mean_reward": sum(rewards) / n,
"median_reward": median,
"success_rate": sum(1 for r in rows if _success(r)) / n,
"decision_accuracy": sum(float(r.get("decision_accuracy", 0.0)) for r in rows) / n,
"evidence_coverage": sum(float(r.get("evidence_coverage", 0.0)) for r in rows) / n,
}
def _delta(pre: Dict[str, float], post: Dict[str, float]) -> Dict[str, float]:
out: Dict[str, float] = {}
for k in _HEADLINE_METRICS:
try:
out[k] = float(post.get(k, 0.0)) - float(pre.get(k, 0.0))
except (TypeError, ValueError):
out[k] = 0.0
return out
def _write_metrics_json(
*,
pre: Dict[str, float],
post: Dict[str, float],
delta: Dict[str, float],
path: Path,
) -> None:
payload = {"pre": pre, "post": post, "delta": delta}
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
logger.info("wrote %s", path)
def _write_components_csv(
*,
pre_rows: Sequence[Dict[str, Any]],
post_rows: Sequence[Dict[str, Any]],
path: Path,
) -> None:
"""One row per (tag, episode) with all 8 components + total."""
fields = ["tag", "episode", "scenario", "cumulative_reward", *_REWARD_COMPONENTS, "total"]
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
for tag_rows, tag in ((pre_rows, "pre"), (post_rows, "post")):
for r in tag_rows:
comps = r.get("reward_components_total") or {}
row = {
"tag": tag,
"episode": r.get("episode"),
"scenario": r.get("scenario"),
"cumulative_reward": round(float(r.get("cumulative_reward", 0.0)), 6),
"total": round(sum(float(comps.get(c, 0.0)) for c in _REWARD_COMPONENTS), 6),
}
for c in _REWARD_COMPONENTS:
row[c] = round(float(comps.get(c, 0.0)), 6)
writer.writerow(row)
logger.info("wrote %s", path)
def _try_matplotlib():
"""Returns (plt, np) or (None, None) when matplotlib is unavailable."""
try:
import matplotlib # type: ignore
matplotlib.use("Agg")
import matplotlib.pyplot as plt # type: ignore
except Exception as exc: # pragma: no cover - plotting best-effort
logger.warning("matplotlib unavailable (%s); skipping plots", exc)
return None, None
try:
import numpy as np # type: ignore
except Exception:
np = None # type: ignore
return plt, np
def _write_summary_png(
*,
pre: Dict[str, float],
post: Dict[str, float],
path: Path,
) -> None:
plt, np = _try_matplotlib()
if plt is None:
return
metrics = [
("mean_reward", "Mean reward"),
("median_reward", "Median reward"),
("success_rate", "Success rate"),
("decision_accuracy", "Decision accuracy"),
("evidence_coverage", "Evidence coverage"),
]
labels = [lbl for _, lbl in metrics]
pre_vals = [float(pre.get(k, 0.0)) for k, _ in metrics]
post_vals = [float(post.get(k, 0.0)) for k, _ in metrics]
n = len(metrics)
x = list(range(n))
width = 0.36
fig, ax = plt.subplots(figsize=(9, 5))
bar_pre = ax.bar([xi - width / 2 for xi in x], pre_vals, width=width,
label=f"pre (n={int(pre.get('n', 0))})", color="#94a3b8")
bar_post = ax.bar([xi + width / 2 for xi in x], post_vals, width=width,
label=f"post (n={int(post.get('n', 0))})", color="#1d4ed8")
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=14, ha="right")
ax.set_ylabel("metric value")
ax.set_title("DrugEnv before vs after — headline metrics")
ax.grid(alpha=0.25, axis="y")
ax.axhline(0, color="black", lw=0.8)
for bar_group in (bar_pre, bar_post):
for rect in bar_group:
h = rect.get_height()
ax.text(rect.get_x() + rect.get_width() / 2, h, f"{h:.2f}",
ha="center", va="bottom" if h >= 0 else "top", fontsize=8)
ax.legend(loc="best")
fig.tight_layout()
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(path, dpi=150)
plt.close(fig)
logger.info("wrote %s", path)
def _write_distribution_png(
*,
pre_rows: Sequence[Dict[str, Any]],
post_rows: Sequence[Dict[str, Any]],
path: Path,
) -> None:
plt, _np = _try_matplotlib()
if plt is None:
return
pre_rewards = [float(r.get("cumulative_reward", 0.0)) for r in pre_rows]
post_rewards = [float(r.get("cumulative_reward", 0.0)) for r in post_rows]
if not pre_rewards and not post_rewards:
logger.warning("no rewards to plot — skipping reward_distribution.png")
return
fig, ax = plt.subplots(figsize=(9, 5))
bins = 12
if pre_rewards:
ax.hist(pre_rewards, bins=bins, color="#94a3b8", alpha=0.55,
label=f"pre (n={len(pre_rewards)}, μ={sum(pre_rewards) / len(pre_rewards):.2f})")
if post_rewards:
ax.hist(post_rewards, bins=bins, color="#1d4ed8", alpha=0.55,
label=f"post (n={len(post_rewards)}, μ={sum(post_rewards) / len(post_rewards):.2f})")
ax.set_xlabel("cumulative reward (per episode)")
ax.set_ylabel("episode count")
ax.set_title("DrugEnv per-episode reward distribution — pre vs post")
ax.grid(alpha=0.25, axis="y")
ax.legend(loc="best")
fig.tight_layout()
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(path, dpi=150)
plt.close(fig)
logger.info("wrote %s", path)
def _write_components_png(
*,
pre_rows: Sequence[Dict[str, Any]],
post_rows: Sequence[Dict[str, Any]],
path: Path,
) -> None:
plt, _np = _try_matplotlib()
if plt is None:
return
def _means(rows: Sequence[Dict[str, Any]]) -> List[float]:
if not rows:
return [0.0] * len(_REWARD_COMPONENTS)
out: List[float] = []
for c in _REWARD_COMPONENTS:
vals = [float((r.get("reward_components_total") or {}).get(c, 0.0)) for r in rows]
out.append(sum(vals) / len(vals))
return out
pre_means = _means(pre_rows)
post_means = _means(post_rows)
n = len(_REWARD_COMPONENTS)
x = list(range(n))
width = 0.36
fig, ax = plt.subplots(figsize=(11, 5.5))
ax.bar([xi - width / 2 for xi in x], pre_means, width=width,
color="#94a3b8", label=f"pre (n={len(pre_rows)})")
ax.bar([xi + width / 2 for xi in x], post_means, width=width,
color="#1d4ed8", label=f"post (n={len(post_rows)})")
ax.set_xticks(x)
ax.set_xticklabels(list(_REWARD_COMPONENTS), rotation=18, ha="right")
ax.set_ylabel("per-episode component sum (mean)")
ax.set_title("DrugEnv reward component breakdown — pre vs post")
ax.grid(alpha=0.25, axis="y")
ax.axhline(0, color="black", lw=0.8)
ax.legend(loc="best")
fig.tight_layout()
path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(path, dpi=150)
plt.close(fig)
logger.info("wrote %s", path)
def _format_episode_md(row: Dict[str, Any], rank_label: str) -> str:
actions = row.get("action_sequence") or []
rewards = row.get("step_rewards") or []
lines = [
f"### {rank_label} — episode {row.get('episode')} "
f"(scenario={row.get('scenario')}, target={row.get('target_gene')})",
"",
f"- tag: `{row.get('tag')}`",
f"- seed: `{row.get('seed')}`",
f"- cumulative_reward: **{float(row.get('cumulative_reward', 0.0)):+.3f}**",
f"- n_steps: {row.get('n_steps')} (invalid: {row.get('invalid_actions', 0)})",
f"- submitted: `{row.get('submitted')}` "
f"submitted_decision: `{row.get('submitted_decision')}` "
f"correct_decision: `{row.get('correct_decision')}`",
f"- decision_accuracy: {float(row.get('decision_accuracy', 0.0)):.3f} "
f"evidence_coverage: {float(row.get('evidence_coverage', 0.0)):.3f}",
"",
"| step | action | reward |",
"|-----:|:-------|-------:|",
]
pad = max(len(actions), len(rewards))
actions = list(actions) + ["(no-op)"] * (pad - len(actions))
rewards = list(rewards) + [0.0] * (pad - len(rewards))
for i, (a, r) in enumerate(zip(actions, rewards)):
lines.append(f"| {i} | `{a}` | {float(r):+.3f} |")
comps = row.get("reward_components_total") or {}
if comps:
lines.append("")
lines.append("**Reward component totals**")
lines.append("")
lines.append("| component | total |")
lines.append("|:----------|------:|")
for c in _REWARD_COMPONENTS:
lines.append(f"| `{c}` | {float(comps.get(c, 0.0)):+.3f} |")
lines.append("")
return "\n".join(lines)
def _write_samples_md(
*,
pre_rows: Sequence[Dict[str, Any]],
post_rows: Sequence[Dict[str, Any]],
n_samples: int,
path: Path,
) -> None:
"""Worst-N pre-train + best-N post-train, sorted by cumulative reward."""
pre_sorted = sorted(pre_rows, key=lambda r: float(r.get("cumulative_reward", 0.0)))
post_sorted = sorted(post_rows, key=lambda r: float(r.get("cumulative_reward", 0.0)), reverse=True)
worst_pre = pre_sorted[:n_samples]
best_post = post_sorted[:n_samples]
parts: List[str] = [
"# Sample trajectories — DrugEnv before/after",
"",
"Generated by `training/summarize.py`. Worst pre-train episodes show "
"what the warm-started model failed at; best post-train episodes show "
"the trajectories GRPO actually reinforced.",
"",
f"## Worst {len(worst_pre)} pre-train episodes (lowest cumulative reward)",
"",
]
if not worst_pre:
parts.append("_(no pre-train episodes recorded)_\n")
for i, row in enumerate(worst_pre):
parts.append(_format_episode_md(row, f"Worst-pre #{i + 1}"))
parts.append(f"## Best {len(best_post)} post-train episodes (highest cumulative reward)")
parts.append("")
if not best_post:
parts.append("_(no post-train episodes recorded)_\n")
for i, row in enumerate(best_post):
parts.append(_format_episode_md(row, f"Best-post #{i + 1}"))
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text("\n".join(parts), encoding="utf-8")
logger.info("wrote %s", path)
def main(argv: Optional[List[str]] = None) -> int:
args = _parse_args(argv)
evidence_dir = Path(args.evidence_dir)
evidence_dir.mkdir(parents=True, exist_ok=True)
pre_rows = _load_jsonl(evidence_dir / args.pre)
post_rows = _load_jsonl(evidence_dir / args.post)
logger.info("loaded pre=%d post=%d episodes", len(pre_rows), len(post_rows))
pre_metrics = _compute_aggregate(pre_rows)
post_metrics = _compute_aggregate(post_rows)
delta_metrics = _delta(pre_metrics, post_metrics)
_write_metrics_json(
pre=pre_metrics, post=post_metrics, delta=delta_metrics,
path=evidence_dir / args.metrics_out,
)
_write_components_csv(
pre_rows=pre_rows, post_rows=post_rows,
path=evidence_dir / args.components_csv,
)
_write_summary_png(
pre=pre_metrics, post=post_metrics,
path=evidence_dir / args.summary_png,
)
_write_distribution_png(
pre_rows=pre_rows, post_rows=post_rows,
path=evidence_dir / args.distribution_png,
)
_write_components_png(
pre_rows=pre_rows, post_rows=post_rows,
path=evidence_dir / args.components_png,
)
_write_samples_md(
pre_rows=pre_rows, post_rows=post_rows,
n_samples=int(args.n_samples),
path=evidence_dir / args.samples_md,
)
logger.info(
"summary: pre_mean=%.3f post_mean=%.3f Δmean=%+.3f "
"pre_success=%.2f post_success=%.2f Δsuccess=%+.2f",
pre_metrics.get("mean_reward", 0.0),
post_metrics.get("mean_reward", 0.0),
delta_metrics.get("mean_reward", 0.0),
pre_metrics.get("success_rate", 0.0),
post_metrics.get("success_rate", 0.0),
delta_metrics.get("success_rate", 0.0),
)
return 0
if __name__ == "__main__":
sys.exit(main())