muskan singh commited on
Commit
e22f664
·
1 Parent(s): 9e29238

training notebook with training logs

Browse files
Files changed (1) hide show
  1. training/grpo_orgos.ipynb +199 -140
training/grpo_orgos.ipynb CHANGED
@@ -21,7 +21,7 @@
21
  "4. GRPO computes relative advantages within the group (which action did better than average?)\n",
22
  "5. Model is updated to favour higher-reward actions\n",
23
  "\n",
24
- "**Key training signal:** Schema drift creates a sharp reward gap.\n",
25
  "Using a stale field name (e.g. `priority` when schema says `severity`) → **−0.20**. \n",
26
  "Using the correct drifted name → **+0.10** adaptation bonus. \n",
27
  "The model learns to read `schema_hints` before constructing action args."
@@ -77,10 +77,52 @@
77
  },
78
  {
79
  "cell_type": "markdown",
80
- "id": "sec3",
81
  "metadata": {},
82
  "source": [
83
- "## 3. Start the OrgOS Environment Server"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  ]
85
  },
86
  {
@@ -101,15 +143,15 @@
101
  "\n",
102
  "health = httpx.get(\"http://localhost:8000/health\").json()\n",
103
  "assert health[\"status\"] == \"healthy\", f\"Server not healthy: {health}\"\n",
104
- "print(\"OrgOS server running:\", health)"
105
  ]
106
  },
107
  {
108
  "cell_type": "markdown",
109
- "id": "sec4",
110
  "metadata": {},
111
  "source": [
112
- "## 4. Load Model with Unsloth 4-bit LoRA"
113
  ]
114
  },
115
  {
@@ -124,6 +166,7 @@
124
  "\n",
125
  "MAX_SEQ_LEN = 2048\n",
126
  "MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
 
127
  "\n",
128
  "model, tokenizer = FastLanguageModel.from_pretrained(\n",
129
  " model_name = MODEL_NAME,\n",
@@ -134,31 +177,27 @@
134
  "\n",
135
  "model = FastLanguageModel.get_peft_model(\n",
136
  " model,\n",
137
- " r = 16,\n",
138
  " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
139
  " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
140
- " lora_alpha = 16,\n",
141
  " lora_dropout = 0,\n",
142
  " bias = \"none\",\n",
143
  " use_gradient_checkpointing = \"unsloth\",\n",
144
  " random_state = 42,\n",
145
  ")\n",
 
146
  "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
147
- "print(f\"Model loaded trainable params: {trainable:,}\")"
 
148
  ]
149
  },
150
  {
151
  "cell_type": "markdown",
152
- "id": "sec5",
153
  "metadata": {},
154
  "source": [
155
- "## 5. Prompt Dataset\n",
156
- "\n",
157
- "We collect **first-turn observations** from fresh episode resets as our prompt dataset.\n",
158
- "These are the most important turns — they contain `schema_hints`, `active_rules`, and the\n",
159
- "full workflow goal. The model must learn to read schema hints and produce a correct first action.\n",
160
- "\n",
161
- "During GRPO training, the reward function will reset the env and evaluate each generated action live."
162
  ]
163
  },
164
  {
@@ -168,7 +207,9 @@
168
  "metadata": {},
169
  "outputs": [],
170
  "source": [
171
- "import json\n",
 
 
172
  "from datasets import Dataset\n",
173
  "\n",
174
  "SYSTEM_PROMPT = \"\"\"\\\n",
@@ -209,6 +250,8 @@
209
  "6. Stop when pending_steps is empty or done=true.\n",
210
  "\"\"\"\n",
211
  "\n",
 
 
212
  "\n",
213
  "def obs_to_text(obs: dict) -> str:\n",
214
  " hints = obs.get(\"schema_hints\", {})\n",
@@ -243,23 +286,34 @@
243
  "\n",
244
  "\n",
245
  "def build_prompt(obs_text: str) -> str:\n",
246
- " \"\"\"Format as a chat prompt with system injected into first user message.\"\"\"\n",
247
  " messages = [{\"role\": \"user\", \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + obs_text}]\n",
248
  " return tokenizer.apply_chat_template(\n",
249
  " messages, tokenize=False, add_generation_prompt=True\n",
250
  " )\n",
251
  "\n",
252
  "\n",
253
- "# Collect first-turn observations across all 3 workflows, multiple episodes\n",
254
- "# Each episode has a different schema version (seed varies) so we get diverse prompts\n",
 
 
 
 
 
 
 
 
 
 
 
 
255
  "N_PROMPTS_PER_WORKFLOW = 20\n",
256
  "prompt_rows = []\n",
257
  "\n",
258
  "print(\"Collecting prompts from env resets...\")\n",
259
  "for wf in [\"A\", \"B\", \"C\"]:\n",
260
  " for _ in range(N_PROMPTS_PER_WORKFLOW):\n",
261
- " result = httpx.post(\"http://localhost:8000/reset\", json={\"workflow_id\": wf}).json()\n",
262
- " obs = result[\"observation\"]\n",
263
  " obs_text = obs_to_text(obs)\n",
264
  " prompt_rows.append({\n",
265
  " \"prompt\": build_prompt(obs_text),\n",
@@ -268,25 +322,17 @@
268
  " })\n",
269
  "\n",
270
  "prompt_dataset = Dataset.from_list(prompt_rows)\n",
271
- "print(f\"Prompt dataset: {len(prompt_dataset)} examples across 3 workflows\")\n",
272
- "print(\"Sample prompt (truncated):\\n\", prompt_rows[0][\"prompt\"][:600], \"...\")"
 
273
  ]
274
  },
