InosLihka Claude Sonnet 4.6 commited on
Commit
fb112e4
·
1 Parent(s): 8a56903

fix: correct GRPO training hyperparameters to prevent KL explosion

Browse files

beta=0.01 weakened the KL penalty (opposite of intended), causing policy
to diverge at step 18 and collapse to learn×28. Fix: beta=0.1 (stronger
constraint), lr=5e-5 (more conservative), max_grad_norm=0.5 (clipping).

Also fix train.py: max_completion_length was 368 (prompt-pad remainder)
instead of 16, which would allow verbose drift in standalone runs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

training/RhythmEnv_GRPO_Training.ipynb CHANGED
@@ -224,7 +224,51 @@
224
  "execution_count": null,
225
  "metadata": {},
226
  "outputs": [],
227
- "source": "from trl import GRPOConfig, GRPOTrainer\n\nMAX_STEPS = 500 # Increase to 1000 if time allows\nNUM_GENERATIONS = 4\nLEARNING_RATE = 2e-4\n\nmax_prompt_length = 400\nmax_completion_length = 16 # Action names are 3-15 chars — no need for more\n\ntraining_args = GRPOConfig(\n temperature=1.0,\n learning_rate=LEARNING_RATE,\n beta=0.01, # KL penalty coefficient (called beta in TRL, default 0.04 causes KL explosion)\n weight_decay=0.001,\n warmup_ratio=0.1,\n lr_scheduler_type=\"linear\",\n optim=\"adamw_8bit\",\n logging_steps=1,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n num_generations=NUM_GENERATIONS,\n max_prompt_length=max_prompt_length,\n max_completion_length=max_completion_length,\n max_steps=MAX_STEPS,\n save_steps=100,\n report_to=REPORT_TO,\n output_dir=\"outputs/rhythmenv_trained\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=reward_funcs,\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Training config: {MAX_STEPS} steps, {NUM_GENERATIONS} generations, lr={LEARNING_RATE}\")\nprint(f\" beta=0.01 (KL penalty — reduced from default to prevent policy drift)\")\nprint(f\" max_completion_length=16 (action names only, no verbose outputs)\")\nprint(\"Starting training...\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  },
229
  {
230
  "cell_type": "code",
 
224
  "execution_count": null,
225
  "metadata": {},
226
  "outputs": [],
227
+ "source": [
228
+ "from trl import GRPOConfig, GRPOTrainer\n",
229
+ "\n",
230
+ "MAX_STEPS = 500 # Increase to 1000 if time allows\n",
231
+ "NUM_GENERATIONS = 4\n",
232
+ "LEARNING_RATE = 5e-5 # Reduced from default — lower lr prevents destabilizing early gradient steps\n",
233
+ "\n",
234
+ "max_prompt_length = 400\n",
235
+ "max_completion_length = 16 # Action names are 3-15 chars\n",
236
+ "\n",
237
+ "training_args = GRPOConfig(\n",
238
+ " temperature=1.0,\n",
239
+ " learning_rate=LEARNING_RATE,\n",
240
+ " beta=0.1, # KL penalty — higher = more conservative, prevents policy drift\n",
241
+ " max_grad_norm=0.5, # Gradient clipping prevents large destabilizing updates\n",
242
+ " weight_decay=0.001,\n",
243
+ " warmup_ratio=0.1,\n",
244
+ " lr_scheduler_type=\"linear\",\n",
245
+ " optim=\"adamw_8bit\",\n",
246
+ " logging_steps=1,\n",
247
+ " per_device_train_batch_size=1,\n",
248
+ " gradient_accumulation_steps=4,\n",
249
+ " num_generations=NUM_GENERATIONS,\n",
250
+ " max_prompt_length=max_prompt_length,\n",
251
+ " max_completion_length=max_completion_length,\n",
252
+ " max_steps=MAX_STEPS,\n",
253
+ " save_steps=100,\n",
254
+ " report_to=REPORT_TO,\n",
255
+ " output_dir=\"outputs/rhythmenv_trained\",\n",
256
+ ")\n",
257
+ "\n",
258
+ "trainer = GRPOTrainer(\n",
259
+ " model=model,\n",
260
+ " processing_class=tokenizer,\n",
261
+ " reward_funcs=reward_funcs,\n",
262
+ " args=training_args,\n",
263
+ " train_dataset=dataset,\n",
264
+ ")\n",
265
+ "\n",
266
+ "print(f\"Training config: {MAX_STEPS} steps, {NUM_GENERATIONS} generations, lr={LEARNING_RATE}\")\n",
267
+ "print(f\" beta=0.1 (higher KL penalty = more conservative = less policy drift)\")\n",
268
+ "print(f\" max_grad_norm=0.5 (gradient clipping for stability)\")\n",
269
+ "print(f\" max_completion_length=16 (action names only, no verbose outputs)\")\n",
270
+ "print(\"Starting training...\")"
271
+ ]
272
  },
273
  {
274
  "cell_type": "code",
training/train.py CHANGED
@@ -32,7 +32,7 @@ def main():
32
  help="Maximum training samples")
33
  parser.add_argument("--num_generations", type=int, default=4,
34
  help="Number of completions per prompt for GRPO")
35
- parser.add_argument("--learning_rate", type=float, default=2e-4,
36
  help="Learning rate")
37
  parser.add_argument("--output_dir", type=str, default="outputs/rhythmenv_trained",
38
  help="Output directory for model and logs")
@@ -128,11 +128,13 @@ def main():
128
  from trl import GRPOConfig, GRPOTrainer
129
 
130
  max_prompt_length = 400
131
- max_completion_length = max_seq_length - max_prompt_length
132
 
133
  training_args = GRPOConfig(
134
  temperature=1.0,
135
  learning_rate=args.learning_rate,
 
 
136
  weight_decay=0.001,
137
  warmup_ratio=0.1,
138
  lr_scheduler_type="linear",
 
32
  help="Maximum training samples")
33
  parser.add_argument("--num_generations", type=int, default=4,
34
  help="Number of completions per prompt for GRPO")
35
+ parser.add_argument("--learning_rate", type=float, default=5e-5,
36
  help="Learning rate")
37
  parser.add_argument("--output_dir", type=str, default="outputs/rhythmenv_trained",
38
  help="Output directory for model and logs")
 
128
  from trl import GRPOConfig, GRPOTrainer
129
 
130
  max_prompt_length = 400
131
+ max_completion_length = 16 # Action names are 3-15 chars; cap prevents verbose drift
132
 
133
  training_args = GRPOConfig(
134
  temperature=1.0,
135
  learning_rate=args.learning_rate,
136
+ beta=0.1, # KL penalty — higher = more conservative, prevents policy drift
137
+ max_grad_norm=0.5, # Gradient clipping prevents large destabilizing updates
138
  weight_decay=0.001,
139
  warmup_ratio=0.1,
140
  lr_scheduler_type="linear",