Spaces:
Sleeping
Sleeping
| """ | |
| Generate training plots from a saved trainer.state.log_history JSON. | |
| Used by HF Jobs (where there's no notebook to call trainer.state directly) | |
| and locally to regenerate plots after training without re-running. | |
| Usage: | |
| python scripts/plot_from_log.py --log outputs/.../log_history.json --out plots/ | |
| """ | |
| import argparse | |
| import json | |
| import os | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def series(log, *keys): | |
| """Find the first matching key across log entries; return (steps, values, key).""" | |
| for k in keys: | |
| steps, vals = [], [] | |
| for entry in log: | |
| if k in entry: | |
| steps.append(entry.get("step", len(steps))) | |
| vals.append(entry[k]) | |
| if vals: | |
| return steps, vals, k | |
| return [], [], None | |
| def main(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--log", required=True) | |
| p.add_argument("--out", default="plots") | |
| args = p.parse_args() | |
| with open(args.log) as f: | |
| log = json.load(f) | |
| os.makedirs(args.out, exist_ok=True) | |
| # Discover keys to help debug | |
| all_keys = set() | |
| for entry in log: | |
| all_keys.update(entry.keys()) | |
| print(f"Available log keys: {sorted(all_keys)}") | |
| # 1. Training Loss | |
| steps, losses, _ = series(log, "loss", "train/loss") | |
| if losses: | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(steps, losses, color="#2563eb", linewidth=1.5, alpha=0.8) | |
| ax.set_xlabel("Training Step") | |
| ax.set_ylabel("Loss") | |
| ax.set_title("GRPO Training Loss - RhythmEnv Meta-RL") | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(f"{args.out}/training_loss.png", dpi=150) | |
| plt.close() | |
| print(f"Saved: {args.out}/training_loss.png ({len(losses)} points)") | |
| # 2. Mean Reward | |
| rs, rv, rk = series(log, "reward", "rewards/mean", "rewards/total/mean") | |
| ss, sv, _ = series(log, "reward_std", "rewards/std", "rewards/total/std") | |
| if rv: | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(rs, rv, color="#16a34a", linewidth=1.5, label=f"Mean Reward ({rk})") | |
| if sv and len(sv) == len(rv): | |
| r, s = np.array(rv), np.array(sv) | |
| ax.fill_between(rs, r - s, r + s, color="#16a34a", alpha=0.15, label="+/-1 std") | |
| ax.set_xlabel("Training Step") | |
| ax.set_ylabel("Mean Total Reward") | |
| ax.set_title("GRPO Mean Reward over Training - RhythmEnv Meta-RL") | |
| ax.legend() | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(f"{args.out}/reward_curve.png", dpi=150) | |
| plt.close() | |
| print(f"Saved: {args.out}/reward_curve.png ({len(rv)} points)") | |
| # 3. Per-Reward-Function components | |
| components = [ | |
| ("format_valid", ["rewards/format_valid/mean", "rewards/format_valid", "format_valid_reward"]), | |
| ("action_legal", ["rewards/action_legal/mean", "rewards/action_legal", "action_legal_reward"]), | |
| ("env_reward", ["rewards/env_reward/mean", "rewards/env_reward", "env_reward_reward"]), | |
| ("belief_accuracy", ["rewards/belief_accuracy/mean", "rewards/belief_accuracy", "belief_accuracy_reward"]), | |
| ] | |
| found = [] | |
| for name, keys in components: | |
| s, v, k = series(log, *keys) | |
| if v: | |
| found.append((name, s, v)) | |
| print(f" {name}: matched key '{k}'") | |
| else: | |
| print(f" {name}: NOT FOUND") | |
| if found: | |
| fig, ax = plt.subplots(figsize=(12, 6)) | |
| colors = {"format_valid": "#94a3b8", "action_legal": "#60a5fa", "env_reward": "#22c55e", "belief_accuracy": "#a855f7"} | |
| for name, s, v in found: | |
| ax.plot(s, v, color=colors.get(name, "#000"), linewidth=1.5, alpha=0.85, label=name) | |
| ax.axhline(0, color="k", linewidth=0.4) | |
| ax.set_xlabel("Training Step") | |
| ax.set_ylabel("Mean Reward Component") | |
| ax.set_title("4-Layer Reward Stack over Training (RhythmEnv Meta-RL)") | |
| ax.legend(loc="best") | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(f"{args.out}/reward_components.png", dpi=150) | |
| plt.close() | |
| print(f"Saved: {args.out}/reward_components.png ({len(found)} components)") | |
| # 4. Belief-Accuracy curve | |
| bs, bv, _ = series(log, "rewards/belief_accuracy/mean", "rewards/belief_accuracy", "belief_accuracy_reward") | |
| if bv: | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(bs, bv, color="#a855f7", linewidth=2.0, alpha=0.9, label="Belief reward") | |
| if len(bv) > 20: | |
| win = max(10, len(bv) // 30) | |
| kernel = np.ones(win) / win | |
| smooth = np.convolve(bv, kernel, mode="valid") | |
| ax.plot(bs[win - 1:], smooth, color="#7e22ce", linewidth=2.5, label=f"Rolling mean ({win}-step)") | |
| ax.axhline(0.0, color="k", linewidth=0.5, linestyle="--", alpha=0.5, label="neutral baseline") | |
| ax.set_xlabel("Training Step") | |
| ax.set_ylabel("Mean belief_accuracy reward (-0.5 to +0.5)") | |
| ax.set_title("Belief-Accuracy Reward over Training (proof agent learned to model user)") | |
| ax.legend(loc="best") | |
| ax.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(f"{args.out}/belief_accuracy.png", dpi=150) | |
| plt.close() | |
| print(f"Saved: {args.out}/belief_accuracy.png ({len(bv)} points)") | |
| # 5. Comparison plot if eval_results.json is available | |
| eval_path = "eval_results.json" | |
| if os.path.exists(eval_path): | |
| with open(eval_path) as f: | |
| results = json.load(f) | |
| conditions = ["discrete-3-profiles (legacy)", "continuous-in-distribution", "continuous-OOD (generalization)"] | |
| def avg(cond, strat, key="final_score"): | |
| rs = [r[key] for r in results if r["condition"] == cond and r["strategy"] == strat] | |
| return float(np.mean(rs)) if rs else 0.0 | |
| x = np.arange(len(conditions)) | |
| width = 0.27 | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| rand = [avg(c, "random") for c in conditions] | |
| heur = [avg(c, "heuristic") for c in conditions] | |
| trnd = [avg(c, "model") for c in conditions] | |
| axes[0].bar(x - width, rand, width, label="Random", color="#94a3b8") | |
| axes[0].bar(x, heur, width, label="Heuristic", color="#60a5fa") | |
| axes[0].bar(x + width, trnd, width, label="Trained Qwen", color="#22c55e") | |
| axes[0].set_ylabel("Final score (0-1)") | |
| axes[0].set_title("Final score by condition") | |
| axes[0].set_xticks(x) | |
| axes[0].set_xticklabels([c.split(" ")[0] for c in conditions], fontsize=10) | |
| axes[0].legend() | |
| axes[0].grid(axis="y", alpha=0.3) | |
| rand_a = [avg(c, "random", "adaptation") for c in conditions] | |
| heur_a = [avg(c, "heuristic", "adaptation") for c in conditions] | |
| trnd_a = [avg(c, "model", "adaptation") for c in conditions] | |
| axes[1].bar(x - width, rand_a, width, label="Random", color="#94a3b8") | |
| axes[1].bar(x, heur_a, width, label="Heuristic", color="#60a5fa") | |
| axes[1].bar(x + width, trnd_a, width, label="Trained Qwen", color="#22c55e") | |
| axes[1].set_ylabel("Adaptation (late-half - early-half mean reward)") | |
| axes[1].set_title("Adaptation: did agent get better mid-episode?") | |
| axes[1].set_xticks(x) | |
| axes[1].set_xticklabels([c.split(" ")[0] for c in conditions], fontsize=10) | |
| axes[1].axhline(0, color="k", linewidth=0.5) | |
| axes[1].legend() | |
| axes[1].grid(axis="y", alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(f"{args.out}/baseline_vs_trained.png", dpi=150) | |
| plt.close() | |
| print(f"Saved: {args.out}/baseline_vs_trained.png") | |
| print() | |
| print(f"{'Condition':<40} {'Random':>10} {'Heuristic':>10} {'Trained':>10} {'vs Heuristic':>14}") | |
| print("-" * 90) | |
| for c, r, h, t in zip(conditions, rand, heur, trnd): | |
| print(f"{c:<40} {r:>10.3f} {h:>10.3f} {t:>10.3f} {(t - h):>+14.3f}") | |
| if __name__ == "__main__": | |
| main() | |