chakravyuh / eval /plot_training_curves.py
UjjwalPardeshi
deploy: latest main to HF Space
03815d6
"""Render the v2 Analyzer training curves from `logs/v2_trainer_state.json`.
Produces `plots/chakravyuh_plots/training_curves_v2.png` β€” a 4-panel
figure showing reward, loss, KL divergence, and gradient norm over the
615-step GRPO run.
This is the artifact the hackathon's "evidence you actually trained
β€” at minimum, loss and reward plots from a real run" non-negotiable
asks for. The data is the canonical TRL `trainer_state.json` log
captured during the v2 retrain (n=123 logged points at logging_steps=5).
Usage:
python eval/plot_training_curves.py
# β†’ plots/chakravyuh_plots/training_curves_v2.png
"""
from __future__ import annotations
import json
from pathlib import Path
import matplotlib.pyplot as plt
def main() -> int:
log_path = Path("logs/v2_trainer_state.json")
out_path = Path("plots/chakravyuh_plots/training_curves_v2.png")
out_path.parent.mkdir(parents=True, exist_ok=True)
state = json.loads(log_path.read_text(encoding="utf-8"))
history = state["log_history"]
steps = [h["step"] for h in history]
reward = [h["reward"] for h in history]
reward_std = [h.get("reward_std", 0.0) for h in history]
loss = [h["loss"] for h in history]
kl = [h["kl"] for h in history]
grad_norm = [h["grad_norm"] for h in history]
fig, axes = plt.subplots(2, 2, figsize=(12, 8), constrained_layout=True)
fig.suptitle(
"Chakravyuh Analyzer v2 β€” TRL GRPO training curves "
f"(n={len(history)} logged steps Β· base = Qwen2.5-7B-Instruct + LoRA r=64)",
fontsize=12,
fontweight="bold",
)
# 1. Reward (with Β±1 std band)
ax = axes[0, 0]
ax.plot(steps, reward, color="#1565c0", linewidth=1.8, label="Mean episode reward")
upper = [r + s for r, s in zip(reward, reward_std)]
lower = [r - s for r, s in zip(reward, reward_std)]
ax.fill_between(steps, lower, upper, color="#1565c0", alpha=0.15, label="Β±1 std")
ax.set_xlabel("Training step")
ax.set_ylabel("Reward (8-rubric weighted sum)")
ax.set_title("Episode reward β€” climbs and stabilises")
ax.grid(True, alpha=0.3)
ax.legend(loc="lower right", fontsize=9)
# 2. Loss
ax = axes[0, 1]
ax.plot(steps, loss, color="#c62828", linewidth=1.8)
ax.axhline(0.0, color="black", linestyle=":", linewidth=0.8, alpha=0.5)
ax.set_xlabel("Training step")
ax.set_ylabel("GRPO loss")
ax.set_title("Loss β€” bounded around zero, clean (no divergence)")
ax.grid(True, alpha=0.3)
# 3. KL divergence β€” anti-collapse anchor
ax = axes[1, 0]
ax.plot(steps, kl, color="#558b2f", linewidth=1.8)
ax.axhline(0.20, color="#e65100", linestyle="--", linewidth=1.0, alpha=0.7,
label="KL guard threshold (v3 plan)")
ax.set_xlabel("Training step")
ax.set_ylabel("KL divergence vs reference")
ax.set_title("KL β€” plateaus 0.25–0.45 (honest read in training_diagnostics.md)")
ax.grid(True, alpha=0.3)
ax.legend(loc="lower right", fontsize=9)
# 4. Gradient norm β€” stability indicator
ax = axes[1, 1]
ax.plot(steps, grad_norm, color="#6a1b9a", linewidth=1.8)
ax.set_xlabel("Training step")
ax.set_ylabel("Gradient norm")
ax.set_title("Grad norm β€” well-behaved (no explosions)")
ax.grid(True, alpha=0.3)
fig.savefig(out_path, dpi=120, bbox_inches="tight")
print(f"Wrote {out_path} ({out_path.stat().st_size:,} bytes)")
return 0
if __name__ == "__main__":
raise SystemExit(main())