Spaces:
Sleeping
Sleeping
muskan singh commited on
Commit ·
e22f664
1
Parent(s): 9e29238
training notebook with training logs
Browse files- training/grpo_orgos.ipynb +199 -140
training/grpo_orgos.ipynb
CHANGED
|
@@ -21,7 +21,7 @@
|
|
| 21 |
"4. GRPO computes relative advantages within the group (which action did better than average?)\n",
|
| 22 |
"5. Model is updated to favour higher-reward actions\n",
|
| 23 |
"\n",
|
| 24 |
-
"**Key training signal:** Schema drift creates a sharp reward gap.\n",
|
| 25 |
"Using a stale field name (e.g. `priority` when schema says `severity`) → **−0.20**. \n",
|
| 26 |
"Using the correct drifted name → **+0.10** adaptation bonus. \n",
|
| 27 |
"The model learns to read `schema_hints` before constructing action args."
|
|
@@ -77,10 +77,52 @@
|
|
| 77 |
},
|
| 78 |
{
|
| 79 |
"cell_type": "markdown",
|
| 80 |
-
"id": "
|
| 81 |
"metadata": {},
|
| 82 |
"source": [
|
| 83 |
-
"## 3.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
]
|
| 85 |
},
|
| 86 |
{
|
|
@@ -101,15 +143,15 @@
|
|
| 101 |
"\n",
|
| 102 |
"health = httpx.get(\"http://localhost:8000/health\").json()\n",
|
| 103 |
"assert health[\"status\"] == \"healthy\", f\"Server not healthy: {health}\"\n",
|
| 104 |
-
"
|
| 105 |
]
|
| 106 |
},
|
| 107 |
{
|
| 108 |
"cell_type": "markdown",
|
| 109 |
-
"id": "
|
| 110 |
"metadata": {},
|
| 111 |
"source": [
|
| 112 |
-
"##
|
| 113 |
]
|
| 114 |
},
|
| 115 |
{
|
|
@@ -124,6 +166,7 @@
|
|
| 124 |
"\n",
|
| 125 |
"MAX_SEQ_LEN = 2048\n",
|
| 126 |
"MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
|
|
|
|
| 127 |
"\n",
|
| 128 |
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 129 |
" model_name = MODEL_NAME,\n",
|
|
@@ -134,31 +177,27 @@
|
|
| 134 |
"\n",
|
| 135 |
"model = FastLanguageModel.get_peft_model(\n",
|
| 136 |
" model,\n",
|
| 137 |
-
" r =
|
| 138 |
" target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 139 |
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 140 |
-
" lora_alpha =
|
| 141 |
" lora_dropout = 0,\n",
|
| 142 |
" bias = \"none\",\n",
|
| 143 |
" use_gradient_checkpointing = \"unsloth\",\n",
|
| 144 |
" random_state = 42,\n",
|
| 145 |
")\n",
|
|
|
|
| 146 |
"trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 147 |
-
"
|
|
|
|
| 148 |
]
|
| 149 |
},
|
| 150 |
{
|
| 151 |
"cell_type": "markdown",
|
| 152 |
-
"id": "
|
| 153 |
"metadata": {},
|
| 154 |
"source": [
|
| 155 |
-
"##
|
| 156 |
-
"\n",
|
| 157 |
-
"We collect **first-turn observations** from fresh episode resets as our prompt dataset.\n",
|
| 158 |
-
"These are the most important turns — they contain `schema_hints`, `active_rules`, and the\n",
|
| 159 |
-
"full workflow goal. The model must learn to read schema hints and produce a correct first action.\n",
|
| 160 |
-
"\n",
|
| 161 |
-
"During GRPO training, the reward function will reset the env and evaluate each generated action live."
|
| 162 |
]
|
| 163 |
},
|
| 164 |
{
|
|
@@ -168,7 +207,9 @@
|
|
| 168 |
"metadata": {},
|
| 169 |
"outputs": [],
|
| 170 |
"source": [
|
| 171 |
-
"import json\n",
|
|
|
|
|
|
|
| 172 |
"from datasets import Dataset\n",
|
| 173 |
"\n",
|
| 174 |
"SYSTEM_PROMPT = \"\"\"\\\n",
|
|
@@ -209,6 +250,8 @@
|
|
| 209 |
"6. Stop when pending_steps is empty or done=true.\n",
|
| 210 |
"\"\"\"\n",
|
| 211 |
"\n",
|
|
|
|
|
|
|
| 212 |
"\n",
|
| 213 |
"def obs_to_text(obs: dict) -> str:\n",
|
| 214 |
" hints = obs.get(\"schema_hints\", {})\n",
|
|
@@ -243,23 +286,34 @@
|
|
| 243 |
"\n",
|
| 244 |
"\n",
|
| 245 |
"def build_prompt(obs_text: str) -> str:\n",
|
| 246 |
-
" \"\"\"Format as a chat prompt with system injected into first user message.\"\"\"\n",
|
| 247 |
" messages = [{\"role\": \"user\", \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + obs_text}]\n",
|
| 248 |
" return tokenizer.apply_chat_template(\n",
|
| 249 |
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 250 |
" )\n",
|
| 251 |
"\n",
|
| 252 |
"\n",
|
| 253 |
-
"
|
| 254 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
"N_PROMPTS_PER_WORKFLOW = 20\n",
|
| 256 |
"prompt_rows = []\n",
|
| 257 |
"\n",
|
| 258 |
"print(\"Collecting prompts from env resets...\")\n",
|
| 259 |
"for wf in [\"A\", \"B\", \"C\"]:\n",
|
| 260 |
" for _ in range(N_PROMPTS_PER_WORKFLOW):\n",
|
| 261 |
-
" result
|
| 262 |
-
" obs
|
| 263 |
" obs_text = obs_to_text(obs)\n",
|
| 264 |
" prompt_rows.append({\n",
|
| 265 |
" \"prompt\": build_prompt(obs_text),\n",
|
|
@@ -268,25 +322,17 @@
|
|
| 268 |
" })\n",
|
| 269 |
"\n",
|
| 270 |
"prompt_dataset = Dataset.from_list(prompt_rows)\n",
|
| 271 |
-
"
|
| 272 |
-
"
|
|
|
|
| 273 |
]
|
| 274 |
},
|
| 275 |
{
|
| 276 |
"cell_type": "markdown",
|
| 277 |
-
"id": "
|
| 278 |
"metadata": {},
|
| 279 |
"source": [
|
| 280 |
-
"##
|
| 281 |
-
"\n",
|
| 282 |
-
"Called by GRPOTrainer during training on each batch of generated completions.\n",
|
| 283 |
-
"For each completion:\n",
|
| 284 |
-
"1. Parse it as action JSON\n",
|
| 285 |
-
"2. Reset the env to a fresh episode for the right workflow\n",
|
| 286 |
-
"3. Send the action via `/step`\n",
|
| 287 |
-
"4. Return the reward\n",
|
| 288 |
-
"\n",
|
| 289 |
-
"This gives the model a live signal from the actual environment."
|
| 290 |
]
|
| 291 |
},
|
| 292 |
{
|
|
@@ -296,53 +342,20 @@
|
|
| 296 |
"metadata": {},
|
| 297 |
"outputs": [],
|
| 298 |
"source": [
|
| 299 |
-
"import re\n",
|
| 300 |
-
"from typing import List\n",
|
| 301 |
-
"\n",
|
| 302 |
-
"ENV_URL = \"http://localhost:8000\"\n",
|
| 303 |
-
"\n",
|
| 304 |
-
"\n",
|
| 305 |
-
"def parse_action(text: str):\n",
|
| 306 |
-
" \"\"\"Extract JSON action from model output.\"\"\"\n",
|
| 307 |
-
" text = text.strip()\n",
|
| 308 |
-
" # Strip markdown code fences if present\n",
|
| 309 |
-
" text = re.sub(r\"```(?:json)?\\s*\", \"\", text).strip()\n",
|
| 310 |
-
" try:\n",
|
| 311 |
-
" return json.loads(text)\n",
|
| 312 |
-
" except json.JSONDecodeError:\n",
|
| 313 |
-
" m = re.search(r\"\\{.*\\}\", text, re.DOTALL)\n",
|
| 314 |
-
" if m:\n",
|
| 315 |
-
" try:\n",
|
| 316 |
-
" return json.loads(m.group())\n",
|
| 317 |
-
" except Exception:\n",
|
| 318 |
-
" pass\n",
|
| 319 |
-
" return None\n",
|
| 320 |
-
"\n",
|
| 321 |
-
"\n",
|
| 322 |
"def orgos_reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:\n",
|
| 323 |
" \"\"\"\n",
|
| 324 |
" GRPO reward function — called by GRPOTrainer each training step.\n",
|
| 325 |
-
"\n",
|
| 326 |
-
" For each generated completion:\n",
|
| 327 |
-
" - Parse as action JSON\n",
|
| 328 |
-
" - Reset env to a fresh episode (workflow inferred from prompt)\n",
|
| 329 |
-
" - Step the env with the action\n",
|
| 330 |
-
" - Return the step reward\n",
|
| 331 |
-
"\n",
|
| 332 |
-
" Invalid JSON or failed actions return a -0.1 penalty.\n",
|
| 333 |
" \"\"\"\n",
|
| 334 |
" workflow_ids = kwargs.get(\"workflow_id\", [\"A\"] * len(completions))\n",
|
| 335 |
" rewards = []\n",
|
| 336 |
"\n",
|
| 337 |
" for completion, wf_id in zip(completions, workflow_ids):\n",
|
| 338 |
" action = parse_action(completion)\n",
|
| 339 |
-
"\n",
|
| 340 |
" if action is None:\n",
|
| 341 |
" rewards.append(-0.1)\n",
|
| 342 |
" continue\n",
|
| 343 |
-
"\n",
|
| 344 |
" try:\n",
|
| 345 |
-
" # Fresh episode for this action evaluation\n",
|
| 346 |
" httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": wf_id}, timeout=10)\n",
|
| 347 |
" result = httpx.post(f\"{ENV_URL}/step\", json=action, timeout=10).json()\n",
|
| 348 |
" rewards.append(float(result[\"reward\"]))\n",
|
|
@@ -352,24 +365,22 @@
|
|
| 352 |
" return rewards\n",
|
| 353 |
"\n",
|
| 354 |
"\n",
|
| 355 |
-
"
|
| 356 |
-
"
|
| 357 |
-
"
|
| 358 |
-
"
|
| 359 |
-
"
|
| 360 |
-
"
|
| 361 |
-
" workflow_id = [\"A\", \"A\"],\n",
|
| 362 |
")\n",
|
| 363 |
-
"
|
| 364 |
-
"print(f\" Invalid action reward: {test_rewards[1]:.4f}\")"
|
| 365 |
]
|
| 366 |
},
|
| 367 |
{
|
| 368 |
"cell_type": "markdown",
|
| 369 |
-
"id": "
|
| 370 |
"metadata": {},
|
| 371 |
"source": [
|
| 372 |
-
"##
|
| 373 |
]
|
| 374 |
},
|
| 375 |
{
|
|
@@ -379,8 +390,6 @@
|
|
| 379 |
"metadata": {},
|
| 380 |
"outputs": [],
|
| 381 |
"source": [
|
| 382 |
-
"import numpy as np\n",
|
| 383 |
-
"\n",
|
| 384 |
"FastLanguageModel.for_inference(model)\n",
|
| 385 |
"\n",
|
| 386 |
"\n",
|
|
@@ -397,9 +406,9 @@
|
|
| 397 |
" obs_text = obs_to_text(obs)\n",
|
| 398 |
" history.append({\"role\": \"user\", \"content\": obs_text})\n",
|
| 399 |
"\n",
|
| 400 |
-
"
|
| 401 |
-
" messages =
|
| 402 |
-
"
|
| 403 |
"\n",
|
| 404 |
" text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 405 |
" inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
|
|
@@ -430,27 +439,29 @@
|
|
| 430 |
" return obs.get(\"current_score\", 0.001)\n",
|
| 431 |
"\n",
|
| 432 |
"\n",
|
| 433 |
-
"N_EVAL = 10
|
| 434 |
"baseline_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
|
| 435 |
"\n",
|
| 436 |
-
"
|
| 437 |
"for wf in [\"A\", \"B\", \"C\"]:\n",
|
| 438 |
" for ep in range(N_EVAL):\n",
|
| 439 |
" score = run_episode_with_model(wf)\n",
|
| 440 |
" baseline_scores[wf].append(score)\n",
|
| 441 |
-
"
|
| 442 |
-
"
|
|
|
|
|
|
|
| 443 |
"\n",
|
| 444 |
"baseline_mean = np.mean([s for v in baseline_scores.values() for s in v])\n",
|
| 445 |
-
"
|
| 446 |
]
|
| 447 |
},
|
| 448 |
{
|
| 449 |
"cell_type": "markdown",
|
| 450 |
-
"id": "
|
| 451 |
"metadata": {},
|
| 452 |
"source": [
|
| 453 |
-
"##
|
| 454 |
]
|
| 455 |
},
|
| 456 |
{
|
|
@@ -461,57 +472,86 @@
|
|
| 461 |
"outputs": [],
|
| 462 |
"source": [
|
| 463 |
"from trl import GRPOConfig, GRPOTrainer\n",
|
|
|
|
| 464 |
"\n",
|
| 465 |
-
"# Switch back to training mode\n",
|
| 466 |
"model.train()\n",
|
| 467 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
"grpo_config = GRPOConfig(\n",
|
| 469 |
" output_dir = \"./orgos_grpo_ckpt\",\n",
|
| 470 |
-
" num_train_epochs =
|
| 471 |
-
" per_device_train_batch_size =
|
| 472 |
-
" gradient_accumulation_steps =
|
| 473 |
-
" learning_rate =
|
| 474 |
" warmup_steps = 10,\n",
|
| 475 |
" logging_steps = 5,\n",
|
| 476 |
" save_steps = 100,\n",
|
| 477 |
" bf16 = torch.cuda.is_bf16_supported(),\n",
|
| 478 |
" fp16 = not torch.cuda.is_bf16_supported(),\n",
|
| 479 |
" max_grad_norm = 1.0,\n",
|
| 480 |
-
"
|
| 481 |
-
" num_generations = 4, # G: candidate actions per prompt\n",
|
| 482 |
" max_new_tokens = 256,\n",
|
| 483 |
-
" temperature =
|
| 484 |
-
" beta =
|
| 485 |
" report_to = \"none\",\n",
|
| 486 |
" seed = 42,\n",
|
| 487 |
")\n",
|
| 488 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
"trainer = GRPOTrainer(\n",
|
| 490 |
-
" model
|
| 491 |
-
" args
|
| 492 |
-
" reward_funcs
|
| 493 |
-
" train_dataset
|
| 494 |
" processing_class = tokenizer,\n",
|
|
|
|
| 495 |
")\n",
|
| 496 |
"\n",
|
| 497 |
-
"
|
| 498 |
-
"print(f\" Prompts: {len(prompt_dataset)}\")\n",
|
| 499 |
-
"print(f\" Generations per prompt (G): {grpo_config.num_generations}\")\n",
|
| 500 |
-
"print(f\" Epochs: {grpo_config.num_train_epochs}\")\n",
|
| 501 |
-
"print(f\" Total env calls per epoch: ~{len(prompt_dataset) * grpo_config.num_generations}\")\n",
|
| 502 |
-
"print()\n",
|
| 503 |
-
"\n",
|
| 504 |
"train_result = trainer.train()\n",
|
| 505 |
-
"
|
| 506 |
-
"
|
|
|
|
| 507 |
]
|
| 508 |
},
|
| 509 |
{
|
| 510 |
"cell_type": "markdown",
|
| 511 |
-
"id": "
|
| 512 |
"metadata": {},
|
| 513 |
"source": [
|
| 514 |
-
"##
|
| 515 |
]
|
| 516 |
},
|
| 517 |
{
|
|
@@ -525,25 +565,41 @@
|
|
| 525 |
"\n",
|
| 526 |
"post_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
|
| 527 |
"\n",
|
| 528 |
-
"
|
| 529 |
"for wf in [\"A\", \"B\", \"C\"]:\n",
|
| 530 |
" for ep in range(N_EVAL):\n",
|
| 531 |
" score = run_episode_with_model(wf)\n",
|
| 532 |
" post_scores[wf].append(score)\n",
|
| 533 |
-
"
|
| 534 |
-
"
|
| 535 |
-
"\n",
|
| 536 |
-
"
|
| 537 |
-
"
|
| 538 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 539 |
]
|
| 540 |
},
|
| 541 |
{
|
| 542 |
"cell_type": "markdown",
|
| 543 |
-
"id": "
|
| 544 |
"metadata": {},
|
| 545 |
"source": [
|
| 546 |
-
"##
|
| 547 |
]
|
| 548 |
},
|
| 549 |
{
|
|
@@ -561,8 +617,7 @@
|
|
| 561 |
" color=\"white\", fontweight=\"bold\", y=0.98)\n",
|
| 562 |
"\n",
|
| 563 |
"gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)\n",
|
| 564 |
-
"\n",
|
| 565 |
-
"COLORS = {\"before\": \"#f87171\", \"after\": \"#34d399\", \"bg\": \"#1e293b\", \"grid\": \"#334155\"}\n",
|
| 566 |
"WF_LABELS = {\n",
|
| 567 |
" \"A\": \"Workflow A\\nCustomer Bug Fix\",\n",
|
| 568 |
" \"B\": \"Workflow B\\nEmployee Onboarding\",\n",
|
|
@@ -570,19 +625,16 @@
|
|
| 570 |
"}\n",
|
| 571 |
"\n",
|
| 572 |
"for col, wf in enumerate([\"A\", \"B\", \"C\"]):\n",
|
| 573 |
-
" ax
|
| 574 |
" ax.set_facecolor(COLORS[\"bg\"])\n",
|
| 575 |
" ax.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.7)\n",
|
| 576 |
-
"\n",
|
| 577 |
" before = baseline_scores[wf]\n",
|
| 578 |
" after = post_scores[wf]\n",
|
| 579 |
" delta = np.mean(after) - np.mean(before)\n",
|
| 580 |
-
"\n",
|
| 581 |
" ax.plot(before, color=COLORS[\"before\"], linewidth=1.5, alpha=0.8, label=\"Before GRPO\")\n",
|
| 582 |
" ax.plot(after, color=COLORS[\"after\"], linewidth=1.5, alpha=0.8, label=\"After GRPO\")\n",
|
| 583 |
" ax.axhline(np.mean(before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
|
| 584 |
" ax.axhline(np.mean(after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
|
| 585 |
-
"\n",
|
| 586 |
" ax.set_title(WF_LABELS[wf] + f\"\\n(Δ = {delta:+.4f})\", color=\"white\", fontsize=9)\n",
|
| 587 |
" ax.set_xlabel(\"Episode\", color=\"#94a3b8\", fontsize=8)\n",
|
| 588 |
" ax.set_ylabel(\"Final Score\", color=\"#94a3b8\", fontsize=8)\n",
|
|
@@ -596,18 +648,15 @@
|
|
| 596 |
"ax_hist = fig.add_subplot(gs[1, :])\n",
|
| 597 |
"ax_hist.set_facecolor(COLORS[\"bg\"])\n",
|
| 598 |
"ax_hist.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.5, axis=\"x\")\n",
|
| 599 |
-
"\n",
|
| 600 |
"all_before = [s for v in baseline_scores.values() for s in v]\n",
|
| 601 |
"all_after = [s for v in post_scores.values() for s in v]\n",
|
| 602 |
"bins = np.linspace(0, 1, 25)\n",
|
| 603 |
-
"\n",
|
| 604 |
"ax_hist.hist(all_before, bins=bins, color=COLORS[\"before\"], alpha=0.6,\n",
|
| 605 |
" label=f\"Before GRPO (mean={np.mean(all_before):.4f})\", edgecolor=\"none\")\n",
|
| 606 |
"ax_hist.hist(all_after, bins=bins, color=COLORS[\"after\"], alpha=0.6,\n",
|
| 607 |
" label=f\"After GRPO (mean={np.mean(all_after):.4f})\", edgecolor=\"none\")\n",
|
| 608 |
"ax_hist.axvline(np.mean(all_before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1.5)\n",
|
| 609 |
"ax_hist.axvline(np.mean(all_after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1.5)\n",
|
| 610 |
-
"\n",
|
| 611 |
"ax_hist.set_title(\"Score Distribution Across All Workflows\", color=\"white\", fontsize=10)\n",
|
| 612 |
"ax_hist.set_xlabel(\"Final Score\", color=\"#94a3b8\", fontsize=9)\n",
|
| 613 |
"ax_hist.set_ylabel(\"Count\", color=\"#94a3b8\", fontsize=9)\n",
|
|
@@ -620,15 +669,16 @@
|
|
| 620 |
"plt.savefig(\"before_after_curves.png\", dpi=150, bbox_inches=\"tight\",\n",
|
| 621 |
" facecolor=\"#0f172a\", edgecolor=\"none\")\n",
|
| 622 |
"plt.show()\n",
|
|
|
|
| 623 |
"print(\"Saved: before_after_curves.png\")"
|
| 624 |
]
|
| 625 |
},
|
| 626 |
{
|
| 627 |
"cell_type": "markdown",
|
| 628 |
-
"id": "
|
| 629 |
"metadata": {},
|
| 630 |
"source": [
|
| 631 |
-
"##
|
| 632 |
]
|
| 633 |
},
|
| 634 |
{
|
|
@@ -640,9 +690,18 @@
|
|
| 640 |
"source": [
|
| 641 |
"model.save_pretrained(\"orgos_lora_adapter\")\n",
|
| 642 |
"tokenizer.save_pretrained(\"orgos_lora_adapter\")\n",
|
| 643 |
-
"
|
| 644 |
-
"\n",
|
| 645 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 646 |
"# from huggingface_hub import login\n",
|
| 647 |
"# login(token=\"YOUR_HF_TOKEN\")\n",
|
| 648 |
"# model.push_to_hub(\"YOUR_USERNAME/orgos-qwen25-3b-grpo\")\n",
|
|
|
|
| 21 |
"4. GRPO computes relative advantages within the group (which action did better than average?)\n",
|
| 22 |
"5. Model is updated to favour higher-reward actions\n",
|
| 23 |
"\n",
|
| 24 |
+
"**Key training signal:** Schema drift creates a sharp reward gap. \n",
|
| 25 |
"Using a stale field name (e.g. `priority` when schema says `severity`) → **−0.20**. \n",
|
| 26 |
"Using the correct drifted name → **+0.10** adaptation bonus. \n",
|
| 27 |
"The model learns to read `schema_hints` before constructing action args."
|
|
|
|
| 77 |
},
|
| 78 |
{
|
| 79 |
"cell_type": "markdown",
|
| 80 |
+
"id": "sec_logger",
|
| 81 |
"metadata": {},
|
| 82 |
"source": [
|
| 83 |
+
"## 3. Training Logger\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"Writes structured logs to `training_log.txt` for submission. \n",
|
| 86 |
+
"Format mirrors the OpenEnv inference log spec:\n",
|
| 87 |
+
"- `[TRAIN_CONFIG]` — model, algorithm, hyperparameters\n",
|
| 88 |
+
"- `[EVAL]` — per-episode score during baseline or post-training eval\n",
|
| 89 |
+
"- `[TRAIN_STEP]` — loss, mean reward, KL per training step\n",
|
| 90 |
+
"- `[TRAIN_SUMMARY]` — final before/after comparison"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"id": "logger",
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [],
|
| 99 |
+
"source": [
|
| 100 |
+
"import datetime\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"LOG_FILE = \"training_log.txt\"\n",
|
| 103 |
+
"\n",
|
| 104 |
+
"# Clear any previous log\n",
|
| 105 |
+
"with open(LOG_FILE, \"w\") as f:\n",
|
| 106 |
+
" f.write(f\"# OrgOS GRPO Training Log\\n\")\n",
|
| 107 |
+
" f.write(f\"# Generated: {datetime.datetime.utcnow().isoformat()}Z\\n\\n\")\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"def tlog(line: str) -> None:\n",
|
| 111 |
+
" \"\"\"Append one structured log line to training_log.txt and print it.\"\"\"\n",
|
| 112 |
+
" print(line, flush=True)\n",
|
| 113 |
+
" with open(LOG_FILE, \"a\") as f:\n",
|
| 114 |
+
" f.write(line + \"\\n\")\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"print(f\"Logger ready — writing to {LOG_FILE}\")"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "markdown",
|
| 122 |
+
"id": "sec4",
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"source": [
|
| 125 |
+
"## 4. Start the OrgOS Environment Server"
|
| 126 |
]
|
| 127 |
},
|
| 128 |
{
|
|
|
|
| 143 |
"\n",
|
| 144 |
"health = httpx.get(\"http://localhost:8000/health\").json()\n",
|
| 145 |
"assert health[\"status\"] == \"healthy\", f\"Server not healthy: {health}\"\n",
|
| 146 |
+
"tlog(f\"[ENV] status=healthy version={health.get('version', '?')}\")"
|
| 147 |
]
|
| 148 |
},
|
| 149 |
{
|
| 150 |
"cell_type": "markdown",
|
| 151 |
+
"id": "sec5",
|
| 152 |
"metadata": {},
|
| 153 |
"source": [
|
| 154 |
+
"## 5. Load Model with Unsloth 4-bit LoRA"
|
| 155 |
]
|
| 156 |
},
|
| 157 |
{
|
|
|
|
| 166 |
"\n",
|
| 167 |
"MAX_SEQ_LEN = 2048\n",
|
| 168 |
"MODEL_NAME = \"Qwen/Qwen2.5-3B-Instruct\"\n",
|
| 169 |
+
"LORA_R = 16\n",
|
| 170 |
"\n",
|
| 171 |
"model, tokenizer = FastLanguageModel.from_pretrained(\n",
|
| 172 |
" model_name = MODEL_NAME,\n",
|
|
|
|
| 177 |
"\n",
|
| 178 |
"model = FastLanguageModel.get_peft_model(\n",
|
| 179 |
" model,\n",
|
| 180 |
+
" r = LORA_R,\n",
|
| 181 |
" target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 182 |
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 183 |
+
" lora_alpha = LORA_R,\n",
|
| 184 |
" lora_dropout = 0,\n",
|
| 185 |
" bias = \"none\",\n",
|
| 186 |
" use_gradient_checkpointing = \"unsloth\",\n",
|
| 187 |
" random_state = 42,\n",
|
| 188 |
")\n",
|
| 189 |
+
"\n",
|
| 190 |
"trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 191 |
+
"tlog(f\"[TRAIN_CONFIG] model={MODEL_NAME} lora_r={LORA_R} max_seq_len={MAX_SEQ_LEN} \"\n",
|
| 192 |
+
" f\"trainable_params={trainable:,} quantization=4bit\")"
|
| 193 |
]
|
| 194 |
},
|
| 195 |
{
|
| 196 |
"cell_type": "markdown",
|
| 197 |
+
"id": "sec6",
|
| 198 |
"metadata": {},
|
| 199 |
"source": [
|
| 200 |
+
"## 6. Prompt Dataset"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
]
|
| 202 |
},
|
| 203 |
{
|
|
|
|
| 207 |
"metadata": {},
|
| 208 |
"outputs": [],
|
| 209 |
"source": [
|
| 210 |
+
"import json, re\n",
|
| 211 |
+
"import numpy as np\n",
|
| 212 |
+
"from typing import List\n",
|
| 213 |
"from datasets import Dataset\n",
|
| 214 |
"\n",
|
| 215 |
"SYSTEM_PROMPT = \"\"\"\\\n",
|
|
|
|
| 250 |
"6. Stop when pending_steps is empty or done=true.\n",
|
| 251 |
"\"\"\"\n",
|
| 252 |
"\n",
|
| 253 |
+
"ENV_URL = \"http://localhost:8000\"\n",
|
| 254 |
+
"\n",
|
| 255 |
"\n",
|
| 256 |
"def obs_to_text(obs: dict) -> str:\n",
|
| 257 |
" hints = obs.get(\"schema_hints\", {})\n",
|
|
|
|
| 286 |
"\n",
|
| 287 |
"\n",
|
| 288 |
"def build_prompt(obs_text: str) -> str:\n",
|
|
|
|
| 289 |
" messages = [{\"role\": \"user\", \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + obs_text}]\n",
|
| 290 |
" return tokenizer.apply_chat_template(\n",
|
| 291 |
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 292 |
" )\n",
|
| 293 |
"\n",
|
| 294 |
"\n",
|
| 295 |
+
"def parse_action(text: str):\n",
|
| 296 |
+
" text = re.sub(r\"```(?:json)?\\s*\", \"\", text.strip()).strip()\n",
|
| 297 |
+
" try:\n",
|
| 298 |
+
" return json.loads(text)\n",
|
| 299 |
+
" except json.JSONDecodeError:\n",
|
| 300 |
+
" m = re.search(r\"\\{.*\\}\", text, re.DOTALL)\n",
|
| 301 |
+
" if m:\n",
|
| 302 |
+
" try:\n",
|
| 303 |
+
" return json.loads(m.group())\n",
|
| 304 |
+
" except Exception:\n",
|
| 305 |
+
" pass\n",
|
| 306 |
+
" return None\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"\n",
|
| 309 |
"N_PROMPTS_PER_WORKFLOW = 20\n",
|
| 310 |
"prompt_rows = []\n",
|
| 311 |
"\n",
|
| 312 |
"print(\"Collecting prompts from env resets...\")\n",
|
| 313 |
"for wf in [\"A\", \"B\", \"C\"]:\n",
|
| 314 |
" for _ in range(N_PROMPTS_PER_WORKFLOW):\n",
|
| 315 |
+
" result = httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": wf}).json()\n",
|
| 316 |
+
" obs = result[\"observation\"]\n",
|
| 317 |
" obs_text = obs_to_text(obs)\n",
|
| 318 |
" prompt_rows.append({\n",
|
| 319 |
" \"prompt\": build_prompt(obs_text),\n",
|
|
|
|
| 322 |
" })\n",
|
| 323 |
"\n",
|
| 324 |
"prompt_dataset = Dataset.from_list(prompt_rows)\n",
|
| 325 |
+
"tlog(f\"[TRAIN_CONFIG] algorithm=GRPO prompts={len(prompt_dataset)} \"\n",
|
| 326 |
+
" f\"workflows=A,B,C prompts_per_workflow={N_PROMPTS_PER_WORKFLOW}\")\n",
|
| 327 |
+
"print(f\"Prompt dataset ready: {len(prompt_dataset)} examples\")"
|
| 328 |
]
|
| 329 |
},
|
| 330 |
{
|
| 331 |
"cell_type": "markdown",
|
| 332 |
+
"id": "sec7",
|
| 333 |
"metadata": {},
|
| 334 |
"source": [
|
| 335 |
+
"## 7. Reward Function"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
]
|
| 337 |
},
|
| 338 |
{
|
|
|
|
| 342 |
"metadata": {},
|
| 343 |
"outputs": [],
|
| 344 |
"source": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
"def orgos_reward_fn(completions: List[str], prompts: List[str], **kwargs) -> List[float]:\n",
|
| 346 |
" \"\"\"\n",
|
| 347 |
" GRPO reward function — called by GRPOTrainer each training step.\n",
|
| 348 |
+
" Parses each completion as an action JSON, steps the live env, returns the reward.\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
" \"\"\"\n",
|
| 350 |
" workflow_ids = kwargs.get(\"workflow_id\", [\"A\"] * len(completions))\n",
|
| 351 |
" rewards = []\n",
|
| 352 |
"\n",
|
| 353 |
" for completion, wf_id in zip(completions, workflow_ids):\n",
|
| 354 |
" action = parse_action(completion)\n",
|
|
|
|
| 355 |
" if action is None:\n",
|
| 356 |
" rewards.append(-0.1)\n",
|
| 357 |
" continue\n",
|
|
|
|
| 358 |
" try:\n",
|
|
|
|
| 359 |
" httpx.post(f\"{ENV_URL}/reset\", json={\"workflow_id\": wf_id}, timeout=10)\n",
|
| 360 |
" result = httpx.post(f\"{ENV_URL}/step\", json=action, timeout=10).json()\n",
|
| 361 |
" rewards.append(float(result[\"reward\"]))\n",
|
|
|
|
| 365 |
" return rewards\n",
|
| 366 |
"\n",
|
| 367 |
"\n",
|
| 368 |
+
"# Sanity check\n",
|
| 369 |
+
"test_r = orgos_reward_fn(\n",
|
| 370 |
+
" completions = ['{\"app\": \"zendesk\", \"operation\": \"list_tickets\", \"args\": {\"state\": \"new\"}}',\n",
|
| 371 |
+
" 'not json'],\n",
|
| 372 |
+
" prompts = [\"\", \"\"],\n",
|
| 373 |
+
" workflow_id = [\"A\", \"A\"],\n",
|
|
|
|
| 374 |
")\n",
|
| 375 |
+
"tlog(f\"[REWARD_FN_CHECK] valid_action={test_r[0]:.4f} invalid_action={test_r[1]:.4f}\")"
|
|
|
|
| 376 |
]
|
| 377 |
},
|
| 378 |
{
|
| 379 |
"cell_type": "markdown",
|
| 380 |
+
"id": "sec8",
|
| 381 |
"metadata": {},
|
| 382 |
"source": [
|
| 383 |
+
"## 8. Collect Baseline Scores (Pre-Training)"
|
| 384 |
]
|
| 385 |
},
|
| 386 |
{
|
|
|
|
| 390 |
"metadata": {},
|
| 391 |
"outputs": [],
|
| 392 |
"source": [
|
|
|
|
|
|
|
| 393 |
"FastLanguageModel.for_inference(model)\n",
|
| 394 |
"\n",
|
| 395 |
"\n",
|
|
|
|
| 406 |
" obs_text = obs_to_text(obs)\n",
|
| 407 |
" history.append({\"role\": \"user\", \"content\": obs_text})\n",
|
| 408 |
"\n",
|
| 409 |
+
" messages = list(history)\n",
|
| 410 |
+
" messages[0] = {\"role\": \"user\",\n",
|
| 411 |
+
" \"content\": SYSTEM_PROMPT + \"\\n\\n---\\n\\n\" + messages[0][\"content\"]}\n",
|
| 412 |
"\n",
|
| 413 |
" text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n",
|
| 414 |
" inputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\n",
|
|
|
|
| 439 |
" return obs.get(\"current_score\", 0.001)\n",
|
| 440 |
"\n",
|
| 441 |
"\n",
|
| 442 |
+
"N_EVAL = 10\n",
|
| 443 |
"baseline_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
|
| 444 |
"\n",
|
| 445 |
+
"tlog(\"[EVAL_START] phase=baseline\")\n",
|
| 446 |
"for wf in [\"A\", \"B\", \"C\"]:\n",
|
| 447 |
" for ep in range(N_EVAL):\n",
|
| 448 |
" score = run_episode_with_model(wf)\n",
|
| 449 |
" baseline_scores[wf].append(score)\n",
|
| 450 |
+
" tlog(f\"[EVAL] phase=baseline workflow={wf} episode={ep+1} score={score:.4f}\")\n",
|
| 451 |
+
" wf_mean = np.mean(baseline_scores[wf])\n",
|
| 452 |
+
" tlog(f\"[EVAL_WORKFLOW] phase=baseline workflow={wf} \"\n",
|
| 453 |
+
" f\"mean={wf_mean:.4f} min={min(baseline_scores[wf]):.4f} max={max(baseline_scores[wf]):.4f}\")\n",
|
| 454 |
"\n",
|
| 455 |
"baseline_mean = np.mean([s for v in baseline_scores.values() for s in v])\n",
|
| 456 |
+
"tlog(f\"[EVAL_END] phase=baseline overall_mean={baseline_mean:.4f}\")"
|
| 457 |
]
|
| 458 |
},
|
| 459 |
{
|
| 460 |
"cell_type": "markdown",
|
| 461 |
+
"id": "sec9",
|
| 462 |
"metadata": {},
|
| 463 |
"source": [
|
| 464 |
+
"## 9. GRPO Training"
|
| 465 |
]
|
| 466 |
},
|
| 467 |
{
|
|
|
|
| 472 |
"outputs": [],
|
| 473 |
"source": [
|
| 474 |
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 475 |
+
"from transformers import TrainerCallback\n",
|
| 476 |
"\n",
|
|
|
|
| 477 |
"model.train()\n",
|
| 478 |
"\n",
|
| 479 |
+
"NUM_EPOCHS = 3\n",
|
| 480 |
+
"BATCH_SIZE = 4\n",
|
| 481 |
+
"GRAD_ACCUM = 2\n",
|
| 482 |
+
"LR = 5e-5\n",
|
| 483 |
+
"NUM_GEN = 4\n",
|
| 484 |
+
"TEMPERATURE = 0.8\n",
|
| 485 |
+
"BETA = 0.04\n",
|
| 486 |
+
"\n",
|
| 487 |
"grpo_config = GRPOConfig(\n",
|
| 488 |
" output_dir = \"./orgos_grpo_ckpt\",\n",
|
| 489 |
+
" num_train_epochs = NUM_EPOCHS,\n",
|
| 490 |
+
" per_device_train_batch_size = BATCH_SIZE,\n",
|
| 491 |
+
" gradient_accumulation_steps = GRAD_ACCUM,\n",
|
| 492 |
+
" learning_rate = LR,\n",
|
| 493 |
" warmup_steps = 10,\n",
|
| 494 |
" logging_steps = 5,\n",
|
| 495 |
" save_steps = 100,\n",
|
| 496 |
" bf16 = torch.cuda.is_bf16_supported(),\n",
|
| 497 |
" fp16 = not torch.cuda.is_bf16_supported(),\n",
|
| 498 |
" max_grad_norm = 1.0,\n",
|
| 499 |
+
" num_generations = NUM_GEN,\n",
|
|
|
|
| 500 |
" max_new_tokens = 256,\n",
|
| 501 |
+
" temperature = TEMPERATURE,\n",
|
| 502 |
+
" beta = BETA,\n",
|
| 503 |
" report_to = \"none\",\n",
|
| 504 |
" seed = 42,\n",
|
| 505 |
")\n",
|
| 506 |
"\n",
|
| 507 |
+
"tlog(f\"[TRAIN_CONFIG] epochs={NUM_EPOCHS} batch_size={BATCH_SIZE} \"\n",
|
| 508 |
+
" f\"grad_accum={GRAD_ACCUM} lr={LR} num_generations={NUM_GEN} \"\n",
|
| 509 |
+
" f\"temperature={TEMPERATURE} beta_kl={BETA}\")\n",
|
| 510 |
+
"\n",
|
| 511 |
+
"\n",
|
| 512 |
+
"class OrgOSLogCallback(TrainerCallback):\n",
|
| 513 |
+
" \"\"\"Logs each training step to training_log.txt.\"\"\"\n",
|
| 514 |
+
"\n",
|
| 515 |
+
" def on_log(self, args, state, control, logs=None, **kwargs):\n",
|
| 516 |
+
" if logs is None:\n",
|
| 517 |
+
" return\n",
|
| 518 |
+
" step = state.global_step\n",
|
| 519 |
+
" loss = logs.get(\"loss\", logs.get(\"train_loss\", \"?\"))\n",
|
| 520 |
+
" mean_reward = logs.get(\"reward\", logs.get(\"mean_reward\", \"?\"))\n",
|
| 521 |
+
" kl = logs.get(\"kl\", logs.get(\"approx_kl\", \"?\"))\n",
|
| 522 |
+
" lr_now = logs.get(\"learning_rate\", \"?\")\n",
|
| 523 |
+
"\n",
|
| 524 |
+
" loss_str = f\"{loss:.6f}\" if isinstance(loss, float) else str(loss)\n",
|
| 525 |
+
" reward_str = f\"{mean_reward:.4f}\" if isinstance(mean_reward, float) else str(mean_reward)\n",
|
| 526 |
+
" kl_str = f\"{kl:.6f}\" if isinstance(kl, float) else str(kl)\n",
|
| 527 |
+
" lr_str = f\"{lr_now:.2e}\" if isinstance(lr_now, float) else str(lr_now)\n",
|
| 528 |
+
"\n",
|
| 529 |
+
" tlog(f\"[TRAIN_STEP] step={step} loss={loss_str} \"\n",
|
| 530 |
+
" f\"mean_reward={reward_str} kl={kl_str} lr={lr_str}\")\n",
|
| 531 |
+
"\n",
|
| 532 |
+
"\n",
|
| 533 |
"trainer = GRPOTrainer(\n",
|
| 534 |
+
" model = model,\n",
|
| 535 |
+
" args = grpo_config,\n",
|
| 536 |
+
" reward_funcs = orgos_reward_fn,\n",
|
| 537 |
+
" train_dataset = prompt_dataset,\n",
|
| 538 |
" processing_class = tokenizer,\n",
|
| 539 |
+
" callbacks = [OrgOSLogCallback()],\n",
|
| 540 |
")\n",
|
| 541 |
"\n",
|
| 542 |
+
"tlog(\"[TRAIN_START]\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 543 |
"train_result = trainer.train()\n",
|
| 544 |
+
"tlog(f\"[TRAIN_END] total_steps={train_result.global_step} \"\n",
|
| 545 |
+
" f\"train_loss={train_result.training_loss:.6f} \"\n",
|
| 546 |
+
" f\"train_runtime_s={train_result.metrics.get('train_runtime', 0):.1f}\")"
|
| 547 |
]
|
| 548 |
},
|
| 549 |
{
|
| 550 |
"cell_type": "markdown",
|
| 551 |
+
"id": "sec10",
|
| 552 |
"metadata": {},
|
| 553 |
"source": [
|
| 554 |
+
"## 10. Collect Post-Training Scores"
|
| 555 |
]
|
| 556 |
},
|
| 557 |
{
|
|
|
|
| 565 |
"\n",
|
| 566 |
"post_scores = {wf: [] for wf in [\"A\", \"B\", \"C\"]}\n",
|
| 567 |
"\n",
|
| 568 |
+
"tlog(\"[EVAL_START] phase=post_training\")\n",
|
| 569 |
"for wf in [\"A\", \"B\", \"C\"]:\n",
|
| 570 |
" for ep in range(N_EVAL):\n",
|
| 571 |
" score = run_episode_with_model(wf)\n",
|
| 572 |
" post_scores[wf].append(score)\n",
|
| 573 |
+
" tlog(f\"[EVAL] phase=post_training workflow={wf} episode={ep+1} score={score:.4f}\")\n",
|
| 574 |
+
" wf_mean = np.mean(post_scores[wf])\n",
|
| 575 |
+
" tlog(f\"[EVAL_WORKFLOW] phase=post_training workflow={wf} \"\n",
|
| 576 |
+
" f\"mean={wf_mean:.4f} min={min(post_scores[wf]):.4f} max={max(post_scores[wf]):.4f}\")\n",
|
| 577 |
+
"\n",
|
| 578 |
+
"post_mean = np.mean([s for v in post_scores.values() for s in v])\n",
|
| 579 |
+
"improvement = post_mean - baseline_mean\n",
|
| 580 |
+
"tlog(f\"[EVAL_END] phase=post_training overall_mean={post_mean:.4f}\")\n",
|
| 581 |
+
"tlog(\n",
|
| 582 |
+
" f\"[TRAIN_SUMMARY] \"\n",
|
| 583 |
+
" f\"model={MODEL_NAME} algorithm=GRPO \"\n",
|
| 584 |
+
" f\"baseline_mean={baseline_mean:.4f} \"\n",
|
| 585 |
+
" f\"post_training_mean={post_mean:.4f} \"\n",
|
| 586 |
+
" f\"improvement={improvement:+.4f} \"\n",
|
| 587 |
+
" f\"workflow_A_before={np.mean(baseline_scores['A']):.4f} \"\n",
|
| 588 |
+
" f\"workflow_A_after={np.mean(post_scores['A']):.4f} \"\n",
|
| 589 |
+
" f\"workflow_B_before={np.mean(baseline_scores['B']):.4f} \"\n",
|
| 590 |
+
" f\"workflow_B_after={np.mean(post_scores['B']):.4f} \"\n",
|
| 591 |
+
" f\"workflow_C_before={np.mean(baseline_scores['C']):.4f} \"\n",
|
| 592 |
+
" f\"workflow_C_after={np.mean(post_scores['C']):.4f}\"\n",
|
| 593 |
+
")\n",
|
| 594 |
+
"print(f\"\\nImprovement: {baseline_mean:.4f} → {post_mean:.4f} ({improvement:+.4f})\")"
|
| 595 |
]
|
| 596 |
},
|
| 597 |
{
|
| 598 |
"cell_type": "markdown",
|
| 599 |
+
"id": "sec11",
|
| 600 |
"metadata": {},
|
| 601 |
"source": [
|
| 602 |
+
"## 11. Plot Before / After"
|
| 603 |
]
|
| 604 |
},
|
| 605 |
{
|
|
|
|
| 617 |
" color=\"white\", fontweight=\"bold\", y=0.98)\n",
|
| 618 |
"\n",
|
| 619 |
"gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)\n",
|
| 620 |
+
"COLORS = {\"before\": \"#f87171\", \"after\": \"#34d399\", \"bg\": \"#1e293b\", \"grid\": \"#334155\"}\n",
|
|
|
|
| 621 |
"WF_LABELS = {\n",
|
| 622 |
" \"A\": \"Workflow A\\nCustomer Bug Fix\",\n",
|
| 623 |
" \"B\": \"Workflow B\\nEmployee Onboarding\",\n",
|
|
|
|
| 625 |
"}\n",
|
| 626 |
"\n",
|
| 627 |
"for col, wf in enumerate([\"A\", \"B\", \"C\"]):\n",
|
| 628 |
+
" ax = fig.add_subplot(gs[0, col])\n",
|
| 629 |
" ax.set_facecolor(COLORS[\"bg\"])\n",
|
| 630 |
" ax.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.7)\n",
|
|
|
|
| 631 |
" before = baseline_scores[wf]\n",
|
| 632 |
" after = post_scores[wf]\n",
|
| 633 |
" delta = np.mean(after) - np.mean(before)\n",
|
|
|
|
| 634 |
" ax.plot(before, color=COLORS[\"before\"], linewidth=1.5, alpha=0.8, label=\"Before GRPO\")\n",
|
| 635 |
" ax.plot(after, color=COLORS[\"after\"], linewidth=1.5, alpha=0.8, label=\"After GRPO\")\n",
|
| 636 |
" ax.axhline(np.mean(before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
|
| 637 |
" ax.axhline(np.mean(after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1, alpha=0.5)\n",
|
|
|
|
| 638 |
" ax.set_title(WF_LABELS[wf] + f\"\\n(Δ = {delta:+.4f})\", color=\"white\", fontsize=9)\n",
|
| 639 |
" ax.set_xlabel(\"Episode\", color=\"#94a3b8\", fontsize=8)\n",
|
| 640 |
" ax.set_ylabel(\"Final Score\", color=\"#94a3b8\", fontsize=8)\n",
|
|
|
|
| 648 |
"ax_hist = fig.add_subplot(gs[1, :])\n",
|
| 649 |
"ax_hist.set_facecolor(COLORS[\"bg\"])\n",
|
| 650 |
"ax_hist.grid(color=COLORS[\"grid\"], linewidth=0.5, alpha=0.5, axis=\"x\")\n",
|
|
|
|
| 651 |
"all_before = [s for v in baseline_scores.values() for s in v]\n",
|
| 652 |
"all_after = [s for v in post_scores.values() for s in v]\n",
|
| 653 |
"bins = np.linspace(0, 1, 25)\n",
|
|
|
|
| 654 |
"ax_hist.hist(all_before, bins=bins, color=COLORS[\"before\"], alpha=0.6,\n",
|
| 655 |
" label=f\"Before GRPO (mean={np.mean(all_before):.4f})\", edgecolor=\"none\")\n",
|
| 656 |
"ax_hist.hist(all_after, bins=bins, color=COLORS[\"after\"], alpha=0.6,\n",
|
| 657 |
" label=f\"After GRPO (mean={np.mean(all_after):.4f})\", edgecolor=\"none\")\n",
|
| 658 |
"ax_hist.axvline(np.mean(all_before), color=COLORS[\"before\"], linestyle=\"--\", linewidth=1.5)\n",
|
| 659 |
"ax_hist.axvline(np.mean(all_after), color=COLORS[\"after\"], linestyle=\"--\", linewidth=1.5)\n",
|
|
|
|
| 660 |
"ax_hist.set_title(\"Score Distribution Across All Workflows\", color=\"white\", fontsize=10)\n",
|
| 661 |
"ax_hist.set_xlabel(\"Final Score\", color=\"#94a3b8\", fontsize=9)\n",
|
| 662 |
"ax_hist.set_ylabel(\"Count\", color=\"#94a3b8\", fontsize=9)\n",
|
|
|
|
| 669 |
"plt.savefig(\"before_after_curves.png\", dpi=150, bbox_inches=\"tight\",\n",
|
| 670 |
" facecolor=\"#0f172a\", edgecolor=\"none\")\n",
|
| 671 |
"plt.show()\n",
|
| 672 |
+
"tlog(\"[ARTIFACT] file=before_after_curves.png\")\n",
|
| 673 |
"print(\"Saved: before_after_curves.png\")"
|
| 674 |
]
|
| 675 |
},
|
| 676 |
{
|
| 677 |
"cell_type": "markdown",
|
| 678 |
+
"id": "sec12",
|
| 679 |
"metadata": {},
|
| 680 |
"source": [
|
| 681 |
+
"## 12. Save LoRA Adapter & Training Log"
|
| 682 |
]
|
| 683 |
},
|
| 684 |
{
|
|
|
|
| 690 |
"source": [
|
| 691 |
"model.save_pretrained(\"orgos_lora_adapter\")\n",
|
| 692 |
"tokenizer.save_pretrained(\"orgos_lora_adapter\")\n",
|
| 693 |
+
"tlog(\"[ARTIFACT] file=orgos_lora_adapter/\")\n",
|
| 694 |
+
"tlog(\"[ARTIFACT] file=training_log.txt\")\n",
|
| 695 |
+
"\n",
|
| 696 |
+
"print(f\"\\n{'='*60}\")\n",
|
| 697 |
+
"print(\" Submission artefacts\")\n",
|
| 698 |
+
"print(f\"{'='*60}\")\n",
|
| 699 |
+
"print(\" training_log.txt — structured training log\")\n",
|
| 700 |
+
"print(\" before_after_curves.png — score improvement chart\")\n",
|
| 701 |
+
"print(\" orgos_lora_adapter/ — LoRA weights\")\n",
|
| 702 |
+
"print(f\"{'='*60}\")\n",
|
| 703 |
+
"\n",
|
| 704 |
+
"# Optional: push to HuggingFace Hub\n",
|
| 705 |
"# from huggingface_hub import login\n",
|
| 706 |
"# login(token=\"YOUR_HF_TOKEN\")\n",
|
| 707 |
"# model.push_to_hub(\"YOUR_USERNAME/orgos-qwen25-3b-grpo\")\n",
|