Add transformers-backend GRPO loader (no triton/Unsloth dep) + fix Jobs deps
Browse filesThe first HF Jobs validation hit two image-level issues:
* pip's resolution pulled torch 2.11+cu130 into a CUDA-12.4 image, so
bitsandbytes failed to load (libnvJitLink.so.13 missing)
* Unsloth requires triton, which JIT-compiles a CUDA helper at runtime
and needs `cc` — the slim pytorch image doesn't ship one
Changes:
* grpo_train.py: add load_transformers_model() (plain transformers +
peft + bnb 4-bit) and load_model() dispatcher with auto-fallback
* CLI flag --backend {auto,unsloth,transformers}
* scripts/jobs_grpo_train.sh: pin torch 2.4.1+cu124, bitsandbytes 0.43.3,
upper-bound transformers/trl/peft/datasets/accelerate to versions that
ship CUDA-12 wheels; default GRPO_BACKEND=transformers so triton
isn't needed in the Job container
110/110 tests still green; new code paths only execute on the GPU side.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- scripts/jobs_grpo_train.sh +18 -9
- train/grpo_train.py +125 -2
|
@@ -34,18 +34,24 @@ echo "==[chaosops]== GPU info"
|
|
| 34 |
nvidia-smi | head -3 || true
|
| 35 |
|
| 36 |
echo "==[chaosops]== installing python deps"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
pip install --quiet --upgrade pip
|
|
|
|
|
|
|
|
|
|
| 38 |
pip install --quiet \
|
| 39 |
-
"
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
| 43 |
-
"
|
| 44 |
-
"bitsandbytes" \
|
| 45 |
"huggingface_hub>=0.24.0" \
|
| 46 |
"pydantic>=2.0.0" \
|
| 47 |
-
"matplotlib>=3.7.0"
|
| 48 |
-
"unsloth"
|
| 49 |
|
| 50 |
echo "==[chaosops]== preparing source tree"
|
| 51 |
mkdir -p /workspace
|
|
@@ -55,9 +61,12 @@ export PYTHONPATH="/tmp:${PYTHONPATH:-}"
|
|
| 55 |
cd /workspace
|
| 56 |
mkdir -p "${OUTPUT_DIR}"
|
| 57 |
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
python -m chaosops.train.grpo_train \
|
| 60 |
--model-name "${GRPO_MODEL}" \
|
|
|
|
| 61 |
--total-episodes "${GRPO_EPISODES}" \
|
| 62 |
--group-size "${GRPO_GROUP_SIZE}" \
|
| 63 |
--log-every "${GRPO_LOG_EVERY}" \
|
|
|
|
| 34 |
nvidia-smi | head -3 || true
|
| 35 |
|
| 36 |
echo "==[chaosops]== installing python deps"
|
| 37 |
+
# Pin torch/torchvision to the image's CUDA 12.4 wheels so transformers/peft
|
| 38 |
+
# don't get pulled into a CUDA-13 wheel set (the bitsandbytes mismatch we
|
| 39 |
+
# saw in the first validation run). bitsandbytes is pinned to the last
|
| 40 |
+
# CUDA-12 compatible release.
|
| 41 |
pip install --quiet --upgrade pip
|
| 42 |
+
pip install --quiet --no-deps \
|
| 43 |
+
"torch==2.4.1+cu124" \
|
| 44 |
+
--index-url https://download.pytorch.org/whl/cu124 || true
|
| 45 |
pip install --quiet \
|
| 46 |
+
"transformers>=4.44.0,<4.50.0" \
|
| 47 |
+
"trl>=0.10.0,<0.15.0" \
|
| 48 |
+
"peft>=0.12.0,<0.14.0" \
|
| 49 |
+
"datasets>=2.20.0,<3.0.0" \
|
| 50 |
+
"accelerate>=0.33.0,<0.36.0" \
|
| 51 |
+
"bitsandbytes==0.43.3" \
|
| 52 |
"huggingface_hub>=0.24.0" \
|
| 53 |
"pydantic>=2.0.0" \
|
| 54 |
+
"matplotlib>=3.7.0"
|
|
|
|
| 55 |
|
| 56 |
echo "==[chaosops]== preparing source tree"
|
| 57 |
mkdir -p /workspace
|
|
|
|
| 61 |
cd /workspace
|
| 62 |
mkdir -p "${OUTPUT_DIR}"
|
| 63 |
|
| 64 |
+
GRPO_BACKEND="${GRPO_BACKEND:-transformers}"
|
| 65 |
+
|
| 66 |
+
echo "==[chaosops]== launching GRPO (backend=$GRPO_BACKEND, $GRPO_EPISODES episodes, group=$GRPO_GROUP_SIZE, lora_rank=$GRPO_LORA_RANK)"
|
| 67 |
python -m chaosops.train.grpo_train \
|
| 68 |
--model-name "${GRPO_MODEL}" \
|
| 69 |
+
--backend "${GRPO_BACKEND}" \
|
| 70 |
--total-episodes "${GRPO_EPISODES}" \
|
| 71 |
--group-size "${GRPO_GROUP_SIZE}" \
|
| 72 |
--log-every "${GRPO_LOG_EVERY}" \
|
|
@@ -317,7 +317,11 @@ def load_unsloth_model(
|
|
| 317 |
load_in_4bit: bool = True,
|
| 318 |
lora_rank: int = 32,
|
| 319 |
):
|
| 320 |
-
"""Load a base LLM with Unsloth + LoRA. Returns ``(model, tokenizer)``.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
from unsloth import FastLanguageModel # type: ignore[import-not-found]
|
| 322 |
|
| 323 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
@@ -345,6 +349,117 @@ def load_unsloth_model(
|
|
| 345 |
return model, tokenizer
|
| 346 |
|
| 347 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
def make_generate_fn(
|
| 349 |
model, tokenizer, *, max_new_tokens: int = 96, temperature: float = 0.7
|
| 350 |
) -> GenerateFn:
|
|
@@ -588,13 +703,21 @@ def _parse_args() -> argparse.Namespace:
|
|
| 588 |
default=DifficultyTier.EASY.value,
|
| 589 |
choices=[t.value for t in DifficultyTier],
|
| 590 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
return parser.parse_args()
|
| 592 |
|
| 593 |
|
| 594 |
def main() -> None:
|
| 595 |
args = _parse_args()
|
| 596 |
-
model, tokenizer =
|
| 597 |
args.model_name,
|
|
|
|
| 598 |
max_seq_length=args.max_seq_length,
|
| 599 |
lora_rank=args.lora_rank,
|
| 600 |
)
|
|
|
|
| 317 |
load_in_4bit: bool = True,
|
| 318 |
lora_rank: int = 32,
|
| 319 |
):
|
| 320 |
+
"""Load a base LLM with Unsloth + LoRA. Returns ``(model, tokenizer)``.
|
| 321 |
+
|
| 322 |
+
Requires triton + a C compiler at runtime; if either is missing,
|
| 323 |
+
fall back to :func:`load_transformers_model`.
|
| 324 |
+
"""
|
| 325 |
from unsloth import FastLanguageModel # type: ignore[import-not-found]
|
| 326 |
|
| 327 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
|
|
|
| 349 |
return model, tokenizer
|
| 350 |
|
| 351 |
|
| 352 |
+
def load_transformers_model(
|
| 353 |
+
model_name: str,
|
| 354 |
+
*,
|
| 355 |
+
max_seq_length: int = 2048,
|
| 356 |
+
load_in_4bit: bool = True,
|
| 357 |
+
lora_rank: int = 32,
|
| 358 |
+
):
|
| 359 |
+
"""Plain ``transformers + peft`` model loader — no Unsloth/triton dep.
|
| 360 |
+
|
| 361 |
+
Used when the runtime image doesn't ship triton/cc (most lightweight
|
| 362 |
+
CUDA images). Slightly slower per step than Unsloth but works on any
|
| 363 |
+
standard PyTorch image.
|
| 364 |
+
"""
|
| 365 |
+
import torch # type: ignore[import-not-found]
|
| 366 |
+
from peft import LoraConfig, get_peft_model # type: ignore[import-not-found]
|
| 367 |
+
from transformers import ( # type: ignore[import-not-found]
|
| 368 |
+
AutoModelForCausalLM,
|
| 369 |
+
AutoTokenizer,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 373 |
+
if tokenizer.pad_token_id is None:
|
| 374 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 375 |
+
|
| 376 |
+
load_kwargs: dict[str, Any] = {}
|
| 377 |
+
if load_in_4bit:
|
| 378 |
+
try:
|
| 379 |
+
from transformers import BitsAndBytesConfig # type: ignore[import-not-found]
|
| 380 |
+
|
| 381 |
+
load_kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 382 |
+
load_in_4bit=True,
|
| 383 |
+
bnb_4bit_quant_type="nf4",
|
| 384 |
+
bnb_4bit_use_double_quant=True,
|
| 385 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 386 |
+
)
|
| 387 |
+
except Exception:
|
| 388 |
+
# bnb unavailable — fall back to fp16 full-precision LoRA.
|
| 389 |
+
load_kwargs["torch_dtype"] = torch.float16
|
| 390 |
+
else:
|
| 391 |
+
load_kwargs["torch_dtype"] = torch.float16
|
| 392 |
+
|
| 393 |
+
if torch.cuda.is_available():
|
| 394 |
+
load_kwargs["device_map"] = {"": 0}
|
| 395 |
+
|
| 396 |
+
base = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)
|
| 397 |
+
lora_cfg = LoraConfig(
|
| 398 |
+
r=lora_rank,
|
| 399 |
+
lora_alpha=lora_rank,
|
| 400 |
+
lora_dropout=0.0,
|
| 401 |
+
bias="none",
|
| 402 |
+
target_modules=[
|
| 403 |
+
"q_proj",
|
| 404 |
+
"k_proj",
|
| 405 |
+
"v_proj",
|
| 406 |
+
"o_proj",
|
| 407 |
+
"gate_proj",
|
| 408 |
+
"up_proj",
|
| 409 |
+
"down_proj",
|
| 410 |
+
],
|
| 411 |
+
task_type="CAUSAL_LM",
|
| 412 |
+
)
|
| 413 |
+
model = get_peft_model(base, lora_cfg)
|
| 414 |
+
return model, tokenizer
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def load_model(
|
| 418 |
+
model_name: str,
|
| 419 |
+
*,
|
| 420 |
+
backend: str = "auto",
|
| 421 |
+
max_seq_length: int = 2048,
|
| 422 |
+
load_in_4bit: bool = True,
|
| 423 |
+
lora_rank: int = 32,
|
| 424 |
+
):
|
| 425 |
+
"""Dispatch to the requested loader, with auto-fallback.
|
| 426 |
+
|
| 427 |
+
``backend`` ∈ ``{"auto", "unsloth", "transformers"}``. ``auto`` tries
|
| 428 |
+
Unsloth first and falls back to transformers if the import fails or
|
| 429 |
+
the runtime can't satisfy triton's C-compiler dep.
|
| 430 |
+
"""
|
| 431 |
+
if backend == "transformers":
|
| 432 |
+
return load_transformers_model(
|
| 433 |
+
model_name,
|
| 434 |
+
max_seq_length=max_seq_length,
|
| 435 |
+
load_in_4bit=load_in_4bit,
|
| 436 |
+
lora_rank=lora_rank,
|
| 437 |
+
)
|
| 438 |
+
if backend == "unsloth":
|
| 439 |
+
return load_unsloth_model(
|
| 440 |
+
model_name,
|
| 441 |
+
max_seq_length=max_seq_length,
|
| 442 |
+
load_in_4bit=load_in_4bit,
|
| 443 |
+
lora_rank=lora_rank,
|
| 444 |
+
)
|
| 445 |
+
# auto
|
| 446 |
+
try:
|
| 447 |
+
return load_unsloth_model(
|
| 448 |
+
model_name,
|
| 449 |
+
max_seq_length=max_seq_length,
|
| 450 |
+
load_in_4bit=load_in_4bit,
|
| 451 |
+
lora_rank=lora_rank,
|
| 452 |
+
)
|
| 453 |
+
except Exception as exc:
|
| 454 |
+
print(f"[grpo_train] Unsloth path failed ({exc!r}); using transformers")
|
| 455 |
+
return load_transformers_model(
|
| 456 |
+
model_name,
|
| 457 |
+
max_seq_length=max_seq_length,
|
| 458 |
+
load_in_4bit=load_in_4bit,
|
| 459 |
+
lora_rank=lora_rank,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
def make_generate_fn(
|
| 464 |
model, tokenizer, *, max_new_tokens: int = 96, temperature: float = 0.7
|
| 465 |
) -> GenerateFn:
|
|
|
|
| 703 |
default=DifficultyTier.EASY.value,
|
| 704 |
choices=[t.value for t in DifficultyTier],
|
| 705 |
)
|
| 706 |
+
parser.add_argument(
|
| 707 |
+
"--backend",
|
| 708 |
+
type=str,
|
| 709 |
+
default="auto",
|
| 710 |
+
choices=["auto", "unsloth", "transformers"],
|
| 711 |
+
help="Model loader. 'auto' tries Unsloth, falls back to transformers.",
|
| 712 |
+
)
|
| 713 |
return parser.parse_args()
|
| 714 |
|
| 715 |
|
| 716 |
def main() -> None:
|
| 717 |
args = _parse_args()
|
| 718 |
+
model, tokenizer = load_model(
|
| 719 |
args.model_name,
|
| 720 |
+
backend=args.backend,
|
| 721 |
max_seq_length=args.max_seq_length,
|
| 722 |
lora_rank=args.lora_rank,
|
| 723 |
)
|