rhythm_env / scripts /plot_v3_results.py
InosLihka's picture
Post-deadline: full eval results + bigger plots via Git LFS
d64efa6
"""Generate the headline 'student vs baselines across conditions' bar chart
from the v3 eval JSON. Output goes to plots/sft_v3_baseline_vs_trained.png.
"""
import json
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
EVAL_PATH = Path("outputs/sft-v3/eval_results_v2.json")
OUT_PATH = Path("plots/sft_v3_baseline_vs_trained.png")
def main() -> None:
rows = json.loads(EVAL_PATH.read_text())
agg: dict = defaultdict(lambda: defaultdict(list))
for r in rows:
agg[r["condition"]][r["strategy"]].append(r["final_score"])
# Order conditions for display
cond_order = [
("continuous-in-distribution", "in-distribution"),
("continuous-OOD (generalization)", "OOD generalization"),
("discrete-3-profiles", "discrete-3-profiles"),
]
strat_order = ["random", "heuristic", "model"]
strat_labels = ["Random", "Heuristic", "Distilled Qwen 3B"]
strat_colors = ["#888888", "#5B8FF9", "#5AD8A6"]
means = {strat: [] for strat in strat_order}
for cond_key, _ in cond_order:
for strat in strat_order:
scores = agg[cond_key][strat]
means[strat].append(sum(scores) / len(scores) if scores else 0.0)
fig, ax = plt.subplots(figsize=(8, 5))
x = np.arange(len(cond_order))
width = 0.27
for i, (strat, label, color) in enumerate(zip(strat_order, strat_labels, strat_colors)):
offset = (i - 1) * width
bars = ax.bar(x + offset, means[strat], width, label=label, color=color)
for bar in bars:
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.005,
f"{bar.get_height():.3f}",
ha="center", va="bottom", fontsize=9,
)
ax.set_xlabel("Eval condition")
ax.set_ylabel("Final score (v2 grader, 0–1)")
ax.set_title("RhythmEnv: Distilled Qwen 3B beats heuristic on all 3 conditions")
ax.set_xticks(x)
ax.set_xticklabels([label for _, label in cond_order])
ax.set_ylim(0, max(max(v) for v in means.values()) * 1.15)
ax.grid(axis="y", alpha=0.3)
ax.legend(loc="upper right", framealpha=0.95)
plt.tight_layout()
plt.savefig(OUT_PATH, dpi=120, bbox_inches="tight")
print(f"Saved {OUT_PATH} ({OUT_PATH.stat().st_size // 1024} KB)")
if __name__ == "__main__":
main()