anuragredbus commited on
Commit
1d82571
·
1 Parent(s): 9536a33

train_grpo: prebuilt flash-attn wheel + verbose training rollouts

Browse files

- Install flash-attn 2.7.4.post1 via prebuilt wheel matched to
torch 2.5 / py3.11 / cu12 instead of building from source. The HF
Job container (pytorch/pytorch:*-runtime) has no nvcc, so the
source build always fails with CUDA_HOME unset. Triple fallback:
prebuilt -> source build -> skip and use sdpa.
- Flip verbose=False to verbose=True in the training-loop rollouts so
per-day generation progress is visible while a round runs (~minutes
per round at 6 eps x 15 days x 2 phases).

Made-with: Cursor

Files changed (1) hide show
  1. training/train_grpo.ipynb +5 -2
training/train_grpo.ipynb CHANGED
@@ -33,7 +33,10 @@
33
  "!pip install -q matplotlib pandas\n",
34
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
35
  "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
36
- "!pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
 
 
 
37
  ],
38
  "execution_count": null,
39
  "outputs": []
@@ -881,7 +884,7 @@
881
  " tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + ep + round_idx * 10) for ep in range(EPISODES_PER_ROUND)]\n",
882
  " t_roll = time.time()\n",
883
  " results = run_llm_episodes_batched(\n",
884
- " peft_model, tokenizer, tasks_seeds, verbose=False, eval=False,\n",
885
  " system=sys_prompt, hint_peak_hours=use_hint, reward_mode=reward_mode,\n",
886
  " log_tag=f\"{phase_name}_r{round_idx}\",\n",
887
  " )\n",
 
33
  "!pip install -q matplotlib pandas\n",
34
  "!pip install -q \"typing_extensions>=4.13.0\" pydantic httpx\n",
35
  "!pip install -q \"openenv-core[core]>=0.2.2\"\n",
36
+ "# flash-attn: install prebuilt wheel matched to torch 2.5 + py3.11 + cu12 (HF Job container).\n",
37
+ "# This avoids the from-source build that fails when the container has no nvcc / CUDA_HOME.\n",
38
+ "# Falls back to sdpa if the wheel install fails (e.g. on a different env).\n",
39
+ "!pip install -q \"https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.5cxx11abiFALSE-cp311-cp311-linux_x86_64.whl\" || pip install -q flash-attn --no-build-isolation || echo \"flash-attn install skipped; will use sdpa\""
40
  ],
41
  "execution_count": null,
42
  "outputs": []
 
884
  " tasks_seeds = [(TASKS[ep % len(TASKS)], 42 + ep + round_idx * 10) for ep in range(EPISODES_PER_ROUND)]\n",
885
  " t_roll = time.time()\n",
886
  " results = run_llm_episodes_batched(\n",
887
+ " peft_model, tokenizer, tasks_seeds, verbose=True, eval=False,\n",
888
  " system=sys_prompt, hint_peak_hours=use_hint, reward_mode=reward_mode,\n",
889
  " log_tag=f\"{phase_name}_r{round_idx}\",\n",
890
  " )\n",