Add pre-train gen sanity check, explicit GenerationConfig, dynamic GRPOConfig params, torch_compile/vllm off
Browse files- run_training.py +36 -6
run_training.py
CHANGED
|
@@ -190,6 +190,8 @@ def run_grpo_training():
|
|
| 190 |
print("\n[4/6] Starting GRPO training...")
|
| 191 |
from trl import GRPOTrainer, GRPOConfig
|
| 192 |
from datasets import Dataset
|
|
|
|
|
|
|
| 193 |
|
| 194 |
_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
| 195 |
_fp16 = torch.cuda.is_available() and not _bf16
|
|
@@ -218,27 +220,40 @@ def run_grpo_training():
|
|
| 218 |
|
| 219 |
return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=1)
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
grpo_config = GRPOConfig(
|
| 222 |
output_dir="training/outputs/grpo_checkpoints",
|
| 223 |
num_train_epochs=3,
|
| 224 |
per_device_train_batch_size=2,
|
| 225 |
-
gradient_accumulation_steps=2,
|
| 226 |
learning_rate=1e-5,
|
| 227 |
-
logging_steps=1,
|
| 228 |
save_steps=50,
|
| 229 |
-
max_prompt_length=1024,
|
| 230 |
-
max_completion_length=96,
|
| 231 |
num_generations=2,
|
| 232 |
-
temperature=0.7, # was 0.9 default — less wasted sampling
|
| 233 |
report_to="none",
|
| 234 |
remove_unused_columns=False,
|
| 235 |
bf16=_bf16,
|
| 236 |
fp16=_fp16,
|
| 237 |
gradient_checkpointing=True,
|
| 238 |
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 239 |
-
optim="paged_adamw_8bit",
|
| 240 |
warmup_ratio=0.05,
|
| 241 |
lr_scheduler_type="cosine",
|
|
|
|
|
|
|
| 242 |
)
|
| 243 |
|
| 244 |
train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
|
|
@@ -250,6 +265,21 @@ def run_grpo_training():
|
|
| 250 |
reward_funcs=reward_fn, processing_class=tokenizer,
|
| 251 |
)
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
t0 = time.time()
|
| 254 |
trainer.train()
|
| 255 |
train_time = time.time() - t0
|
|
|
|
| 190 |
print("\n[4/6] Starting GRPO training...")
|
| 191 |
from trl import GRPOTrainer, GRPOConfig
|
| 192 |
from datasets import Dataset
|
| 193 |
+
import inspect as _inspect
|
| 194 |
+
_grpo_params = set(_inspect.signature(GRPOConfig.__init__).parameters)
|
| 195 |
|
| 196 |
_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
|
| 197 |
_fp16 = torch.cuda.is_available() and not _bf16
|
|
|
|
| 220 |
|
| 221 |
return compute_grpo_reward_env(texts, obs_dicts, task_config, horizon=1)
|
| 222 |
|
| 223 |
+
# Set generation config explicitly so EOS is always respected and
|
| 224 |
+
# generation never runs to max_completion_length every single time.
|
| 225 |
+
from transformers import GenerationConfig
|
| 226 |
+
model.generation_config = GenerationConfig(
|
| 227 |
+
do_sample=True,
|
| 228 |
+
temperature=0.7,
|
| 229 |
+
top_p=0.9,
|
| 230 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 231 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 232 |
+
max_new_tokens=96,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
grpo_config = GRPOConfig(
|
| 236 |
output_dir="training/outputs/grpo_checkpoints",
|
| 237 |
num_train_epochs=3,
|
| 238 |
per_device_train_batch_size=2,
|
| 239 |
+
gradient_accumulation_steps=2,
|
| 240 |
learning_rate=1e-5,
|
| 241 |
+
logging_steps=1,
|
| 242 |
save_steps=50,
|
| 243 |
+
max_prompt_length=1024,
|
| 244 |
+
max_completion_length=96,
|
| 245 |
num_generations=2,
|
|
|
|
| 246 |
report_to="none",
|
| 247 |
remove_unused_columns=False,
|
| 248 |
bf16=_bf16,
|
| 249 |
fp16=_fp16,
|
| 250 |
gradient_checkpointing=True,
|
| 251 |
gradient_checkpointing_kwargs={"use_reentrant": False},
|
| 252 |
+
optim="paged_adamw_8bit",
|
| 253 |
warmup_ratio=0.05,
|
| 254 |
lr_scheduler_type="cosine",
|
| 255 |
+
**({'torch_compile': False} if 'torch_compile' in _grpo_params else {}),
|
| 256 |
+
**({'use_vllm': False} if 'use_vllm' in _grpo_params else {}),
|
| 257 |
)
|
| 258 |
|
| 259 |
train_dataset = Dataset.from_dict({"prompt": prompts, "obs_context": obs_contexts})
|
|
|
|
| 265 |
reward_funcs=reward_fn, processing_class=tokenizer,
|
| 266 |
)
|
| 267 |
|
| 268 |
+
# ── Sanity-check generation before handing off to GRPO ──
|
| 269 |
+
# If this hangs, the model/tokenizer setup is the problem.
|
| 270 |
+
print(" [DEBUG] Testing model generation (should complete in <30s)...")
|
| 271 |
+
_test_inputs = tokenizer("Hello", return_tensors="pt").to(model.device)
|
| 272 |
+
with torch.no_grad():
|
| 273 |
+
_out = model.generate(
|
| 274 |
+
**_test_inputs,
|
| 275 |
+
max_new_tokens=8,
|
| 276 |
+
do_sample=False,
|
| 277 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 278 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 279 |
+
)
|
| 280 |
+
print(f" [DEBUG] Generation OK: {tokenizer.decode(_out[0][-8:], skip_special_tokens=True)!r}")
|
| 281 |
+
|
| 282 |
+
print(" [NOTE] First GRPO step includes Triton JIT — may show 0/N for up to 5 min. That is normal.")
|
| 283 |
t0 = time.time()
|
| 284 |
trainer.train()
|
| 285 |
train_time = time.time() - t0
|