""" Generate labeled PNG plots for the README from a WandB run OR from local episode_stats.jsonl files. Usage: # From a WandB run id (preferred — uses the per-step rebalanced metrics) python scripts/generate_training_plots.py \\ --wandb-run ptnv-s-research/huggingface/ \\ --output-dir docs/plots/ # From local episode_stats.jsonl (faster, no API call) python scripts/generate_training_plots.py \\ --jsonl logs/run_*/episode_stats.jsonl \\ --output-dir docs/plots/ Generates (with axis labels + units): docs/plots/training_reward_over_steps.png docs/plots/per_rubric_breakdown.png docs/plots/tool_call_frequency.png docs/plots/match_completion_rate.png docs/plots/before_after_comparison.png (if --compare given) """ import argparse import glob import json import os from pathlib import Path from typing import Any import matplotlib matplotlib.use("Agg") # headless import matplotlib.pyplot as plt def _load_jsonl(path: str) -> list[dict[str, Any]]: rows = [] paths = glob.glob(path) if "*" in path else [path] for p in paths: with open(p) as f: for line in f: line = line.strip() if line: try: rows.append(json.loads(line)) except json.JSONDecodeError: continue return rows def _load_wandb(run_path: str) -> tuple[list[dict[str, Any]], dict[str, Any]]: """Returns (history, config). Requires `pip install wandb` and login.""" try: import wandb except ImportError: raise RuntimeError("wandb not installed. pip install wandb") api = wandb.Api() run = api.run(run_path) history = list(run.history(samples=10000)) return history, run.config def plot_training_reward(history, out_dir: Path, label: str): steps, rewards = [], [] for row in history: if "rewards/environment_reward/mean" in row and row["rewards/environment_reward/mean"] is not None: steps.append(row.get("_step", row.get("step", len(steps)))) rewards.append(row["rewards/environment_reward/mean"]) if not rewards: print(" no environment_reward/mean found, skipping") return fig, ax = plt.subplots(figsize=(8, 4.5)) ax.plot(steps, rewards, marker="o", linewidth=1.5, markersize=4, color="#0066cc") ax.set_xlabel("Training step (gradient updates)") ax.set_ylabel("Mean environment reward (composite)") ax.set_title(f"GRPO training reward over time — {label}") ax.grid(alpha=0.3) fig.tight_layout() out_path = out_dir / "training_reward_over_steps.png" fig.savefig(out_path, dpi=130) plt.close(fig) print(f" → {out_path}") def plot_per_rubric_breakdown(history, out_dir: Path, label: str): """Plot the per-step means of all 4 rubrics on one axes.""" rubrics = ("reward/composite_mean", "reward/r_result_mean", "reward/r_cricket_mean", "reward/r_behavior_mean", "reward/r_validity_mean") series = {r: [] for r in rubrics} steps_per = {r: [] for r in rubrics} for row in history: for r in rubrics: if r in row and row[r] is not None: series[r].append(row[r]) steps_per[r].append(row.get("_step", row.get("step", len(series[r])))) if not any(series.values()): print(" no per-rubric metrics found, skipping") return fig, ax = plt.subplots(figsize=(9, 5)) colors = {"reward/composite_mean": "#000", "reward/r_result_mean": "#cc0000", "reward/r_cricket_mean": "#0066cc", "reward/r_behavior_mean": "#009900", "reward/r_validity_mean": "#9900cc"} for r in rubrics: if series[r]: ax.plot(steps_per[r], series[r], marker="o", markersize=3, linewidth=1.3, label=r.replace("reward/", "").replace("_mean", ""), color=colors[r]) ax.set_xlabel("Training step (gradient updates)") ax.set_ylabel("Mean reward") ax.set_title(f"Per-rubric reward breakdown — {label}") ax.legend(loc="best", fontsize=9) ax.grid(alpha=0.3) fig.tight_layout() out_path = out_dir / "per_rubric_breakdown.png" fig.savefig(out_path, dpi=130) plt.close(fig) print(f" → {out_path}") def plot_tool_call_frequency(history, out_dir: Path, label: str): steps, freq = [], [] for row in history: if "tools/call_frequency" in row and row["tools/call_frequency"] is not None: steps.append(row.get("_step", row.get("step", len(steps)))) freq.append(row["tools/call_frequency"]) if not freq: print(" no tools/call_frequency found, skipping") return fig, ax = plt.subplots(figsize=(8, 4.5)) ax.plot(steps, freq, marker="o", linewidth=1.5, markersize=4, color="#cc6600") ax.set_xlabel("Training step (gradient updates)") ax.set_ylabel("Mean tool calls per rollout") ax.set_title(f"Tool-call execution frequency (proxy for match progress) — {label}") ax.grid(alpha=0.3) fig.tight_layout() out_path = out_dir / "tool_call_frequency.png" fig.savefig(out_path, dpi=130) plt.close(fig) print(f" → {out_path}") def plot_completion_rate(history, out_dir: Path, label: str): steps, rate = [], [] for row in history: if "rollout/match_completion_rate" in row and row["rollout/match_completion_rate"] is not None: steps.append(row.get("_step", row.get("step", len(steps)))) rate.append(row["rollout/match_completion_rate"]) if not rate: print(" no match_completion_rate found, skipping") return fig, ax = plt.subplots(figsize=(8, 4.5)) ax.plot(steps, rate, marker="o", linewidth=1.5, markersize=4, color="#009966") ax.set_xlabel("Training step (gradient updates)") ax.set_ylabel("Match completion rate") ax.set_ylim(0, 1.05) ax.set_title(f"Fraction of rollouts that completed the full match — {label}") ax.grid(alpha=0.3) fig.tight_layout() out_path = out_dir / "match_completion_rate.png" fig.savefig(out_path, dpi=130) plt.close(fig) print(f" → {out_path}") def plot_before_after(baseline_json: str, trained_json: str, out_dir: Path): """Bar chart comparing baseline vs trained on key eval metrics.""" with open(baseline_json) as f: b = json.load(f) with open(trained_json) as f: t = json.load(f) bs, ts = b["summary"], t["summary"] metrics = [ ("match_completion_rate", "Match\ncompletion rate"), ("win_rate_overall", "Overall\nwin rate"), ("mean_validity_rate", "Mean\nvalidity rate"), ("mean_composite_reward", "Mean composite\nreward (scaled)"), ] bvals = [bs.get(k, 0) or 0 for k, _ in metrics] tvals = [ts.get(k, 0) or 0 for k, _ in metrics] labels = [lbl for _, lbl in metrics] x = range(len(metrics)) fig, ax = plt.subplots(figsize=(9, 5)) width = 0.35 bars_b = ax.bar([xi - width/2 for xi in x], bvals, width, label="baseline (untrained)", color="#999") bars_t = ax.bar([xi + width/2 for xi in x], tvals, width, label="trained (LoRA r=64)", color="#0066cc") for bars in (bars_b, bars_t): for bar in bars: h = bar.get_height() ax.text(bar.get_x() + bar.get_width()/2, h + 0.01, f"{h:.2f}", ha="center", fontsize=8) ax.set_xticks(list(x)) ax.set_xticklabels(labels) ax.set_ylabel("Metric value") ax.set_title(f"Before vs After training — {bs['n_episodes']} eval matches each") ax.legend() ax.grid(axis="y", alpha=0.3) fig.tight_layout() out_path = out_dir / "before_after_comparison.png" fig.savefig(out_path, dpi=130) plt.close(fig) print(f" → {out_path}") def main(): p = argparse.ArgumentParser() p.add_argument("--wandb-run", default=None, help="WandB run path: entity/project/run_id (e.g. ptnv-s-research/huggingface/abc123)") p.add_argument("--jsonl", default=None, help="Local episode_stats.jsonl path (or glob)") p.add_argument("--output-dir", default="docs/plots", help="Output directory for PNGs (default: docs/plots/)") p.add_argument("--label", default="warmup", help="Label suffix for plot titles") p.add_argument("--compare", nargs=2, metavar=("BASELINE_JSON", "TRAINED_JSON"), help="Also generate before/after bar chart from two compare_eval JSON files") args = p.parse_args() out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) history = [] if args.wandb_run: print(f"Loading WandB run: {args.wandb_run}") history, _ = _load_wandb(args.wandb_run) print(f" {len(history)} history rows") elif args.jsonl: print(f"Loading local jsonl: {args.jsonl}") history = _load_jsonl(args.jsonl) print(f" {len(history)} rows") if history: plot_training_reward(history, out_dir, args.label) plot_per_rubric_breakdown(history, out_dir, args.label) plot_tool_call_frequency(history, out_dir, args.label) plot_completion_rate(history, out_dir, args.label) if args.compare: plot_before_after(args.compare[0], args.compare[1], out_dir) print(f"\nDone — PNGs in {out_dir}/") if __name__ == "__main__": main()