AxiomForgeAI / scripts /plot_grpo_run.py
jampuramprem's picture
Initial Space deployment
ec4ae03
#!/usr/bin/env python3
"""
Generate demo-quality plots from a completed (or in-progress) GRPO run.
Usage
-----
# from the run output directory
python scripts/plot_grpo_run.py checkpoints/grpo/<run_name>/metrics.jsonl
# auto-discover the latest run
python scripts/plot_grpo_run.py --latest
# custom output directory
python scripts/plot_grpo_run.py metrics.jsonl --out-dir plots/my_run
Output
------
Six PNG files saved next to the JSONL (or --out-dir if given):
01_training_objective.png – combined_score vs iteration (PRIMARY demo plot)
02_reward_components.png – 4-panel breakdown: correct / PRM / SymPy / format
03_training_dynamics.png – GRPO loss + batch reward + batch accuracy
04_reward_vs_eval.png – training reward vs eval score on same axis
05_component_area.png – stacked-area chart of the 4 weighted components
06_summary_card.png – single-panel card: all key metrics in one view
All figures use a clean dark-on-white academic style. They are saved at
300 dpi so they look sharp in slides and posters.
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import matplotlib
matplotlib.use("Agg") # headless β€” no display needed on training servers
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
# ── Style ────────────────────────────────────────────────────────────────────
PALETTE = {
"combined": "#2563EB", # blue β€” training objective
"correct": "#16A34A", # green β€” correctness
"prm": "#DC2626", # red β€” PRM step quality
"sympy": "#D97706", # amber β€” SymPy verification
"fmt": "#7C3AED", # violet β€” format
"reward": "#0891B2", # cyan β€” mean batch reward
"loss": "#64748B", # slate β€” loss
"batch_acc": "#059669", # emerald β€” batch accuracy
}
plt.rcParams.update({
"figure.dpi": 150,
"savefig.dpi": 300,
"font.family": "DejaVu Sans",
"axes.spines.top": False,
"axes.spines.right": False,
"axes.grid": True,
"grid.alpha": 0.3,
"grid.linestyle": "--",
"axes.labelsize": 11,
"axes.titlesize": 13,
"legend.fontsize": 9,
"xtick.labelsize": 9,
"ytick.labelsize": 9,
})
# ── Data loading ─────────────────────────────────────────────────────────────
def _load(path: Path) -> List[Dict[str, Any]]:
rows = []
with path.open(encoding="utf-8") as fh:
for line in fh:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def _field(rows: List[Dict], key: str) -> Tuple[List[int], List[float]]:
"""Return (iterations, values) for rows that have a non-empty key."""
iters, vals = [], []
for r in rows:
v = r.get(key)
if v is not None and v != "" and not (isinstance(v, float) and np.isnan(v)):
try:
iters.append(int(r["iteration"]))
vals.append(float(v))
except (TypeError, ValueError):
pass
return iters, vals
# ── Individual plots ─────────────────────────────────────────────────────────
def plot_training_objective(rows: List[Dict], out: Path) -> None:
"""Plot 01: combined_score β€” the single most important demo plot."""
xi, xv = _field(rows, "combined_score")
if not xi:
return
fig, ax = plt.subplots(figsize=(9, 5))
ax.plot(xi, xv, color=PALETTE["combined"], linewidth=2.5,
marker="o", markersize=5, label="Training-objective score")
ax.fill_between(xi, xv, alpha=0.12, color=PALETTE["combined"])
# annotate first and last eval points
ax.annotate(f"{xv[0]:.3f}", (xi[0], xv[0]), textcoords="offset points",
xytext=(8, 6), fontsize=8, color=PALETTE["combined"])
ax.annotate(f"{xv[-1]:.3f}", (xi[-1], xv[-1]), textcoords="offset points",
xytext=(8, 6), fontsize=8, color=PALETTE["combined"])
ax.set_xlabel("Iteration")
ax.set_ylabel("Score (0 – 1)")
ax.set_title(
"GRPO Training β€” Combined Reward Score\n"
"0.60 Γ— correct + 0.15 Γ— PRM + 0.15 Γ— SymPy + 0.10 Γ— format",
fontsize=12,
)
ax.set_ylim(0, 1.05)
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
ax.legend(loc="lower right")
fig.tight_layout()
fig.savefig(out)
plt.close(fig)
print(f" saved {out.name}")
def plot_reward_components(rows: List[Dict], out: Path) -> None:
"""Plot 02: four-panel breakdown of each reward component."""
specs = [
("correct_rate", "correct", "Correctness (gt_match)", "60 %"),
("prm_mean", "prm", "PRM Step Quality", "15 %"),
("sympy_mean", "sympy", "SymPy Verification", "15 %"),
("format_mean", "fmt", "Format Compliance", "10 %"),
]
fig, axes = plt.subplots(2, 2, figsize=(12, 7), sharex=False)
axes = axes.flatten()
for ax, (key, pal, title, weight) in zip(axes, specs):
xi, xv = _field(rows, key)
if not xi:
ax.set_visible(False)
continue
ax.plot(xi, xv, color=PALETTE[pal], linewidth=2,
marker="o", markersize=4)
ax.fill_between(xi, xv, alpha=0.12, color=PALETTE[pal])
ax.set_title(f"{title} (weight {weight})", fontsize=11)
ax.set_xlabel("Iteration")
ax.set_ylabel("Score")
ax.set_ylim(0, 1.05)
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
if xv:
delta = xv[-1] - xv[0]
sign = "+" if delta >= 0 else ""
ax.set_title(
f"{title} (weight {weight}) Ξ”={sign}{delta:+.1%}",
fontsize=10,
)
fig.suptitle("Reward Component Breakdown over Training", fontsize=13, y=1.01)
fig.tight_layout()
fig.savefig(out, bbox_inches="tight")
plt.close(fig)
print(f" saved {out.name}")
def plot_training_dynamics(rows: List[Dict], out: Path) -> None:
"""Plot 03: loss, mean_reward, batch_accuracy over all iterations."""
li, lv = _field(rows, "loss")
ri, rv = _field(rows, "mean_reward")
bi, bv = _field(rows, "batch_accuracy")
fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True)
if lv:
axes[0].plot(li, lv, color=PALETTE["loss"], linewidth=1.8)
axes[0].fill_between(li, lv, alpha=0.1, color=PALETTE["loss"])
axes[0].set_ylabel("GRPO Loss")
axes[0].set_title("Training Loss", fontsize=11)
axes[0].axhline(0, color="black", linewidth=0.8, linestyle="--", alpha=0.4)
if rv:
axes[1].plot(ri, rv, color=PALETTE["reward"], linewidth=1.8)
axes[1].fill_between(ri, rv, alpha=0.1, color=PALETTE["reward"])
axes[1].set_ylabel("Reward")
axes[1].set_ylim(0, 1.05)
axes[1].set_title("Mean Batch Reward", fontsize=11)
axes[1].yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
if bv:
axes[2].plot(bi, bv, color=PALETTE["batch_acc"], linewidth=1.8)
axes[2].fill_between(bi, bv, alpha=0.1, color=PALETTE["batch_acc"])
axes[2].set_ylabel("Accuracy")
axes[2].set_ylim(0, 1.05)
axes[2].set_title("Batch Accuracy (training rollouts)", fontsize=11)
axes[2].yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
for ax in axes:
ax.set_xlabel("Iteration")
fig.suptitle("GRPO Training Dynamics", fontsize=13)
fig.tight_layout()
fig.savefig(out)
plt.close(fig)
print(f" saved {out.name}")
def plot_reward_vs_eval(rows: List[Dict], out: Path) -> None:
"""Plot 04: mean_reward (all iters) + combined_score (eval iters) overlaid."""
ri, rv = _field(rows, "mean_reward")
ei, ev = _field(rows, "combined_score")
fig, ax = plt.subplots(figsize=(10, 5))
if rv:
ax.plot(ri, rv, color=PALETTE["reward"], linewidth=1.4, alpha=0.7,
label="Batch reward (training)")
ax.fill_between(ri, rv, alpha=0.06, color=PALETTE["reward"])
if ev:
ax.plot(ei, ev, color=PALETTE["combined"], linewidth=2.5,
marker="D", markersize=6, label="Eval score (held-out GSM8K)")
for x, y in zip(ei, ev):
ax.annotate(f"{y:.3f}", (x, y), textcoords="offset points",
xytext=(0, 8), ha="center", fontsize=7,
color=PALETTE["combined"])
ax.set_xlabel("Iteration")
ax.set_ylabel("Score (0 – 1)")
ax.set_ylim(0, 1.05)
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
ax.set_title("Training Reward vs Held-Out Eval Score", fontsize=12)
ax.legend()
fig.tight_layout()
fig.savefig(out)
plt.close(fig)
print(f" saved {out.name}")
def plot_component_area(rows: List[Dict], out: Path) -> None:
"""Plot 05: stacked-area of the four WEIGHTED components summing to combined_score."""
ei, ev_combined = _field(rows, "combined_score")
if not ei:
return
# Build per-component weighted series aligned to eval iterations
iter_set = set(ei)
aligned: Dict[str, List[float]] = {k: [] for k in ("correct", "prm", "sympy", "fmt")}
weights = {"correct": 0.60, "prm": 0.15, "sympy": 0.15, "fmt": 0.10}
keys = {"correct": "correct_rate", "prm": "prm_mean",
"sympy": "sympy_mean", "fmt": "format_mean"}
# Build lookup per iteration
it_map: Dict[int, Dict] = {r["iteration"]: r for r in rows if r["iteration"] in iter_set}
iters_sorted = sorted(iter_set)
for it in iters_sorted:
row = it_map.get(it, {})
for comp, field in keys.items():
v = row.get(field)
if v is not None and v != "":
aligned[comp].append(float(v) * weights[comp])
else:
aligned[comp].append(0.0)
x = np.array(iters_sorted)
arr = np.array([aligned["correct"], aligned["prm"],
aligned["sympy"], aligned["fmt"]])
fig, ax = plt.subplots(figsize=(10, 5))
labels = ["Correct (Γ—0.60)", "PRM (Γ—0.15)", "SymPy (Γ—0.15)", "Format (Γ—0.10)"]
colors = [PALETTE[k] for k in ("correct", "prm", "sympy", "fmt")]
ax.stackplot(x, arr, labels=labels, colors=colors, alpha=0.75)
ax.plot(x, ev_combined, color="black", linewidth=1.5,
linestyle="--", label="Combined score", zorder=5)
ax.set_xlabel("Iteration")
ax.set_ylabel("Weighted contribution to score")
ax.set_ylim(0, 1.0)
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
ax.set_title("Contribution of Each Reward Component (Stacked)", fontsize=12)
ax.legend(loc="lower right", ncol=2)
fig.tight_layout()
fig.savefig(out)
plt.close(fig)
print(f" saved {out.name}")
def plot_summary_card(rows: List[Dict], run_name: str, out: Path) -> None:
"""Plot 06: all key metrics on a single clean card β€” ideal for poster / slide."""
ei, ev = _field(rows, "combined_score")
_, crv = _field(rows, "correct_rate")
_, prmv = _field(rows, "prm_mean")
_, syv = _field(rows, "sympy_mean")
_, fmv = _field(rows, "format_mean")
_, lv = _field(rows, "loss")
_, rv = _field(rows, "mean_reward")
li = _field(rows, "loss")[0]
ri = _field(rows, "mean_reward")[0]
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()
def _panel(ax, iters, vals, color, title, pct=True):
if not iters:
ax.set_visible(False)
return
ax.plot(iters, vals, color=color, linewidth=2, marker="o", markersize=4)
ax.fill_between(iters, vals, alpha=0.12, color=color)
ax.set_title(title, fontsize=11, fontweight="bold")
ax.set_xlabel("Iteration", fontsize=9)
if pct:
ax.set_ylim(0, 1.05)
ax.yaxis.set_major_formatter(mtick.PercentFormatter(xmax=1, decimals=0))
if vals:
ax.annotate(f"{vals[-1]:.3f}", (iters[-1], vals[-1]),
textcoords="offset points", xytext=(6, 4),
fontsize=8, color=color)
_panel(axes[0], ei, ev, PALETTE["combined"], "Training-Objective Score")
_panel(axes[1], ei, crv, PALETTE["correct"], "Correctness Rate")
_panel(axes[2], ei, prmv, PALETTE["prm"], "PRM Step Quality")
_panel(axes[3], ei, syv, PALETTE["sympy"], "SymPy Verification")
_panel(axes[4], ei, fmv, PALETTE["fmt"], "Format Compliance")
_panel(axes[5], li, lv, PALETTE["loss"], "GRPO Loss", pct=False)
fig.suptitle(f"GRPO Training Summary β€” {run_name}", fontsize=14, fontweight="bold")
fig.tight_layout()
fig.savefig(out, bbox_inches="tight")
plt.close(fig)
print(f" saved {out.name}")
# ── CLI ──────────────────────────────────────────────────────────────────────
def find_latest_metrics() -> Optional[Path]:
"""Find the most recently modified metrics.jsonl under checkpoints/grpo/."""
ckpt = Path("checkpoints/grpo")
if not ckpt.exists():
return None
candidates = sorted(
ckpt.rglob("metrics.jsonl"),
key=lambda p: p.stat().st_mtime,
)
return candidates[-1] if candidates else None
def generate_plots(metrics_path: Path, out_dir: Optional[Path] = None) -> Path:
"""Generate all six plots and return the output directory."""
rows = _load(metrics_path)
if not rows:
print(f"[plot] No data in {metrics_path}", file=sys.stderr)
return metrics_path.parent
out_dir = out_dir or metrics_path.parent / "plots"
out_dir.mkdir(parents=True, exist_ok=True)
# Derive run name from the directory name two levels up
run_name = metrics_path.parent.name
print(f"[plot] Generating plots for run '{run_name}' ({len(rows)} iterations)")
print(f"[plot] Output β†’ {out_dir}")
plot_training_objective(rows, out_dir / "01_training_objective.png")
plot_reward_components(rows, out_dir / "02_reward_components.png")
plot_training_dynamics(rows, out_dir / "03_training_dynamics.png")
plot_reward_vs_eval(rows, out_dir / "04_reward_vs_eval.png")
plot_component_area(rows, out_dir / "05_component_area.png")
plot_summary_card(rows, run_name, out_dir / "06_summary_card.png")
print(f"[plot] Done β€” {len(list(out_dir.glob('*.png')))} PNGs in {out_dir}")
return out_dir
def main() -> None:
parser = argparse.ArgumentParser(
description="Generate demo plots from a GRPO metrics.jsonl file."
)
parser.add_argument(
"metrics_jsonl", nargs="?", type=Path, default=None,
help="Path to metrics.jsonl produced by run_grpo_training.py",
)
parser.add_argument(
"--latest", action="store_true",
help="Auto-discover the most recent metrics.jsonl under checkpoints/grpo/",
)
parser.add_argument(
"--out-dir", type=Path, default=None,
help="Directory to write PNG files (default: <metrics_dir>/plots/)",
)
args = parser.parse_args()
if args.latest:
path = find_latest_metrics()
if path is None:
print("No metrics.jsonl found under checkpoints/grpo/", file=sys.stderr)
sys.exit(1)
print(f"[plot] Auto-selected {path}")
elif args.metrics_jsonl:
path = args.metrics_jsonl
else:
parser.print_help()
sys.exit(1)
if not path.exists():
print(f"File not found: {path}", file=sys.stderr)
sys.exit(1)
generate_plots(path, args.out_dir)
if __name__ == "__main__":
main()