K446 commited on
Commit
114859b
·
1 Parent(s): d50432b

Fix batch/num_generations: 4/4 on 1 GPU, grad_accum=4

Browse files
Files changed (1) hide show
  1. run_training.py +3 -3
run_training.py CHANGED
@@ -235,14 +235,14 @@ def run_grpo_training():
235
  grpo_config = GRPOConfig(
236
  output_dir="training/outputs/grpo_checkpoints",
237
  num_train_epochs=3,
238
- per_device_train_batch_size=4, # must be divisible by num_generations (4)
239
- gradient_accumulation_steps=2,
240
  learning_rate=1e-5,
241
  logging_steps=1,
242
  save_steps=50,
243
  max_prompt_length=1024,
244
  max_completion_length=96,
245
- num_generations=4, # min for meaningful GRPO variance; 2 gives reward_std=0
246
  report_to="none",
247
  remove_unused_columns=False,
248
  bf16=_bf16,
 
235
  grpo_config = GRPOConfig(
236
  output_dir="training/outputs/grpo_checkpoints",
237
  num_train_epochs=3,
238
+ per_device_train_batch_size=4, # must equal num_generations on 1 GPU
239
+ gradient_accumulation_steps=4,
240
  learning_rate=1e-5,
241
  logging_steps=1,
242
  save_steps=50,
243
  max_prompt_length=1024,
244
  max_completion_length=96,
245
+ num_generations=4,
246
  report_to="none",
247
  remove_unused_columns=False,
248
  bf16=_bf16,