| """Update the notebook: fix rewards, hyperparams, remove emojis, show plots inline.""" |
| import json |
|
|
| nb = json.load(open('training/opengrid_grpo_colab.ipynb', encoding='utf-8')) |
|
|
| |
| for cell in nb['cells']: |
| for i, line in enumerate(cell.get('source', [])): |
| for emoji in ['🔋','⚡','🚀','📊','✅','⚠️']: |
| line = line.replace(emoji, '') |
| cell['source'][i] = line |
|
|
| |
| for cell in nb['cells']: |
| src = ''.join(cell.get('source', [])) |
| if 'compute_grpo_reward,' in src and 'def reward_fn' in src: |
| cell['source'] = [ |
| 'import json as _json\n', |
| 'from training.train_grpo import compute_grpo_reward_env, extract_action\n', |
| '\n', |
| 'def reward_fn(completions, obs_context=None, **kwargs):\n', |
| ' """GRPO reward function with env-grounded physics rewards."""\n', |
| ' texts = []\n', |
| ' for c in completions:\n', |
| ' if isinstance(c, list):\n', |
| ' text = c[-1]["content"] if c else ""\n', |
| ' else:\n', |
| ' text = str(c)\n', |
| ' texts.append(text)\n', |
| '\n', |
| ' if obs_context is None:\n', |
| ' batch_obs = [None] * len(texts)\n', |
| ' else:\n', |
| ' batch_obs = [\n', |
| ' _json.loads(ctx) if isinstance(ctx, str) else ctx\n', |
| ' for ctx in obs_context\n', |
| ' ]\n', |
| ' return compute_grpo_reward_env(texts, batch_obs, task_config, horizon=3)\n', |
| '\n', |
| '# Sanity test\n', |
| 'test_rewards = reward_fn([\n', |
| ' \'{"bus_adjustments": [{"bus_id": 1, "delta": 5.0}], "topology_actions": []}\',\n', |
| ' "invalid json here",\n', |
| '])\n', |
| 'print(f"Test rewards: {test_rewards}")\n', |
| 'assert len(test_rewards) == 2\n', |
| 'print("[OK] reward_fn works")\n', |
| ] |
| break |
|
|
| |
| for cell in nb['cells']: |
| src = ''.join(cell.get('source', [])) |
| if 'GRPOConfig(' in src and 'num_generations' in src: |
| new_src = src.replace('num_train_epochs=1', 'num_train_epochs=3') |
| new_src = new_src.replace('gradient_accumulation_steps=4', 'gradient_accumulation_steps=8') |
| new_src = new_src.replace('learning_rate=5e-6', 'learning_rate=1e-5') |
| new_src = new_src.replace('num_generations=4', 'num_generations=8') |
| cell['source'] = new_src.splitlines(True) |
| break |
|
|
| |
| for cell in nb['cells']: |
| src = ''.join(cell.get('source', [])) |
| if 'google.colab' in src: |
| cell['source'] = [ |
| '# Display plots inline\n', |
| 'from IPython.display import Image, display\n', |
| 'display(Image("training/outputs/before_after.png"))\n', |
| 'display(Image("training/outputs/training_loss.png"))\n', |
| ] |
| break |
|
|
| json.dump(nb, open('training/opengrid_grpo_colab.ipynb', 'w', encoding='utf-8'), indent=1) |
| print("Notebook updated successfully") |
|
|