Prasham.Jain Claude Sonnet 4.6 commited on
Commit
68277e2
·
1 Parent(s): e3da0da

fix(training): drop unsloth, use bitsandbytes+PEFT for SFT

Browse files

Unsloth requires transformers>=4.51 (for CompileConfig) but torch 2.4.1
in the Docker image forces transformers<=4.46 to avoid the torchao
version conflict. Remove unsloth entirely; use AutoModelForCausalLM +
BitsAndBytesConfig (nf4 4-bit) + PEFT LoRA instead — works identically
on 46 GB VRAM with no version conflicts.

Also fixes:
- MODEL_NAME: Qwen/Qwen3.5-4B → Qwen/Qwen3-4B (correct model ID)
- grpo.py: hp.pop() was called before hp dict was constructed (NameError)
- Dockerfile.train: remove unsloth install step, add bitsandbytes pin

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Dockerfile.train CHANGED
@@ -20,21 +20,18 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
20
 
21
  WORKDIR /workspace
22
 
23
- # 1. Pin torchao BEFORE installing anything else.
24
- # Latest torchao requires torch>=2.11 but this image ships torch 2.4.
25
- # transformers>=4.47 pulls torchao as a dep, so we must pin transformers too.
26
  RUN pip install --no-cache-dir \
27
  "torchao==0.5.0" \
28
  "transformers==4.46.3" \
29
  "trl==0.11.4" \
30
  "peft==0.13.2" \
31
- "accelerate==0.34.2"
 
32
 
33
- # 2. Install unsloth (must come after torch)
34
- RUN pip install --no-cache-dir \
35
- "unsloth[cu121-torch240] @ git+https://github.com/unslothai/unsloth.git"
36
-
37
- # 3. Install project deps (transformers/trl/peft already pinned above, won't be overridden)
38
  COPY pyproject.toml README.md ./
39
  COPY src/ src/
40
  RUN pip install --no-cache-dir -e ".[data,training]"
 
20
 
21
  WORKDIR /workspace
22
 
23
+ # 1. Pin versions compatible with torch 2.4.1 in this image.
24
+ # torchao latest requires torch>=2.11; transformers>=4.47 pulls torchao as dep.
25
+ # bitsandbytes replaces unsloth for 4-bit quantisation.
26
  RUN pip install --no-cache-dir \
27
  "torchao==0.5.0" \
28
  "transformers==4.46.3" \
29
  "trl==0.11.4" \
30
  "peft==0.13.2" \
31
+ "accelerate==0.34.2" \
32
+ "bitsandbytes>=0.43.0"
33
 
34
+ # 2. Install project deps (versions pinned above won't be overridden)
 
 
 
 
35
  COPY pyproject.toml README.md ./
36
  COPY src/ src/
37
  RUN pip install --no-cache-dir -e ".[data,training]"
