muskan singh commited on
Commit
9e29238
·
1 Parent(s): a35bcd0

training notebook

Browse files
Files changed (1) hide show
  1. training/grpo_orgos.ipynb +452 -336
training/grpo_orgos.ipynb CHANGED
@@ -1,54 +1,39 @@
1
  {
2
- "nbformat": 4,
3
- "nbformat_minor": 5,
4
- "metadata": {
5
- "kernelspec": {
6
- "display_name": "Python 3",
7
- "language": "python",
8
- "name": "python3"
9
- },
10
- "language_info": {
11
- "name": "python",
12
- "version": "3.10.0"
13
- },
14
- "colab": {
15
- "gpuType": "T4",
16
- "provenance": []
17
- },
18
- "accelerator": "GPU"
19
- },
20
  "cells": [
21
  {
22
  "cell_type": "markdown",
23
  "id": "title",
24
  "metadata": {},
25
  "source": [
26
- "# OrgOS GRPO Training Notebook\n",
27
  "\n",
28
  "**Environment:** OrgOS — Multi-App Enterprise RL Environment \n",
29
  "**Model:** `Qwen/Qwen2.5-3B-Instruct` (4-bit LoRA via Unsloth) \n",
30
  "**Algorithm:** GRPO (Group Relative Policy Optimization) via HuggingFace TRL \n",
31
- "**Hardware:** Colab T4 (free tier compatible) \n",
32
- "\n",
33
- "## What this notebook does\n",
34
- "1. Installs dependencies (Unsloth + TRL)\n",
35
- "2. Loads Qwen2.5-3B-Instruct with 4-bit LoRA\n",
36
- "3. Collects **baseline rollouts** (untrained model) on Workflows A & C\n",
37
- "4. Fine-tunes with **GRPOTrainer** using OrgOS dense rewards\n",
38
- "5. Collects **post-training rollouts** and computes score improvement\n",
39
- "6. Plots the **before/after reward curve** for the demo\n",
40
- "\n",
41
- "**Key training signal:** The schema drift mechanic creates a sharp signal gap —\n",
42
- "an untrained model uses stale canonical field names (−0.20 per step),\n",
43
- "while a GRPO-trained model learns to read `schema_hints` first (+reward).\n",
44
- "This produces a clear, visually compelling before/after improvement."
 
45
  ]
46
  },
47
  {
48
  "cell_type": "markdown",
49
  "id": "sec1",
50
  "metadata": {},
51
- "source": ["## 1. Install Dependencies"]
 
 
52
  },
53
  {
54
  "cell_type": "code",
@@ -57,67 +42,46 @@
57
  "metadata": {},
58
  "outputs": [],
59
  "source": [
60
- "# Install Unsloth (optimised 4-bit LLM training) + TRL (GRPO)\n",
61
- "!pip install -q unsloth[colab-new] trl>=0.9.0 peft accelerate bitsandbytes\n",
62
- "!pip install -q fastapi uvicorn httpx openai pydantic\n",
63
- "!pip install -q matplotlib numpy\n",
64
- "\n",
65
- "# Clone / mount the OrgOS repo\n",
66
- "import os\n",
67
- "if not os.path.exists('/content/openEnv'):\n",
68
- " !git clone https://huggingface.co/spaces/YOUR_HF_USERNAME/orgos-openenv /content/openEnv\n",
69
- " # Alternatively: upload the repo zip and unzip it here\n",
70
- "\n",
71
- "os.chdir('/content/openEnv')\n",
72
- "print('Working directory:', os.getcwd())"
73
  ]
74
  },
75
  {
76
  "cell_type": "markdown",
77
  "id": "sec2",
78
  "metadata": {},
79
- "source": ["## 2. Load Model with Unsloth 4-bit LoRA"]
 
 
80
  },
81
  {
82
  "cell_type": "code",
83
  "execution_count": null,
84
- "id": "load_model",
85
  "metadata": {},
86
  "outputs": [],
87
  "source": [
88
- "from unsloth import FastLanguageModel\n",
89
- "import torch\n",
90
  "\n",
91
- "MAX_SEQ_LEN = 2048\n",
92
- "MODEL_NAME = 'Qwen/Qwen2.5-3B-Instruct'\n",
93
  "\n",
94
- "model, tokenizer = FastLanguageModel.from_pretrained(\n",
95
- " model_name = MODEL_NAME,\n",
96
- " max_seq_length = MAX_SEQ_LEN,\n",
97
- " dtype = None, # auto-detect\n",
98
- " load_in_4bit = True,\n",
99
- ")\n",
100
  "\n",
101
- "# Add LoRA adapters\n",
102
- "model = FastLanguageModel.get_peft_model(\n",
103
- " model,\n",
104
- " r = 16,\n",
105
- " target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj',\n",
106
- " 'gate_proj', 'up_proj', 'down_proj'],\n",
107
- " lora_alpha = 16,\n",
108
- " lora_dropout = 0,\n",
109
- " bias = 'none',\n",
110
- " use_gradient_checkpointing = 'unsloth',\n",
111
- " random_state = 42,\n",
112
- ")\n",
113
- "print(f'Model loaded — trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')"
114
  ]
115
  },
116
  {
117
  "cell_type": "markdown",
118
  "id": "sec3",
119
  "metadata": {},
120
- "source": ["## 3. Start the OrgOS Environment Server (subprocess)"]
 
 
121
  },
122
  {
123
  "cell_type": "code",
@@ -129,203 +93,365 @@
129
  "import subprocess, time, httpx\n",
130
  "\n",
131
  "server_proc = subprocess.Popen(\n",
132
- " ['python', '-m', 'uvicorn', 'server.app:app', '--host', '0.0.0.0', '--port', '8000'],\n",
133
- " stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL\n",
 
134
  ")\n",
135
- "time.sleep(3)\n",
136
  "\n",
137
- "health = httpx.get('http://localhost:8000/health').json()\n",
138
- "assert health['status'] == 'healthy', f'Server not healthy: {health}'\n",
139
- "print('OrgOS server running — health:', health)"
140
  ]
141
  },
142
  {
143
  "cell_type": "markdown",
144
  "id": "sec4",
145
  "metadata": {},
146
- "source": ["## 4. Rollout Harness (collect trajectories)"]
 
 
147
  },
148
  {
149
  "cell_type": "code",
150
  "execution_count": null,
151
- "id": "rollout_harness",
152
  "metadata": {},
153
  "outputs": [],
154
  "source": [
155
- "import json, re, sys\n",
156
- "from typing import List, Dict, Tuple\n",
157
  "\n",
158
- "SYSTEM_PROMPT = open('inference.py').read().split('SYSTEM_PROMPT = \\\"\\\"\\\"')[1].split('\\\"\\\"\\\"')[0]\n",
 
159
  "\n",
160
- "def obs_to_text(obs: dict) -> str:\n",
161
- " \"\"\"Convert observation dict to text for the model.\"\"\"\n",
162
- " hints = obs.get('schema_hints', {})\n",
163
- " pending = obs.get('pending_steps', [])\n",
164
- " return (\n",
165
- " f\"current_score: {obs['current_score']}\\n\"\n",
166
- " f\"step_count: {obs['step_count']}\\n\"\n",
167
- " f\"workflow_id: {obs['workflow_id']}\\n\\n\"\n",
168
- " f\"=== WORKFLOW GOAL ===\\n{obs['workflow_goal']}\\n\\n\"\n",
169
- " f\"=== PENDING STEPS ===\\n\" + ('\\n'.join(f'- {s}' for s in pending) or '(done!)') + \"\\n\\n\"\n",
170
- " f\"=== SCHEMA HINTS ===\\n{json.dumps(hints, indent=2)}\\n\\n\"\n",
171
- " f\"=== ACTIVE RULES ===\\n{json.dumps(obs.get('active_rules', {}), indent=2)}\\n\\n\"\n",
172
- " f\"=== LAST MESSAGE ===\\n{obs['message']}\\n\"\n",
173
- " )\n",
174
  "\n",
175
- "def generate_action(prompt_messages: List[Dict], max_tokens=256) -> str:\n",
176
- " \"\"\"Run the model to produce an action JSON string.\"\"\"\n",
177
- " from transformers import GenerationConfig\n",
178
- " # Format as chat\n",
179
- " text = tokenizer.apply_chat_template(\n",
180
- " prompt_messages, tokenize=False, add_generation_prompt=True\n",
181
- " )\n",
182
- " inputs = tokenizer(text, return_tensors='pt').to(model.device)\n",
183
- " with torch.no_grad():\n",
184
- " out = model.generate(\n",
185
- " **inputs,\n",
186
- " max_new_tokens = max_tokens,\n",
187
- " temperature = 0.7,\n",
188
- " do_sample = True,\n",
189
- " pad_token_id = tokenizer.eos_token_id,\n",
190
- " )\n",
191
- " decoded = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)\n",
192
- " return decoded.strip()\n",
193
- "\n",
194
- "def run_episode(workflow_id: str, max_steps: int = 15) -> Tuple[List[dict], float]:\n",
195
- " \"\"\"\n",
196
- " Run one episode. Returns (trajectory, final_score).\n",
197
- " trajectory = list of {'messages': [...], 'reward': float}\n",
198
- " \"\"\"\n",
199
- " resp = httpx.post('http://localhost:8000/reset', json={'workflow_id': workflow_id})\n",
200
- " obs = resp.json()['observation']\n",
201
- " history = []\n",
202
- " trajectory = []\n",
203
- " cumulative_reward = 0.0\n",
204
  "\n",
205
- " for step_i in range(max_steps):\n",
206
- " if obs['done']:\n",
207
- " break\n",
208
  "\n",
209
- " obs_text = obs_to_text(obs)\n",
210
- " history.append({'role': 'user', 'content': obs_text})\n",
 
 
 
 
 
 
 
 
 
 
211
  "\n",
212
- " msgs = [{'role': 'system', 'content': SYSTEM_PROMPT}] + history[-10:]\n",
213
- " action_str = generate_action(msgs)\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  "\n",
215
- " history.append({'role': 'assistant', 'content': action_str})\n",
216
  "\n",
217
- " # Parse action\n",
218
- " action = None\n",
219
- " try:\n",
220
- " action = json.loads(action_str)\n",
221
- " except:\n",
222
- " m = re.search(r'\\{.*\\}', action_str, re.DOTALL)\n",
223
- " if m:\n",
224
- " try: action = json.loads(m.group())\n",
225
- " except: pass\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  "\n",
227
- " if action is None:\n",
228
- " cumulative_reward -= 0.05\n",
229
- " break\n",
230
  "\n",
231
- " result = httpx.post('http://localhost:8000/step', json=action).json()\n",
232
- " obs = result['observation']\n",
233
- " reward = result['reward']\n",
234
- " cumulative_reward += reward\n",
235
  "\n",
236
- " # Store step for GRPO\n",
237
- " trajectory.append({\n",
238
- " 'messages': msgs + [{'role': 'assistant', 'content': action_str}],\n",
239
- " 'reward': reward,\n",
 
 
 
 
 
 
240
  " })\n",
241
  "\n",
242
- " if obs['done']:\n",
243
- " break\n",
244
- "\n",
245
- " return trajectory, obs.get('current_score', 0.001)\n",
246
- "\n",
247
- "print('Rollout harness ready.')"
248
  ]
249
  },
250
  {
251
  "cell_type": "markdown",
252
- "id": "sec5",
253
  "metadata": {},
254
- "source": ["## 5. Collect Baseline Rollouts (Pre-Training)"]
 
 
 
 
 
 
 
 
 
 
 
255
  },
256
  {
257
  "cell_type": "code",
258
  "execution_count": null,
259
- "id": "baseline_rollouts",
260
  "metadata": {},
261
  "outputs": [],
262
  "source": [
263
- "import numpy as np\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  "\n",
265
- "N_BASELINE = 30 # 30 episodes pre-training (10 per workflow)\n",
 
 
 
 
266
  "\n",
267
- "baseline_scores = {'A': [], 'B': [], 'C': []}\n",
268
- "all_trajectories = []\n",
 
 
269
  "\n",
270
- "print('Collecting baseline rollouts...')\n",
271
- "for wf in ['A', 'B', 'C']:\n",
272
- " for ep in range(N_BASELINE // 3):\n",
273
- " traj, score = run_episode(wf)\n",
274
- " baseline_scores[wf].append(score)\n",
275
- " all_trajectories.extend(traj)\n",
276
- " print(f' Workflow {wf} ep {ep+1}: score={score:.4f}', end='\\r')\n",
277
- " print(f' Workflow {wf}: mean={np.mean(baseline_scores[wf]):.4f} ± {np.std(baseline_scores[wf]):.4f}')\n",
278
  "\n",
279
- "print(f'\\nTotal baseline episodes: {N_BASELINE}')\n",
280
- "print(f'Total trajectory steps: {len(all_trajectories)}')\n",
281
- "print(f'Overall baseline mean: {np.mean([s for v in baseline_scores.values() for s in v]):.4f}')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  ]
283
  },
284
  {
285
  "cell_type": "markdown",
286
- "id": "sec6",
287
  "metadata": {},
288
- "source": ["## 6. Build GRPO Dataset from Trajectories"]
 
 
289
  },
290
  {
291
  "cell_type": "code",
292
  "execution_count": null,
293
- "id": "build_dataset",
294
  "metadata": {},
295
  "outputs": [],
296
  "source": [
297
- "from datasets import Dataset\n",
298
  "\n",
299
- "def trajectories_to_dataset(trajectories: List[dict]) -> Dataset:\n",
300
- " \"\"\"\n",
301
- " Convert trajectory steps into a GRPO-compatible dataset.\n",
302
- " Each row = one (prompt, completion, reward) triple.\n",
303
- " \"\"\"\n",
304
- " rows = []\n",
305
- " for step in trajectories:\n",
306
- " messages = step['messages']\n",
307
- " reward = step['reward']\n",
308
- " # Separate prompt (all but last assistant turn) from completion\n",
309
- " prompt_msgs = messages[:-1]\n",
310
- " completion = messages[-1]['content']\n",
311
- " prompt_text = tokenizer.apply_chat_template(\n",
312
- " prompt_msgs, tokenize=False, add_generation_prompt=True\n",
313
- " )\n",
314
- " rows.append({'prompt': prompt_text, 'completion': completion, 'reward': reward})\n",
315
- " return Dataset.from_list(rows)\n",
316
- "\n",
317
- "train_dataset = trajectories_to_dataset(all_trajectories)\n",
318
- "print(f'Training dataset: {len(train_dataset)} examples')\n",
319
- "print(f'Reward range: [{min(train_dataset[\"reward\"]):.4f}, {max(train_dataset[\"reward\"]):.4f}]')\n",
320
- "print(f'Mean reward: {np.mean(train_dataset[\"reward\"]):.4f}')\n",
321
- "train_dataset[0]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  ]
323
  },
324
  {
325
  "cell_type": "markdown",
326
- "id": "sec7",
327
  "metadata": {},
328
- "source": ["## 7. GRPO Training"]
 
 
329
  },
330
  {
331
  "cell_type": "code",
@@ -336,164 +462,174 @@
336
  "source": [
337
  "from trl import GRPOConfig, GRPOTrainer\n",
338
  "\n",
339
- "# Reward function for GRPO: directly use the env's per-step reward\n",
340
- "def reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:\n",
341
- " \"\"\"GRPO reward function — called on each group of completions.\"\"\"\n",
342
- " # In GRPO the rewards come from rollouts; we pre-compute them above.\n",
343
- " # This function returns the rewards already stored in the dataset.\n",
344
- " return kwargs.get('reward', [0.0] * len(completions))\n",
345
  "\n",
346
  "grpo_config = GRPOConfig(\n",
347
- " output_dir = './orgos_grpo_ckpt',\n",
348
- " num_train_epochs = 3,\n",
349
- " per_device_train_batch_size = 2,\n",
350
- " gradient_accumulation_steps = 4,\n",
351
- " learning_rate = 5e-5,\n",
352
- " warmup_steps = 10,\n",
353
- " logging_steps = 5,\n",
354
- " save_steps = 50,\n",
355
- " fp16 = not torch.cuda.is_bf16_supported(),\n",
356
- " bf16 = torch.cuda.is_bf16_supported(),\n",
357
- " max_grad_norm = 1.0,\n",
358
  " # GRPO-specific\n",
359
- " num_generations = 4, # group size G\n",
360
- " max_new_tokens = 256,\n",
361
- " temperature = 0.7,\n",
362
- " beta = 0.04, # KL penalty\n",
363
- " report_to = 'none',\n",
364
- " seed = 42,\n",
365
  ")\n",
366
  "\n",
367
  "trainer = GRPOTrainer(\n",
368
  " model = model,\n",
369
  " args = grpo_config,\n",
370
- " reward_funcs = reward_fn,\n",
371
- " train_dataset = train_dataset,\n",
372
- " tokenizer = tokenizer,\n",
373
  ")\n",
374
  "\n",
375
- "print('Starting GRPO training...')\n",
 
 
 
 
 
 
376
  "train_result = trainer.train()\n",
377
- "print('Training complete!')\n",
378
  "print(train_result.metrics)"
379
  ]
380
  },
381
  {
382
  "cell_type": "markdown",
383
- "id": "sec8",
384
  "metadata": {},
385
- "source": ["## 8. Collect Post-Training Rollouts"]
 
 
386
  },
387
  {
388
  "cell_type": "code",
389
  "execution_count": null,
390
- "id": "posttraining_rollouts",
391
  "metadata": {},
392
  "outputs": [],
393
  "source": [
394
- "# Switch model to inference mode\n",
395
  "FastLanguageModel.for_inference(model)\n",
396
  "\n",
397
- "N_EVAL = 30\n",
398
- "post_scores = {'A': [], 'B': [], 'C': []}\n",
399
  "\n",
400
- "print('Collecting post-training rollouts...')\n",
401
- "for wf in ['A', 'B', 'C']:\n",
402
- " for ep in range(N_EVAL // 3):\n",
403
- " _, score = run_episode(wf)\n",
404
  " post_scores[wf].append(score)\n",
405
- " print(f' Workflow {wf} ep {ep+1}: score={score:.4f}', end='\\r')\n",
406
- " print(f' Workflow {wf}: mean={np.mean(post_scores[wf]):.4f} ± {np.std(post_scores[wf]):.4f}')\n",
407
  "\n",
408
- "print(f'\\nOverall post-training mean: {np.mean([s for v in post_scores.values() for s in v]):.4f}')"
 
 
409
  ]
410
  },
411
  {
412
  "cell_type": "markdown",
413
- "id": "sec9",
414
  "metadata": {},
415
- "source": ["## 9. Plot Before/After Reward Curves"]
 
 
416
  },
417
  {
418
  "cell_type": "code",
419
  "execution_count": null,
420
- "id": "plot_curves",
421
  "metadata": {},
422
  "outputs": [],
423
  "source": [
424
  "import matplotlib.pyplot as plt\n",
425
  "import matplotlib.gridspec as gridspec\n",
426
  "\n",
427
- "fig = plt.figure(figsize=(14, 8), facecolor='#0f172a')\n",
428
- "fig.suptitle('OrgOS: Before vs After GRPO Training', fontsize=15,\n",
429
- " color='white', fontweight='bold', y=0.98)\n",
430
  "\n",
431
  "gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)\n",
432
  "\n",
433
- "COLORS = {'before': '#f87171', 'after': '#34d399', 'bg': '#1e293b', 'grid': '#334155'}\n",
434
- "WF_LABELS = {'A': 'Workflow A\\nCustomer Bug Fix',\n",
435
- " 'B': 'Workflow B\\nEmployee Onboarding',\n",
436
- " 'C': 'Workflow C\\nChurn Risk Alert'}\n",
 
 
437
  "\n",
438
- "for col, wf in enumerate(['A', 'B', 'C']):\n",
439
  " ax = fig.add_subplot(gs[0, col])\n",
440
- " ax.set_facecolor(COLORS['bg'])\n",
441
- " ax.grid(color=COLORS['grid'], linewidth=0.5, alpha=0.7)\n",
442
  "\n",
443
  " before = baseline_scores[wf]\n",
444
  " after = post_scores[wf]\n",
 
445
  "\n",
446
- " ax.plot(before, color=COLORS['before'], linewidth=1.5, alpha=0.8, label='Before GRPO')\n",
447
- " ax.plot(after, color=COLORS['after'], linewidth=1.5, alpha=0.8, label='After GRPO')\n",
448
- "\n",
449
- " ax.axhline(np.mean(before), color=COLORS['before'], linestyle='--', linewidth=1, alpha=0.5)\n",
450
- " ax.axhline(np.mean(after), color=COLORS['after'], linestyle='--', linewidth=1, alpha=0.5)\n",
451
  "\n",
452
- " delta = np.mean(after) - np.mean(before)\n",
453
- " ax.set_title(WF_LABELS[wf] + f'\\n(Δ = {delta:+.4f})', color='white', fontsize=9)\n",
454
- " ax.set_xlabel('Episode', color='#94a3b8', fontsize=8)\n",
455
- " ax.set_ylabel('Final Score', color='#94a3b8', fontsize=8)\n",
456
- " ax.tick_params(colors='#64748b', labelsize=7)\n",
457
  " ax.set_ylim(0, 1)\n",
458
- " ax.legend(fontsize=7, facecolor='#1e293b', labelcolor='white',\n",
459
- " edgecolor='#475569', framealpha=0.8)\n",
460
  " for spine in ax.spines.values():\n",
461
- " spine.set_edgecolor('#334155')\n",
462
  "\n",
463
- "# Bottom row: combined histogram\n",
464
  "ax_hist = fig.add_subplot(gs[1, :])\n",
465
- "ax_hist.set_facecolor(COLORS['bg'])\n",
466
- "ax_hist.grid(color=COLORS['grid'], linewidth=0.5, alpha=0.5, axis='x')\n",
467
  "\n",
468
  "all_before = [s for v in baseline_scores.values() for s in v]\n",
469
  "all_after = [s for v in post_scores.values() for s in v]\n",
470
- "\n",
471
  "bins = np.linspace(0, 1, 25)\n",
472
- "ax_hist.hist(all_before, bins=bins, color=COLORS['before'], alpha=0.6, label=f'Before GRPO (mean={np.mean(all_before):.4f})', edgecolor='none')\n",
473
- "ax_hist.hist(all_after, bins=bins, color=COLORS['after'], alpha=0.6, label=f'After GRPO (mean={np.mean(all_after):.4f})', edgecolor='none')\n",
474
- "ax_hist.axvline(np.mean(all_before), color=COLORS['before'], linestyle='--', linewidth=1.5)\n",
475
- "ax_hist.axvline(np.mean(all_after), color=COLORS['after'], linestyle='--', linewidth=1.5)\n",
476
- "\n",
477
- "ax_hist.set_title('Score Distribution Across All Workflows', color='white', fontsize=10)\n",
478
- "ax_hist.set_xlabel('Final Score', color='#94a3b8', fontsize=9)\n",
479
- "ax_hist.set_ylabel('Count', color='#94a3b8', fontsize=9)\n",
480
- "ax_hist.tick_params(colors='#64748b', labelsize=8)\n",
481
- "ax_hist.legend(fontsize=9, facecolor='#1e293b', labelcolor='white',\n",
482
- " edgecolor='#475569', framealpha=0.9)\n",
 
 
 
483
  "for spine in ax_hist.spines.values():\n",
484
- " spine.set_edgecolor('#334155')\n",
485
  "\n",
486
- "plt.savefig('before_after_curves.png', dpi=150, bbox_inches='tight',\n",
487
- " facecolor='#0f172a', edgecolor='none')\n",
488
  "plt.show()\n",
489
- "print('Saved: before_after_curves.png')"
490
  ]
491
  },
492
  {
493
  "cell_type": "markdown",
494
- "id": "sec10",
495
  "metadata": {},
496
- "source": ["## 10. Save LoRA Adapter & Upload to HuggingFace"]
 
 
497
  },
498
  {
499
  "cell_type": "code",
@@ -502,49 +638,29 @@
502
  "metadata": {},
503
  "outputs": [],
504
  "source": [
505
- "# Save LoRA adapter locally\n",
506
- "model.save_pretrained('orgos_lora_adapter')\n",
507
- "tokenizer.save_pretrained('orgos_lora_adapter')\n",
508
- "print('LoRA adapter saved to ./orgos_lora_adapter')\n",
509
  "\n",
510
- "# Optionally push to HuggingFace Hub\n",
511
  "# from huggingface_hub import login\n",
512
- "# login(token=os.environ['HF_TOKEN'])\n",
513
- "# model.push_to_hub('YOUR_HF_USERNAME/orgos-qwen25-3b-grpo-lora')\n",
514
- "# tokenizer.push_to_hub('YOUR_HF_USERNAME/orgos-qwen25-3b-grpo-lora')\n",
515
- "# print('Pushed to HuggingFace Hub!')"
516
  ]
 
 
 
 
 
 
 
517
  },
518
- {
519
- "cell_type": "markdown",
520
- "id": "sec11",
521
- "metadata": {},
522
- "source": [
523
- "## 11. Summary\n",
524
- "\n",
525
- "```\n",
526
- "OrgOS GRPO Training Summary\n",
527
- "============================\n",
528
- "Model: Qwen2.5-3B-Instruct + 4-bit LoRA\n",
529
- "Algorithm: GRPO (Group Relative Policy Optimization)\n",
530
- "Epochs: 3\n",
531
- "Episodes: 30 baseline + 30 post-training\n",
532
- "\n",
533
- "Key result: The GRPO-trained model learns to:\n",
534
- " 1. Read schema_hints before constructing action args\n",
535
- " 2. Use drifted field names (e.g. 'severity' not 'priority')\n",
536
- " 3. Complete workflow steps in the correct order\n",
537
- " 4. Avoid RBAC violations by checking role constraints\n",
538
- "\n",
539
- "This produces a clear, measurable improvement visible in\n",
540
- "before_after_curves.png — the core evidence for judging.\n",
541
- "```\n",
542
- "\n",
543
- "**Artefacts produced:**\n",
544
- "- `before_after_curves.png` — the money chart for the pitch\n",
545
- "- `orgos_lora_adapter/` — the trained LoRA weights\n",
546
- "- `baseline_scores.json` — raw score data"
547
- ]
548
  }
549
- ]
 
 
550
  }
 
1
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
  "id": "title",
6
  "metadata": {},
7
  "source": [
8
+ "# OrgOS GRPO Training\n",
9
  "\n",
10
  "**Environment:** OrgOS — Multi-App Enterprise RL Environment \n",
11
  "**Model:** `Qwen/Qwen2.5-3B-Instruct` (4-bit LoRA via Unsloth) \n",
12
  "**Algorithm:** GRPO (Group Relative Policy Optimization) via HuggingFace TRL \n",
13
+ "**Target hardware:** HuggingFace compute (A10G / A100) \n",
14
+ "\n",
15
+ "## How this works\n",
16
+ "\n",
17
+ "GRPO is an **online** RL algorithm:\n",
18
+ "1. Each training step takes a batch of **prompts** (observations from the env)\n",
19
+ "2. The model generates **G candidate actions** per prompt (the group)\n",
20
+ "3. Each action is sent to the **live OrgOS env** to get a real reward\n",
21
+ "4. GRPO computes relative advantages within the group (which action did better than average?)\n",
22
+ "5. Model is updated to favour higher-reward actions\n",
23
+ "\n",
24
+ "**Key training signal:** Schema drift creates a sharp reward gap.\n",
25
+ "Using a stale field name (e.g. `priority` when schema says `severity`) → **−0.20**. \n",
26
+ "Using the correct drifted name **+0.10** adaptation bonus. \n",
27
+ "The model learns to read `schema_hints` before constructing action args."
28
  ]
29
  },
