InosLihka Claude Sonnet 4.6 commited on
Commit
2c6ee11
·
1 Parent(s): 26b1e6a

fix: rename kl_coef to beta (correct param name in TRL GRPOConfig)

Browse files
training/RhythmEnv_GRPO_Training.ipynb CHANGED
@@ -224,7 +224,7 @@
224
  "execution_count": null,
225
  "metadata": {},
226
  "outputs": [],
227
- "source": "from trl import GRPOConfig, GRPOTrainer\n\nMAX_STEPS = 500 # Increase to 1000 if time allows\nNUM_GENERATIONS = 4\nLEARNING_RATE = 2e-4\n\nmax_prompt_length = 400\nmax_completion_length = 16 # Action names are 3-15 chars — no need for more\n\ntraining_args = GRPOConfig(\n temperature=1.0,\n learning_rate=LEARNING_RATE,\n kl_coef=0.01, # Default 0.1 caused KL explosion at step 205; 0.01 keeps drift in check\n weight_decay=0.001,\n warmup_ratio=0.1,\n lr_scheduler_type=\"linear\",\n optim=\"adamw_8bit\",\n logging_steps=1,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n num_generations=NUM_GENERATIONS,\n max_prompt_length=max_prompt_length,\n max_completion_length=max_completion_length,\n max_steps=MAX_STEPS,\n save_steps=100,\n report_to=REPORT_TO,\n output_dir=\"outputs/rhythmenv_trained\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=reward_funcs,\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Training config: {MAX_STEPS} steps, {NUM_GENERATIONS} generations, lr={LEARNING_RATE}\")\nprint(f\" kl_coef=0.01 (reduced from default 0.1 to prevent KL explosion)\")\nprint(f\" max_completion_length=16 (action names only, no verbose outputs)\")\nprint(\"Starting training...\")"
228
  },
229
  {
230
  "cell_type": "code",
 
224
  "execution_count": null,
225
  "metadata": {},
226
  "outputs": [],
227
+ "source": "from trl import GRPOConfig, GRPOTrainer\n\nMAX_STEPS = 500 # Increase to 1000 if time allows\nNUM_GENERATIONS = 4\nLEARNING_RATE = 2e-4\n\nmax_prompt_length = 400\nmax_completion_length = 16 # Action names are 3-15 chars — no need for more\n\ntraining_args = GRPOConfig(\n temperature=1.0,\n learning_rate=LEARNING_RATE,\n beta=0.01, # KL penalty coefficient (called beta in TRL, default 0.04 causes KL explosion)\n weight_decay=0.001,\n warmup_ratio=0.1,\n lr_scheduler_type=\"linear\",\n optim=\"adamw_8bit\",\n logging_steps=1,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n num_generations=NUM_GENERATIONS,\n max_prompt_length=max_prompt_length,\n max_completion_length=max_completion_length,\n max_steps=MAX_STEPS,\n save_steps=100,\n report_to=REPORT_TO,\n output_dir=\"outputs/rhythmenv_trained\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=reward_funcs,\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Training config: {MAX_STEPS} steps, {NUM_GENERATIONS} generations, lr={LEARNING_RATE}\")\nprint(f\" beta=0.01 (KL penalty — reduced from default to prevent policy drift)\")\nprint(f\" max_completion_length=16 (action names only, no verbose outputs)\")\nprint(\"Starting training...\")"
228
  },
229
  {
230
  "cell_type": "code",