275
  {
276
  "cell_type": "markdown",
277
- "id": "sec6",
278
  "metadata": {},
279
  "source": [
280
- "## 6. Reward Function\n",
281
- "\n",
282
- "Called by GRPOTrainer during training on each batch of generated completions.\n",
283
- "For each completion:\n",
284
- "1. Parse it as action JSON\n",
285
- "2. Reset the env to a fresh episode for the right workflow\n",
286
- "3. Send the action via `/step`\n",
287
- "4. Return the reward\n",
288
- "\n",
289
- "This gives the model a live signal from the actual environment."
290
  ]
291
  },
292
  {
@@ -296,53 +342,20 @@
296
  "metadata": {},
297
  "outputs": [],
298
  "source": [
299
- "import re\n",
300
- "from typing import List\n",
301
- "\n",
302
- "ENV_URL = \"http://localhost:8000\"\n",
303
- "\n",
304
- "\n",
305
- "def parse_action(text: str):\n",
306
- " \"\"\"Extract JSON action from model output.\"\"\"\n",
307
- " text = text.strip()\n",
308
- " # Strip markdown code fences if present\n",
309
- " text = re.sub(r\"```(?:json)?\\s*\", \"\", text).strip()\n",
310
- " try:\n",
311
- " return json.loads(text)\n",
312
- " except json.JSONDecodeError:\n",
313
- " m = re.search(r\"\\{.*\\}\", text, re.DOTALL)\n",
314
- " if m:\n",
315
- " try:\n",
316
- " return json.loads(m.group())\n",
317
- " except Exception:\n",
318
- " pass\n",
319
- " return None\n",
320
- "\n",
321
- "\n",
322
  "def orgos_reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:\n",
323
  " \"\"\"\n",
324
  " GRPO reward function — called by GRPOTrainer each training step.\n",
325
- "\n",
326
- " For each generated completion:\n",
327
- " - Parse as action JSON\n",
328
- " - Reset env to a fresh episode (workflow inferred from prompt)\n",
329
- " - Step the env with the action\n",
330
- " - Return the step reward\n",
331
- "\n",
332
- " Invalid JSON or failed actions return a -0.1 penalty.\n",
333
  " \"\"\"\n",
334
  " workflow_ids = kwargs.get(\"workflow_id\", [\"A\"] * len(completions))\n",
335
  " rewards = []\n",
336
  "\n",
337
  " for completion, wf_id in zip(completions, workflow_ids):\n",
338
  " action = parse_action(completion)\n",
339
- "\n",
340
  " if action is None:\n",
341
  " rewards.append(-0.1)\n",
342
  " continue\n",
343
- "\n",
344
  " try:\n",
345
- " # Fresh episode for this action evaluation\n",
346
  " httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": wf_id}, timeout=10)\n",
347
  " result = httpx.post(f\"{ENV_URL}/step\", json=action, timeout=10).json()\n",
348
  " rewards.append(float(result[\"reward\"]))\n",
@@ -352,24 +365,22 @@
352
  " return rewards\n",
353
  "\n",
354
  "\n",
355
- "print(\"Reward function defined.\")\n",
356
- "print(\"Quick sanity check...\")\n",
357
- "test_rewards = orgos_reward_fn(\n",
358
- " completions = ['{\"app\": \"zendesk\", \"operation\": \"list_tickets\", \"args\": {\"state\": \"new\"}}',\n",
359
- " 'this is not valid json'],\n",
360
- " prompts = [\"\", \"\"],\n",
361
- " workflow_id = [\"A\", \"A\"],\n",
362
  ")\n",
363
- "print(f\" Valid action reward: {test_rewards[0]:.4f}\")\n",
364
- "print(f\" Invalid action reward: {test_rewards[1]:.4f}\")"
365
  ]
366
  },
367
  {
368
  "cell_type": "markdown",
369
- "id": "sec7",
370
  "metadata": {},
371
  "source": [
372
- "## 7. Collect Baseline Scores (Pre-Training)"
373
  ]
374
  },
375
  {
@@ -379,8 +390,6 @@
379
  "metadata": {},
380
  "outputs": [],
381
  "source": [
382
- "import numpy as np\n",
383
- "\n",
384
  "FastLanguageModel.for_inference(model)\n",
385
  "\n",
386
  "\n",
@@ -397,9 +406,9 @@
397
  " obs_text = obs_to_text(obs)\n",
398
  " history.append({\"role\": \"user\", \"content\": obs_text})\n",
399
  "\n",
400
- " # Inject system prompt into first user message\n",
401
- " messages = list(history)\n",
402
- " messages[0] = {\"role\": \"user\", \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + messages[0][\"content\"]}\n",
403
  "\n",
404
  " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
405
  " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
@@ -430,27 +439,29 @@
430
  " return obs.get(\"current_score\", 0.001)\n",
431
  "\n",
432
  "\n",
433
- "N_EVAL = 10 # episodes per workflow for evaluation\n",
434
  "baseline_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
435
  "\n",
436
- "print(\"Collecting pre-training baseline scores...\")\n",
437
  "for wf in [\"A\", \"B\", \"C\"]:\n",
438
  " for ep in range(N_EVAL):\n",
439
  " score = run_episode_with_model(wf)\n",
440
  " baseline_scores[wf].append(score)\n",
441
- " print(f\" Workflow {wf} ep {ep+1}/{N_EVAL}: score={score:.4f}\", end=\"\\r\")\n",
442
- " print(f\" Workflow {wf}: mean={np.mean(baseline_scores[wf]):.4f}\")\n",
 
 
443
  "\n",
444
  "baseline_mean = np.mean([s for v in baseline_scores.values() for s in v])\n",
445
- "print(f\"\\nOverall baseline mean: {baseline_mean:.4f}\")"
446
  ]
447
  },
