K446 commited on
Commit
c09f4cb
·
1 Parent(s): 8d94e97

Fix OOM: reduce batch/gen/tokens, add grad checkpointing + adafactor

Browse files
Files changed (1) hide show
  1. run_training.py +8 -5
run_training.py CHANGED
@@ -4,6 +4,7 @@ Runs env-grounded GRPO training, saves model + plots,
4
  then starts a FastAPI server to serve/download results.
5
  """
6
  import os
 
7
  import sys
8
  import json
9
  import copy
@@ -125,7 +126,7 @@ def run_grpo_training():
125
  obs_contexts = []
126
  rng = np.random.RandomState(base_seed)
127
 
128
- for episode in range(30):
129
  ep_config = copy.deepcopy(task_config)
130
  ep_config['seed'] = base_seed + episode
131
  env = OpenGridEnv(ep_config)
@@ -208,17 +209,19 @@ def run_grpo_training():
208
  grpo_config = GRPOConfig(
209
  output_dir="training/outputs/grpo_checkpoints",
210
  num_train_epochs=3,
211
- per_device_train_batch_size=8,
212
- gradient_accumulation_steps=2,
213
  learning_rate=1e-5,
214
  logging_steps=5,
215
  save_steps=50,
216
- max_completion_length=256,
217
- num_generations=8,
218
  report_to="none",
219
  remove_unused_columns=False,
220
  bf16=_bf16,
221
  fp16=_fp16,
 
 
222
  )
223
 
224
  train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
 
4
  then starts a FastAPI server to serve/download results.
5
  """
6
  import os
7
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
8
  import sys
9
  import json
10
  import copy
 
126
  obs_contexts = []
127
  rng = np.random.RandomState(base_seed)
128
 
129
+ for episode in range(10): # 10 episodes → ~600 prompts, fits training time
130
  ep_config = copy.deepcopy(task_config)
131
  ep_config['seed'] = base_seed + episode
132
  env = OpenGridEnv(ep_config)
 
209
  grpo_config = GRPOConfig(
210
  output_dir="training/outputs/grpo_checkpoints",
211
  num_train_epochs=3,
212
+ per_device_train_batch_size=2,
213
+ gradient_accumulation_steps=8,
214
  learning_rate=1e-5,
215
  logging_steps=5,
216
  save_steps=50,
217
+ max_completion_length=128,
218
+ num_generations=2,
219
  report_to="none",
220
  remove_unused_columns=False,
221
  bf16=_bf16,
222
  fp16=_fp16,
223
+ gradient_checkpointing=True,
224
+ optim="adafactor",
225
  )
226
 
227
  train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})