K446 commited on
Commit
69bab30
·
1 Parent(s): be15396

fix: notebook uses compute_grpo_reward_env, updated hyperparams, no emojis

Browse files
Files changed (1) hide show
  1. training/opengrid_grpo_colab.ipynb +44 -47
training/opengrid_grpo_colab.ipynb CHANGED
@@ -4,17 +4,17 @@
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
- "# 🔋 OpenGrid GRPO Training Notebook\n",
8
  "\n",
9
  "**Multi-Agent RL for Power Grid Operations**\n",
10
  "\n",
11
  "This notebook trains an LLM (Qwen 2.5 1.5B) to operate a power grid using GRPO (Group Relative Policy Optimization).\n",
12
  "\n",
13
- "- **Environment**: OpenGrid multi-agent POMDP with safety layer & oversight agent\n",
14
  "- **Task**: Maintain 50 Hz frequency, prevent line overloads, avoid blackouts\n",
15
  "- **Training**: TRL GRPOTrainer + Unsloth 4-bit quantization\n",
16
  "\n",
17
- " **Runtime**: Select `T4 GPU` from Runtime Change runtime type"
18
  ]
19
  },
20
  {
@@ -51,7 +51,7 @@
51
  "source": [
52
  "import os\n",
53
  "\n",
54
- "# ⚠️ UPDATE THIS with your actual repo URL\n",
55
  "REPO_URL = \"https://github.com/krishnagoyal099/Opengrid_env.git\"\n",
56
  "\n",
57
  "if not os.path.exists(\"opengrid\"):\n",
@@ -84,7 +84,7 @@
84
  " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
85
  " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
86
  "else:\n",
87
- " print(\"⚠️ No GPU detected! Go to Runtime Change runtime type T4 GPU\")"
88
  ]
89
  },
90
  {
@@ -177,14 +177,14 @@
177
  " \"std_reward\": np.std(rewards),\n",
178
  " \"rewards\": rewards\n",
179
  " }\n",
180
- " print(f\"[BASELINE] {task_id}: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}\")\n",
181
  "\n",
182
  "# Save baseline for later comparison\n",
183
  "import pickle\n",
184
  "os.makedirs(\"training/outputs\", exist_ok=True)\n",
185
  "with open(\"training/outputs/baseline_results.pkl\", \"wb\") as f:\n",
186
  " pickle.dump(baseline_results, f)\n",
187
- "print(\"\\n Baseline scores saved.\")"
188
  ]
189
  },
190
  {
@@ -222,7 +222,7 @@
222
  "if tokenizer.pad_token is None:\n",
223
  " tokenizer.pad_token = tokenizer.eos_token\n",
224
  "\n",
225
- "print(f\" Model loaded: {MODEL_NAME}\")\n",
226
  "print(f\" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")"
227
  ]
228
  },
@@ -261,7 +261,7 @@
261
  "\n",
262
  " for t in range(min(10, task_config['max_steps'])):\n",
263
  " for agent_id, obs in zone_obs.items():\n",
264
- " # model_dump_json() json.loads() ensures all keys are strings\n",
265
  " obs_dict = _json.loads(obs.model_dump_json())\n",
266
  " prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name)\n",
267
  " messages = [\n",
@@ -272,7 +272,7 @@
272
  " messages, tokenize=False, add_generation_prompt=True\n",
273
  " )\n",
274
  " prompts.append(formatted)\n",
275
- " # Store as JSON string flat scalar, no schema-inference issues\n",
276
  " obs_contexts.append(_json.dumps(obs_dict))\n",
277
  "\n",
278
  " # Advance env with diverse random actions (no slack bus)\n",
@@ -293,7 +293,7 @@
293
  " break\n",
294
  " zone_obs = result.observations\n",
295
  "\n",
296
- "print(f\" Generated {len(prompts)} training prompts\")\n",
297
  "print(f\"\\nSample prompt (first 400 chars):\")\n",
298
  "print(prompts[0][:400])"
299
  ]
@@ -312,21 +312,18 @@
312
  "outputs": [],
