K446 commited on
Commit
a6ecb81
·
1 Parent(s): 8dab919

Add pre-train gen sanity check, explicit GenerationConfig, dynamic GRPOConfig params, torch_compile/vllm off

Browse files
Files changed (1) hide show
  1. run_training.py +36 -6
run_training.py CHANGED
@@ -190,6 +190,8 @@ def run_grpo_training():
190
  print("\n[4/6] Starting GRPO training...")
191
  from trl import GRPOTrainer, GRPOConfig
192
  from datasets import Dataset
 
 
193
 
194
  _bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
195
  _fp16 = torch.cuda.is_available() and not _bf16
@@ -218,27 +220,40 @@ def run_grpo_training():
218
 
219
  return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=1)
220
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  grpo_config = GRPOConfig(
222
  output_dir="training/outputs/grpo_checkpoints",
223
  num_train_epochs=3,
224
  per_device_train_batch_size=2,
225
- gradient_accumulation_steps=2, # was 8 — first visible step lands ~4x sooner
226
  learning_rate=1e-5,
227
- logging_steps=1, # was 5 — see loss every step
228
  save_steps=50,
229
- max_prompt_length=1024, # default 512 truncates Karnataka prompts
230
- max_completion_length=96, # was 128 — ~25% faster generation
231
  num_generations=2,
232
- temperature=0.7, # was 0.9 default — less wasted sampling
233
  report_to="none",
234
  remove_unused_columns=False,
235
  bf16=_bf16,
236
  fp16=_fp16,
237
  gradient_checkpointing=True,
238
  gradient_checkpointing_kwargs={"use_reentrant": False},
239
- optim="paged_adamw_8bit", # canonical for QLoRA; adafactor fights bf16+bnb
240
  warmup_ratio=0.05,
241
  lr_scheduler_type="cosine",
 
 
242
  )
243
 
244
  train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
@@ -250,6 +265,21 @@ def run_grpo_training():
250
  reward_funcs=reward_fn, processing_class=tokenizer,
251
  )
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  t0 = time.time()
254
  trainer.train()
255
  train_time = time.time() - t0
 
190
  print("\n[4/6] Starting GRPO training...")
191
  from trl import GRPOTrainer, GRPOConfig
192
  from datasets import Dataset
193
+ import inspect as _inspect
194
+ _grpo_params = set(_inspect.signature(GRPOConfig.__init__).parameters)
195
 
196
  _bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
197
  _fp16 = torch.cuda.is_available() and not _bf16
 
220
 
221
  return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=1)
222
 
223
+ # Set generation config explicitly so EOS is always respected and
224
+ # generation never runs to max_completion_length every single time.
225
+ from transformers import GenerationConfig
226
+ model.generation_config = GenerationConfig(
227
+ do_sample=True,
228
+ temperature=0.7,
229
+ top_p=0.9,
230
+ pad_token_id=tokenizer.pad_token_id,
231
+ eos_token_id=tokenizer.eos_token_id,
232
+ max_new_tokens=96,
233
+ )
234
+
235
  grpo_config = GRPOConfig(
236
  output_dir="training/outputs/grpo_checkpoints",
237
  num_train_epochs=3,
238
  per_device_train_batch_size=2,
239
+ gradient_accumulation_steps=2,
240
  learning_rate=1e-5,
241
+ logging_steps=1,
242
  save_steps=50,
243
+ max_prompt_length=1024,
244
+ max_completion_length=96,
245
  num_generations=2,
 
246
  report_to="none",
247
  remove_unused_columns=False,
248
  bf16=_bf16,
249
  fp16=_fp16,
250
  gradient_checkpointing=True,
251
  gradient_checkpointing_kwargs={"use_reentrant": False},
252
+ optim="paged_adamw_8bit",
253
  warmup_ratio=0.05,
254
  lr_scheduler_type="cosine",
255
+ **({'torch_compile': False} if 'torch_compile' in _grpo_params else {}),
256
+ **({'use_vllm': False} if 'use_vllm' in _grpo_params else {}),
257
  )
258
 
259
  train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
 
265
  reward_funcs=reward_fn, processing_class=tokenizer,
266
  )
267
 
268
+ # ── Sanity-check generation before handing off to GRPO ──
269
+ # If this hangs, the model/tokenizer setup is the problem.
270
+ print(" [DEBUG] Testing model generation (should complete in <30s)...")
271
+ _test_inputs = tokenizer("Hello", return_tensors="pt").to(model.device)
272
+ with torch.no_grad():
273
+ _out = model.generate(
274
+ **_test_inputs,
275
+ max_new_tokens=8,
276
+ do_sample=False,
277
+ pad_token_id=tokenizer.pad_token_id,
278
+ eos_token_id=tokenizer.eos_token_id,
279
+ )
280
+ print(f" [DEBUG] Generation OK: {tokenizer.decode(_out[0][-8:], skip_special_tokens=True)!r}")
281
+
282
+ print(" [NOTE] First GRPO step includes Triton JIT — may show 0/N for up to 5 min. That is normal.")
283
  t0 = time.time()
284
  trainer.train()
285
  train_time = time.time() - t0