Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Generate demo-quality plots from a completed (or in-progress) GRPO run. | |
| Usage | |
| ----- | |
| # from the run output directory | |
| python scripts/plot_grpo_run.py checkpoints/grpo/<run_name>/metrics.jsonl | |
| # auto-discover the latest run | |
| python scripts/plot_grpo_run.py --latest | |
| # custom output directory | |
| python scripts/plot_grpo_run.py metrics.jsonl --out-dir plots/my_run | |
| Output | |
| ------ | |
| Six PNG files saved next to the JSONL (or --out-dir if given): | |
| 01_training_objective.png β combined_score vs iteration (PRIMARY demo plot) | |
| 02_reward_components.png β 4-panel breakdown: correct / PRM / SymPy / format | |
| 03_training_dynamics.png β GRPO loss + batch reward + batch accuracy | |
| 04_reward_vs_eval.png β training reward vs eval score on same axis | |
| 05_component_area.png β stacked-area chart of the 4 weighted components | |
| 06_summary_card.png β single-panel card: all key metrics in one view | |
| All figures use a clean dark-on-white academic style. They are saved at | |
| 300 dpi so they look sharp in slides and posters. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import matplotlib | |
| matplotlib.use("Agg") # headless β no display needed on training servers | |
| import matplotlib.pyplot as plt | |
| import matplotlib.ticker as mtick | |
| import numpy as np | |
| # ββ Style ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PALETTE = { | |
| "combined": "#2563EB", # blue β training objective | |
| "correct": "#16A34A", # green β correctness | |
| "prm": "#DC2626", # red β PRM step quality | |
| "sympy": "#D97706", # amber β SymPy verification | |
| "fmt": "#7C3AED", # violet β format | |
| "reward": "#0891B2", # cyan β mean batch reward | |
| "loss": "#64748B", # slate β loss | |
| "batch_acc": "#059669", # emerald β batch accuracy | |
| } | |
| plt.rcParams.update({ | |
| "figure.dpi": 150, | |
| "savefig.dpi": 300, | |
| "font.family": "DejaVu Sans", | |
| "axes.spines.top": False, | |
| "axes.spines.right": False, | |
| "axes.grid": True, | |
| "grid.alpha": 0.3, | |
| "grid.linestyle": "--", | |
| "axes.labelsize": 11, | |
| "axes.titlesize": 13, | |
| "legend.fontsize": 9, | |
| "xtick.labelsize": 9, | |
| "ytick.labelsize": 9, | |
| }) | |
| # ββ Data loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load(path: Path) -> List[Dict[str, Any]]: | |
| rows = [] | |
| with path.open(encoding="utf-8") as fh: | |
| for line in fh: | |
| line = line.strip() | |
| if line: | |
| rows.append(json.loads(line)) | |
| return rows | |
| def _field(rows: List[Dict], key: str) -> Tuple[List[int], List[float]]: | |
| """Return (iterations, values) for rows that have a non-empty key.""" | |
| iters, vals = [], [] | |
| for r in rows: | |
| v = r.get(key) | |
| if v is not None and v != "" and not (isinstance(v, float) and np.isnan(v)): | |
| try: | |
| iters.append(int(r["iteration"])) | |
| vals.append(float(v)) | |
| except (TypeError, ValueError): | |
| pass | |
| return iters, vals | |
| # ββ Individual plots βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def plot_training_objective(rows: List[Dict], out: Path) -> None: | |
| """Plot 01: combined_score β the single most important demo plot.""" | |
| xi, xv = _field(rows, "combined_score") | |
| if not xi: | |
| return | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| ax.plot(xi, xv, color=PALETTE["combined"], linewidth=2.5, | |
| marker="o", markersize=5, label="Training-objective score") | |
| ax.fill_between(xi, xv, alpha=0.12, color=PALETTE["combined"]) | |
| # annotate first and last eval points | |
| ax.annotate(f"{xv[0]:.3f}", (xi[0], xv[0]), textcoords="offset points", | |
| xytext=(8, 6), fontsize=8, color=PALETTE["combined"]) | |
| ax.annotate(f"{xv[-1]:.3f}", (xi[-1], xv[-1]), textcoords="offset points", | |
| xytext=(8, 6), fontsize=8, color=PALETTE["combined"]) | |
| ax.set_xlabel("Iteration") | |
| ax.set_ylabel("Score (0 β 1)") | |
| ax.set_title( | |
| "GRPO Training β Combined Reward Score\n" | |
| "0.60 Γ correct + 0.15 Γ PRM + 0.15 Γ SymPy + 0.10 Γ format", | |
| fontsize=12, | |
| ) | |
| ax.set_ylim(0, 1.05) | |
| ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0)) | |
| ax.legend(loc="lower right") | |
| fig.tight_layout() | |
| fig.savefig(out) | |
| plt.close(fig) | |
| print(f" saved {out.name}") | |
| def plot_reward_components(rows: List[Dict], out: Path) -> None: | |
| """Plot 02: four-panel breakdown of each reward component.""" | |
| specs = [ | |
| ("correct_rate", "correct", "Correctness (gt_match)", "60 %"), | |
| ("prm_mean", "prm", "PRM Step Quality", "15 %"), | |
| ("sympy_mean", "sympy", "SymPy Verification", "15 %"), | |
| ("format_mean", "fmt", "Format Compliance", "10 %"), | |
| ] | |
| fig, axes = plt.subplots(2, 2, figsize=(12, 7), sharex=False) | |
| axes = axes.flatten() | |
| for ax, (key, pal, title, weight) in zip(axes, specs): | |
| xi, xv = _field(rows, key) | |
| if not xi: | |
| ax.set_visible(False) | |
| continue | |
| ax.plot(xi, xv, color=PALETTE[pal], linewidth=2, | |
| marker="o", markersize=4) | |
| ax.fill_between(xi, xv, alpha=0.12, color=PALETTE[pal]) | |
| ax.set_title(f"{title} (weight {weight})", fontsize=11) | |
| ax.set_xlabel("Iteration") | |
| ax.set_ylabel("Score") | |
| ax.set_ylim(0, 1.05) | |
| ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0)) | |
| if xv: | |
| delta = xv[-1] - xv[0] | |
| sign = "+" if delta >= 0 else "" | |
| ax.set_title( | |
| f"{title} (weight {weight}) Ξ={sign}{delta:+.1%}", | |
| fontsize=10, | |
| ) | |
| fig.suptitle("Reward Component Breakdown over Training", fontsize=13, y=1.01) | |
| fig.tight_layout() | |
| fig.savefig(out, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" saved {out.name}") | |
| def plot_training_dynamics(rows: List[Dict], out: Path) -> None: | |
| """Plot 03: loss, mean_reward, batch_accuracy over all iterations.""" | |
| li, lv = _field(rows, "loss") | |
| ri, rv = _field(rows, "mean_reward") | |
| bi, bv = _field(rows, "batch_accuracy") | |
| fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True) | |
| if lv: | |
| axes[0].plot(li, lv, color=PALETTE["loss"], linewidth=1.8) | |
| axes[0].fill_between(li, lv, alpha=0.1, color=PALETTE["loss"]) | |
| axes[0].set_ylabel("GRPO Loss") | |
| axes[0].set_title("Training Loss", fontsize=11) | |
| axes[0].axhline(0, color="black", linewidth=0.8, linestyle="--", alpha=0.4) | |
| if rv: | |
| axes[1].plot(ri, rv, color=PALETTE["reward"], linewidth=1.8) | |
| axes[1].fill_between(ri, rv, alpha=0.1, color=PALETTE["reward"]) | |
| axes[1].set_ylabel("Reward") | |
| axes[1].set_ylim(0, 1.05) | |
| axes[1].set_title("Mean Batch Reward", fontsize=11) | |
| axes[1].yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0)) | |
| if bv: | |
| axes[2].plot(bi, bv, color=PALETTE["batch_acc"], linewidth=1.8) | |
| axes[2].fill_between(bi, bv, alpha=0.1, color=PALETTE["batch_acc"]) | |
| axes[2].set_ylabel("Accuracy") | |
| axes[2].set_ylim(0, 1.05) | |
| axes[2].set_title("Batch Accuracy (training rollouts)", fontsize=11) | |
| axes[2].yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0)) | |
| for ax in axes: | |
| ax.set_xlabel("Iteration") | |
| fig.suptitle("GRPO Training Dynamics", fontsize=13) | |
| fig.tight_layout() | |
| fig.savefig(out) | |
| plt.close(fig) | |
| print(f" saved {out.name}") | |
| def plot_reward_vs_eval(rows: List[Dict], out: Path) -> None: | |
| """Plot 04: mean_reward (all iters) + combined_score (eval iters) overlaid.""" | |
| ri, rv = _field(rows, "mean_reward") | |
| ei, ev = _field(rows, "combined_score") | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| if rv: | |
| ax.plot(ri, rv, color=PALETTE["reward"], linewidth=1.4, alpha=0.7, | |
| label="Batch reward (training)") | |
| ax.fill_between(ri, rv, alpha=0.06, color=PALETTE["reward"]) | |
| if ev: | |
| ax.plot(ei, ev, color=PALETTE["combined"], linewidth=2.5, | |
| marker="D", markersize=6, label="Eval score (held-out GSM8K)") | |
| for x, y in zip(ei, ev): | |
| ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points", | |
| xytext=(0, 8), ha="center", fontsize=7, | |
| color=PALETTE["combined"]) | |
| ax.set_xlabel("Iteration") | |
| ax.set_ylabel("Score (0 β 1)") | |
| ax.set_ylim(0, 1.05) | |
| ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0)) | |
| ax.set_title("Training Reward vs Held-Out Eval Score", fontsize=12) | |
| ax.legend() | |
| fig.tight_layout() | |
| fig.savefig(out) | |
| plt.close(fig) | |
| print(f" saved {out.name}") | |
| def plot_component_area(rows: List[Dict], out: Path) -> None: | |
| """Plot 05: stacked-area of the four WEIGHTED components summing to combined_score.""" | |
| ei, ev_combined = _field(rows, "combined_score") | |
| if not ei: | |
| return | |
| # Build per-component weighted series aligned to eval iterations | |
| iter_set = set(ei) | |
| aligned: Dict[str, List[float]] = {k: [] for k in ("correct", "prm", "sympy", "fmt")} | |
| weights = {"correct": 0.60, "prm": 0.15, "sympy": 0.15, "fmt": 0.10} | |
| keys = {"correct": "correct_rate", "prm": "prm_mean", | |
| "sympy": "sympy_mean", "fmt": "format_mean"} | |
| # Build lookup per iteration | |
| it_map: Dict[int, Dict] = {r["iteration"]: r for r in rows if r["iteration"] in iter_set} | |
| iters_sorted = sorted(iter_set) | |
| for it in iters_sorted: | |
| row = it_map.get(it, {}) | |
| for comp, field in keys.items(): | |
| v = row.get(field) | |
| if v is not None and v != "": | |
| aligned[comp].append(float(v) * weights[comp]) | |
| else: | |
| aligned[comp].append(0.0) | |
| x = np.array(iters_sorted) | |
| arr = np.array([aligned["correct"], aligned["prm"], | |
| aligned["sympy"], aligned["fmt"]]) | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| labels = ["Correct (Γ0.60)", "PRM (Γ0.15)", "SymPy (Γ0.15)", "Format (Γ0.10)"] | |
| colors = [PALETTE[k] for k in ("correct", "prm", "sympy", "fmt")] | |
| ax.stackplot(x, arr, labels=labels, colors=colors, alpha=0.75) | |
| ax.plot(x, ev_combined, color="black", linewidth=1.5, | |
| linestyle="--", label="Combined score", zorder=5) | |
| ax.set_xlabel("Iteration") | |
| ax.set_ylabel("Weighted contribution to score") | |
| ax.set_ylim(0, 1.0) | |
| ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0)) | |
| ax.set_title("Contribution of Each Reward Component (Stacked)", fontsize=12) | |
| ax.legend(loc="lower right", ncol=2) | |
| fig.tight_layout() | |
| fig.savefig(out) | |
| plt.close(fig) | |
| print(f" saved {out.name}") | |
| def plot_summary_card(rows: List[Dict], run_name: str, out: Path) -> None: | |
| """Plot 06: all key metrics on a single clean card β ideal for poster / slide.""" | |
| ei, ev = _field(rows, "combined_score") | |
| _, crv = _field(rows, "correct_rate") | |
| _, prmv = _field(rows, "prm_mean") | |
| _, syv = _field(rows, "sympy_mean") | |
| _, fmv = _field(rows, "format_mean") | |
| _, lv = _field(rows, "loss") | |
| _, rv = _field(rows, "mean_reward") | |
| li = _field(rows, "loss")[0] | |
| ri = _field(rows, "mean_reward")[0] | |
| fig, axes = plt.subplots(2, 3, figsize=(15, 8)) | |
| axes = axes.flatten() | |
| def _panel(ax, iters, vals, color, title, pct=True): | |
| if not iters: | |
| ax.set_visible(False) | |
| return | |
| ax.plot(iters, vals, color=color, linewidth=2, marker="o", markersize=4) | |
| ax.fill_between(iters, vals, alpha=0.12, color=color) | |
| ax.set_title(title, fontsize=11, fontweight="bold") | |
| ax.set_xlabel("Iteration", fontsize=9) | |
| if pct: | |
| ax.set_ylim(0, 1.05) | |
| ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0)) | |
| if vals: | |
| ax.annotate(f"{vals[-1]:.3f}", (iters[-1], vals[-1]), | |
| textcoords="offset points", xytext=(6, 4), | |
| fontsize=8, color=color) | |
| _panel(axes[0], ei, ev, PALETTE["combined"], "Training-Objective Score") | |
| _panel(axes[1], ei, crv, PALETTE["correct"], "Correctness Rate") | |
| _panel(axes[2], ei, prmv, PALETTE["prm"], "PRM Step Quality") | |
| _panel(axes[3], ei, syv, PALETTE["sympy"], "SymPy Verification") | |
| _panel(axes[4], ei, fmv, PALETTE["fmt"], "Format Compliance") | |
| _panel(axes[5], li, lv, PALETTE["loss"], "GRPO Loss", pct=False) | |
| fig.suptitle(f"GRPO Training Summary β {run_name}", fontsize=14, fontweight="bold") | |
| fig.tight_layout() | |
| fig.savefig(out, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" saved {out.name}") | |
| # ββ CLI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def find_latest_metrics() -> Optional[Path]: | |
| """Find the most recently modified metrics.jsonl under checkpoints/grpo/.""" | |
| ckpt = Path("checkpoints/grpo") | |
| if not ckpt.exists(): | |
| return None | |
| candidates = sorted( | |
| ckpt.rglob("metrics.jsonl"), | |
| key=lambda p: p.stat().st_mtime, | |
| ) | |
| return candidates[-1] if candidates else None | |
| def generate_plots(metrics_path: Path, out_dir: Optional[Path] = None) -> Path: | |
| """Generate all six plots and return the output directory.""" | |
| rows = _load(metrics_path) | |
| if not rows: | |
| print(f"[plot] No data in {metrics_path}", file=sys.stderr) | |
| return metrics_path.parent | |
| out_dir = out_dir or metrics_path.parent / "plots" | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # Derive run name from the directory name two levels up | |
| run_name = metrics_path.parent.name | |
| print(f"[plot] Generating plots for run '{run_name}' ({len(rows)} iterations)") | |
| print(f"[plot] Output β {out_dir}") | |
| plot_training_objective(rows, out_dir / "01_training_objective.png") | |
| plot_reward_components(rows, out_dir / "02_reward_components.png") | |
| plot_training_dynamics(rows, out_dir / "03_training_dynamics.png") | |
| plot_reward_vs_eval(rows, out_dir / "04_reward_vs_eval.png") | |
| plot_component_area(rows, out_dir / "05_component_area.png") | |
| plot_summary_card(rows, run_name, out_dir / "06_summary_card.png") | |
| print(f"[plot] Done β {len(list(out_dir.glob('*.png')))} PNGs in {out_dir}") | |
| return out_dir | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Generate demo plots from a GRPO metrics.jsonl file." | |
| ) | |
| parser.add_argument( | |
| "metrics_jsonl", nargs="?", type=Path, default=None, | |
| help="Path to metrics.jsonl produced by run_grpo_training.py", | |
| ) | |
| parser.add_argument( | |
| "--latest", action="store_true", | |
| help="Auto-discover the most recent metrics.jsonl under checkpoints/grpo/", | |
| ) | |
| parser.add_argument( | |
| "--out-dir", type=Path, default=None, | |
| help="Directory to write PNG files (default: <metrics_dir>/plots/)", | |
| ) | |
| args = parser.parse_args() | |
| if args.latest: | |
| path = find_latest_metrics() | |
| if path is None: | |
| print("No metrics.jsonl found under checkpoints/grpo/", file=sys.stderr) | |
| sys.exit(1) | |
| print(f"[plot] Auto-selected {path}") | |
| elif args.metrics_jsonl: | |
| path = args.metrics_jsonl | |
| else: | |
| parser.print_help() | |
| sys.exit(1) | |
| if not path.exists(): | |
| print(f"File not found: {path}", file=sys.stderr) | |
| sys.exit(1) | |
| generate_plots(path, args.out_dir) | |
| if __name__ == "__main__": | |
| main() | |