Spaces:
Sleeping
Sleeping
muskan singh commited on
Commit ·
9e29238
1
Parent(s): a35bcd0
training notebook
Browse files- training/grpo_orgos.ipynb +452 -336
training/grpo_orgos.ipynb
CHANGED
|
@@ -1,54 +1,39 @@
|
|
| 1 |
{
|
| 2 |
-
"nbformat": 4,
|
| 3 |
-
"nbformat_minor": 5,
|
| 4 |
-
"metadata": {
|
| 5 |
-
"kernelspec": {
|
| 6 |
-
"display_name": "Python 3",
|
| 7 |
-
"language": "python",
|
| 8 |
-
"name": "python3"
|
| 9 |
-
},
|
| 10 |
-
"language_info": {
|
| 11 |
-
"name": "python",
|
| 12 |
-
"version": "3.10.0"
|
| 13 |
-
},
|
| 14 |
-
"colab": {
|
| 15 |
-
"gpuType": "T4",
|
| 16 |
-
"provenance": []
|
| 17 |
-
},
|
| 18 |
-
"accelerator": "GPU"
|
| 19 |
-
},
|
| 20 |
"cells": [
|
| 21 |
{
|
| 22 |
"cell_type": "markdown",
|
| 23 |
"id": "title",
|
| 24 |
"metadata": {},
|
| 25 |
"source": [
|
| 26 |
-
"# OrgOS GRPO Training
|
| 27 |
"\n",
|
| 28 |
"**Environment:** OrgOS — Multi-App Enterprise RL Environment \n",
|
| 29 |
"**Model:** `Qwen/Qwen2.5-3B-Instruct` (4-bit LoRA via Unsloth) \n",
|
| 30 |
"**Algorithm:** GRPO (Group Relative Policy Optimization) via HuggingFace TRL \n",
|
| 31 |
-
"**
|
| 32 |
-
"\n",
|
| 33 |
-
"##
|
| 34 |
-
"
|
| 35 |
-
"
|
| 36 |
-
"
|
| 37 |
-
"
|
| 38 |
-
"
|
| 39 |
-
"
|
| 40 |
-
"\n",
|
| 41 |
-
"
|
| 42 |
-
"
|
| 43 |
-
"
|
| 44 |
-
"
|
|
|
|
| 45 |
]
|
| 46 |
},
|
| 47 |
{
|
| 48 |
"cell_type": "markdown",
|
| 49 |
"id": "sec1",
|
| 50 |
"metadata": {},
|
| 51 |
-
"source": [
|
|
|
|
|
|
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"cell_type": "code",
|
|
@@ -57,67 +42,46 @@
|
|
| 57 |
"metadata": {},
|
| 58 |
"outputs": [],
|
| 59 |
"source": [
|
| 60 |
-
"
|
| 61 |
-
"!pip install -q
|
| 62 |
-
"!pip install -q
|
| 63 |
-
"!pip install -q matplotlib numpy\n",
|
| 64 |
-
"\n",
|
| 65 |
-
"# Clone / mount the OrgOS repo\n",
|
| 66 |
-
"import os\n",
|
| 67 |
-
"if not os.path.exists('/content/openEnv'):\n",
|
| 68 |
-
" !git clone https://huggingface.co/spaces/YOUR_HF_USERNAME/orgos-openenv /content/openEnv\n",
|
| 69 |
-
" # Alternatively: upload the repo zip and unzip it here\n",
|
| 70 |
-
"\n",
|
| 71 |
-
"os.chdir('/content/openEnv')\n",
|
| 72 |
-
"print('Working directory:', os.getcwd())"
|
| 73 |
]
|
| 74 |
},
|
| 75 |
{
|
| 76 |
"cell_type": "markdown",
|
| 77 |
"id": "sec2",
|
| 78 |
"metadata": {},
|
| 79 |
-
"source": [
|
|
|
|
|
|
|
| 80 |
},
|
| 81 |
{
|
| 82 |
"cell_type": "code",
|
| 83 |
"execution_count": null,
|
| 84 |
-
"id": "
|
| 85 |
"metadata": {},
|
| 86 |
"outputs": [],
|
| 87 |
"source": [
|
| 88 |
-
"
|
| 89 |
-
"import torch\n",
|
| 90 |
"\n",
|
| 91 |
-
"
|
| 92 |
-
"
|
| 93 |
"\n",
|
| 94 |
-
"
|
| 95 |
-
"
|
| 96 |
-
" max_seq_length = MAX_SEQ_LEN,\n",
|
| 97 |
-
" dtype = None, # auto-detect\n",
|
| 98 |
-
" load_in_4bit = True,\n",
|
| 99 |
-
")\n",
|
| 100 |
"\n",
|
| 101 |
-
"
|
| 102 |
-
"
|
| 103 |
-
"
|
| 104 |
-
" r = 16,\n",
|
| 105 |
-
" target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj',\n",
|
| 106 |
-
" 'gate_proj', 'up_proj', 'down_proj'],\n",
|
| 107 |
-
" lora_alpha = 16,\n",
|
| 108 |
-
" lora_dropout = 0,\n",
|
| 109 |
-
" bias = 'none',\n",
|
| 110 |
-
" use_gradient_checkpointing = 'unsloth',\n",
|
| 111 |
-
" random_state = 42,\n",
|
| 112 |
-
")\n",
|
| 113 |
-
"print(f'Model loaded — trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')"
|
| 114 |
]
|
| 115 |
},
|
| 116 |
{
|
| 117 |
"cell_type": "markdown",
|
| 118 |
"id": "sec3",
|
| 119 |
"metadata": {},
|
| 120 |
-
"source": [
|
|
|
|
|
|
|
| 121 |
},
|
| 122 |
{
|
| 123 |
"cell_type": "code",
|
|
@@ -129,203 +93,365 @@
|
|
| 129 |
"import subprocess, time, httpx\n",
|
| 130 |
"\n",
|
| 131 |
"server_proc = subprocess.Popen(\n",
|
| 132 |
-
" [
|
| 133 |
-
" stdout=subprocess.DEVNULL,
|
|
|
|
| 134 |
")\n",
|
| 135 |
-
"time.sleep(
|
| 136 |
"\n",
|
| 137 |
-
"health = httpx.get(
|
| 138 |
-
"assert health[
|
| 139 |
-
"print(
|
| 140 |
]
|
| 141 |
},
|
| 142 |
{
|
| 143 |
"cell_type": "markdown",
|
| 144 |
"id": "sec4",
|
| 145 |
"metadata": {},
|
| 146 |
-
"source": [
|
|
|
|
|
|
|
| 147 |
},
|
| 148 |
{
|
| 149 |
"cell_type": "code",
|
| 150 |
"execution_count": null,
|
| 151 |
-
"id": "
|
| 152 |
"metadata": {},
|
| 153 |
"outputs": [],
|
| 154 |
"source": [
|
| 155 |
-
"
|
| 156 |
-
"
|
| 157 |
"\n",
|
| 158 |
-
"
|
|
|
|
| 159 |
"\n",
|
| 160 |
-
"
|
| 161 |
-
"
|
| 162 |
-
"
|
| 163 |
-
"
|
| 164 |
-
"
|
| 165 |
-
"
|
| 166 |
-
" f\"step_count: {obs['step_count']}\\n\"\n",
|
| 167 |
-
" f\"workflow_id: {obs['workflow_id']}\\n\\n\"\n",
|
| 168 |
-
" f\"=== WORKFLOW GOAL ===\\n{obs['workflow_goal']}\\n\\n\"\n",
|
| 169 |
-
" f\"=== PENDING STEPS ===\\n\" + ('\\n'.join(f'- {s}' for s in pending) or '(done!)') + \"\\n\\n\"\n",
|
| 170 |
-
" f\"=== SCHEMA HINTS ===\\n{json.dumps(hints, indent=2)}\\n\\n\"\n",
|
| 171 |
-
" f\"=== ACTIVE RULES ===\\n{json.dumps(obs.get('active_rules', {}), indent=2)}\\n\\n\"\n",
|
| 172 |
-
" f\"=== LAST MESSAGE ===\\n{obs['message']}\\n\"\n",
|
| 173 |
-
" )\n",
|
| 174 |
"\n",
|
| 175 |
-
"
|
| 176 |
-
"
|
| 177 |
-
"
|
| 178 |
-
"
|
| 179 |
-
"
|
| 180 |
-
"
|
| 181 |
-
"
|
| 182 |
-
"
|
| 183 |
-
"
|
| 184 |
-
"
|
| 185 |
-
"
|
| 186 |
-
"
|
| 187 |
-
"
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
"
|
| 196 |
-
" Run one episode. Returns (trajectory, final_score).\n",
|
| 197 |
-
" trajectory = list of {'messages': [...], 'reward': float}\n",
|
| 198 |
-
" \"\"\"\n",
|
| 199 |
-
" resp = httpx.post('http://localhost:8000/reset', json={'workflow_id': workflow_id})\n",
|
| 200 |
-
" obs = resp.json()['observation']\n",
|
| 201 |
-
" history = []\n",
|
| 202 |
-
" trajectory = []\n",
|
| 203 |
-
" cumulative_reward = 0.0\n",
|
| 204 |
"\n",
|
| 205 |
-
"
|
| 206 |
-
"
|
| 207 |
-
"
|
| 208 |
"\n",
|
| 209 |
-
"
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
"\n",
|
| 212 |
-
"
|
| 213 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
"\n",
|
| 215 |
-
" history.append({'role': 'assistant', 'content': action_str})\n",
|
| 216 |
"\n",
|
| 217 |
-
"
|
| 218 |
-
"
|
| 219 |
-
"
|
| 220 |
-
"
|
| 221 |
-
"
|
| 222 |
-
"
|
| 223 |
-
"
|
| 224 |
-
"
|
| 225 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
"\n",
|
| 227 |
-
" if action is None:\n",
|
| 228 |
-
" cumulative_reward -= 0.05\n",
|
| 229 |
-
" break\n",
|
| 230 |
"\n",
|
| 231 |
-
"
|
| 232 |
-
"
|
| 233 |
-
"
|
| 234 |
-
"
|
| 235 |
"\n",
|
| 236 |
-
"
|
| 237 |
-
"
|
| 238 |
-
"
|
| 239 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
" })\n",
|
| 241 |
"\n",
|
| 242 |
-
"
|
| 243 |
-
"
|
| 244 |
-
"\n",
|
| 245 |
-
" return trajectory, obs.get('current_score', 0.001)\n",
|
| 246 |
-
"\n",
|
| 247 |
-
"print('Rollout harness ready.')"
|
| 248 |
]
|
| 249 |
},
|
| 250 |
{
|
| 251 |
"cell_type": "markdown",
|
| 252 |
-
"id": "
|
| 253 |
"metadata": {},
|
| 254 |
-
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
},
|
| 256 |
{
|
| 257 |
"cell_type": "code",
|
| 258 |
"execution_count": null,
|
| 259 |
-
"id": "
|
| 260 |
"metadata": {},
|
| 261 |
"outputs": [],
|
| 262 |
"source": [
|
| 263 |
-
"import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
"\n",
|
| 265 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
"\n",
|
| 267 |
-
"
|
| 268 |
-
"
|
|
|
|
|
|
|
| 269 |
"\n",
|
| 270 |
-
"
|
| 271 |
-
"
|
| 272 |
-
"
|
| 273 |
-
"
|
| 274 |
-
"
|
| 275 |
-
"
|
| 276 |
-
" print(f' Workflow {wf} ep {ep+1}: score={score:.4f}', end='\\r')\n",
|
| 277 |
-
" print(f' Workflow {wf}: mean={np.mean(baseline_scores[wf]):.4f} ± {np.std(baseline_scores[wf]):.4f}')\n",
|
| 278 |
"\n",
|
| 279 |
-
"
|
| 280 |
-
"
|
| 281 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
]
|
| 283 |
},
|
| 284 |
{
|
| 285 |
"cell_type": "markdown",
|
| 286 |
-
"id": "
|
| 287 |
"metadata": {},
|
| 288 |
-
"source": [
|
|
|
|
|
|
|
| 289 |
},
|
| 290 |
{
|
| 291 |
"cell_type": "code",
|
| 292 |
"execution_count": null,
|
| 293 |
-
"id": "
|
| 294 |
"metadata": {},
|
| 295 |
"outputs": [],
|
| 296 |
"source": [
|
| 297 |
-
"
|
| 298 |
"\n",
|
| 299 |
-
"
|
| 300 |
-
"
|
| 301 |
-
"
|
| 302 |
-
"
|
| 303 |
-
" \"\"\"\n",
|
| 304 |
-
"
|
| 305 |
-
"
|
| 306 |
-
"
|
| 307 |
-
"
|
| 308 |
-
"
|
| 309 |
-
"
|
| 310 |
-
"
|
| 311 |
-
"
|
| 312 |
-
"
|
| 313 |
-
" )\n",
|
| 314 |
-
"
|
| 315 |
-
"
|
| 316 |
-
"\n",
|
| 317 |
-
"
|
| 318 |
-
"
|
| 319 |
-
"
|
| 320 |
-
"
|
| 321 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
]
|
| 323 |
},
|
| 324 |
{
|
| 325 |
"cell_type": "markdown",
|
| 326 |
-
"id": "
|
| 327 |
"metadata": {},
|
| 328 |
-
"source": [
|
|
|
|
|
|
|
| 329 |
},
|
| 330 |
{
|
| 331 |
"cell_type": "code",
|
|
@@ -336,164 +462,174 @@
|
|
| 336 |
"source": [
|
| 337 |
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 338 |
"\n",
|
| 339 |
-
"#
|
| 340 |
-
"
|
| 341 |
-
" \"\"\"GRPO reward function — called on each group of completions.\"\"\"\n",
|
| 342 |
-
" # In GRPO the rewards come from rollouts; we pre-compute them above.\n",
|
| 343 |
-
" # This function returns the rewards already stored in the dataset.\n",
|
| 344 |
-
" return kwargs.get('reward', [0.0] * len(completions))\n",
|
| 345 |
"\n",
|
| 346 |
"grpo_config = GRPOConfig(\n",
|
| 347 |
-
" output_dir
|
| 348 |
-
" num_train_epochs
|
| 349 |
-
" per_device_train_batch_size =
|
| 350 |
-
" gradient_accumulation_steps =
|
| 351 |
-
" learning_rate
|
| 352 |
-
" warmup_steps
|
| 353 |
-
" logging_steps
|
| 354 |
-
" save_steps
|
| 355 |
-
"
|
| 356 |
-
"
|
| 357 |
-
" max_grad_norm
|
| 358 |
" # GRPO-specific\n",
|
| 359 |
-
" num_generations
|
| 360 |
-
" max_new_tokens
|
| 361 |
-
" temperature
|
| 362 |
-
" beta
|
| 363 |
-
" report_to
|
| 364 |
-
" seed
|
| 365 |
")\n",
|
| 366 |
"\n",
|
| 367 |
"trainer = GRPOTrainer(\n",
|
| 368 |
" model = model,\n",
|
| 369 |
" args = grpo_config,\n",
|
| 370 |
-
" reward_funcs =
|
| 371 |
-
" train_dataset =
|
| 372 |
-
"
|
| 373 |
")\n",
|
| 374 |
"\n",
|
| 375 |
-
"print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
"train_result = trainer.train()\n",
|
| 377 |
-
"print(
|
| 378 |
"print(train_result.metrics)"
|
| 379 |
]
|
| 380 |
},
|
| 381 |
{
|
| 382 |
"cell_type": "markdown",
|
| 383 |
-
"id": "
|
| 384 |
"metadata": {},
|
| 385 |
-
"source": [
|
|
|
|
|
|
|
| 386 |
},
|
| 387 |
{
|
| 388 |
"cell_type": "code",
|
| 389 |
"execution_count": null,
|
| 390 |
-
"id": "
|
| 391 |
"metadata": {},
|
| 392 |
"outputs": [],
|
| 393 |
"source": [
|
| 394 |
-
"# Switch model to inference mode\n",
|
| 395 |
"FastLanguageModel.for_inference(model)\n",
|
| 396 |
"\n",
|
| 397 |
-
"
|
| 398 |
-
"post_scores = {'A': [], 'B': [], 'C': []}\n",
|
| 399 |
"\n",
|
| 400 |
-
"print(
|
| 401 |
-
"for wf in [
|
| 402 |
-
" for ep in range(N_EVAL
|
| 403 |
-
"
|
| 404 |
" post_scores[wf].append(score)\n",
|
| 405 |
-
" print(f
|
| 406 |
-
" print(f
|
| 407 |
"\n",
|
| 408 |
-
"
|
|
|
|
|
|
|
| 409 |
]
|
| 410 |
},
|
| 411 |
{
|
| 412 |
"cell_type": "markdown",
|
| 413 |
-
"id": "
|
| 414 |
"metadata": {},
|
| 415 |
-
"source": [
|
|
|
|
|
|
|
| 416 |
},
|
| 417 |
{
|
| 418 |
"cell_type": "code",
|
| 419 |
"execution_count": null,
|
| 420 |
-
"id": "
|
| 421 |
"metadata": {},
|
| 422 |
"outputs": [],
|
| 423 |
"source": [
|
| 424 |
"import matplotlib.pyplot as plt\n",
|
| 425 |
"import matplotlib.gridspec as gridspec\n",
|
| 426 |
"\n",
|
| 427 |
-
"fig = plt.figure(figsize=(14, 8), facecolor=
|
| 428 |
-
"fig.suptitle(
|
| 429 |
-
" color=
|
| 430 |
"\n",
|
| 431 |
"gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)\n",
|
| 432 |
"\n",
|
| 433 |
-
"COLORS = {
|
| 434 |
-
"WF_LABELS = {
|
| 435 |
-
"
|
| 436 |
-
"
|
|
|
|
|
|
|
| 437 |
"\n",
|
| 438 |
-
"for col, wf in enumerate([
|
| 439 |
" ax = fig.add_subplot(gs[0, col])\n",
|
| 440 |
-
" ax.set_facecolor(COLORS[
|
| 441 |
-
" ax.grid(color=COLORS[
|
| 442 |
"\n",
|
| 443 |
" before = baseline_scores[wf]\n",
|
| 444 |
" after = post_scores[wf]\n",
|
|
|
|
| 445 |
"\n",
|
| 446 |
-
" ax.plot(before, color=COLORS[
|
| 447 |
-
" ax.plot(after, color=COLORS[
|
| 448 |
-
"\n",
|
| 449 |
-
" ax.axhline(np.mean(
|
| 450 |
-
" ax.axhline(np.mean(after), color=COLORS['after'], linestyle='--', linewidth=1, alpha=0.5)\n",
|
| 451 |
"\n",
|
| 452 |
-
"
|
| 453 |
-
" ax.
|
| 454 |
-
" ax.
|
| 455 |
-
" ax.
|
| 456 |
-
" ax.tick_params(colors='#64748b', labelsize=7)\n",
|
| 457 |
" ax.set_ylim(0, 1)\n",
|
| 458 |
-
" ax.legend(fontsize=7, facecolor=
|
| 459 |
-
" edgecolor=
|
| 460 |
" for spine in ax.spines.values():\n",
|
| 461 |
-
" spine.set_edgecolor(
|
| 462 |
"\n",
|
| 463 |
-
"# Bottom row: combined histogram\n",
|
| 464 |
"ax_hist = fig.add_subplot(gs[1, :])\n",
|
| 465 |
-
"ax_hist.set_facecolor(COLORS[
|
| 466 |
-
"ax_hist.grid(color=COLORS[
|
| 467 |
"\n",
|
| 468 |
"all_before = [s for v in baseline_scores.values() for s in v]\n",
|
| 469 |
"all_after = [s for v in post_scores.values() for s in v]\n",
|
| 470 |
-
"\n",
|
| 471 |
"bins = np.linspace(0, 1, 25)\n",
|
| 472 |
-
"
|
| 473 |
-
"ax_hist.hist(
|
| 474 |
-
"
|
| 475 |
-
"ax_hist.
|
| 476 |
-
"\n",
|
| 477 |
-
"ax_hist.
|
| 478 |
-
"ax_hist.
|
| 479 |
-
"
|
| 480 |
-
"ax_hist.
|
| 481 |
-
"ax_hist.
|
| 482 |
-
"
|
|
|
|
|
|
|
|
|
|
| 483 |
"for spine in ax_hist.spines.values():\n",
|
| 484 |
-
" spine.set_edgecolor(
|
| 485 |
"\n",
|
| 486 |
-
"plt.savefig(
|
| 487 |
-
" facecolor=
|
| 488 |
"plt.show()\n",
|
| 489 |
-
"print(
|
| 490 |
]
|
| 491 |
},
|
| 492 |
{
|
| 493 |
"cell_type": "markdown",
|
| 494 |
-
"id": "
|
| 495 |
"metadata": {},
|
| 496 |
-
"source": [
|
|
|
|
|
|
|
| 497 |
},
|
| 498 |
{
|
| 499 |
"cell_type": "code",
|
|
@@ -502,49 +638,29 @@
|
|
| 502 |
"metadata": {},
|
| 503 |
"outputs": [],
|
| 504 |
"source": [
|
| 505 |
-
"
|
| 506 |
-
"
|
| 507 |
-
"
|
| 508 |
-
"print('LoRA adapter saved to ./orgos_lora_adapter')\n",
|
| 509 |
"\n",
|
| 510 |
-
"#
|
| 511 |
"# from huggingface_hub import login\n",
|
| 512 |
-
"# login(token=
|
| 513 |
-
"# model.push_to_hub(
|
| 514 |
-
"# tokenizer.push_to_hub(
|
| 515 |
-
"# print('Pushed to HuggingFace Hub!')"
|
| 516 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 517 |
},
|
| 518 |
-
{
|
| 519 |
-
"
|
| 520 |
-
"
|
| 521 |
-
"metadata": {},
|
| 522 |
-
"source": [
|
| 523 |
-
"## 11. Summary\n",
|
| 524 |
-
"\n",
|
| 525 |
-
"```\n",
|
| 526 |
-
"OrgOS GRPO Training Summary\n",
|
| 527 |
-
"============================\n",
|
| 528 |
-
"Model: Qwen2.5-3B-Instruct + 4-bit LoRA\n",
|
| 529 |
-
"Algorithm: GRPO (Group Relative Policy Optimization)\n",
|
| 530 |
-
"Epochs: 3\n",
|
| 531 |
-
"Episodes: 30 baseline + 30 post-training\n",
|
| 532 |
-
"\n",
|
| 533 |
-
"Key result: The GRPO-trained model learns to:\n",
|
| 534 |
-
" 1. Read schema_hints before constructing action args\n",
|
| 535 |
-
" 2. Use drifted field names (e.g. 'severity' not 'priority')\n",
|
| 536 |
-
" 3. Complete workflow steps in the correct order\n",
|
| 537 |
-
" 4. Avoid RBAC violations by checking role constraints\n",
|
| 538 |
-
"\n",
|
| 539 |
-
"This produces a clear, measurable improvement visible in\n",
|
| 540 |
-
"before_after_curves.png — the core evidence for judging.\n",
|
| 541 |
-
"```\n",
|
| 542 |
-
"\n",
|
| 543 |
-
"**Artefacts produced:**\n",
|
| 544 |
-
"- `before_after_curves.png` — the money chart for the pitch\n",
|
| 545 |
-
"- `orgos_lora_adapter/` — the trained LoRA weights\n",
|
| 546 |
-
"- `baseline_scores.json` — raw score data"
|
| 547 |
-
]
|
| 548 |
}
|
| 549 |
-
|
|
|
|
|
|
|
| 550 |
}
|
|
|
|
| 1 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "markdown",
|
| 5 |
"id": "title",
|
| 6 |
"metadata": {},
|
| 7 |
"source": [
|
| 8 |
+
"# OrgOS GRPO Training\n",
|
| 9 |
"\n",
|
| 10 |
"**Environment:** OrgOS — Multi-App Enterprise RL Environment \n",
|
| 11 |
"**Model:** `Qwen/Qwen2.5-3B-Instruct` (4-bit LoRA via Unsloth) \n",
|
| 12 |
"**Algorithm:** GRPO (Group Relative Policy Optimization) via HuggingFace TRL \n",
|
| 13 |
+
"**Target hardware:** HuggingFace compute (A10G / A100) \n",
|
| 14 |
+
"\n",
|
| 15 |
+
"## How this works\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"GRPO is an **online** RL algorithm:\n",
|
| 18 |
+
"1. Each training step takes a batch of **prompts** (observations from the env)\n",
|
| 19 |
+
"2. The model generates **G candidate actions** per prompt (the group)\n",
|
| 20 |
+
"3. Each action is sent to the **live OrgOS env** to get a real reward\n",
|
| 21 |
+
"4. GRPO computes relative advantages within the group (which action did better than average?)\n",
|
| 22 |
+
"5. Model is updated to favour higher-reward actions\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"**Key training signal:** Schema drift creates a sharp reward gap.\n",
|
| 25 |
+
"Using a stale field name (e.g. `priority` when schema says `severity`) → **−0.20**. \n",
|
| 26 |
+
"Using the correct drifted name → **+0.10** adaptation bonus. \n",
|
| 27 |
+
"The model learns to read `schema_hints` before constructing action args."
|
| 28 |
]
|
| 29 |
},
|
| 30 |
{
|
| 31 |
"cell_type": "markdown",
|
| 32 |
"id": "sec1",
|
| 33 |
"metadata": {},
|
| 34 |
+
"source": [
|
| 35 |
+
"## 1. Install Dependencies"
|
| 36 |
+
]
|
| 37 |
},
|
| 38 |
{
|
| 39 |
"cell_type": "code",
|
|
|
|
| 42 |
"metadata": {},
|
| 43 |
"outputs": [],
|
| 44 |
"source": [
|
| 45 |
+
"!pip install -q \"unsloth[huggingface]\" \"trl>=0.12.0\" peft accelerate bitsandbytes\n",
|
| 46 |
+
"!pip install -q fastapi uvicorn httpx openai pydantic python-dotenv\n",
|
| 47 |
+
"!pip install -q matplotlib numpy datasets"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
]
|
| 49 |
},
|
| 50 |
{
|
| 51 |
"cell_type": "markdown",
|
| 52 |
"id": "sec2",
|
| 53 |
"metadata": {},
|
| 54 |
+
"source": [
|
| 55 |
+
"## 2. Clone the OrgOS Repo"
|
| 56 |
+
]
|
| 57 |
},
|
| 58 |
{
|
| 59 |
"cell_type": "code",
|
| 60 |
"execution_count": null,
|
| 61 |
+
"id": "clone_repo",
|
| 62 |
"metadata": {},
|
| 63 |
"outputs": [],
|
| 64 |
"source": [
|
| 65 |
+
"import os\n",
|
|
|
|
| 66 |
"\n",
|
| 67 |
+
"REPO_URL = \"https://huggingface.co/spaces/tanvibisht/orgos-openenv\"\n",
|
| 68 |
+
"REPO_DIR = \"/home/user/orgos\"\n",
|
| 69 |
"\n",
|
| 70 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 71 |
+
" !git clone {REPO_URL} {REPO_DIR}\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
"\n",
|
| 73 |
+
"os.chdir(REPO_DIR)\n",
|
| 74 |
+
"print(\"Working directory:\", os.getcwd())\n",
|
| 75 |
+
"!ls"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
]
|
| 77 |
},
|
| 78 |
{
|
| 79 |
"cell_type": "markdown",
|
| 80 |
"id": "sec3",
|
| 81 |
"metadata": {},
|
| 82 |
+
"source": [
|
| 83 |
+
"## 3. Start the OrgOS Environment Server"
|
| 84 |
+
]
|
| 85 |
},
|
| 86 |
{
|
| 87 |
"cell_type": "code",
|
|
|
|
| 93 |
"import subprocess, time, httpx\n",
|
| 94 |
"\n",
|
| 95 |
"server_proc = subprocess.Popen(\n",
|
| 96 |
+
" [\"python\", \"-m\", \"uvicorn\", \"server.app:app\", \"--host\", \"0.0.0.0\", \"--port\", \"8000\"],\n",
|
| 97 |
+
" stdout=subprocess.DEVNULL,\n",
|
| 98 |
+
" stderr=subprocess.DEVNULL,\n",
|
| 99 |
")\n",
|
| 100 |
+
"time.sleep(4)\n",
|
| 101 |
"\n",
|
| 102 |
+
"health = httpx.get(\"http://localhost:8000/health\").json()\n",
|
| 103 |
+
"assert health[\"status\"] == \"healthy\", f\"Server not healthy: {health}\"\n",
|
| 104 |
+
"print(\"OrgOS server running:\", health)"
|
| 105 |
]
|
| 106 |
},
|
| 107 |
{
|
| 108 |
"cell_type": "markdown",
|
| 109 |
"id": "sec4",
|
| 110 |
"metadata": {},
|
| 111 |
+
"source": [
|
| 112 |
+
"## 4. Load Model with Unsloth 4-bit LoRA"
|
| 113 |
+
]
|
| 114 |
},
|
| 115 |
{
|
| 116 |
"cell_type": "code",
|
| 117 |
"execution_count": null,
|
| 118 |
+
"id": "load_model",
|
| 119 |
"metadata": {},
|
| 120 |
"outputs": [],
|
| 121 |
"source": [
|
| 122 |
+
"from unsloth import FastLanguageModel\n",
|
| 123 |
+
"import torch\n",
|
| 124 |
"\n",
|
| 125 |
+
"MAX_SEQ_LEN = 2048\n",
|
| 126 |
+
"MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
|
| 127 |
"\n",
|
| 128 |
+
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 129 |
+
" model_name = MODEL_NAME,\n",
|
| 130 |
+
" max_seq_length = MAX_SEQ_LEN,\n",
|
| 131 |
+
" dtype = None,\n",
|
| 132 |
+
" load_in_4bit = True,\n",
|
| 133 |
+
")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
"\n",
|
| 135 |
+
"model = FastLanguageModel.get_peft_model(\n",
|
| 136 |
+
" model,\n",
|
| 137 |
+
" r = 16,\n",
|
| 138 |
+
" target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 139 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 140 |
+
" lora_alpha = 16,\n",
|
| 141 |
+
" lora_dropout = 0,\n",
|
| 142 |
+
" bias = \"none\",\n",
|
| 143 |
+
" use_gradient_checkpointing = \"unsloth\",\n",
|
| 144 |
+
" random_state = 42,\n",
|
| 145 |
+
")\n",
|
| 146 |
+
"trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 147 |
+
"print(f\"Model loaded — trainable params: {trainable:,}\")"
|
| 148 |
+
]
|
| 149 |
+
},
|
| 150 |
+
{
|
| 151 |
+
"cell_type": "markdown",
|
| 152 |
+
"id": "sec5",
|
| 153 |
+
"metadata": {},
|
| 154 |
+
"source": [
|
| 155 |
+
"## 5. Prompt Dataset\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
"\n",
|
| 157 |
+
"We collect **first-turn observations** from fresh episode resets as our prompt dataset.\n",
|
| 158 |
+
"These are the most important turns — they contain `schema_hints`, `active_rules`, and the\n",
|
| 159 |
+
"full workflow goal. The model must learn to read schema hints and produce a correct first action.\n",
|
| 160 |
"\n",
|
| 161 |
+
"During GRPO training, the reward function will reset the env and evaluate each generated action live."
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
{
|
| 165 |
+
"cell_type": "code",
|
| 166 |
+
"execution_count": null,
|
| 167 |
+
"id": "build_prompts",
|
| 168 |
+
"metadata": {},
|
| 169 |
+
"outputs": [],
|
| 170 |
+
"source": [
|
| 171 |
+
"import json\n",
|
| 172 |
+
"from datasets import Dataset\n",
|
| 173 |
"\n",
|
| 174 |
+
"SYSTEM_PROMPT = \"\"\"\\\n",
|
| 175 |
+
"You are OrgOS Agent — an enterprise workflow automation agent.\n",
|
| 176 |
+
"You operate across four SaaS applications: Jira, Zendesk, Salesforce, and Workday.\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"Each turn you receive a JSON observation with:\n",
|
| 179 |
+
" - workflow_goal : the task you must complete\n",
|
| 180 |
+
" - pending_steps : remaining steps in the workflow\n",
|
| 181 |
+
" - app_states : current state of each app\n",
|
| 182 |
+
" - schema_hints : field renames in effect this episode (e.g. {\"jira.priority\": \"severity\"})\n",
|
| 183 |
+
" - active_rules : current SLA / approval thresholds\n",
|
| 184 |
+
" - message : feedback from the last action\n",
|
| 185 |
+
" - current_score : your cumulative score (0.001-0.999)\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"Respond ONLY with a valid JSON object — no markdown, no explanation.\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"Action format:\n",
|
| 190 |
+
" {\"app\": \"<app>\", \"operation\": \"<op>\", \"args\": {...}}\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"Available apps and key operations:\n",
|
| 193 |
+
" jira: get_issue, create_issue, update_status, set_priority, assign_owner,\n",
|
| 194 |
+
" add_label, link_zendesk_ticket, close_issue, list_issues\n",
|
| 195 |
+
" zendesk: get_ticket, acknowledge_ticket, set_urgency, assign_agent,\n",
|
| 196 |
+
" escalate_to_jira, resolve_ticket, add_note, list_tickets,\n",
|
| 197 |
+
" create_agent_profile\n",
|
| 198 |
+
" salesforce: get_account, list_accounts, update_deal_stage, flag_churn_risk,\n",
|
| 199 |
+
" assign_account_owner, log_interaction, get_opportunity\n",
|
| 200 |
+
" workday: get_employee, list_employees, provision_access, log_sla_event,\n",
|
| 201 |
+
" request_budget_approval, create_onboarding_task, complete_task\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"CRITICAL RULES:\n",
|
| 204 |
+
"1. Read schema_hints FIRST — if \"jira.priority\" -> \"severity\", use \"severity\" not \"priority\" in args.\n",
|
| 205 |
+
"2. Complete ALL pending_steps in order.\n",
|
| 206 |
+
"3. Do not repeat a successful action.\n",
|
| 207 |
+
"4. If an operation fails, read the message carefully and adapt.\n",
|
| 208 |
+
"5. Use list_* operations to discover record IDs when needed.\n",
|
| 209 |
+
"6. Stop when pending_steps is empty or done=true.\n",
|
| 210 |
+
"\"\"\"\n",
|
| 211 |
"\n",
|
|
|
|
| 212 |
"\n",
|
| 213 |
+
"def obs_to_text(obs: dict) -> str:\n",
|
| 214 |
+
" hints = obs.get(\"schema_hints\", {})\n",
|
| 215 |
+
" pending = obs.get(\"pending_steps\", [])\n",
|
| 216 |
+
" lines = [\n",
|
| 217 |
+
" f\"current_score: {obs['current_score']}\",\n",
|
| 218 |
+
" f\"step_count: {obs['step_count']}\",\n",
|
| 219 |
+
" f\"workflow_id: {obs['workflow_id']}\",\n",
|
| 220 |
+
" \"\",\n",
|
| 221 |
+
" \"=== WORKFLOW GOAL ===\",\n",
|
| 222 |
+
" obs[\"workflow_goal\"],\n",
|
| 223 |
+
" \"\",\n",
|
| 224 |
+
" \"=== PENDING STEPS ===\",\n",
|
| 225 |
+
" \"\\n\".join(f\" - {s}\" for s in pending) or \" (all steps complete!)\",\n",
|
| 226 |
+
" \"\",\n",
|
| 227 |
+
" \"=== SCHEMA HINTS (use these field names) ===\",\n",
|
| 228 |
+
" json.dumps(hints, indent=2) if hints else \" (no drift — use canonical names)\",\n",
|
| 229 |
+
" \"\",\n",
|
| 230 |
+
" \"=== ACTIVE RULES ===\",\n",
|
| 231 |
+
" json.dumps(obs.get(\"active_rules\", {}), indent=2),\n",
|
| 232 |
+
" \"\",\n",
|
| 233 |
+
" \"=== LAST MESSAGE ===\",\n",
|
| 234 |
+
" obs[\"message\"],\n",
|
| 235 |
+
" \"\",\n",
|
| 236 |
+
" \"=== APP STATES ===\",\n",
|
| 237 |
+
" ]\n",
|
| 238 |
+
" for app_name, view in obs.get(\"app_states\", {}).items():\n",
|
| 239 |
+
" lines.append(f\" [{app_name.upper()}]\")\n",
|
| 240 |
+
" lines.append(f\" {view}\")\n",
|
| 241 |
+
" lines.append(\"\")\n",
|
| 242 |
+
" return \"\\n\".join(lines)\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"def build_prompt(obs_text: str) -> str:\n",
|
| 246 |
+
" \"\"\"Format as a chat prompt with system injected into first user message.\"\"\"\n",
|
| 247 |
+
" messages = [{\"role\": \"user\", \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + obs_text}]\n",
|
| 248 |
+
" return tokenizer.apply_chat_template(\n",
|
| 249 |
+
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 250 |
+
" )\n",
|
| 251 |
"\n",
|
|
|
|
|
|
|
|
|
|
| 252 |
"\n",
|
| 253 |
+
"# Collect first-turn observations across all 3 workflows, multiple episodes\n",
|
| 254 |
+
"# Each episode has a different schema version (seed varies) so we get diverse prompts\n",
|
| 255 |
+
"N_PROMPTS_PER_WORKFLOW = 20\n",
|
| 256 |
+
"prompt_rows = []\n",
|
| 257 |
"\n",
|
| 258 |
+
"print(\"Collecting prompts from env resets...\")\n",
|
| 259 |
+
"for wf in [\"A\", \"B\", \"C\"]:\n",
|
| 260 |
+
" for _ in range(N_PROMPTS_PER_WORKFLOW):\n",
|
| 261 |
+
" result = httpx.post(\"http://localhost:8000/reset\", json={\"workflow_id\": wf}).json()\n",
|
| 262 |
+
" obs = result[\"observation\"]\n",
|
| 263 |
+
" obs_text = obs_to_text(obs)\n",
|
| 264 |
+
" prompt_rows.append({\n",
|
| 265 |
+
" \"prompt\": build_prompt(obs_text),\n",
|
| 266 |
+
" \"workflow_id\": wf,\n",
|
| 267 |
+
" \"obs_text\": obs_text,\n",
|
| 268 |
" })\n",
|
| 269 |
"\n",
|
| 270 |
+
"prompt_dataset = Dataset.from_list(prompt_rows)\n",
|
| 271 |
+
"print(f\"Prompt dataset: {len(prompt_dataset)} examples across 3 workflows\")\n",
|
| 272 |
+
"print(\"Sample prompt (truncated):\\n\", prompt_rows[0][\"prompt\"][:600], \"...\")"
|
|
|
|
|
|
|
|
|
|
| 273 |
]
|
| 274 |
},
|
| 275 |
{
|
| 276 |
"cell_type": "markdown",
|
| 277 |
+
"id": "sec6",
|
| 278 |
"metadata": {},
|
| 279 |
+
"source": [
|
| 280 |
+
"## 6. Reward Function\n",
|
| 281 |
+
"\n",
|
| 282 |
+
"Called by GRPOTrainer during training on each batch of generated completions.\n",
|
| 283 |
+
"For each completion:\n",
|
| 284 |
+
"1. Parse it as action JSON\n",
|
| 285 |
+
"2. Reset the env to a fresh episode for the right workflow\n",
|
| 286 |
+
"3. Send the action via `/step`\n",
|
| 287 |
+
"4. Return the reward\n",
|
| 288 |
+
"\n",
|
| 289 |
+
"This gives the model a live signal from the actual environment."
|
| 290 |
+
]
|
| 291 |
},
|
| 292 |
{
|
| 293 |
"cell_type": "code",
|
| 294 |
"execution_count": null,
|
| 295 |
+
"id": "reward_fn",
|
| 296 |
"metadata": {},
|
| 297 |
"outputs": [],
|
| 298 |
"source": [
|
| 299 |
+
"import re\n",
|
| 300 |
+
"from typing import List\n",
|
| 301 |
+
"\n",
|
| 302 |
+
"ENV_URL = \"http://localhost:8000\"\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"\n",
|
| 305 |
+
"def parse_action(text: str):\n",
|
| 306 |
+
" \"\"\"Extract JSON action from model output.\"\"\"\n",
|
| 307 |
+
" text = text.strip()\n",
|
| 308 |
+
" # Strip markdown code fences if present\n",
|
| 309 |
+
" text = re.sub(r\"```(?:json)?\\s*\", \"\", text).strip()\n",
|
| 310 |
+
" try:\n",
|
| 311 |
+
" return json.loads(text)\n",
|
| 312 |
+
" except json.JSONDecodeError:\n",
|
| 313 |
+
" m = re.search(r\"\\{.*\\}\", text, re.DOTALL)\n",
|
| 314 |
+
" if m:\n",
|
| 315 |
+
" try:\n",
|
| 316 |
+
" return json.loads(m.group())\n",
|
| 317 |
+
" except Exception:\n",
|
| 318 |
+
" pass\n",
|
| 319 |
+
" return None\n",
|
| 320 |
+
"\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"def orgos_reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:\n",
|
| 323 |
+
" \"\"\"\n",
|
| 324 |
+
" GRPO reward function — called by GRPOTrainer each training step.\n",
|
| 325 |
"\n",
|
| 326 |
+
" For each generated completion:\n",
|
| 327 |
+
" - Parse as action JSON\n",
|
| 328 |
+
" - Reset env to a fresh episode (workflow inferred from prompt)\n",
|
| 329 |
+
" - Step the env with the action\n",
|
| 330 |
+
" - Return the step reward\n",
|
| 331 |
"\n",
|
| 332 |
+
" Invalid JSON or failed actions return a -0.1 penalty.\n",
|
| 333 |
+
" \"\"\"\n",
|
| 334 |
+
" workflow_ids = kwargs.get(\"workflow_id\", [\"A\"] * len(completions))\n",
|
| 335 |
+
" rewards = []\n",
|
| 336 |
"\n",
|
| 337 |
+
" for completion, wf_id in zip(completions, workflow_ids):\n",
|
| 338 |
+
" action = parse_action(completion)\n",
|
| 339 |
+
"\n",
|
| 340 |
+
" if action is None:\n",
|
| 341 |
+
" rewards.append(-0.1)\n",
|
| 342 |
+
" continue\n",
|
|
|
|
|
|
|
| 343 |
"\n",
|
| 344 |
+
" try:\n",
|
| 345 |
+
" # Fresh episode for this action evaluation\n",
|
| 346 |
+
" httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": wf_id}, timeout=10)\n",
|
| 347 |
+
" result = httpx.post(f\"{ENV_URL}/step\", json=action, timeout=10).json()\n",
|
| 348 |
+
" rewards.append(float(result[\"reward\"]))\n",
|
| 349 |
+
" except Exception:\n",
|
| 350 |
+
" rewards.append(-0.1)\n",
|
| 351 |
+
"\n",
|
| 352 |
+
" return rewards\n",
|
| 353 |
+
"\n",
|
| 354 |
+
"\n",
|
| 355 |
+
"print(\"Reward function defined.\")\n",
|
| 356 |
+
"print(\"Quick sanity check...\")\n",
|
| 357 |
+
"test_rewards = orgos_reward_fn(\n",
|
| 358 |
+
" completions = ['{\"app\": \"zendesk\", \"operation\": \"list_tickets\", \"args\": {\"state\": \"new\"}}',\n",
|
| 359 |
+
" 'this is not valid json'],\n",
|
| 360 |
+
" prompts = [\"\", \"\"],\n",
|
| 361 |
+
" workflow_id = [\"A\", \"A\"],\n",
|
| 362 |
+
")\n",
|
| 363 |
+
"print(f\" Valid action reward: {test_rewards[0]:.4f}\")\n",
|
| 364 |
+
"print(f\" Invalid action reward: {test_rewards[1]:.4f}\")"
|
| 365 |
]
|
| 366 |
},
|
| 367 |
{
|
| 368 |
"cell_type": "markdown",
|
| 369 |
+
"id": "sec7",
|
| 370 |
"metadata": {},
|
| 371 |
+
"source": [
|
| 372 |
+
"## 7. Collect Baseline Scores (Pre-Training)"
|
| 373 |
+
]
|
| 374 |
},
|
| 375 |
{
|
| 376 |
"cell_type": "code",
|
| 377 |
"execution_count": null,
|
| 378 |
+
"id": "baseline",
|
| 379 |
"metadata": {},
|
| 380 |
"outputs": [],
|
| 381 |
"source": [
|
| 382 |
+
"import numpy as np\n",
|
| 383 |
"\n",
|
| 384 |
+
"FastLanguageModel.for_inference(model)\n",
|
| 385 |
+
"\n",
|
| 386 |
+
"\n",
|
| 387 |
+
"def run_episode_with_model(workflow_id: str, max_steps: int = 15) -> float:\n",
|
| 388 |
+
" \"\"\"Run one full episode with the current model. Returns final score.\"\"\"\n",
|
| 389 |
+
" result = httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": workflow_id}).json()\n",
|
| 390 |
+
" obs = result[\"observation\"]\n",
|
| 391 |
+
" history = []\n",
|
| 392 |
+
"\n",
|
| 393 |
+
" for _ in range(max_steps):\n",
|
| 394 |
+
" if obs[\"done\"]:\n",
|
| 395 |
+
" break\n",
|
| 396 |
+
"\n",
|
| 397 |
+
" obs_text = obs_to_text(obs)\n",
|
| 398 |
+
" history.append({\"role\": \"user\", \"content\": obs_text})\n",
|
| 399 |
+
"\n",
|
| 400 |
+
" # Inject system prompt into first user message\n",
|
| 401 |
+
" messages = list(history)\n",
|
| 402 |
+
" messages[0] = {\"role\": \"user\", \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + messages[0][\"content\"]}\n",
|
| 403 |
+
"\n",
|
| 404 |
+
" text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 405 |
+
" inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
|
| 406 |
+
"\n",
|
| 407 |
+
" with torch.no_grad():\n",
|
| 408 |
+
" out = model.generate(\n",
|
| 409 |
+
" **inputs,\n",
|
| 410 |
+
" max_new_tokens = 256,\n",
|
| 411 |
+
" temperature = 0.0,\n",
|
| 412 |
+
" do_sample = False,\n",
|
| 413 |
+
" pad_token_id = tokenizer.eos_token_id,\n",
|
| 414 |
+
" )\n",
|
| 415 |
+
" action_str = tokenizer.decode(\n",
|
| 416 |
+
" out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n",
|
| 417 |
+
" ).strip()\n",
|
| 418 |
+
"\n",
|
| 419 |
+
" history.append({\"role\": \"assistant\", \"content\": action_str})\n",
|
| 420 |
+
"\n",
|
| 421 |
+
" action = parse_action(action_str)\n",
|
| 422 |
+
" if action is None:\n",
|
| 423 |
+
" break\n",
|
| 424 |
+
"\n",
|
| 425 |
+
" result = httpx.post(f\"{ENV_URL}/step\", json=action).json()\n",
|
| 426 |
+
" obs = result[\"observation\"]\n",
|
| 427 |
+
" if obs[\"done\"]:\n",
|
| 428 |
+
" break\n",
|
| 429 |
+
"\n",
|
| 430 |
+
" return obs.get(\"current_score\", 0.001)\n",
|
| 431 |
+
"\n",
|
| 432 |
+
"\n",
|
| 433 |
+
"N_EVAL = 10 # episodes per workflow for evaluation\n",
|
| 434 |
+
"baseline_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
|
| 435 |
+
"\n",
|
| 436 |
+
"print(\"Collecting pre-training baseline scores...\")\n",
|
| 437 |
+
"for wf in [\"A\", \"B\", \"C\"]:\n",
|
| 438 |
+
" for ep in range(N_EVAL):\n",
|
| 439 |
+
" score = run_episode_with_model(wf)\n",
|
| 440 |
+
" baseline_scores[wf].append(score)\n",
|
| 441 |
+
" print(f\" Workflow {wf} ep {ep+1}/{N_EVAL}: score={score:.4f}\", end=\"\\r\")\n",
|
| 442 |
+
" print(f\" Workflow {wf}: mean={np.mean(baseline_scores[wf]):.4f}\")\n",
|
| 443 |
+
"\n",
|
| 444 |
+
"baseline_mean = np.mean([s for v in baseline_scores.values() for s in v])\n",
|
| 445 |
+
"print(f\"\\nOverall baseline mean: {baseline_mean:.4f}\")"
|
| 446 |
]
|
| 447 |
},
|
| 448 |
{
|
| 449 |
"cell_type": "markdown",
|
| 450 |
+
"id": "sec8",
|
| 451 |
"metadata": {},
|
| 452 |
+
"source": [
|
| 453 |
+
"## 8. GRPO Training"
|
| 454 |
+
]
|
| 455 |
},
|
| 456 |
{
|
| 457 |
"cell_type": "code",
|
|
|
|
| 462 |
"source": [
|
| 463 |
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 464 |
"\n",
|
| 465 |
+
"# Switch back to training mode\n",
|
| 466 |
+
"model.train()\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
"\n",
|
| 468 |
"grpo_config = GRPOConfig(\n",
|
| 469 |
+
" output_dir = \"./orgos_grpo_ckpt\",\n",
|
| 470 |
+
" num_train_epochs = 3,\n",
|
| 471 |
+
" per_device_train_batch_size = 4,\n",
|
| 472 |
+
" gradient_accumulation_steps = 2,\n",
|
| 473 |
+
" learning_rate = 5e-5,\n",
|
| 474 |
+
" warmup_steps = 10,\n",
|
| 475 |
+
" logging_steps = 5,\n",
|
| 476 |
+
" save_steps = 100,\n",
|
| 477 |
+
" bf16 = torch.cuda.is_bf16_supported(),\n",
|
| 478 |
+
" fp16 = not torch.cuda.is_bf16_supported(),\n",
|
| 479 |
+
" max_grad_norm = 1.0,\n",
|
| 480 |
" # GRPO-specific\n",
|
| 481 |
+
" num_generations = 4, # G: candidate actions per prompt\n",
|
| 482 |
+
" max_new_tokens = 256,\n",
|
| 483 |
+
" temperature = 0.8, # exploration during training\n",
|
| 484 |
+
" beta = 0.04, # KL penalty coefficient\n",
|
| 485 |
+
" report_to = \"none\",\n",
|
| 486 |
+
" seed = 42,\n",
|
| 487 |
")\n",
|
| 488 |
"\n",
|
| 489 |
"trainer = GRPOTrainer(\n",
|
| 490 |
" model = model,\n",
|
| 491 |
" args = grpo_config,\n",
|
| 492 |
+
" reward_funcs = orgos_reward_fn,\n",
|
| 493 |
+
" train_dataset = prompt_dataset,\n",
|
| 494 |
+
" processing_class = tokenizer,\n",
|
| 495 |
")\n",
|
| 496 |
"\n",
|
| 497 |
+
"print(\"Starting GRPO training...\")\n",
|
| 498 |
+
"print(f\" Prompts: {len(prompt_dataset)}\")\n",
|
| 499 |
+
"print(f\" Generations per prompt (G): {grpo_config.num_generations}\")\n",
|
| 500 |
+
"print(f\" Epochs: {grpo_config.num_train_epochs}\")\n",
|
| 501 |
+
"print(f\" Total env calls per epoch: ~{len(prompt_dataset) * grpo_config.num_generations}\")\n",
|
| 502 |
+
"print()\n",
|
| 503 |
+
"\n",
|
| 504 |
"train_result = trainer.train()\n",
|
| 505 |
+
"print(\"\\nTraining complete!\")\n",
|
| 506 |
"print(train_result.metrics)"
|
| 507 |
]
|
| 508 |
},
|
| 509 |
{
|
| 510 |
"cell_type": "markdown",
|
| 511 |
+
"id": "sec9",
|
| 512 |
"metadata": {},
|
| 513 |
+
"source": [
|
| 514 |
+
"## 9. Collect Post-Training Scores"
|
| 515 |
+
]
|
| 516 |
},
|
| 517 |
{
|
| 518 |
"cell_type": "code",
|
| 519 |
"execution_count": null,
|
| 520 |
+
"id": "post_training",
|
| 521 |
"metadata": {},
|
| 522 |
"outputs": [],
|
| 523 |
"source": [
|
|
|
|
| 524 |
"FastLanguageModel.for_inference(model)\n",
|
| 525 |
"\n",
|
| 526 |
+
"post_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
|
|
|
|
| 527 |
"\n",
|
| 528 |
+
"print(\"Collecting post-training scores...\")\n",
|
| 529 |
+
"for wf in [\"A\", \"B\", \"C\"]:\n",
|
| 530 |
+
" for ep in range(N_EVAL):\n",
|
| 531 |
+
" score = run_episode_with_model(wf)\n",
|
| 532 |
" post_scores[wf].append(score)\n",
|
| 533 |
+
" print(f\" Workflow {wf} ep {ep+1}/{N_EVAL}: score={score:.4f}\", end=\"\\r\")\n",
|
| 534 |
+
" print(f\" Workflow {wf}: mean={np.mean(post_scores[wf]):.4f}\")\n",
|
| 535 |
"\n",
|
| 536 |
+
"post_mean = np.mean([s for v in post_scores.values() for s in v])\n",
|
| 537 |
+
"print(f\"\\nOverall post-training mean: {post_mean:.4f}\")\n",
|
| 538 |
+
"print(f\"Improvement: {post_mean - baseline_mean:+.4f}\")"
|
| 539 |
]
|
| 540 |
},
|
| 541 |
{
|
| 542 |
"cell_type": "markdown",
|
| 543 |
+
"id": "sec10",
|
| 544 |
"metadata": {},
|
| 545 |
+
"source": [
|
| 546 |
+
"## 10. Plot Before / After"
|
| 547 |
+
]
|
| 548 |
},
|
| 549 |
{
|
| 550 |
"cell_type": "code",
|
| 551 |
"execution_count": null,
|
| 552 |
+
"id": "plot",
|
| 553 |
"metadata": {},
|
| 554 |
"outputs": [],
|
| 555 |
"source": [
|
| 556 |
"import matplotlib.pyplot as plt\n",
|
| 557 |
"import matplotlib.gridspec as gridspec\n",
|
| 558 |
"\n",
|
| 559 |
+
"fig = plt.figure(figsize=(14, 8), facecolor=\"#0f172a\")\n",
|
| 560 |
+
"fig.suptitle(\"OrgOS: Before vs After GRPO Training\", fontsize=15,\n",
|
| 561 |
+
" color=\"white\", fontweight=\"bold\", y=0.98)\n",
|
| 562 |
"\n",
|
| 563 |
"gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)\n",
|
| 564 |
"\n",
|
| 565 |
+
"COLORS = {\"before\": \"#f87171\", \"after\": \"#34d399\", \"bg\": \"#1e293b\", \"grid\": \"#334155\"}\n",
|
| 566 |
+
"WF_LABELS = {\n",
|
| 567 |
+
" \"A\": \"Workflow A\\nCustomer Bug Fix\",\n",
|
| 568 |
+
" \"B\": \"Workflow B\\nEmployee Onboarding\",\n",
|
| 569 |
+
" \"C\": \"Workflow C\\nChurn Risk Alert\",\n",
|
| 570 |
+
"}\n",
|
| 571 |
"\n",
|
| 572 |
+
"for col, wf in enumerate([\"A\", \"B\", \"C\"]):\n",
|
| 573 |
" ax = fig.add_subplot(gs[0, col])\n",
|
| 574 |
+
" ax.set_facecolor(COLORS[\"bg\"])\n",
|
| 575 |
+
" ax.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.7)\n",
|
| 576 |
"\n",
|
| 577 |
" before = baseline_scores[wf]\n",
|
| 578 |
" after = post_scores[wf]\n",
|
| 579 |
+
" delta = np.mean(after) - np.mean(before)\n",
|
| 580 |
"\n",
|
| 581 |
+
" ax.plot(before, color=COLORS[\"before\"], linewidth=1.5, alpha=0.8, label=\"Before GRPO\")\n",
|
| 582 |
+
" ax.plot(after, color=COLORS[\"after\"], linewidth=1.5, alpha=0.8, label=\"After GRPO\")\n",
|
| 583 |
+
" ax.axhline(np.mean(before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
|
| 584 |
+
" ax.axhline(np.mean(after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
|
|
|
|
| 585 |
"\n",
|
| 586 |
+
" ax.set_title(WF_LABELS[wf] + f\"\\n(Δ = {delta:+.4f})\", color=\"white\", fontsize=9)\n",
|
| 587 |
+
" ax.set_xlabel(\"Episode\", color=\"#94a3b8\", fontsize=8)\n",
|
| 588 |
+
" ax.set_ylabel(\"Final Score\", color=\"#94a3b8\", fontsize=8)\n",
|
| 589 |
+
" ax.tick_params(colors=\"#64748b\", labelsize=7)\n",
|
|
|
|
| 590 |
" ax.set_ylim(0, 1)\n",
|
| 591 |
+
" ax.legend(fontsize=7, facecolor=\"#1e293b\", labelcolor=\"white\",\n",
|
| 592 |
+
" edgecolor=\"#475569\", framealpha=0.8)\n",
|
| 593 |
" for spine in ax.spines.values():\n",
|
| 594 |
+
" spine.set_edgecolor(\"#334155\")\n",
|
| 595 |
"\n",
|
|
|
|
| 596 |
"ax_hist = fig.add_subplot(gs[1, :])\n",
|
| 597 |
+
"ax_hist.set_facecolor(COLORS[\"bg\"])\n",
|
| 598 |
+
"ax_hist.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.5, axis=\"x\")\n",
|
| 599 |
"\n",
|
| 600 |
"all_before = [s for v in baseline_scores.values() for s in v]\n",
|
| 601 |
"all_after = [s for v in post_scores.values() for s in v]\n",
|
|
|
|
| 602 |
"bins = np.linspace(0, 1, 25)\n",
|
| 603 |
+
"\n",
|
| 604 |
+
"ax_hist.hist(all_before, bins=bins, color=COLORS[\"before\"], alpha=0.6,\n",
|
| 605 |
+
" label=f\"Before GRPO (mean={np.mean(all_before):.4f})\", edgecolor=\"none\")\n",
|
| 606 |
+
"ax_hist.hist(all_after, bins=bins, color=COLORS[\"after\"], alpha=0.6,\n",
|
| 607 |
+
" label=f\"After GRPO (mean={np.mean(all_after):.4f})\", edgecolor=\"none\")\n",
|
| 608 |
+
"ax_hist.axvline(np.mean(all_before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1.5)\n",
|
| 609 |
+
"ax_hist.axvline(np.mean(all_after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1.5)\n",
|
| 610 |
+
"\n",
|
| 611 |
+
"ax_hist.set_title(\"Score Distribution Across All Workflows\", color=\"white\", fontsize=10)\n",
|
| 612 |
+
"ax_hist.set_xlabel(\"Final Score\", color=\"#94a3b8\", fontsize=9)\n",
|
| 613 |
+
"ax_hist.set_ylabel(\"Count\", color=\"#94a3b8\", fontsize=9)\n",
|
| 614 |
+
"ax_hist.tick_params(colors=\"#64748b\", labelsize=8)\n",
|
| 615 |
+
"ax_hist.legend(fontsize=9, facecolor=\"#1e293b\", labelcolor=\"white\",\n",
|
| 616 |
+
" edgecolor=\"#475569\", framealpha=0.9)\n",
|
| 617 |
"for spine in ax_hist.spines.values():\n",
|
| 618 |
+
" spine.set_edgecolor(\"#334155\")\n",
|
| 619 |
"\n",
|
| 620 |
+
"plt.savefig(\"before_after_curves.png\", dpi=150, bbox_inches=\"tight\",\n",
|
| 621 |
+
" facecolor=\"#0f172a\", edgecolor=\"none\")\n",
|
| 622 |
"plt.show()\n",
|
| 623 |
+
"print(\"Saved: before_after_curves.png\")"
|
| 624 |
]
|
| 625 |
},
|
| 626 |
{
|
| 627 |
"cell_type": "markdown",
|
| 628 |
+
"id": "sec11",
|
| 629 |
"metadata": {},
|
| 630 |
+
"source": [
|
| 631 |
+
"## 11. Save LoRA Adapter"
|
| 632 |
+
]
|
| 633 |
},
|
| 634 |
{
|
| 635 |
"cell_type": "code",
|
|
|
|
| 638 |
"metadata": {},
|
| 639 |
"outputs": [],
|
| 640 |
"source": [
|
| 641 |
+
"model.save_pretrained(\"orgos_lora_adapter\")\n",
|
| 642 |
+
"tokenizer.save_pretrained(\"orgos_lora_adapter\")\n",
|
| 643 |
+
"print(\"LoRA adapter saved to ./orgos_lora_adapter\")\n",
|
|
|
|
| 644 |
"\n",
|
| 645 |
+
"# Push to HuggingFace Hub\n",
|
| 646 |
"# from huggingface_hub import login\n",
|
| 647 |
+
"# login(token=\"YOUR_HF_TOKEN\")\n",
|
| 648 |
+
"# model.push_to_hub(\"YOUR_USERNAME/orgos-qwen25-3b-grpo\")\n",
|
| 649 |
+
"# tokenizer.push_to_hub(\"YOUR_USERNAME/orgos-qwen25-3b-grpo\")"
|
|
|
|
| 650 |
]
|
| 651 |
+
}
|
| 652 |
+
],
|
| 653 |
+
"metadata": {
|
| 654 |
+
"kernelspec": {
|
| 655 |
+
"display_name": "Python 3",
|
| 656 |
+
"language": "python",
|
| 657 |
+
"name": "python3"
|
| 658 |
},
|
| 659 |
+
"language_info": {
|
| 660 |
+
"name": "python",
|
| 661 |
+
"version": "3.10.0"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 662 |
}
|
| 663 |
+
},
|
| 664 |
+
"nbformat": 4,
|
| 665 |
+
"nbformat_minor": 5
|
| 666 |
}
|