InfoLens / backend /model_loader.py
dqy08's picture
initial beta release
494c9e4
"""
Causal LM 模型加载:设备策略与加载逻辑统一封装
供 language_checker.QwenLM(信息密度分析)与 model_manager.ensure_model_loaded 共用,
消除重复的设备分支、量化配置、加载后处理等逻辑。
加载策略说明:
- INT8 量化:bitsandbytes 8bit,device_map="cpu"/"auto",减少约 4 倍内存
- CPU 手动模式:无 device_map,.to(device),默认 float32
- GPU/MPS 自动模式:device_map="auto",float16
dtype/设备与因果 LM 在「仅前缀 forward」vs「整段 forward」同位置 logits 的逐元素对比:
float32(CPU)常完全一致;float16(MPS/CUDA)可能因实现路径出现约 1e-2 量级差,非掩码错误。
复现与说明见 scripts/reproduce_logits_triple_path.py、scripts/prove_fp16_gemm_shape_sensitivity.py。
"""
import os
import time
from typing import Any, Dict, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.utils import is_flash_attn_2_available
from .device import DeviceManager
from .load_utils import resolve_and_load
from .quantization_config import get_quantization_config
def get_device_load_strategy(device: torch.device) -> Dict[str, Any]:
"""
根据设备推断加载策略(device_map、dtype、use_int8 等)。
打印设备模式说明,与 QwenLM 风格一致。
环境变量:FORCE_INT8=1 / CPU_FORCE_BFLOAT16=1
返回供 load_causal_lm 使用的参数字典。
"""
qconfig = get_quantization_config(device)
use_int8 = qconfig.use_int8
device_map = None
dtype = qconfig.dtype
use_low_cpu_mem = False
if device.type == "cpu":
print("🔧 CPU 模式:手动控制设备分配")
if use_int8:
device_map = "cpu"
print("⚠️ 启用 INT8 量化(FORCE_INT8=1,实验性,在某些情况下会降低性能)")
elif dtype == torch.bfloat16:
use_low_cpu_mem = True
print("⚠️ 启用 bfloat16(CPU_FORCE_BFLOAT16=1,需硬件支持 AVX-512_BF16 或 AMX,否则可能极慢)")
else:
use_low_cpu_mem = True
print("🔧 dtype: float32") # 默认: float32
elif device.type == "cuda":
print("🔧 CUDA 模式:自动设备分配")
device_map = "auto"
use_low_cpu_mem = True
if use_int8:
print("⚠️ 启用 INT8 量化(FORCE_INT8=1)")
else:
print("🔧 dtype: float16")
print("🔧 device_map: auto")
else:
# MPS 模式:自动设备分配 + float16(MPS 不支持 INT8 量化)
print(f"🔧 {device.type.upper()} 模式:自动设备分配")
if os.environ.get("FORCE_INT8") == "1":
print("⚠️ MPS 不支持 INT8 量化,已忽略 FORCE_INT8=1 环境变量")
device_map = "auto"
use_low_cpu_mem = True
print("🔧 dtype: float16")
print("🔧 device_map: auto")
return {
"device_map": device_map,
"dtype": dtype,
"use_low_cpu_mem": use_low_cpu_mem,
"use_int8": use_int8,
}
def attn_implementation_for_device(device: torch.device) -> str:
"""
非 CUDA:eager,兼容性最好(CPU / MPS 等)。
CUDA:已安装 flash-attn 时用 flash_attention_2;否则 eager(不使用 sdpa)。
"""
if device.type != "cuda":
return "eager"
if is_flash_attn_2_available():
return "flash_attention_2"
return "eager"
def load_causal_lm(
model_path: str,
device: torch.device,
*,
attn_implementation: Optional[str] = None,
extra_model_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.nn.Module:
"""
加载 Causal LM 模型,统一处理设备策略、量化、加载后处理。
Args:
model_path: HuggingFace 模型路径或本地路径
device: 目标设备
attn_implementation: 可选;未传时可在外层用 attn_implementation_for_device(device)
extra_model_kwargs: 可选,额外传给 from_pretrained 的参数
Returns:
已 eval() 的模型
"""
strategy = get_device_load_strategy(device)
device_map = strategy["device_map"]
dtype = strategy["dtype"]
use_low_cpu_mem = strategy["use_low_cpu_mem"]
use_int8 = strategy["use_int8"]
load_kw: Dict[str, Any] = {
"trust_remote_code": True,
"low_cpu_mem_usage": use_low_cpu_mem or use_int8,
}
if attn_implementation is not None:
load_kw["attn_implementation"] = attn_implementation
if extra_model_kwargs:
load_kw.update(extra_model_kwargs)
def _load(path: str, lf: bool):
kw = dict(local_files_only=lf, **load_kw)
if use_int8:
from transformers import BitsAndBytesConfig
return AutoModelForCausalLM.from_pretrained(
path,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map=device_map,
**kw,
)
if device_map:
return AutoModelForCausalLM.from_pretrained(
path,
device_map=device_map,
dtype=dtype,
**kw,
)
return AutoModelForCausalLM.from_pretrained(
path, dtype=dtype, **kw
).to(device)
t0 = time.perf_counter()
model = resolve_and_load(model_path, _load)
load_time = time.perf_counter() - t0
DeviceManager.print_model_load_stats(model, load_time)
model.eval()
if device.type == "cuda":
device_idx = device.index if device.index is not None else 0
DeviceManager.print_cuda_memory_summary(device=device_idx)
return model
def load_tokenizer(model_path: str):
"""加载 tokenizer。本地优先时先解析为缓存路径,避免 tokenizer 内部 model_info 联网。"""
def _load(path: str, lf: bool):
return AutoTokenizer.from_pretrained(
path, trust_remote_code=True, local_files_only=lf
)
return resolve_and_load(model_path, _load)