Jayant-Kernel Claude Sonnet 4.6 commited on
Commit
db475da
·
unverified ·
1 Parent(s): 97384d7

Fix notebook: HF Space URL, /step envelope, health check retry on cold start

Browse files

- ENV_BASE_URL now points to ajsaxena-deceit.hf.space by default (USE_LOCAL_DOCKER=False)
- /step calls now use {"action": {...}} envelope as required by OpenEnv
- /reset response unpacked from {"observation": {...}} wrapper
- Health check retries 12x with 10s sleep to handle HF cold start

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. training/sanity_run.ipynb +5 -230
training/sanity_run.ipynb CHANGED
@@ -30,46 +30,7 @@
30
  "execution_count": null,
31
  "metadata": {},
32
  "outputs": [],
33
- "source": [
34
- "# ============================================================\n",
35
- "# SANITY RUN CONFIG (Phase 3)\n",
36
- "# ============================================================\n",
37
- "TRAINING_STEPS = 50\n",
38
- "ROLLOUTS_PER_PROMPT = 4\n",
39
- "BATCH_SIZE = 2\n",
40
- "LEARNING_RATE = 5e-6\n",
41
- "LORA_RANK = 16\n",
42
- "SAVE_STEPS = 25\n",
43
- "\n",
44
- "# ============================================================\n",
45
- "# FULL RUN CONFIG (Phase 5) — uncomment to activate\n",
46
- "# ============================================================\n",
47
- "# TRAINING_STEPS = 500\n",
48
- "# ROLLOUTS_PER_PROMPT = 8\n",
49
- "# BATCH_SIZE = 4\n",
50
- "# LEARNING_RATE = 2e-6\n",
51
- "# LORA_RANK = 32\n",
52
- "# SAVE_STEPS = 100\n",
53
- "\n",
54
- "# ============================================================\n",
55
- "# Environment connection — toggle here\n",
56
- "# ============================================================\n",
57
- "USE_LOCAL_DOCKER = True # True = local Docker on port 8000 (default, faster)\n",
58
- " # False = deployed HF Space (for Phase 5+)\n",
59
- "\n",
60
- "HF_SPACE_URL = \"https://<your-hf-username>-deceit-env.hf.space\" # only used if above is False\n",
61
- "\n",
62
- "ENV_BASE_URL = \"http://localhost:8000\" if USE_LOCAL_DOCKER else HF_SPACE_URL\n",
63
- "\n",
64
- "# ============================================================\n",
65
- "# Model & logging\n",
66
- "# ============================================================\n",
67
- "MODEL_NAME = \"unsloth/Qwen2.5-0.5B-Instruct\"\n",
68
- "HF_REPO_ID = \"<your-hf-username>/deceit-qwen-0.5b-sanity\" # checkpoint destination\n",
69
- "WANDB_PROJECT = \"deceit-sanity\"\n",
70
- "\n",
71
- "print(f\"Config loaded. Steps={TRAINING_STEPS}, ENV={ENV_BASE_URL}\")"
72
- ]
73
  },
74
  {
75
  "cell_type": "markdown",
@@ -167,14 +128,7 @@
167
  "execution_count": null,
168
  "metadata": {},
169
  "outputs": [],
170
- "source": [
171
- "import requests\n",
172
- "\n",
173
- "# Verify env is reachable\n",
174
- "resp = requests.get(f\"{ENV_BASE_URL}/health\", timeout=10)\n",
175
- "print(f\"Health check: {resp.status_code} — {resp.json()}\")\n",
176
- "assert resp.status_code == 200, f\"Env not reachable at {ENV_BASE_URL}\""
177
- ]
178
  },
