File size: 3,093 Bytes
1046afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python3
"""
Render ``evidence_grpo_training.png`` from a TRL ``trainer.state.log_history`` export.

After ``python -m training.train_grpo`` (or Colab), you will have
``<output_dir>/grpo_log_history.json``. Run:

  python scripts/plot_grpo_log_history.py immunoorg-defender/grpo_log_history.json

Or default path:

  python scripts/plot_grpo_log_history.py

Requires matplotlib only.
"""

from __future__ import annotations

import json
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parent.parent


def main() -> None:
    import matplotlib

    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    default_json = REPO_ROOT / "immunoorg-defender" / "grpo_log_history.json"
    in_path = Path(sys.argv[1]) if len(sys.argv) > 1 else default_json
    out_png = REPO_ROOT / "evidence_grpo_training.png"

    if not in_path.is_file():
        print(f"Missing {in_path} — run training first (see README / Colab).")
        sys.exit(1)

    raw = json.loads(in_path.read_text(encoding="utf-8"))
    if not isinstance(raw, list):
        print("Expected a JSON list (log_history).")
        sys.exit(1)

    steps: list[int] = []
    loss: list[float] = []
    reward: list[float] = []

    for i, row in enumerate(raw):
        if not isinstance(row, dict):
            continue
        step = row.get("step")
        if step is None:
            step = i
        steps.append(int(step))
        if "loss" in row and row["loss"] is not None:
            loss.append(float(row["loss"]))
        if "reward" in row and row["reward"] is not None:
            reward.append(float(row["reward"]))

    fig, axes = plt.subplots(1, 2, figsize=(11, 4), dpi=120)
    fig.suptitle("GRPO training (ImmunoOrg defender) — exported log history", fontsize=12)

    if len(loss) >= 2 and len(steps) >= len(loss):
        sx = steps[: len(loss)]
        axes[0].plot(sx, loss, "b-o", markersize=4)
        axes[0].set_xlabel("Optimization step")
        axes[0].set_ylabel("Loss")
        axes[0].set_title("Training loss")
        axes[0].grid(True, alpha=0.3)
    else:
        axes[0].text(
            0.5,
            0.5,
            "No loss entries in log_history.\n(Re-run with TRL that logs `loss`.)",
            ha="center",
            va="center",
            transform=axes[0].transAxes,
        )

    if len(reward) >= 2:
        sr = steps[: len(reward)] if len(steps) >= len(reward) else list(range(len(reward)))
        axes[1].plot(sr, reward, "g-s", markersize=4)
        axes[1].set_xlabel("Optimization step")
        axes[1].set_ylabel("Reward (logged)")
        axes[1].set_title("Logged reward signal")
        axes[1].grid(True, alpha=0.3)
    else:
        axes[1].text(
            0.5,
            0.5,
            "No reward entries in log_history.\n(OK — loss-only runs are common.)",
            ha="center",
            va="center",
            transform=axes[1].transAxes,
        )

    fig.tight_layout()
    fig.savefig(out_png, bbox_inches="tight")
    print(f"Wrote {out_png}")


if __name__ == "__main__":
    main()