"""Generate the headline 'student vs baselines across conditions' bar chart from the v3 eval JSON. Output goes to plots/sft_v3_baseline_vs_trained.png. """ import json from collections import defaultdict from pathlib import Path import matplotlib.pyplot as plt import numpy as np EVAL_PATH = Path("outputs/sft-v3/eval_results_v2.json") OUT_PATH = Path("plots/sft_v3_baseline_vs_trained.png") def main() -> None: rows = json.loads(EVAL_PATH.read_text()) agg: dict = defaultdict(lambda: defaultdict(list)) for r in rows: agg[r["condition"]][r["strategy"]].append(r["final_score"]) # Order conditions for display cond_order = [ ("continuous-in-distribution", "in-distribution"), ("continuous-OOD (generalization)", "OOD generalization"), ("discrete-3-profiles", "discrete-3-profiles"), ] strat_order = ["random", "heuristic", "model"] strat_labels = ["Random", "Heuristic", "Distilled Qwen 3B"] strat_colors = ["#888888", "#5B8FF9", "#5AD8A6"] means = {strat: [] for strat in strat_order} for cond_key, _ in cond_order: for strat in strat_order: scores = agg[cond_key][strat] means[strat].append(sum(scores) / len(scores) if scores else 0.0) fig, ax = plt.subplots(figsize=(8, 5)) x = np.arange(len(cond_order)) width = 0.27 for i, (strat, label, color) in enumerate(zip(strat_order, strat_labels, strat_colors)): offset = (i - 1) * width bars = ax.bar(x + offset, means[strat], width, label=label, color=color) for bar in bars: ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, f"{bar.get_height():.3f}", ha="center", va="bottom", fontsize=9, ) ax.set_xlabel("Eval condition") ax.set_ylabel("Final score (v2 grader, 0–1)") ax.set_title("RhythmEnv: Distilled Qwen 3B beats heuristic on all 3 conditions") ax.set_xticks(x) ax.set_xticklabels([label for _, label in cond_order]) ax.set_ylim(0, max(max(v) for v in means.values()) * 1.15) ax.grid(axis="y", alpha=0.3) ax.legend(loc="upper right", framealpha=0.95) plt.tight_layout() plt.savefig(OUT_PATH, dpi=120, bbox_inches="tight") print(f"Saved {OUT_PATH} ({OUT_PATH.stat().st_size // 1024} KB)") if __name__ == "__main__": main()