File size: 2,397 Bytes
d64efa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""Generate the headline 'student vs baselines across conditions' bar chart
from the v3 eval JSON. Output goes to plots/sft_v3_baseline_vs_trained.png.
"""

import json
from collections import defaultdict
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

EVAL_PATH = Path("outputs/sft-v3/eval_results_v2.json")
OUT_PATH = Path("plots/sft_v3_baseline_vs_trained.png")


def main() -> None:
    rows = json.loads(EVAL_PATH.read_text())
    agg: dict = defaultdict(lambda: defaultdict(list))
    for r in rows:
        agg[r["condition"]][r["strategy"]].append(r["final_score"])

    # Order conditions for display
    cond_order = [
        ("continuous-in-distribution", "in-distribution"),
        ("continuous-OOD (generalization)", "OOD generalization"),
        ("discrete-3-profiles", "discrete-3-profiles"),
    ]
    strat_order = ["random", "heuristic", "model"]
    strat_labels = ["Random", "Heuristic", "Distilled Qwen 3B"]
    strat_colors = ["#888888", "#5B8FF9", "#5AD8A6"]

    means = {strat: [] for strat in strat_order}
    for cond_key, _ in cond_order:
        for strat in strat_order:
            scores = agg[cond_key][strat]
            means[strat].append(sum(scores) / len(scores) if scores else 0.0)

    fig, ax = plt.subplots(figsize=(8, 5))
    x = np.arange(len(cond_order))
    width = 0.27

    for i, (strat, label, color) in enumerate(zip(strat_order, strat_labels, strat_colors)):
        offset = (i - 1) * width
        bars = ax.bar(x + offset, means[strat], width, label=label, color=color)
        for bar in bars:
            ax.text(
                bar.get_x() + bar.get_width() / 2,
                bar.get_height() + 0.005,
                f"{bar.get_height():.3f}",
                ha="center", va="bottom", fontsize=9,
            )

    ax.set_xlabel("Eval condition")
    ax.set_ylabel("Final score (v2 grader, 0–1)")
    ax.set_title("RhythmEnv: Distilled Qwen 3B beats heuristic on all 3 conditions")
    ax.set_xticks(x)
    ax.set_xticklabels([label for _, label in cond_order])
    ax.set_ylim(0, max(max(v) for v in means.values()) * 1.15)
    ax.grid(axis="y", alpha=0.3)
    ax.legend(loc="upper right", framealpha=0.95)
    plt.tight_layout()
    plt.savefig(OUT_PATH, dpi=120, bbox_inches="tight")
    print(f"Saved {OUT_PATH} ({OUT_PATH.stat().st_size // 1024} KB)")


if __name__ == "__main__":
    main()