Spaces:
Sleeping
Sleeping
| """Generate the headline 'student vs baselines across conditions' bar chart | |
| from the v3 eval JSON. Output goes to plots/sft_v3_baseline_vs_trained.png. | |
| """ | |
| import json | |
| from collections import defaultdict | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| EVAL_PATH = Path("outputs/sft-v3/eval_results_v2.json") | |
| OUT_PATH = Path("plots/sft_v3_baseline_vs_trained.png") | |
| def main() -> None: | |
| rows = json.loads(EVAL_PATH.read_text()) | |
| agg: dict = defaultdict(lambda: defaultdict(list)) | |
| for r in rows: | |
| agg[r["condition"]][r["strategy"]].append(r["final_score"]) | |
| # Order conditions for display | |
| cond_order = [ | |
| ("continuous-in-distribution", "in-distribution"), | |
| ("continuous-OOD (generalization)", "OOD generalization"), | |
| ("discrete-3-profiles", "discrete-3-profiles"), | |
| ] | |
| strat_order = ["random", "heuristic", "model"] | |
| strat_labels = ["Random", "Heuristic", "Distilled Qwen 3B"] | |
| strat_colors = ["#888888", "#5B8FF9", "#5AD8A6"] | |
| means = {strat: [] for strat in strat_order} | |
| for cond_key, _ in cond_order: | |
| for strat in strat_order: | |
| scores = agg[cond_key][strat] | |
| means[strat].append(sum(scores) / len(scores) if scores else 0.0) | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| x = np.arange(len(cond_order)) | |
| width = 0.27 | |
| for i, (strat, label, color) in enumerate(zip(strat_order, strat_labels, strat_colors)): | |
| offset = (i - 1) * width | |
| bars = ax.bar(x + offset, means[strat], width, label=label, color=color) | |
| for bar in bars: | |
| ax.text( | |
| bar.get_x() + bar.get_width() / 2, | |
| bar.get_height() + 0.005, | |
| f"{bar.get_height():.3f}", | |
| ha="center", va="bottom", fontsize=9, | |
| ) | |
| ax.set_xlabel("Eval condition") | |
| ax.set_ylabel("Final score (v2 grader, 0–1)") | |
| ax.set_title("RhythmEnv: Distilled Qwen 3B beats heuristic on all 3 conditions") | |
| ax.set_xticks(x) | |
| ax.set_xticklabels([label for _, label in cond_order]) | |
| ax.set_ylim(0, max(max(v) for v in means.values()) * 1.15) | |
| ax.grid(axis="y", alpha=0.3) | |
| ax.legend(loc="upper right", framealpha=0.95) | |
| plt.tight_layout() | |
| plt.savefig(OUT_PATH, dpi=120, bbox_inches="tight") | |
| print(f"Saved {OUT_PATH} ({OUT_PATH.stat().st_size // 1024} KB)") | |
| if __name__ == "__main__": | |
| main() | |