muskan singh Claude Opus 4.7 commited on
Commit
5ebb26b
·
1 Parent(s): 03d30a6

fix: stable GRPO notebook — pin TRL<=0.24, multi-step reward, Drive checkpoints every 30 steps

Browse files

Key changes vs previous run that stopped at step 21:
- Pin trl>=0.18.2,<=0.24.0 BEFORE unsloth install (trl 1.x breaks Unsloth patches)
- Multi-step reward fn (REWARD_STEPS=2) for richer training signal
- NUM_GENERATIONS=2 to halve VRAM pressure from G×reward_steps inference calls
- max_new_tokens=256 in GRPOConfig (works with pinned TRL, fixes 95% clipping)
- Drive checkpoint every 30 steps via callback (survives Colab disconnects)
- MAX_TRAIN_STEPS=150, LR=8e-6

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

Files changed (1) hide show
  1. training/grpo_orgos.ipynb +530 -492
training/grpo_orgos.ipynb CHANGED
@@ -1,173 +1,203 @@
1
  {
 
 
 
 
 
 
 
 
2
  "cells": [
3
  {
4
  "cell_type": "markdown",
5
- "id": "title",
6
  "metadata": {},
7
  "source": [
8
- "# OrgOS GRPO Training\n",
9
- "\n",
10
- "**Environment:** OrgOS — Multi-App Enterprise RL Environment \n",
11
- "**Model:** `Qwen/Qwen2.5-3B-Instruct` (4-bit LoRA via Unsloth) \n",
12
- "**Algorithm:** GRPO (Group Relative Policy Optimization) via HuggingFace TRL \n",
13
- "**Target hardware:** HuggingFace compute (A10G / A100) \n",
14
- "\n",
15
- "## How this works\n",
16
- "\n",
17
- "GRPO is an **online** RL algorithm:\n",
18
- "1. Each training step takes a batch of **prompts** (observations from the env)\n",
19
- "2. The model generates **G candidate actions** per prompt (the group)\n",
20
- "3. Each action is sent to the **live OrgOS env** to get a real reward\n",
21
- "4. GRPO computes relative advantages within the group (which action did better than average?)\n",
22
- "5. Model is updated to favour higher-reward actions\n",
23
- "\n",
24
- "**Key training signal:** Schema drift creates a sharp reward gap. \n",
25
- "Using a stale field name (e.g. `priority` when schema says `severity`) → **−0.20**. \n",
26
- "Using the correct drifted name → **+0.10** adaptation bonus. \n",
27
- "The model learns to read `schema_hints` before constructing action args."
28
  ]
29
  },
30
  {
31
  "cell_type": "markdown",
32
- "id": "sec1",
33
  "metadata": {},
34
- "source": [
35
- "## 1. Install Dependencies"
36
- ]
37
  },
38
  {
39
  "cell_type": "code",
40
- "execution_count": null,
41
- "id": "install",
42
  "metadata": {},
43
  "outputs": [],
44
  "source": [
45
- "!pip install -q \"unsloth[huggingface]\" \"trl>=0.12.0\" peft accelerate bitsandbytes\n",
46
- "!pip install -q fastapi uvicorn httpx openai pydantic python-dotenv\n",
47
- "!pip install -q matplotlib numpy datasets"
 
 
 
48
  ]
49
  },
50
  {
51
- "cell_type": "markdown",
52
- "id": "sec2",
53
  "metadata": {},
 
54
  "source": [
55
- "## 2. Clone the OrgOS Repo"
 
 
 
 
 
56
  ]
57
  },
58
  {
59
  "cell_type": "code",
60
- "execution_count": null,
61
- "id": "clone_repo",
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
65
- "import os\n",
 
66
  "\n",
67
- "REPO_URL = \"https://huggingface.co/spaces/tanvibisht/orgos-openenv\"\n",
68
- "REPO_DIR = \"/home/user/orgos\"\n",
 
69
  "\n",
70
- "if not os.path.exists(REPO_DIR):\n",
71
- " !git clone {REPO_URL} {REPO_DIR}\n",
 
 
 
 
 
 
72
  "\n",
73
- "os.chdir(REPO_DIR)\n",
74
- "print(\"Working directory:\", os.getcwd())\n",
75
- "!ls"
76
  ]
77
  },
78
  {
79
  "cell_type": "markdown",
80
- "id": "sec_logger",
81
  "metadata": {},
 
 
 
 
 
 
 
82
  "source": [
83
- "## 3. Training Logger\n",
84
- "\n",
85
- "Writes structured logs to `training_log.txt` for submission. \n",
86
- "Format mirrors the OpenEnv inference log spec:\n",
87
- "- `[TRAIN_CONFIG]` — model, algorithm, hyperparameters\n",
88
- "- `[EVAL]` — per-episode score during baseline or post-training eval\n",
89
- "- `[TRAIN_STEP]` — loss, mean reward, KL per training step\n",
90
- "- `[TRAIN_SUMMARY]` — final before/after comparison"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  ]
92
  },
93
  {
94
  "cell_type": "code",
95
- "execution_count": null,
96
- "id": "logger",
97
  "metadata": {},
98
  "outputs": [],
99
  "source": [
100
- "import datetime\n",
 
101
  "\n",
102
- "LOG_FILE = \"training_log.txt\"\n",
103
- "\n",
104
- "# Clear any previous log\n",
105
- "with open(LOG_FILE, \"w\") as f:\n",
106
- " f.write(f\"# OrgOS GRPO Training Log\\n\")\n",
107
- " f.write(f\"# Generated: {datetime.datetime.utcnow().isoformat()}Z\\n\\n\")\n",
108
- "\n",
109
- "\n",
110
- "def tlog(line: str) -> None:\n",
111
- " \"\"\"Append one structured log line to training_log.txt and print it.\"\"\"\n",
112
  " print(line, flush=True)\n",
113
- " with open(LOG_FILE, \"a\") as f:\n",
114
- " f.write(line + \"\\n\")\n",
115
- "\n",
116
- "\n",
117
- "print(f\"Logger ready — writing to {LOG_FILE}\")"
118
  ]
119
  },
120
  {
121
  "cell_type": "markdown",
122
- "id": "sec4",
123
  "metadata": {},
124
  "source": [
125
- "## 4. Start the OrgOS Environment Server"
 
 
126
  ]
127
  },
128
  {
129
  "cell_type": "code",
130
- "execution_count": null,
131
- "id": "start_server",
132
  "metadata": {},
133
  "outputs": [],
134
  "source": [
135
- "import subprocess, time, httpx\n",
136
- "\n",
137
- "server_proc = subprocess.Popen(\n",
138
- " [\"python\", \"-m\", \"uvicorn\", \"server.app:app\", \"--host\", \"0.0.0.0\", \"--port\", \"8000\"],\n",
139
  " stdout=subprocess.DEVNULL,\n",
140
  " stderr=subprocess.DEVNULL,\n",
141
  ")\n",
142
- "time.sleep(4)\n",
143
- "\n",
144
- "health = httpx.get(\"http://localhost:8000/health\").json()\n",
145
- "assert health[\"status\"] == \"healthy\", f\"Server not healthy: {health}\"\n",
146
- "tlog(f\"[ENV] status=healthy version={health.get('version', '?')}\")"
 
 
 
 
 
147
  ]
148
  },
149
  {
150
  "cell_type": "markdown",
151
- "id": "sec5",
152
  "metadata": {},
153
- "source": [
154
- "## 5. Load Model with Unsloth 4-bit LoRA"
155
- ]
156
  },
157
  {
158
  "cell_type": "code",
159
- "execution_count": null,
160
- "id": "load_model",
161
  "metadata": {},
162
  "outputs": [],
163
  "source": [
164
- "from unsloth import FastLanguageModel\n",
165
- "import torch\n",
166
- "\n",
167
- "MAX_SEQ_LEN = 2048\n",
168
- "MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
169
- "LORA_R = 16\n",
170
- "\n",
171
  "model, tokenizer = FastLanguageModel.from_pretrained(\n",
172
  " model_name = MODEL_NAME,\n",
173
  " max_seq_length = MAX_SEQ_LEN,\n",
@@ -178,548 +208,556 @@
178
  "model = FastLanguageModel.get_peft_model(\n",
179
  " model,\n",
180
  " r = LORA_R,\n",
181
- " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
182
- " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
183
- " lora_alpha = LORA_R,\n",
184
- " lora_dropout = 0,\n",
185
- " bias = \"none\",\n",
186
- " use_gradient_checkpointing = \"unsloth\",\n",
187
- " random_state = 42,\n",
188
  ")\n",
189
  "\n",
 
 
 
190
  "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
191
- "tlog(f\"[TRAIN_CONFIG] model={MODEL_NAME} lora_r={LORA_R} max_seq_len={MAX_SEQ_LEN} \"\n",
192
- " f\"trainable_params={trainable:,} quantization=4bit\")"
193
  ]
194
  },
195
  {
196
  "cell_type": "markdown",
197
- "id": "sec6",
198
  "metadata": {},
199
  "source": [
200
- "## 6. Prompt Dataset"
 
 
201
  ]
202
  },
203
  {
204
  "cell_type": "code",
205
- "execution_count": null,
206
- "id": "build_prompts",
207
  "metadata": {},
208
  "outputs": [],
209
  "source": [
210
- "import json, re\n",
211
- "import numpy as np\n",
212
- "from typing import List\n",
213
- "from datasets import Dataset\n",
214
  "\n",
215
- "SYSTEM_PROMPT = \"\"\"\\\n",
216
- "You are OrgOS Agent an enterprise workflow automation agent.\n",
217
- "You operate across four SaaS applications: Jira, Zendesk, Salesforce, and Workday.\n",
218
  "\n",
219
- "Each turn you receive a JSON observation with:\n",
220
- " - workflow_goal : the task you must complete\n",
221
- " - pending_steps : remaining steps in the workflow\n",
222
- " - app_states : current state of each app\n",
223
- " - schema_hints : field renames in effect this episode (e.g. {\"jira.priority\": \"severity\"})\n",
224
- " - active_rules : current SLA / approval thresholds\n",
225
- " - message : feedback from the last action\n",
226
- " - current_score : your cumulative score (0.001-0.999)\n",
227
- "\n",
228
- "Respond ONLY with a valid JSON object — no markdown, no explanation.\n",
229
- "\n",
230
- "Action format:\n",
231
  " {\"app\": \"<app>\", \"operation\": \"<op>\", \"args\": {...}}\n",
232
  "\n",
233
  "Available apps and key operations:\n",
234
  " jira: get_issue, create_issue, update_status, set_priority, assign_owner,\n",
235
  " add_label, link_zendesk_ticket, close_issue, list_issues\n",
236
  " zendesk: get_ticket, acknowledge_ticket, set_urgency, assign_agent,\n",
237
- " escalate_to_jira, resolve_ticket, add_note, list_tickets,\n",
238
- " create_agent_profile\n",
239
  " salesforce: get_account, list_accounts, update_deal_stage, flag_churn_risk,\n",
240
  " assign_account_owner, log_interaction, get_opportunity\n",
241
  " workday: get_employee, list_employees, provision_access, log_sla_event,\n",
242
  " request_budget_approval, create_onboarding_task, complete_task\n",
243
  "\n",
244
  "CRITICAL RULES:\n",
245
- "1. Read schema_hints FIRST if \"jira.priority\" -> \"severity\", use \"severity\" not \"priority\" in args.\n",
246
- "2. Complete ALL pending_steps in order.\n",
247
- "3. Do not repeat a successful action.\n",
248
- "4. If an operation fails, read the message carefully and adapt.\n",
249
- "5. Use list_* operations to discover record IDs when needed.\n",
250
- "6. Stop when pending_steps is empty or done=true.\n",
251
- "\"\"\"\n",
252
- "\n",
253
- "ENV_URL = \"http://localhost:8000\"\n",
254
- "\n",
 
 
 
 
 
 
 
 
255
  "\n",
256
  "def obs_to_text(obs: dict) -> str:\n",
257
- " hints = obs.get(\"schema_hints\", {})\n",
258
- " pending = obs.get(\"pending_steps\", [])\n",
259
  " lines = [\n",
260
  " f\"current_score: {obs['current_score']}\",\n",
261
  " f\"step_count: {obs['step_count']}\",\n",
262
  " f\"workflow_id: {obs['workflow_id']}\",\n",
263
- " \"\",\n",
264
- " \"=== WORKFLOW GOAL ===\",\n",
265
- " obs[\"workflow_goal\"],\n",
266
- " \"\",\n",
267
- " \"=== PENDING STEPS ===\",\n",
268
- " \"\\n\".join(f\" - {s}\" for s in pending) or \" (all steps complete!)\",\n",
269
- " \"\",\n",
270
- " \"=== SCHEMA HINTS (use these field names) ===\",\n",
271
- " json.dumps(hints, indent=2) if hints else \" (no drift — use canonical names)\",\n",
272
- " \"\",\n",
273
- " \"=== ACTIVE RULES ===\",\n",
274
- " json.dumps(obs.get(\"active_rules\", {}), indent=2),\n",
275
- " \"\",\n",
276
- " \"=== LAST MESSAGE ===\",\n",
277
- " obs[\"message\"],\n",
278
- " \"\",\n",
279
- " \"=== APP STATES ===\",\n",
280
  " ]\n",
281
- " for app_name, view in obs.get(\"app_states\", {}).items():\n",
282
- " lines.append(f\" [{app_name.upper()}]\")\n",
283
- " lines.append(f\" {view}\")\n",
284
- " lines.append(\"\")\n",
285
- " return \"\\n\".join(lines)\n",
286
- "\n",
287
- "\n",
288
- "def build_prompt(obs_text: str) -> str:\n",
289
- " messages = [{\"role\": \"user\", \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + obs_text}]\n",
290
- " return tokenizer.apply_chat_template(\n",
291
- " messages, tokenize=False, add_generation_prompt=True\n",
292
- " )\n",
293
- "\n",
294
  "\n",
295
  "def parse_action(text: str):\n",
296
- " text = re.sub(r\"```(?:json)?\\s*\", \"\", text.strip()).strip()\n",
297
  " try:\n",
298
  " return json.loads(text)\n",
299
  " except json.JSONDecodeError:\n",
300
- " m = re.search(r\"\\{.*\\}\", text, re.DOTALL)\n",
301
  " if m:\n",
302
- " try:\n",
303
- " return json.loads(m.group())\n",
304
- " except Exception:\n",
305
- " pass\n",
306
  " return None\n",
307
  "\n",
 
 
 
 
 
 
 
 
 
 
 
308
  "\n",
309
- "N_PROMPTS_PER_WORKFLOW = 20\n",
310
- "prompt_rows = []\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  "\n",
312
- "print(\"Collecting prompts from env resets...\")\n",
313
- "for wf in [\"A\", \"B\", \"C\"]:\n",
 
 
 
 
 
 
 
 
 
314
  " for _ in range(N_PROMPTS_PER_WORKFLOW):\n",
315
- " result = httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": wf}).json()\n",
316
- " obs = result[\"observation\"]\n",
317
- " obs_text = obs_to_text(obs)\n",
318
- " prompt_rows.append({\n",
319
- " \"prompt\": build_prompt(obs_text),\n",
320
- " \"workflow_id\": wf,\n",
321
- " \"obs_text\": obs_text,\n",
322
  " })\n",
 
 
 
323
  "\n",
324
- "prompt_dataset = Dataset.from_list(prompt_rows)\n",
325
- "tlog(f\"[TRAIN_CONFIG] algorithm=GRPO prompts={len(prompt_dataset)} \"\n",
326
- " f\"workflows=A,B,C prompts_per_workflow={N_PROMPTS_PER_WORKFLOW}\")\n",
327
- "print(f\"Prompt dataset ready: {len(prompt_dataset)} examples\")"
328
  ]
