Spaces:
Sleeping
Sleeping
Prasham.Jain Claude Sonnet 4.6 commited on
Commit ·
8580936
1
Parent(s): 4647df7
feat(training): switch from LoRA to QLoRA per mentor recommendation
Browse filesQLoRA = NF4-quantized frozen base + bf16 LoRA adapters trained on top.
unsloth's load_in_4bit=True already implements this; make it explicit:
- lora_alpha: 16 → 32 (= 2×r, follows the QLoRA paper scaling rule)
- lora_dropout: 0 → 0.05 (standard for QLoRA regularisation)
- Add module-level docstring explaining the QLoRA setup clearly
- Note: unsloth calls prepare_model_for_kbit_training() internally so
no manual call needed
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
src/ci_triage_env/training/sft.py
CHANGED
|
@@ -1,4 +1,10 @@
|
|
| 1 |
-
"""SFT warmstart trainer — Qwen3-4B +
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
All GPU-heavy imports (unsloth, trl, torch) are lazy so the module is
|
| 4 |
importable without a GPU for testing.
|
|
@@ -6,8 +12,8 @@ importable without a GPU for testing.
|
|
| 6 |
|
| 7 |
from __future__ import annotations
|
| 8 |
|
| 9 |
-
# unsloth hosts optimised weights; the bnb-4bit variant
|
| 10 |
-
# so it loads ~2x faster than
|
| 11 |
MODEL_NAME = "unsloth/Qwen3-4B-bnb-4bit"
|
| 12 |
MAX_SEQ_LEN = 8192
|
| 13 |
|
|
@@ -16,16 +22,20 @@ def load_model_for_sft(
|
|
| 16 |
model_name: str = MODEL_NAME,
|
| 17 |
max_seq_length: int = MAX_SEQ_LEN,
|
| 18 |
):
|
| 19 |
-
"""Load Qwen3-4B with
|
| 20 |
from unsloth import FastLanguageModel # type: ignore[import]
|
| 21 |
|
|
|
|
|
|
|
| 22 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 23 |
model_name=model_name,
|
| 24 |
max_seq_length=max_seq_length,
|
| 25 |
load_in_4bit=True,
|
| 26 |
-
dtype=None,
|
| 27 |
)
|
| 28 |
|
|
|
|
|
|
|
| 29 |
model = FastLanguageModel.get_peft_model(
|
| 30 |
model,
|
| 31 |
r=16,
|
|
@@ -33,10 +43,10 @@ def load_model_for_sft(
|
|
| 33 |
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 34 |
"gate_proj", "up_proj", "down_proj",
|
| 35 |
],
|
| 36 |
-
lora_alpha=
|
| 37 |
-
lora_dropout=0,
|
| 38 |
bias="none",
|
| 39 |
-
use_gradient_checkpointing="unsloth", #
|
| 40 |
random_state=3407,
|
| 41 |
)
|
| 42 |
return model, tokenizer
|
|
|
|
| 1 |
+
"""SFT warmstart trainer — Qwen3-4B + QLoRA via unsloth.
|
| 2 |
+
|
| 3 |
+
QLoRA setup:
|
| 4 |
+
- Base model loaded in NF4 4-bit (frozen) — the "Q" in QLoRA
|
| 5 |
+
- LoRA adapter matrices trained in bf16 on top — the "LoRA" part
|
| 6 |
+
- Double quantization enabled by default in unsloth for ~0.4 bit extra savings
|
| 7 |
+
- unsloth calls prepare_model_for_kbit_training() internally
|
| 8 |
|
| 9 |
All GPU-heavy imports (unsloth, trl, torch) are lazy so the module is
|
| 10 |
importable without a GPU for testing.
|
|
|
|
| 12 |
|
| 13 |
from __future__ import annotations
|
| 14 |
|
| 15 |
+
# unsloth hosts optimised weights; the bnb-4bit variant is pre-quantised to NF4
|
| 16 |
+
# so it loads ~2x faster than quantising float16 on the fly.
|
| 17 |
MODEL_NAME = "unsloth/Qwen3-4B-bnb-4bit"
|
| 18 |
MAX_SEQ_LEN = 8192
|
| 19 |
|
|
|
|
| 22 |
model_name: str = MODEL_NAME,
|
| 23 |
max_seq_length: int = MAX_SEQ_LEN,
|
| 24 |
):
|
| 25 |
+
"""Load Qwen3-4B with QLoRA (NF4 base + bf16 LoRA adapters) via unsloth."""
|
| 26 |
from unsloth import FastLanguageModel # type: ignore[import]
|
| 27 |
|
| 28 |
+
# load_in_4bit=True → base weights frozen in NF4 4-bit (QLoRA)
|
| 29 |
+
# dtype=None → adapter compute dtype auto-selected (bf16 on Ampere+)
|
| 30 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 31 |
model_name=model_name,
|
| 32 |
max_seq_length=max_seq_length,
|
| 33 |
load_in_4bit=True,
|
| 34 |
+
dtype=None,
|
| 35 |
)
|
| 36 |
|
| 37 |
+
# QLoRA LoRA config: r=16 is standard for 4B models.
|
| 38 |
+
# lora_alpha=32 (= 2×r) follows the QLoRA paper's scaling recommendation.
|
| 39 |
model = FastLanguageModel.get_peft_model(
|
| 40 |
model,
|
| 41 |
r=16,
|
|
|
|
| 43 |
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 44 |
"gate_proj", "up_proj", "down_proj",
|
| 45 |
],
|
| 46 |
+
lora_alpha=32,
|
| 47 |
+
lora_dropout=0.05,
|
| 48 |
bias="none",
|
| 49 |
+
use_gradient_checkpointing="unsloth", # 30% faster than standard
|
| 50 |
random_state=3407,
|
| 51 |
)
|
| 52 |
return model, tokenizer
|