179
  {
180
  "cell_type": "markdown",
@@ -270,96 +224,7 @@
270
  "execution_count": null,
271
  "metadata": {},
272
  "outputs": [],
273
- "source": [
274
- "def run_rollout(model, tokenizer, base_url: str, verbose: bool = False) -> dict:\n",
275
- " \"\"\"Run one full episode and return trajectory + total reward.\"\"\"\n",
276
- " # Reset environment\n",
277
- " resp = requests.post(f\"{base_url}/reset\", json={}, timeout=15)\n",
278
- " resp.raise_for_status()\n",
279
- " obs = resp.json()\n",
280
- "\n",
281
- " question = obs.get(\"question\", \"\")\n",
282
- " context = obs.get(\"context\", [])\n",
283
- " max_turns = obs.get(\"max_turns\", 3)\n",
284
- "\n",
285
- " total_reward = 0.0\n",
286
- " steps = 0\n",
287
- " parse_fails = 0\n",
288
- " trajectory = []\n",
289
- "\n",
290
- " for turn in range(max_turns):\n",
291
- " # Build prompt for this turn\n",
292
- " context_str = \"\\n\".join(context) if context else \"\"\n",
293
- " user_content = f\"Question: {question}\"\n",
294
- " if context_str:\n",
295
- " user_content += f\"\\n\\n{context_str}\"\n",
296
- " user_content += f\"\\n\\nTurn {turn + 1} of {max_turns}. Respond in JSON.\"\n",
297
- "\n",
298
- " messages = [\n",
299
- " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
300
- " {\"role\": \"user\", \"content\": user_content},\n",
301
- " ]\n",
302
- " prompt = tokenizer.apply_chat_template(\n",
303
- " messages, tokenize=False, add_generation_prompt=True\n",
304
- " )\n",
305
- " inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
306
- "\n",
307
- " with torch.no_grad():\n",
308
- " output_ids = model.generate(\n",
309
- " **inputs,\n",
310
- " max_new_tokens=256,\n",
311
- " do_sample=True,\n",
312
- " temperature=0.7,\n",
313
- " pad_token_id=tokenizer.eos_token_id,\n",
314
- " )\n",
315
- " generated = tokenizer.decode(\n",
316
- " output_ids[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n",
317
- " )\n",
318
- "\n",
319
- " # Parse action\n",
320
- " try:\n",
321
- " action = parse_action(generated)\n",
322
- " except Exception:\n",
323
- " action = PARSE_FAIL_ACTION.copy()\n",
324
- " parse_fails += 1\n",
325
- "\n",
326
- " # Force final on last turn\n",
327
- " if turn == max_turns - 1:\n",
328
- " action[\"is_final\"] = True\n",
329
- "\n",
330
- " if verbose:\n",
331
- " print(f\" Turn {turn+1}: is_final={action['is_final']} answer='{action['answer']}' confidence={action['confidence']:.2f}\")\n",
332
- "\n",
333
- " # Step environment\n",
334
- " step_resp = requests.post(f\"{base_url}/step\", json=action, timeout=30)\n",
335
- " step_resp.raise_for_status()\n",
336
- " step_obs = step_resp.json()\n",
337
- "\n",
338
- " reward = step_obs.get(\"reward\", 0.0)\n",
339
- " done = step_obs.get(\"done\", False)\n",
340
- " context = step_obs.get(\"context\", [])\n",
341
- "\n",
342
- " total_reward += reward\n",
343
- " steps += 1\n",
344
- " trajectory.append({\n",
345
- " \"turn\": turn + 1, \"action\": action, \"reward\": reward,\n",
346
- " \"done\": done, \"metadata\": step_obs.get(\"metadata\", {})\n",
347
- " })\n",
348
- "\n",
349
- " if done:\n",
350
- " break\n",
351
- "\n",
352
- " return {\n",
353
- " \"question\": question,\n",
354
- " \"total_reward\": total_reward,\n",
355
- " \"steps\": steps,\n",
356
- " \"parse_fails\": parse_fails,\n",
357
- " \"trajectory\": trajectory,\n",
358
- " }\n",
359
- "\n",
360
- "\n",
361
- "print(\"Rollout function ready.\")"
362
- ]
363
  },
