| """LoRA / QLoRA utilities.""" | |
| from __future__ import annotations | |
| from typing import Any | |
| def build_lora_config(rank: int = 16, alpha: int = 32, dropout: float = 0.05) -> dict[str, Any]: | |
| return { | |
| "r": rank, | |
| "lora_alpha": alpha, | |
| "lora_dropout": dropout, | |
| "bias": "none", | |
| "task_type": "CAUSAL_LM", | |
| } | |
| def build_qlora_config(rank: int = 16, alpha: int = 32, dropout: float = 0.05) -> dict[str, Any]: | |
| base = build_lora_config(rank=rank, alpha=alpha, dropout=dropout) | |
| base.update( | |
| { | |
| "load_in_4bit": True, | |
| "bnb_4bit_quant_type": "nf4", | |
| "bnb_4bit_compute_dtype": "bfloat16", | |
| "bnb_4bit_use_double_quant": True, | |
| } | |
| ) | |
| return base | |