Fix batch size: 8 to match num_generations=8
Browse files- run_training.py +2 -2
run_training.py
CHANGED
|
@@ -208,8 +208,8 @@ def run_grpo_training():
|
|
| 208 |
grpo_config = GRPOConfig(
|
| 209 |
output_dir="training/outputs/grpo_checkpoints",
|
| 210 |
num_train_epochs=3,
|
| 211 |
-
per_device_train_batch_size=
|
| 212 |
-
gradient_accumulation_steps=
|
| 213 |
learning_rate=1e-5,
|
| 214 |
logging_steps=5,
|
| 215 |
save_steps=50,
|
|
|
|
| 208 |
grpo_config = GRPOConfig(
|
| 209 |
output_dir="training/outputs/grpo_checkpoints",
|
| 210 |
num_train_epochs=3,
|
| 211 |
+
per_device_train_batch_size=8,
|
| 212 |
+
gradient_accumulation_steps=2,
|
| 213 |
learning_rate=1e-5,
|
| 214 |
logging_steps=5,
|
| 215 |
save_steps=50,
|