cricket-captain-llm / scripts /generate_training_plots.py
pratinavseth's picture
sync: today's source updates (XML-only prompt, reward unclip, neg-reward on loss, pinned versions, configs reorg)
2fc50a9 verified
"""
Generate labeled PNG plots for the README from a WandB run OR from local
episode_stats.jsonl files.
Usage:
# From a WandB run id (preferred — uses the per-step rebalanced metrics)
python scripts/generate_training_plots.py \\
--wandb-run ptnv-s-research/huggingface/<RUN_ID> \\
--output-dir docs/plots/
# From local episode_stats.jsonl (faster, no API call)
python scripts/generate_training_plots.py \\
--jsonl logs/run_*/episode_stats.jsonl \\
--output-dir docs/plots/
Generates (with axis labels + units):
docs/plots/training_reward_over_steps.png
docs/plots/per_rubric_breakdown.png
docs/plots/tool_call_frequency.png
docs/plots/match_completion_rate.png
docs/plots/before_after_comparison.png (if --compare given)
"""
import argparse
import glob
import json
import os
from pathlib import Path
from typing import Any
import matplotlib
matplotlib.use("Agg") # headless
import matplotlib.pyplot as plt
def _load_jsonl(path: str) -> list[dict[str, Any]]:
rows = []
paths = glob.glob(path) if "*" in path else [path]
for p in paths:
with open(p) as f:
for line in f:
line = line.strip()
if line:
try:
rows.append(json.loads(line))
except json.JSONDecodeError:
continue
return rows
def _load_wandb(run_path: str) -> tuple[list[dict[str, Any]], dict[str, Any]]:
"""Returns (history, config). Requires `pip install wandb` and login."""
try:
import wandb
except ImportError:
raise RuntimeError("wandb not installed. pip install wandb")
api = wandb.Api()
run = api.run(run_path)
history = list(run.history(samples=10000))
return history, run.config
def plot_training_reward(history, out_dir: Path, label: str):
steps, rewards = [], []
for row in history:
if "rewards/environment_reward/mean" in row and row["rewards/environment_reward/mean"] is not None:
steps.append(row.get("_step", row.get("step", len(steps))))
rewards.append(row["rewards/environment_reward/mean"])
if not rewards:
print(" no environment_reward/mean found, skipping")
return
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.plot(steps, rewards, marker="o", linewidth=1.5, markersize=4, color="#0066cc")
ax.set_xlabel("Training step (gradient updates)")
ax.set_ylabel("Mean environment reward (composite)")
ax.set_title(f"GRPO training reward over time — {label}")
ax.grid(alpha=0.3)
fig.tight_layout()
out_path = out_dir / "training_reward_over_steps.png"
fig.savefig(out_path, dpi=130)
plt.close(fig)
print(f" → {out_path}")
def plot_per_rubric_breakdown(history, out_dir: Path, label: str):
"""Plot the per-step means of all 4 rubrics on one axes."""
rubrics = ("reward/composite_mean", "reward/r_result_mean",
"reward/r_cricket_mean", "reward/r_behavior_mean",
"reward/r_validity_mean")
series = {r: [] for r in rubrics}
steps_per = {r: [] for r in rubrics}
for row in history:
for r in rubrics:
if r in row and row[r] is not None:
series[r].append(row[r])
steps_per[r].append(row.get("_step", row.get("step", len(series[r]))))
if not any(series.values()):
print(" no per-rubric metrics found, skipping")
return
fig, ax = plt.subplots(figsize=(9, 5))
colors = {"reward/composite_mean": "#000",
"reward/r_result_mean": "#cc0000",
"reward/r_cricket_mean": "#0066cc",
"reward/r_behavior_mean": "#009900",
"reward/r_validity_mean": "#9900cc"}
for r in rubrics:
if series[r]:
ax.plot(steps_per[r], series[r], marker="o", markersize=3, linewidth=1.3,
label=r.replace("reward/", "").replace("_mean", ""),
color=colors[r])
ax.set_xlabel("Training step (gradient updates)")
ax.set_ylabel("Mean reward")
ax.set_title(f"Per-rubric reward breakdown — {label}")
ax.legend(loc="best", fontsize=9)
ax.grid(alpha=0.3)
fig.tight_layout()
out_path = out_dir / "per_rubric_breakdown.png"
fig.savefig(out_path, dpi=130)
plt.close(fig)
print(f" → {out_path}")
def plot_tool_call_frequency(history, out_dir: Path, label: str):
steps, freq = [], []
for row in history:
if "tools/call_frequency" in row and row["tools/call_frequency"] is not None:
steps.append(row.get("_step", row.get("step", len(steps))))
freq.append(row["tools/call_frequency"])
if not freq:
print(" no tools/call_frequency found, skipping")
return
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.plot(steps, freq, marker="o", linewidth=1.5, markersize=4, color="#cc6600")
ax.set_xlabel("Training step (gradient updates)")
ax.set_ylabel("Mean tool calls per rollout")
ax.set_title(f"Tool-call execution frequency (proxy for match progress) — {label}")
ax.grid(alpha=0.3)
fig.tight_layout()
out_path = out_dir / "tool_call_frequency.png"
fig.savefig(out_path, dpi=130)
plt.close(fig)
print(f" → {out_path}")
def plot_completion_rate(history, out_dir: Path, label: str):
steps, rate = [], []
for row in history:
if "rollout/match_completion_rate" in row and row["rollout/match_completion_rate"] is not None:
steps.append(row.get("_step", row.get("step", len(steps))))
rate.append(row["rollout/match_completion_rate"])
if not rate:
print(" no match_completion_rate found, skipping")
return
fig, ax = plt.subplots(figsize=(8, 4.5))
ax.plot(steps, rate, marker="o", linewidth=1.5, markersize=4, color="#009966")
ax.set_xlabel("Training step (gradient updates)")
ax.set_ylabel("Match completion rate")
ax.set_ylim(0, 1.05)
ax.set_title(f"Fraction of rollouts that completed the full match — {label}")
ax.grid(alpha=0.3)
fig.tight_layout()
out_path = out_dir / "match_completion_rate.png"
fig.savefig(out_path, dpi=130)
plt.close(fig)
print(f" → {out_path}")
def plot_before_after(baseline_json: str, trained_json: str, out_dir: Path):
"""Bar chart comparing baseline vs trained on key eval metrics."""
with open(baseline_json) as f:
b = json.load(f)
with open(trained_json) as f:
t = json.load(f)
bs, ts = b["summary"], t["summary"]
metrics = [
("match_completion_rate", "Match\ncompletion rate"),
("win_rate_overall", "Overall\nwin rate"),
("mean_validity_rate", "Mean\nvalidity rate"),
("mean_composite_reward", "Mean composite\nreward (scaled)"),
]
bvals = [bs.get(k, 0) or 0 for k, _ in metrics]
tvals = [ts.get(k, 0) or 0 for k, _ in metrics]
labels = [lbl for _, lbl in metrics]
x = range(len(metrics))
fig, ax = plt.subplots(figsize=(9, 5))
width = 0.35
bars_b = ax.bar([xi - width/2 for xi in x], bvals, width, label="baseline (untrained)", color="#999")
bars_t = ax.bar([xi + width/2 for xi in x], tvals, width, label="trained (LoRA r=64)", color="#0066cc")
for bars in (bars_b, bars_t):
for bar in bars:
h = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, h + 0.01,
f"{h:.2f}", ha="center", fontsize=8)
ax.set_xticks(list(x))
ax.set_xticklabels(labels)
ax.set_ylabel("Metric value")
ax.set_title(f"Before vs After training — {bs['n_episodes']} eval matches each")
ax.legend()
ax.grid(axis="y", alpha=0.3)
fig.tight_layout()
out_path = out_dir / "before_after_comparison.png"
fig.savefig(out_path, dpi=130)
plt.close(fig)
print(f" → {out_path}")
def main():
p = argparse.ArgumentParser()
p.add_argument("--wandb-run", default=None,
help="WandB run path: entity/project/run_id (e.g. ptnv-s-research/huggingface/abc123)")
p.add_argument("--jsonl", default=None,
help="Local episode_stats.jsonl path (or glob)")
p.add_argument("--output-dir", default="docs/plots",
help="Output directory for PNGs (default: docs/plots/)")
p.add_argument("--label", default="warmup", help="Label suffix for plot titles")
p.add_argument("--compare", nargs=2, metavar=("BASELINE_JSON", "TRAINED_JSON"),
help="Also generate before/after bar chart from two compare_eval JSON files")
args = p.parse_args()
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
history = []
if args.wandb_run:
print(f"Loading WandB run: {args.wandb_run}")
history, _ = _load_wandb(args.wandb_run)
print(f" {len(history)} history rows")
elif args.jsonl:
print(f"Loading local jsonl: {args.jsonl}")
history = _load_jsonl(args.jsonl)
print(f" {len(history)} rows")
if history:
plot_training_reward(history, out_dir, args.label)
plot_per_rubric_breakdown(history, out_dir, args.label)
plot_tool_call_frequency(history, out_dir, args.label)
plot_completion_rate(history, out_dir, args.label)
if args.compare:
plot_before_after(args.compare[0], args.compare[1], out_dir)
print(f"\nDone — PNGs in {out_dir}/")
if __name__ == "__main__":
main()