fix: notebook uses compute_grpo_reward_env, updated hyperparams, no emojis
Browse files
training/opengrid_grpo_colab.ipynb
CHANGED
|
@@ -4,17 +4,17 @@
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
-
"#
|
| 8 |
"\n",
|
| 9 |
"**Multi-Agent RL for Power Grid Operations**\n",
|
| 10 |
"\n",
|
| 11 |
"This notebook trains an LLM (Qwen 2.5 1.5B) to operate a power grid using GRPO (Group Relative Policy Optimization).\n",
|
| 12 |
"\n",
|
| 13 |
-
"- **Environment**: OpenGrid
|
| 14 |
"- **Task**: Maintain 50 Hz frequency, prevent line overloads, avoid blackouts\n",
|
| 15 |
"- **Training**: TRL GRPOTrainer + Unsloth 4-bit quantization\n",
|
| 16 |
"\n",
|
| 17 |
-
"
|
| 18 |
]
|
| 19 |
},
|
| 20 |
{
|
|
@@ -51,7 +51,7 @@
|
|
| 51 |
"source": [
|
| 52 |
"import os\n",
|
| 53 |
"\n",
|
| 54 |
-
"#
|
| 55 |
"REPO_URL = \"https://github.com/krishnagoyal099/Opengrid_env.git\"\n",
|
| 56 |
"\n",
|
| 57 |
"if not os.path.exists(\"opengrid\"):\n",
|
|
@@ -84,7 +84,7 @@
|
|
| 84 |
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 85 |
" print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
|
| 86 |
"else:\n",
|
| 87 |
-
" print(\"
|
| 88 |
]
|
| 89 |
},
|
| 90 |
{
|
|
@@ -177,14 +177,14 @@
|
|
| 177 |
" \"std_reward\": np.std(rewards),\n",
|
| 178 |
" \"rewards\": rewards\n",
|
| 179 |
" }\n",
|
| 180 |
-
" print(f\"[BASELINE] {task_id}: {np.mean(rewards):.2f}
|
| 181 |
"\n",
|
| 182 |
"# Save baseline for later comparison\n",
|
| 183 |
"import pickle\n",
|
| 184 |
"os.makedirs(\"training/outputs\", exist_ok=True)\n",
|
| 185 |
"with open(\"training/outputs/baseline_results.pkl\", \"wb\") as f:\n",
|
| 186 |
" pickle.dump(baseline_results, f)\n",
|
| 187 |
-
"print(\"\\n
|
| 188 |
]
|
| 189 |
},
|
| 190 |
{
|
|
@@ -222,7 +222,7 @@
|
|
| 222 |
"if tokenizer.pad_token is None:\n",
|
| 223 |
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 224 |
"\n",
|
| 225 |
-
"print(f\"
|
| 226 |
"print(f\" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")"
|
| 227 |
]
|
| 228 |
},
|
|
@@ -261,7 +261,7 @@
|
|
| 261 |
"\n",
|
| 262 |
" for t in range(min(10, task_config['max_steps'])):\n",
|
| 263 |
" for agent_id, obs in zone_obs.items():\n",
|
| 264 |
-
" # model_dump_json()
|
| 265 |
" obs_dict = _json.loads(obs.model_dump_json())\n",
|
| 266 |
" prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name)\n",
|
| 267 |
" messages = [\n",
|
|
@@ -272,7 +272,7 @@
|
|
| 272 |
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 273 |
" )\n",
|
| 274 |
" prompts.append(formatted)\n",
|
| 275 |
-
" # Store as JSON string
|
| 276 |
" obs_contexts.append(_json.dumps(obs_dict))\n",
|
| 277 |
"\n",
|
| 278 |
" # Advance env with diverse random actions (no slack bus)\n",
|
|
@@ -293,7 +293,7 @@
|
|
| 293 |
" break\n",
|
| 294 |
" zone_obs = result.observations\n",
|
| 295 |
"\n",
|
| 296 |
-
"print(f\"
|
| 297 |
"print(f\"\\nSample prompt (first 400 chars):\")\n",
|
| 298 |
"print(prompts[0][:400])"
|
| 299 |
]
|
|
@@ -312,21 +312,18 @@
|
|
| 312 |
"outputs": [],
|
| 313 |
"source": [
|
| 314 |
"import json as _json\n",
|
| 315 |
-
"from training.train_grpo import
|
| 316 |
"\n",
|
| 317 |
"def reward_fn(completions, obs_context=None, **kwargs):\n",
|
| 318 |
-
" \"\"\"GRPO
|
| 319 |
-
" obs_context arrives as JSON strings from the dataset column.\n",
|
| 320 |
-
" \"\"\"\n",
|
| 321 |
" texts = []\n",
|
| 322 |
" for c in completions:\n",
|
| 323 |
" if isinstance(c, list):\n",
|
| 324 |
-
" text = c[-1][
|
| 325 |
" else:\n",
|
| 326 |
" text = str(c)\n",
|
| 327 |
" texts.append(text)\n",
|
| 328 |
"\n",
|
| 329 |
-
" # Deserialize JSON strings → dicts for the reward scorer\n",
|
| 330 |
" if obs_context is None:\n",
|
| 331 |
" batch_obs = [None] * len(texts)\n",
|
| 332 |
" else:\n",
|
|
@@ -334,23 +331,23 @@
|
|
| 334 |
" _json.loads(ctx) if isinstance(ctx, str) else ctx\n",
|
| 335 |
" for ctx in obs_context\n",
|
| 336 |
" ]\n",
|
| 337 |
-
" return
|
| 338 |
"\n",
|
| 339 |
-
"#
|
| 340 |
"test_rewards = reward_fn([\n",
|
| 341 |
" '{\"bus_adjustments\": [{\"bus_id\": 1, \"delta\": 5.0}], \"topology_actions\": []}',\n",
|
| 342 |
-
"
|
| 343 |
"])\n",
|
| 344 |
"print(f\"Test rewards: {test_rewards}\")\n",
|
| 345 |
-
"assert len(test_rewards) == 2
|
| 346 |
-
"print(\"
|
| 347 |
]
|
| 348 |
},
|
| 349 |
{
|
| 350 |
"cell_type": "markdown",
|
| 351 |
"metadata": {},
|
| 352 |
"source": [
|
| 353 |
-
"## 9. Train with GRPO
|
| 354 |
]
|
| 355 |
},
|
| 356 |
{
|
|
@@ -368,21 +365,21 @@
|
|
| 368 |
"\n",
|
| 369 |
"grpo_config = GRPOConfig(\n",
|
| 370 |
" output_dir=\"training/outputs/grpo_checkpoints\",\n",
|
| 371 |
-
" num_train_epochs=
|
| 372 |
" per_device_train_batch_size=2,\n",
|
| 373 |
-
" gradient_accumulation_steps=
|
| 374 |
-
" learning_rate=
|
| 375 |
" logging_steps=5,\n",
|
| 376 |
" save_steps=50,\n",
|
| 377 |
" max_completion_length=256,\n",
|
| 378 |
-
" num_generations=
|
| 379 |
" report_to=\"none\",\n",
|
| 380 |
" remove_unused_columns=False,\n",
|
| 381 |
" bf16=_bf16,\n",
|
| 382 |
" fp16=_fp16,\n",
|
| 383 |
")\n",
|
| 384 |
"\n",
|
| 385 |
-
"# obs_contexts are JSON strings
|
| 386 |
"train_dataset = Dataset.from_dict({\"prompt\": prompts, \"obs_context\": obs_contexts})\n",
|
| 387 |
"print(f\"Dataset: {len(train_dataset)} rows, columns: {train_dataset.column_names}\")\n",
|
| 388 |
"\n",
|
|
@@ -396,11 +393,11 @@
|
|
| 396 |
"\n",
|
| 397 |
"print(f\"Training on {len(prompts)} prompts, {grpo_config.num_train_epochs} epoch(s)\")\n",
|
| 398 |
"print(f\"Effective batch size: {grpo_config.per_device_train_batch_size * grpo_config.gradient_accumulation_steps}\")\n",
|
| 399 |
-
"print(\"\\n
|
| 400 |
"\n",
|
| 401 |
"train_result = trainer.train()\n",
|
| 402 |
"\n",
|
| 403 |
-
"print(\"\\n
|
| 404 |
"print(f\" Total steps: {trainer.state.global_step}\")"
|
| 405 |
]
|
| 406 |
},
|
|
@@ -420,7 +417,7 @@
|
|
| 420 |
"OUTPUT_PATH = \"training/outputs/trained_model\"\n",
|
| 421 |
"trainer.save_model(OUTPUT_PATH)\n",
|
| 422 |
"tokenizer.save_pretrained(OUTPUT_PATH)\n",
|
| 423 |
-
"print(f\"
|
| 424 |
]
|
| 425 |
},
|
| 426 |
{
|
|
@@ -479,14 +476,14 @@
|
|
| 479 |
" \"std_reward\": np.std(rewards),\n",
|
| 480 |
" \"rewards\": rewards\n",
|
| 481 |
" }\n",
|
| 482 |
-
" print(f\"[TRAINED] {task_id}: {np.mean(rewards):.2f}
|
| 483 |
]
|
| 484 |
},
|
| 485 |
{
|
| 486 |
"cell_type": "markdown",
|
| 487 |
"metadata": {},
|
| 488 |
"source": [
|
| 489 |
-
"## 12. Generate Before/After Plots
|
| 490 |
]
|
| 491 |
},
|
| 492 |
{
|
|
@@ -502,7 +499,7 @@
|
|
| 502 |
"with open(\"training/outputs/baseline_results.pkl\", \"rb\") as f:\n",
|
| 503 |
" baseline_results = pickle.load(f)\n",
|
| 504 |
"\n",
|
| 505 |
-
"#
|
| 506 |
"common_tasks = [t for t in baseline_results if t in trained_results]\n",
|
| 507 |
"fig, ax = plt.subplots(figsize=(10, 6))\n",
|
| 508 |
"x = np.arange(len(common_tasks))\n",
|
|
@@ -516,7 +513,7 @@
|
|
| 516 |
"\n",
|
| 517 |
"ax.set_xlabel('Task', fontsize=12)\n",
|
| 518 |
"ax.set_ylabel('Average Episode Reward', fontsize=12)\n",
|
| 519 |
-
"ax.set_title('OpenGrid
|
| 520 |
"ax.set_xticks(x)\n",
|
| 521 |
"ax.set_xticklabels([t.replace('task_', '').title() for t in common_tasks])\n",
|
| 522 |
"ax.legend(fontsize=11)\n",
|
|
@@ -536,7 +533,7 @@
|
|
| 536 |
"plt.tight_layout()\n",
|
| 537 |
"plt.savefig('training/outputs/before_after.png', dpi=150)\n",
|
| 538 |
"plt.show()\n",
|
| 539 |
-
"print(\"
|
| 540 |
]
|
| 541 |
},
|
| 542 |
{
|
|
@@ -545,7 +542,7 @@
|
|
| 545 |
"metadata": {},
|
| 546 |
"outputs": [],
|
| 547 |
"source": [
|
| 548 |
-
"#
|
| 549 |
"history = trainer.state.log_history\n",
|
| 550 |
"\n",
|
| 551 |
"steps = [h['step'] for h in history if 'loss' in h]\n",
|
|
@@ -560,13 +557,13 @@
|
|
| 560 |
"\n",
|
| 561 |
"ax.set_xlabel('Training Step', fontsize=12)\n",
|
| 562 |
"ax.set_ylabel('Loss', fontsize=12)\n",
|
| 563 |
-
"ax.set_title('OpenGrid GRPO
|
| 564 |
"ax.legend()\n",
|
| 565 |
"ax.grid(True, alpha=0.3)\n",
|
| 566 |
"plt.tight_layout()\n",
|
| 567 |
"plt.savefig('training/outputs/training_loss.png', dpi=150)\n",
|
| 568 |
"plt.show()\n",
|
| 569 |
-
"print(\"
|
| 570 |
]
|
| 571 |
},
|
| 572 |
{
|
|
@@ -585,19 +582,19 @@
|
|
| 585 |
"outputs": [],
|
| 586 |
"source": [
|
| 587 |
"print(\"=\"*60)\n",
|
| 588 |
-
"print(\" OpenGrid GRPO Training
|
| 589 |
"print(\"=\"*60)\n",
|
| 590 |
"\n",
|
| 591 |
"# Rebuild common_tasks in case Cell 12 was skipped\n",
|
| 592 |
"common_tasks = [t for t in baseline_results if t in trained_results]\n",
|
| 593 |
"\n",
|
| 594 |
-
"print(f\"{'Task':<20} {'Baseline':>12} {'Trained':>12} {'
|
| 595 |
"print(\"-\"*60)\n",
|
| 596 |
"for t in common_tasks:\n",
|
| 597 |
" b = baseline_results[t]['avg_reward']\n",
|
| 598 |
" a = trained_results[t]['avg_reward']\n",
|
| 599 |
" delta = a - b\n",
|
| 600 |
-
" arrow = '
|
| 601 |
" print(f\"{t:<20} {b:>10.2f} {a:>10.2f} {arrow} {abs(delta):.2f}\")\n",
|
| 602 |
"print(\"=\"*60)"
|
| 603 |
]
|
|
@@ -608,10 +605,10 @@
|
|
| 608 |
"metadata": {},
|
| 609 |
"outputs": [],
|
| 610 |
"source": [
|
| 611 |
-
"#
|
| 612 |
-
"from
|
| 613 |
-
"
|
| 614 |
-
"
|
| 615 |
]
|
| 616 |
}
|
| 617 |
],
|
|
@@ -632,4 +629,4 @@
|
|
| 632 |
},
|
| 633 |
"nbformat": 4,
|
| 634 |
"nbformat_minor": 0
|
| 635 |
-
}
|
|
|
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"metadata": {},
|
| 6 |
"source": [
|
| 7 |
+
"# OpenGrid \u2014 GRPO Training Notebook\n",
|
| 8 |
"\n",
|
| 9 |
"**Multi-Agent RL for Power Grid Operations**\n",
|
| 10 |
"\n",
|
| 11 |
"This notebook trains an LLM (Qwen 2.5 1.5B) to operate a power grid using GRPO (Group Relative Policy Optimization).\n",
|
| 12 |
"\n",
|
| 13 |
+
"- **Environment**: OpenGrid \u2014 multi-agent POMDP with safety layer & oversight agent\n",
|
| 14 |
"- **Task**: Maintain 50 Hz frequency, prevent line overloads, avoid blackouts\n",
|
| 15 |
"- **Training**: TRL GRPOTrainer + Unsloth 4-bit quantization\n",
|
| 16 |
"\n",
|
| 17 |
+
" **Runtime**: Select `T4 GPU` from Runtime \u2192 Change runtime type"
|
| 18 |
]
|
| 19 |
},
|
| 20 |
{
|
|
|
|
| 51 |
"source": [
|
| 52 |
"import os\n",
|
| 53 |
"\n",
|
| 54 |
+
"# UPDATE THIS with your actual repo URL\n",
|
| 55 |
"REPO_URL = \"https://github.com/krishnagoyal099/Opengrid_env.git\"\n",
|
| 56 |
"\n",
|
| 57 |
"if not os.path.exists(\"opengrid\"):\n",
|
|
|
|
| 84 |
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 85 |
" print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
|
| 86 |
"else:\n",
|
| 87 |
+
" print(\" No GPU detected! Go to Runtime \u2192 Change runtime type \u2192 T4 GPU\")"
|
| 88 |
]
|
| 89 |
},
|
| 90 |
{
|
|
|
|
| 177 |
" \"std_reward\": np.std(rewards),\n",
|
| 178 |
" \"rewards\": rewards\n",
|
| 179 |
" }\n",
|
| 180 |
+
" print(f\"[BASELINE] {task_id}: {np.mean(rewards):.2f} \u00b1 {np.std(rewards):.2f}\")\n",
|
| 181 |
"\n",
|
| 182 |
"# Save baseline for later comparison\n",
|
| 183 |
"import pickle\n",
|
| 184 |
"os.makedirs(\"training/outputs\", exist_ok=True)\n",
|
| 185 |
"with open(\"training/outputs/baseline_results.pkl\", \"wb\") as f:\n",
|
| 186 |
" pickle.dump(baseline_results, f)\n",
|
| 187 |
+
"print(\"\\n Baseline scores saved.\")"
|
| 188 |
]
|
| 189 |
},
|
| 190 |
{
|
|
|
|
| 222 |
"if tokenizer.pad_token is None:\n",
|
| 223 |
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 224 |
"\n",
|
| 225 |
+
"print(f\" Model loaded: {MODEL_NAME}\")\n",
|
| 226 |
"print(f\" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")"
|
| 227 |
]
|
| 228 |
},
|
|
|
|
| 261 |
"\n",
|
| 262 |
" for t in range(min(10, task_config['max_steps'])):\n",
|
| 263 |
" for agent_id, obs in zone_obs.items():\n",
|
| 264 |
+
" # model_dump_json() \u2192 json.loads() ensures all keys are strings\n",
|
| 265 |
" obs_dict = _json.loads(obs.model_dump_json())\n",
|
| 266 |
" prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name)\n",
|
| 267 |
" messages = [\n",
|
|
|
|
| 272 |
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 273 |
" )\n",
|
| 274 |
" prompts.append(formatted)\n",
|
| 275 |
+
" # Store as JSON string \u2014 flat scalar, no schema-inference issues\n",
|
| 276 |
" obs_contexts.append(_json.dumps(obs_dict))\n",
|
| 277 |
"\n",
|
| 278 |
" # Advance env with diverse random actions (no slack bus)\n",
|
|
|
|
| 293 |
" break\n",
|
| 294 |
" zone_obs = result.observations\n",
|
| 295 |
"\n",
|
| 296 |
+
"print(f\" Generated {len(prompts)} training prompts\")\n",
|
| 297 |
"print(f\"\\nSample prompt (first 400 chars):\")\n",
|
| 298 |
"print(prompts[0][:400])"
|
| 299 |
]
|
|
|
|
| 312 |
"outputs": [],
|
| 313 |
"source": [
|
| 314 |
"import json as _json\n",
|
| 315 |
+
"from training.train_grpo import compute_grpo_reward_env, extract_action\n",
|
| 316 |
"\n",
|
| 317 |
"def reward_fn(completions, obs_context=None, **kwargs):\n",
|
| 318 |
+
" \"\"\"GRPO reward function with env-grounded physics rewards.\"\"\"\n",
|
|
|
|
|
|
|
| 319 |
" texts = []\n",
|
| 320 |
" for c in completions:\n",
|
| 321 |
" if isinstance(c, list):\n",
|
| 322 |
+
" text = c[-1][\"content\"] if c else \"\"\n",
|
| 323 |
" else:\n",
|
| 324 |
" text = str(c)\n",
|
| 325 |
" texts.append(text)\n",
|
| 326 |
"\n",
|
|
|
|
| 327 |
" if obs_context is None:\n",
|
| 328 |
" batch_obs = [None] * len(texts)\n",
|
| 329 |
" else:\n",
|
|
|
|
| 331 |
" _json.loads(ctx) if isinstance(ctx, str) else ctx\n",
|
| 332 |
" for ctx in obs_context\n",
|
| 333 |
" ]\n",
|
| 334 |
+
" return compute_grpo_reward_env(texts, batch_obs, task_config, horizon=3)\n",
|
| 335 |
"\n",
|
| 336 |
+
"# Sanity test\n",
|
| 337 |
"test_rewards = reward_fn([\n",
|
| 338 |
" '{\"bus_adjustments\": [{\"bus_id\": 1, \"delta\": 5.0}], \"topology_actions\": []}',\n",
|
| 339 |
+
" \"invalid json here\",\n",
|
| 340 |
"])\n",
|
| 341 |
"print(f\"Test rewards: {test_rewards}\")\n",
|
| 342 |
+
"assert len(test_rewards) == 2\n",
|
| 343 |
+
"print(\"[OK] reward_fn works\")\n"
|
| 344 |
]
|
| 345 |
},
|
| 346 |
{
|
| 347 |
"cell_type": "markdown",
|
| 348 |
"metadata": {},
|
| 349 |
"source": [
|
| 350 |
+
"## 9. Train with GRPO "
|
| 351 |
]
|
| 352 |
},
|
| 353 |
{
|
|
|
|
| 365 |
"\n",
|
| 366 |
"grpo_config = GRPOConfig(\n",
|
| 367 |
" output_dir=\"training/outputs/grpo_checkpoints\",\n",
|
| 368 |
+
" num_train_epochs=3,\n",
|
| 369 |
" per_device_train_batch_size=2,\n",
|
| 370 |
+
" gradient_accumulation_steps=8,\n",
|
| 371 |
+
" learning_rate=1e-5,\n",
|
| 372 |
" logging_steps=5,\n",
|
| 373 |
" save_steps=50,\n",
|
| 374 |
" max_completion_length=256,\n",
|
| 375 |
+
" num_generations=8,\n",
|
| 376 |
" report_to=\"none\",\n",
|
| 377 |
" remove_unused_columns=False,\n",
|
| 378 |
" bf16=_bf16,\n",
|
| 379 |
" fp16=_fp16,\n",
|
| 380 |
")\n",
|
| 381 |
"\n",
|
| 382 |
+
"# obs_contexts are JSON strings \u2014 PyArrow handles flat strings with no issues\n",
|
| 383 |
"train_dataset = Dataset.from_dict({\"prompt\": prompts, \"obs_context\": obs_contexts})\n",
|
| 384 |
"print(f\"Dataset: {len(train_dataset)} rows, columns: {train_dataset.column_names}\")\n",
|
| 385 |
"\n",
|
|
|
|
| 393 |
"\n",
|
| 394 |
"print(f\"Training on {len(prompts)} prompts, {grpo_config.num_train_epochs} epoch(s)\")\n",
|
| 395 |
"print(f\"Effective batch size: {grpo_config.per_device_train_batch_size * grpo_config.gradient_accumulation_steps}\")\n",
|
| 396 |
+
"print(\"\\n Starting GRPO training...\")\n",
|
| 397 |
"\n",
|
| 398 |
"train_result = trainer.train()\n",
|
| 399 |
"\n",
|
| 400 |
+
"print(\"\\n Training complete!\")\n",
|
| 401 |
"print(f\" Total steps: {trainer.state.global_step}\")"
|
| 402 |
]
|
| 403 |
},
|
|
|
|
| 417 |
"OUTPUT_PATH = \"training/outputs/trained_model\"\n",
|
| 418 |
"trainer.save_model(OUTPUT_PATH)\n",
|
| 419 |
"tokenizer.save_pretrained(OUTPUT_PATH)\n",
|
| 420 |
+
"print(f\" Model saved to {OUTPUT_PATH}\")"
|
| 421 |
]
|
| 422 |
},
|
| 423 |
{
|
|
|
|
| 476 |
" \"std_reward\": np.std(rewards),\n",
|
| 477 |
" \"rewards\": rewards\n",
|
| 478 |
" }\n",
|
| 479 |
+
" print(f\"[TRAINED] {task_id}: {np.mean(rewards):.2f} \u00b1 {np.std(rewards):.2f}\\n\")"
|
| 480 |
]
|
| 481 |
},
|
| 482 |
{
|
| 483 |
"cell_type": "markdown",
|
| 484 |
"metadata": {},
|
| 485 |
"source": [
|
| 486 |
+
"## 12. Generate Before/After Plots "
|
| 487 |
]
|
| 488 |
},
|
| 489 |
{
|
|
|
|
| 499 |
"with open(\"training/outputs/baseline_results.pkl\", \"rb\") as f:\n",
|
| 500 |
" baseline_results = pickle.load(f)\n",
|
| 501 |
"\n",
|
| 502 |
+
"# \u2500\u2500 Plot 1: Before vs After Bar Chart \u2500\u2500\n",
|
| 503 |
"common_tasks = [t for t in baseline_results if t in trained_results]\n",
|
| 504 |
"fig, ax = plt.subplots(figsize=(10, 6))\n",
|
| 505 |
"x = np.arange(len(common_tasks))\n",
|
|
|
|
| 513 |
"\n",
|
| 514 |
"ax.set_xlabel('Task', fontsize=12)\n",
|
| 515 |
"ax.set_ylabel('Average Episode Reward', fontsize=12)\n",
|
| 516 |
+
"ax.set_title('OpenGrid \u2014 GRPO Training: Before vs After', fontsize=14, fontweight='bold')\n",
|
| 517 |
"ax.set_xticks(x)\n",
|
| 518 |
"ax.set_xticklabels([t.replace('task_', '').title() for t in common_tasks])\n",
|
| 519 |
"ax.legend(fontsize=11)\n",
|
|
|
|
| 533 |
"plt.tight_layout()\n",
|
| 534 |
"plt.savefig('training/outputs/before_after.png', dpi=150)\n",
|
| 535 |
"plt.show()\n",
|
| 536 |
+
"print(\" Saved: training/outputs/before_after.png\")"
|
| 537 |
]
|
| 538 |
},
|
| 539 |
{
|
|
|
|
| 542 |
"metadata": {},
|
| 543 |
"outputs": [],
|
| 544 |
"source": [
|
| 545 |
+
"# \u2500\u2500 Plot 2: Training Reward Curve \u2500\u2500\n",
|
| 546 |
"history = trainer.state.log_history\n",
|
| 547 |
"\n",
|
| 548 |
"steps = [h['step'] for h in history if 'loss' in h]\n",
|
|
|
|
| 557 |
"\n",
|
| 558 |
"ax.set_xlabel('Training Step', fontsize=12)\n",
|
| 559 |
"ax.set_ylabel('Loss', fontsize=12)\n",
|
| 560 |
+
"ax.set_title('OpenGrid GRPO \u2014 Training Loss', fontsize=14, fontweight='bold')\n",
|
| 561 |
"ax.legend()\n",
|
| 562 |
"ax.grid(True, alpha=0.3)\n",
|
| 563 |
"plt.tight_layout()\n",
|
| 564 |
"plt.savefig('training/outputs/training_loss.png', dpi=150)\n",
|
| 565 |
"plt.show()\n",
|
| 566 |
+
"print(\" Saved: training/outputs/training_loss.png\")"
|
| 567 |
]
|
| 568 |
},
|
| 569 |
{
|
|
|
|
| 582 |
"outputs": [],
|
| 583 |
"source": [
|
| 584 |
"print(\"=\"*60)\n",
|
| 585 |
+
"print(\" OpenGrid GRPO Training \u2014 Results Summary\")\n",
|
| 586 |
"print(\"=\"*60)\n",
|
| 587 |
"\n",
|
| 588 |
"# Rebuild common_tasks in case Cell 12 was skipped\n",
|
| 589 |
"common_tasks = [t for t in baseline_results if t in trained_results]\n",
|
| 590 |
"\n",
|
| 591 |
+
"print(f\"{'Task':<20} {'Baseline':>12} {'Trained':>12} {'\u0394':>10}\")\n",
|
| 592 |
"print(\"-\"*60)\n",
|
| 593 |
"for t in common_tasks:\n",
|
| 594 |
" b = baseline_results[t]['avg_reward']\n",
|
| 595 |
" a = trained_results[t]['avg_reward']\n",
|
| 596 |
" delta = a - b\n",
|
| 597 |
+
" arrow = '\u2191' if delta > 0 else '\u2193'\n",
|
| 598 |
" print(f\"{t:<20} {b:>10.2f} {a:>10.2f} {arrow} {abs(delta):.2f}\")\n",
|
| 599 |
"print(\"=\"*60)"
|
| 600 |
]
|
|
|
|
| 605 |
"metadata": {},
|
| 606 |
"outputs": [],
|
| 607 |
"source": [
|
| 608 |
+
"# Display plots inline\n",
|
| 609 |
+
"from IPython.display import Image, display\n",
|
| 610 |
+
"display(Image(\"training/outputs/before_after.png\"))\n",
|
| 611 |
+
"display(Image(\"training/outputs/training_loss.png\"))\n"
|
| 612 |
]
|
| 613 |
}
|
| 614 |
],
|
|
|
|
| 629 |
},
|
| 630 |
"nbformat": 4,
|
| 631 |
"nbformat_minor": 0
|
| 632 |
+
}
|