30
  {
31
  "cell_type": "markdown",
32
  "id": "sec1",
33
  "metadata": {},
34
+ "source": [
35
+ "## 1. Install Dependencies"
36
+ ]
37
  },
38
  {
39
  "cell_type": "code",
 
42
  "metadata": {},
43
  "outputs": [],
44
  "source": [
45
+ "!pip install -q \"unsloth[huggingface]\" \"trl>=0.12.0\" peft accelerate bitsandbytes\n",
46
+ "!pip install -q fastapi uvicorn httpx openai pydantic python-dotenv\n",
47
+ "!pip install -q matplotlib numpy datasets"
 
 
 
 
 
 
 
 
 
 
48
  ]
49
  },
50
  {
51
  "cell_type": "markdown",
52
  "id": "sec2",
53
  "metadata": {},
54
+ "source": [
55
+ "## 2. Clone the OrgOS Repo"
56
+ ]
57
  },
58
  {
59
  "cell_type": "code",
60
  "execution_count": null,
61
+ "id": "clone_repo",
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
65
+ "import os\n",
 
66
  "\n",
67
+ "REPO_URL = \"https://huggingface.co/spaces/tanvibisht/orgos-openenv\"\n",
68
+ "REPO_DIR = \"/home/user/orgos\"\n",
69
  "\n",
70
+ "if not os.path.exists(REPO_DIR):\n",
71
+ " !git clone {REPO_URL} {REPO_DIR}\n",
 
 
 
 
72
  "\n",
73
+ "os.chdir(REPO_DIR)\n",
74
+ "print(\"Working directory:\", os.getcwd())\n",
75
+ "!ls"
 
 
 
 
 
 
 
 
 
 
76
  ]
