| """ |
| Causal LM 模型加载:设备策略与加载逻辑统一封装 |
| |
| 供 language_checker.QwenLM(信息密度分析)与 model_manager.ensure_model_loaded 共用, |
| 消除重复的设备分支、量化配置、加载后处理等逻辑。 |
| |
| 加载策略说明: |
| - INT8 量化:bitsandbytes 8bit,device_map="cpu"/"auto",减少约 4 倍内存 |
| - CPU 手动模式:无 device_map,.to(device),默认 float32 |
| - GPU/MPS 自动模式:device_map="auto",float16 |
| |
| dtype/设备与因果 LM 在「仅前缀 forward」vs「整段 forward」同位置 logits 的逐元素对比: |
| float32(CPU)常完全一致;float16(MPS/CUDA)可能因实现路径出现约 1e-2 量级差,非掩码错误。 |
| 复现与说明见 scripts/reproduce_logits_triple_path.py、scripts/prove_fp16_gemm_shape_sensitivity.py。 |
| """ |
|
|
| import os |
| import time |
| from typing import Any, Dict, Optional |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from transformers.utils import is_flash_attn_2_available |
|
|
| from .device import DeviceManager |
| from .load_utils import resolve_and_load |
| from .quantization_config import get_quantization_config |
|
|
|
|
| def get_device_load_strategy(device: torch.device) -> Dict[str, Any]: |
| """ |
| 根据设备推断加载策略(device_map、dtype、use_int8 等)。 |
| |
| 打印设备模式说明,与 QwenLM 风格一致。 |
| 环境变量:FORCE_INT8=1 / CPU_FORCE_BFLOAT16=1 |
| 返回供 load_causal_lm 使用的参数字典。 |
| """ |
| qconfig = get_quantization_config(device) |
| use_int8 = qconfig.use_int8 |
| device_map = None |
| dtype = qconfig.dtype |
| use_low_cpu_mem = False |
|
|
| if device.type == "cpu": |
| print("🔧 CPU 模式:手动控制设备分配") |
| if use_int8: |
| device_map = "cpu" |
| print("⚠️ 启用 INT8 量化(FORCE_INT8=1,实验性,在某些情况下会降低性能)") |
| elif dtype == torch.bfloat16: |
| use_low_cpu_mem = True |
| print("⚠️ 启用 bfloat16(CPU_FORCE_BFLOAT16=1,需硬件支持 AVX-512_BF16 或 AMX,否则可能极慢)") |
| else: |
| use_low_cpu_mem = True |
| print("🔧 dtype: float32") |
| elif device.type == "cuda": |
| print("🔧 CUDA 模式:自动设备分配") |
| device_map = "auto" |
| use_low_cpu_mem = True |
| if use_int8: |
| print("⚠️ 启用 INT8 量化(FORCE_INT8=1)") |
| else: |
| print("🔧 dtype: float16") |
| print("🔧 device_map: auto") |
| else: |
| |
| print(f"🔧 {device.type.upper()} 模式:自动设备分配") |
| if os.environ.get("FORCE_INT8") == "1": |
| print("⚠️ MPS 不支持 INT8 量化,已忽略 FORCE_INT8=1 环境变量") |
| device_map = "auto" |
| use_low_cpu_mem = True |
| print("🔧 dtype: float16") |
| print("🔧 device_map: auto") |
|
|
| return { |
| "device_map": device_map, |
| "dtype": dtype, |
| "use_low_cpu_mem": use_low_cpu_mem, |
| "use_int8": use_int8, |
| } |
|
|
|
|
| def attn_implementation_for_device(device: torch.device) -> str: |
| """ |
| 非 CUDA:eager,兼容性最好(CPU / MPS 等)。 |
| CUDA:已安装 flash-attn 时用 flash_attention_2;否则 eager(不使用 sdpa)。 |
| """ |
| if device.type != "cuda": |
| return "eager" |
| if is_flash_attn_2_available(): |
| return "flash_attention_2" |
| return "eager" |
|
|
|
|
| def load_causal_lm( |
| model_path: str, |
| device: torch.device, |
| *, |
| attn_implementation: Optional[str] = None, |
| extra_model_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> torch.nn.Module: |
| """ |
| 加载 Causal LM 模型,统一处理设备策略、量化、加载后处理。 |
| |
| Args: |
| model_path: HuggingFace 模型路径或本地路径 |
| device: 目标设备 |
| attn_implementation: 可选;未传时可在外层用 attn_implementation_for_device(device) |
| extra_model_kwargs: 可选,额外传给 from_pretrained 的参数 |
| |
| Returns: |
| 已 eval() 的模型 |
| """ |
| strategy = get_device_load_strategy(device) |
| device_map = strategy["device_map"] |
| dtype = strategy["dtype"] |
| use_low_cpu_mem = strategy["use_low_cpu_mem"] |
| use_int8 = strategy["use_int8"] |
|
|
| load_kw: Dict[str, Any] = { |
| "trust_remote_code": True, |
| "low_cpu_mem_usage": use_low_cpu_mem or use_int8, |
| } |
| if attn_implementation is not None: |
| load_kw["attn_implementation"] = attn_implementation |
| if extra_model_kwargs: |
| load_kw.update(extra_model_kwargs) |
|
|
| def _load(path: str, lf: bool): |
| kw = dict(local_files_only=lf, **load_kw) |
| if use_int8: |
| from transformers import BitsAndBytesConfig |
| return AutoModelForCausalLM.from_pretrained( |
| path, |
| quantization_config=BitsAndBytesConfig(load_in_8bit=True), |
| device_map=device_map, |
| **kw, |
| ) |
| if device_map: |
| return AutoModelForCausalLM.from_pretrained( |
| path, |
| device_map=device_map, |
| dtype=dtype, |
| **kw, |
| ) |
| return AutoModelForCausalLM.from_pretrained( |
| path, dtype=dtype, **kw |
| ).to(device) |
|
|
| t0 = time.perf_counter() |
| model = resolve_and_load(model_path, _load) |
| load_time = time.perf_counter() - t0 |
|
|
| DeviceManager.print_model_load_stats(model, load_time) |
| model.eval() |
| if device.type == "cuda": |
| device_idx = device.index if device.index is not None else 0 |
| DeviceManager.print_cuda_memory_summary(device=device_idx) |
| return model |
|
|
|
|
| def load_tokenizer(model_path: str): |
| """加载 tokenizer。本地优先时先解析为缓存路径,避免 tokenizer 内部 model_info 联网。""" |
|
|
| def _load(path: str, lf: bool): |
| return AutoTokenizer.from_pretrained( |
| path, trust_remote_code=True, local_files_only=lf |
| ) |
|
|
| return resolve_and_load(model_path, _load) |
|
|