329
  },
330
  {
331
  "cell_type": "markdown",
332
- "id": "sec7",
333
  "metadata": {},
334
  "source": [
335
- "## 7. Reward Function"
 
 
 
 
 
 
 
 
336
  ]
337
  },
338
  {
339
  "cell_type": "code",
340
- "execution_count": null,
341
- "id": "reward_fn",
342
  "metadata": {},
343
  "outputs": [],
344
  "source": [
345
- "def orgos_reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:\n",
346
- " \"\"\"\n",
347
- " GRPO reward function — called by GRPOTrainer each training step.\n",
348
- " Parses each completion as an action JSON, steps the live env, returns the reward.\n",
349
- " \"\"\"\n",
350
- " workflow_ids = kwargs.get(\"workflow_id\", [\"A\"] * len(completions))\n",
351
  " rewards = []\n",
352
- "\n",
353
  " for completion, wf_id in zip(completions, workflow_ids):\n",
354
  " action = parse_action(completion)\n",
355
  " if action is None:\n",
356
  " rewards.append(-0.1)\n",
357
  " continue\n",
358
  " try:\n",
359
- " httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": wf_id}, timeout=10)\n",
360
- " result = httpx.post(f\"{ENV_URL}/step\", json=action, timeout=10).json()\n",
361
- " rewards.append(float(result[\"reward\"]))\n",
362
- " except Exception:\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  " rewards.append(-0.1)\n",
364
- "\n",
365
  " return rewards\n",
366
  "\n",
367
- "\n",
368
  "# Sanity check\n",
369
- "test_r = orgos_reward_fn(\n",
370
- " completions = ['{\"app\": \"zendesk\", \"operation\": \"list_tickets\", \"args\": {\"state\": \"new\"}}',\n",
371
- " 'not json'],\n",
372
- " prompts = [\"\", \"\"],\n",
373
- " workflow_id = [\"A\", \"A\"],\n",
374
- ")\n",
375
- "tlog(f\"[REWARD_FN_CHECK] valid_action={test_r[0]:.4f} invalid_action={test_r[1]:.4f}\")"
376
  ]
377
  },
378
  {
379
  "cell_type": "markdown",
380
- "id": "sec8",
381
  "metadata": {},
382
- "source": [
383
- "## 8. Collect Baseline Scores (Pre-Training)"
384
- ]
 
385
  },
386
  {
387
  "cell_type": "code",
388
- "execution_count": null,
389
- "id": "baseline",
390
  "metadata": {},
391
  "outputs": [],
392
  "source": [
393
- "FastLanguageModel.for_inference(model)\n",
394
- "\n",
395
- "\n",
396
- "def run_episode_with_model(workflow_id: str, max_steps: int = 15) -> float:\n",
397
- " \"\"\"Run one full episode with the current model. Returns final score.\"\"\"\n",
398
- " result = httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": workflow_id}).json()\n",
399
- " obs = result[\"observation\"]\n",
400
- " history = []\n",
401
- "\n",
402
- " for _ in range(max_steps):\n",
403
- " if obs[\"done\"]:\n",
404
- " break\n",
405
- "\n",
406
- " obs_text = obs_to_text(obs)\n",
407
- " history.append({\"role\": \"user\", \"content\": obs_text})\n",
408
  "\n",