313
  "source": [
314
  "import json as _json\n",
315
- "from training.train_grpo import compute_grpo_reward, extract_action\n",
316
  "\n",
317
  "def reward_fn(completions, obs_context=None, **kwargs):\n",
318
- " \"\"\"GRPO-compatible reward function for OpenGrid.\n",
319
- " obs_context arrives as JSON strings from the dataset column.\n",
320
- " \"\"\"\n",
321
  " texts = []\n",
322
  " for c in completions:\n",
323
  " if isinstance(c, list):\n",
324
- " text = c[-1]['content'] if c else \"\"\n",
325
  " else:\n",
326
  " text = str(c)\n",
327
  " texts.append(text)\n",
328
  "\n",
329
- " # Deserialize JSON strings → dicts for the reward scorer\n",
330
  " if obs_context is None:\n",
331
  " batch_obs = [None] * len(texts)\n",
332
  " else:\n",
@@ -334,23 +331,23 @@
334
  " _json.loads(ctx) if isinstance(ctx, str) else ctx\n",
335
  " for ctx in obs_context\n",
336
  " ]\n",
337
- " return compute_grpo_reward(texts, batch_obs)\n",
338
  "\n",
339
- "# Quick sanity test\n",
340
  "test_rewards = reward_fn([\n",
341
  " '{\"bus_adjustments\": [{\"bus_id\": 1, \"delta\": 5.0}], \"topology_actions\": []}',\n",
342
- " 'invalid json here',\n",
343
  "])\n",
344
  "print(f\"Test rewards: {test_rewards}\")\n",
345
- "assert len(test_rewards) == 2, \"reward_fn must return one score per completion\"\n",
346
- "print(\" reward_fn OK\")"
347
  ]
348
  },
349
  {
350
  "cell_type": "markdown",
351
  "metadata": {},
352
  "source": [
353
- "## 9. Train with GRPO 🚀"
354
  ]
355
  },
356
  {
@@ -368,21 +365,21 @@
368
  "\n",
369
  "grpo_config = GRPOConfig(\n",
370
  " output_dir=\"training/outputs/grpo_checkpoints\",\n",
371
- " num_train_epochs=1,\n",
372
  " per_device_train_batch_size=2,\n",
373
- " gradient_accumulation_steps=4,\n",
374
- " learning_rate=5e-6,\n",
375
  " logging_steps=5,\n",
376
  " save_steps=50,\n",
377
  " max_completion_length=256,\n",
378
- " num_generations=4,\n",
379
  " report_to=\"none\",\n",
380
  " remove_unused_columns=False,\n",
381
  " bf16=_bf16,\n",
382
  " fp16=_fp16,\n",
383
  ")\n",
384
  "\n",
385
- "# obs_contexts are JSON strings PyArrow handles flat strings with no issues\n",
386
  "train_dataset = Dataset.from_dict({\"prompt\": prompts, \"obs_context\": obs_contexts})\n",
387
  "print(f\"Dataset: {len(train_dataset)} rows, columns: {train_dataset.column_names}\")\n",
388
  "\n",
@@ -396,11 +393,11 @@
396
  "\n",
397
  "print(f\"Training on {len(prompts)} prompts, {grpo_config.num_train_epochs} epoch(s)\")\n",
398
  "print(f\"Effective batch size: {grpo_config.per_device_train_batch_size * grpo_config.gradient_accumulation_steps}\")\n",
399
- "print(\"\\n🚀 Starting GRPO training...\")\n",
400
  "\n",
401
  "train_result = trainer.train()\n",
402
  "\n",
403
- "print(\"\\n Training complete!\")\n",
404
  "print(f\" Total steps: {trainer.state.global_step}\")"
405
  ]
406
  },
@@ -420,7 +417,7 @@
420
  "OUTPUT_PATH = \"training/outputs/trained_model\"\n",
421
  "trainer.save_model(OUTPUT_PATH)\n",
422
  "tokenizer.save_pretrained(OUTPUT_PATH)\n",
423
- "print(f\" Model saved to {OUTPUT_PATH}\")"
424
  ]
425
  },