448
  {
449
  "cell_type": "markdown",
450
- "id": "sec8",
451
  "metadata": {},
452
  "source": [
453
- "## 8. GRPO Training"
454
  ]
455
  },
456
  {
@@ -461,57 +472,86 @@
461
  "outputs": [],
462
  "source": [
463
  "from trl import GRPOConfig, GRPOTrainer\n",
 
464
  "\n",
465
- "# Switch back to training mode\n",
466
  "model.train()\n",
467
  "\n",
 
 
 
 
 
 
 
 
468
  "grpo_config = GRPOConfig(\n",
469
  " output_dir = \"./orgos_grpo_ckpt\",\n",
470
- " num_train_epochs = 3,\n",
471
- " per_device_train_batch_size = 4,\n",
472
- " gradient_accumulation_steps = 2,\n",
473
- " learning_rate = 5e-5,\n",
474
  " warmup_steps = 10,\n",
475
  " logging_steps = 5,\n",
476
  " save_steps = 100,\n",
477
  " bf16 = torch.cuda.is_bf16_supported(),\n",
478
  " fp16 = not torch.cuda.is_bf16_supported(),\n",
479
  " max_grad_norm = 1.0,\n",
480
- " # GRPO-specific\n",
481
- " num_generations = 4, # G: candidate actions per prompt\n",
482
  " max_new_tokens = 256,\n",
483
- " temperature = 0.8, # exploration during training\n",
484
- " beta = 0.04, # KL penalty coefficient\n",
485
  " report_to = \"none\",\n",
486
  " seed = 42,\n",
487
  ")\n",
488
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  "trainer = GRPOTrainer(\n",
490
- " model = model,\n",
491
- " args = grpo_config,\n",
492
- " reward_funcs = orgos_reward_fn,\n",
493
- " train_dataset = prompt_dataset,\n",
494
  " processing_class = tokenizer,\n",
 
495
  ")\n",
496
  "\n",
497
- "print(\"Starting GRPO training...\")\n",
498
- "print(f\" Prompts: {len(prompt_dataset)}\")\n",
499
- "print(f\" Generations per prompt (G): {grpo_config.num_generations}\")\n",
500
- "print(f\" Epochs: {grpo_config.num_train_epochs}\")\n",
501
- "print(f\" Total env calls per epoch: ~{len(prompt_dataset) * grpo_config.num_generations}\")\n",
502
- "print()\n",
503
- "\n",
504
  "train_result = trainer.train()\n",
505
- "print(\"\\nTraining complete!\")\n",
506
- "print(train_result.metrics)"
 
507
  ]
508
  },
509
  {
510
  "cell_type": "markdown",
511
- "id": "sec9",
512
  "metadata": {},
513
  "source": [
514
- "## 9. Collect Post-Training Scores"
515
  ]
516
  },
517
  {
@@ -525,25 +565,41 @@
525
  "\n",
526
  "post_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
527
  "\n",
528
- "print(\"Collecting post-training scores...\")\n",
529
  "for wf in [\"A\", \"B\", \"C\"]:\n",
530
  " for ep in range(N_EVAL):\n",
531
  " score = run_episode_with_model(wf)\n",
532
  " post_scores[wf].append(score)\n",
533
- " print(f\" Workflow {wf} ep {ep+1}/{N_EVAL}: score={score:.4f}\", end=\"\\r\")\n",
534
- " print(f\" Workflow {wf}: mean={np.mean(post_scores[wf]):.4f}\")\n",
535
- "\n",
536
- "post_mean = np.mean([s for v in post_scores.values() for s in v])\n",
537
- "print(f\"\\nOverall post-training mean: {post_mean:.4f}\")\n",
538
- "print(f\"Improvement: {post_mean - baseline_mean:+.4f}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
  ]
540
  },
541
  {
542
  "cell_type": "markdown",
543
- "id": "sec10",
544
  "metadata": {},
545
  "source": [
546
- "## 10. Plot Before / After"
547
  ]
548
  },