77
  },
78
  {
79
  "cell_type": "markdown",
80
  "id": "sec3",
81
  "metadata": {},
82
+ "source": [
83
+ "## 3. Start the OrgOS Environment Server"
84
+ ]
85
  },
86
  {
87
  "cell_type": "code",
 
93
  "import subprocess, time, httpx\n",
94
  "\n",
95
  "server_proc = subprocess.Popen(\n",
96
+ " [\"python\", \"-m\", \"uvicorn\", \"server.app:app\", \"--host\", \"0.0.0.0\", \"--port\", \"8000\"],\n",
97
+ " stdout=subprocess.DEVNULL,\n",
98
+ " stderr=subprocess.DEVNULL,\n",
99
  ")\n",
100
+ "time.sleep(4)\n",
101
  "\n",
102
+ "health = httpx.get(\"http://localhost:8000/health\").json()\n",
103
+ "assert health[\"status\"] == \"healthy\", f\"Server not healthy: {health}\"\n",
104
+ "print(\"OrgOS server running:\", health)"
105
  ]
106
  },
107
  {
108
  "cell_type": "markdown",
109
  "id": "sec4",
110
  "metadata": {},
111
+ "source": [
112
+ "## 4. Load Model with Unsloth 4-bit LoRA"
113
+ ]
114
  },