426
  {
@@ -479,14 +476,14 @@
479
  " \"std_reward\": np.std(rewards),\n",
480
  " \"rewards\": rewards\n",
481
  " }\n",
482
- " print(f\"[TRAINED] {task_id}: {np.mean(rewards):.2f} ± {np.std(rewards):.2f}\\n\")"
483
  ]
484
  },
485
  {
486
  "cell_type": "markdown",
487
  "metadata": {},
488
  "source": [
489
- "## 12. Generate Before/After Plots 📊"
490
  ]
491
  },
492
  {
@@ -502,7 +499,7 @@
502
  "with open(\"training/outputs/baseline_results.pkl\", \"rb\") as f:\n",
503
  " baseline_results = pickle.load(f)\n",
504
  "\n",
505
- "# ── Plot 1: Before vs After Bar Chart ──\n",
506
  "common_tasks = [t for t in baseline_results if t in trained_results]\n",
507
  "fig, ax = plt.subplots(figsize=(10, 6))\n",
508
  "x = np.arange(len(common_tasks))\n",
@@ -516,7 +513,7 @@
516
  "\n",
517
  "ax.set_xlabel('Task', fontsize=12)\n",
518
  "ax.set_ylabel('Average Episode Reward', fontsize=12)\n",
519
- "ax.set_title('OpenGrid GRPO Training: Before vs After', fontsize=14, fontweight='bold')\n",
520
  "ax.set_xticks(x)\n",
521
  "ax.set_xticklabels([t.replace('task_', '').title() for t in common_tasks])\n",
522
  "ax.legend(fontsize=11)\n",
@@ -536,7 +533,7 @@
536
  "plt.tight_layout()\n",
537
  "plt.savefig('training/outputs/before_after.png', dpi=150)\n",
538
  "plt.show()\n",
539
- "print(\" Saved: training/outputs/before_after.png\")"
540
  ]
541
  },
542
  {
@@ -545,7 +542,7 @@
545
  "metadata": {},
546
  "outputs": [],
547
  "source": [
548
- "# ── Plot 2: Training Reward Curve ──\n",
549
  "history = trainer.state.log_history\n",
550
  "\n",
551
  "steps = [h['step'] for h in history if 'loss' in h]\n",
@@ -560,13 +557,13 @@
560
  "\n",
561
  "ax.set_xlabel('Training Step', fontsize=12)\n",
562
  "ax.set_ylabel('Loss', fontsize=12)\n",
563
- "ax.set_title('OpenGrid GRPO Training Loss', fontsize=14, fontweight='bold')\n",
564
  "ax.legend()\n",
565
  "ax.grid(True, alpha=0.3)\n",
566
  "plt.tight_layout()\n",
567
  "plt.savefig('training/outputs/training_loss.png', dpi=150)\n",
568
  "plt.show()\n",
569
- "print(\" Saved: training/outputs/training_loss.png\")"
570
  ]
571
  },
572
  {
@@ -585,19 +582,19 @@
585
  "outputs": [],
586
  "source": [
587
  "print(\"=\"*60)\n",
588
- "print(\" OpenGrid GRPO Training Results Summary\")\n",
589
  "print(\"=\"*60)\n",
590
  "\n",
591
  "# Rebuild common_tasks in case Cell 12 was skipped\n",
592
  "common_tasks = [t for t in baseline_results if t in trained_results]\n",
593
  "\n",
594
- "print(f\"{'Task':<20} {'Baseline':>12} {'Trained':>12} {'Δ':>10}\")\n",
595
  "print(\"-\"*60)\n",
596
  "for t in common_tasks:\n",
597
  " b = baseline_results[t]['avg_reward']\n",
598
  " a = trained_results[t]['avg_reward']\n",
599
  " delta = a - b\n",
600
- " arrow = '' if delta > 0 else ''\n",
601
  " print(f\"{t:<20} {b:>10.2f} {a:>10.2f} {arrow} {abs(delta):.2f}\")\n",
602
  "print(\"=\"*60)"
603
  ]
@@ -608,10 +605,10 @@
608
  "metadata": {},
609
  "outputs": [],
610
  "source": [
611
- "# Download plots for your README\n",
612
- "from google.colab import files\n",
613
- "files.download('training/outputs/before_after.png')\n",
614
- "files.download('training/outputs/training_loss.png')"
615
  ]
616
  }