549
  {
@@ -561,8 +617,7 @@
561
  " color=\"white\", fontweight=\"bold\", y=0.98)\n",
562
  "\n",
563
  "gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)\n",
564
- "\n",
565
- "COLORS = {\"before\": \"#f87171\", \"after\": \"#34d399\", \"bg\": \"#1e293b\", \"grid\": \"#334155\"}\n",
566
  "WF_LABELS = {\n",
567
  " \"A\": \"Workflow A\\nCustomer Bug Fix\",\n",
568
  " \"B\": \"Workflow B\\nEmployee Onboarding\",\n",
@@ -570,19 +625,16 @@
570
  "}\n",
571
  "\n",
572
  "for col, wf in enumerate([\"A\", \"B\", \"C\"]):\n",
573
- " ax = fig.add_subplot(gs[0, col])\n",
574
  " ax.set_facecolor(COLORS[\"bg\"])\n",
575
  " ax.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.7)\n",
576
- "\n",
577
  " before = baseline_scores[wf]\n",
578
  " after = post_scores[wf]\n",
579
  " delta = np.mean(after) - np.mean(before)\n",
580
- "\n",
581
  " ax.plot(before, color=COLORS[\"before\"], linewidth=1.5, alpha=0.8, label=\"Before GRPO\")\n",
582
  " ax.plot(after, color=COLORS[\"after\"], linewidth=1.5, alpha=0.8, label=\"After GRPO\")\n",
583
  " ax.axhline(np.mean(before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
584
  " ax.axhline(np.mean(after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
585
- "\n",
586
  " ax.set_title(WF_LABELS[wf] + f\"\\n(Δ = {delta:+.4f})\", color=\"white\", fontsize=9)\n",
587
  " ax.set_xlabel(\"Episode\", color=\"#94a3b8\", fontsize=8)\n",
588
  " ax.set_ylabel(\"Final Score\", color=\"#94a3b8\", fontsize=8)\n",
@@ -596,18 +648,15 @@
596
  "ax_hist = fig.add_subplot(gs[1, :])\n",
597
  "ax_hist.set_facecolor(COLORS[\"bg\"])\n",
598
  "ax_hist.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.5, axis=\"x\")\n",
599
- "\n",
600
  "all_before = [s for v in baseline_scores.values() for s in v]\n",
601
  "all_after = [s for v in post_scores.values() for s in v]\n",
602
  "bins = np.linspace(0, 1, 25)\n",
603
- "\n",
604
  "ax_hist.hist(all_before, bins=bins, color=COLORS[\"before\"], alpha=0.6,\n",
605
  " label=f\"Before GRPO (mean={np.mean(all_before):.4f})\", edgecolor=\"none\")\n",
606
  "ax_hist.hist(all_after, bins=bins, color=COLORS[\"after\"], alpha=0.6,\n",
607
  " label=f\"After GRPO (mean={np.mean(all_after):.4f})\", edgecolor=\"none\")\n",
608
  "ax_hist.axvline(np.mean(all_before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1.5)\n",
609
  "ax_hist.axvline(np.mean(all_after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1.5)\n",
610
- "\n",
611
  "ax_hist.set_title(\"Score Distribution Across All Workflows\", color=\"white\", fontsize=10)\n",
612
  "ax_hist.set_xlabel(\"Final Score\", color=\"#94a3b8\", fontsize=9)\n",
613
  "ax_hist.set_ylabel(\"Count\", color=\"#94a3b8\", fontsize=9)\n",
@@ -620,15 +669,16 @@
620
  "plt.savefig(\"before_after_curves.png\", dpi=150, bbox_inches=\"tight\",\n",
621
  " facecolor=\"#0f172a\", edgecolor=\"none\")\n",
622
  "plt.show()\n",
 
623
  "print(\"Saved: before_after_curves.png\")"
624
  ]
625
  },
626
  {
627
  "cell_type": "markdown",
628
- "id": "sec11",
629
  "metadata": {},
630
  "source": [
631
- "## 11. Save LoRA Adapter"
632
  ]
633
  },
634
  {
@@ -640,9 +690,18 @@
640
  "source": [
641
  "model.save_pretrained(\"orgos_lora_adapter\")\n",
642
  "tokenizer.save_pretrained(\"orgos_lora_adapter\")\n",
643
- "print(\"LoRA adapter saved to ./orgos_lora_adapter\")\n",
644
- "\n",
645
- "# Push to HuggingFace Hub\n",
 
 
 
 
 
 
 
 
 
646
  "# from huggingface_hub import login\n",
647
  "# login(token=\"YOUR_HF_TOKEN\")\n",
648
  "# model.push_to_hub(\"YOUR_USERNAME/orgos-qwen25-3b-grpo\")\n",
 
21
  "4. GRPO computes relative advantages within the group (which action did better than average?)\n",
22
  "5. Model is updated to favour higher-reward actions\n",
23
  "\n",
24
+ "**Key training signal:** Schema drift creates a sharp reward gap. \n",
25
  "Using a stale field name (e.g. `priority` when schema says `severity`) → **−0.20**. \n",
26
  "Using the correct drifted name → **+0.10** adaptation bonus. \n",
27
  "The model learns to read `schema_hints` before constructing action args."
 
77
  },
78
  {
79
  "cell_type": "markdown",
80
+ "id": "sec_logger",
81
  "metadata": {},
82
  "source": [
83
+ "## 3. Training Logger\n",
84
+ "\n",
85
+ "Writes structured logs to `training_log.txt` for submission. \n",
86
+ "Format mirrors the OpenEnv inference log spec:\n",
87
+ "- `[TRAIN_CONFIG]` — model, algorithm, hyperparameters\n",
88
+ "- `[EVAL]` — per-episode score during baseline or post-training eval\n",
89
+ "- `[TRAIN_STEP]` — loss, mean reward, KL per training step\n",
90
+ "- `[TRAIN_SUMMARY]` — final before/after comparison"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "id": "logger",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "import datetime\n",
101
+ "\n",
102
+ "LOG_FILE = \"training_log.txt\"\n",
103
+ "\n",
104
+ "# Clear any previous log\n",
105
+ "with open(LOG_FILE, \"w\") as f:\n",
106
+ " f.write(f\"# OrgOS GRPO Training Log\\n\")\n",
107
+ " f.write(f\"# Generated: {datetime.datetime.utcnow().isoformat()}Z\\n\\n\")\n",
108
+ "\n",
109
+ "\n",
110
+ "def tlog(line: str) -> None:\n",
111
+ " \"\"\"Append one structured log line to training_log.txt and print it.\"\"\"\n",
112
+ " print(line, flush=True)\n",
113
+ " with open(LOG_FILE, \"a\") as f:\n",
114
+ " f.write(line + \"\\n\")\n",
115
+ "\n",
116
+ "\n",
117
+ "print(f\"Logger ready — writing to {LOG_FILE}\")"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "id": "sec4",
123
+ "metadata": {},
124
+ "source": [
125
+ "## 4. Start the OrgOS Environment Server"
126
  ]
127
  },
128
  {
 
143
  "\n",
144
  "health = httpx.get(\"http://localhost:8000/health\").json()\n",
145
  "assert health[\"status\"] == \"healthy\", f\"Server not healthy: {health}\"\n",
146
+ "tlog(f\"[ENV] status=healthy version={health.get('version', '?')}\")"
147
  ]
148
  },
149
  {
150
  "cell_type": "markdown",
151
+ "id": "sec5",
152
  "metadata": {},
153
  "source": [
154
+ "## 5. Load Model with Unsloth 4-bit LoRA"
155
  ]
156
  },
157
  {
 
166
  "\n",
167
  "MAX_SEQ_LEN = 2048\n",
168
  "MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
169
+ "LORA_R = 16\n",
170
  "\n",
171
  "model, tokenizer = FastLanguageModel.from_pretrained(\n",
172
  " model_name = MODEL_NAME,\n",
 
177
  "\n",
178
  "model = FastLanguageModel.get_peft_model(\n",
179
  " model,\n",
180
+ " r = LORA_R,\n",
181
  " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
182
  " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
183
+ " lora_alpha = LORA_R,\n",
184
  " lora_dropout = 0,\n",
185
  " bias = \"none\",\n",
186
  " use_gradient_checkpointing = \"unsloth\",\n",
187
  " random_state = 42,\n",
188
  ")\n",
189
+ "\n",
190
  "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
191
+ "tlog(f\"[TRAIN_CONFIG] model={MODEL_NAME} lora_r={LORA_R} max_seq_len={MAX_SEQ_LEN} \"\n",
192
+ " f\"trainable_params={trainable:,} quantization=4bit\")"
193
  ]
194
  },
