#!/usr/bin/env python3 """ Render ``evidence_grpo_training.png`` from a TRL ``trainer.state.log_history`` export. After ``python -m training.train_grpo`` (or Colab), you will have ``/grpo_log_history.json``. Run: python scripts/plot_grpo_log_history.py immunoorg-defender/grpo_log_history.json Or default path: python scripts/plot_grpo_log_history.py Requires matplotlib only. """ from __future__ import annotations import json import sys from pathlib import Path REPO_ROOT = Path(__file__).resolve().parent.parent def main() -> None: import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt default_json = REPO_ROOT / "immunoorg-defender" / "grpo_log_history.json" in_path = Path(sys.argv[1]) if len(sys.argv) > 1 else default_json out_png = REPO_ROOT / "evidence_grpo_training.png" if not in_path.is_file(): print(f"Missing {in_path} — run training first (see README / Colab).") sys.exit(1) raw = json.loads(in_path.read_text(encoding="utf-8")) if not isinstance(raw, list): print("Expected a JSON list (log_history).") sys.exit(1) steps: list[int] = [] loss: list[float] = [] reward: list[float] = [] for i, row in enumerate(raw): if not isinstance(row, dict): continue step = row.get("step") if step is None: step = i steps.append(int(step)) if "loss" in row and row["loss"] is not None: loss.append(float(row["loss"])) if "reward" in row and row["reward"] is not None: reward.append(float(row["reward"])) fig, axes = plt.subplots(1, 2, figsize=(11, 4), dpi=120) fig.suptitle("GRPO training (ImmunoOrg defender) — exported log history", fontsize=12) if len(loss) >= 2 and len(steps) >= len(loss): sx = steps[: len(loss)] axes[0].plot(sx, loss, "b-o", markersize=4) axes[0].set_xlabel("Optimization step") axes[0].set_ylabel("Loss") axes[0].set_title("Training loss") axes[0].grid(True, alpha=0.3) else: axes[0].text( 0.5, 0.5, "No loss entries in log_history.\n(Re-run with TRL that logs `loss`.)", ha="center", va="center", transform=axes[0].transAxes, ) if len(reward) >= 2: sr = steps[: len(reward)] if len(steps) >= len(reward) else list(range(len(reward))) axes[1].plot(sr, reward, "g-s", markersize=4) axes[1].set_xlabel("Optimization step") axes[1].set_ylabel("Reward (logged)") axes[1].set_title("Logged reward signal") axes[1].grid(True, alpha=0.3) else: axes[1].text( 0.5, 0.5, "No reward entries in log_history.\n(OK — loss-only runs are common.)", ha="center", va="center", transform=axes[1].transAxes, ) fig.tight_layout() fig.savefig(out_png, bbox_inches="tight") print(f"Wrote {out_png}") if __name__ == "__main__": main()