InosLihka Claude Opus 4.7 (1M context) commited on
Commit
b5ac530
·
1 Parent(s): 786249b

notebook: add belief-accuracy + reward-components plots

Browse files

The original plot cell only generated training_loss.png and reward_curve.png.
For meta-RL submission, the most important plot is missing: how the agent's
belief_accuracy reward evolved over training. That curve directly proves
the meta-learning thesis (agent learns to model the user from observation).

Added two plots:
- plots/reward_components.png: all 4 reward functions overlaid over training
(format_valid, action_legal, env_reward, belief_accuracy) so you can see
which signals were gradient-providing
- plots/belief_accuracy.png: focused belief reward with rolling mean and
neutral-baseline reference line

Plot generation is defensive: it discovers all log_history keys and tries
multiple TRL key conventions (rewards/X/mean, rewards/X, X) since these
vary by TRL version. Prints "Available log keys" so the user can debug
if any series isn't found.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

Files changed (1) hide show
  1. training/RhythmEnv_GRPO_Training.ipynb +109 -54
training/RhythmEnv_GRPO_Training.ipynb CHANGED
@@ -328,77 +328,132 @@
328
  "metadata": {},
329
  "outputs": [],
330
  "source": [
 
 
 
 
331
  "import matplotlib.pyplot as plt\n",
332
- "import json\n",
333
- "import os\n",
334
  "\n",
335
- "# Extract training logs from trainer\n",
336
  "log_history = trainer.state.log_history\n",
 
337
  "\n",
338
- "steps = []\n",
339
- "losses = []\n",
340
- "rewards = []\n",
341
- "reward_stds = []\n",
342
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  "for entry in log_history:\n",
344
- " if \"loss\" in entry:\n",
345
- " steps.append(entry.get(\"step\", 0))\n",
346
- " losses.append(entry[\"loss\"])\n",
347
- " if \"reward\" in entry:\n",
348
- " rewards.append(entry[\"reward\"])\n",
349
- " if \"reward_std\" in entry:\n",
350
- " reward_stds.append(entry[\"reward_std\"])\n",
351
- "\n",
352
- "# Also try rewards/mean key used by some TRL versions\n",
353
- "if not rewards:\n",
354
- " for entry in log_history:\n",
355
- " if \"rewards/mean\" in entry:\n",
356
- " rewards.append(entry[\"rewards/mean\"])\n",
357
- " if \"rewards/std\" in entry:\n",
358
- " reward_stds.append(entry[\"rewards/std\"])\n",
359
- "\n",
360
- "os.makedirs(\"plots\", exist_ok=True)\n",
361
  "\n",
362
  "# --- Plot 1: Training Loss ---\n",
363
- "fig, ax = plt.subplots(figsize=(10, 5))\n",
364
- "ax.plot(steps[:len(losses)], losses, color=\"#2563eb\", linewidth=1.5, alpha=0.8)\n",
365
- "ax.set_xlabel(\"Training Step\", fontsize=12)\n",
366
- "ax.set_ylabel(\"Loss\", fontsize=12)\n",
367
- "ax.set_title(\"GRPO Training Loss \u00e2\u20ac\u201d RhythmEnv Life Simulator\", fontsize=14)\n",
368
- "ax.grid(True, alpha=0.3)\n",
369
- "plt.tight_layout()\n",
370
- "plt.savefig(\"plots/training_loss.png\", dpi=150)\n",
371
- "plt.show()\n",
372
- "print(\"Saved: plots/training_loss.png\")\n",
 
 
373
  "\n",
374
- "# --- Plot 2: Mean Reward ---\n",
375
- "if rewards:\n",
 
 
376
  " fig, ax = plt.subplots(figsize=(10, 5))\n",
377
- " reward_steps = steps[:len(rewards)]\n",
378
- " ax.plot(reward_steps, rewards, color=\"#16a34a\", linewidth=1.5, alpha=0.8, label=\"Mean Reward\")\n",
379
- " if reward_stds and len(reward_stds) == len(rewards):\n",
380
- " import numpy as np\n",
381
- " r = np.array(rewards)\n",
382
- " s = np.array(reward_stds)\n",
383
- " ax.fill_between(reward_steps, r - s, r + s, color=\"#16a34a\", alpha=0.15, label=\"\u00c2\u00b11 Std Dev\")\n",
384
- " ax.set_xlabel(\"Training Step\", fontsize=12)\n",
385
- " ax.set_ylabel(\"Mean Reward\", fontsize=12)\n",
386
- " ax.set_title(\"GRPO Mean Reward \u00e2\u20ac\u201d RhythmEnv Life Simulator\", fontsize=14)\n",
387
  " ax.legend()\n",
388
  " ax.grid(True, alpha=0.3)\n",
389
  " plt.tight_layout()\n",
390
  " plt.savefig(\"plots/reward_curve.png\", dpi=150)\n",
391
  " plt.show()\n",
392
- " print(\"Saved: plots/reward_curve.png\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  "else:\n",
394
- " print(\"No reward data in logs. Check trainer.state.log_history keys:\")\n",
395
- " if log_history:\n",
396
- " print(list(log_history[0].keys()))\n",
397
  "\n",
398
- "# Save raw log data for reference\n",
399
- "with open(\"plots/training_log.json\", \"w\") as f:\n",
400
- " json.dump(log_history, f, indent=2)\n",
401
- "print(\"Saved: plots/training_log.json\")"
402
  ]
