Prasham.Jain Claude Sonnet 4.6 commited on
Commit
ddfe351
·
1 Parent(s): 68277e2

fix(training): upgrade to torch 2.5.1+cu124, restore unsloth for Qwen3

Browse files

Root cause of the dependency chain:
torch 2.4 → can't use torchao>0.5 → must pin transformers<4.47
transformers<4.47 → no Qwen3 (qwen3_5 arch added in 4.51)
transformers<4.51 → no CompileConfig → unsloth import fails

Fix: bump base Docker image to pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
- torch 2.5.1 is compatible with modern torchao
- unsloth[cu124-torch251] installs transformers>=4.51, peft, trl, xformers
- Qwen3-4B architecture (qwen3_5) now recognized by transformers

sft.py:
- Restore unsloth FastLanguageModel (use_gradient_checkpointing="unsloth")
- MODEL_NAME = "unsloth/Qwen3-4B-bnb-4bit" (pre-quantized, 2x faster load)
- Add bf16=True, dataset_text_field="text" to SFTConfig

pyproject.toml:
- Bump training extras to transformers>=4.51, trl>=0.12, peft>=0.14, torch>=2.5

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Dockerfile.train CHANGED
@@ -9,7 +9,8 @@
9
  # HF_SCENARIOS_REPO, HF_SFT_DATASET_REPO, HF_MODEL_REPO (optional)
10
  # GRPO_STEPS (optional, default 100)
11
 
12
- FROM pytorch/pytorch:2.4.0-cuda12.1-cudnn9-devel
 
13
 
14
  ENV DEBIAN_FRONTEND=noninteractive
15
  ENV PYTHONUNBUFFERED=1
@@ -20,18 +21,13 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
20
 
21
  WORKDIR /workspace
22
 
23
- # 1. Pin versions compatible with torch 2.4.1 in this image.
24
- # torchao latest requires torch>=2.11; transformers>=4.47 pulls torchao as dep.
25
- # bitsandbytes replaces unsloth for 4-bit quantisation.
26
  RUN pip install --no-cache-dir \
27
- "torchao==0.5.0" \
28
- "transformers==4.46.3" \
29
- "trl==0.11.4" \
30
- "peft==0.13.2" \
31
- "accelerate==0.34.2" \
32
- "bitsandbytes>=0.43.0"
33
-
34
- # 2. Install project deps (versions pinned above won't be overridden)
35
  COPY pyproject.toml README.md ./
36
  COPY src/ src/
37
  RUN pip install --no-cache-dir -e ".[data,training]"
 
9
  # HF_SCENARIOS_REPO, HF_SFT_DATASET_REPO, HF_MODEL_REPO (optional)
10
  # GRPO_STEPS (optional, default 100)
11
 
12
+ # torch 2.5.1 + CUDA 12.4 — minimum needed for unsloth + transformers>=4.51 + Qwen3.
13
+ FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
14
 
15
  ENV DEBIAN_FRONTEND=noninteractive
16
  ENV PYTHONUNBUFFERED=1
 
21
 
22
  WORKDIR /workspace
23
 
24
+ # 1. Install unsloth for this exact torch/CUDA combo.
25
+ # This resolves and installs compatible versions of:
26
+ # transformers>=4.51 (Qwen3 + CompileConfig), peft, trl, accelerate, xformers.
27
  RUN pip install --no-cache-dir \
28
+ "unsloth[cu124-torch251] @ git+https://github.com/unslothai/unsloth.git"
29
+
30
+ # 2. Install project deps (unsloth already locked transformers/trl/peft above).
 
 
 
 
 
31
  COPY pyproject.toml README.md ./
32
  COPY src/ src/
33
  RUN pip install --no-cache-dir -e ".[data,training]"
