Spaces:
Sleeping
notebook: add belief-accuracy + reward-components plots
Browse filesThe original plot cell only generated training_loss.png and reward_curve.png.
For meta-RL submission, the most important plot is missing: how the agent's
belief_accuracy reward evolved over training. That curve directly proves
the meta-learning thesis (agent learns to model the user from observation).
Added two plots:
- plots/reward_components.png: all 4 reward functions overlaid over training
(format_valid, action_legal, env_reward, belief_accuracy) so you can see
which signals were gradient-providing
- plots/belief_accuracy.png: focused belief reward with rolling mean and
neutral-baseline reference line
Plot generation is defensive: it discovers all log_history keys and tries
multiple TRL key conventions (rewards/X/mean, rewards/X, X) since these
vary by TRL version. Prints "Available log keys" so the user can debug
if any series isn't found.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
@@ -328,77 +328,132 @@
|
|
| 328 |
"metadata": {},
|
| 329 |
"outputs": [],
|
| 330 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
"import matplotlib.pyplot as plt\n",
|
| 332 |
-
"import json\n",
|
| 333 |
-
"import os\n",
|
| 334 |
"\n",
|
| 335 |
-
"# Extract training logs from trainer\n",
|
| 336 |
"log_history = trainer.state.log_history\n",
|
|
|
|
| 337 |
"\n",
|
| 338 |
-
"
|
| 339 |
-
"
|
| 340 |
-
"
|
| 341 |
-
"reward_stds = []\n",
|
| 342 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 343 |
"for entry in log_history:\n",
|
| 344 |
-
"
|
| 345 |
-
"
|
| 346 |
-
" losses.append(entry[\"loss\"])\n",
|
| 347 |
-
" if \"reward\" in entry:\n",
|
| 348 |
-
" rewards.append(entry[\"reward\"])\n",
|
| 349 |
-
" if \"reward_std\" in entry:\n",
|
| 350 |
-
" reward_stds.append(entry[\"reward_std\"])\n",
|
| 351 |
-
"\n",
|
| 352 |
-
"# Also try rewards/mean key used by some TRL versions\n",
|
| 353 |
-
"if not rewards:\n",
|
| 354 |
-
" for entry in log_history:\n",
|
| 355 |
-
" if \"rewards/mean\" in entry:\n",
|
| 356 |
-
" rewards.append(entry[\"rewards/mean\"])\n",
|
| 357 |
-
" if \"rewards/std\" in entry:\n",
|
| 358 |
-
" reward_stds.append(entry[\"rewards/std\"])\n",
|
| 359 |
-
"\n",
|
| 360 |
-
"os.makedirs(\"plots\", exist_ok=True)\n",
|
| 361 |
"\n",
|
| 362 |
"# --- Plot 1: Training Loss ---\n",
|
| 363 |
-
"
|
| 364 |
-
"
|
| 365 |
-
"ax.
|
| 366 |
-
"ax.
|
| 367 |
-
"ax.
|
| 368 |
-
"ax.
|
| 369 |
-
"
|
| 370 |
-
"
|
| 371 |
-
"plt.
|
| 372 |
-
"
|
|
|
|
|
|
|
| 373 |
"\n",
|
| 374 |
-
"# --- Plot 2: Mean Reward ---\n",
|
| 375 |
-
"
|
|
|
|
|
|
|
| 376 |
" fig, ax = plt.subplots(figsize=(10, 5))\n",
|
| 377 |
-
"
|
| 378 |
-
"
|
| 379 |
-
"
|
| 380 |
-
"
|
| 381 |
-
"
|
| 382 |
-
"
|
| 383 |
-
"
|
| 384 |
-
" ax.set_xlabel(\"Training Step\", fontsize=12)\n",
|
| 385 |
-
" ax.set_ylabel(\"Mean Reward\", fontsize=12)\n",
|
| 386 |
-
" ax.set_title(\"GRPO Mean Reward \u00e2\u20ac\u201d RhythmEnv Life Simulator\", fontsize=14)\n",
|
| 387 |
" ax.legend()\n",
|
| 388 |
" ax.grid(True, alpha=0.3)\n",
|
| 389 |
" plt.tight_layout()\n",
|
| 390 |
" plt.savefig(\"plots/reward_curve.png\", dpi=150)\n",
|
| 391 |
" plt.show()\n",
|
| 392 |
-
" print(\"Saved: plots/reward_curve.png\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
"else:\n",
|
| 394 |
-
" print(\"
|
| 395 |
-
"
|
| 396 |
-
"
|
| 397 |
"\n",
|
| 398 |
-
"
|
| 399 |
-
"with open(\"plots/training_log.json\", \"w\") as f:\n",
|
| 400 |
-
" json.dump(log_history, f, indent=2)\n",
|
| 401 |
-
"print(\"Saved: plots/training_log.json\")"
|
| 402 |
]
|
| 403 |
},
|
| 404 |
{
|
|
|
|
| 328 |
"metadata": {},
|
| 329 |
"outputs": [],
|
| 330 |
"source": [
|
| 331 |
+
"# Generate all training plots from trainer log_history\n",
|
| 332 |
+
"# Saves: training_loss.png, reward_curve.png, reward_components.png, belief_accuracy.png\n",
|
| 333 |
+
"import os, json\n",
|
| 334 |
+
"import numpy as np\n",
|
| 335 |
"import matplotlib.pyplot as plt\n",
|
|
|
|
|
|
|
| 336 |
"\n",
|
|
|
|
| 337 |
"log_history = trainer.state.log_history\n",
|
| 338 |
+
"os.makedirs(\"plots\", exist_ok=True)\n",
|
| 339 |
"\n",
|
| 340 |
+
"# Save raw log first (always, even if plotting fails)\n",
|
| 341 |
+
"with open(\"plots/training_log.json\", \"w\") as f:\n",
|
| 342 |
+
" json.dump(log_history, f, indent=2)\n",
|
|
|
|
| 343 |
"\n",
|
| 344 |
+
"# Helper: extract a series across all log entries that have a key\n",
|
| 345 |
+
"def series(*keys):\n",
|
| 346 |
+
" \"\"\"Extract (steps, values) for the first matching key across log entries.\"\"\"\n",
|
| 347 |
+
" for k in keys:\n",
|
| 348 |
+
" steps, vals = [], []\n",
|
| 349 |
+
" for entry in log_history:\n",
|
| 350 |
+
" if k in entry:\n",
|
| 351 |
+
" steps.append(entry.get(\"step\", len(steps)))\n",
|
| 352 |
+
" vals.append(entry[k])\n",
|
| 353 |
+
" if vals:\n",
|
| 354 |
+
" return steps, vals, k\n",
|
| 355 |
+
" return [], [], None\n",
|
| 356 |
+
"\n",
|
| 357 |
+
"# Discover all log keys to help debug missing plots\n",
|
| 358 |
+
"all_keys = set()\n",
|
| 359 |
"for entry in log_history:\n",
|
| 360 |
+
" all_keys.update(entry.keys())\n",
|
| 361 |
+
"print(f\"Available log keys: {sorted(all_keys)}\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
"\n",
|
| 363 |
"# --- Plot 1: Training Loss ---\n",
|
| 364 |
+
"steps, losses, _ = series(\"loss\", \"train/loss\")\n",
|
| 365 |
+
"if losses:\n",
|
| 366 |
+
" fig, ax = plt.subplots(figsize=(10, 5))\n",
|
| 367 |
+
" ax.plot(steps, losses, color=\"#2563eb\", linewidth=1.5, alpha=0.8)\n",
|
| 368 |
+
" ax.set_xlabel(\"Training Step\")\n",
|
| 369 |
+
" ax.set_ylabel(\"Loss\")\n",
|
| 370 |
+
" ax.set_title(\"GRPO Training Loss \u2014 RhythmEnv Meta-RL\")\n",
|
| 371 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 372 |
+
" plt.tight_layout()\n",
|
| 373 |
+
" plt.savefig(\"plots/training_loss.png\", dpi=150)\n",
|
| 374 |
+
" plt.show()\n",
|
| 375 |
+
" print(f\"Saved: plots/training_loss.png ({len(losses)} points)\")\n",
|
| 376 |
"\n",
|
| 377 |
+
"# --- Plot 2: Mean Reward (overall) ---\n",
|
| 378 |
+
"rsteps, rvals, rkey = series(\"reward\", \"rewards/mean\", \"rewards/total/mean\")\n",
|
| 379 |
+
"ssteps, svals, _ = series(\"reward_std\", \"rewards/std\", \"rewards/total/std\")\n",
|
| 380 |
+
"if rvals:\n",
|
| 381 |
" fig, ax = plt.subplots(figsize=(10, 5))\n",
|
| 382 |
+
" ax.plot(rsteps, rvals, color=\"#16a34a\", linewidth=1.5, label=f\"Mean Reward ({rkey})\")\n",
|
| 383 |
+
" if svals and len(svals) == len(rvals):\n",
|
| 384 |
+
" r, s = np.array(rvals), np.array(svals)\n",
|
| 385 |
+
" ax.fill_between(rsteps, r - s, r + s, color=\"#16a34a\", alpha=0.15, label=\"\u00b11 std\")\n",
|
| 386 |
+
" ax.set_xlabel(\"Training Step\")\n",
|
| 387 |
+
" ax.set_ylabel(\"Mean Total Reward\")\n",
|
| 388 |
+
" ax.set_title(\"GRPO Mean Reward over Training \u2014 RhythmEnv Meta-RL\")\n",
|
|
|
|
|
|
|
|
|
|
| 389 |
" ax.legend()\n",
|
| 390 |
" ax.grid(True, alpha=0.3)\n",
|
| 391 |
" plt.tight_layout()\n",
|
| 392 |
" plt.savefig(\"plots/reward_curve.png\", dpi=150)\n",
|
| 393 |
" plt.show()\n",
|
| 394 |
+
" print(f\"Saved: plots/reward_curve.png ({len(rvals)} points)\")\n",
|
| 395 |
+
"\n",
|
| 396 |
+
"# --- Plot 3: Per-Reward-Function Components (the 4-layer reward stack) ---\n",
|
| 397 |
+
"# TRL logs these as rewards/<func_name>/mean in newer versions.\n",
|
| 398 |
+
"components = [\n",
|
| 399 |
+
" (\"format_valid\", [\"rewards/format_valid/mean\", \"rewards/format_valid\", \"format_valid_reward\"]),\n",
|
| 400 |
+
" (\"action_legal\", [\"rewards/action_legal/mean\", \"rewards/action_legal\", \"action_legal_reward\"]),\n",
|
| 401 |
+
" (\"env_reward\", [\"rewards/env_reward/mean\", \"rewards/env_reward\", \"env_reward_reward\"]),\n",
|
| 402 |
+
" (\"belief_accuracy\", [\"rewards/belief_accuracy/mean\", \"rewards/belief_accuracy\", \"belief_accuracy_reward\"]),\n",
|
| 403 |
+
"]\n",
|
| 404 |
+
"found = []\n",
|
| 405 |
+
"for name, keys in components:\n",
|
| 406 |
+
" s, v, k = series(*keys)\n",
|
| 407 |
+
" if v:\n",
|
| 408 |
+
" found.append((name, s, v))\n",
|
| 409 |
+
" print(f\" {name}: matched key '{k}'\")\n",
|
| 410 |
+
" else:\n",
|
| 411 |
+
" print(f\" {name}: NOT FOUND (looked for {keys})\")\n",
|
| 412 |
+
"\n",
|
| 413 |
+
"if found:\n",
|
| 414 |
+
" fig, ax = plt.subplots(figsize=(12, 6))\n",
|
| 415 |
+
" colors = {\"format_valid\": \"#94a3b8\", \"action_legal\": \"#60a5fa\", \"env_reward\": \"#22c55e\", \"belief_accuracy\": \"#a855f7\"}\n",
|
| 416 |
+
" for name, s, v in found:\n",
|
| 417 |
+
" ax.plot(s, v, color=colors.get(name, \"#000\"), linewidth=1.5, alpha=0.85, label=name)\n",
|
| 418 |
+
" ax.axhline(0, color=\"k\", linewidth=0.4)\n",
|
| 419 |
+
" ax.set_xlabel(\"Training Step\")\n",
|
| 420 |
+
" ax.set_ylabel(\"Mean Reward Component\")\n",
|
| 421 |
+
" ax.set_title(\"4-Layer Reward Stack over Training (RhythmEnv Meta-RL)\")\n",
|
| 422 |
+
" ax.legend(loc=\"best\")\n",
|
| 423 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 424 |
+
" plt.tight_layout()\n",
|
| 425 |
+
" plt.savefig(\"plots/reward_components.png\", dpi=150)\n",
|
| 426 |
+
" plt.show()\n",
|
| 427 |
+
" print(f\"Saved: plots/reward_components.png ({len(found)} components)\")\n",
|
| 428 |
+
"\n",
|
| 429 |
+
"# --- Plot 4: Belief-Accuracy Curve (THE meta-RL signal) ---\n",
|
| 430 |
+
"bsteps, bvals, bkey = series(\"rewards/belief_accuracy/mean\", \"rewards/belief_accuracy\", \"belief_accuracy_reward\")\n",
|
| 431 |
+
"if bvals:\n",
|
| 432 |
+
" fig, ax = plt.subplots(figsize=(10, 5))\n",
|
| 433 |
+
" ax.plot(bsteps, bvals, color=\"#a855f7\", linewidth=2.0, alpha=0.9, label=\"Belief reward\")\n",
|
| 434 |
+
" # Smoothed line (rolling mean)\n",
|
| 435 |
+
" if len(bvals) > 20:\n",
|
| 436 |
+
" win = max(10, len(bvals) // 30)\n",
|
| 437 |
+
" kernel = np.ones(win) / win\n",
|
| 438 |
+
" smooth = np.convolve(bvals, kernel, mode=\"valid\")\n",
|
| 439 |
+
" smooth_x = bsteps[win - 1:]\n",
|
| 440 |
+
" ax.plot(smooth_x, smooth, color=\"#7e22ce\", linewidth=2.5, label=f\"Rolling mean ({win}-step)\")\n",
|
| 441 |
+
" ax.axhline(0.0, color=\"k\", linewidth=0.5, linestyle=\"--\", alpha=0.5, label=\"neutral belief baseline (0.0)\")\n",
|
| 442 |
+
" ax.set_xlabel(\"Training Step\")\n",
|
| 443 |
+
" ax.set_ylabel(\"Mean belief_accuracy reward (\u22120.5 to +0.5)\")\n",
|
| 444 |
+
" ax.set_title(\"Belief-Accuracy Reward over Training\nProof the agent learned to model the user\")\n",
|
| 445 |
+
" ax.legend(loc=\"best\")\n",
|
| 446 |
+
" ax.grid(True, alpha=0.3)\n",
|
| 447 |
+
" plt.tight_layout()\n",
|
| 448 |
+
" plt.savefig(\"plots/belief_accuracy.png\", dpi=150)\n",
|
| 449 |
+
" plt.show()\n",
|
| 450 |
+
" print(f\"Saved: plots/belief_accuracy.png ({len(bvals)} points)\")\n",
|
| 451 |
"else:\n",
|
| 452 |
+
" print(\"WARNING: belief_accuracy series not found in log_history.\")\n",
|
| 453 |
+
" print(\" Check the 'Available log keys' line above to find the correct key name.\")\n",
|
| 454 |
+
" print(\" TRL key conventions vary by version; you may need to update the 'series(...)' calls.\")\n",
|
| 455 |
"\n",
|
| 456 |
+
"print(\"\nAll plots saved to plots/\")\n"
|
|
|
|
|
|
|
|
|
|
| 457 |
]
|
| 458 |
},
|
| 459 |
{
|