prompt_golf_env / training /make_plots.py
Don Rishabh
Initial commit: Prompt Golf environment for OpenEnv
6850dad
"""
Read train_metrics.jsonl from a training run and produce the demo plots:
- reward_curve.png : mean reward per GRPO step (the headline)
- length_curve.png : mean prompt tokens per GRPO step (the compression story)
- breakdown.png : 2x2 grid of reward, tokens, raw_task_score, length_factor
Usage:
python training/make_plots.py --metrics outputs/grpo/train_metrics.jsonl --out-dir plots
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Dict, List
def load(path: Path) -> List[Dict]:
rows: List[Dict] = []
with path.open() as f:
for line in f:
line = line.strip()
if not line:
continue
rows.append(json.loads(line))
return rows
def smooth(values: List[float], k: int = 10) -> List[float]:
if not values or k <= 1:
return values
out = []
window: List[float] = []
for v in values:
window.append(v)
if len(window) > k:
window.pop(0)
out.append(sum(window) / len(window))
return out
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--metrics", required=True)
p.add_argument("--out-dir", default="plots")
p.add_argument("--smooth-k", type=int, default=10)
p.add_argument("--title-suffix", default="")
return p.parse_args()
def main() -> None:
args = parse_args()
import matplotlib.pyplot as plt
rows = load(Path(args.metrics))
if not rows:
raise SystemExit(f"no rows in {args.metrics}")
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
steps = [r["step"] for r in rows]
rewards = [r.get("avg_reward") or r.get("reward") or 0.0 for r in rows]
tokens = [r.get("avg_tokens") or 0.0 for r in rows]
raws = [r.get("avg_raw_score") or 0.0 for r in rows]
lfs = [r.get("avg_length_factor") or 0.0 for r in rows]
# --- reward curve ---
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(steps, rewards, alpha=0.4, label="per-step")
ax.plot(steps, smooth(rewards, args.smooth_k), linewidth=2, label=f"smoothed (k={args.smooth_k})")
ax.set_xlabel("GRPO step")
ax.set_ylabel("mean reward (group-averaged)")
ax.set_title(f"Prompt Golf — Reward per step{args.title_suffix}")
ax.grid(True, alpha=0.3)
ax.legend()
ax.set_ylim(bottom=0)
fig.tight_layout()
fig.savefig(out_dir / "reward_curve.png", dpi=150)
plt.close(fig)
# --- length curve ---
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(steps, tokens, alpha=0.4, label="per-step")
ax.plot(steps, smooth(tokens, args.smooth_k), linewidth=2, label=f"smoothed (k={args.smooth_k})")
ax.set_xlabel("GRPO step")
ax.set_ylabel("mean prompt tokens (target tokenizer)")
ax.set_title(f"Prompt Golf — Compression emerging{args.title_suffix}")
ax.grid(True, alpha=0.3)
ax.legend()
ax.set_ylim(bottom=0)
fig.tight_layout()
fig.savefig(out_dir / "length_curve.png", dpi=150)
plt.close(fig)
# --- 2x2 breakdown ---
fig, axes = plt.subplots(2, 2, figsize=(11, 7))
series = [
("reward", rewards, "mean reward"),
("tokens", tokens, "mean prompt tokens"),
("raw_task_score", raws, "raw task score"),
("length_factor", lfs, "length factor"),
]
for ax, (name, vals, ylabel) in zip(axes.flat, series):
ax.plot(steps, vals, alpha=0.4)
ax.plot(steps, smooth(vals, args.smooth_k), linewidth=2)
ax.set_title(name)
ax.set_xlabel("step")
ax.set_ylabel(ylabel)
ax.grid(True, alpha=0.3)
fig.suptitle(f"Prompt Golf — training breakdown{args.title_suffix}", fontsize=13)
fig.tight_layout()
fig.savefig(out_dir / "breakdown.png", dpi=150)
plt.close(fig)
print(f"Saved plots to {out_dir}/")
for f in ("reward_curve.png", "length_curve.png", "breakdown.png"):
print(f" {out_dir / f}")
if __name__ == "__main__":
main()