Prasham.Jain Claude Sonnet 4.6 commited on
Commit
421885d
Β·
1 Parent(s): b78f85d

fix(spaces): switch training Space to A10G Small, tune notebook for 24 GB

Browse files

A10G Large has been unavailable for 40+ min. A10G Small (24 GB) is more
reliably allocated and fully fits Qwen3-4B-bnb-4bit 4-bit + LoRA + GRPO
with our current hyperparams.

push_to_hf.sh:
- hardware: a10g-small in training Space YAML
- app_port: 7860 (JupyterLab)
- app_port: 8000 in env Space YAML (fixes "Starting" loop)

notebook:
- per_device_batch_size 4β†’2 (SFT, fits 24 GB)
- num_generations 4β†’2 (GRPO, halves peak VRAM)
- max_completion_length 256β†’128
- max_prompt_length 2048β†’1536
- all A10G Large references updated to A10G Small

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. notebooks/train_grpo.ipynb +14 -14
  2. push_to_hf.sh +2 -0
notebooks/train_grpo.ipynb CHANGED
@@ -7,14 +7,14 @@
7
  "source": [
8
  "# CI-Triage-Env β€” GRPO Training\n",
9
  "\n",
10
- "**Hardware target**: A10G Large (46 GB VRAM, 12 vCPU) β€” HuggingFace Space\n",
11
  "\n",
12
  "**Set these as Space secrets** (Settings β†’ Variables and secrets):\n",
13
  "- `HF_TOKEN` β€” HuggingFace write token\n",
14
  "- `HF_USERNAME` β€” your HF username\n",
15
  "- `WANDB_API_KEY` β€” Weights & Biases key (get free at wandb.ai)\n",
16
  "\n",
17
- "**Time budget**: SFTβ‰ˆ45 min + GRPOβ‰ˆ90 min = ~2.5 hours on A10G Large."
18
  ]
19
  },
20
  {
@@ -143,7 +143,7 @@
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
146
- "# Cell 5 β€” SFT warmstart (~45 min on A10G Large)\n",
147
  "from pathlib import Path\n",
148
  "from ci_triage_env.training.sft import run_sft\n",
149
  "\n",
@@ -156,7 +156,7 @@
156
  " dataset_path=str(SFT_DS_DIR),\n",
157
  " output_dir=str(SFT_CKPT),\n",
158
  " num_epochs=2,\n",
159
- " per_device_batch_size=4,\n",
160
  " gradient_accumulation_steps=4,\n",
161
  " )\n",
162
  " print(f'SFT done β†’ {SFT_CKPT}')\n",
@@ -167,9 +167,9 @@
167
  " repo_id=MODEL_REPO + '-sft',\n",
168
  " repo_type='model',\n",
169
  " token=HF_TOKEN,\n",
170
- " commit_message='SFT warmstart checkpoint (Qwen3.5-4B + LoRA)',\n",
171
  " )\n",
172
- " print(f'SFT checkpoint pushed to {MODEL_REPO}-sft')"
173
  ]
174
  },
175
  {
@@ -179,7 +179,7 @@
179
  "metadata": {},
180
  "outputs": [],
181
  "source": [
182
- "# Cell 6 β€” GRPO fine-tuning (~90 min for 100 steps on A10G Large)\n",
183
  "from pathlib import Path\n",
184
  "from ci_triage_env.training.mock_env_client import MockEnvClient\n",
185
  "from ci_triage_env.training.grpo import run_grpo\n",
@@ -204,9 +204,9 @@
204
  " 'gradient_accumulation_steps': 4,\n",
205
  " 'learning_rate': 5e-6,\n",
206
  " 'kl_coef': 0.04,\n",
207
- " 'num_generations': 4,\n",
208
- " 'max_prompt_length': 2048,\n",
209
- " 'max_completion_length': 256,\n",
210
  " 'temperature': 0.8,\n",
211
  " 'top_p': 0.95,\n",
212
  " 'logging_steps': 5,\n",
@@ -214,7 +214,7 @@
214
  " 'report_to': 'wandb' if WANDB_KEY else 'none',\n",
215
  " },\n",
216
  ")\n",
217
- "print(f'GRPO done β†’ {GRPO_CKPT}')"
218
  ]
219
  },
