Spaces:
Sleeping
Sleeping
Prasham.Jain Claude Sonnet 4.6 commited on
Commit ·
68277e2
1
Parent(s): e3da0da
fix(training): drop unsloth, use bitsandbytes+PEFT for SFT
Browse filesUnsloth requires transformers>=4.51 (for CompileConfig) but torch 2.4.1
in the Docker image forces transformers<=4.46 to avoid the torchao
version conflict. Remove unsloth entirely; use AutoModelForCausalLM +
BitsAndBytesConfig (nf4 4-bit) + PEFT LoRA instead — works identically
on 46 GB VRAM with no version conflicts.
Also fixes:
- MODEL_NAME: Qwen/Qwen3.5-4B → Qwen/Qwen3-4B (correct model ID)
- grpo.py: hp.pop() was called before hp dict was constructed (NameError)
- Dockerfile.train: remove unsloth install step, add bitsandbytes pin
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Dockerfile.train +6 -9
- src/ci_triage_env/training/grpo.py +4 -4
- src/ci_triage_env/training/sft.py +35 -15
Dockerfile.train
CHANGED
|
@@ -20,21 +20,18 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
|
| 20 |
|
| 21 |
WORKDIR /workspace
|
| 22 |
|
| 23 |
-
# 1. Pin
|
| 24 |
-
#
|
| 25 |
-
#
|
| 26 |
RUN pip install --no-cache-dir \
|
| 27 |
"torchao==0.5.0" \
|
| 28 |
"transformers==4.46.3" \
|
| 29 |
"trl==0.11.4" \
|
| 30 |
"peft==0.13.2" \
|
| 31 |
-
"accelerate==0.34.2"
|
|
|
|
| 32 |
|
| 33 |
-
# 2. Install
|
| 34 |
-
RUN pip install --no-cache-dir \
|
| 35 |
-
"unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
|
| 36 |
-
|
| 37 |
-
# 3. Install project deps (transformers/trl/peft already pinned above, won't be overridden)
|
| 38 |
COPY pyproject.toml README.md ./
|
| 39 |
COPY src/ src/
|
| 40 |
RUN pip install --no-cache-dir -e ".[data,training]"
|
|
|
|
| 20 |
|
| 21 |
WORKDIR /workspace
|
| 22 |
|
| 23 |
+
# 1. Pin versions compatible with torch 2.4.1 in this image.
|
| 24 |
+
# torchao latest requires torch>=2.11; transformers>=4.47 pulls torchao as dep.
|
| 25 |
+
# bitsandbytes replaces unsloth for 4-bit quantisation.
|
| 26 |
RUN pip install --no-cache-dir \
|
| 27 |
"torchao==0.5.0" \
|
| 28 |
"transformers==4.46.3" \
|
| 29 |
"trl==0.11.4" \
|
| 30 |
"peft==0.13.2" \
|
| 31 |
+
"accelerate==0.34.2" \
|
| 32 |
+
"bitsandbytes>=0.43.0"
|
| 33 |
|
| 34 |
+
# 2. Install project deps (versions pinned above won't be overridden)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
COPY pyproject.toml README.md ./
|
| 36 |
COPY src/ src/
|
| 37 |
RUN pip install --no-cache-dir -e ".[data,training]"
|
src/ci_triage_env/training/grpo.py
CHANGED
|
@@ -64,6 +64,10 @@ def run_grpo(
|
|
| 64 |
train_dir = Path(scenarios_train_path)
|
| 65 |
scenario_ids = [p.stem for p in train_dir.rglob("*.json")] if train_dir.exists() else []
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
max_turns = hp.pop("max_turns", 4) # short episodes for faster GRPO
|
| 68 |
rollout = TrainingRollout(
|
| 69 |
env_client=env_client,
|
|
@@ -74,10 +78,6 @@ def run_grpo(
|
|
| 74 |
|
| 75 |
model, tokenizer = load_model_for_sft(model_name=sft_checkpoint_dir)
|
| 76 |
|
| 77 |
-
hp = dict(GRPO_HYPERPARAMS)
|
| 78 |
-
if hyperparams:
|
| 79 |
-
hp.update(hyperparams)
|
| 80 |
-
|
| 81 |
config = GRPOConfig(
|
| 82 |
output_dir=output_dir,
|
| 83 |
max_steps=total_steps,
|
|
|
|
| 64 |
train_dir = Path(scenarios_train_path)
|
| 65 |
scenario_ids = [p.stem for p in train_dir.rglob("*.json")] if train_dir.exists() else []
|
| 66 |
|
| 67 |
+
hp = dict(GRPO_HYPERPARAMS)
|
| 68 |
+
if hyperparams:
|
| 69 |
+
hp.update(hyperparams)
|
| 70 |
+
|
| 71 |
max_turns = hp.pop("max_turns", 4) # short episodes for faster GRPO
|
| 72 |
rollout = TrainingRollout(
|
| 73 |
env_client=env_client,
|
|
|
|
| 78 |
|
| 79 |
model, tokenizer = load_model_for_sft(model_name=sft_checkpoint_dir)
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
config = GRPOConfig(
|
| 82 |
output_dir=output_dir,
|
| 83 |
max_steps=total_steps,
|
src/ci_triage_env/training/sft.py
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
"""SFT warmstart trainer — Qwen3
|
| 2 |
|
| 3 |
-
All GPU-heavy imports (
|
| 4 |
importable without a GPU for testing.
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
-
MODEL_NAME = "Qwen/Qwen3
|
| 10 |
MAX_SEQ_LEN = 8192
|
| 11 |
|
| 12 |
|
|
@@ -14,28 +14,48 @@ def load_model_for_sft(
|
|
| 14 |
model_name: str = MODEL_NAME,
|
| 15 |
max_seq_length: int = MAX_SEQ_LEN,
|
| 16 |
):
|
| 17 |
-
"""Load
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
|
| 21 |
-
model_name=model_name,
|
| 22 |
-
max_seq_length=max_seq_length,
|
| 23 |
load_in_4bit=True,
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
)
|
| 26 |
-
model
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
r=16,
|
|
|
|
| 29 |
target_modules=[
|
| 30 |
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 31 |
"gate_proj", "up_proj", "down_proj",
|
| 32 |
],
|
| 33 |
-
lora_alpha=32,
|
| 34 |
lora_dropout=0.0,
|
| 35 |
bias="none",
|
| 36 |
-
|
| 37 |
-
random_state=3407,
|
| 38 |
)
|
|
|
|
|
|
|
| 39 |
return model, tokenizer
|
| 40 |
|
| 41 |
|
|
@@ -57,7 +77,7 @@ def run_sft(
|
|
| 57 |
gradient_accumulation_steps: int = 4,
|
| 58 |
model_name: str = MODEL_NAME,
|
| 59 |
) -> str:
|
| 60 |
-
"""Train the SFT warmstart model. Requires GPU +
|
| 61 |
|
| 62 |
Args:
|
| 63 |
dataset_path: Path to a HF Dataset saved by trajectory_gen (save_to_disk).
|
|
|
|
| 1 |
+
"""SFT warmstart trainer — Qwen3-4B + LoRA on the C3 trajectory dataset.
|
| 2 |
|
| 3 |
+
All GPU-heavy imports (trl, torch, peft) are lazy so the module is
|
| 4 |
importable without a GPU for testing.
|
| 5 |
"""
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
+
MODEL_NAME = "Qwen/Qwen3-4B"
|
| 10 |
MAX_SEQ_LEN = 8192
|
| 11 |
|
| 12 |
|
|
|
|
| 14 |
model_name: str = MODEL_NAME,
|
| 15 |
max_seq_length: int = MAX_SEQ_LEN,
|
| 16 |
):
|
| 17 |
+
"""Load Qwen3-4B in 4-bit via bitsandbytes + LoRA via PEFT. Requires GPU."""
|
| 18 |
+
import torch
|
| 19 |
+
from peft import LoraConfig, TaskType, get_peft_model # type: ignore[import]
|
| 20 |
+
from transformers import ( # type: ignore[import]
|
| 21 |
+
AutoModelForCausalLM,
|
| 22 |
+
AutoTokenizer,
|
| 23 |
+
BitsAndBytesConfig,
|
| 24 |
+
)
|
| 25 |
|
| 26 |
+
bnb_config = BitsAndBytesConfig(
|
|
|
|
|
|
|
| 27 |
load_in_4bit=True,
|
| 28 |
+
bnb_4bit_quant_type="nf4",
|
| 29 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 30 |
+
bnb_4bit_use_double_quant=True,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 34 |
+
model_name,
|
| 35 |
+
quantization_config=bnb_config,
|
| 36 |
+
device_map="auto",
|
| 37 |
+
trust_remote_code=True,
|
| 38 |
)
|
| 39 |
+
model.gradient_checkpointing_enable()
|
| 40 |
+
|
| 41 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 42 |
+
if tokenizer.pad_token is None:
|
| 43 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 44 |
+
tokenizer.model_max_length = max_seq_length
|
| 45 |
+
|
| 46 |
+
lora_config = LoraConfig(
|
| 47 |
r=16,
|
| 48 |
+
lora_alpha=32,
|
| 49 |
target_modules=[
|
| 50 |
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 51 |
"gate_proj", "up_proj", "down_proj",
|
| 52 |
],
|
|
|
|
| 53 |
lora_dropout=0.0,
|
| 54 |
bias="none",
|
| 55 |
+
task_type=TaskType.CAUSAL_LM,
|
|
|
|
| 56 |
)
|
| 57 |
+
model = get_peft_model(model, lora_config)
|
| 58 |
+
model.print_trainable_parameters()
|
| 59 |
return model, tokenizer
|
| 60 |
|
| 61 |
|
|
|
|
| 77 |
gradient_accumulation_steps: int = 4,
|
| 78 |
model_name: str = MODEL_NAME,
|
| 79 |
) -> str:
|
| 80 |
+
"""Train the SFT warmstart model. Requires GPU + trl + peft + bitsandbytes.
|
| 81 |
|
| 82 |
Args:
|
| 83 |
dataset_path: Path to a HF Dataset saved by trajectory_gen (save_to_disk).
|