617
  ],
@@ -632,4 +629,4 @@
632
  },
633
  "nbformat": 4,
634
  "nbformat_minor": 0
635
- }
 
4
  "cell_type": "markdown",
5
  "metadata": {},
6
  "source": [
7
+ "# OpenGrid \u2014 GRPO Training Notebook\n",
8
  "\n",
9
  "**Multi-Agent RL for Power Grid Operations**\n",
10
  "\n",
11
  "This notebook trains an LLM (Qwen 2.5 1.5B) to operate a power grid using GRPO (Group Relative Policy Optimization).\n",
12
  "\n",
13
+ "- **Environment**: OpenGrid \u2014 multi-agent POMDP with safety layer & oversight agent\n",
14
  "- **Task**: Maintain 50 Hz frequency, prevent line overloads, avoid blackouts\n",
15
  "- **Training**: TRL GRPOTrainer + Unsloth 4-bit quantization\n",
16
  "\n",
17
+ " **Runtime**: Select `T4 GPU` from Runtime \u2192 Change runtime type"
18
  ]
19
  },
20
  {
 
51
  "source": [
52
  "import os\n",
53
  "\n",
54
+ "# UPDATE THIS with your actual repo URL\n",
55
  "REPO_URL = \"https://github.com/krishnagoyal099/Opengrid_env.git\"\n",
56
  "\n",
57
  "if not os.path.exists(\"opengrid\"):\n",
 
84
  " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
85
  " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")\n",
86
  "else:\n",
87
+ " print(\" No GPU detected! Go to Runtime \u2192 Change runtime type \u2192 T4 GPU\")"
88
  ]
89
  },
90
  {
 
177
  " \"std_reward\": np.std(rewards),\n",
178
  " \"rewards\": rewards\n",
179
  " }\n",
180
+ " print(f\"[BASELINE] {task_id}: {np.mean(rewards):.2f} \u00b1 {np.std(rewards):.2f}\")\n",
181
  "\n",
182
  "# Save baseline for later comparison\n",
183
  "import pickle\n",
184
  "os.makedirs(\"training/outputs\", exist_ok=True)\n",
185
  "with open(\"training/outputs/baseline_results.pkl\", \"wb\") as f:\n",
186
  " pickle.dump(baseline_results, f)\n",
187
+ "print(\"\\n Baseline scores saved.\")"
188
  ]
189
  },
190
  {
 
222
  "if tokenizer.pad_token is None:\n",
223
  " tokenizer.pad_token = tokenizer.eos_token\n",
224
  "\n",
225
+ "print(f\" Model loaded: {MODEL_NAME}\")\n",
226
  "print(f\" Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}\")"
227
  ]
228
  },
 
261
  "\n",
262
  " for t in range(min(10, task_config['max_steps'])):\n",
263
  " for agent_id, obs in zone_obs.items():\n",
264
+ " # model_dump_json() \u2192 json.loads() ensures all keys are strings\n",
265
  " obs_dict = _json.loads(obs.model_dump_json())\n",
266
  " prompt_text = format_observation_prompt(obs_dict, zone_name=obs.zone_name)\n",
267
  " messages = [\n",
 
272
  " messages, tokenize=False, add_generation_prompt=True\n",
273
  " )\n",
274
  " prompts.append(formatted)\n",
275
+ " # Store as JSON string \u2014 flat scalar, no schema-inference issues\n",
276
  " obs_contexts.append(_json.dumps(obs_dict))\n",
277
  "\n",
278
  " # Advance env with diverse random actions (no slack bus)\n",
 
293
  " break\n",
294
  " zone_obs = result.observations\n",
295
  "\n",
296
+ "print(f\" Generated {len(prompts)} training prompts\")\n",
297
  "print(f\"\\nSample prompt (first 400 chars):\")\n",
298
  "print(prompts[0][:400])"
299
  ]
 
312
  "outputs": [],
