akhiilll commited on
Commit
e893ade
·
verified ·
1 Parent(s): eed849b

make GRPOConfig kwargs version-tolerant

Browse files
Files changed (1) hide show
  1. training/train_grpo_hf_job.py +42 -24
training/train_grpo_hf_job.py CHANGED
@@ -288,30 +288,48 @@ assert sane_r > 0, f"reward fn broken (expected >0 on case 0, got {sane_r})"
288
  # 7. GRPO training
289
  # ---------------------------------------------------------------------------
290
 
291
- training_args = GRPOConfig(
292
- output_dir=str(OUT_DIR),
293
- learning_rate=LEARNING_RATE,
294
- adam_beta1=0.9,
295
- adam_beta2=0.99,
296
- weight_decay=0.1,
297
- warmup_ratio=0.1,
298
- lr_scheduler_type="cosine",
299
- optim="adamw_torch",
300
- logging_steps=1,
301
- per_device_train_batch_size=BATCH_SIZE,
302
- gradient_accumulation_steps=GRAD_ACCUM,
303
- num_generations=NUM_GENERATIONS,
304
- max_prompt_length=MAX_PROMPT_LEN,
305
- max_completion_length=MAX_COMPLETION_LEN,
306
- max_steps=NUM_GRPO_STEPS,
307
- save_steps=999_999,
308
- report_to="none",
309
- bf16=True,
310
- temperature=0.9,
311
- top_p=0.95,
312
- epsilon=0.2,
313
- beta=0.04,
314
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  trainer = GRPOTrainer(
317
  model=model,
 
288
  # 7. GRPO training
289
  # ---------------------------------------------------------------------------
290
 
291
+ # Build kwargs incrementally and only pass args the installed TRL accepts -
292
+ # the GRPOConfig surface has shifted across releases (max_prompt_length
293
+ # disappeared in some, top_p / epsilon were renamed, etc.).
294
+ import inspect
295
+
296
+ _grpo_sig = inspect.signature(GRPOConfig.__init__).parameters
297
+
298
+ _grpo_kwargs: dict = {
299
+ "output_dir": str(OUT_DIR),
300
+ "learning_rate": LEARNING_RATE,
301
+ "weight_decay": 0.1,
302
+ "warmup_ratio": 0.1,
303
+ "lr_scheduler_type": "cosine",
304
+ "optim": "adamw_torch",
305
+ "logging_steps": 1,
306
+ "per_device_train_batch_size": BATCH_SIZE,
307
+ "gradient_accumulation_steps": GRAD_ACCUM,
308
+ "num_generations": NUM_GENERATIONS,
309
+ "max_steps": NUM_GRPO_STEPS,
310
+ "save_steps": 999_999,
311
+ "report_to": "none",
312
+ "bf16": True,
313
+ }
314
+
315
+ _optional_kwargs: dict = {
316
+ "adam_beta1": 0.9,
317
+ "adam_beta2": 0.99,
318
+ "max_prompt_length": MAX_PROMPT_LEN,
319
+ "max_completion_length": MAX_COMPLETION_LEN,
320
+ "temperature": 0.9,
321
+ "top_p": 0.95,
322
+ "epsilon": 0.2,
323
+ "beta": 0.04,
324
+ }
325
+ for k, v in _optional_kwargs.items():
326
+ if k in _grpo_sig:
327
+ _grpo_kwargs[k] = v
328
+ else:
329
+ print(f"[config] skipping unknown GRPOConfig arg: {k}")
330
+
331
+ print("[config] GRPOConfig kwargs:", sorted(_grpo_kwargs.keys()))
332
+ training_args = GRPOConfig(**_grpo_kwargs)
333
 
334
  trainer = GRPOTrainer(
335
  model=model,