K446 commited on
Commit
e1ab78c
·
1 Parent(s): 7be88b4

Fix GRPO training: reward variance, batch/gen alignment, generation config

Browse files
Files changed (2) hide show
  1. run_training.py +1 -1
  2. training/train_grpo.py +2 -2
run_training.py CHANGED
@@ -235,7 +235,7 @@ def run_grpo_training():
235
  grpo_config = GRPOConfig(
236
  output_dir="training/outputs/grpo_checkpoints",
237
  num_train_epochs=3,
238
- per_device_train_batch_size=2,
239
  gradient_accumulation_steps=2,
240
  learning_rate=1e-5,
241
  logging_steps=1,
 
235
  grpo_config = GRPOConfig(
236
  output_dir="training/outputs/grpo_checkpoints",
237
  num_train_epochs=3,
238
+ per_device_train_batch_size=4, # must be divisible by num_generations (4)
239
  gradient_accumulation_steps=2,
240
  learning_rate=1e-5,
241
  logging_steps=1,
training/train_grpo.py CHANGED
@@ -557,8 +557,8 @@ def train_grpo(args):
557
  grpo_config = GRPOConfig(
558
  output_dir=str(Path(args.output_dir) / "grpo_checkpoints"),
559
  num_train_epochs=args.epochs,
560
- per_device_train_batch_size=args.batch_size,
561
- gradient_accumulation_steps=max(1, 8 // args.batch_size),
562
  learning_rate=1e-5,
563
  logging_steps=1,
564
  save_steps=50,
 
557
  grpo_config = GRPOConfig(
558
  output_dir=str(Path(args.output_dir) / "grpo_checkpoints"),
559
  num_train_epochs=args.epochs,
560
+ per_device_train_batch_size=max(args.batch_size, 4), # must be >= num_generations
561
+ gradient_accumulation_steps=max(1, 8 // max(args.batch_size, 4)),
562
  learning_rate=1e-5,
563
  logging_steps=1,
564
  save_steps=50,