220
  {
@@ -235,9 +235,9 @@
235
  " repo_id=MODEL_REPO,\n",
236
  " repo_type='model',\n",
237
  " token=HF_TOKEN,\n",
238
- " commit_message=f'GRPO-trained adapter β€” {GRPO_STEPS} steps on A10G Large',\n",
239
  ")\n",
240
- "print(f'Final model: https://huggingface.co/{MODEL_REPO}')"
241
  ]
242
  }
243
  ],
@@ -254,4 +254,4 @@
254
  },
255
  "nbformat": 4,
256
  "nbformat_minor": 5
257
- }
 
7
  "source": [
8
  "# CI-Triage-Env β€” GRPO Training\n",
9
  "\n",
10
+ "**Hardware target**: A10G Small (24 GB VRAM, 4 vCPU) β€” HuggingFace Space\n",
11
  "\n",
12
  "**Set these as Space secrets** (Settings β†’ Variables and secrets):\n",
13
  "- `HF_TOKEN` β€” HuggingFace write token\n",
14
  "- `HF_USERNAME` β€” your HF username\n",
15
  "- `WANDB_API_KEY` β€” Weights & Biases key (get free at wandb.ai)\n",
16
  "\n",
17
+ "**Time budget**: SFTβ‰ˆ50 min + GRPOβ‰ˆ90 min = ~2.5 hours on A10G Small.\n"
18
  ]
19
  },
20
  {
 
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
146
+ "# Cell 5 β€” SFT warmstart (~50 min on A10G Small)\n",
147
  "from pathlib import Path\n",
148
  "from ci_triage_env.training.sft import run_sft\n",
149
  "\n",
 
156
  " dataset_path=str(SFT_DS_DIR),\n",
157
  " output_dir=str(SFT_CKPT),\n",
158
  " num_epochs=2,\n",
159
+ " per_device_batch_size=2,\n",
160
  " gradient_accumulation_steps=4,\n",
161
  " )\n",
162
  " print(f'SFT done β†’ {SFT_CKPT}')\n",
 
167
  " repo_id=MODEL_REPO + '-sft',\n",
168
  " repo_type='model',\n",
169
  " token=HF_TOKEN,\n",
170
+ " commit_message='SFT warmstart checkpoint (Qwen3-4B + LoRA)',\n",
171
  " )\n",
172
+ " print(f'SFT checkpoint pushed to {MODEL_REPO}-sft')\n"
173
  ]
174
  },
175
  {
 
179
  "metadata": {},
180
  "outputs": [],
181
  "source": [
182
+ "# Cell 6 β€” GRPO fine-tuning (~90 min for 100 steps on A10G Small)\n",
183
  "from pathlib import Path\n",
184
  "from ci_triage_env.training.mock_env_client import MockEnvClient\n",
185
  "from ci_triage_env.training.grpo import run_grpo\n",
 
204
  " 'gradient_accumulation_steps': 4,\n",
205
  " 'learning_rate': 5e-6,\n",
206
  " 'kl_coef': 0.04,\n",
207
+ " 'num_generations': 2,\n",
208
+ " 'max_prompt_length': 1536,\n",
209
+ " 'max_completion_length': 128,\n",
210
  " 'temperature': 0.8,\n",
211
  " 'top_p': 0.95,\n",
212
  " 'logging_steps': 5,\n",
 
214
  " 'report_to': 'wandb' if WANDB_KEY else 'none',\n",
215
  " },\n",
216
  ")\n",
217
+ "print(f'GRPO done β†’ {GRPO_CKPT}')\n"
218
  ]
219
  },
220
  {
 
235
  " repo_id=MODEL_REPO,\n",
236
  " repo_type='model',\n",
237
  " token=HF_TOKEN,\n",
238
+ " commit_message=f'GRPO-trained adapter β€” {GRPO_STEPS} steps on A10G Small',\n",
239
  ")\n",
240
+ "print(f'Final model: https://huggingface.co/{MODEL_REPO}')\n"
241
  ]
242
  }
243
  ],
 
254
  },
255
  "nbformat": 4,
256
  "nbformat_minor": 5
257
+ }
push_to_hf.sh CHANGED
@@ -73,6 +73,8 @@ emoji: πŸ‹οΈ
73
  colorFrom: yellow
74
  colorTo: red
75
  sdk: docker
 
 
76
  pinned: false
77
  ---"
78
 
 
73
  colorFrom: yellow
74
  colorTo: red
75
  sdk: docker
76
+ app_port: 7860
77
+ hardware: a10g-small
78
  pinned: false
79
  ---"
80