409
- " messages = list(history)\n",
410
- " messages[0] = {\"role\": \"user\",\n",
411
- " \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + messages[0][\"content\"]}\n",
412
- "\n",
413
- " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
414
- " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
415
- "\n",
416
- " with torch.no_grad():\n",
417
- " out = model.generate(\n",
418
- " **inputs,\n",
419
- " max_new_tokens = 256,\n",
420
- " temperature = 0.0,\n",
421
- " do_sample = False,\n",
422
- " pad_token_id = tokenizer.eos_token_id,\n",
423
- " )\n",
424
- " action_str = tokenizer.decode(\n",
425
- " out[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n",
426
- " ).strip()\n",
427
- "\n",
428
- " history.append({\"role\": \"assistant\", \"content\": action_str})\n",
429
- "\n",
430
- " action = parse_action(action_str)\n",
431
- " if action is None:\n",
432
- " break\n",
433
- "\n",
434
- " result = httpx.post(f\"{ENV_URL}/step\", json=action).json()\n",
435
- " obs = result[\"observation\"]\n",
436
- " if obs[\"done\"]:\n",
437
- " break\n",
438
- "\n",
439
- " return obs.get(\"current_score\", 0.001)\n",
440
- "\n",
441
- "\n",
442
- "N_EVAL = 10\n",
443
- "baseline_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
444
- "\n",
445
- "tlog(\"[EVAL_START] phase=baseline\")\n",
446
- "for wf in [\"A\", \"B\", \"C\"]:\n",
447
- " for ep in range(N_EVAL):\n",
448
- " score = run_episode_with_model(wf)\n",
449
- " baseline_scores[wf].append(score)\n",
450
- " tlog(f\"[EVAL] phase=baseline workflow={wf} episode={ep+1} score={score:.4f}\")\n",
451
- " wf_mean = np.mean(baseline_scores[wf])\n",
452
- " tlog(f\"[EVAL_WORKFLOW] phase=baseline workflow={wf} \"\n",
453
- " f\"mean={wf_mean:.4f} min={min(baseline_scores[wf]):.4f} max={max(baseline_scores[wf]):.4f}\")\n",
454
- "\n",
455
- "baseline_mean = np.mean([s for v in baseline_scores.values() for s in v])\n",
456
- "tlog(f\"[EVAL_END] phase=baseline overall_mean={baseline_mean:.4f}\")"
457
- ]
458
- },
459
- {
460
- "cell_type": "markdown",
461
- "id": "sec9",
462
- "metadata": {},
463
- "source": [
464
- "## 9. GRPO Training"
465
  ]
466
  },
467
  {
468
  "cell_type": "code",
469
- "execution_count": null,
470
- "id": "grpo_training",
471
  "metadata": {},
472
  "outputs": [],
473
  "source": [
474
- "from trl import GRPOConfig, GRPOTrainer\n",
475
- "from transformers import TrainerCallback\n",
476
- "\n",
477
- "model.train()\n",
478
- "\n",
479
- "NUM_EPOCHS = 3\n",
480
- "BATCH_SIZE = 4\n",
481
- "GRAD_ACCUM = 2\n",
482
- "LR = 5e-5\n",
483
- "NUM_GEN = 4\n",
484
- "TEMPERATURE = 0.8\n",
485
- "BETA = 0.04\n",
486
  "\n",
 
 
487
  "grpo_config = GRPOConfig(\n",
488
- " output_dir = \"./orgos_grpo_ckpt\",\n",
489
- " num_train_epochs = NUM_EPOCHS,\n",
490
- " per_device_train_batch_size = BATCH_SIZE,\n",
 
491
  " gradient_accumulation_steps = GRAD_ACCUM,\n",
492
- " learning_rate = LR,\n",
493
- " warmup_steps = 10,\n",
494
- " logging_steps = 5,\n",
495
- " save_steps = 100,\n",
496
- " bf16 = torch.cuda.is_bf16_supported(),\n",
497
- " fp16 = not torch.cuda.is_bf16_supported(),\n",
498
- " max_grad_norm = 1.0,\n",
499
- " num_generations = NUM_GEN,\n",
500
- " max_new_tokens = 256,\n",
501
- " temperature = TEMPERATURE,\n",
502
- " beta = BETA,\n",
503
- " report_to = \"none\",\n",
504
- " seed = 42,\n",
505
  ")\n",
506
  "\n",
507
- "tlog(f\"[TRAIN_CONFIG] epochs={NUM_EPOCHS} batch_size={BATCH_SIZE} \"\n",
508
- " f\"grad_accum={GRAD_ACCUM} lr={LR} num_generations={NUM_GEN} \"\n",
509
- " f\"temperature={TEMPERATURE} beta_kl={BETA}\")\n",
510
- "\n",
511
- "\n",
512
- "class OrgOSLogCallback(TrainerCallback):\n",
513
- " \"\"\"Logs each training step to training_log.txt.\"\"\"\n",
514
- "\n",
515
- " def on_log(self, args, state, control, logs=None, **kwargs):\n",
516
- " if logs is None:\n",
517
- " return\n",
518
- " step = state.global_step\n",
519
- " loss = logs.get(\"loss\", logs.get(\"train_loss\", \"?\"))\n",
520
- " mean_reward = logs.get(\"reward\", logs.get(\"mean_reward\", \"?\"))\n",
521
- " kl = logs.get(\"kl\", logs.get(\"approx_kl\", \"?\"))\n",
522
- " lr_now = logs.get(\"learning_rate\", \"?\")\n",
523
- "\n",
524
- " loss_str = f\"{loss:.6f}\" if isinstance(loss, float) else str(loss)\n",
525
- " reward_str = f\"{mean_reward:.4f}\" if isinstance(mean_reward, float) else str(mean_reward)\n",
526
- " kl_str = f\"{kl:.6f}\" if isinstance(kl, float) else str(kl)\n",
527
- " lr_str = f\"{lr_now:.2e}\" if isinstance(lr_now, float) else str(lr_now)\n",
528
- "\n",
529
- " tlog(f\"[TRAIN_STEP] step={step} loss={loss_str} \"\n",
530
- " f\"mean_reward={reward_str} kl={kl_str} lr={lr_str}\")\n",
531
- "\n",
532
- "\n",
533
  "trainer = GRPOTrainer(\n",
534
  " model = model,\n",
535
- " args = grpo_config,\n",
536
- " reward_funcs = orgos_reward_fn,\n",
537
- " train_dataset = prompt_dataset,\n",
538
  " processing_class = tokenizer,\n",
 
 
 
539
  " callbacks = [OrgOSLogCallback()],\n",
540
  ")\n",
541
  "\n",
542
- "tlog(\"[TRAIN_START]\")\n",
543
- "train_result = trainer.train()\n",
544
- "tlog(f\"[TRAIN_END] total_steps={train_result.global_step} \"\n",
545
- " f\"train_loss={train_result.training_loss:.6f} \"\n",
546
- " f\"train_runtime_s={train_result.metrics.get('train_runtime', 0):.1f}\")"
547
  ]
548
  },
549
  {
550
  "cell_type": "markdown",
551
- "id": "sec10",
552
  "metadata": {},
553
  "source": [
554
- "## 10. Collect Post-Training Scores"
 
 
555
  ]
556
  },
557
  {
558
  "cell_type": "code",
559
- "execution_count": null,
560
- "id": "post_training",
561
  "metadata": {},
562
  "outputs": [],
563
  "source": [
564
  "FastLanguageModel.for_inference(model)\n",
 
 
565
  "\n",
566
- "post_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
567
- "\n",
568
- "tlog(\"[EVAL_START] phase=post_training\")\n",
569
- "for wf in [\"A\", \"B\", \"C\"]:\n",
570
- " for ep in range(N_EVAL):\n",
571
- " score = run_episode_with_model(wf)\n",
572
- " post_scores[wf].append(score)\n",
573
- " tlog(f\"[EVAL] phase=post_training workflow={wf} episode={ep+1} score={score:.4f}\")\n",
574
- " wf_mean = np.mean(post_scores[wf])\n",
575
- " tlog(f\"[EVAL_WORKFLOW] phase=post_training workflow={wf} \"\n",
576
- " f\"mean={wf_mean:.4f} min={min(post_scores[wf]):.4f} max={max(post_scores[wf]):.4f}\")\n",
577
- "\n",
578
- "post_mean = np.mean([s for v in post_scores.values() for s in v])\n",
579
- "improvement = post_mean - baseline_mean\n",
580
- "tlog(f\"[EVAL_END] phase=post_training overall_mean={post_mean:.4f}\")\n",
581
- "tlog(\n",
582
- " f\"[TRAIN_SUMMARY] \"\n",
583
- " f\"model={MODEL_NAME} algorithm=GRPO \"\n",
584
- " f\"baseline_mean={baseline_mean:.4f} \"\n",
585
- " f\"post_training_mean={post_mean:.4f} \"\n",
586
- " f\"improvement={improvement:+.4f} \"\n",
587
- " f\"workflow_A_before={np.mean(baseline_scores['A']):.4f} \"\n",
588
- " f\"workflow_A_after={np.mean(post_scores['A']):.4f} \"\n",
589
- " f\"workflow_B_before={np.mean(baseline_scores['B']):.4f} \"\n",
590
- " f\"workflow_B_after={np.mean(post_scores['B']):.4f} \"\n",
591
- " f\"workflow_C_before={np.mean(baseline_scores['C']):.4f} \"\n",
592
- " f\"workflow_C_after={np.mean(post_scores['C']):.4f}\"\n",
593
- ")\n",
594
- "print(f\"\\nImprovement: {baseline_mean:.4f} → {post_mean:.4f} ({improvement:+.4f})\")"
595
  ]
596
  },