195
  {
196
  "cell_type": "markdown",
197
+ "id": "sec6",
198
  "metadata": {},
199
  "source": [
200
+ "## 6. Prompt Dataset"
 
 
 
 
 
 
201
  ]
202
  },
203
  {
 
207
  "metadata": {},
208
  "outputs": [],
209
  "source": [
210
+ "import json, re\n",
211
+ "import numpy as np\n",
212
+ "from typing import List\n",
213
  "from datasets import Dataset\n",
214
  "\n",
215
  "SYSTEM_PROMPT = \"\"\"\\\n",
 
250
  "6. Stop when pending_steps is empty or done=true.\n",
251
  "\"\"\"\n",
252
  "\n",
253
+ "ENV_URL = \"http://localhost:8000\"\n",
254
+ "\n",
255
  "\n",
256
  "def obs_to_text(obs: dict) -> str:\n",
257
  " hints = obs.get(\"schema_hints\", {})\n",
 
286
  "\n",
287
  "\n",
288
  "def build_prompt(obs_text: str) -> str:\n",
 
289
  " messages = [{\"role\": \"user\", \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + obs_text}]\n",
290
  " return tokenizer.apply_chat_template(\n",
291
  " messages, tokenize=False, add_generation_prompt=True\n",
292
  " )\n",
293
  "\n",
294
  "\n",
295
+ "def parse_action(text: str):\n",
296
+ " text = re.sub(r\"```(?:json)?\\s*\", \"\", text.strip()).strip()\n",
297
+ " try:\n",
298
+ " return json.loads(text)\n",
299
+ " except json.JSONDecodeError:\n",
300
+ " m = re.search(r\"\\{.*\\}\", text, re.DOTALL)\n",
301
+ " if m:\n",
302
+ " try:\n",
303
+ " return json.loads(m.group())\n",
304
+ " except Exception:\n",
305
+ " pass\n",
306
+ " return None\n",
307
+ "\n",
308
+ "\n",
309
  "N_PROMPTS_PER_WORKFLOW = 20\n",
310
  "prompt_rows = []\n",
311
  "\n",
312
  "print(\"Collecting prompts from env resets...\")\n",
313
  "for wf in [\"A\", \"B\", \"C\"]:\n",
314
  " for _ in range(N_PROMPTS_PER_WORKFLOW):\n",
315
+ " result = httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": wf}).json()\n",
316
+ " obs = result[\"observation\"]\n",
317
  " obs_text = obs_to_text(obs)\n",
318
  " prompt_rows.append({\n",
319
  " \"prompt\": build_prompt(obs_text),\n",
 
322
  " })\n",
323
  "\n",
324
  "prompt_dataset = Dataset.from_list(prompt_rows)\n",
325
+ "tlog(f\"[TRAIN_CONFIG] algorithm=GRPO prompts={len(prompt_dataset)} \"\n",
326
+ " f\"workflows=A,B,C prompts_per_workflow={N_PROMPTS_PER_WORKFLOW}\")\n",
327
+ "print(f\"Prompt dataset ready: {len(prompt_dataset)} examples\")"
328
  ]
329
  },
330
  {
331
  "cell_type": "markdown",
332
+ "id": "sec7",
333
  "metadata": {},
334
  "source": [
335
+ "## 7. Reward Function"
 
 
 
 
 
 
 
 
 
336
  ]
337
  },
