File size: 6,080 Bytes
494c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
"""
Causal LM 模型加载:设备策略与加载逻辑统一封装

供 language_checker.QwenLM(信息密度分析)与 model_manager.ensure_model_loaded 共用,
消除重复的设备分支、量化配置、加载后处理等逻辑。

加载策略说明:
- INT8 量化:bitsandbytes 8bit,device_map="cpu"/"auto",减少约 4 倍内存
- CPU 手动模式:无 device_map,.to(device),默认 float32
- GPU/MPS 自动模式:device_map="auto",float16

dtype/设备与因果 LM 在「仅前缀 forward」vs「整段 forward」同位置 logits 的逐元素对比:
float32(CPU)常完全一致;float16(MPS/CUDA)可能因实现路径出现约 1e-2 量级差,非掩码错误。
复现与说明见 scripts/reproduce_logits_triple_path.py、scripts/prove_fp16_gemm_shape_sensitivity.py。
"""

import os
import time
from typing import Any, Dict, Optional

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import is_flash_attn_2_available

from .device import DeviceManager
from .load_utils import resolve_and_load
from .quantization_config import get_quantization_config


def get_device_load_strategy(device: torch.device) -> Dict[str, Any]:
    """
    根据设备推断加载策略(device_map、dtype、use_int8 等)。

    打印设备模式说明,与 QwenLM 风格一致。
    环境变量:FORCE_INT8=1 / CPU_FORCE_BFLOAT16=1
    返回供 load_causal_lm 使用的参数字典。
    """
    qconfig = get_quantization_config(device)
    use_int8 = qconfig.use_int8
    device_map = None
    dtype = qconfig.dtype
    use_low_cpu_mem = False

    if device.type == "cpu":
        print("🔧 CPU 模式:手动控制设备分配")
        if use_int8:
            device_map = "cpu"
            print("⚠️  启用 INT8 量化(FORCE_INT8=1,实验性,在某些情况下会降低性能)")
        elif dtype == torch.bfloat16:
            use_low_cpu_mem = True
            print("⚠️  启用 bfloat16(CPU_FORCE_BFLOAT16=1,需硬件支持 AVX-512_BF16 或 AMX,否则可能极慢)")
        else:
            use_low_cpu_mem = True
            print("🔧 dtype: float32")  # 默认: float32
    elif device.type == "cuda":
        print("🔧 CUDA 模式:自动设备分配")
        device_map = "auto"
        use_low_cpu_mem = True
        if use_int8:
            print("⚠️  启用 INT8 量化(FORCE_INT8=1)")
        else:
            print("🔧 dtype: float16")
        print("🔧 device_map: auto")
    else:
        # MPS 模式:自动设备分配 + float16(MPS 不支持 INT8 量化)
        print(f"🔧 {device.type.upper()} 模式:自动设备分配")
        if os.environ.get("FORCE_INT8") == "1":
            print("⚠️  MPS 不支持 INT8 量化,已忽略 FORCE_INT8=1 环境变量")
        device_map = "auto"
        use_low_cpu_mem = True
        print("🔧 dtype: float16")
        print("🔧 device_map: auto")

    return {
        "device_map": device_map,
        "dtype": dtype,
        "use_low_cpu_mem": use_low_cpu_mem,
        "use_int8": use_int8,
    }


def attn_implementation_for_device(device: torch.device) -> str:
    """
    非 CUDA:eager,兼容性最好(CPU / MPS 等)。
    CUDA:已安装 flash-attn 时用 flash_attention_2;否则 eager(不使用 sdpa)。
    """
    if device.type != "cuda":
        return "eager"
    if is_flash_attn_2_available():
        return "flash_attention_2"
    return "eager"


def load_causal_lm(
    model_path: str,
    device: torch.device,
    *,
    attn_implementation: Optional[str] = None,
    extra_model_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.nn.Module:
    """
    加载 Causal LM 模型,统一处理设备策略、量化、加载后处理。

    Args:
        model_path: HuggingFace 模型路径或本地路径
        device: 目标设备
        attn_implementation: 可选;未传时可在外层用 attn_implementation_for_device(device)
        extra_model_kwargs: 可选,额外传给 from_pretrained 的参数

    Returns:
        已 eval() 的模型
    """
    strategy = get_device_load_strategy(device)
    device_map = strategy["device_map"]
    dtype = strategy["dtype"]
    use_low_cpu_mem = strategy["use_low_cpu_mem"]
    use_int8 = strategy["use_int8"]

    load_kw: Dict[str, Any] = {
        "trust_remote_code": True,
        "low_cpu_mem_usage": use_low_cpu_mem or use_int8,
    }
    if attn_implementation is not None:
        load_kw["attn_implementation"] = attn_implementation
    if extra_model_kwargs:
        load_kw.update(extra_model_kwargs)

    def _load(path: str, lf: bool):
        kw = dict(local_files_only=lf, **load_kw)
        if use_int8:
            from transformers import BitsAndBytesConfig
            return AutoModelForCausalLM.from_pretrained(
                path,
                quantization_config=BitsAndBytesConfig(load_in_8bit=True),
                device_map=device_map,
                **kw,
            )
        if device_map:
            return AutoModelForCausalLM.from_pretrained(
                path,
                device_map=device_map,
                dtype=dtype,
                **kw,
            )
        return AutoModelForCausalLM.from_pretrained(
            path, dtype=dtype, **kw
        ).to(device)

    t0 = time.perf_counter()
    model = resolve_and_load(model_path, _load)
    load_time = time.perf_counter() - t0

    DeviceManager.print_model_load_stats(model, load_time)
    model.eval()
    if device.type == "cuda":
        device_idx = device.index if device.index is not None else 0
        DeviceManager.print_cuda_memory_summary(device=device_idx)
    return model


def load_tokenizer(model_path: str):
    """加载 tokenizer。本地优先时先解析为缓存路径,避免 tokenizer 内部 model_info 联网。"""

    def _load(path: str, lf: bool):
        return AutoTokenizer.from_pretrained(
            path, trust_remote_code=True, local_files_only=lf
        )

    return resolve_and_load(model_path, _load)