File size: 4,037 Bytes
6850dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
Read train_metrics.jsonl from a training run and produce the demo plots:

  - reward_curve.png    : mean reward per GRPO step (the headline)
  - length_curve.png    : mean prompt tokens per GRPO step (the compression story)
  - breakdown.png       : 2x2 grid of reward, tokens, raw_task_score, length_factor

Usage:
    python training/make_plots.py --metrics outputs/grpo/train_metrics.jsonl --out-dir plots
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, List


def load(path: Path) -> List[Dict]:
    rows: List[Dict] = []
    with path.open() as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rows.append(json.loads(line))
    return rows


def smooth(values: List[float], k: int = 10) -> List[float]:
    if not values or k <= 1:
        return values
    out = []
    window: List[float] = []
    for v in values:
        window.append(v)
        if len(window) > k:
            window.pop(0)
        out.append(sum(window) / len(window))
    return out


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--metrics", required=True)
    p.add_argument("--out-dir", default="plots")
    p.add_argument("--smooth-k", type=int, default=10)
    p.add_argument("--title-suffix", default="")
    return p.parse_args()


def main() -> None:
    args = parse_args()
    import matplotlib.pyplot as plt

    rows = load(Path(args.metrics))
    if not rows:
        raise SystemExit(f"no rows in {args.metrics}")

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    steps = [r["step"] for r in rows]
    rewards = [r.get("avg_reward") or r.get("reward") or 0.0 for r in rows]
    tokens = [r.get("avg_tokens") or 0.0 for r in rows]
    raws = [r.get("avg_raw_score") or 0.0 for r in rows]
    lfs = [r.get("avg_length_factor") or 0.0 for r in rows]

    # --- reward curve ---
    fig, ax = plt.subplots(figsize=(7, 4))
    ax.plot(steps, rewards, alpha=0.4, label="per-step")
    ax.plot(steps, smooth(rewards, args.smooth_k), linewidth=2, label=f"smoothed (k={args.smooth_k})")
    ax.set_xlabel("GRPO step")
    ax.set_ylabel("mean reward (group-averaged)")
    ax.set_title(f"Prompt Golf — Reward per step{args.title_suffix}")
    ax.grid(True, alpha=0.3)
    ax.legend()
    ax.set_ylim(bottom=0)
    fig.tight_layout()
    fig.savefig(out_dir / "reward_curve.png", dpi=150)
    plt.close(fig)

    # --- length curve ---
    fig, ax = plt.subplots(figsize=(7, 4))
    ax.plot(steps, tokens, alpha=0.4, label="per-step")
    ax.plot(steps, smooth(tokens, args.smooth_k), linewidth=2, label=f"smoothed (k={args.smooth_k})")
    ax.set_xlabel("GRPO step")
    ax.set_ylabel("mean prompt tokens (target tokenizer)")
    ax.set_title(f"Prompt Golf — Compression emerging{args.title_suffix}")
    ax.grid(True, alpha=0.3)
    ax.legend()
    ax.set_ylim(bottom=0)
    fig.tight_layout()
    fig.savefig(out_dir / "length_curve.png", dpi=150)
    plt.close(fig)

    # --- 2x2 breakdown ---
    fig, axes = plt.subplots(2, 2, figsize=(11, 7))
    series = [
        ("reward", rewards, "mean reward"),
        ("tokens", tokens, "mean prompt tokens"),
        ("raw_task_score", raws, "raw task score"),
        ("length_factor", lfs, "length factor"),
    ]
    for ax, (name, vals, ylabel) in zip(axes.flat, series):
        ax.plot(steps, vals, alpha=0.4)
        ax.plot(steps, smooth(vals, args.smooth_k), linewidth=2)
        ax.set_title(name)
        ax.set_xlabel("step")
        ax.set_ylabel(ylabel)
        ax.grid(True, alpha=0.3)
    fig.suptitle(f"Prompt Golf — training breakdown{args.title_suffix}", fontsize=13)
    fig.tight_layout()
    fig.savefig(out_dir / "breakdown.png", dpi=150)
    plt.close(fig)

    print(f"Saved plots to {out_dir}/")
    for f in ("reward_curve.png", "length_curve.png", "breakdown.png"):
        print(f"  {out_dir / f}")


if __name__ == "__main__":
    main()