shank commited on
Commit Β·
dc8001b
1
Parent(s): 4bac574
Fix eval device selection with CUDA-safe fallback
Browse files- training/train_grpo.py +36 -2
training/train_grpo.py
CHANGED
|
@@ -253,6 +253,40 @@ model = FastLanguageModel.get_peft_model(
|
|
| 253 |
)
|
| 254 |
print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")
|
| 255 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
# ββ Reward function βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 257 |
calculator = DebugRewardCalculator()
|
| 258 |
|
|
@@ -305,7 +339,7 @@ def run_baseline(n: int = 20) -> dict:
|
|
| 305 |
solved = 0
|
| 306 |
for bug in bugs:
|
| 307 |
prompt = bug_to_prompt(bug)
|
| 308 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(
|
| 309 |
with torch.no_grad():
|
| 310 |
out = model.generate(**inputs, max_new_tokens=400, temperature=0.1, do_sample=False)
|
| 311 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
|
@@ -379,7 +413,7 @@ post_rewards = []
|
|
| 379 |
post_solved = 0
|
| 380 |
for bug in bugs:
|
| 381 |
prompt = bug_to_prompt(bug)
|
| 382 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(
|
| 383 |
with torch.no_grad():
|
| 384 |
out = model.generate(**inputs, max_new_tokens=400, temperature=0.1, do_sample=False)
|
| 385 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
|
|
|
| 253 |
)
|
| 254 |
print(f"Trainable params: {model.num_parameters(only_trainable=True):,}")
|
| 255 |
|
| 256 |
+
# ββ Runtime device selection ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 257 |
+
def _select_runtime_device(model) -> str:
|
| 258 |
+
"""
|
| 259 |
+
Pick the safest generation device without forcing CUDA init on broken drivers.
|
| 260 |
+
"""
|
| 261 |
+
def _cuda_usable() -> bool:
|
| 262 |
+
try:
|
| 263 |
+
if not torch.cuda.is_available():
|
| 264 |
+
return False
|
| 265 |
+
# Force lightweight CUDA init probe.
|
| 266 |
+
_ = torch.zeros(1, device="cuda")
|
| 267 |
+
return True
|
| 268 |
+
except Exception as e:
|
| 269 |
+
print(f"WARNING: CUDA initialization failed ({e}). Falling back to CPU.")
|
| 270 |
+
return False
|
| 271 |
+
|
| 272 |
+
# Prefer model's current device when available.
|
| 273 |
+
try:
|
| 274 |
+
model_device = str(next(model.parameters()).device)
|
| 275 |
+
if model_device.startswith("cuda") and not _cuda_usable():
|
| 276 |
+
return "cpu"
|
| 277 |
+
return model_device
|
| 278 |
+
except Exception:
|
| 279 |
+
pass
|
| 280 |
+
|
| 281 |
+
# Fallback to torch capability checks.
|
| 282 |
+
if _cuda_usable():
|
| 283 |
+
return "cuda"
|
| 284 |
+
return "cpu"
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
RUNTIME_DEVICE = _select_runtime_device(model)
|
| 288 |
+
print(f"Using generation/training runtime device: {RUNTIME_DEVICE}")
|
| 289 |
+
|
| 290 |
# ββ Reward function βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 291 |
calculator = DebugRewardCalculator()
|
| 292 |
|
|
|
|
| 339 |
solved = 0
|
| 340 |
for bug in bugs:
|
| 341 |
prompt = bug_to_prompt(bug)
|
| 342 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
|
| 343 |
with torch.no_grad():
|
| 344 |
out = model.generate(**inputs, max_new_tokens=400, temperature=0.1, do_sample=False)
|
| 345 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|
|
|
|
| 413 |
post_solved = 0
|
| 414 |
for bug in bugs:
|
| 415 |
prompt = bug_to_prompt(bug)
|
| 416 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(RUNTIME_DEVICE)
|
| 417 |
with torch.no_grad():
|
| 418 |
out = model.generate(**inputs, max_new_tokens=400, temperature=0.1, do_sample=False)
|
| 419 |
completion = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
|