File size: 749 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""LoRA / QLoRA utilities."""

from __future__ import annotations

from typing import Any


def build_lora_config(rank: int = 16, alpha: int = 32, dropout: float = 0.05) -> dict[str, Any]:
    return {
        "r": rank,
        "lora_alpha": alpha,
        "lora_dropout": dropout,
        "bias": "none",
        "task_type": "CAUSAL_LM",
    }


def build_qlora_config(rank: int = 16, alpha: int = 32, dropout: float = 0.05) -> dict[str, Any]:
    base = build_lora_config(rank=rank, alpha=alpha, dropout=dropout)
    base.update(
        {
            "load_in_4bit": True,
            "bnb_4bit_quant_type": "nf4",
            "bnb_4bit_compute_dtype": "bfloat16",
            "bnb_4bit_use_double_quant": True,
        }
    )
    return base