Jayant-Kernel Claude Sonnet 4.6 commited on
Fix notebook: HF Space URL, /step envelope, health check retry on cold start
Browse files- ENV_BASE_URL now points to ajsaxena-deceit.hf.space by default (USE_LOCAL_DOCKER=False)
- /step calls now use {"action": {...}} envelope as required by OpenEnv
- /reset response unpacked from {"observation": {...}} wrapper
- Health check retries 12x with 10s sleep to handle HF cold start
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- training/sanity_run.ipynb +5 -230
training/sanity_run.ipynb
CHANGED
|
@@ -30,46 +30,7 @@
|
|
| 30 |
"execution_count": null,
|
| 31 |
"metadata": {},
|
| 32 |
"outputs": [],
|
| 33 |
-
"source":
|
| 34 |
-
"# ============================================================\n",
|
| 35 |
-
"# SANITY RUN CONFIG (Phase 3)\n",
|
| 36 |
-
"# ============================================================\n",
|
| 37 |
-
"TRAINING_STEPS = 50\n",
|
| 38 |
-
"ROLLOUTS_PER_PROMPT = 4\n",
|
| 39 |
-
"BATCH_SIZE = 2\n",
|
| 40 |
-
"LEARNING_RATE = 5e-6\n",
|
| 41 |
-
"LORA_RANK = 16\n",
|
| 42 |
-
"SAVE_STEPS = 25\n",
|
| 43 |
-
"\n",
|
| 44 |
-
"# ============================================================\n",
|
| 45 |
-
"# FULL RUN CONFIG (Phase 5) — uncomment to activate\n",
|
| 46 |
-
"# ============================================================\n",
|
| 47 |
-
"# TRAINING_STEPS = 500\n",
|
| 48 |
-
"# ROLLOUTS_PER_PROMPT = 8\n",
|
| 49 |
-
"# BATCH_SIZE = 4\n",
|
| 50 |
-
"# LEARNING_RATE = 2e-6\n",
|
| 51 |
-
"# LORA_RANK = 32\n",
|
| 52 |
-
"# SAVE_STEPS = 100\n",
|
| 53 |
-
"\n",
|
| 54 |
-
"# ============================================================\n",
|
| 55 |
-
"# Environment connection — toggle here\n",
|
| 56 |
-
"# ============================================================\n",
|
| 57 |
-
"USE_LOCAL_DOCKER = True # True = local Docker on port 8000 (default, faster)\n",
|
| 58 |
-
" # False = deployed HF Space (for Phase 5+)\n",
|
| 59 |
-
"\n",
|
| 60 |
-
"HF_SPACE_URL = \"https://<your-hf-username>-deceit-env.hf.space\" # only used if above is False\n",
|
| 61 |
-
"\n",
|
| 62 |
-
"ENV_BASE_URL = \"http://localhost:8000\" if USE_LOCAL_DOCKER else HF_SPACE_URL\n",
|
| 63 |
-
"\n",
|
| 64 |
-
"# ============================================================\n",
|
| 65 |
-
"# Model & logging\n",
|
| 66 |
-
"# ============================================================\n",
|
| 67 |
-
"MODEL_NAME = \"unsloth/Qwen2.5-0.5B-Instruct\"\n",
|
| 68 |
-
"HF_REPO_ID = \"<your-hf-username>/deceit-qwen-0.5b-sanity\" # checkpoint destination\n",
|
| 69 |
-
"WANDB_PROJECT = \"deceit-sanity\"\n",
|
| 70 |
-
"\n",
|
| 71 |
-
"print(f\"Config loaded. Steps={TRAINING_STEPS}, ENV={ENV_BASE_URL}\")"
|
| 72 |
-
]
|
| 73 |
},
|
| 74 |
{
|
| 75 |
"cell_type": "markdown",
|
|
@@ -167,14 +128,7 @@
|
|
| 167 |
"execution_count": null,
|
| 168 |
"metadata": {},
|
| 169 |
"outputs": [],
|
| 170 |
-
"source":
|
| 171 |
-
"import requests\n",
|
| 172 |
-
"\n",
|
| 173 |
-
"# Verify env is reachable\n",
|
| 174 |
-
"resp = requests.get(f\"{ENV_BASE_URL}/health\", timeout=10)\n",
|
| 175 |
-
"print(f\"Health check: {resp.status_code} — {resp.json()}\")\n",
|
| 176 |
-
"assert resp.status_code == 200, f\"Env not reachable at {ENV_BASE_URL}\""
|
| 177 |
-
]
|
| 178 |
},
|
| 179 |
{
|
| 180 |
"cell_type": "markdown",
|
|
@@ -270,96 +224,7 @@
|
|
| 270 |
"execution_count": null,
|
| 271 |
"metadata": {},
|
| 272 |
"outputs": [],
|
| 273 |
-
"source": [
|
| 274 |
-
"def run_rollout(model, tokenizer, base_url: str, verbose: bool = False) -> dict:\n",
|
| 275 |
-
" \"\"\"Run one full episode and return trajectory + total reward.\"\"\"\n",
|
| 276 |
-
" # Reset environment\n",
|
| 277 |
-
" resp = requests.post(f\"{base_url}/reset\", json={}, timeout=15)\n",
|
| 278 |
-
" resp.raise_for_status()\n",
|
| 279 |
-
" obs = resp.json()\n",
|
| 280 |
-
"\n",
|
| 281 |
-
" question = obs.get(\"question\", \"\")\n",
|
| 282 |
-
" context = obs.get(\"context\", [])\n",
|
| 283 |
-
" max_turns = obs.get(\"max_turns\", 3)\n",
|
| 284 |
-
"\n",
|
| 285 |
-
" total_reward = 0.0\n",
|
| 286 |
-
" steps = 0\n",
|
| 287 |
-
" parse_fails = 0\n",
|
| 288 |
-
" trajectory = []\n",
|
| 289 |
-
"\n",
|
| 290 |
-
" for turn in range(max_turns):\n",
|
| 291 |
-
" # Build prompt for this turn\n",
|
| 292 |
-
" context_str = \"\\n\".join(context) if context else \"\"\n",
|
| 293 |
-
" user_content = f\"Question: {question}\"\n",
|
| 294 |
-
" if context_str:\n",
|
| 295 |
-
" user_content += f\"\\n\\n{context_str}\"\n",
|
| 296 |
-
" user_content += f\"\\n\\nTurn {turn + 1} of {max_turns}. Respond in JSON.\"\n",
|
| 297 |
-
"\n",
|
| 298 |
-
" messages = [\n",
|
| 299 |
-
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 300 |
-
" {\"role\": \"user\", \"content\": user_content},\n",
|
| 301 |
-
" ]\n",
|
| 302 |
-
" prompt = tokenizer.apply_chat_template(\n",
|
| 303 |
-
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 304 |
-
" )\n",
|
| 305 |
-
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
|
| 306 |
-
"\n",
|
| 307 |
-
" with torch.no_grad():\n",
|
| 308 |
-
" output_ids = model.generate(\n",
|
| 309 |
-
" **inputs,\n",
|
| 310 |
-
" max_new_tokens=256,\n",
|
| 311 |
-
" do_sample=True,\n",
|
| 312 |
-
" temperature=0.7,\n",
|
| 313 |
-
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 314 |
-
" )\n",
|
| 315 |
-
" generated = tokenizer.decode(\n",
|
| 316 |
-
" output_ids[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n",
|
| 317 |
-
" )\n",
|
| 318 |
-
"\n",
|
| 319 |
-
" # Parse action\n",
|
| 320 |
-
" try:\n",
|
| 321 |
-
" action = parse_action(generated)\n",
|
| 322 |
-
" except Exception:\n",
|
| 323 |
-
" action = PARSE_FAIL_ACTION.copy()\n",
|
| 324 |
-
" parse_fails += 1\n",
|
| 325 |
-
"\n",
|
| 326 |
-
" # Force final on last turn\n",
|
| 327 |
-
" if turn == max_turns - 1:\n",
|
| 328 |
-
" action[\"is_final\"] = True\n",
|
| 329 |
-
"\n",
|
| 330 |
-
" if verbose:\n",
|
| 331 |
-
" print(f\" Turn {turn+1}: is_final={action['is_final']} answer='{action['answer']}' confidence={action['confidence']:.2f}\")\n",
|
| 332 |
-
"\n",
|
| 333 |
-
" # Step environment\n",
|
| 334 |
-
" step_resp = requests.post(f\"{base_url}/step\", json=action, timeout=30)\n",
|
| 335 |
-
" step_resp.raise_for_status()\n",
|
| 336 |
-
" step_obs = step_resp.json()\n",
|
| 337 |
-
"\n",
|
| 338 |
-
" reward = step_obs.get(\"reward\", 0.0)\n",
|
| 339 |
-
" done = step_obs.get(\"done\", False)\n",
|
| 340 |
-
" context = step_obs.get(\"context\", [])\n",
|
| 341 |
-
"\n",
|
| 342 |
-
" total_reward += reward\n",
|
| 343 |
-
" steps += 1\n",
|
| 344 |
-
" trajectory.append({\n",
|
| 345 |
-
" \"turn\": turn + 1, \"action\": action, \"reward\": reward,\n",
|
| 346 |
-
" \"done\": done, \"metadata\": step_obs.get(\"metadata\", {})\n",
|
| 347 |
-
" })\n",
|
| 348 |
-
"\n",
|
| 349 |
-
" if done:\n",
|
| 350 |
-
" break\n",
|
| 351 |
-
"\n",
|
| 352 |
-
" return {\n",
|
| 353 |
-
" \"question\": question,\n",
|
| 354 |
-
" \"total_reward\": total_reward,\n",
|
| 355 |
-
" \"steps\": steps,\n",
|
| 356 |
-
" \"parse_fails\": parse_fails,\n",
|
| 357 |
-
" \"trajectory\": trajectory,\n",
|
| 358 |
-
" }\n",
|
| 359 |
-
"\n",
|
| 360 |
-
"\n",
|
| 361 |
-
"print(\"Rollout function ready.\")"
|
| 362 |
-
]
|
| 363 |
},
|
| 364 |
{
|
| 365 |
"cell_type": "markdown",
|
|
@@ -467,97 +332,7 @@
|
|
| 467 |
"execution_count": null,
|
| 468 |
"metadata": {},
|
| 469 |
"outputs": [],
|
| 470 |
-
"source": [
|
| 471 |
-
"import threading\n",
|
| 472 |
-
"\n",
|
| 473 |
-
"_env_lock = threading.Lock()\n",
|
| 474 |
-
"\n",
|
| 475 |
-
"def grpo_reward_fn(completions, prompts=None, **kwargs):\n",
|
| 476 |
-
" \"\"\"GRPO reward function: run one rollout per completion, return list of rewards.\n",
|
| 477 |
-
" \n",
|
| 478 |
-
" GRPO passes a list of completions (generated texts) for the same prompt.\n",
|
| 479 |
-
" Each gets an independent rollout in the environment.\n",
|
| 480 |
-
" \"\"\"\n",
|
| 481 |
-
" rewards = []\n",
|
| 482 |
-
" parse_fail_count = 0\n",
|
| 483 |
-
"\n",
|
| 484 |
-
" for completion_text in completions:\n",
|
| 485 |
-
" # Parse the initial action from the model's first completion\n",
|
| 486 |
-
" try:\n",
|
| 487 |
-
" action = parse_action(completion_text)\n",
|
| 488 |
-
" except Exception:\n",
|
| 489 |
-
" action = PARSE_FAIL_ACTION.copy()\n",
|
| 490 |
-
" parse_fail_count += 1\n",
|
| 491 |
-
"\n",
|
| 492 |
-
" try:\n",
|
| 493 |
-
" with _env_lock:\n",
|
| 494 |
-
" # Reset for fresh episode\n",
|
| 495 |
-
" reset_resp = requests.post(f\"{ENV_BASE_URL}/reset\", json={}, timeout=15)\n",
|
| 496 |
-
" reset_resp.raise_for_status()\n",
|
| 497 |
-
" obs = reset_resp.json()\n",
|
| 498 |
-
" max_turns = obs.get(\"max_turns\", 3)\n",
|
| 499 |
-
"\n",
|
| 500 |
-
" # If model committed on turn 1, just step once\n",
|
| 501 |
-
" # If not final, continue rolling out with greedy decoding\n",
|
| 502 |
-
" total_reward = 0.0\n",
|
| 503 |
-
" current_action = action\n",
|
| 504 |
-
" context = obs.get(\"context\", [])\n",
|
| 505 |
-
" question = obs.get(\"question\", \"\")\n",
|
| 506 |
-
"\n",
|
| 507 |
-
" for turn in range(max_turns):\n",
|
| 508 |
-
" if turn == max_turns - 1:\n",
|
| 509 |
-
" current_action[\"is_final\"] = True\n",
|
| 510 |
-
"\n",
|
| 511 |
-
" step_resp = requests.post(f\"{ENV_BASE_URL}/step\", json=current_action, timeout=30)\n",
|
| 512 |
-
" step_resp.raise_for_status()\n",
|
| 513 |
-
" step_obs = step_resp.json()\n",
|
| 514 |
-
"\n",
|
| 515 |
-
" total_reward += step_obs.get(\"reward\", 0.0)\n",
|
| 516 |
-
" done = step_obs.get(\"done\", False)\n",
|
| 517 |
-
" context = step_obs.get(\"context\", [])\n",
|
| 518 |
-
"\n",
|
| 519 |
-
" if done:\n",
|
| 520 |
-
" break\n",
|
| 521 |
-
"\n",
|
| 522 |
-
" # Continue rollout with model for subsequent turns\n",
|
| 523 |
-
" context_str = \"\\n\".join(context)\n",
|
| 524 |
-
" user_content = f\"Question: {question}\\n\\n{context_str}\\n\\nTurn {turn+2} of {max_turns}. Respond in JSON.\"\n",
|
| 525 |
-
" messages = [\n",
|
| 526 |
-
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 527 |
-
" {\"role\": \"user\", \"content\": user_content},\n",
|
| 528 |
-
" ]\n",
|
| 529 |
-
" next_prompt = tokenizer.apply_chat_template(\n",
|
| 530 |
-
" messages, tokenize=False, add_generation_prompt=True\n",
|
| 531 |
-
" )\n",
|
| 532 |
-
" inputs = tokenizer(next_prompt, return_tensors=\"pt\").to(model.device)\n",
|
| 533 |
-
" with torch.no_grad():\n",
|
| 534 |
-
" out_ids = model.generate(\n",
|
| 535 |
-
" **inputs, max_new_tokens=256,\n",
|
| 536 |
-
" do_sample=False, # greedy for subsequent turns\n",
|
| 537 |
-
" pad_token_id=tokenizer.eos_token_id,\n",
|
| 538 |
-
" )\n",
|
| 539 |
-
" next_text = tokenizer.decode(\n",
|
| 540 |
-
" out_ids[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n",
|
| 541 |
-
" )\n",
|
| 542 |
-
" try:\n",
|
| 543 |
-
" current_action = parse_action(next_text)\n",
|
| 544 |
-
" except Exception:\n",
|
| 545 |
-
" current_action = PARSE_FAIL_ACTION.copy()\n",
|
| 546 |
-
"\n",
|
| 547 |
-
" except Exception as e:\n",
|
| 548 |
-
" print(f\" [reward_fn] Episode error: {e}\")\n",
|
| 549 |
-
" total_reward = -1.3 # worst possible reward on crash\n",
|
| 550 |
-
"\n",
|
| 551 |
-
" rewards.append(total_reward)\n",
|
| 552 |
-
"\n",
|
| 553 |
-
" if parse_fail_count > 0:\n",
|
| 554 |
-
" print(f\" [reward_fn] Parse failures: {parse_fail_count}/{len(completions)}\")\n",
|
| 555 |
-
"\n",
|
| 556 |
-
" return rewards\n",
|
| 557 |
-
"\n",
|
| 558 |
-
"\n",
|
| 559 |
-
"print(\"GRPO reward function ready.\")"
|
| 560 |
-
]
|
| 561 |
},
|
| 562 |
{
|
| 563 |
"cell_type": "markdown",
|
|
@@ -793,4 +568,4 @@
|
|
| 793 |
},
|
| 794 |
"nbformat": 4,
|
| 795 |
"nbformat_minor": 4
|
| 796 |
-
}
|
|
|
|
| 30 |
"execution_count": null,
|
| 31 |
"metadata": {},
|
| 32 |
"outputs": [],
|
| 33 |
+
"source": "# ============================================================\n# SANITY RUN CONFIG (Phase 3)\n# ============================================================\nTRAINING_STEPS = 50\nROLLOUTS_PER_PROMPT = 4\nBATCH_SIZE = 2\nLEARNING_RATE = 5e-6\nLORA_RANK = 16\nSAVE_STEPS = 25\n\n# ============================================================\n# FULL RUN CONFIG (Phase 5) — uncomment to activate\n# ============================================================\n# TRAINING_STEPS = 500\n# ROLLOUTS_PER_PROMPT = 8\n# BATCH_SIZE = 4\n# LEARNING_RATE = 2e-6\n# LORA_RANK = 32\n# SAVE_STEPS = 100\n\n# ============================================================\n# Environment connection — toggle here\n# ============================================================\nUSE_LOCAL_DOCKER = False # True = local Docker on port 8000\n # False = deployed HF Space (default for Colab)\n\nHF_SPACE_URL = \"https://ajsaxena-deceit.hf.space\" # Ajsaxena/DECEIT on HF Spaces\n\nENV_BASE_URL = \"http://localhost:8000\" if USE_LOCAL_DOCKER else HF_SPACE_URL\n\n# ============================================================\n# Model & logging\n# ============================================================\nMODEL_NAME = \"unsloth/Qwen2.5-0.5B-Instruct\"\nHF_REPO_ID = \"Ajsaxena/deceit-qwen-0.5b-sanity\" # checkpoint destination\nWANDB_PROJECT = \"deceit-sanity\"\n\nprint(f\"Config loaded. Steps={TRAINING_STEPS}, ENV={ENV_BASE_URL}\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
},
|
| 35 |
{
|
| 36 |
"cell_type": "markdown",
|
|
|
|
| 128 |
"execution_count": null,
|
| 129 |
"metadata": {},
|
| 130 |
"outputs": [],
|
| 131 |
+
"source": "import requests\nimport time\n\n# Verify env is reachable — retries for HF Space cold start (up to 2 min)\nprint(f\"Connecting to {ENV_BASE_URL} ...\")\nfor attempt in range(12):\n try:\n resp = requests.get(f\"{ENV_BASE_URL}/health\", timeout=15)\n if resp.status_code == 200:\n print(f\"✓ Health check passed: {resp.json()}\")\n break\n else:\n print(f\" Attempt {attempt+1}: status {resp.status_code}, retrying...\")\n except Exception as e:\n print(f\" Attempt {attempt+1}: {e}, retrying in 10s...\")\n time.sleep(10)\nelse:\n raise RuntimeError(f\"Env not reachable at {ENV_BASE_URL} after 12 attempts\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
},
|
| 133 |
{
|
| 134 |
"cell_type": "markdown",
|
|
|
|
| 224 |
"execution_count": null,
|
| 225 |
"metadata": {},
|
| 226 |
"outputs": [],
|
| 227 |
+
"source": "def run_rollout(model, tokenizer, base_url: str, verbose: bool = False) -> dict:\n \"\"\"Run one full episode and return trajectory + total reward.\"\"\"\n resp = requests.post(f\"{base_url}/reset\", json={}, timeout=30)\n resp.raise_for_status()\n obs = resp.json()\n # OpenEnv wraps observation in {\"observation\": {...}}\n obs_data = obs.get(\"observation\", obs)\n question = obs_data.get(\"question\", \"\")\n context = obs_data.get(\"context\", [])\n max_turns = obs_data.get(\"max_turns\", 3)\n\n total_reward = 0.0\n steps = 0\n parse_fails = 0\n trajectory = []\n\n for turn in range(max_turns):\n context_str = \"\\n\".join(context) if context else \"\"\n user_content = f\"Question: {question}\"\n if context_str:\n user_content += f\"\\n\\n{context_str}\"\n user_content += f\"\\n\\nTurn {turn + 1} of {max_turns}. Respond in JSON.\"\n\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": user_content},\n ]\n prompt = tokenizer.apply_chat_template(\n messages, tokenize=False, add_generation_prompt=True\n )\n inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n\n with torch.no_grad():\n output_ids = model.generate(\n **inputs,\n max_new_tokens=256,\n do_sample=True,\n temperature=0.7,\n pad_token_id=tokenizer.eos_token_id,\n )\n generated = tokenizer.decode(\n output_ids[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n )\n\n try:\n action = parse_action(generated)\n except Exception:\n action = PARSE_FAIL_ACTION.copy()\n parse_fails += 1\n\n if turn == max_turns - 1:\n action[\"is_final\"] = True\n\n if verbose:\n print(f\" Turn {turn+1}: is_final={action['is_final']} answer='{action['answer']}' confidence={action['confidence']:.2f}\")\n\n # OpenEnv /step expects {\"action\": {...}}\n step_resp = requests.post(f\"{base_url}/step\", json={\"action\": action}, timeout=30)\n step_resp.raise_for_status()\n step_obs = step_resp.json()\n\n # Response is {\"observation\": {...}, \"reward\": ..., \"done\": ...}\n step_obs_data = step_obs.get(\"observation\", step_obs)\n reward = step_obs.get(\"reward\", 0.0) or 0.0\n done = step_obs.get(\"done\", False)\n context = step_obs_data.get(\"context\", [])\n\n total_reward += reward\n steps += 1\n trajectory.append({\n \"turn\": turn + 1, \"action\": action, \"reward\": reward,\n \"done\": done, \"metadata\": step_obs_data.get(\"metadata\", {})\n })\n\n if done:\n break\n\n return {\n \"question\": question,\n \"total_reward\": total_reward,\n \"steps\": steps,\n \"parse_fails\": parse_fails,\n \"trajectory\": trajectory,\n }\n\n\nprint(\"Rollout function ready.\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
},
|
| 229 |
{
|
| 230 |
"cell_type": "markdown",
|
|
|
|
| 332 |
"execution_count": null,
|
| 333 |
"metadata": {},
|
| 334 |
"outputs": [],
|
| 335 |
+
"source": "import threading\n\n_env_lock = threading.Lock()\n\ndef grpo_reward_fn(completions, prompts=None, **kwargs):\n \"\"\"GRPO reward function: run one rollout per completion, return list of rewards.\"\"\"\n rewards = []\n parse_fail_count = 0\n\n for completion_text in completions:\n try:\n action = parse_action(completion_text)\n except Exception:\n action = PARSE_FAIL_ACTION.copy()\n parse_fail_count += 1\n\n try:\n with _env_lock:\n reset_resp = requests.post(f\"{ENV_BASE_URL}/reset\", json={}, timeout=30)\n reset_resp.raise_for_status()\n obs = reset_resp.json()\n obs_data = obs.get(\"observation\", obs)\n max_turns = obs_data.get(\"max_turns\", 3)\n question = obs_data.get(\"question\", \"\")\n context = obs_data.get(\"context\", [])\n\n total_reward = 0.0\n current_action = action\n\n for turn in range(max_turns):\n if turn == max_turns - 1:\n current_action[\"is_final\"] = True\n\n # OpenEnv /step expects {\"action\": {...}}\n step_resp = requests.post(\n f\"{ENV_BASE_URL}/step\",\n json={\"action\": current_action},\n timeout=30,\n )\n step_resp.raise_for_status()\n step_obs = step_resp.json()\n step_obs_data = step_obs.get(\"observation\", step_obs)\n\n reward = step_obs.get(\"reward\", 0.0) or 0.0\n done = step_obs.get(\"done\", False)\n context = step_obs_data.get(\"context\", [])\n total_reward += reward\n\n if done:\n break\n\n # Subsequent turns: greedy decoding\n context_str = \"\\n\".join(context)\n user_content = f\"Question: {question}\\n\\n{context_str}\\n\\nTurn {turn+2} of {max_turns}. Respond in JSON.\"\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": user_content},\n ]\n next_prompt = tokenizer.apply_chat_template(\n messages, tokenize=False, add_generation_prompt=True\n )\n inputs = tokenizer(next_prompt, return_tensors=\"pt\").to(model.device)\n with torch.no_grad():\n out_ids = model.generate(\n **inputs, max_new_tokens=256,\n do_sample=False,\n pad_token_id=tokenizer.eos_token_id,\n )\n next_text = tokenizer.decode(\n out_ids[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n )\n try:\n current_action = parse_action(next_text)\n except Exception:\n current_action = PARSE_FAIL_ACTION.copy()\n\n except Exception as e:\n print(f\" [reward_fn] Episode error: {e}\")\n total_reward = -1.3\n\n rewards.append(total_reward)\n\n if parse_fail_count > 0:\n print(f\" [reward_fn] Parse failures: {parse_fail_count}/{len(completions)}\")\n\n return rewards\n\n\nprint(\"GRPO reward function ready.\")"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 336 |
},
|
| 337 |
{
|
| 338 |
"cell_type": "markdown",
|
|
|
|
| 568 |
},
|
| 569 |
"nbformat": 4,
|
| 570 |
"nbformat_minor": 4
|
| 571 |
+
}
|