338
  {
 
342
  "metadata": {},
343
  "outputs": [],
344
  "source": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  "def orgos_reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:\n",
346
  " \"\"\"\n",
347
  " GRPO reward function — called by GRPOTrainer each training step.\n",
348
+ " Parses each completion as an action JSON, steps the live env, returns the reward.\n",
 
 
 
 
 
 
 
349
  " \"\"\"\n",
350
  " workflow_ids = kwargs.get(\"workflow_id\", [\"A\"] * len(completions))\n",
351
  " rewards = []\n",
352
  "\n",
353
  " for completion, wf_id in zip(completions, workflow_ids):\n",
354
  " action = parse_action(completion)\n",
 
355
  " if action is None:\n",
356
  " rewards.append(-0.1)\n",
357
  " continue\n",
 
358
  " try:\n",
 
359
  " httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": wf_id}, timeout=10)\n",
360
  " result = httpx.post(f\"{ENV_URL}/step\", json=action, timeout=10).json()\n",
361
  " rewards.append(float(result[\"reward\"]))\n",
 
365
  " return rewards\n",
366
  "\n",
367
  "\n",
368
+ "# Sanity check\n",
369
+ "test_r = orgos_reward_fn(\n",
370
+ " completions = ['{\"app\": \"zendesk\", \"operation\": \"list_tickets\", \"args\": {\"state\": \"new\"}}',\n",
371
+ " 'not json'],\n",
372
+ " prompts = [\"\", \"\"],\n",
373
+ " workflow_id = [\"A\", \"A\"],\n",
 
374
  ")\n",
375
+ "tlog(f\"[REWARD_FN_CHECK] valid_action={test_r[0]:.4f} invalid_action={test_r[1]:.4f}\")"
 
376
  ]
377
  },
378
  {
379
  "cell_type": "markdown",
380
+ "id": "sec8",
381
  "metadata": {},
382
  "source": [
383
+ "## 8. Collect Baseline Scores (Pre-Training)"
384
  ]
385
  },
386
  {
 
390
  "metadata": {},
391
  "outputs": [],
392
  "source": [
 
 
393
  "FastLanguageModel.for_inference(model)\n",
394
  "\n",
395
  "\n",
 
406
  " obs_text = obs_to_text(obs)\n",
407
  " history.append({\"role\": \"user\", \"content\": obs_text})\n",
408
  "\n",
409
+ " messages = list(history)\n",
410
+ " messages[0] = {\"role\": \"user\",\n",
411
+ " \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + messages[0][\"content\"]}\n",
412
  "\n",
413
  " text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
414
  " inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
 
439
  " return obs.get(\"current_score\", 0.001)\n",
440
  "\n",
441
  "\n",
442
+ "N_EVAL = 10\n",
443
  "baseline_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
444
  "\n",
445
+ "tlog(\"[EVAL_START] phase=baseline\")\n",
446
  "for wf in [\"A\", \"B\", \"C\"]:\n",
447
  " for ep in range(N_EVAL):\n",
448
  " score = run_episode_with_model(wf)\n",
449
  " baseline_scores[wf].append(score)\n",
450
+ " tlog(f\"[EVAL] phase=baseline workflow={wf} episode={ep+1} score={score:.4f}\")\n",
451
+ " wf_mean = np.mean(baseline_scores[wf])\n",
452
+ " tlog(f\"[EVAL_WORKFLOW] phase=baseline workflow={wf} \"\n",
453
+ " f\"mean={wf_mean:.4f} min={min(baseline_scores[wf]):.4f} max={max(baseline_scores[wf]):.4f}\")\n",
454
  "\n",
455
  "baseline_mean = np.mean([s for v in baseline_scores.values() for s in v])\n",
456
+ "tlog(f\"[EVAL_END] phase=baseline overall_mean={baseline_mean:.4f}\")"
457
  ]
458
  },
459
  {
460
  "cell_type": "markdown",
461
+ "id": "sec9",
462
  "metadata": {},
463
  "source": [
464
+ "## 9. GRPO Training"
465
  ]
466
  },
