sync: today's source updates (XML-only prompt, reward unclip, neg-reward on loss, pinned versions, configs reorg)
2fc50a9 verified | """ | |
| Generate labeled PNG plots for the README from a WandB run OR from local | |
| episode_stats.jsonl files. | |
| Usage: | |
| # From a WandB run id (preferred — uses the per-step rebalanced metrics) | |
| python scripts/generate_training_plots.py \\ | |
| --wandb-run ptnv-s-research/huggingface/<RUN_ID> \\ | |
| --output-dir docs/plots/ | |
| # From local episode_stats.jsonl (faster, no API call) | |
| python scripts/generate_training_plots.py \\ | |
| --jsonl logs/run_*/episode_stats.jsonl \\ | |
| --output-dir docs/plots/ | |
| Generates (with axis labels + units): | |
| docs/plots/training_reward_over_steps.png | |
| docs/plots/per_rubric_breakdown.png | |
| docs/plots/tool_call_frequency.png | |
| docs/plots/match_completion_rate.png | |
| docs/plots/before_after_comparison.png (if --compare given) | |
| """ | |
| import argparse | |
| import glob | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| import matplotlib | |
| matplotlib.use("Agg") # headless | |
| import matplotlib.pyplot as plt | |
| def _load_jsonl(path: str) -> list[dict[str, Any]]: | |
| rows = [] | |
| paths = glob.glob(path) if "*" in path else [path] | |
| for p in paths: | |
| with open(p) as f: | |
| for line in f: | |
| line = line.strip() | |
| if line: | |
| try: | |
| rows.append(json.loads(line)) | |
| except json.JSONDecodeError: | |
| continue | |
| return rows | |
| def _load_wandb(run_path: str) -> tuple[list[dict[str, Any]], dict[str, Any]]: | |
| """Returns (history, config). Requires `pip install wandb` and login.""" | |
| try: | |
| import wandb | |
| except ImportError: | |
| raise RuntimeError("wandb not installed. pip install wandb") | |
| api = wandb.Api() | |
| run = api.run(run_path) | |
| history = list(run.history(samples=10000)) | |
| return history, run.config | |
| def plot_training_reward(history, out_dir: Path, label: str): | |
| steps, rewards = [], [] | |
| for row in history: | |
| if "rewards/environment_reward/mean" in row and row["rewards/environment_reward/mean"] is not None: | |
| steps.append(row.get("_step", row.get("step", len(steps)))) | |
| rewards.append(row["rewards/environment_reward/mean"]) | |
| if not rewards: | |
| print(" no environment_reward/mean found, skipping") | |
| return | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| ax.plot(steps, rewards, marker="o", linewidth=1.5, markersize=4, color="#0066cc") | |
| ax.set_xlabel("Training step (gradient updates)") | |
| ax.set_ylabel("Mean environment reward (composite)") | |
| ax.set_title(f"GRPO training reward over time — {label}") | |
| ax.grid(alpha=0.3) | |
| fig.tight_layout() | |
| out_path = out_dir / "training_reward_over_steps.png" | |
| fig.savefig(out_path, dpi=130) | |
| plt.close(fig) | |
| print(f" → {out_path}") | |
| def plot_per_rubric_breakdown(history, out_dir: Path, label: str): | |
| """Plot the per-step means of all 4 rubrics on one axes.""" | |
| rubrics = ("reward/composite_mean", "reward/r_result_mean", | |
| "reward/r_cricket_mean", "reward/r_behavior_mean", | |
| "reward/r_validity_mean") | |
| series = {r: [] for r in rubrics} | |
| steps_per = {r: [] for r in rubrics} | |
| for row in history: | |
| for r in rubrics: | |
| if r in row and row[r] is not None: | |
| series[r].append(row[r]) | |
| steps_per[r].append(row.get("_step", row.get("step", len(series[r])))) | |
| if not any(series.values()): | |
| print(" no per-rubric metrics found, skipping") | |
| return | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| colors = {"reward/composite_mean": "#000", | |
| "reward/r_result_mean": "#cc0000", | |
| "reward/r_cricket_mean": "#0066cc", | |
| "reward/r_behavior_mean": "#009900", | |
| "reward/r_validity_mean": "#9900cc"} | |
| for r in rubrics: | |
| if series[r]: | |
| ax.plot(steps_per[r], series[r], marker="o", markersize=3, linewidth=1.3, | |
| label=r.replace("reward/", "").replace("_mean", ""), | |
| color=colors[r]) | |
| ax.set_xlabel("Training step (gradient updates)") | |
| ax.set_ylabel("Mean reward") | |
| ax.set_title(f"Per-rubric reward breakdown — {label}") | |
| ax.legend(loc="best", fontsize=9) | |
| ax.grid(alpha=0.3) | |
| fig.tight_layout() | |
| out_path = out_dir / "per_rubric_breakdown.png" | |
| fig.savefig(out_path, dpi=130) | |
| plt.close(fig) | |
| print(f" → {out_path}") | |
| def plot_tool_call_frequency(history, out_dir: Path, label: str): | |
| steps, freq = [], [] | |
| for row in history: | |
| if "tools/call_frequency" in row and row["tools/call_frequency"] is not None: | |
| steps.append(row.get("_step", row.get("step", len(steps)))) | |
| freq.append(row["tools/call_frequency"]) | |
| if not freq: | |
| print(" no tools/call_frequency found, skipping") | |
| return | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| ax.plot(steps, freq, marker="o", linewidth=1.5, markersize=4, color="#cc6600") | |
| ax.set_xlabel("Training step (gradient updates)") | |
| ax.set_ylabel("Mean tool calls per rollout") | |
| ax.set_title(f"Tool-call execution frequency (proxy for match progress) — {label}") | |
| ax.grid(alpha=0.3) | |
| fig.tight_layout() | |
| out_path = out_dir / "tool_call_frequency.png" | |
| fig.savefig(out_path, dpi=130) | |
| plt.close(fig) | |
| print(f" → {out_path}") | |
| def plot_completion_rate(history, out_dir: Path, label: str): | |
| steps, rate = [], [] | |
| for row in history: | |
| if "rollout/match_completion_rate" in row and row["rollout/match_completion_rate"] is not None: | |
| steps.append(row.get("_step", row.get("step", len(steps)))) | |
| rate.append(row["rollout/match_completion_rate"]) | |
| if not rate: | |
| print(" no match_completion_rate found, skipping") | |
| return | |
| fig, ax = plt.subplots(figsize=(8, 4.5)) | |
| ax.plot(steps, rate, marker="o", linewidth=1.5, markersize=4, color="#009966") | |
| ax.set_xlabel("Training step (gradient updates)") | |
| ax.set_ylabel("Match completion rate") | |
| ax.set_ylim(0, 1.05) | |
| ax.set_title(f"Fraction of rollouts that completed the full match — {label}") | |
| ax.grid(alpha=0.3) | |
| fig.tight_layout() | |
| out_path = out_dir / "match_completion_rate.png" | |
| fig.savefig(out_path, dpi=130) | |
| plt.close(fig) | |
| print(f" → {out_path}") | |
| def plot_before_after(baseline_json: str, trained_json: str, out_dir: Path): | |
| """Bar chart comparing baseline vs trained on key eval metrics.""" | |
| with open(baseline_json) as f: | |
| b = json.load(f) | |
| with open(trained_json) as f: | |
| t = json.load(f) | |
| bs, ts = b["summary"], t["summary"] | |
| metrics = [ | |
| ("match_completion_rate", "Match\ncompletion rate"), | |
| ("win_rate_overall", "Overall\nwin rate"), | |
| ("mean_validity_rate", "Mean\nvalidity rate"), | |
| ("mean_composite_reward", "Mean composite\nreward (scaled)"), | |
| ] | |
| bvals = [bs.get(k, 0) or 0 for k, _ in metrics] | |
| tvals = [ts.get(k, 0) or 0 for k, _ in metrics] | |
| labels = [lbl for _, lbl in metrics] | |
| x = range(len(metrics)) | |
| fig, ax = plt.subplots(figsize=(9, 5)) | |
| width = 0.35 | |
| bars_b = ax.bar([xi - width/2 for xi in x], bvals, width, label="baseline (untrained)", color="#999") | |
| bars_t = ax.bar([xi + width/2 for xi in x], tvals, width, label="trained (LoRA r=64)", color="#0066cc") | |
| for bars in (bars_b, bars_t): | |
| for bar in bars: | |
| h = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2, h + 0.01, | |
| f"{h:.2f}", ha="center", fontsize=8) | |
| ax.set_xticks(list(x)) | |
| ax.set_xticklabels(labels) | |
| ax.set_ylabel("Metric value") | |
| ax.set_title(f"Before vs After training — {bs['n_episodes']} eval matches each") | |
| ax.legend() | |
| ax.grid(axis="y", alpha=0.3) | |
| fig.tight_layout() | |
| out_path = out_dir / "before_after_comparison.png" | |
| fig.savefig(out_path, dpi=130) | |
| plt.close(fig) | |
| print(f" → {out_path}") | |
| def main(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--wandb-run", default=None, | |
| help="WandB run path: entity/project/run_id (e.g. ptnv-s-research/huggingface/abc123)") | |
| p.add_argument("--jsonl", default=None, | |
| help="Local episode_stats.jsonl path (or glob)") | |
| p.add_argument("--output-dir", default="docs/plots", | |
| help="Output directory for PNGs (default: docs/plots/)") | |
| p.add_argument("--label", default="warmup", help="Label suffix for plot titles") | |
| p.add_argument("--compare", nargs=2, metavar=("BASELINE_JSON", "TRAINED_JSON"), | |
| help="Also generate before/after bar chart from two compare_eval JSON files") | |
| args = p.parse_args() | |
| out_dir = Path(args.output_dir) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| history = [] | |
| if args.wandb_run: | |
| print(f"Loading WandB run: {args.wandb_run}") | |
| history, _ = _load_wandb(args.wandb_run) | |
| print(f" {len(history)} history rows") | |
| elif args.jsonl: | |
| print(f"Loading local jsonl: {args.jsonl}") | |
| history = _load_jsonl(args.jsonl) | |
| print(f" {len(history)} rows") | |
| if history: | |
| plot_training_reward(history, out_dir, args.label) | |
| plot_per_rubric_breakdown(history, out_dir, args.label) | |
| plot_tool_call_frequency(history, out_dir, args.label) | |
| plot_completion_rate(history, out_dir, args.label) | |
| if args.compare: | |
| plot_before_after(args.compare[0], args.compare[1], out_dir) | |
| print(f"\nDone — PNGs in {out_dir}/") | |
| if __name__ == "__main__": | |
| main() | |