Spaces:
Sleeping
Sleeping
feat: A10G-optimised GRPO config — 256 tokens, bf16, 300 samples
Browse files- ECHO_Training.ipynb +31 -14
ECHO_Training.ipynb
CHANGED
|
@@ -273,28 +273,36 @@
|
|
| 273 |
"print(\" BAD : all rewards exactly -0.5 → stop & report\")\n",
|
| 274 |
"print(\"=\" * 50)"
|
| 275 |
],
|
| 276 |
-
"id": "081d73fd",
|
| 277 |
"execution_count": null,
|
| 278 |
-
"outputs": []
|
|
|
|
| 279 |
},
|
| 280 |
{
|
| 281 |
"cell_type": "code",
|
| 282 |
"metadata": {},
|
| 283 |
"source": [
|
| 284 |
-
"# Configure GRPO training\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
"training_args = GRPOConfig(\n",
|
| 286 |
" output_dir=\"echo_grpo_output\",\n",
|
| 287 |
-
" num_train_epochs=
|
| 288 |
" per_device_train_batch_size=1,\n",
|
| 289 |
-
" gradient_accumulation_steps=8,\n",
|
| 290 |
" learning_rate=2e-5,\n",
|
| 291 |
-
" warmup_steps=
|
| 292 |
-
" logging_steps=
|
| 293 |
-
" save_steps=
|
| 294 |
-
"
|
|
|
|
| 295 |
" report_to=\"none\",\n",
|
| 296 |
-
" max_completion_length=
|
| 297 |
-
" num_generations=4,
|
| 298 |
" temperature=0.8,\n",
|
| 299 |
")\n",
|
| 300 |
"\n",
|
|
@@ -302,13 +310,22 @@
|
|
| 302 |
" model=model,\n",
|
| 303 |
" args=training_args,\n",
|
| 304 |
" reward_funcs=[echo_reward_function],\n",
|
| 305 |
-
" train_dataset=
|
| 306 |
" tokenizer=tokenizer,\n",
|
| 307 |
")\n",
|
| 308 |
"\n",
|
| 309 |
-
"print(\"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
"trainer.train()\n",
|
| 311 |
-
"print(\"Training complete!\")"
|
| 312 |
],
|
| 313 |
"execution_count": null,
|
| 314 |
"outputs": [],
|
|
|
|
| 273 |
"print(\" BAD : all rewards exactly -0.5 → stop & report\")\n",
|
| 274 |
"print(\"=\" * 50)"
|
| 275 |
],
|
|
|
|
| 276 |
"execution_count": null,
|
| 277 |
+
"outputs": [],
|
| 278 |
+
"id": "081d73fd"
|
| 279 |
},
|
| 280 |
{
|
| 281 |
"cell_type": "code",
|
| 282 |
"metadata": {},
|
| 283 |
"source": [
|
| 284 |
+
"# Configure GRPO training — OPTIMIZED for A10G small (~2.5 hrs, ~$3-4 cost)\n",
|
| 285 |
+
"# Hardware: A10G small ($1.05/hr) — 3x faster than T4 for 7B models\n",
|
| 286 |
+
"# max_completion_length=256: enough for reasoning, 2x faster than 512\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"# Rebuild dataset for A10G run\n",
|
| 289 |
+
"dataset_a10g = build_training_dataset(300)\n",
|
| 290 |
+
"print(f\"Dataset: {len(dataset_a10g)} samples\")\n",
|
| 291 |
+
"\n",
|
| 292 |
"training_args = GRPOConfig(\n",
|
| 293 |
" output_dir=\"echo_grpo_output\",\n",
|
| 294 |
+
" num_train_epochs=1,\n",
|
| 295 |
" per_device_train_batch_size=1,\n",
|
| 296 |
+
" gradient_accumulation_steps=8, # effective batch = 8, keep for GRPO stability\n",
|
| 297 |
" learning_rate=2e-5,\n",
|
| 298 |
+
" warmup_steps=20,\n",
|
| 299 |
+
" logging_steps=5,\n",
|
| 300 |
+
" save_steps=50,\n",
|
| 301 |
+
" bf16=True, # A10G supports bfloat16 — better than fp16\n",
|
| 302 |
+
" fp16=False,\n",
|
| 303 |
" report_to=\"none\",\n",
|
| 304 |
+
" max_completion_length=256, # 256 = enough reasoning space, 2x faster than 512\n",
|
| 305 |
+
" num_generations=4, # GRPO group size — do NOT reduce\n",
|
| 306 |
" temperature=0.8,\n",
|
| 307 |
")\n",
|
| 308 |
"\n",
|
|
|
|
| 310 |
" model=model,\n",
|
| 311 |
" args=training_args,\n",
|
| 312 |
" reward_funcs=[echo_reward_function],\n",
|
| 313 |
+
" train_dataset=dataset_a10g,\n",
|
| 314 |
" tokenizer=tokenizer,\n",
|
| 315 |
")\n",
|
| 316 |
"\n",
|
| 317 |
+
"print(\"=\" * 55)\n",
|
| 318 |
+
"print(\"🚀 ECHO GRPO Training — A10G small + 256 tokens\")\n",
|
| 319 |
+
"print(\" 300 samples | 1 epoch | grad_accum=8\")\n",
|
| 320 |
+
"print(\" Estimated: ~2.5 hrs | Cost: ~$3-4\")\n",
|
| 321 |
+
"print(\"=\" * 55)\n",
|
| 322 |
+
"print()\n",
|
| 323 |
+
"print(\"Watch step output — after step 5 you should see:\")\n",
|
| 324 |
+
"print(\" GOOD: rewards mixed between -0.5 and +0.8\")\n",
|
| 325 |
+
"print(\" BAD : all rewards exactly -0.5 → stop & report\")\n",
|
| 326 |
+
"print()\n",
|
| 327 |
"trainer.train()\n",
|
| 328 |
+
"print(\"\\n✅ Training complete!\")"
|
| 329 |
],
|
| 330 |
"execution_count": null,
|
| 331 |
"outputs": [],
|