""" Generate training plots from a saved trainer.state.log_history JSON. Used by HF Jobs (where there's no notebook to call trainer.state directly) and locally to regenerate plots after training without re-running. Usage: python scripts/plot_from_log.py --log outputs/.../log_history.json --out plots/ """ import argparse import json import os import matplotlib.pyplot as plt import numpy as np def series(log, *keys): """Find the first matching key across log entries; return (steps, values, key).""" for k in keys: steps, vals = [], [] for entry in log: if k in entry: steps.append(entry.get("step", len(steps))) vals.append(entry[k]) if vals: return steps, vals, k return [], [], None def main(): p = argparse.ArgumentParser() p.add_argument("--log", required=True) p.add_argument("--out", default="plots") args = p.parse_args() with open(args.log) as f: log = json.load(f) os.makedirs(args.out, exist_ok=True) # Discover keys to help debug all_keys = set() for entry in log: all_keys.update(entry.keys()) print(f"Available log keys: {sorted(all_keys)}") # 1. Training Loss steps, losses, _ = series(log, "loss", "train/loss") if losses: fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(steps, losses, color="#2563eb", linewidth=1.5, alpha=0.8) ax.set_xlabel("Training Step") ax.set_ylabel("Loss") ax.set_title("GRPO Training Loss - RhythmEnv Meta-RL") ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(f"{args.out}/training_loss.png", dpi=150) plt.close() print(f"Saved: {args.out}/training_loss.png ({len(losses)} points)") # 2. Mean Reward rs, rv, rk = series(log, "reward", "rewards/mean", "rewards/total/mean") ss, sv, _ = series(log, "reward_std", "rewards/std", "rewards/total/std") if rv: fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(rs, rv, color="#16a34a", linewidth=1.5, label=f"Mean Reward ({rk})") if sv and len(sv) == len(rv): r, s = np.array(rv), np.array(sv) ax.fill_between(rs, r - s, r + s, color="#16a34a", alpha=0.15, label="+/-1 std") ax.set_xlabel("Training Step") ax.set_ylabel("Mean Total Reward") ax.set_title("GRPO Mean Reward over Training - RhythmEnv Meta-RL") ax.legend() ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(f"{args.out}/reward_curve.png", dpi=150) plt.close() print(f"Saved: {args.out}/reward_curve.png ({len(rv)} points)") # 3. Per-Reward-Function components components = [ ("format_valid", ["rewards/format_valid/mean", "rewards/format_valid", "format_valid_reward"]), ("action_legal", ["rewards/action_legal/mean", "rewards/action_legal", "action_legal_reward"]), ("env_reward", ["rewards/env_reward/mean", "rewards/env_reward", "env_reward_reward"]), ("belief_accuracy", ["rewards/belief_accuracy/mean", "rewards/belief_accuracy", "belief_accuracy_reward"]), ] found = [] for name, keys in components: s, v, k = series(log, *keys) if v: found.append((name, s, v)) print(f" {name}: matched key '{k}'") else: print(f" {name}: NOT FOUND") if found: fig, ax = plt.subplots(figsize=(12, 6)) colors = {"format_valid": "#94a3b8", "action_legal": "#60a5fa", "env_reward": "#22c55e", "belief_accuracy": "#a855f7"} for name, s, v in found: ax.plot(s, v, color=colors.get(name, "#000"), linewidth=1.5, alpha=0.85, label=name) ax.axhline(0, color="k", linewidth=0.4) ax.set_xlabel("Training Step") ax.set_ylabel("Mean Reward Component") ax.set_title("4-Layer Reward Stack over Training (RhythmEnv Meta-RL)") ax.legend(loc="best") ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(f"{args.out}/reward_components.png", dpi=150) plt.close() print(f"Saved: {args.out}/reward_components.png ({len(found)} components)") # 4. Belief-Accuracy curve bs, bv, _ = series(log, "rewards/belief_accuracy/mean", "rewards/belief_accuracy", "belief_accuracy_reward") if bv: fig, ax = plt.subplots(figsize=(10, 5)) ax.plot(bs, bv, color="#a855f7", linewidth=2.0, alpha=0.9, label="Belief reward") if len(bv) > 20: win = max(10, len(bv) // 30) kernel = np.ones(win) / win smooth = np.convolve(bv, kernel, mode="valid") ax.plot(bs[win - 1:], smooth, color="#7e22ce", linewidth=2.5, label=f"Rolling mean ({win}-step)") ax.axhline(0.0, color="k", linewidth=0.5, linestyle="--", alpha=0.5, label="neutral baseline") ax.set_xlabel("Training Step") ax.set_ylabel("Mean belief_accuracy reward (-0.5 to +0.5)") ax.set_title("Belief-Accuracy Reward over Training (proof agent learned to model user)") ax.legend(loc="best") ax.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(f"{args.out}/belief_accuracy.png", dpi=150) plt.close() print(f"Saved: {args.out}/belief_accuracy.png ({len(bv)} points)") # 5. Comparison plot if eval_results.json is available eval_path = "eval_results.json" if os.path.exists(eval_path): with open(eval_path) as f: results = json.load(f) conditions = ["discrete-3-profiles (legacy)", "continuous-in-distribution", "continuous-OOD (generalization)"] def avg(cond, strat, key="final_score"): rs = [r[key] for r in results if r["condition"] == cond and r["strategy"] == strat] return float(np.mean(rs)) if rs else 0.0 x = np.arange(len(conditions)) width = 0.27 fig, axes = plt.subplots(1, 2, figsize=(14, 5)) rand = [avg(c, "random") for c in conditions] heur = [avg(c, "heuristic") for c in conditions] trnd = [avg(c, "model") for c in conditions] axes[0].bar(x - width, rand, width, label="Random", color="#94a3b8") axes[0].bar(x, heur, width, label="Heuristic", color="#60a5fa") axes[0].bar(x + width, trnd, width, label="Trained Qwen", color="#22c55e") axes[0].set_ylabel("Final score (0-1)") axes[0].set_title("Final score by condition") axes[0].set_xticks(x) axes[0].set_xticklabels([c.split(" ")[0] for c in conditions], fontsize=10) axes[0].legend() axes[0].grid(axis="y", alpha=0.3) rand_a = [avg(c, "random", "adaptation") for c in conditions] heur_a = [avg(c, "heuristic", "adaptation") for c in conditions] trnd_a = [avg(c, "model", "adaptation") for c in conditions] axes[1].bar(x - width, rand_a, width, label="Random", color="#94a3b8") axes[1].bar(x, heur_a, width, label="Heuristic", color="#60a5fa") axes[1].bar(x + width, trnd_a, width, label="Trained Qwen", color="#22c55e") axes[1].set_ylabel("Adaptation (late-half - early-half mean reward)") axes[1].set_title("Adaptation: did agent get better mid-episode?") axes[1].set_xticks(x) axes[1].set_xticklabels([c.split(" ")[0] for c in conditions], fontsize=10) axes[1].axhline(0, color="k", linewidth=0.5) axes[1].legend() axes[1].grid(axis="y", alpha=0.3) plt.tight_layout() plt.savefig(f"{args.out}/baseline_vs_trained.png", dpi=150) plt.close() print(f"Saved: {args.out}/baseline_vs_trained.png") print() print(f"{'Condition':<40} {'Random':>10} {'Heuristic':>10} {'Trained':>10} {'vs Heuristic':>14}") print("-" * 90) for c, r, h, t in zip(conditions, rand, heur, trnd): print(f"{c:<40} {r:>10.3f} {h:>10.3f} {t:>10.3f} {(t - h):>+14.3f}") if __name__ == "__main__": main()