115
  {
116
  "cell_type": "code",
117
  "execution_count": null,
118
+ "id": "load_model",
119
  "metadata": {},
120
  "outputs": [],
121
  "source": [
122
+ "from unsloth import FastLanguageModel\n",
123
+ "import torch\n",
124
  "\n",
125
+ "MAX_SEQ_LEN = 2048\n",
126
+ "MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
127
  "\n",
128
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
129
+ " model_name = MODEL_NAME,\n",
130
+ " max_seq_length = MAX_SEQ_LEN,\n",
131
+ " dtype = None,\n",
132
+ " load_in_4bit = True,\n",
133
+ ")\n",
 
 
 
 
 
 
 
 
134
  "\n",
135
+ "model = FastLanguageModel.get_peft_model(\n",
136
+ " model,\n",
137
+ " r = 16,\n",
138
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
139
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
140
+ " lora_alpha = 16,\n",
141
+ " lora_dropout = 0,\n",
142
+ " bias = \"none\",\n",
143
+ " use_gradient_checkpointing = \"unsloth\",\n",
144
+ " random_state = 42,\n",
145
+ ")\n",
146
+ "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
147
+ "print(f\"Model loaded — trainable params: {trainable:,}\")"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "markdown",
152
+ "id": "sec5",
153
+ "metadata": {},
154
+ "source": [
155
+ "## 5. Prompt Dataset\n",
 
 
 
 
 
 
 
 
156
  "\n",
157
+ "We collect **first-turn observations** from fresh episode resets as our prompt dataset.\n",
158
+ "These are the most important turns — they contain `schema_hints`, `active_rules`, and the\n",
159
+ "full workflow goal. The model must learn to read schema hints and produce a correct first action.\n",
160
  "\n",
161
+ "During GRPO training, the reward function will reset the env and evaluate each generated action live."
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "id": "build_prompts",
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "import json\n",
172
+ "from datasets import Dataset\n",
173
  "\n",
174
+ "SYSTEM_PROMPT = \"\"\"\\\n",
175
+ "You are OrgOS Agent — an enterprise workflow automation agent.\n",
176
+ "You operate across four SaaS applications: Jira, Zendesk, Salesforce, and Workday.\n",
177
+ "\n",
178
+ "Each turn you receive a JSON observation with:\n",
179
+ " - workflow_goal : the task you must complete\n",
180
+ " - pending_steps : remaining steps in the workflow\n",
181
+ " - app_states : current state of each app\n",
182
+ " - schema_hints : field renames in effect this episode (e.g. {\"jira.priority\": \"severity\"})\n",
183
+ " - active_rules : current SLA / approval thresholds\n",
184
+ " - message : feedback from the last action\n",
185
+ " - current_score : your cumulative score (0.001-0.999)\n",
186
+ "\n",
187
+ "Respond ONLY with a valid JSON object — no markdown, no explanation.\n",
188
+ "\n",
189
+ "Action format:\n",
190
+ " {\"app\": \"<app>\", \"operation\": \"<op>\", \"args\": {...}}\n",
191
+ "\n",
192
+ "Available apps and key operations:\n",
193
+ " jira: get_issue, create_issue, update_status, set_priority, assign_owner,\n",
194
+ " add_label, link_zendesk_ticket, close_issue, list_issues\n",
195
+ " zendesk: get_ticket, acknowledge_ticket, set_urgency, assign_agent,\n",
196
+ " escalate_to_jira, resolve_ticket, add_note, list_tickets,\n",
197
+ " create_agent_profile\n",
198
+ " salesforce: get_account, list_accounts, update_deal_stage, flag_churn_risk,\n",
199
+ " assign_account_owner, log_interaction, get_opportunity\n",
200
+ " workday: get_employee, list_employees, provision_access, log_sla_event,\n",
201
+ " request_budget_approval, create_onboarding_task, complete_task\n",
202
+ "\n",
203
+ "CRITICAL RULES:\n",
204
+ "1. Read schema_hints FIRST — if \"jira.priority\" -> \"severity\", use \"severity\" not \"priority\" in args.\n",
205
+ "2. Complete ALL pending_steps in order.\n",
206
+ "3. Do not repeat a successful action.\n",
207
+ "4. If an operation fails, read the message carefully and adapt.\n",
208
+ "5. Use list_* operations to discover record IDs when needed.\n",
209
+ "6. Stop when pending_steps is empty or done=true.\n",
210
+ "\"\"\"\n",
211
  "\n",
 
212
  "\n",
213
+ "def obs_to_text(obs: dict) -> str:\n",
214
+ " hints = obs.get(\"schema_hints\", {})\n",
215
+ " pending = obs.get(\"pending_steps\", [])\n",
216
+ " lines = [\n",
217
+ " f\"current_score: {obs['current_score']}\",\n",
218
+ " f\"step_count: {obs['step_count']}\",\n",
219
+ " f\"workflow_id: {obs['workflow_id']}\",\n",
220
+ " \"\",\n",
221
+ " \"=== WORKFLOW GOAL ===\",\n",
222
+ " obs[\"workflow_goal\"],\n",
223
+ " \"\",\n",
224
+ " \"=== PENDING STEPS ===\",\n",
225
+ " \"\\n\".join(f\" - {s}\" for s in pending) or \" (all steps complete!)\",\n",
226
+ " \"\",\n",
227
+ " \"=== SCHEMA HINTS (use these field names) ===\",\n",
228
+ " json.dumps(hints, indent=2) if hints else \" (no drift — use canonical names)\",\n",
229
+ " \"\",\n",
230
+ " \"=== ACTIVE RULES ===\",\n",
231
+ " json.dumps(obs.get(\"active_rules\", {}), indent=2),\n",
232
+ " \"\",\n",
233
+ " \"=== LAST MESSAGE ===\",\n",
234
+ " obs[\"message\"],\n",
235
+ " \"\",\n",
236
+ " \"=== APP STATES ===\",\n",
237
+ " ]\n",
238
+ " for app_name, view in obs.get(\"app_states\", {}).items():\n",
239
+ " lines.append(f\" [{app_name.upper()}]\")\n",
240
+ " lines.append(f\" {view}\")\n",
241
+ " lines.append(\"\")\n",
242
+ " return \"\\n\".join(lines)\n",
243
+ "\n",
244
+ "\n",
245
+ "def build_prompt(obs_text: str) -> str:\n",
246
+ " \"\"\"Format as a chat prompt with system injected into first user message.\"\"\"\n",
247
+ " messages = [{\"role\": \"user\", \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + obs_text}]\n",
248
+ " return tokenizer.apply_chat_template(\n",
249
+ " messages, tokenize=False, add_generation_prompt=True\n",
250
+ " )\n",
251
  "\n",
 
 
 
252
  "\n",
253
+ "# Collect first-turn observations across all 3 workflows, multiple episodes\n",
254
+ "# Each episode has a different schema version (seed varies) so we get diverse prompts\n",
255
+ "N_PROMPTS_PER_WORKFLOW = 20\n",
256
+ "prompt_rows = []\n",
257
  "\n",
258
+ "print(\"Collecting prompts from env resets...\")\n",
259
+ "for wf in [\"A\", \"B\", \"C\"]:\n",
260
+ " for _ in range(N_PROMPTS_PER_WORKFLOW):\n",
261
+ " result = httpx.post(\"http://localhost:8000/reset\", json={\"workflow_id\": wf}).json()\n",
262
+ " obs = result[\"observation\"]\n",
263
+ " obs_text = obs_to_text(obs)\n",
264
+ " prompt_rows.append({\n",
265
+ " \"prompt\": build_prompt(obs_text),\n",
266
+ " \"workflow_id\": wf,\n",
267
+ " \"obs_text\": obs_text,\n",
268
  " })\n",
269
  "\n",
270
+ "prompt_dataset = Dataset.from_list(prompt_rows)\n",
271
+ "print(f\"Prompt dataset: {len(prompt_dataset)} examples across 3 workflows\")\n",
272
+ "print(\"Sample prompt (truncated):\\n\", prompt_rows[0][\"prompt\"][:600], \"...\")"
 
 
 
273
  ]
274
  },
275
  {
276
  "cell_type": "markdown",
277
+ "id": "sec6",
278
  "metadata": {},
279
+ "source": [
280
+ "## 6. Reward Function\n",
281
+ "\n",
282
+ "Called by GRPOTrainer during training on each batch of generated completions.\n",
283
+ "For each completion:\n",
284
+ "1. Parse it as action JSON\n",
285
+ "2. Reset the env to a fresh episode for the right workflow\n",
286
+ "3. Send the action via `/step`\n",
287
+ "4. Return the reward\n",
288
+ "\n",
289
+ "This gives the model a live signal from the actual environment."
290
+ ]
291
  },
292
  {
293
  "cell_type": "code",
294
  "execution_count": null,
295
+ "id": "reward_fn",
296
  "metadata": {},
297
  "outputs": [],
298
  "source": [
299
+ "import re\n",
300
+ "from typing import List\n",
301
+ "\n",
302
+ "ENV_URL = \"http://localhost:8000\"\n",
303
+ "\n",
304
+ "\n",
305
+ "def parse_action(text: str):\n",
306
+ " \"\"\"Extract JSON action from model output.\"\"\"\n",
307
+ " text = text.strip()\n",
308
+ " # Strip markdown code fences if present\n",
309
+ " text = re.sub(r\"```(?:json)?\\s*\", \"\", text).strip()\n",
310
+ " try:\n",
311
+ " return json.loads(text)\n",
312
+ " except json.JSONDecodeError:\n",
313
+ " m = re.search(r\"\\{.*\\}\", text, re.DOTALL)\n",
314
+ " if m:\n",
315
+ " try:\n",
316
+ " return json.loads(m.group())\n",
317
+ " except Exception:\n",
318
+ " pass\n",
319
+ " return None\n",
320
+ "\n",
321
+ "\n",
322
+ "def orgos_reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:\n",
323
+ " \"\"\"\n",
324
+ " GRPO reward function — called by GRPOTrainer each training step.\n",
325
  "\n",
326
+ " For each generated completion:\n",
327
+ " - Parse as action JSON\n",
328
+ " - Reset env to a fresh episode (workflow inferred from prompt)\n",
329
+ " - Step the env with the action\n",
330
+ " - Return the step reward\n",
331
  "\n",
332
+ " Invalid JSON or failed actions return a -0.1 penalty.\n",
333
+ " \"\"\"\n",
334
+ " workflow_ids = kwargs.get(\"workflow_id\", [\"A\"] * len(completions))\n",
335
+ " rewards = []\n",
336
  "\n",
337
+ " for completion, wf_id in zip(completions, workflow_ids):\n",
338
+ " action = parse_action(completion)\n",
339
+ "\n",
340
+ " if action is None:\n",
341
+ " rewards.append(-0.1)\n",
342
+ " continue\n",
 
 
343
  "\n",
344
+ " try:\n",
345
+ " # Fresh episode for this action evaluation\n",
346
+ " httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": wf_id}, timeout=10)\n",
347
+ " result = httpx.post(f\"{ENV_URL}/step\", json=action, timeout=10).json()\n",
348
+ " rewards.append(float(result[\"reward\"]))\n",
349
+ " except Exception:\n",
350
+ " rewards.append(-0.1)\n",
351
+ "\n",
352
+ " return rewards\n",
353
+ "\n",
354
+ "\n",
355
+ "print(\"Reward function defined.\")\n",
356
+ "print(\"Quick sanity check...\")\n",
357
+ "test_rewards = orgos_reward_fn(\n",
358
+ " completions = ['{\"app\": \"zendesk\", \"operation\": \"list_tickets\", \"args\": {\"state\": \"new\"}}',\n",
359
+ " 'this is not valid json'],\n",
360
+ " prompts = [\"\", \"\"],\n",
361
+ " workflow_id = [\"A\", \"A\"],\n",
362
+ ")\n",
363
+ "print(f\" Valid action reward: {test_rewards[0]:.4f}\")\n",
364
+ "print(f\" Invalid action reward: {test_rewards[1]:.4f}\")"
365
  ]