pyproject.toml CHANGED
@@ -19,10 +19,11 @@ dependencies = [
19
 
20
  [project.optional-dependencies]
21
  training = [
22
- "torch>=2.3",
23
- "transformers>=4.45",
24
- "trl>=0.11",
25
- "accelerate>=0.30",
 
26
  "wandb>=0.17",
27
  "matplotlib>=3.8",
28
  "seaborn>=0.13",
 
19
 
20
  [project.optional-dependencies]
21
  training = [
22
+ "torch>=2.5",
23
+ "transformers>=4.51",
24
+ "trl>=0.12",
25
+ "peft>=0.14",
26
+ "accelerate>=0.34",
27
  "wandb>=0.17",
28
  "matplotlib>=3.8",
29
  "seaborn>=0.13",
src/ci_triage_env/training/sft.py CHANGED
@@ -1,12 +1,14 @@
1
- """SFT warmstart trainer — Qwen3-4B + LoRA on the C3 trajectory dataset.
2
 
3
- All GPU-heavy imports (trl, torch, peft) are lazy so the module is
4
  importable without a GPU for testing.
5
  """
6
 
7
  from __future__ import annotations
8
 
9
- MODEL_NAME = "Qwen/Qwen3-4B"
 
 
10
  MAX_SEQ_LEN = 8192
11
 
12
 
@@ -14,48 +16,29 @@ def load_model_for_sft(
14
  model_name: str = MODEL_NAME,
15
  max_seq_length: int = MAX_SEQ_LEN,
16
  ):
17
- """Load Qwen3-4B in 4-bit via bitsandbytes + LoRA via PEFT. Requires GPU."""
18
- import torch
19
- from peft import LoraConfig, TaskType, get_peft_model # type: ignore[import]
20
- from transformers import ( # type: ignore[import]
21
- AutoModelForCausalLM,
22
- AutoTokenizer,
23
- BitsAndBytesConfig,
24
- )
25
 
26
- bnb_config = BitsAndBytesConfig(
 
 
27
  load_in_4bit=True,
28
- bnb_4bit_quant_type="nf4",
29
- bnb_4bit_compute_dtype=torch.bfloat16,
30
- bnb_4bit_use_double_quant=True,
31
- )
32
-
33
- model = AutoModelForCausalLM.from_pretrained(
34
- model_name,
35
- quantization_config=bnb_config,
36
- device_map="auto",
37
- trust_remote_code=True,
38
  )
39
- model.gradient_checkpointing_enable()
40
-
41
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
42
- if tokenizer.pad_token is None:
43
- tokenizer.pad_token = tokenizer.eos_token
44
- tokenizer.model_max_length = max_seq_length
45
 
46
- lora_config = LoraConfig(
 
47
  r=16,
48
- lora_alpha=32,
49
  target_modules=[
50
  "q_proj", "k_proj", "v_proj", "o_proj",
51
  "gate_proj", "up_proj", "down_proj",
52
  ],
53
- lora_dropout=0.0,
 
54
  bias="none",
55
- task_type=TaskType.CAUSAL_LM,
 
56
  )
57
- model = get_peft_model(model, lora_config)
58
- model.print_trainable_parameters()
59
  return model, tokenizer
60
 
61
 
@@ -77,7 +60,7 @@ def run_sft(
77
  gradient_accumulation_steps: int = 4,
78
  model_name: str = MODEL_NAME,
79
  ) -> str:
80
- """Train the SFT warmstart model. Requires GPU + trl + peft + bitsandbytes.
81
 
82
  Args:
83
  dataset_path: Path to a HF Dataset saved by trajectory_gen (save_to_disk).
@@ -104,10 +87,13 @@ def run_sft(
104
  gradient_accumulation_steps=gradient_accumulation_steps,
105
  learning_rate=2e-5,
106
  warmup_ratio=0.05,
 
 
107
  logging_steps=10,
108
  save_steps=100,
109
  report_to="wandb",
110
  max_seq_length=MAX_SEQ_LEN,
 
111
  )
112
  trainer = SFTTrainer(
113
  model=model,
 
1
+ """SFT warmstart trainer — Qwen3-4B + LoRA via unsloth.
2
 
3
+ All GPU-heavy imports (unsloth, trl, torch) are lazy so the module is
4
  importable without a GPU for testing.
5
  """
6
 
7
  from __future__ import annotations
8
 
9
+ # unsloth hosts optimised weights; the bnb-4bit variant skips on-the-fly quantisation
10
+ # so it loads ~2x faster than the base float16 weights.
11
+ MODEL_NAME = "unsloth/Qwen3-4B-bnb-4bit"
12
  MAX_SEQ_LEN = 8192
13
 
14
 
 
16
  model_name: str = MODEL_NAME,
17
  max_seq_length: int = MAX_SEQ_LEN,
18
  ):
19
+ """Load Qwen3-4B with unsloth 4-bit + LoRA. Requires GPU and unsloth installed."""
20
+ from unsloth import FastLanguageModel # type: ignore[import]
 
 
 
 
 
 
21
 
22
+ model, tokenizer = FastLanguageModel.from_pretrained(
23
+ model_name=model_name,
24
+ max_seq_length=max_seq_length,
25
  load_in_4bit=True,
26
+ dtype=None, # auto — bfloat16 on Ampere+
 
 
 
 
 
 
 
 
 
27
  )
 
 
 
 
 
 
28
 
29
+ model = FastLanguageModel.get_peft_model(
30
+ model,
31
  r=16,
 
32
  target_modules=[
33
  "q_proj", "k_proj", "v_proj", "o_proj",
34
  "gate_proj", "up_proj", "down_proj",
35
  ],
36
+ lora_alpha=16,
37
+ lora_dropout=0,
38
  bias="none",
39
+ use_gradient_checkpointing="unsloth", # unsloth's gradient checkpointing is 30% faster
40
+ random_state=3407,
41
  )
 
 
42
  return model, tokenizer
43
 
44
 
 
60
  gradient_accumulation_steps: int = 4,
61
  model_name: str = MODEL_NAME,
62
  ) -> str:
63
+ """Train the SFT warmstart model. Requires GPU + unsloth + trl installed.
64
 
65
  Args:
66
  dataset_path: Path to a HF Dataset saved by trajectory_gen (save_to_disk).
 
87
  gradient_accumulation_steps=gradient_accumulation_steps,
88
  learning_rate=2e-5,
89
  warmup_ratio=0.05,
90
+ bf16=True,
91
+ fp16=False,
92
  logging_steps=10,
93
  save_steps=100,
94
  report_to="wandb",
95
  max_seq_length=MAX_SEQ_LEN,
96
+ dataset_text_field="text",
97
  )
98
  trainer = SFTTrainer(
99
  model=model,