Fix GRPO training: reward variance, batch/gen alignment, generation config
Browse files- run_training.py +1 -1
- training/train_grpo.py +2 -2
run_training.py
CHANGED
|
@@ -235,7 +235,7 @@ def run_grpo_training():
|
|
| 235 |
grpo_config = GRPOConfig(
|
| 236 |
output_dir="training/outputs/grpo_checkpoints",
|
| 237 |
num_train_epochs=3,
|
| 238 |
-
per_device_train_batch_size=
|
| 239 |
gradient_accumulation_steps=2,
|
| 240 |
learning_rate=1e-5,
|
| 241 |
logging_steps=1,
|
|
|
|
| 235 |
grpo_config = GRPOConfig(
|
| 236 |
output_dir="training/outputs/grpo_checkpoints",
|
| 237 |
num_train_epochs=3,
|
| 238 |
+
per_device_train_batch_size=4, # must be divisible by num_generations (4)
|
| 239 |
gradient_accumulation_steps=2,
|
| 240 |
learning_rate=1e-5,
|
| 241 |
logging_steps=1,
|
training/train_grpo.py
CHANGED
|
@@ -557,8 +557,8 @@ def train_grpo(args):
|
|
| 557 |
grpo_config = GRPOConfig(
|
| 558 |
output_dir=str(Path(args.output_dir) / "grpo_checkpoints"),
|
| 559 |
num_train_epochs=args.epochs,
|
| 560 |
-
per_device_train_batch_size=args.batch_size,
|
| 561 |
-
gradient_accumulation_steps=max(1, 8 // args.batch_size),
|
| 562 |
learning_rate=1e-5,
|
| 563 |
logging_steps=1,
|
| 564 |
save_steps=50,
|
|
|
|
| 557 |
grpo_config = GRPOConfig(
|
| 558 |
output_dir=str(Path(args.output_dir) / "grpo_checkpoints"),
|
| 559 |
num_train_epochs=args.epochs,
|
| 560 |
+
per_device_train_batch_size=max(args.batch_size, 4), # must be >= num_generations
|
| 561 |
+
gradient_accumulation_steps=max(1, 8 // max(args.batch_size, 4)),
|
| 562 |
learning_rate=1e-5,
|
| 563 |
logging_steps=1,
|
| 564 |
save_steps=50,
|