366
  },
367
  {
368
  "cell_type": "markdown",
369
+ "id": "sec7",
370
  "metadata": {},
371
+ "source": [
372
+ "## 7. Collect Baseline Scores (Pre-Training)"
373
+ ]
374
  },
375
  {
376
  "cell_type": "code",
377
  "execution_count": null,
378
+ "id": "baseline",
379
  "metadata": {},
380
  "outputs": [],
381
  "source": [
382
+ "import numpy as np\n",
383
  "\n",
384
+ "FastLanguageModel.for_inference(model)\n",
385
+ "\n",
386
+ "\n",
387
+ "def run_episode_with_model(workflow_id: str, max_steps: int = 15) -> float:\n",
388
+ " \"\"\"Run one full episode with the current model. Returns final score.\"\"\"\n",
389
+ " result = httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": workflow_id}).json()\n",
390
+ " obs = result[\"observation\"]\n",
391
+ " history = []\n",
392
+ "\n",
393
+ " for _ in range(max_steps):\n",
394
+ " if obs[\"done\"]:\n",
395
+ " break\n",
396
+ "\n",
397
+ " obs_text = obs_to_text(obs)\n",
398
+ " history.append({\"role\": \"user\", \"content\": obs_text})\n",
399
+ "\n",
400
+ " # Inject system prompt into first user message\n",
401
+ " messages = list(history)\n",
402
+ " messages[0] = {\"role\": \"user\", \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + messages[0][\"content\"]}\n",
403
+ "\n",
404
+ " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
405
+ " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
406
+ "\n",
407
+ " with torch.no_grad():\n",
408
+ " out = model.generate(\n",
409
+ " **inputs,\n",
410
+ " max_new_tokens = 256,\n",
411
+ " temperature = 0.0,\n",
412
+ " do_sample = False,\n",
413
+ " pad_token_id = tokenizer.eos_token_id,\n",
414
+ " )\n",
415
+ " action_str = tokenizer.decode(\n",
416
+ " out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n",
417
+ " ).strip()\n",
418
+ "\n",
419
+ " history.append({\"role\": \"assistant\", \"content\": action_str})\n",
420
+ "\n",
421
+ " action = parse_action(action_str)\n",
422
+ " if action is None:\n",
423
+ " break\n",
424
+ "\n",
425
+ " result = httpx.post(f\"{ENV_URL}/step\", json=action).json()\n",
426
+ " obs = result[\"observation\"]\n",
427
+ " if obs[\"done\"]:\n",
428
+ " break\n",
429
+ "\n",
430
+ " return obs.get(\"current_score\", 0.001)\n",
431
+ "\n",
432
+ "\n",
433
+ "N_EVAL = 10 # episodes per workflow for evaluation\n",
434
+ "baseline_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
435
+ "\n",
436
+ "print(\"Collecting pre-training baseline scores...\")\n",
437
+ "for wf in [\"A\", \"B\", \"C\"]:\n",
438
+ " for ep in range(N_EVAL):\n",
439
+ " score = run_episode_with_model(wf)\n",
440
+ " baseline_scores[wf].append(score)\n",
441
+ " print(f\" Workflow {wf} ep {ep+1}/{N_EVAL}: score={score:.4f}\", end=\"\\r\")\n",
442
+ " print(f\" Workflow {wf}: mean={np.mean(baseline_scores[wf]):.4f}\")\n",
443
+ "\n",
444
+ "baseline_mean = np.mean([s for v in baseline_scores.values() for s in v])\n",
445
+ "print(f\"\\nOverall baseline mean: {baseline_mean:.4f}\")"
446
  ]