364
  {
365
  "cell_type": "markdown",
@@ -467,97 +332,7 @@
467
  "execution_count": null,
468
  "metadata": {},
469
  "outputs": [],
470
- "source": [
471
- "import threading\n",
472
- "\n",
473
- "_env_lock = threading.Lock()\n",
474
- "\n",
475
- "def grpo_reward_fn(completions, prompts=None, **kwargs):\n",
476
- " \"\"\"GRPO reward function: run one rollout per completion, return list of rewards.\n",
477
- " \n",
478
- " GRPO passes a list of completions (generated texts) for the same prompt.\n",
479
- " Each gets an independent rollout in the environment.\n",
480
- " \"\"\"\n",
481
- " rewards = []\n",
482
- " parse_fail_count = 0\n",
483
- "\n",
484
- " for completion_text in completions:\n",
485
- " # Parse the initial action from the model's first completion\n",
486
- " try:\n",
487
- " action = parse_action(completion_text)\n",
488
- " except Exception:\n",
489
- " action = PARSE_FAIL_ACTION.copy()\n",
490
- " parse_fail_count += 1\n",
491
- "\n",
492
- " try:\n",
493
- " with _env_lock:\n",
494
- " # Reset for fresh episode\n",
495
- " reset_resp = requests.post(f\"{ENV_BASE_URL}/reset\", json={}, timeout=15)\n",
496
- " reset_resp.raise_for_status()\n",
497
- " obs = reset_resp.json()\n",
498
- " max_turns = obs.get(\"max_turns\", 3)\n",
499
- "\n",
500
- " # If model committed on turn 1, just step once\n",
501
- " # If not final, continue rolling out with greedy decoding\n",
502
- " total_reward = 0.0\n",
503
- " current_action = action\n",
504
- " context = obs.get(\"context\", [])\n",
505
- " question = obs.get(\"question\", \"\")\n",
506
- "\n",
507
- " for turn in range(max_turns):\n",
508
- " if turn == max_turns - 1:\n",
509
- " current_action[\"is_final\"] = True\n",
510
- "\n",
511
- " step_resp = requests.post(f\"{ENV_BASE_URL}/step\", json=current_action, timeout=30)\n",
512
- " step_resp.raise_for_status()\n",
513
- " step_obs = step_resp.json()\n",
514
- "\n",
515
- " total_reward += step_obs.get(\"reward\", 0.0)\n",
516
- " done = step_obs.get(\"done\", False)\n",
517
- " context = step_obs.get(\"context\", [])\n",
518
- "\n",
519
- " if done:\n",
520
- " break\n",
521
- "\n",
522
- " # Continue rollout with model for subsequent turns\n",
523
- " context_str = \"\\n\".join(context)\n",
524
- " user_content = f\"Question: {question}\\n\\n{context_str}\\n\\nTurn {turn+2} of {max_turns}. Respond in JSON.\"\n",
525
- " messages = [\n",
526
- " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
527
- " {\"role\": \"user\", \"content\": user_content},\n",
528
- " ]\n",
529
- " next_prompt = tokenizer.apply_chat_template(\n",
530
- " messages, tokenize=False, add_generation_prompt=True\n",
531
- " )\n",
532
- " inputs = tokenizer(next_prompt, return_tensors=\"pt\").to(model.device)\n",
533
- " with torch.no_grad():\n",
534
- " out_ids = model.generate(\n",
535
- " **inputs, max_new_tokens=256,\n",
536
- " do_sample=False, # greedy for subsequent turns\n",
537
- " pad_token_id=tokenizer.eos_token_id,\n",
538
- " )\n",
539
- " next_text = tokenizer.decode(\n",
540
- " out_ids[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n",
541
- " )\n",
542
- " try:\n",
543
- " current_action = parse_action(next_text)\n",
544
- " except Exception:\n",
545
- " current_action = PARSE_FAIL_ACTION.copy()\n",
546
- "\n",
547
- " except Exception as e:\n",
548
- " print(f\" [reward_fn] Episode error: {e}\")\n",
549
- " total_reward = -1.3 # worst possible reward on crash\n",
550
- "\n",
551
- " rewards.append(total_reward)\n",
552
- "\n",
553
- " if parse_fail_count > 0:\n",
554
- " print(f\" [reward_fn] Parse failures: {parse_fail_count}/{len(completions)}\")\n",
555
- "\n",
556
- " return rewards\n",
557
- "\n",
558
- "\n",
559
- "print(\"GRPO reward function ready.\")"
560
- ]
561
  },
562
  {
563
  "cell_type": "markdown",
@@ -793,4 +568,4 @@
793
  },
794
  "nbformat": 4,
795
  "nbformat_minor": 4
796
- }
 
30
  "execution_count": null,
31
  "metadata": {},
32
  "outputs": [],
