Prasham.Jain Claude Sonnet 4.6 commited on
Commit
8580936
·
1 Parent(s): 4647df7

feat(training): switch from LoRA to QLoRA per mentor recommendation

Browse files

QLoRA = NF4-quantized frozen base + bf16 LoRA adapters trained on top.
unsloth's load_in_4bit=True already implements this; make it explicit:

- lora_alpha: 16 → 32 (= 2×r, follows the QLoRA paper scaling rule)
- lora_dropout: 0 → 0.05 (standard for QLoRA regularisation)
- Add module-level docstring explaining the QLoRA setup clearly
- Note: unsloth calls prepare_model_for_kbit_training() internally so
no manual call needed

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. src/ci_triage_env/training/sft.py +18 -8
src/ci_triage_env/training/sft.py CHANGED
@@ -1,4 +1,10 @@
1
- """SFT warmstart trainer — Qwen3-4B + LoRA via unsloth.
 
 
 
 
 
 
2
 
3
  All GPU-heavy imports (unsloth, trl, torch) are lazy so the module is
4
  importable without a GPU for testing.
@@ -6,8 +12,8 @@ importable without a GPU for testing.
6
 
7
  from __future__ import annotations
8
 
9
- # unsloth hosts optimised weights; the bnb-4bit variant skips on-the-fly quantisation
10
- # so it loads ~2x faster than the base float16 weights.
11
  MODEL_NAME = "unsloth/Qwen3-4B-bnb-4bit"
12
  MAX_SEQ_LEN = 8192
13
 
@@ -16,16 +22,20 @@ def load_model_for_sft(
16
  model_name: str = MODEL_NAME,
17
  max_seq_length: int = MAX_SEQ_LEN,
18
  ):
19
- """Load Qwen3-4B with unsloth 4-bit + LoRA. Requires GPU and unsloth installed."""
20
  from unsloth import FastLanguageModel # type: ignore[import]
21
 
 
 
22
  model, tokenizer = FastLanguageModel.from_pretrained(
23
  model_name=model_name,
24
  max_seq_length=max_seq_length,
25
  load_in_4bit=True,
26
- dtype=None, # auto — bfloat16 on Ampere+
27
  )
28
 
 
 
29
  model = FastLanguageModel.get_peft_model(
30
  model,
31
  r=16,
@@ -33,10 +43,10 @@ def load_model_for_sft(
33
  "q_proj", "k_proj", "v_proj", "o_proj",
34
  "gate_proj", "up_proj", "down_proj",
35
  ],
36
- lora_alpha=16,
37
- lora_dropout=0,
38
  bias="none",
39
- use_gradient_checkpointing="unsloth", # unsloth's gradient checkpointing is 30% faster
40
  random_state=3407,
41
  )
42
  return model, tokenizer
 
1
+ """SFT warmstart trainer — Qwen3-4B + QLoRA via unsloth.
2
+
3
+ QLoRA setup:
4
+ - Base model loaded in NF4 4-bit (frozen) — the "Q" in QLoRA
5
+ - LoRA adapter matrices trained in bf16 on top — the "LoRA" part
6
+ - Double quantization enabled by default in unsloth for ~0.4 bit extra savings
7
+ - unsloth calls prepare_model_for_kbit_training() internally
8
 
9
  All GPU-heavy imports (unsloth, trl, torch) are lazy so the module is
10
  importable without a GPU for testing.
 
12
 
13
  from __future__ import annotations
14
 
15
+ # unsloth hosts optimised weights; the bnb-4bit variant is pre-quantised to NF4
16
+ # so it loads ~2x faster than quantising float16 on the fly.
17
  MODEL_NAME = "unsloth/Qwen3-4B-bnb-4bit"
18
  MAX_SEQ_LEN = 8192
19
 
 
22
  model_name: str = MODEL_NAME,
23
  max_seq_length: int = MAX_SEQ_LEN,
24
  ):
25
+ """Load Qwen3-4B with QLoRA (NF4 base + bf16 LoRA adapters) via unsloth."""
26
  from unsloth import FastLanguageModel # type: ignore[import]
27
 
28
+ # load_in_4bit=True → base weights frozen in NF4 4-bit (QLoRA)
29
+ # dtype=None → adapter compute dtype auto-selected (bf16 on Ampere+)
30
  model, tokenizer = FastLanguageModel.from_pretrained(
31
  model_name=model_name,
32
  max_seq_length=max_seq_length,
33
  load_in_4bit=True,
34
+ dtype=None,
35
  )
36
 
37
+ # QLoRA LoRA config: r=16 is standard for 4B models.
38
+ # lora_alpha=32 (= 2×r) follows the QLoRA paper's scaling recommendation.
39
  model = FastLanguageModel.get_peft_model(
40
  model,
41
  r=16,
 
43
  "q_proj", "k_proj", "v_proj", "o_proj",
44
  "gate_proj", "up_proj", "down_proj",
45
  ],
46
+ lora_alpha=32,
47
+ lora_dropout=0.05,
48
  bias="none",
49
+ use_gradient_checkpointing="unsloth", # 30% faster than standard
50
  random_state=3407,
51
  )
52
  return model, tokenizer