447
  },
448
  {
449
  "cell_type": "markdown",
450
+ "id": "sec8",
451
  "metadata": {},
452
+ "source": [
453
+ "## 8. GRPO Training"
454
+ ]
455
  },
456
  {
457
  "cell_type": "code",
 
462
  "source": [
463
  "from trl import GRPOConfig, GRPOTrainer\n",
464
  "\n",
465
+ "# Switch back to training mode\n",
466
+ "model.train()\n",
 
 
 
 
467
  "\n",
468
  "grpo_config = GRPOConfig(\n",
469
+ " output_dir = \"./orgos_grpo_ckpt\",\n",
470
+ " num_train_epochs = 3,\n",
471
+ " per_device_train_batch_size = 4,\n",
472
+ " gradient_accumulation_steps = 2,\n",
473
+ " learning_rate = 5e-5,\n",
474
+ " warmup_steps = 10,\n",
475
+ " logging_steps = 5,\n",
476
+ " save_steps = 100,\n",
477
+ " bf16 = torch.cuda.is_bf16_supported(),\n",
478
+ " fp16 = not torch.cuda.is_bf16_supported(),\n",
479
+ " max_grad_norm = 1.0,\n",
480
  " # GRPO-specific\n",
481
+ " num_generations = 4, # G: candidate actions per prompt\n",
482
+ " max_new_tokens = 256,\n",
483
+ " temperature = 0.8, # exploration during training\n",
484
+ " beta = 0.04, # KL penalty coefficient\n",
485
+ " report_to = \"none\",\n",
486
+ " seed = 42,\n",
487
  ")\n",
488
  "\n",
489
  "trainer = GRPOTrainer(\n",
490
  " model = model,\n",
491
  " args = grpo_config,\n",
492
+ " reward_funcs = orgos_reward_fn,\n",
493
+ " train_dataset = prompt_dataset,\n",
494
+ " processing_class = tokenizer,\n",
495
  ")\n",
496
  "\n",
497
+ "print(\"Starting GRPO training...\")\n",
498
+ "print(f\" Prompts: {len(prompt_dataset)}\")\n",
499
+ "print(f\" Generations per prompt (G): {grpo_config.num_generations}\")\n",
500
+ "print(f\" Epochs: {grpo_config.num_train_epochs}\")\n",
501
+ "print(f\" Total env calls per epoch: ~{len(prompt_dataset) * grpo_config.num_generations}\")\n",
502
+ "print()\n",
503
+ "\n",
504
  "train_result = trainer.train()\n",
505
+ "print(\"\\nTraining complete!\")\n",
506
  "print(train_result.metrics)"
507
  ]
508
  },
