K446 commited on
Commit
a76abcc
·
1 Parent(s): efbeb4b

Reduce prompt/completion length to fix silent OOM on backward pass

Browse files
Files changed (1) hide show
  1. run_training.py +4 -4
run_training.py CHANGED
@@ -229,19 +229,19 @@ def run_grpo_training():
229
  top_p=0.9,
230
  pad_token_id=tokenizer.pad_token_id,
231
  eos_token_id=tokenizer.eos_token_id,
232
- max_new_tokens=96,
233
  )
234
 
235
  grpo_config = GRPOConfig(
236
  output_dir="training/outputs/grpo_checkpoints",
237
  num_train_epochs=3,
238
- per_device_train_batch_size=4, # must equal num_generations on 1 GPU
239
  gradient_accumulation_steps=4,
240
  learning_rate=1e-5,
241
  logging_steps=1,
242
  save_steps=50,
243
- max_prompt_length=1024,
244
- max_completion_length=96,
245
  num_generations=4,
246
  report_to="none",
247
  remove_unused_columns=False,
 
229
  top_p=0.9,
230
  pad_token_id=tokenizer.pad_token_id,
231
  eos_token_id=tokenizer.eos_token_id,
232
+ max_new_tokens=64,
233
  )
234
 
235
  grpo_config = GRPOConfig(
236
  output_dir="training/outputs/grpo_checkpoints",
237
  num_train_epochs=3,
238
+ per_device_train_batch_size=4,
239
  gradient_accumulation_steps=4,
240
  learning_rate=1e-5,
241
  logging_steps=1,
242
  save_steps=50,
243
+ max_prompt_length=512,
244
+ max_completion_length=64,
245
  num_generations=4,
246
  report_to="none",
247
  remove_unused_columns=False,