Reduce prompt/completion length to fix silent OOM on backward pass
Browse files- run_training.py +4 -4
run_training.py
CHANGED
|
@@ -229,19 +229,19 @@ def run_grpo_training():
|
|
| 229 |
top_p=0.9,
|
| 230 |
pad_token_id=tokenizer.pad_token_id,
|
| 231 |
eos_token_id=tokenizer.eos_token_id,
|
| 232 |
-
max_new_tokens=
|
| 233 |
)
|
| 234 |
|
| 235 |
grpo_config = GRPOConfig(
|
| 236 |
output_dir="training/outputs/grpo_checkpoints",
|
| 237 |
num_train_epochs=3,
|
| 238 |
-
per_device_train_batch_size=4,
|
| 239 |
gradient_accumulation_steps=4,
|
| 240 |
learning_rate=1e-5,
|
| 241 |
logging_steps=1,
|
| 242 |
save_steps=50,
|
| 243 |
-
max_prompt_length=
|
| 244 |
-
max_completion_length=
|
| 245 |
num_generations=4,
|
| 246 |
report_to="none",
|
| 247 |
remove_unused_columns=False,
|
|
|
|
| 229 |
top_p=0.9,
|
| 230 |
pad_token_id=tokenizer.pad_token_id,
|
| 231 |
eos_token_id=tokenizer.eos_token_id,
|
| 232 |
+
max_new_tokens=64,
|
| 233 |
)
|
| 234 |
|
| 235 |
grpo_config = GRPOConfig(
|
| 236 |
output_dir="training/outputs/grpo_checkpoints",
|
| 237 |
num_train_epochs=3,
|
| 238 |
+
per_device_train_batch_size=4,
|
| 239 |
gradient_accumulation_steps=4,
|
| 240 |
learning_rate=1e-5,
|
| 241 |
logging_steps=1,
|
| 242 |
save_steps=50,
|
| 243 |
+
max_prompt_length=512,
|
| 244 |
+
max_completion_length=64,
|
| 245 |
num_generations=4,
|
| 246 |
report_to="none",
|
| 247 |
remove_unused_columns=False,
|