509
  {
510
  "cell_type": "markdown",
511
+ "id": "sec9",
512
  "metadata": {},
513
+ "source": [
514
+ "## 9. Collect Post-Training Scores"
515
+ ]
516
  },
517
  {
518
  "cell_type": "code",
519
  "execution_count": null,
520
+ "id": "post_training",
521
  "metadata": {},
522
  "outputs": [],
523
  "source": [
 
524
  "FastLanguageModel.for_inference(model)\n",
525
  "\n",
526
+ "post_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
 
527
  "\n",
528
+ "print(\"Collecting post-training scores...\")\n",
529
+ "for wf in [\"A\", \"B\", \"C\"]:\n",
530
+ " for ep in range(N_EVAL):\n",
531
+ " score = run_episode_with_model(wf)\n",
532
  " post_scores[wf].append(score)\n",
533
+ " print(f\" Workflow {wf} ep {ep+1}/{N_EVAL}: score={score:.4f}\", end=\"\\r\")\n",
534
+ " print(f\" Workflow {wf}: mean={np.mean(post_scores[wf]):.4f}\")\n",
535
  "\n",
536
+ "post_mean = np.mean([s for v in post_scores.values() for s in v])\n",
537
+ "print(f\"\\nOverall post-training mean: {post_mean:.4f}\")\n",
538
+ "print(f\"Improvement: {post_mean - baseline_mean:+.4f}\")"
539
  ]
540
  },
