Spaces:
Sleeping
Sleeping
fix: correct GRPO training hyperparameters to prevent KL explosion
Browse filesbeta=0.01 weakened the KL penalty (opposite of intended), causing policy
to diverge at step 18 and collapse to learn×28. Fix: beta=0.1 (stronger
constraint), lr=5e-5 (more conservative), max_grad_norm=0.5 (clipping).
Also fix train.py: max_completion_length was 368 (prompt-pad remainder)
instead of 16, which would allow verbose drift in standalone runs.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- training/RhythmEnv_GRPO_Training.ipynb +45 -1
- training/train.py +4 -2
training/RhythmEnv_GRPO_Training.ipynb
CHANGED
|
@@ -224,7 +224,51 @@
|
|
| 224 |
"execution_count": null,
|
| 225 |
"metadata": {},
|
| 226 |
"outputs": [],
|
| 227 |
-
"source":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
},
|
| 229 |
{
|
| 230 |
"cell_type": "code",
|
|
|
|
| 224 |
"execution_count": null,
|
| 225 |
"metadata": {},
|
| 226 |
"outputs": [],
|
| 227 |
+
"source": [
|
| 228 |
+
"from trl import GRPOConfig, GRPOTrainer\n",
|
| 229 |
+
"\n",
|
| 230 |
+
"MAX_STEPS = 500 # Increase to 1000 if time allows\n",
|
| 231 |
+
"NUM_GENERATIONS = 4\n",
|
| 232 |
+
"LEARNING_RATE = 5e-5 # Reduced from default — lower lr prevents destabilizing early gradient steps\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"max_prompt_length = 400\n",
|
| 235 |
+
"max_completion_length = 16 # Action names are 3-15 chars\n",
|
| 236 |
+
"\n",
|
| 237 |
+
"training_args = GRPOConfig(\n",
|
| 238 |
+
" temperature=1.0,\n",
|
| 239 |
+
" learning_rate=LEARNING_RATE,\n",
|
| 240 |
+
" beta=0.1, # KL penalty — higher = more conservative, prevents policy drift\n",
|
| 241 |
+
" max_grad_norm=0.5, # Gradient clipping prevents large destabilizing updates\n",
|
| 242 |
+
" weight_decay=0.001,\n",
|
| 243 |
+
" warmup_ratio=0.1,\n",
|
| 244 |
+
" lr_scheduler_type=\"linear\",\n",
|
| 245 |
+
" optim=\"adamw_8bit\",\n",
|
| 246 |
+
" logging_steps=1,\n",
|
| 247 |
+
" per_device_train_batch_size=1,\n",
|
| 248 |
+
" gradient_accumulation_steps=4,\n",
|
| 249 |
+
" num_generations=NUM_GENERATIONS,\n",
|
| 250 |
+
" max_prompt_length=max_prompt_length,\n",
|
| 251 |
+
" max_completion_length=max_completion_length,\n",
|
| 252 |
+
" max_steps=MAX_STEPS,\n",
|
| 253 |
+
" save_steps=100,\n",
|
| 254 |
+
" report_to=REPORT_TO,\n",
|
| 255 |
+
" output_dir=\"outputs/rhythmenv_trained\",\n",
|
| 256 |
+
")\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"trainer = GRPOTrainer(\n",
|
| 259 |
+
" model=model,\n",
|
| 260 |
+
" processing_class=tokenizer,\n",
|
| 261 |
+
" reward_funcs=reward_funcs,\n",
|
| 262 |
+
" args=training_args,\n",
|
| 263 |
+
" train_dataset=dataset,\n",
|
| 264 |
+
")\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"print(f\"Training config: {MAX_STEPS} steps, {NUM_GENERATIONS} generations, lr={LEARNING_RATE}\")\n",
|
| 267 |
+
"print(f\" beta=0.1 (higher KL penalty = more conservative = less policy drift)\")\n",
|
| 268 |
+
"print(f\" max_grad_norm=0.5 (gradient clipping for stability)\")\n",
|
| 269 |
+
"print(f\" max_completion_length=16 (action names only, no verbose outputs)\")\n",
|
| 270 |
+
"print(\"Starting training...\")"
|
| 271 |
+
]
|
| 272 |
},
|
| 273 |
{
|
| 274 |
"cell_type": "code",
|
training/train.py
CHANGED
|
@@ -32,7 +32,7 @@ def main():
|
|
| 32 |
help="Maximum training samples")
|
| 33 |
parser.add_argument("--num_generations", type=int, default=4,
|
| 34 |
help="Number of completions per prompt for GRPO")
|
| 35 |
-
parser.add_argument("--learning_rate", type=float, default=
|
| 36 |
help="Learning rate")
|
| 37 |
parser.add_argument("--output_dir", type=str, default="outputs/rhythmenv_trained",
|
| 38 |
help="Output directory for model and logs")
|
|
@@ -128,11 +128,13 @@ def main():
|
|
| 128 |
from trl import GRPOConfig, GRPOTrainer
|
| 129 |
|
| 130 |
max_prompt_length = 400
|
| 131 |
-
max_completion_length =
|
| 132 |
|
| 133 |
training_args = GRPOConfig(
|
| 134 |
temperature=1.0,
|
| 135 |
learning_rate=args.learning_rate,
|
|
|
|
|
|
|
| 136 |
weight_decay=0.001,
|
| 137 |
warmup_ratio=0.1,
|
| 138 |
lr_scheduler_type="linear",
|
|
|
|
| 32 |
help="Maximum training samples")
|
| 33 |
parser.add_argument("--num_generations", type=int, default=4,
|
| 34 |
help="Number of completions per prompt for GRPO")
|
| 35 |
+
parser.add_argument("--learning_rate", type=float, default=5e-5,
|
| 36 |
help="Learning rate")
|
| 37 |
parser.add_argument("--output_dir", type=str, default="outputs/rhythmenv_trained",
|
| 38 |
help="Output directory for model and logs")
|
|
|
|
| 128 |
from trl import GRPOConfig, GRPOTrainer
|
| 129 |
|
| 130 |
max_prompt_length = 400
|
| 131 |
+
max_completion_length = 16 # Action names are 3-15 chars; cap prevents verbose drift
|
| 132 |
|
| 133 |
training_args = GRPOConfig(
|
| 134 |
temperature=1.0,
|
| 135 |
learning_rate=args.learning_rate,
|
| 136 |
+
beta=0.1, # KL penalty — higher = more conservative, prevents policy drift
|
| 137 |
+
max_grad_norm=0.5, # Gradient clipping prevents large destabilizing updates
|
| 138 |
weight_decay=0.001,
|
| 139 |
warmup_ratio=0.1,
|
| 140 |
lr_scheduler_type="linear",
|