Spaces:
Sleeping
Sleeping
fix: rename kl_coef to beta (correct param name in TRL GRPOConfig)
Browse files
training/RhythmEnv_GRPO_Training.ipynb
CHANGED
|
@@ -224,7 +224,7 @@
|
|
| 224 |
"execution_count": null,
|
| 225 |
"metadata": {},
|
| 226 |
"outputs": [],
|
| 227 |
-
"source": "from trl import GRPOConfig, GRPOTrainer\n\nMAX_STEPS = 500 # Increase to 1000 if time allows\nNUM_GENERATIONS = 4\nLEARNING_RATE = 2e-4\n\nmax_prompt_length = 400\nmax_completion_length = 16 # Action names are 3-15 chars — no need for more\n\ntraining_args = GRPOConfig(\n temperature=1.0,\n learning_rate=LEARNING_RATE,\n
|
| 228 |
},
|
| 229 |
{
|
| 230 |
"cell_type": "code",
|
|
|
|
| 224 |
"execution_count": null,
|
| 225 |
"metadata": {},
|
| 226 |
"outputs": [],
|
| 227 |
+
"source": "from trl import GRPOConfig, GRPOTrainer\n\nMAX_STEPS = 500 # Increase to 1000 if time allows\nNUM_GENERATIONS = 4\nLEARNING_RATE = 2e-4\n\nmax_prompt_length = 400\nmax_completion_length = 16 # Action names are 3-15 chars — no need for more\n\ntraining_args = GRPOConfig(\n temperature=1.0,\n learning_rate=LEARNING_RATE,\n beta=0.01, # KL penalty coefficient (called beta in TRL, default 0.04 causes KL explosion)\n weight_decay=0.001,\n warmup_ratio=0.1,\n lr_scheduler_type=\"linear\",\n optim=\"adamw_8bit\",\n logging_steps=1,\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n num_generations=NUM_GENERATIONS,\n max_prompt_length=max_prompt_length,\n max_completion_length=max_completion_length,\n max_steps=MAX_STEPS,\n save_steps=100,\n report_to=REPORT_TO,\n output_dir=\"outputs/rhythmenv_trained\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=reward_funcs,\n args=training_args,\n train_dataset=dataset,\n)\n\nprint(f\"Training config: {MAX_STEPS} steps, {NUM_GENERATIONS} generations, lr={LEARNING_RATE}\")\nprint(f\" beta=0.01 (KL penalty — reduced from default to prevent policy drift)\")\nprint(f\" max_completion_length=16 (action names only, no verbose outputs)\")\nprint(\"Starting training...\")"
|
| 228 |
},
|
| 229 |
{
|
| 230 |
"cell_type": "code",
|