rhythm_env / scripts /plot_from_log.py
InosLihka's picture
feat: HF Jobs training script + plot generator
73c7ea0
"""
Generate training plots from a saved trainer.state.log_history JSON.
Used by HF Jobs (where there's no notebook to call trainer.state directly)
and locally to regenerate plots after training without re-running.
Usage:
python scripts/plot_from_log.py --log outputs/.../log_history.json --out plots/
"""
import argparse
import json
import os
import matplotlib.pyplot as plt
import numpy as np
def series(log, *keys):
"""Find the first matching key across log entries; return (steps, values, key)."""
for k in keys:
steps, vals = [], []
for entry in log:
if k in entry:
steps.append(entry.get("step", len(steps)))
vals.append(entry[k])
if vals:
return steps, vals, k
return [], [], None
def main():
p = argparse.ArgumentParser()
p.add_argument("--log", required=True)
p.add_argument("--out", default="plots")
args = p.parse_args()
with open(args.log) as f:
log = json.load(f)
os.makedirs(args.out, exist_ok=True)
# Discover keys to help debug
all_keys = set()
for entry in log:
all_keys.update(entry.keys())
print(f"Available log keys: {sorted(all_keys)}")
# 1. Training Loss
steps, losses, _ = series(log, "loss", "train/loss")
if losses:
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(steps, losses, color="#2563eb", linewidth=1.5, alpha=0.8)
ax.set_xlabel("Training Step")
ax.set_ylabel("Loss")
ax.set_title("GRPO Training Loss - RhythmEnv Meta-RL")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f"{args.out}/training_loss.png", dpi=150)
plt.close()
print(f"Saved: {args.out}/training_loss.png ({len(losses)} points)")
# 2. Mean Reward
rs, rv, rk = series(log, "reward", "rewards/mean", "rewards/total/mean")
ss, sv, _ = series(log, "reward_std", "rewards/std", "rewards/total/std")
if rv:
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(rs, rv, color="#16a34a", linewidth=1.5, label=f"Mean Reward ({rk})")
if sv and len(sv) == len(rv):
r, s = np.array(rv), np.array(sv)
ax.fill_between(rs, r - s, r + s, color="#16a34a", alpha=0.15, label="+/-1 std")
ax.set_xlabel("Training Step")
ax.set_ylabel("Mean Total Reward")
ax.set_title("GRPO Mean Reward over Training - RhythmEnv Meta-RL")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f"{args.out}/reward_curve.png", dpi=150)
plt.close()
print(f"Saved: {args.out}/reward_curve.png ({len(rv)} points)")
# 3. Per-Reward-Function components
components = [
("format_valid", ["rewards/format_valid/mean", "rewards/format_valid", "format_valid_reward"]),
("action_legal", ["rewards/action_legal/mean", "rewards/action_legal", "action_legal_reward"]),
("env_reward", ["rewards/env_reward/mean", "rewards/env_reward", "env_reward_reward"]),
("belief_accuracy", ["rewards/belief_accuracy/mean", "rewards/belief_accuracy", "belief_accuracy_reward"]),
]
found = []
for name, keys in components:
s, v, k = series(log, *keys)
if v:
found.append((name, s, v))
print(f" {name}: matched key '{k}'")
else:
print(f" {name}: NOT FOUND")
if found:
fig, ax = plt.subplots(figsize=(12, 6))
colors = {"format_valid": "#94a3b8", "action_legal": "#60a5fa", "env_reward": "#22c55e", "belief_accuracy": "#a855f7"}
for name, s, v in found:
ax.plot(s, v, color=colors.get(name, "#000"), linewidth=1.5, alpha=0.85, label=name)
ax.axhline(0, color="k", linewidth=0.4)
ax.set_xlabel("Training Step")
ax.set_ylabel("Mean Reward Component")
ax.set_title("4-Layer Reward Stack over Training (RhythmEnv Meta-RL)")
ax.legend(loc="best")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f"{args.out}/reward_components.png", dpi=150)
plt.close()
print(f"Saved: {args.out}/reward_components.png ({len(found)} components)")
# 4. Belief-Accuracy curve
bs, bv, _ = series(log, "rewards/belief_accuracy/mean", "rewards/belief_accuracy", "belief_accuracy_reward")
if bv:
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(bs, bv, color="#a855f7", linewidth=2.0, alpha=0.9, label="Belief reward")
if len(bv) > 20:
win = max(10, len(bv) // 30)
kernel = np.ones(win) / win
smooth = np.convolve(bv, kernel, mode="valid")
ax.plot(bs[win - 1:], smooth, color="#7e22ce", linewidth=2.5, label=f"Rolling mean ({win}-step)")
ax.axhline(0.0, color="k", linewidth=0.5, linestyle="--", alpha=0.5, label="neutral baseline")
ax.set_xlabel("Training Step")
ax.set_ylabel("Mean belief_accuracy reward (-0.5 to +0.5)")
ax.set_title("Belief-Accuracy Reward over Training (proof agent learned to model user)")
ax.legend(loc="best")
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(f"{args.out}/belief_accuracy.png", dpi=150)
plt.close()
print(f"Saved: {args.out}/belief_accuracy.png ({len(bv)} points)")
# 5. Comparison plot if eval_results.json is available
eval_path = "eval_results.json"
if os.path.exists(eval_path):
with open(eval_path) as f:
results = json.load(f)
conditions = ["discrete-3-profiles (legacy)", "continuous-in-distribution", "continuous-OOD (generalization)"]
def avg(cond, strat, key="final_score"):
rs = [r[key] for r in results if r["condition"] == cond and r["strategy"] == strat]
return float(np.mean(rs)) if rs else 0.0
x = np.arange(len(conditions))
width = 0.27
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
rand = [avg(c, "random") for c in conditions]
heur = [avg(c, "heuristic") for c in conditions]
trnd = [avg(c, "model") for c in conditions]
axes[0].bar(x - width, rand, width, label="Random", color="#94a3b8")
axes[0].bar(x, heur, width, label="Heuristic", color="#60a5fa")
axes[0].bar(x + width, trnd, width, label="Trained Qwen", color="#22c55e")
axes[0].set_ylabel("Final score (0-1)")
axes[0].set_title("Final score by condition")
axes[0].set_xticks(x)
axes[0].set_xticklabels([c.split(" ")[0] for c in conditions], fontsize=10)
axes[0].legend()
axes[0].grid(axis="y", alpha=0.3)
rand_a = [avg(c, "random", "adaptation") for c in conditions]
heur_a = [avg(c, "heuristic", "adaptation") for c in conditions]
trnd_a = [avg(c, "model", "adaptation") for c in conditions]
axes[1].bar(x - width, rand_a, width, label="Random", color="#94a3b8")
axes[1].bar(x, heur_a, width, label="Heuristic", color="#60a5fa")
axes[1].bar(x + width, trnd_a, width, label="Trained Qwen", color="#22c55e")
axes[1].set_ylabel("Adaptation (late-half - early-half mean reward)")
axes[1].set_title("Adaptation: did agent get better mid-episode?")
axes[1].set_xticks(x)
axes[1].set_xticklabels([c.split(" ")[0] for c in conditions], fontsize=10)
axes[1].axhline(0, color="k", linewidth=0.5)
axes[1].legend()
axes[1].grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig(f"{args.out}/baseline_vs_trained.png", dpi=150)
plt.close()
print(f"Saved: {args.out}/baseline_vs_trained.png")
print()
print(f"{'Condition':<40} {'Random':>10} {'Heuristic':>10} {'Trained':>10} {'vs Heuristic':>14}")
print("-" * 90)
for c, r, h, t in zip(conditions, rand, heur, trnd):
print(f"{c:<40} {r:>10.3f} {h:>10.3f} {t:>10.3f} {(t - h):>+14.3f}")
if __name__ == "__main__":
main()