541
  {
542
  "cell_type": "markdown",
543
+ "id": "sec10",
544
  "metadata": {},
545
+ "source": [
546
+ "## 10. Plot Before / After"
547
+ ]
548
  },
549
  {
550
  "cell_type": "code",
551
  "execution_count": null,
552
+ "id": "plot",
553
  "metadata": {},
554
  "outputs": [],
555
  "source": [
556
  "import matplotlib.pyplot as plt\n",
557
  "import matplotlib.gridspec as gridspec\n",
558
  "\n",
559
+ "fig = plt.figure(figsize=(14, 8), facecolor=\"#0f172a\")\n",
560
+ "fig.suptitle(\"OrgOS: Before vs After GRPO Training\", fontsize=15,\n",
561
+ " color=\"white\", fontweight=\"bold\", y=0.98)\n",
562
  "\n",
563
  "gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)\n",
564
  "\n",
565
+ "COLORS = {\"before\": \"#f87171\", \"after\": \"#34d399\", \"bg\": \"#1e293b\", \"grid\": \"#334155\"}\n",
566
+ "WF_LABELS = {\n",
567
+ " \"A\": \"Workflow A\\nCustomer Bug Fix\",\n",
568
+ " \"B\": \"Workflow B\\nEmployee Onboarding\",\n",
569
+ " \"C\": \"Workflow C\\nChurn Risk Alert\",\n",
570
+ "}\n",
571
  "\n",
572
+ "for col, wf in enumerate([\"A\", \"B\", \"C\"]):\n",
573
  " ax = fig.add_subplot(gs[0, col])\n",
574
+ " ax.set_facecolor(COLORS[\"bg\"])\n",
575
+ " ax.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.7)\n",
576
  "\n",
577
  " before = baseline_scores[wf]\n",
578
  " after = post_scores[wf]\n",
579
+ " delta = np.mean(after) - np.mean(before)\n",
580
  "\n",
581
+ " ax.plot(before, color=COLORS[\"before\"], linewidth=1.5, alpha=0.8, label=\"Before GRPO\")\n",
582
+ " ax.plot(after, color=COLORS[\"after\"], linewidth=1.5, alpha=0.8, label=\"After GRPO\")\n",
583
+ " ax.axhline(np.mean(before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
584
+ " ax.axhline(np.mean(after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
 
585
  "\n",
586
+ " ax.set_title(WF_LABELS[wf] + f\"\\n(Δ = {delta:+.4f})\", color=\"white\", fontsize=9)\n",
587
+ " ax.set_xlabel(\"Episode\", color=\"#94a3b8\", fontsize=8)\n",
588
+ " ax.set_ylabel(\"Final Score\", color=\"#94a3b8\", fontsize=8)\n",
589
+ " ax.tick_params(colors=\"#64748b\", labelsize=7)\n",
 
590
  " ax.set_ylim(0, 1)\n",
591
+ " ax.legend(fontsize=7, facecolor=\"#1e293b\", labelcolor=\"white\",\n",
592
+ " edgecolor=\"#475569\", framealpha=0.8)\n",
593
  " for spine in ax.spines.values():\n",
594
+ " spine.set_edgecolor(\"#334155\")\n",
595
  "\n",
 
596
  "ax_hist = fig.add_subplot(gs[1, :])\n",
597
+ "ax_hist.set_facecolor(COLORS[\"bg\"])\n",
598
+ "ax_hist.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.5, axis=\"x\")\n",
599
  "\n",
600
  "all_before = [s for v in baseline_scores.values() for s in v]\n",
601
  "all_after = [s for v in post_scores.values() for s in v]\n",
 
602
  "bins = np.linspace(0, 1, 25)\n",
603
+ "\n",
604
+ "ax_hist.hist(all_before, bins=bins, color=COLORS[\"before\"], alpha=0.6,\n",
605
+ " label=f\"Before GRPO (mean={np.mean(all_before):.4f})\", edgecolor=\"none\")\n",
606
+ "ax_hist.hist(all_after, bins=bins, color=COLORS[\"after\"], alpha=0.6,\n",
607
+ " label=f\"After GRPO (mean={np.mean(all_after):.4f})\", edgecolor=\"none\")\n",
608
+ "ax_hist.axvline(np.mean(all_before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1.5)\n",
609
+ "ax_hist.axvline(np.mean(all_after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1.5)\n",
610
+ "\n",
611
+ "ax_hist.set_title(\"Score Distribution Across All Workflows\", color=\"white\", fontsize=10)\n",
612
+ "ax_hist.set_xlabel(\"Final Score\", color=\"#94a3b8\", fontsize=9)\n",
613
+ "ax_hist.set_ylabel(\"Count\", color=\"#94a3b8\", fontsize=9)\n",
614
+ "ax_hist.tick_params(colors=\"#64748b\", labelsize=8)\n",
615
+ "ax_hist.legend(fontsize=9, facecolor=\"#1e293b\", labelcolor=\"white\",\n",
616
+ " edgecolor=\"#475569\", framealpha=0.9)\n",
617
  "for spine in ax_hist.spines.values():\n",
618
+ " spine.set_edgecolor(\"#334155\")\n",
619
  "\n",
620
+ "plt.savefig(\"before_after_curves.png\", dpi=150, bbox_inches=\"tight\",\n",
621
+ " facecolor=\"#0f172a\", edgecolor=\"none\")\n",
622
  "plt.show()\n",
623
+ "print(\"Saved: before_after_curves.png\")"
624
  ]
625
  },
626
  {
627
  "cell_type": "markdown",
628
+ "id": "sec11",
629
  "metadata": {},
630
+ "source": [
631
+ "## 11. Save LoRA Adapter"
632
+ ]
633
  },
634
  {
635
  "cell_type": "code",
 
638
  "metadata": {},
639
  "outputs": [],
640
  "source": [
641
+ "model.save_pretrained(\"orgos_lora_adapter\")\n",
642
+ "tokenizer.save_pretrained(\"orgos_lora_adapter\")\n",
643
+ "print(\"LoRA adapter saved to ./orgos_lora_adapter\")\n",
 
644
  "\n",
645
+ "# Push to HuggingFace Hub\n",
646
  "# from huggingface_hub import login\n",
647
+ "# login(token=\"YOUR_HF_TOKEN\")\n",
648
+ "# model.push_to_hub(\"YOUR_USERNAME/orgos-qwen25-3b-grpo\")\n",
649
+ "# tokenizer.push_to_hub(\"YOUR_USERNAME/orgos-qwen25-3b-grpo\")"
 
650
  ]
651
+ }
652
+ ],
653
+ "metadata": {
654
+ "kernelspec": {
655
+ "display_name": "Python 3",
656
+ "language": "python",
657
+ "name": "python3"
658
  },
659
+ "language_info": {
660
+ "name": "python",
661
+ "version": "3.10.0"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
662
  }
663
+ },
664
+ "nbformat": 4,
665
+ "nbformat_minor": 5
666
  }