Spaces:
Sleeping
Sleeping
Prasham.Jain Claude Sonnet 4.6 commited on
Commit Β·
421885d
1
Parent(s): b78f85d
fix(spaces): switch training Space to A10G Small, tune notebook for 24 GB
Browse filesA10G Large has been unavailable for 40+ min. A10G Small (24 GB) is more
reliably allocated and fully fits Qwen3-4B-bnb-4bit 4-bit + LoRA + GRPO
with our current hyperparams.
push_to_hf.sh:
- hardware: a10g-small in training Space YAML
- app_port: 7860 (JupyterLab)
- app_port: 8000 in env Space YAML (fixes "Starting" loop)
notebook:
- per_device_batch_size 4β2 (SFT, fits 24 GB)
- num_generations 4β2 (GRPO, halves peak VRAM)
- max_completion_length 256β128
- max_prompt_length 2048β1536
- all A10G Large references updated to A10G Small
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- notebooks/train_grpo.ipynb +14 -14
- push_to_hf.sh +2 -0
notebooks/train_grpo.ipynb
CHANGED
|
@@ -7,14 +7,14 @@
|
|
| 7 |
"source": [
|
| 8 |
"# CI-Triage-Env β GRPO Training\n",
|
| 9 |
"\n",
|
| 10 |
-
"**Hardware target**: A10G
|
| 11 |
"\n",
|
| 12 |
"**Set these as Space secrets** (Settings β Variables and secrets):\n",
|
| 13 |
"- `HF_TOKEN` β HuggingFace write token\n",
|
| 14 |
"- `HF_USERNAME` β your HF username\n",
|
| 15 |
"- `WANDB_API_KEY` β Weights & Biases key (get free at wandb.ai)\n",
|
| 16 |
"\n",
|
| 17 |
-
"**Time budget**: SFTβ
|
| 18 |
]
|
| 19 |
},
|
| 20 |
{
|
|
@@ -143,7 +143,7 @@
|
|
| 143 |
"metadata": {},
|
| 144 |
"outputs": [],
|
| 145 |
"source": [
|
| 146 |
-
"# Cell 5 β SFT warmstart (~
|
| 147 |
"from pathlib import Path\n",
|
| 148 |
"from ci_triage_env.training.sft import run_sft\n",
|
| 149 |
"\n",
|
|
@@ -156,7 +156,7 @@
|
|
| 156 |
" dataset_path=str(SFT_DS_DIR),\n",
|
| 157 |
" output_dir=str(SFT_CKPT),\n",
|
| 158 |
" num_epochs=2,\n",
|
| 159 |
-
" per_device_batch_size=
|
| 160 |
" gradient_accumulation_steps=4,\n",
|
| 161 |
" )\n",
|
| 162 |
" print(f'SFT done β {SFT_CKPT}')\n",
|
|
@@ -167,9 +167,9 @@
|
|
| 167 |
" repo_id=MODEL_REPO + '-sft',\n",
|
| 168 |
" repo_type='model',\n",
|
| 169 |
" token=HF_TOKEN,\n",
|
| 170 |
-
" commit_message='SFT warmstart checkpoint (Qwen3
|
| 171 |
" )\n",
|
| 172 |
-
" print(f'SFT checkpoint pushed to {MODEL_REPO}-sft')"
|
| 173 |
]
|
| 174 |
},
|
| 175 |
{
|
|
@@ -179,7 +179,7 @@
|
|
| 179 |
"metadata": {},
|
| 180 |
"outputs": [],
|
| 181 |
"source": [
|
| 182 |
-
"# Cell 6 β GRPO fine-tuning (~90 min for 100 steps on A10G
|
| 183 |
"from pathlib import Path\n",
|
| 184 |
"from ci_triage_env.training.mock_env_client import MockEnvClient\n",
|
| 185 |
"from ci_triage_env.training.grpo import run_grpo\n",
|
|
@@ -204,9 +204,9 @@
|
|
| 204 |
" 'gradient_accumulation_steps': 4,\n",
|
| 205 |
" 'learning_rate': 5e-6,\n",
|
| 206 |
" 'kl_coef': 0.04,\n",
|
| 207 |
-
" 'num_generations':
|
| 208 |
-
" 'max_prompt_length':
|
| 209 |
-
" 'max_completion_length':
|
| 210 |
" 'temperature': 0.8,\n",
|
| 211 |
" 'top_p': 0.95,\n",
|
| 212 |
" 'logging_steps': 5,\n",
|
|
@@ -214,7 +214,7 @@
|
|
| 214 |
" 'report_to': 'wandb' if WANDB_KEY else 'none',\n",
|
| 215 |
" },\n",
|
| 216 |
")\n",
|
| 217 |
-
"print(f'GRPO done β {GRPO_CKPT}')"
|
| 218 |
]
|
| 219 |
},
|
| 220 |
{
|
|
@@ -235,9 +235,9 @@
|
|
| 235 |
" repo_id=MODEL_REPO,\n",
|
| 236 |
" repo_type='model',\n",
|
| 237 |
" token=HF_TOKEN,\n",
|
| 238 |
-
" commit_message=f'GRPO-trained adapter β {GRPO_STEPS} steps on A10G
|
| 239 |
")\n",
|
| 240 |
-
"print(f'Final model: https://huggingface.co/{MODEL_REPO}')"
|
| 241 |
]
|
| 242 |
}
|
| 243 |
],
|
|
@@ -254,4 +254,4 @@
|
|
| 254 |
},
|
| 255 |
"nbformat": 4,
|
| 256 |
"nbformat_minor": 5
|
| 257 |
-
}
|
|
|
|
| 7 |
"source": [
|
| 8 |
"# CI-Triage-Env β GRPO Training\n",
|
| 9 |
"\n",
|
| 10 |
+
"**Hardware target**: A10G Small (24 GB VRAM, 4 vCPU) β HuggingFace Space\n",
|
| 11 |
"\n",
|
| 12 |
"**Set these as Space secrets** (Settings β Variables and secrets):\n",
|
| 13 |
"- `HF_TOKEN` β HuggingFace write token\n",
|
| 14 |
"- `HF_USERNAME` β your HF username\n",
|
| 15 |
"- `WANDB_API_KEY` β Weights & Biases key (get free at wandb.ai)\n",
|
| 16 |
"\n",
|
| 17 |
+
"**Time budget**: SFTβ50 min + GRPOβ90 min = ~2.5 hours on A10G Small.\n"
|
| 18 |
]
|
| 19 |
},
|
| 20 |
{
|
|
|
|
| 143 |
"metadata": {},
|
| 144 |
"outputs": [],
|
| 145 |
"source": [
|
| 146 |
+
"# Cell 5 β SFT warmstart (~50 min on A10G Small)\n",
|
| 147 |
"from pathlib import Path\n",
|
| 148 |
"from ci_triage_env.training.sft import run_sft\n",
|
| 149 |
"\n",
|
|
|
|
| 156 |
" dataset_path=str(SFT_DS_DIR),\n",
|
| 157 |
" output_dir=str(SFT_CKPT),\n",
|
| 158 |
" num_epochs=2,\n",
|
| 159 |
+
" per_device_batch_size=2,\n",
|
| 160 |
" gradient_accumulation_steps=4,\n",
|
| 161 |
" )\n",
|
| 162 |
" print(f'SFT done β {SFT_CKPT}')\n",
|
|
|
|
| 167 |
" repo_id=MODEL_REPO + '-sft',\n",
|
| 168 |
" repo_type='model',\n",
|
| 169 |
" token=HF_TOKEN,\n",
|
| 170 |
+
" commit_message='SFT warmstart checkpoint (Qwen3-4B + LoRA)',\n",
|
| 171 |
" )\n",
|
| 172 |
+
" print(f'SFT checkpoint pushed to {MODEL_REPO}-sft')\n"
|
| 173 |
]
|
| 174 |
},
|
| 175 |
{
|
|
|
|
| 179 |
"metadata": {},
|
| 180 |
"outputs": [],
|
| 181 |
"source": [
|
| 182 |
+
"# Cell 6 β GRPO fine-tuning (~90 min for 100 steps on A10G Small)\n",
|
| 183 |
"from pathlib import Path\n",
|
| 184 |
"from ci_triage_env.training.mock_env_client import MockEnvClient\n",
|
| 185 |
"from ci_triage_env.training.grpo import run_grpo\n",
|
|
|
|
| 204 |
" 'gradient_accumulation_steps': 4,\n",
|
| 205 |
" 'learning_rate': 5e-6,\n",
|
| 206 |
" 'kl_coef': 0.04,\n",
|
| 207 |
+
" 'num_generations': 2,\n",
|
| 208 |
+
" 'max_prompt_length': 1536,\n",
|
| 209 |
+
" 'max_completion_length': 128,\n",
|
| 210 |
" 'temperature': 0.8,\n",
|
| 211 |
" 'top_p': 0.95,\n",
|
| 212 |
" 'logging_steps': 5,\n",
|
|
|
|
| 214 |
" 'report_to': 'wandb' if WANDB_KEY else 'none',\n",
|
| 215 |
" },\n",
|
| 216 |
")\n",
|
| 217 |
+
"print(f'GRPO done β {GRPO_CKPT}')\n"
|
| 218 |
]
|
| 219 |
},
|
| 220 |
{
|
|
|
|
| 235 |
" repo_id=MODEL_REPO,\n",
|
| 236 |
" repo_type='model',\n",
|
| 237 |
" token=HF_TOKEN,\n",
|
| 238 |
+
" commit_message=f'GRPO-trained adapter β {GRPO_STEPS} steps on A10G Small',\n",
|
| 239 |
")\n",
|
| 240 |
+
"print(f'Final model: https://huggingface.co/{MODEL_REPO}')\n"
|
| 241 |
]
|
| 242 |
}
|
| 243 |
],
|
|
|
|
| 254 |
},
|
| 255 |
"nbformat": 4,
|
| 256 |
"nbformat_minor": 5
|
| 257 |
+
}
|
push_to_hf.sh
CHANGED
|
@@ -73,6 +73,8 @@ emoji: ποΈ
|
|
| 73 |
colorFrom: yellow
|
| 74 |
colorTo: red
|
| 75 |
sdk: docker
|
|
|
|
|
|
|
| 76 |
pinned: false
|
| 77 |
---"
|
| 78 |
|
|
|
|
| 73 |
colorFrom: yellow
|
| 74 |
colorTo: red
|
| 75 |
sdk: docker
|
| 76 |
+
app_port: 7860
|
| 77 |
+
hardware: a10g-small
|
| 78 |
pinned: false
|
| 79 |
---"
|
| 80 |
|