Fix batch/num_generations: 4/4 on 1 GPU, grad_accum=4
Browse files- run_training.py +3 -3
run_training.py
CHANGED
|
@@ -235,14 +235,14 @@ def run_grpo_training():
|
|
| 235 |
grpo_config = GRPOConfig(
|
| 236 |
output_dir="training/outputs/grpo_checkpoints",
|
| 237 |
num_train_epochs=3,
|
| 238 |
-
per_device_train_batch_size=4, # must
|
| 239 |
-
gradient_accumulation_steps=
|
| 240 |
learning_rate=1e-5,
|
| 241 |
logging_steps=1,
|
| 242 |
save_steps=50,
|
| 243 |
max_prompt_length=1024,
|
| 244 |
max_completion_length=96,
|
| 245 |
-
num_generations=4,
|
| 246 |
report_to="none",
|
| 247 |
remove_unused_columns=False,
|
| 248 |
bf16=_bf16,
|
|
|
|
| 235 |
grpo_config = GRPOConfig(
|
| 236 |
output_dir="training/outputs/grpo_checkpoints",
|
| 237 |
num_train_epochs=3,
|
| 238 |
+
per_device_train_batch_size=4, # must equal num_generations on 1 GPU
|
| 239 |
+
gradient_accumulation_steps=4,
|
| 240 |
learning_rate=1e-5,
|
| 241 |
logging_steps=1,
|
| 242 |
save_steps=50,
|
| 243 |
max_prompt_length=1024,
|
| 244 |
max_completion_length=96,
|
| 245 |
+
num_generations=4,
|
| 246 |
report_to="none",
|
| 247 |
remove_unused_columns=False,
|
| 248 |
bf16=_bf16,
|