403
  },
404
  {
 
328
  "metadata": {},
329
  "outputs": [],
330
  "source": [
331
+ "# Generate all training plots from trainer log_history\n",
332
+ "# Saves: training_loss.png, reward_curve.png, reward_components.png, belief_accuracy.png\n",
333
+ "import os, json\n",
334
+ "import numpy as np\n",
335
  "import matplotlib.pyplot as plt\n",
 
 
336
  "\n",
 
337
  "log_history = trainer.state.log_history\n",
338
+ "os.makedirs(\"plots\", exist_ok=True)\n",
339
  "\n",
340
+ "# Save raw log first (always, even if plotting fails)\n",
341
+ "with open(\"plots/training_log.json\", \"w\") as f:\n",
342
+ " json.dump(log_history, f, indent=2)\n",
 
343
  "\n",
344
+ "# Helper: extract a series across all log entries that have a key\n",
345
+ "def series(*keys):\n",
346
+ " \"\"\"Extract (steps, values) for the first matching key across log entries.\"\"\"\n",
347
+ " for k in keys:\n",
348
+ " steps, vals = [], []\n",
349
+ " for entry in log_history:\n",
350
+ " if k in entry:\n",
351
+ " steps.append(entry.get(\"step\", len(steps)))\n",
352
+ " vals.append(entry[k])\n",
353
+ " if vals:\n",
354
+ " return steps, vals, k\n",
355
+ " return [], [], None\n",
356
+ "\n",
357
+ "# Discover all log keys to help debug missing plots\n",
358
+ "all_keys = set()\n",
359
  "for entry in log_history:\n",
360
+ " all_keys.update(entry.keys())\n",
361
+ "print(f\"Available log keys: {sorted(all_keys)}\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
  "\n",
363
  "# --- Plot 1: Training Loss ---\n",
364
+ "steps, losses, _ = series(\"loss\", \"train/loss\")\n",
365
+ "if losses:\n",
366
+ " fig, ax = plt.subplots(figsize=(10, 5))\n",
367
+ " ax.plot(steps, losses, color=\"#2563eb\", linewidth=1.5, alpha=0.8)\n",
368
+ " ax.set_xlabel(\"Training Step\")\n",
369
+ " ax.set_ylabel(\"Loss\")\n",
370
+ " ax.set_title(\"GRPO Training Loss \u2014 RhythmEnv Meta-RL\")\n",
371
+ " ax.grid(True, alpha=0.3)\n",
372
+ " plt.tight_layout()\n",
373
+ " plt.savefig(\"plots/training_loss.png\", dpi=150)\n",
374
+ " plt.show()\n",
375
+ " print(f\"Saved: plots/training_loss.png ({len(losses)} points)\")\n",
376
  "\n",
377
+ "# --- Plot 2: Mean Reward (overall) ---\n",
378
+ "rsteps, rvals, rkey = series(\"reward\", \"rewards/mean\", \"rewards/total/mean\")\n",
379
+ "ssteps, svals, _ = series(\"reward_std\", \"rewards/std\", \"rewards/total/std\")\n",
380
+ "if rvals:\n",
381
  " fig, ax = plt.subplots(figsize=(10, 5))\n",
382
+ " ax.plot(rsteps, rvals, color=\"#16a34a\", linewidth=1.5, label=f\"Mean Reward ({rkey})\")\n",
383
+ " if svals and len(svals) == len(rvals):\n",
384
+ " r, s = np.array(rvals), np.array(svals)\n",
385
+ " ax.fill_between(rsteps, r - s, r + s, color=\"#16a34a\", alpha=0.15, label=\"\u00b11 std\")\n",
386
+ " ax.set_xlabel(\"Training Step\")\n",
387
+ " ax.set_ylabel(\"Mean Total Reward\")\n",
388
+ " ax.set_title(\"GRPO Mean Reward over Training \u2014 RhythmEnv Meta-RL\")\n",
 
 
 
389
  " ax.legend()\n",
390
  " ax.grid(True, alpha=0.3)\n",
391
  " plt.tight_layout()\n",
392
  " plt.savefig(\"plots/reward_curve.png\", dpi=150)\n",
393
  " plt.show()\n",
394
+ " print(f\"Saved: plots/reward_curve.png ({len(rvals)} points)\")\n",
395
+ "\n",
396
+ "# --- Plot 3: Per-Reward-Function Components (the 4-layer reward stack) ---\n",
397
+ "# TRL logs these as rewards/<func_name>/mean in newer versions.\n",
398
+ "components = [\n",
399
+ " (\"format_valid\", [\"rewards/format_valid/mean\", \"rewards/format_valid\", \"format_valid_reward\"]),\n",
400
+ " (\"action_legal\", [\"rewards/action_legal/mean\", \"rewards/action_legal\", \"action_legal_reward\"]),\n",
401
+ " (\"env_reward\", [\"rewards/env_reward/mean\", \"rewards/env_reward\", \"env_reward_reward\"]),\n",
402
+ " (\"belief_accuracy\", [\"rewards/belief_accuracy/mean\", \"rewards/belief_accuracy\", \"belief_accuracy_reward\"]),\n",
403
+ "]\n",
404
+ "found = []\n",
405
+ "for name, keys in components:\n",
406
+ " s, v, k = series(*keys)\n",
407
+ " if v:\n",
408
+ " found.append((name, s, v))\n",
409
+ " print(f\" {name}: matched key '{k}'\")\n",
410
+ " else:\n",
411
+ " print(f\" {name}: NOT FOUND (looked for {keys})\")\n",
412
+ "\n",
413
+ "if found:\n",
414
+ " fig, ax = plt.subplots(figsize=(12, 6))\n",
415
+ " colors = {\"format_valid\": \"#94a3b8\", \"action_legal\": \"#60a5fa\", \"env_reward\": \"#22c55e\", \"belief_accuracy\": \"#a855f7\"}\n",
416
+ " for name, s, v in found:\n",
417
+ " ax.plot(s, v, color=colors.get(name, \"#000\"), linewidth=1.5, alpha=0.85, label=name)\n",
418
+ " ax.axhline(0, color=\"k\", linewidth=0.4)\n",
419
+ " ax.set_xlabel(\"Training Step\")\n",
420
+ " ax.set_ylabel(\"Mean Reward Component\")\n",
421
+ " ax.set_title(\"4-Layer Reward Stack over Training (RhythmEnv Meta-RL)\")\n",
422
+ " ax.legend(loc=\"best\")\n",
423
+ " ax.grid(True, alpha=0.3)\n",
424
+ " plt.tight_layout()\n",
425
+ " plt.savefig(\"plots/reward_components.png\", dpi=150)\n",
426
+ " plt.show()\n",
427
+ " print(f\"Saved: plots/reward_components.png ({len(found)} components)\")\n",
428
+ "\n",
429
+ "# --- Plot 4: Belief-Accuracy Curve (THE meta-RL signal) ---\n",
430
+ "bsteps, bvals, bkey = series(\"rewards/belief_accuracy/mean\", \"rewards/belief_accuracy\", \"belief_accuracy_reward\")\n",
431
+ "if bvals:\n",
432
+ " fig, ax = plt.subplots(figsize=(10, 5))\n",
433
+ " ax.plot(bsteps, bvals, color=\"#a855f7\", linewidth=2.0, alpha=0.9, label=\"Belief reward\")\n",
434
+ " # Smoothed line (rolling mean)\n",
435
+ " if len(bvals) > 20:\n",
436
+ " win = max(10, len(bvals) // 30)\n",
437
+ " kernel = np.ones(win) / win\n",
438
+ " smooth = np.convolve(bvals, kernel, mode=\"valid\")\n",
439
+ " smooth_x = bsteps[win - 1:]\n",
440
+ " ax.plot(smooth_x, smooth, color=\"#7e22ce\", linewidth=2.5, label=f\"Rolling mean ({win}-step)\")\n",
441
+ " ax.axhline(0.0, color=\"k\", linewidth=0.5, linestyle=\"--\", alpha=0.5, label=\"neutral belief baseline (0.0)\")\n",
442
+ " ax.set_xlabel(\"Training Step\")\n",
443
+ " ax.set_ylabel(\"Mean belief_accuracy reward (\u22120.5 to +0.5)\")\n",
444
+ " ax.set_title(\"Belief-Accuracy Reward over Training\nProof the agent learned to model the user\")\n",
445
+ " ax.legend(loc=\"best\")\n",
446
+ " ax.grid(True, alpha=0.3)\n",
447
+ " plt.tight_layout()\n",
448
+ " plt.savefig(\"plots/belief_accuracy.png\", dpi=150)\n",
449
+ " plt.show()\n",
450
+ " print(f\"Saved: plots/belief_accuracy.png ({len(bvals)} points)\")\n",
451
  "else:\n",
452
+ " print(\"WARNING: belief_accuracy series not found in log_history.\")\n",
453
+ " print(\" Check the 'Available log keys' line above to find the correct key name.\")\n",
454
+ " print(\" TRL key conventions vary by version; you may need to update the 'series(...)' calls.\")\n",
455
  "\n",
456
+ "print(\"\nAll plots saved to plots/\")\n"
 
 
 
457
  ]
458
  },
459
  {