467
  {
 
472
  "outputs": [],
473
  "source": [
474
  "from trl import GRPOConfig, GRPOTrainer\n",
475
+ "from transformers import TrainerCallback\n",
476
  "\n",
 
477
  "model.train()\n",
478
  "\n",
479
+ "NUM_EPOCHS = 3\n",
480
+ "BATCH_SIZE = 4\n",
481
+ "GRAD_ACCUM = 2\n",
482
+ "LR = 5e-5\n",
483
+ "NUM_GEN = 4\n",
484
+ "TEMPERATURE = 0.8\n",
485
+ "BETA = 0.04\n",
486
+ "\n",
487
  "grpo_config = GRPOConfig(\n",
488
  " output_dir = \"./orgos_grpo_ckpt\",\n",
489
+ " num_train_epochs = NUM_EPOCHS,\n",
490
+ " per_device_train_batch_size = BATCH_SIZE,\n",
491
+ " gradient_accumulation_steps = GRAD_ACCUM,\n",
492
+ " learning_rate = LR,\n",
493
  " warmup_steps = 10,\n",
494
  " logging_steps = 5,\n",
495
  " save_steps = 100,\n",
496
  " bf16 = torch.cuda.is_bf16_supported(),\n",
497
  " fp16 = not torch.cuda.is_bf16_supported(),\n",
498
  " max_grad_norm = 1.0,\n",
499
+ " num_generations = NUM_GEN,\n",
 
500
  " max_new_tokens = 256,\n",
501
+ " temperature = TEMPERATURE,\n",
502
+ " beta = BETA,\n",
503
  " report_to = \"none\",\n",
504
  " seed = 42,\n",
505
  ")\n",
506
  "\n",
507
+ "tlog(f\"[TRAIN_CONFIG] epochs={NUM_EPOCHS} batch_size={BATCH_SIZE} \"\n",
508
+ " f\"grad_accum={GRAD_ACCUM} lr={LR} num_generations={NUM_GEN} \"\n",
509
+ " f\"temperature={TEMPERATURE} beta_kl={BETA}\")\n",
510
+ "\n",
511
+ "\n",
512
+ "class OrgOSLogCallback(TrainerCallback):\n",
513
+ " \"\"\"Logs each training step to training_log.txt.\"\"\"\n",
514
+ "\n",
515
+ " def on_log(self, args, state, control, logs=None, **kwargs):\n",
516
+ " if logs is None:\n",
517
+ " return\n",
518
+ " step = state.global_step\n",
519
+ " loss = logs.get(\"loss\", logs.get(\"train_loss\", \"?\"))\n",
520
+ " mean_reward = logs.get(\"reward\", logs.get(\"mean_reward\", \"?\"))\n",
521
+ " kl = logs.get(\"kl\", logs.get(\"approx_kl\", \"?\"))\n",
522
+ " lr_now = logs.get(\"learning_rate\", \"?\")\n",
523
+ "\n",
524
+ " loss_str = f\"{loss:.6f}\" if isinstance(loss, float) else str(loss)\n",
525
+ " reward_str = f\"{mean_reward:.4f}\" if isinstance(mean_reward, float) else str(mean_reward)\n",
526
+ " kl_str = f\"{kl:.6f}\" if isinstance(kl, float) else str(kl)\n",
527
+ " lr_str = f\"{lr_now:.2e}\" if isinstance(lr_now, float) else str(lr_now)\n",
528
+ "\n",
529
+ " tlog(f\"[TRAIN_STEP] step={step} loss={loss_str} \"\n",
530
+ " f\"mean_reward={reward_str} kl={kl_str} lr={lr_str}\")\n",
531
+ "\n",
532
+ "\n",
533
  "trainer = GRPOTrainer(\n",
534
+ " model = model,\n",
535
+ " args = grpo_config,\n",
536
+ " reward_funcs = orgos_reward_fn,\n",
537
+ " train_dataset = prompt_dataset,\n",
538
  " processing_class = tokenizer,\n",
539
+ " callbacks = [OrgOSLogCallback()],\n",
540
  ")\n",
541
  "\n",
542
+ "tlog(\"[TRAIN_START]\")\n",
 
 
 
 
 
 
543
  "train_result = trainer.train()\n",
544
+ "tlog(f\"[TRAIN_END] total_steps={train_result.global_step} \"\n",
545
+ " f\"train_loss={train_result.training_loss:.6f} \"\n",
546
+ " f\"train_runtime_s={train_result.metrics.get('train_runtime', 0):.1f}\")"
547
  ]
548
  },
549
  {
550
  "cell_type": "markdown",
551
+ "id": "sec10",
552
  "metadata": {},
553
  "source": [
554
+ "## 10. Collect Post-Training Scores"
555
  ]
556
  },
557
  {
 
565
  "\n",
566
  "post_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
567
  "\n",
568
+ "tlog(\"[EVAL_START] phase=post_training\")\n",
569
  "for wf in [\"A\", \"B\", \"C\"]:\n",
570
  " for ep in range(N_EVAL):\n",
571
  " score = run_episode_with_model(wf)\n",
572
  " post_scores[wf].append(score)\n",
573
+ " tlog(f\"[EVAL] phase=post_training workflow={wf} episode={ep+1} score={score:.4f}\")\n",
574
+ " wf_mean = np.mean(post_scores[wf])\n",
575
+ " tlog(f\"[EVAL_WORKFLOW] phase=post_training workflow={wf} \"\n",
576
+ " f\"mean={wf_mean:.4f} min={min(post_scores[wf]):.4f} max={max(post_scores[wf]):.4f}\")\n",
577
+ "\n",
578
+ "post_mean = np.mean([s for v in post_scores.values() for s in v])\n",
579
+ "improvement = post_mean - baseline_mean\n",
580
+ "tlog(f\"[EVAL_END] phase=post_training overall_mean={post_mean:.4f}\")\n",
581
+ "tlog(\n",
582
+ " f\"[TRAIN_SUMMARY] \"\n",
583
+ " f\"model={MODEL_NAME} algorithm=GRPO \"\n",
584
+ " f\"baseline_mean={baseline_mean:.4f} \"\n",
585
+ " f\"post_training_mean={post_mean:.4f} \"\n",
586
+ " f\"improvement={improvement:+.4f} \"\n",
587
+ " f\"workflow_A_before={np.mean(baseline_scores['A']):.4f} \"\n",
588
+ " f\"workflow_A_after={np.mean(post_scores['A']):.4f} \"\n",
589
+ " f\"workflow_B_before={np.mean(baseline_scores['B']):.4f} \"\n",
590
+ " f\"workflow_B_after={np.mean(post_scores['B']):.4f} \"\n",
591
+ " f\"workflow_C_before={np.mean(baseline_scores['C']):.4f} \"\n",
592
+ " f\"workflow_C_after={np.mean(post_scores['C']):.4f}\"\n",
593
+ ")\n",
594
+ "print(f\"\\nImprovement: {baseline_mean:.4f} → {post_mean:.4f} ({improvement:+.4f})\")"
595
  ]
