| """ |
| 量化配置(语义分析、信息密度分析共用) |
| |
| 从环境变量读取并返回设备相关的量化策略: |
| - FORCE_INT8=1: INT8 量化(CPU/CUDA 支持,MPS 不支持) |
| - CPU_FORCE_BFLOAT16=1: CPU 使用 bfloat16 |
| """ |
|
|
| import os |
| from typing import NamedTuple |
|
|
| import torch |
|
|
|
|
| class QuantizationConfig(NamedTuple): |
| """量化配置,语义模型和信息密度模型共用""" |
| use_int8: bool |
| dtype: torch.dtype |
|
|
|
|
| def get_quantization_config(device: torch.device) -> QuantizationConfig: |
| """ |
| 根据设备和环境变量返回量化配置。 |
| |
| Returns: |
| QuantizationConfig: use_int8, dtype |
| """ |
| force_int8 = os.environ.get("FORCE_INT8") == "1" |
| force_bfloat16 = os.environ.get("CPU_FORCE_BFLOAT16") == "1" |
|
|
| if device.type == "cpu": |
| use_int8 = force_int8 |
| dtype = torch.bfloat16 if force_bfloat16 else torch.float32 |
| elif device.type == "cuda": |
| use_int8 = force_int8 |
| dtype = torch.float16 |
| else: |
| |
| use_int8 = False |
| dtype = torch.float16 |
|
|
| return QuantizationConfig(use_int8=use_int8, dtype=dtype) |
|
|