Spaces:
Sleeping
Sleeping
make GRPOConfig kwargs version-tolerant
Browse files- training/train_grpo_hf_job.py +42 -24
training/train_grpo_hf_job.py
CHANGED
|
@@ -288,30 +288,48 @@ assert sane_r > 0, f"reward fn broken (expected >0 on case 0, got {sane_r})"
|
|
| 288 |
# 7. GRPO training
|
| 289 |
# ---------------------------------------------------------------------------
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
|
| 316 |
trainer = GRPOTrainer(
|
| 317 |
model=model,
|
|
|
|
| 288 |
# 7. GRPO training
|
| 289 |
# ---------------------------------------------------------------------------
|
| 290 |
|
| 291 |
+
# Build kwargs incrementally and only pass args the installed TRL accepts -
|
| 292 |
+
# the GRPOConfig surface has shifted across releases (max_prompt_length
|
| 293 |
+
# disappeared in some, top_p / epsilon were renamed, etc.).
|
| 294 |
+
import inspect
|
| 295 |
+
|
| 296 |
+
_grpo_sig = inspect.signature(GRPOConfig.__init__).parameters
|
| 297 |
+
|
| 298 |
+
_grpo_kwargs: dict = {
|
| 299 |
+
"output_dir": str(OUT_DIR),
|
| 300 |
+
"learning_rate": LEARNING_RATE,
|
| 301 |
+
"weight_decay": 0.1,
|
| 302 |
+
"warmup_ratio": 0.1,
|
| 303 |
+
"lr_scheduler_type": "cosine",
|
| 304 |
+
"optim": "adamw_torch",
|
| 305 |
+
"logging_steps": 1,
|
| 306 |
+
"per_device_train_batch_size": BATCH_SIZE,
|
| 307 |
+
"gradient_accumulation_steps": GRAD_ACCUM,
|
| 308 |
+
"num_generations": NUM_GENERATIONS,
|
| 309 |
+
"max_steps": NUM_GRPO_STEPS,
|
| 310 |
+
"save_steps": 999_999,
|
| 311 |
+
"report_to": "none",
|
| 312 |
+
"bf16": True,
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
_optional_kwargs: dict = {
|
| 316 |
+
"adam_beta1": 0.9,
|
| 317 |
+
"adam_beta2": 0.99,
|
| 318 |
+
"max_prompt_length": MAX_PROMPT_LEN,
|
| 319 |
+
"max_completion_length": MAX_COMPLETION_LEN,
|
| 320 |
+
"temperature": 0.9,
|
| 321 |
+
"top_p": 0.95,
|
| 322 |
+
"epsilon": 0.2,
|
| 323 |
+
"beta": 0.04,
|
| 324 |
+
}
|
| 325 |
+
for k, v in _optional_kwargs.items():
|
| 326 |
+
if k in _grpo_sig:
|
| 327 |
+
_grpo_kwargs[k] = v
|
| 328 |
+
else:
|
| 329 |
+
print(f"[config] skipping unknown GRPOConfig arg: {k}")
|
| 330 |
+
|
| 331 |
+
print("[config] GRPOConfig kwargs:", sorted(_grpo_kwargs.keys()))
|
| 332 |
+
training_args = GRPOConfig(**_grpo_kwargs)
|
| 333 |
|
| 334 |
trainer = GRPOTrainer(
|
| 335 |
model=model,
|