File size: 4,123 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""Matplotlib plotting helpers — produces the 3 PNGs that go into the README.

Plots:
  1. baseline_vs_trained.png — bar/line comparison
  2. training_reward_curve.png — moving-average reward over episodes
  3. success_by_category.png — per-primitive-type success rate

All plots are 600x400 @ 100 dpi, label both axes, and use a colour-blind-safe palette.
"""
from __future__ import annotations

from pathlib import Path
from typing import Iterable

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt  # noqa: E402

PALETTE = {
    "baseline": "#888888",
    "trained": "#1F77B4",
    "ema": "#D62728",
    "raw": "#1F77B4",
}


def _moving_average(values: list[float], window: int = 10) -> list[float]:
    if not values:
        return []
    out: list[float] = []
    cumsum = 0.0
    for i, v in enumerate(values):
        cumsum += v
        if i >= window:
            cumsum -= values[i - window]
        out.append(cumsum / min(i + 1, window))
    return out


def plot_baseline_vs_trained(
    baseline_rewards: list[float],
    trained_rewards: list[float],
    out_path: str | Path,
    title: str = "ForgeEnv: Baseline vs Trained (50 eval episodes)",
) -> str:
    """Side-by-side bar chart of mean reward + per-episode strip plot."""
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig, ax = plt.subplots(figsize=(6, 4), dpi=100)

    means = [
        sum(baseline_rewards) / max(1, len(baseline_rewards)),
        sum(trained_rewards) / max(1, len(trained_rewards)),
    ]
    labels = ["Baseline (no-op)", "Trained (GRPO)"]
    colors = [PALETTE["baseline"], PALETTE["trained"]]
    bars = ax.bar(labels, means, color=colors, width=0.5, alpha=0.85)
    ax.bar_label(bars, fmt="%.2f", padding=3)

    for x, rewards in zip([0, 1], [baseline_rewards, trained_rewards]):
        if rewards:
            xs = [x + 0.18] * len(rewards)
            ax.scatter(xs, rewards, s=8, color="black", alpha=0.4, zorder=3)

    ax.set_ylabel("Visible verifier reward")
    ax.set_title(title)
    ax.grid(axis="y", linestyle=":", alpha=0.5)
    ax.set_ylim(bottom=min(0, min(means + baseline_rewards + trained_rewards or [0])))
    fig.tight_layout()
    fig.savefig(out_path, dpi=100, bbox_inches="tight")
    plt.close(fig)
    return str(out_path)


def plot_reward_curve(
    rewards: list[float],
    out_path: str | Path,
    window: int = 10,
    title: str = "ForgeEnv: Repair Agent reward over training",
) -> str:
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig, ax = plt.subplots(figsize=(6, 4), dpi=100)
    xs = list(range(1, len(rewards) + 1))
    ax.plot(xs, rewards, color=PALETTE["raw"], alpha=0.35, linewidth=1.0, label="Per-episode")
    if rewards:
        ax.plot(
            xs,
            _moving_average(rewards, window=window),
            color=PALETTE["ema"],
            linewidth=2.0,
            label=f"Moving avg (w={window})",
        )
    ax.set_xlabel("Episode")
    ax.set_ylabel("Visible verifier reward")
    ax.set_title(title)
    ax.legend(loc="lower right")
    ax.grid(linestyle=":", alpha=0.4)
    fig.tight_layout()
    fig.savefig(out_path, dpi=100, bbox_inches="tight")
    plt.close(fig)
    return str(out_path)


def plot_success_rate_by_category(
    by_category: dict[str, list[bool]],
    out_path: str | Path,
    title: str = "ForgeEnv: Repair success by primitive type",
) -> str:
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig, ax = plt.subplots(figsize=(7, 4), dpi=100)

    cats = list(by_category.keys())
    rates = [
        sum(by_category[c]) / max(1, len(by_category[c])) for c in cats
    ]
    bars = ax.barh(cats, rates, color=PALETTE["trained"], alpha=0.85)
    ax.bar_label(bars, fmt="%.2f", padding=3)
    ax.set_xlim(0, 1.05)
    ax.set_xlabel("Success rate (held-out: executed_cleanly)")
    ax.set_title(title)
    ax.grid(axis="x", linestyle=":", alpha=0.4)
    fig.tight_layout()
    fig.savefig(out_path, dpi=100, bbox_inches="tight")
    plt.close(fig)
    return str(out_path)