597
  {
598
  "cell_type": "markdown",
599
- "id": "sec11",
600
  "metadata": {},
601
  "source": [
602
- "## 11. Plot Before / After"
 
 
603
  ]
604
  },
605
  {
606
  "cell_type": "code",
607
- "execution_count": null,
608
- "id": "plot",
609
  "metadata": {},
610
  "outputs": [],
611
  "source": [
612
- "import matplotlib.pyplot as plt\n",
613
- "import matplotlib.gridspec as gridspec\n",
614
- "\n",
615
- "fig = plt.figure(figsize=(14, 8), facecolor=\"#0f172a\")\n",
616
- "fig.suptitle(\"OrgOS: Before vs After GRPO Training\", fontsize=15,\n",
617
- " color=\"white\", fontweight=\"bold\", y=0.98)\n",
618
- "\n",
619
- "gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)\n",
620
- "COLORS = {\"before\": \"#f87171\", \"after\": \"#34d399\", \"bg\": \"#1e293b\", \"grid\": \"#334155\"}\n",
621
- "WF_LABELS = {\n",
622
- " \"A\": \"Workflow A\\nCustomer Bug Fix\",\n",
623
- " \"B\": \"Workflow B\\nEmployee Onboarding\",\n",
624
- " \"C\": \"Workflow C\\nChurn Risk Alert\",\n",
625
- "}\n",
626
- "\n",
627
- "for col, wf in enumerate([\"A\", \"B\", \"C\"]):\n",
628
- " ax = fig.add_subplot(gs[0, col])\n",
629
- " ax.set_facecolor(COLORS[\"bg\"])\n",
630
- " ax.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.7)\n",
631
- " before = baseline_scores[wf]\n",
632
- " after = post_scores[wf]\n",
633
- " delta = np.mean(after) - np.mean(before)\n",
634
- " ax.plot(before, color=COLORS[\"before\"], linewidth=1.5, alpha=0.8, label=\"Before GRPO\")\n",
635
- " ax.plot(after, color=COLORS[\"after\"], linewidth=1.5, alpha=0.8, label=\"After GRPO\")\n",
636
- " ax.axhline(np.mean(before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
637
- " ax.axhline(np.mean(after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
638
- " ax.set_title(WF_LABELS[wf] + f\"\\n(Δ = {delta:+.4f})\", color=\"white\", fontsize=9)\n",
639
- " ax.set_xlabel(\"Episode\", color=\"#94a3b8\", fontsize=8)\n",
640
- " ax.set_ylabel(\"Final Score\", color=\"#94a3b8\", fontsize=8)\n",
641
- " ax.tick_params(colors=\"#64748b\", labelsize=7)\n",
642
- " ax.set_ylim(0, 1)\n",
643
- " ax.legend(fontsize=7, facecolor=\"#1e293b\", labelcolor=\"white\",\n",
644
- " edgecolor=\"#475569\", framealpha=0.8)\n",
645
- " for spine in ax.spines.values():\n",
646
- " spine.set_edgecolor(\"#334155\")\n",
647
- "\n",
648
- "ax_hist = fig.add_subplot(gs[1, :])\n",
649
- "ax_hist.set_facecolor(COLORS[\"bg\"])\n",
650
- "ax_hist.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.5, axis=\"x\")\n",
651
- "all_before = [s for v in baseline_scores.values() for s in v]\n",
652
- "all_after = [s for v in post_scores.values() for s in v]\n",
653
- "bins = np.linspace(0, 1, 25)\n",
654
- "ax_hist.hist(all_before, bins=bins, color=COLORS[\"before\"], alpha=0.6,\n",
655
- " label=f\"Before GRPO (mean={np.mean(all_before):.4f})\", edgecolor=\"none\")\n",
656
- "ax_hist.hist(all_after, bins=bins, color=COLORS[\"after\"], alpha=0.6,\n",
657
- " label=f\"After GRPO (mean={np.mean(all_after):.4f})\", edgecolor=\"none\")\n",
658
- "ax_hist.axvline(np.mean(all_before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1.5)\n",
659
- "ax_hist.axvline(np.mean(all_after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1.5)\n",
660
- "ax_hist.set_title(\"Score Distribution Across All Workflows\", color=\"white\", fontsize=10)\n",
661
- "ax_hist.set_xlabel(\"Final Score\", color=\"#94a3b8\", fontsize=9)\n",
662
- "ax_hist.set_ylabel(\"Count\", color=\"#94a3b8\", fontsize=9)\n",
663
- "ax_hist.tick_params(colors=\"#64748b\", labelsize=8)\n",
664
- "ax_hist.legend(fontsize=9, facecolor=\"#1e293b\", labelcolor=\"white\",\n",
665
- " edgecolor=\"#475569\", framealpha=0.9)\n",
666
- "for spine in ax_hist.spines.values():\n",
667
- " spine.set_edgecolor(\"#334155\")\n",
668
- "\n",
669
- "plt.savefig(\"before_after_curves.png\", dpi=150, bbox_inches=\"tight\",\n",
670
- " facecolor=\"#0f172a\", edgecolor=\"none\")\n",
671
  "plt.show()\n",
672
- "tlog(\"[ARTIFACT] file=before_after_curves.png\")\n",
673
- "print(\"Saved: before_after_curves.png\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
  ]
675
  },
676
  {
677
  "cell_type": "markdown",
678
- "id": "sec12",
679
  "metadata": {},
680
  "source": [
681
- "## 12. Save LoRA Adapter & Training Log"
 
 
682
  ]
683
  },
684
  {
685
  "cell_type": "code",
686
- "execution_count": null,
687
- "id": "save_model",
688
  "metadata": {},
689
  "outputs": [],
690
  "source": [
691
- "model.save_pretrained(\"orgos_lora_adapter\")\n",
692
- "tokenizer.save_pretrained(\"orgos_lora_adapter\")\n",
693
- "tlog(\"[ARTIFACT] file=orgos_lora_adapter/\")\n",
694
- "tlog(\"[ARTIFACT] file=training_log.txt\")\n",
695
- "\n",
696
- "print(f\"\\n{'='*60}\")\n",
697
- "print(\" Submission artefacts\")\n",
698
- "print(f\"{'='*60}\")\n",
699
- "print(\" training_log.txt — structured training log\")\n",
700
- "print(\" before_after_curves.png — score improvement chart\")\n",
701
- "print(\" orgos_lora_adapter/ — LoRA weights\")\n",
702
- "print(f\"{'='*60}\")\n",
703
- "\n",
704
- "# Optional: push to HuggingFace Hub\n",
705
- "# from huggingface_hub import login\n",
706
- "# login(token=\"YOUR_HF_TOKEN\")\n",
707
- "# model.push_to_hub(\"YOUR_USERNAME/orgos-qwen25-3b-grpo\")\n",
708
- "# tokenizer.push_to_hub(\"YOUR_USERNAME/orgos-qwen25-3b-grpo\")"
 
 
 
 
 
 
 
 
 
 
 
 
709
  ]
710
  }
711
- ],
712
- "metadata": {
713
- "kernelspec": {
714
- "display_name": "Python 3",
715
- "language": "python",
716
- "name": "python3"
717
- },
718
- "language_info": {
719
- "name": "python",
720
- "version": "3.10.0"
721
- }
722
- },
723
- "nbformat": 4,
724
- "nbformat_minor": 5
725
  }
 