src/ci_triage_env/training/grpo.py CHANGED
@@ -64,6 +64,10 @@ def run_grpo(
64
  train_dir = Path(scenarios_train_path)
65
  scenario_ids = [p.stem for p in train_dir.rglob("*.json")] if train_dir.exists() else []
66
 
 
 
 
 
67
  max_turns = hp.pop("max_turns", 4) # short episodes for faster GRPO
68
  rollout = TrainingRollout(
69
  env_client=env_client,
@@ -74,10 +78,6 @@ def run_grpo(
74
 
75
  model, tokenizer = load_model_for_sft(model_name=sft_checkpoint_dir)
76
 
77
- hp = dict(GRPO_HYPERPARAMS)
78
- if hyperparams:
79
- hp.update(hyperparams)
80
-
81
  config = GRPOConfig(
82
  output_dir=output_dir,
83
  max_steps=total_steps,
 
64
  train_dir = Path(scenarios_train_path)
65
  scenario_ids = [p.stem for p in train_dir.rglob("*.json")] if train_dir.exists() else []
66
 
67
+ hp = dict(GRPO_HYPERPARAMS)
68
+ if hyperparams:
69
+ hp.update(hyperparams)
70
+
71
  max_turns = hp.pop("max_turns", 4) # short episodes for faster GRPO
72
  rollout = TrainingRollout(
73
  env_client=env_client,
 
78
 
79
  model, tokenizer = load_model_for_sft(model_name=sft_checkpoint_dir)
80
 
 
 
 
 
81
  config = GRPOConfig(
82
  output_dir=output_dir,
83
  max_steps=total_steps,
src/ci_triage_env/training/sft.py CHANGED
@@ -1,12 +1,12 @@
1
- """SFT warmstart trainer — Qwen3.5-4B + LoRA on the C3 trajectory dataset.
2
 
3
- All GPU-heavy imports (unsloth, trl, torch) are lazy so the module is
4
  importable without a GPU for testing.
5
  """
6
 
7
  from __future__ import annotations
8
 
9
- MODEL_NAME = "Qwen/Qwen3.5-4B"
10
  MAX_SEQ_LEN = 8192
11
 
12
 
@@ -14,28 +14,48 @@ def load_model_for_sft(
14
  model_name: str = MODEL_NAME,
15
  max_seq_length: int = MAX_SEQ_LEN,
16
  ):
17
- """Load Qwen model with Unsloth + LoRA. Requires GPU and unsloth installed."""
18
- from unsloth import FastLanguageModel # type: ignore[import]
 
 
 
 
 
 
19
 
20
- model, tokenizer = FastLanguageModel.from_pretrained(
21
- model_name=model_name,
22
- max_seq_length=max_seq_length,
23
  load_in_4bit=True,
24
- dtype=None,
 
 
 
 
 
 
 
 
 
25
  )
26
- model = FastLanguageModel.get_peft_model(
27
- model,
 
 
 
 
 
 
28
  r=16,
 
29
  target_modules=[
30
  "q_proj", "k_proj", "v_proj", "o_proj",
31
  "gate_proj", "up_proj", "down_proj",
32
  ],
33
- lora_alpha=32,
34
  lora_dropout=0.0,
35
  bias="none",
36
- use_gradient_checkpointing="unsloth",
37
- random_state=3407,
38
  )
 
 
39
  return model, tokenizer
40
 
41
 
@@ -57,7 +77,7 @@ def run_sft(
57
  gradient_accumulation_steps: int = 4,
58
  model_name: str = MODEL_NAME,
59
  ) -> str:
60
- """Train the SFT warmstart model. Requires GPU + unsloth + trl installed.
61
 
62
  Args:
63
  dataset_path: Path to a HF Dataset saved by trajectory_gen (save_to_disk).
 
1
+ """SFT warmstart trainer — Qwen3-4B + LoRA on the C3 trajectory dataset.
2
 
3
+ All GPU-heavy imports (trl, torch, peft) are lazy so the module is
4
  importable without a GPU for testing.
5
  """
6
 
7
  from __future__ import annotations
8
 
9
+ MODEL_NAME = "Qwen/Qwen3-4B"
10
  MAX_SEQ_LEN = 8192
11
 
12
 
 
14
  model_name: str = MODEL_NAME,
15
  max_seq_length: int = MAX_SEQ_LEN,
16
  ):
17
+ """Load Qwen3-4B in 4-bit via bitsandbytes + LoRA via PEFT. Requires GPU."""
18
+ import torch
19
+ from peft import LoraConfig, TaskType, get_peft_model # type: ignore[import]
20
+ from transformers import ( # type: ignore[import]
21
+ AutoModelForCausalLM,
22
+ AutoTokenizer,
23
+ BitsAndBytesConfig,
24
+ )
25
 
26
+ bnb_config = BitsAndBytesConfig(
 
 
27
  load_in_4bit=True,
28
+ bnb_4bit_quant_type="nf4",
29
+ bnb_4bit_compute_dtype=torch.bfloat16,
30
+ bnb_4bit_use_double_quant=True,
31
+ )
32
+
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ model_name,
35
+ quantization_config=bnb_config,
36
+ device_map="auto",
37
+ trust_remote_code=True,
38
  )
39
+ model.gradient_checkpointing_enable()
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
42
+ if tokenizer.pad_token is None:
43
+ tokenizer.pad_token = tokenizer.eos_token
44
+ tokenizer.model_max_length = max_seq_length
45
+
46
+ lora_config = LoraConfig(
47
  r=16,
48
+ lora_alpha=32,
49
  target_modules=[
50
  "q_proj", "k_proj", "v_proj", "o_proj",
51
  "gate_proj", "up_proj", "down_proj",
52
  ],
 
53
  lora_dropout=0.0,
54
  bias="none",
55
+ task_type=TaskType.CAUSAL_LM,
 
56
  )
57
+ model = get_peft_model(model, lora_config)
58
+ model.print_trainable_parameters()
59
  return model, tokenizer
60
 
61
 
 
77
  gradient_accumulation_steps: int = 4,
78
  model_name: str = MODEL_NAME,
79
  ) -> str:
80
+ """Train the SFT warmstart model. Requires GPU + trl + peft + bitsandbytes.
81
 
82
  Args:
83
  dataset_path: Path to a HF Dataset saved by trajectory_gen (save_to_disk).