33
+ "source": "# ============================================================\n# SANITY RUN CONFIG (Phase 3)\n# ============================================================\nTRAINING_STEPS = 50\nROLLOUTS_PER_PROMPT = 4\nBATCH_SIZE = 2\nLEARNING_RATE = 5e-6\nLORA_RANK = 16\nSAVE_STEPS = 25\n\n# ============================================================\n# FULL RUN CONFIG (Phase 5) — uncomment to activate\n# ============================================================\n# TRAINING_STEPS = 500\n# ROLLOUTS_PER_PROMPT = 8\n# BATCH_SIZE = 4\n# LEARNING_RATE = 2e-6\n# LORA_RANK = 32\n# SAVE_STEPS = 100\n\n# ============================================================\n# Environment connection — toggle here\n# ============================================================\nUSE_LOCAL_DOCKER = False # True = local Docker on port 8000\n # False = deployed HF Space (default for Colab)\n\nHF_SPACE_URL = \"https://ajsaxena-deceit.hf.space\" # Ajsaxena/DECEIT on HF Spaces\n\nENV_BASE_URL = \"http://localhost:8000\" if USE_LOCAL_DOCKER else HF_SPACE_URL\n\n# ============================================================\n# Model & logging\n# ============================================================\nMODEL_NAME = \"unsloth/Qwen2.5-0.5B-Instruct\"\nHF_REPO_ID = \"Ajsaxena/deceit-qwen-0.5b-sanity\" # checkpoint destination\nWANDB_PROJECT = \"deceit-sanity\"\n\nprint(f\"Config loaded. Steps={TRAINING_STEPS}, ENV={ENV_BASE_URL}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  },
35
  {
36
  "cell_type": "markdown",
 
128
  "execution_count": null,
129
  "metadata": {},
130
  "outputs": [],
131
+ "source": "import requests\nimport time\n\n# Verify env is reachable — retries for HF Space cold start (up to 2 min)\nprint(f\"Connecting to {ENV_BASE_URL} ...\")\nfor attempt in range(12):\n try:\n resp = requests.get(f\"{ENV_BASE_URL}/health\", timeout=15)\n if resp.status_code == 200:\n print(f\"✓ Health check passed: {resp.json()}\")\n break\n else:\n print(f\" Attempt {attempt+1}: status {resp.status_code}, retrying...\")\n except Exception as e:\n print(f\" Attempt {attempt+1}: {e}, retrying in 10s...\")\n time.sleep(10)\nelse:\n raise RuntimeError(f\"Env not reachable at {ENV_BASE_URL} after 12 attempts\")"
 
 
 
 
 
 
 
132
  },