596
  },
597
  {
598
  "cell_type": "markdown",
599
+ "id": "sec11",
600
  "metadata": {},
601
  "source": [
602
+ "## 11. Plot Before / After"
603
  ]
604
  },
605
  {
 
617
  " color=\"white\", fontweight=\"bold\", y=0.98)\n",
618
  "\n",
619
  "gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)\n",
620
+ "COLORS = {\"before\": \"#f87171\", \"after\": \"#34d399\", \"bg\": \"#1e293b\", \"grid\": \"#334155\"}\n",
 
621
  "WF_LABELS = {\n",
622
  " \"A\": \"Workflow A\\nCustomer Bug Fix\",\n",
623
  " \"B\": \"Workflow B\\nEmployee Onboarding\",\n",
 
625
  "}\n",
626
  "\n",
627
  "for col, wf in enumerate([\"A\", \"B\", \"C\"]):\n",
628
+ " ax = fig.add_subplot(gs[0, col])\n",
629
  " ax.set_facecolor(COLORS[\"bg\"])\n",
630
  " ax.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.7)\n",
 
631
  " before = baseline_scores[wf]\n",
632
  " after = post_scores[wf]\n",
633
  " delta = np.mean(after) - np.mean(before)\n",
 
634
  " ax.plot(before, color=COLORS[\"before\"], linewidth=1.5, alpha=0.8, label=\"Before GRPO\")\n",
635
  " ax.plot(after, color=COLORS[\"after\"], linewidth=1.5, alpha=0.8, label=\"After GRPO\")\n",
636
  " ax.axhline(np.mean(before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
637
  " ax.axhline(np.mean(after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
 
638
  " ax.set_title(WF_LABELS[wf] + f\"\\n(Δ = {delta:+.4f})\", color=\"white\", fontsize=9)\n",
639
  " ax.set_xlabel(\"Episode\", color=\"#94a3b8\", fontsize=8)\n",
640
  " ax.set_ylabel(\"Final Score\", color=\"#94a3b8\", fontsize=8)\n",
 
648
  "ax_hist = fig.add_subplot(gs[1, :])\n",
649
  "ax_hist.set_facecolor(COLORS[\"bg\"])\n",
650
  "ax_hist.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.5, axis=\"x\")\n",
 
651
  "all_before = [s for v in baseline_scores.values() for s in v]\n",
652
  "all_after = [s for v in post_scores.values() for s in v]\n",
653
  "bins = np.linspace(0, 1, 25)\n",
 
654
  "ax_hist.hist(all_before, bins=bins, color=COLORS[\"before\"], alpha=0.6,\n",
655
  " label=f\"Before GRPO (mean={np.mean(all_before):.4f})\", edgecolor=\"none\")\n",
656
  "ax_hist.hist(all_after, bins=bins, color=COLORS[\"after\"], alpha=0.6,\n",
657
  " label=f\"After GRPO (mean={np.mean(all_after):.4f})\", edgecolor=\"none\")\n",
658
  "ax_hist.axvline(np.mean(all_before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1.5)\n",
659
  "ax_hist.axvline(np.mean(all_after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1.5)\n",
 
660
  "ax_hist.set_title(\"Score Distribution Across All Workflows\", color=\"white\", fontsize=10)\n",
661
  "ax_hist.set_xlabel(\"Final Score\", color=\"#94a3b8\", fontsize=9)\n",
662
  "ax_hist.set_ylabel(\"Count\", color=\"#94a3b8\", fontsize=9)\n",
 
669
  "plt.savefig(\"before_after_curves.png\", dpi=150, bbox_inches=\"tight\",\n",
670
  " facecolor=\"#0f172a\", edgecolor=\"none\")\n",
671
  "plt.show()\n",
672
+ "tlog(\"[ARTIFACT] file=before_after_curves.png\")\n",
673
  "print(\"Saved: before_after_curves.png\")"
674
  ]
675
  },
676
  {
677
  "cell_type": "markdown",
678
+ "id": "sec12",
679
  "metadata": {},
680
  "source": [
681
+ "## 12. Save LoRA Adapter & Training Log"
682
  ]
683
  },
684
  {
 
690
  "source": [
691
  "model.save_pretrained(\"orgos_lora_adapter\")\n",
692
  "tokenizer.save_pretrained(\"orgos_lora_adapter\")\n",
693
+ "tlog(\"[ARTIFACT] file=orgos_lora_adapter/\")\n",
694
+ "tlog(\"[ARTIFACT] file=training_log.txt\")\n",
695
+ "\n",
696
+ "print(f\"\\n{'='*60}\")\n",
697
+ "print(\" Submission artefacts\")\n",
698
+ "print(f\"{'='*60}\")\n",
699
+ "print(\" training_log.txt — structured training log\")\n",
700
+ "print(\" before_after_curves.png — score improvement chart\")\n",
701
+ "print(\" orgos_lora_adapter/ — LoRA weights\")\n",
702
+ "print(f\"{'='*60}\")\n",
703
+ "\n",
704
+ "# Optional: push to HuggingFace Hub\n",
705
  "# from huggingface_hub import login\n",
706
  "# login(token=\"YOUR_HF_TOKEN\")\n",
707
  "# model.push_to_hub(\"YOUR_USERNAME/orgos-qwen25-3b-grpo\")\n",