File size: 3,245 Bytes
bcce6af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | """Update the notebook: fix rewards, hyperparams, remove emojis, show plots inline."""
import json
nb = json.load(open('training/opengrid_grpo_colab.ipynb', encoding='utf-8'))
# Remove emojis from all cells
for cell in nb['cells']:
for i, line in enumerate(cell.get('source', [])):
for emoji in ['🔋','⚡','🚀','📊','✅','⚠️']:
line = line.replace(emoji, '')
cell['source'][i] = line
# Fix Cell 8: use compute_grpo_reward_env
for cell in nb['cells']:
src = ''.join(cell.get('source', []))
if 'compute_grpo_reward,' in src and 'def reward_fn' in src:
cell['source'] = [
'import json as _json\n',
'from training.train_grpo import compute_grpo_reward_env, extract_action\n',
'\n',
'def reward_fn(completions, obs_context=None, **kwargs):\n',
' """GRPO reward function with env-grounded physics rewards."""\n',
' texts = []\n',
' for c in completions:\n',
' if isinstance(c, list):\n',
' text = c[-1]["content"] if c else ""\n',
' else:\n',
' text = str(c)\n',
' texts.append(text)\n',
'\n',
' if obs_context is None:\n',
' batch_obs = [None] * len(texts)\n',
' else:\n',
' batch_obs = [\n',
' _json.loads(ctx) if isinstance(ctx, str) else ctx\n',
' for ctx in obs_context\n',
' ]\n',
' return compute_grpo_reward_env(texts, batch_obs, task_config, horizon=3)\n',
'\n',
'# Sanity test\n',
'test_rewards = reward_fn([\n',
' \'{"bus_adjustments": [{"bus_id": 1, "delta": 5.0}], "topology_actions": []}\',\n',
' "invalid json here",\n',
'])\n',
'print(f"Test rewards: {test_rewards}")\n',
'assert len(test_rewards) == 2\n',
'print("[OK] reward_fn works")\n',
]
break
# Fix Cell 9: update hyperparameters
for cell in nb['cells']:
src = ''.join(cell.get('source', []))
if 'GRPOConfig(' in src and 'num_generations' in src:
new_src = src.replace('num_train_epochs=1', 'num_train_epochs=3')
new_src = new_src.replace('gradient_accumulation_steps=4', 'gradient_accumulation_steps=8')
new_src = new_src.replace('learning_rate=5e-6', 'learning_rate=1e-5')
new_src = new_src.replace('num_generations=4', 'num_generations=8')
cell['source'] = new_src.splitlines(True)
break
# Fix download cell: replace google.colab with inline display
for cell in nb['cells']:
src = ''.join(cell.get('source', []))
if 'google.colab' in src:
cell['source'] = [
'# Display plots inline\n',
'from IPython.display import Image, display\n',
'display(Image("training/outputs/before_after.png"))\n',
'display(Image("training/outputs/training_loss.png"))\n',
]
break
json.dump(nb, open('training/opengrid_grpo_colab.ipynb', 'w', encoding='utf-8'), indent=1)
print("Notebook updated successfully")
|