133
  {
134
  "cell_type": "markdown",
 
224
  "execution_count": null,
225
  "metadata": {},
226
  "outputs": [],
227
+ "source": "def run_rollout(model, tokenizer, base_url: str, verbose: bool = False) -> dict:\n \"\"\"Run one full episode and return trajectory + total reward.\"\"\"\n resp = requests.post(f\"{base_url}/reset\", json={}, timeout=30)\n resp.raise_for_status()\n obs = resp.json()\n # OpenEnv wraps observation in {\"observation\": {...}}\n obs_data = obs.get(\"observation\", obs)\n question = obs_data.get(\"question\", \"\")\n context = obs_data.get(\"context\", [])\n max_turns = obs_data.get(\"max_turns\", 3)\n\n total_reward = 0.0\n steps = 0\n parse_fails = 0\n trajectory = []\n\n for turn in range(max_turns):\n context_str = \"\\n\".join(context) if context else \"\"\n user_content = f\"Question: {question}\"\n if context_str:\n user_content += f\"\\n\\n{context_str}\"\n user_content += f\"\\n\\nTurn {turn + 1} of {max_turns}. Respond in JSON.\"\n\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": user_content},\n ]\n prompt = tokenizer.apply_chat_template(\n messages, tokenize=False, add_generation_prompt=True\n )\n inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n\n with torch.no_grad():\n output_ids = model.generate(\n **inputs,\n max_new_tokens=256,\n do_sample=True,\n temperature=0.7,\n pad_token_id=tokenizer.eos_token_id,\n )\n generated = tokenizer.decode(\n output_ids[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n )\n\n try:\n action = parse_action(generated)\n except Exception:\n action = PARSE_FAIL_ACTION.copy()\n parse_fails += 1\n\n if turn == max_turns - 1:\n action[\"is_final\"] = True\n\n if verbose:\n print(f\" Turn {turn+1}: is_final={action['is_final']} answer='{action['answer']}' confidence={action['confidence']:.2f}\")\n\n # OpenEnv /step expects {\"action\": {...}}\n step_resp = requests.post(f\"{base_url}/step\", json={\"action\": action}, timeout=30)\n step_resp.raise_for_status()\n step_obs = step_resp.json()\n\n # Response is {\"observation\": {...}, \"reward\": ..., \"done\": ...}\n step_obs_data = step_obs.get(\"observation\", step_obs)\n reward = step_obs.get(\"reward\", 0.0) or 0.0\n done = step_obs.get(\"done\", False)\n context = step_obs_data.get(\"context\", [])\n\n total_reward += reward\n steps += 1\n trajectory.append({\n \"turn\": turn + 1, \"action\": action, \"reward\": reward,\n \"done\": done, \"metadata\": step_obs_data.get(\"metadata\", {})\n })\n\n if done:\n break\n\n return {\n \"question\": question,\n \"total_reward\": total_reward,\n \"steps\": steps,\n \"parse_fails\": parse_fails,\n \"trajectory\": trajectory,\n }\n\n\nprint(\"Rollout function ready.\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  },
229
  {
230
  "cell_type": "markdown",
 
332
  "execution_count": null,
333
  "metadata": {},
334
  "outputs": [],
335
+ "source": "import threading\n\n_env_lock = threading.Lock()\n\ndef grpo_reward_fn(completions, prompts=None, **kwargs):\n \"\"\"GRPO reward function: run one rollout per completion, return list of rewards.\"\"\"\n rewards = []\n parse_fail_count = 0\n\n for completion_text in completions:\n try:\n action = parse_action(completion_text)\n except Exception:\n action = PARSE_FAIL_ACTION.copy()\n parse_fail_count += 1\n\n try:\n with _env_lock:\n reset_resp = requests.post(f\"{ENV_BASE_URL}/reset\", json={}, timeout=30)\n reset_resp.raise_for_status()\n obs = reset_resp.json()\n obs_data = obs.get(\"observation\", obs)\n max_turns = obs_data.get(\"max_turns\", 3)\n question = obs_data.get(\"question\", \"\")\n context = obs_data.get(\"context\", [])\n\n total_reward = 0.0\n current_action = action\n\n for turn in range(max_turns):\n if turn == max_turns - 1:\n current_action[\"is_final\"] = True\n\n # OpenEnv /step expects {\"action\": {...}}\n step_resp = requests.post(\n f\"{ENV_BASE_URL}/step\",\n json={\"action\": current_action},\n timeout=30,\n )\n step_resp.raise_for_status()\n step_obs = step_resp.json()\n step_obs_data = step_obs.get(\"observation\", step_obs)\n\n reward = step_obs.get(\"reward\", 0.0) or 0.0\n done = step_obs.get(\"done\", False)\n context = step_obs_data.get(\"context\", [])\n total_reward += reward\n\n if done:\n break\n\n # Subsequent turns: greedy decoding\n context_str = \"\\n\".join(context)\n user_content = f\"Question: {question}\\n\\n{context_str}\\n\\nTurn {turn+2} of {max_turns}. Respond in JSON.\"\n messages = [\n {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n {\"role\": \"user\", \"content\": user_content},\n ]\n next_prompt = tokenizer.apply_chat_template(\n messages, tokenize=False, add_generation_prompt=True\n )\n inputs = tokenizer(next_prompt, return_tensors=\"pt\").to(model.device)\n with torch.no_grad():\n out_ids = model.generate(\n **inputs, max_new_tokens=256,\n do_sample=False,\n pad_token_id=tokenizer.eos_token_id,\n )\n next_text = tokenizer.decode(\n out_ids[0][inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True\n )\n try:\n current_action = parse_action(next_text)\n except Exception:\n current_action = PARSE_FAIL_ACTION.copy()\n\n except Exception as e:\n print(f\" [reward_fn] Episode error: {e}\")\n total_reward = -1.3\n\n rewards.append(total_reward)\n\n if parse_fail_count > 0:\n print(f\" [reward_fn] Parse failures: {parse_fail_count}/{len(completions)}\")\n\n return rewards\n\n\nprint(\"GRPO reward function ready.\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  },
337
  {
338
  "cell_type": "markdown",
 
568
  },
569
  "nbformat": 4,
570
  "nbformat_minor": 4
571
+ }