313
  "source": [
314
  "import json as _json\n",
315
+ "from training.train_grpo import compute_grpo_reward_env, extract_action\n",
316
  "\n",
317
  "def reward_fn(completions, obs_context=None, **kwargs):\n",
318
+ " \"\"\"GRPO reward function with env-grounded physics rewards.\"\"\"\n",
 
 
319
  " texts = []\n",
320
  " for c in completions:\n",
321
  " if isinstance(c, list):\n",
322
+ " text = c[-1][\"content\"] if c else \"\"\n",
323
  " else:\n",
324
  " text = str(c)\n",
325
  " texts.append(text)\n",
326
  "\n",
 
327
  " if obs_context is None:\n",
328
  " batch_obs = [None] * len(texts)\n",
329
  " else:\n",
 
331
  " _json.loads(ctx) if isinstance(ctx, str) else ctx\n",
332
  " for ctx in obs_context\n",
333
  " ]\n",
334
+ " return compute_grpo_reward_env(texts, batch_obs, task_config, horizon=3)\n",
335
  "\n",
336
+ "# Sanity test\n",
337
  "test_rewards = reward_fn([\n",
338
  " '{\"bus_adjustments\": [{\"bus_id\": 1, \"delta\": 5.0}], \"topology_actions\": []}',\n",
339
+ " \"invalid json here\",\n",
340
  "])\n",
341
  "print(f\"Test rewards: {test_rewards}\")\n",
342
+ "assert len(test_rewards) == 2\n",
343
+ "print(\"[OK] reward_fn works\")\n"
344
  ]
345
  },
346
  {
347
  "cell_type": "markdown",
348
  "metadata": {},
349
  "source": [
350
+ "## 9. Train with GRPO "
351
  ]
352
  },
353
  {
 
365
  "\n",
366
  "grpo_config = GRPOConfig(\n",
367
  " output_dir=\"training/outputs/grpo_checkpoints\",\n",
368
+ " num_train_epochs=3,\n",
369
  " per_device_train_batch_size=2,\n",
370
+ " gradient_accumulation_steps=8,\n",
371
+ " learning_rate=1e-5,\n",
372
  " logging_steps=5,\n",
373
  " save_steps=50,\n",
374
  " max_completion_length=256,\n",
375
+ " num_generations=8,\n",
376
  " report_to=\"none\",\n",
377
  " remove_unused_columns=False,\n",
378
  " bf16=_bf16,\n",
379
  " fp16=_fp16,\n",
380
  ")\n",
381
  "\n",
382
+ "# obs_contexts are JSON strings \u2014 PyArrow handles flat strings with no issues\n",
383
  "train_dataset = Dataset.from_dict({\"prompt\": prompts, \"obs_context\": obs_contexts})\n",
384
  "print(f\"Dataset: {len(train_dataset)} rows, columns: {train_dataset.column_names}\")\n",
385
  "\n",
 
393
  "\n",
394
  "print(f\"Training on {len(prompts)} prompts, {grpo_config.num_train_epochs} epoch(s)\")\n",
395
  "print(f\"Effective batch size: {grpo_config.per_device_train_batch_size * grpo_config.gradient_accumulation_steps}\")\n",
396
+ "print(\"\\n Starting GRPO training...\")\n",
397
  "\n",
398
  "train_result = trainer.train()\n",
399
  "\n",
400
+ "print(\"\\n Training complete!\")\n",
401
  "print(f\" Total steps: {trainer.state.global_step}\")"
402
  ]
403
  },
 
417
  "OUTPUT_PATH = \"training/outputs/trained_model\"\n",
418
  "trainer.save_model(OUTPUT_PATH)\n",
419
  "tokenizer.save_pretrained(OUTPUT_PATH)\n",
420
+ "print(f\" Model saved to {OUTPUT_PATH}\")"
421
  ]
422
  },
423
  {
 
476
  " \"std_reward\": np.std(rewards),\n",
477
  " \"rewards\": rewards\n",
478
  " }\n",
479
+ " print(f\"[TRAINED] {task_id}: {np.mean(rewards):.2f} \u00b1 {np.std(rewards):.2f}\\n\")"
480
  ]