1
  {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 5,
4
+ "metadata": {
5
+ "kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"},
6
+ "language_info": {"name": "python", "version": "3.10.0"},
7
+ "accelerator": "GPU",
8
+ "colab": {"gpuType": "T4"}
9
+ },
10
  "cells": [
11
  {
12
  "cell_type": "markdown",
13
+ "id": "cell-0",
14
  "metadata": {},
15
  "source": [
16
+ "# OrgOS GRPO Training on a Multi-App Enterprise RL Environment\n",
17
+ "\n",
18
+ "**Project:** OrgOS — an OpenEnv environment that simulates enterprise workflows across **Jira, Zendesk, Salesforce, and Workday** with realistic challenges: schema drift, RBAC, SLA constraints, and policy drift.\n",
19
+ "\n",
20
+ "**Goal of this notebook:** Fine-tune `Qwen2.5-3B-Instruct` with **GRPO** (Group Relative Policy Optimization) using **live environment rewards**, then compare the trained agent against the untrained baseline.\n",
21
+ "\n",
22
+ "**Hardware:** Colab T4 (free tier, 16 GB VRAM). End-to-end runtime ≈ 45–60 min.\n",
23
+ "\n",
24
+ "**Outputs (committed to the repo):**\n",
25
+ "- `training/training_log.txt` structured logs (`[TRAIN_CONFIG]`, `[EVAL]`, `[TRAIN_STEP]`, …)\n",
26
+ "- `training/plots/training_curve.png` mean reward vs GRPO step\n",
27
+ "- `training/plots/baseline_vs_trained.png` per-workflow comparison\n",
28
+ "- `training/plots/score_distribution.png` per-episode score distribution\n",
29
+ "- `training/orgos_lora_adapter/` trained LoRA weights\n",
30
+ "\n",
31
+ "Reviewers can open this notebook on Colab → Runtime → *Run all* and reproduce every artifact end-to-end."
 
 
 
 
32
  ]
33
  },
34
  {
35
  "cell_type": "markdown",
36
+ "id": "cell-1",
37
  "metadata": {},
38
+ "source": ["## 1. Setup — install dependencies and clone the repo"]
 
 
39
  },
40
  {
41
  "cell_type": "code",
42
+ "id": "cell-2",
 
43
  "metadata": {},
44
  "outputs": [],
45
  "source": [
46
+ "# Pin TRL to the version Unsloth requires BEFORE installing unsloth.\n",
47
+ "# trl 1.x breaks Unsloth's GRPOTrainer patches keep it <=0.24.\n",
48
+ "%pip install -q \"trl>=0.18.2,<=0.24.0\" peft accelerate bitsandbytes datasets\n",
49
+ "# Install Unsloth after TRL so its patches apply to the right TRL version.\n",
50
+ "%pip install -q --upgrade unsloth\n",
51
+ "%pip install -q fastapi 'uvicorn[standard]' pydantic httpx faker openai aiofiles"
52
  ]
53
  },
54
  {
55
+ "cell_type": "code",
56
+ "id": "cell-3",
57
  "metadata": {},
58
+ "outputs": [],
59
  "source": [
60
+ "# Clone the OrgOS dev repo (env server, models, business rules)\n",
61
+ "import os\n",
62
+ "REPO_URL = 'https://github.com/Tanvi51204/OpenEnv-Round-2.git'\n",
63
+ "if not os.path.isdir('/content/OpenEnv-Round-2'):\n",
64
+ " !git clone {REPO_URL} /content/OpenEnv-Round-2\n",
65
+ "%cd /content/OpenEnv-Round-2"
66
  ]
67
  },
68
  {
69
  "cell_type": "code",
70
+ "id": "cell-4",
 
71
  "metadata": {},
72
  "outputs": [],
73
  "source": [
74
+ "# Imports — keep `import unsloth` first to register its patches.\n",
75
+ "import unsloth\n",
76
  "\n",
77
+ "import json, os, re, sys, time, subprocess\n",
78
+ "from pathlib import Path\n",
79
+ "from typing import List\n",
80
  "\n",
81
+ "import httpx\n",
82
+ "import numpy as np\n",
83
+ "import torch\n",
84
+ "import matplotlib.pyplot as plt\n",
85
+ "from datasets import Dataset\n",
86
+ "from transformers import TrainerCallback\n",
87
+ "from trl import GRPOConfig, GRPOTrainer\n",
88
+ "from unsloth import FastLanguageModel\n",
89
  "\n",
90
+ "torch.set_float32_matmul_precision('high')"
 
 
91
  ]
92
  },
93
  {
94
  "cell_type": "markdown",
95
+ "id": "cell-5",
96
  "metadata": {},
97
+ "source": ["## 2. Configuration"]
98
+ },
99
+ {
100
+ "cell_type": "code",
101
+ "id": "cell-6",
102
+ "metadata": {},
103
+ "outputs": [],
104
  "source": [
105
+ "# ---- Model ----\n",
106
+ "MODEL_NAME = 'unsloth/Qwen2.5-3B-Instruct-bnb-4bit'\n",
107
+ "MAX_SEQ_LEN = 4096\n",
108
+ "LORA_R = 16\n",
109
+ "LORA_ALPHA = 16\n",
110
+ "\n",
111
+ "# ---- Environment ----\n",
112
+ "ENV_URL = 'http://localhost:8000'\n",
113
+ "WORKFLOWS = ['A', 'B', 'C']\n",
114
+ "\n",
115
+ "# ---- Data / eval ----\n",
116
+ "N_PROMPTS_PER_WORKFLOW = 20 # 20 × 3 = 60 prompts\n",
117
+ "N_EVAL_EPISODES = 5 # episodes per workflow at eval time\n",
118
+ "MAX_EPISODE_STEPS = 15\n",
119
+ "\n",
120
+ "# ---- GRPO ----\n",
121
+ "MAX_TRAIN_STEPS = 150 # more steps for better convergence\n",
122
+ "NUM_GENERATIONS = 2 # G = candidates per prompt (keep low for T4 VRAM)\n",
123
+ "PER_DEVICE_BATCH = 1\n",
124
+ "GRAD_ACCUM = 2 # effective batch = 2 with grad accum\n",
125
+ "LEARNING_RATE = 8e-6\n",
126
+ "MAX_COMPLETION_LENGTH = 256\n",
127
+ "REWARD_STEPS = 2 # multi-step rollout depth in reward fn\n",
128
+ "\n",
129
+ "# ---- Drive checkpoint (saves every N steps so Colab disconnects don't lose progress) ----\n",
130
+ "CKPT_EVERY_STEPS = 30\n",
131
+ "\n",
132
+ "# ---- Output paths ----\n",
133
+ "TRAIN_DIR = Path('/content/OpenEnv-Round-2/training')\n",
134
+ "PLOTS_DIR = TRAIN_DIR / 'plots'\n",
135
+ "ADAPTER_DIR = TRAIN_DIR / 'orgos_lora_adapter'\n",
136
+ "LOG_PATH = TRAIN_DIR / 'training_log.txt'\n",
137
+ "PLOTS_DIR.mkdir(parents=True, exist_ok=True)"
138
  ]
139
  },
140
  {
141
  "cell_type": "code",
142
+ "id": "cell-7",
 
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
146
+ "# Structured logger — every important event goes through this so submission has a clean log.\n",
147
+ "LOG_PATH.write_text('') # truncate\n",
148
  "\n",
149
+ "def tlog(line: str):\n",
 
 
 
 
 
 
 
 
 
150
  " print(line, flush=True)\n",
151
+ " with open(LOG_PATH, 'a') as f:\n",
152
+ " f.write(line + '\\n')"
 
 
 
153
  ]
154
  },
155
  {
156
  "cell_type": "markdown",
157
+ "id": "cell-8",
158
  "metadata": {},
159
  "source": [
160
+ "## 3. Start the OrgOS environment server\n",
161
+ "\n",
162
+ "We launch the FastAPI env server (`server/app.py`) as a background subprocess. The reward function and eval loop call it over HTTP at `localhost:8000`."
163
  ]
164
  },
165
  {
166
  "cell_type": "code",
167
+ "id": "cell-9",
 
168
  "metadata": {},
169
  "outputs": [],
170
  "source": [
171
+ "ENV_PROC = subprocess.Popen(\n",
172
+ " [sys.executable, '-m', 'uvicorn', 'server.app:app', '--host', '0.0.0.0', '--port', '8000'],\n",
173
+ " cwd='/content/OpenEnv-Round-2',\n",
 
174
  " stdout=subprocess.DEVNULL,\n",
175
  " stderr=subprocess.DEVNULL,\n",
176
  ")\n",
177
+ "for _ in range(30):\n",
178
+ " try:\n",
179
+ " r = httpx.get(f'{ENV_URL}/health', timeout=2)\n",
180
+ " if r.status_code == 200:\n",
181
+ " tlog(f\"[ENV] status={r.json().get('status')} version={r.json().get('version','?')}\")\n",
182
+ " break\n",
183
+ " except Exception:\n",
184
+ " time.sleep(1)\n",
185
+ "else:\n",
186
+ " raise RuntimeError('Env server failed to start')"
187
  ]
188
  },
189
  {
190
  "cell_type": "markdown",
191
+ "id": "cell-10",
192
  "metadata": {},
193
+ "source": ["## 4. Load model — Qwen2.5-3B-Instruct, 4-bit, with LoRA adapters"]
 
 
194
  },
195
  {
196
  "cell_type": "code",
197
+ "id": "cell-11",
 
198
  "metadata": {},
199
  "outputs": [],
200
  "source": [
 
 
 
 
 
 
 
201
  "model, tokenizer = FastLanguageModel.from_pretrained(\n",
202
  " model_name = MODEL_NAME,\n",
203
  " max_seq_length = MAX_SEQ_LEN,\n",
 
208
  "model = FastLanguageModel.get_peft_model(\n",
209
  " model,\n",
210
  " r = LORA_R,\n",
211
+ " lora_alpha = LORA_ALPHA,\n",
212
+ " target_modules = ['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'],\n",
213
+ " use_gradient_checkpointing = 'unsloth',\n",
 
 
 
 
214
  ")\n",
215
  "\n",
216
+ "# Clear max_length so generate() doesn't warn about max_new_tokens vs max_length conflict.\n",
217
+ "model.config.max_length = None\n",
218
+ "\n",
219
  "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
220
+ "tlog(f'[TRAIN_CONFIG] model={MODEL_NAME} lora_r={LORA_R} max_seq_len={MAX_SEQ_LEN} '\n",
221
+ " f'trainable_params={trainable:,} quantization=4bit')"
222
  ]
223
  },
224
  {
225
  "cell_type": "markdown",
226
+ "id": "cell-12",
227
  "metadata": {},
228
  "source": [
229
+ "## 5. Helpers — system prompt, observation formatting, action parsing\n",
230
+ "\n",
231
+ "The agent gets a **stateless single-turn prompt**: `[SYSTEM_PROMPT] + [observation]` → `[action JSON]`. This matches what GRPO trains on, which is critical for eval/train alignment, and prevents context accumulation over a multi-step episode."
232
  ]
233
  },