481
  },
482
  {
483
  "cell_type": "markdown",
484
  "metadata": {},
485
  "source": [
486
+ "## 12. Generate Before/After Plots "
487
  ]
488
  },
489
  {
 
499
  "with open(\"training/outputs/baseline_results.pkl\", \"rb\") as f:\n",
500
  " baseline_results = pickle.load(f)\n",
501
  "\n",
502
+ "# \u2500\u2500 Plot 1: Before vs After Bar Chart \u2500\u2500\n",
503
  "common_tasks = [t for t in baseline_results if t in trained_results]\n",
504
  "fig, ax = plt.subplots(figsize=(10, 6))\n",
505
  "x = np.arange(len(common_tasks))\n",
 
513
  "\n",
514
  "ax.set_xlabel('Task', fontsize=12)\n",
515
  "ax.set_ylabel('Average Episode Reward', fontsize=12)\n",
516
+ "ax.set_title('OpenGrid \u2014 GRPO Training: Before vs After', fontsize=14, fontweight='bold')\n",
517
  "ax.set_xticks(x)\n",
518
  "ax.set_xticklabels([t.replace('task_', '').title() for t in common_tasks])\n",
519
  "ax.legend(fontsize=11)\n",
 
533
  "plt.tight_layout()\n",
534
  "plt.savefig('training/outputs/before_after.png', dpi=150)\n",
535
  "plt.show()\n",
536
+ "print(\" Saved: training/outputs/before_after.png\")"
537
  ]
538
  },
539
  {
 
542
  "metadata": {},
543
  "outputs": [],
544
  "source": [
545
+ "# \u2500\u2500 Plot 2: Training Reward Curve \u2500\u2500\n",
546
  "history = trainer.state.log_history\n",
547
  "\n",
548
  "steps = [h['step'] for h in history if 'loss' in h]\n",
 
557
  "\n",
558
  "ax.set_xlabel('Training Step', fontsize=12)\n",
559
  "ax.set_ylabel('Loss', fontsize=12)\n",
560
+ "ax.set_title('OpenGrid GRPO \u2014 Training Loss', fontsize=14, fontweight='bold')\n",
561
  "ax.legend()\n",
562
  "ax.grid(True, alpha=0.3)\n",
563
  "plt.tight_layout()\n",
564
  "plt.savefig('training/outputs/training_loss.png', dpi=150)\n",
565
  "plt.show()\n",
566
+ "print(\" Saved: training/outputs/training_loss.png\")"
567
  ]
568
  },
569
  {
 
582
  "outputs": [],
583
  "source": [
584
  "print(\"=\"*60)\n",
585
+ "print(\" OpenGrid GRPO Training \u2014 Results Summary\")\n",
586
  "print(\"=\"*60)\n",
587
  "\n",
588
  "# Rebuild common_tasks in case Cell 12 was skipped\n",
589
  "common_tasks = [t for t in baseline_results if t in trained_results]\n",
590
  "\n",
591
+ "print(f\"{'Task':<20} {'Baseline':>12} {'Trained':>12} {'\u0394':>10}\")\n",
592
  "print(\"-\"*60)\n",
593
  "for t in common_tasks:\n",
594
  " b = baseline_results[t]['avg_reward']\n",
595
  " a = trained_results[t]['avg_reward']\n",
596
  " delta = a - b\n",
597
+ " arrow = '\u2191' if delta > 0 else '\u2193'\n",
598
  " print(f\"{t:<20} {b:>10.2f} {a:>10.2f} {arrow} {abs(delta):.2f}\")\n",
599
  "print(\"=\"*60)"
600
  ]
 
605
  "metadata": {},
606
  "outputs": [],
607
  "source": [
608
+ "# Display plots inline\n",
609
+ "from IPython.display import Image, display\n",
610
+ "display(Image(\"training/outputs/before_after.png\"))\n",
611
+ "display(Image(\"training/outputs/training_loss.png\"))\n"
612
  ]
613
  }
614
  ],
 
629
  },
630
  "nbformat": 4,
631
  "nbformat_minor": 0
632
+ }