234
  {
235
  "cell_type": "code",
236
+ "id": "cell-13",
 
237
  "metadata": {},
238
  "outputs": [],
239
  "source": [
240
+ "SYSTEM_PROMPT = '''You are OrgOS Agent — an enterprise workflow automation agent.\n",
241
+ "You operate across four SaaS apps: Jira, Zendesk, Salesforce, and Workday.\n",
 
 
242
  "\n",
243
+ "Each turn you receive a JSON observation with workflow_goal, pending_steps, app_states,\n",
244
+ "schema_hints (field renames in effect this episode, e.g. {\"jira.priority\": \"severity\"}),\n",
245
+ "active_rules, message (feedback from last action), and current_score.\n",
246
  "\n",
247
+ "Respond ONLY with a valid JSON object — no markdown, no explanation:\n",
 
 
 
 
 
 
 
 
 
 
 
248
  " {\"app\": \"<app>\", \"operation\": \"<op>\", \"args\": {...}}\n",
249
  "\n",
250
  "Available apps and key operations:\n",
251
  " jira: get_issue, create_issue, update_status, set_priority, assign_owner,\n",
252
  " add_label, link_zendesk_ticket, close_issue, list_issues\n",
253
  " zendesk: get_ticket, acknowledge_ticket, set_urgency, assign_agent,\n",
254
+ " escalate_to_jira, resolve_ticket, add_note, list_tickets, create_agent_profile\n",
 
255
  " salesforce: get_account, list_accounts, update_deal_stage, flag_churn_risk,\n",
256
  " assign_account_owner, log_interaction, get_opportunity\n",
257
  " workday: get_employee, list_employees, provision_access, log_sla_event,\n",
258
  " request_budget_approval, create_onboarding_task, complete_task\n",
259
  "\n",
260
  "CRITICAL RULES:\n",
261
+ "1. Read schema_hints FIRST. If \"salesforce.owner\" -> \"rep_email\", use \"rep_email\" not \"owner\".\n",
262
+ "2. Complete pending_steps in order.\n",
263
+ "3. Never repeat a failed action unchanged — read the message and adapt.\n",
264
+ "4. Use list_* operations to discover record IDs.\n",
265
+ "5. Stop when pending_steps is empty or done=true.'''"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "id": "cell-14",
271
+ "metadata": {},
272
+ "outputs": [],
273
+ "source": [
274
+ "WORKFLOW_APPS = {\n",
275
+ " 'A': {'jira', 'zendesk', 'salesforce', 'workday'},\n",
276
+ " 'B': {'zendesk', 'salesforce', 'workday'},\n",
277
+ " 'C': {'jira', 'zendesk', 'salesforce'},\n",
278
+ "}\n",
279
  "\n",
280
  "def obs_to_text(obs: dict) -> str:\n",
281
+ " hints = obs.get('schema_hints', {})\n",
282
+ " pending = obs.get('pending_steps', [])\n",
283
  " lines = [\n",
284
  " f\"current_score: {obs['current_score']}\",\n",
285
  " f\"step_count: {obs['step_count']}\",\n",
286
  " f\"workflow_id: {obs['workflow_id']}\",\n",
287
+ " '',\n",
288
+ " '=== WORKFLOW GOAL ===' , obs['workflow_goal'], '',\n",
289
+ " '=== PENDING STEPS ===',\n",
290
+ " '\\n'.join(f' - {s}' for s in pending) or ' (all steps complete!)',\n",
291
+ " '',\n",
292
+ " '=== SCHEMA HINTS (use these field names) ===',\n",
293
+ " json.dumps(hints, indent=2) if hints else ' (no drift — use canonical names)',\n",
294
+ " '',\n",
295
+ " '=== ACTIVE RULES ===',\n",
296
+ " json.dumps(obs.get('active_rules', {}), indent=2),\n",
297
+ " '',\n",
298
+ " '=== LAST MESSAGE ===', obs['message'], '',\n",
299
+ " '=== APP STATES ===',\n",
 
 
 
 
300
  " ]\n",
301
+ " relevant = WORKFLOW_APPS.get(obs.get('workflow_id', 'A'),\n",
302
+ " {'jira','zendesk','salesforce','workday'})\n",
303
+ " for app_name, view in obs.get('app_states', {}).items():\n",
304
+ " if app_name not in relevant:\n",
305
+ " continue\n",
306
+ " view_str = str(view)\n",
307
+ " if len(view_str) > 600:\n",
308
+ " view_str = view_str[:600] + '...[truncated]'\n",
309
+ " lines += [f' [{app_name.upper()}]', f' {view_str}', '']\n",
310
+ " return '\\n'.join(lines)\n",
 
 
 
311
  "\n",
312
  "def parse_action(text: str):\n",
313
+ " text = re.sub(r'```(?:json)?\\s*', '', text.strip()).strip()\n",
314
  " try:\n",
315
  " return json.loads(text)\n",
316
  " except json.JSONDecodeError:\n",
317
+ " m = re.search(r'\\{.*\\}', text, re.DOTALL)\n",
318
  " if m:\n",
319
+ " try: return json.loads(m.group())\n",
320
+ " except Exception: return None\n",
 
 
321
  " return None\n",
322
  "\n",
323
+ "def build_prompt(obs_text: str) -> str:\n",
324
+ " msgs = [{'role': 'user', 'content': SYSTEM_PROMPT + '\\n\\n---\\n\\n' + obs_text}]\n",
325
+ " return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "markdown",
330
+ "id": "cell-15",
331
+ "metadata": {},
332
+ "source": [
333
+ "## 6. Episode runner & evaluator\n",
334
  "\n",
335
+ "`run_episode_with_model` is **stateless** — each step sends just `[system + current obs]`, no chat history. This (a) keeps prompts under `MAX_SEQ_LEN`, (b) matches the GRPO training format exactly, and (c) avoids context accumulation across multi-step episodes."
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "id": "cell-16",
341
+ "metadata": {},
342
+ "outputs": [],
343
+ "source": [
344
+ "def run_episode_with_model(workflow_id: str, max_steps: int = MAX_EPISODE_STEPS) -> float:\n",
345
+ " obs = httpx.post(f'{ENV_URL}/reset', json={'workflow_id': workflow_id}).json()['observation']\n",
346
+ " for _ in range(max_steps):\n",
347
+ " if obs['done']:\n",
348
+ " break\n",
349
+ " prompt = build_prompt(obs_to_text(obs))\n",
350
+ " inputs = tokenizer(prompt, return_tensors='pt').to(model.device)\n",
351
+ " with torch.no_grad():\n",
352
+ " out = model.generate(\n",
353
+ " **inputs,\n",
354
+ " max_new_tokens = 256,\n",
355
+ " do_sample = False,\n",
356
+ " pad_token_id = tokenizer.eos_token_id,\n",
357
+ " )\n",
358
+ " action_str = tokenizer.decode(out[0][inputs['input_ids'].shape[1]:],\n",
359
+ " skip_special_tokens=True).strip()\n",
360
+ " action = parse_action(action_str)\n",
361
+ " if action is None:\n",
362
+ " break\n",
363
+ " result = httpx.post(f'{ENV_URL}/step', json=action).json()\n",
364
+ " obs = result['observation']\n",
365
+ " if obs['done']:\n",
366
+ " break\n",
367
+ " return float(obs.get('current_score', 0.001))\n",
368
+ "\n",
369
+ "def evaluate(phase: str, n_eval: int = N_EVAL_EPISODES) -> dict:\n",
370
+ " scores = {wf: [] for wf in WORKFLOWS}\n",
371
+ " tlog(f'[EVAL_START] phase={phase}')\n",
372
+ " for wf in WORKFLOWS:\n",
373
+ " for ep in range(n_eval):\n",
374
+ " s = run_episode_with_model(wf)\n",
375
+ " scores[wf].append(s)\n",
376
+ " tlog(f'[EVAL] phase={phase} workflow={wf} episode={ep+1} score={s:.4f}')\n",
377
+ " m = float(np.mean(scores[wf]))\n",
378
+ " tlog(f'[EVAL_WORKFLOW] phase={phase} workflow={wf} '\n",
379
+ " f'mean={m:.4f} min={min(scores[wf]):.4f} max={max(scores[wf]):.4f}')\n",
380
+ " overall = float(np.mean([s for v in scores.values() for s in v]))\n",
381
+ " tlog(f'[EVAL_END] phase={phase} overall_mean={overall:.4f}')\n",
382
+ " return scores"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "markdown",
387
+ "id": "cell-17",
388
+ "metadata": {},
389
+ "source": [
390
+ "## 7. Baseline evaluation — *before* training\n",
391
+ "\n",
392
+ "This is the untrained Qwen2.5-3B reference point. We will compare against this after GRPO training."
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "id": "cell-18",
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": [
401
+ "FastLanguageModel.for_inference(model)\n",
402
+ "baseline_scores = evaluate(phase='baseline')\n",
403
+ "baseline_overall = float(np.mean([s for v in baseline_scores.values() for s in v]))"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "markdown",
408
+ "id": "cell-19",
409
+ "metadata": {},
410
+ "source": [
411
+ "## 8. Build the prompt dataset for GRPO\n",
412
  "\n",
413
+ "We collect 60 fresh observations (20 per workflow) by resetting the env. GRPO will sample from this dataset, generate G=2 candidate actions per prompt, score each via the live env, and update the policy from the relative advantages."
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "id": "cell-20",
419
+ "metadata": {},
420
+ "outputs": [],
421
+ "source": [
422
+ "rows = []\n",
423
+ "for wf in WORKFLOWS:\n",
424
  " for _ in range(N_PROMPTS_PER_WORKFLOW):\n",
425
+ " obs = httpx.post(f'{ENV_URL}/reset', json={'workflow_id': wf}).json()['observation']\n",
426
+ " rows.append({\n",
427
+ " 'prompt': build_prompt(obs_to_text(obs)),\n",
428
+ " 'workflow_id': wf,\n",
 
 
 
429
  " })\n",
430
+ "prompt_dataset = Dataset.from_list(rows)\n",
431
+ "tlog(f'[TRAIN_CONFIG] algorithm=GRPO prompts={len(prompt_dataset)} '\n",
432
+ " f'workflows={\",\".join(WORKFLOWS)} prompts_per_workflow={N_PROMPTS_PER_WORKFLOW}')\n",
433
  "\n",
434
+ "tok_len = len(tokenizer(prompt_dataset[0]['prompt']).input_ids)\n",
435
+ "tlog(f'[PROMPT_DEBUG] first_prompt_tokens={tok_len}')"
 
 
436
  ]
437
  },
438
  {
439
  "cell_type": "markdown",
440
+ "id": "cell-21",
441
  "metadata": {},
442
  "source": [
443
+ "## 9. Reward function — multi-step live environment rollout\n",
444
+ "\n",
445
+ "For each candidate completion we:\n",
446
+ "1. Parse it as a JSON action.\n",
447
+ "2. Reset the env and apply the action (step 1).\n",
448
+ "3. Continue `REWARD_STEPS-1` more steps with the current model (greedy), accumulating env transitions.\n",
449
+ "4. Return the **cumulative episode score** — not just single-step reward.\n",
450
+ "\n",
451
+ "This multi-step signal prevents the model from collapsing to always outputting `list_tickets` (which gives a small single-step reward but doesn't advance the workflow). Invalid JSON gets a −0.1 penalty."
452
  ]
453
  },
454
  {
455
  "cell_type": "code",
456
+ "id": "cell-22",
 
457
  "metadata": {},
458
  "outputs": [],
459
  "source": [
460
+ "def orgos_reward_fn(completions: List[str], prompts: List[str] = None, **kwargs) -> List[float]:\n",
461
+ " workflow_ids = kwargs.get('workflow_id', ['A'] * len(completions))\n",
 
 
 
 
462
  " rewards = []\n",
 
463
  " for completion, wf_id in zip(completions, workflow_ids):\n",
464
  " action = parse_action(completion)\n",
465
  " if action is None:\n",
466
  " rewards.append(-0.1)\n",
467
  " continue\n",
468
  " try:\n",
469
+ " # Reset env and apply GRPO-generated action (step 1)\n",
470
+ " obs = httpx.post(f'{ENV_URL}/reset', json={'workflow_id': wf_id}, timeout=10).json()['observation']\n",
471
+ " result = httpx.post(f'{ENV_URL}/step', json=action, timeout=10).json()\n",
472
+ " obs = result['observation']\n",
473
+ "\n",
474
+ " # Continue REWARD_STEPS-1 more steps with the current model (greedy)\n",
475
+ " for _ in range(REWARD_STEPS - 1):\n",
476
+ " if obs.get('done'):\n",
477
+ " break\n",
478
+ " prompt = build_prompt(obs_to_text(obs))\n",
479
+ " inputs = tokenizer(prompt, return_tensors='pt').to(model.device)\n",
480
+ " with torch.no_grad():\n",
481
+ " out = model.generate(\n",
482
+ " **inputs,\n",
483
+ " max_new_tokens = 128,\n",
484
+ " do_sample = False,\n",
485
+ " pad_token_id = tokenizer.eos_token_id,\n",
486
+ " )\n",
487
+ " cont_str = tokenizer.decode(\n",
488
+ " out[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True\n",
489
+ " ).strip()\n",
490
+ " cont_action = parse_action(cont_str)\n",
491
+ " if cont_action is None:\n",
492
+ " break\n",
493
+ " result = httpx.post(f'{ENV_URL}/step', json=cont_action, timeout=10).json()\n",
494
+ " obs = result['observation']\n",
495
+ "\n",
496
+ " # Return cumulative score — rewards multi-step progress, not just single actions\n",
497
+ " rewards.append(float(obs.get('current_score', 0.001)))\n",
498
+ " except Exception as e:\n",
499
  " rewards.append(-0.1)\n",
 
500
  " return rewards\n",
501
  "\n",
 
502
  "# Sanity check\n",
503
+ "_v = orgos_reward_fn(['{\\'app\\':\\'zendesk\\',\\'operation\\':\\'list_tickets\\',\\'args\\':{}}'], workflow_id=['A'])\n",
504
+ "_i = orgos_reward_fn(['not json'], workflow_id=['A'])\n",
505
+ "tlog(f'[REWARD_FN_CHECK] valid_action={_v[0]:.4f} invalid_action={_i[0]:.4f}')"
 
 
 
 
506
  ]
507
  },
508
  {
509
  "cell_type": "markdown",
510
+ "id": "cell-23",
511
  "metadata": {},
512
+ "source": ["## 10. GRPO training\n",
513
+ "\n",
514
+ "We log every training step's reward into `[TRAIN_STEP]` lines so we can plot a meaningful learning curve.\n",
515
+ "A Drive checkpoint callback saves the adapter every 30 steps so a Colab disconnect doesn't lose progress."]
516
  },
517
  {
518
  "cell_type": "code",
519
+ "id": "cell-24",
 
520
  "metadata": {},
521
  "outputs": [],
522
  "source": [
523
+ "training_step_rewards = [] # captured by callback for the plot\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
524
  "\n",
525
+ "class OrgOSLogCallback(TrainerCallback):\n",
526
+ " def on_log(self, args, state, control, logs=None, **kwargs):\n",
527
+ " if not logs:\n",
528
+ " return\n",
529
+ " step = state.global_step\n",
530
+ " reward = logs.get('reward') or logs.get('rewards/orgos_reward_fn') or logs.get('reward/mean')\n",
531
+ " loss = logs.get('loss')\n",
532
+ " kl = logs.get('kl')\n",
533
+ " if reward is not None:\n",
534
+ " training_step_rewards.append((step, float(reward)))\n",
535
+ " tlog(f'[TRAIN_STEP] step={step} reward={float(reward):.4f} '\n",
536
+ " f\"loss={('%.4f'%loss) if loss is not None else 'NA'} \"\n",
537
+ " f\"kl={('%.4f'%kl) if kl is not None else 'NA'}\")\n",
538
+ "\n",
539
+ " def on_step_end(self, args, state, control, **kwargs):\n",
540
+ " \"\"\"Save checkpoint to Drive every CKPT_EVERY_STEPS steps.\"\"\"\n",
541
+ " if state.global_step % CKPT_EVERY_STEPS == 0 and state.global_step > 0:\n",
542
+ " try:\n",
543
+ " from google.colab import drive\n",
544
+ " drive.mount('/content/drive', force_remount=False)\n",
545
+ " ckpt_dir = Path(f'/content/drive/MyDrive/orgos-training/ckpt_step{state.global_step}')\n",
546
+ " ckpt_dir.mkdir(parents=True, exist_ok=True)\n",
547
+ " model.save_pretrained(str(ckpt_dir))\n",
548
+ " import shutil\n",
549
+ " shutil.copy(LOG_PATH, ckpt_dir / 'training_log.txt')\n",
550
+ " tlog(f'[CHECKPOINT] step={state.global_step} saved to {ckpt_dir}')\n",
551
+ " except Exception:\n",
552
+ " pass # not on Colab or Drive not mounted — skip silently"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
  ]
554
  },
555
  {
556
  "cell_type": "code",
557
+ "id": "cell-25",
 
558
  "metadata": {},
559
  "outputs": [],
560
  "source": [
561
+ "FastLanguageModel.for_training(model)\n",
 
 
 
 
 
 
 
 
 
 
 
562
  "\n",
563
+ "# GRPOConfig — using TRL <=0.24 (pinned in cell 2) so max_new_tokens is accepted.\n",
564
+ "# Unsloth patches this config; max_prompt_length / max_completion_length are NOT supported.\n",
565
  "grpo_config = GRPOConfig(\n",
566
+ " output_dir = '/content/grpo_ckpt',\n",
567
+ " num_train_epochs = 1,\n",
568
+ " max_steps = MAX_TRAIN_STEPS,\n",
569
+ " per_device_train_batch_size = PER_DEVICE_BATCH,\n",
570
  " gradient_accumulation_steps = GRAD_ACCUM,\n",
571
+ " learning_rate = LEARNING_RATE,\n",
572
+ " num_generations = NUM_GENERATIONS,\n",
573
+ " max_new_tokens = MAX_COMPLETION_LENGTH,\n",
574
+ " temperature = 0.9,\n",
575
+ " logging_steps = 1,\n",
576
+ " save_strategy = 'no',\n",
577
+ " report_to = 'none',\n",
578
+ " bf16 = False,\n",
579
+ " fp16 = True,\n",
580
+ " optim = 'adamw_8bit',\n",
 
 
 
581
  ")\n",
582
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
  "trainer = GRPOTrainer(\n",
584
  " model = model,\n",
 
 
 
585
  " processing_class = tokenizer,\n",
586
+ " reward_funcs = [orgos_reward_fn],\n",
587
+ " train_dataset = prompt_dataset,\n",
588
+ " args = grpo_config,\n",
589
  " callbacks = [OrgOSLogCallback()],\n",
590
  ")\n",
591
  "\n",
592
+ "tlog(f'[TRAIN_START] max_steps={MAX_TRAIN_STEPS} G={NUM_GENERATIONS} lr={LEARNING_RATE} reward_steps={REWARD_STEPS}')\n",
593
+ "trainer.train()\n",
594
+ "tlog(f'[TRAIN_END] steps_completed={trainer.state.global_step}')"
 
 
595
  ]
596
  },
597
  {
598
  "cell_type": "markdown",
599
+ "id": "cell-26",
600
  "metadata": {},
601
  "source": [
602
+ "## 11. Post-training evaluation\n",
603
+ "\n",
604
+ "Same protocol as the baseline (3 workflows × 5 episodes), so the comparison is apples-to-apples."
605
  ]
606
  },
607
  {
608
  "cell_type": "code",
609
+ "id": "cell-27",
 
610
  "metadata": {},
611
  "outputs": [],
612
  "source": [
613
  "FastLanguageModel.for_inference(model)\n",
614
+ "trained_scores = evaluate(phase='trained')\n",
615
+ "trained_overall = float(np.mean([s for v in trained_scores.values() for s in v]))\n",
616
  "\n",
617
+ "tlog('[TRAIN_SUMMARY] '\n",
618
+ " f'baseline_overall={baseline_overall:.4f} trained_overall={trained_overall:.4f} '\n",
619
+ " f'delta={trained_overall - baseline_overall:+.4f}')"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
620
  ]
621
  },
622
  {
623
  "cell_type": "markdown",
624
+ "id": "cell-28",
625
  "metadata": {},
626
  "source": [
627
+ "## 12. Plots\n",
628
+ "\n",
629
+ "All plots are saved to `training/plots/` and committed to the repo so reviewers can see them in the README."
630
  ]
631
  },
632
  {
633
  "cell_type": "code",
634
+ "id": "cell-29",
 
635
  "metadata": {},
636
  "outputs": [],
637
  "source": [
638
+ "# 12a. Training curve — mean reward vs GRPO step\n",
639
+ "if training_step_rewards:\n",
640
+ " steps, rewards = zip(*training_step_rewards)\n",
641
+ " plt.figure(figsize=(8,5))\n",
642
+ " plt.plot(steps, rewards, marker='o', markersize=3, linewidth=1.5, color='tab:blue', label='per-step reward')\n",
643
+ " if len(rewards) >= 5:\n",
644
+ " win = max(3, len(rewards)//10)\n",
645
+ " roll = np.convolve(rewards, np.ones(win)/win, mode='valid')\n",
646
+ " plt.plot(steps[win-1:], roll, color='tab:orange', linewidth=2.5, label=f'rolling mean (w={win})')\n",
647
+ " plt.xlabel('GRPO training step')\n",
648
+ " plt.ylabel('mean reward (per batch)')\n",
649
+ " plt.title('OrgOS GRPO training curve — Qwen2.5-3B-Instruct')\n",
650
+ " plt.legend()\n",
651
+ " plt.grid(alpha=0.3)\n",
652
+ " plt.tight_layout()\n",
653
+ " plt.savefig(PLOTS_DIR / 'training_curve.png', dpi=150)\n",
654
+ " plt.show()\n",
655
+ " tlog('[ARTIFACT] training_curve.png saved')"
656
+ ]
657
+ },
658
+ {
659
+ "cell_type": "code",
660
+ "id": "cell-30",
661
+ "metadata": {},
662
+ "outputs": [],
663
+ "source": [
664
+ "# 12b. Baseline vs trained grouped bar per workflow\n",
665
+ "x = np.arange(len(WORKFLOWS))\n",
666
+ "width = 0.35\n",
667
+ "baseline_means = [np.mean(baseline_scores[wf]) for wf in WORKFLOWS]\n",
668
+ "trained_means = [np.mean(trained_scores[wf]) for wf in WORKFLOWS]\n",
669
+ "\n",
670
+ "fig, ax = plt.subplots(figsize=(8,5))\n",
671
+ "ax.bar(x - width/2, baseline_means, width, label='baseline (untrained)', color='tab:gray')\n",
672
+ "ax.bar(x + width/2, trained_means, width, label='GRPO-trained', color='tab:green')\n",
673
+ "ax.set_xticks(x)\n",
674
+ "ax.set_xticklabels([f'Workflow {wf}' for wf in WORKFLOWS])\n",
675
+ "ax.set_ylabel('mean episode score (0–1)')\n",
676
+ "ax.set_ylim(0, 1)\n",
677
+ "ax.set_title(f'Baseline vs GRPO-trained overall {baseline_overall:.3f} {trained_overall:.3f}')\n",
678
+ "ax.legend()\n",
679
+ "ax.grid(axis='y', alpha=0.3)\n",
680
+ "for i, (b, t) in enumerate(zip(baseline_means, trained_means)):\n",
681
+ " ax.text(i - width/2, b + 0.01, f'{b:.2f}', ha='center', fontsize=9)\n",
682
+ " ax.text(i + width/2, t + 0.01, f'{t:.2f}', ha='center', fontsize=9)\n",
683
+ "plt.tight_layout()\n",
684
+ "plt.savefig(PLOTS_DIR / 'baseline_vs_trained.png', dpi=150)\n",
 
 
 
 
 
 
 
 
 
 
 
 
685
  "plt.show()\n",
686
+ "tlog('[ARTIFACT] baseline_vs_trained.png saved')"
687
+ ]
688
+ },
689
+ {
690
+ "cell_type": "code",
691
+ "id": "cell-31",
692
+ "metadata": {},
693
+ "outputs": [],
694
+ "source": [
695
+ "# 12c. Per-episode score distribution (box plot)\n",
696
+ "fig, ax = plt.subplots(figsize=(9,5))\n",
697
+ "data, labels, colors = [], [], []\n",
698
+ "for wf in WORKFLOWS:\n",
699
+ " data += [baseline_scores[wf], trained_scores[wf]]\n",
700
+ " labels += [f'{wf} (base)', f'{wf} (trained)']\n",
701
+ " colors += ['lightgray', 'lightgreen']\n",
702
+ "bp = ax.boxplot(data, labels=labels, patch_artist=True)\n",
703
+ "for patch, c in zip(bp['boxes'], colors):\n",
704
+ " patch.set_facecolor(c)\n",
705
+ "ax.set_ylabel('episode score (0–1)')\n",
706
+ "ax.set_title('Per-episode score distribution — baseline vs GRPO-trained')\n",
707
+ "ax.grid(axis='y', alpha=0.3)\n",
708
+ "plt.tight_layout()\n",
709
+ "plt.savefig(PLOTS_DIR / 'score_distribution.png', dpi=150)\n",
710
+ "plt.show()\n",
711
+ "tlog('[ARTIFACT] score_distribution.png saved')"
712
  ]
713
  },
714
  {
715
  "cell_type": "markdown",
716
+ "id": "cell-32",
717
  "metadata": {},
718
  "source": [
719
+ "## 13. Save artifacts\n",
720
+ "\n",
721
+ "Saves the LoRA adapter and copies all artifacts to Google Drive so they survive a Colab disconnect."
722
  ]
723
  },
724
  {
725
  "cell_type": "code",
726
+ "id": "cell-33",
 
727
  "metadata": {},
728
  "outputs": [],
729
  "source": [
730
+ "model.save_pretrained(str(ADAPTER_DIR))\n",
731
+ "tokenizer.save_pretrained(str(ADAPTER_DIR))\n",
732
+ "tlog(f'[ARTIFACT] lora_adapter saved to {ADAPTER_DIR}')\n",
733
+ "\n",
734
+ "try:\n",
735
+ " from google.colab import drive\n",
736
+ " drive.mount('/content/drive', force_remount=False)\n",
737
+ " DRIVE_DIR = Path('/content/drive/MyDrive/orgos-training')\n",
738
+ " DRIVE_DIR.mkdir(parents=True, exist_ok=True)\n",
739
+ " !cp {LOG_PATH} {DRIVE_DIR}/\n",
740
+ " !cp -r {PLOTS_DIR} {DRIVE_DIR}/\n",
741
+ " !cp -r {ADAPTER_DIR} {DRIVE_DIR}/\n",
742
+ " print(f'Artifacts copied to {DRIVE_DIR}')\n",
743
+ "except ImportError:\n",
744
+ " print('Not on Colab skipping Drive copy')"
745
+ ]
746
+ },
747
+ {
748
+ "cell_type": "code",
749
+ "id": "cell-34",
750
+ "metadata": {},
751
+ "outputs": [],
752
+ "source": [
753
+ "# Stop the env server cleanly\n",
754
+ "ENV_PROC.terminate()\n",
755
+ "tlog('[RUN_END]')\n",
756
+ "print('\\nDone. Commit these to the repo:')\n",
757
+ "print(f' - {LOG_PATH}')\n",
758
+ "print(f' - {PLOTS_DIR}/*.png')\n",
759
+ "print(f' - {ADAPTER_DIR}/')"
760
  ]
761
  }
762
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
763
  }