InfoLens / backend /language_checker.py
dqy08's picture
initial beta release
494c9e4
import torch
import gc
from typing import Callable, Dict, List, Optional, Tuple
from .api.utils import round_to_sig_figs
from .pred_topk_format import pred_topk_pairs_from_flat_ids_and_probs
from .class_register import register_model, REGISTERED_MODELS
from .device import DeviceManager
from .model_manager import ensure_model_loaded
from .runtime_config import load_runtime_config, DEFAULT_TOPK
from model_paths import DEFAULT_MODEL, MODEL_PATHS, SEMANTIC_MODEL_PATHS, resolve_hf_path
# 按 id(model) 缓存「仅含 BOS/等价起始符一步 forward」得到的末位词表 logits(全词表,不随分析文本变)
_bos_first_position_logits_cache: Dict[int, torch.Tensor] = {}
def compute_first_token_lm_with_bos_prefix_cache(
model: torch.nn.Module,
tokenizer,
device: torch.device,
first_token_id: int,
effective_topk: int,
) -> Tuple[float, List[Tuple[str, float]]]:
"""
首 token 无左文时的 workaround:与旧版 BOS 前缀一致,对单 token 输入 [bos] 做一步 forward,
将末位 logits(预测首段文本第一个 token 的分布)缓存到 CPU,再在 CPU 上 softmax/topk。
同一 model 实例复用同一份词表 logits,不在每次分析时重复 forward。
"""
mid = id(model)
if mid not in _bos_first_position_logits_cache:
if tokenizer.bos_token_id is not None:
bos_id = int(tokenizer.bos_token_id)
elif tokenizer.eos_token_id is not None:
bos_id = int(tokenizer.eos_token_id)
else:
bos_id = 0
with torch.inference_mode():
bos_in = torch.tensor([[bos_id]], device=device, dtype=torch.long)
out = model(input_ids=bos_in)
# [V]:在 BOS 条件下预测「第一个文本 token」的分布
row = out.logits[0, -1, :].detach().float()
_bos_first_position_logits_cache[mid] = row.cpu()
logits = _bos_first_position_logits_cache[mid]
probs = torch.softmax(logits, dim=-1)
p = float(probs[first_token_id].item())
topk_vals, topk_inds = torch.topk(probs, k=min(effective_topk, probs.shape[0]), dim=-1)
topk_vals = topk_vals.float().numpy()
topk_inds_flat = topk_inds.flatten().tolist()
topk_tokens_decoded = tokenizer.batch_decode(
[[tid] for tid in topk_inds_flat],
skip_special_tokens=False,
)
pred_topk = [
(topk_tokens_decoded[j], round_to_sig_figs(float(topk_vals[j])))
for j in range(len(topk_tokens_decoded))
]
return p, pred_topk
class AbstractLanguageChecker:
"""
Abstract Class that defines the Backend API of GLTR.
To extend the GLTR interface, you need to inherit this and
fill in the defined functions.
"""
def __init__(self):
"""
In the subclass, you need to load all necessary components
for the other functions.
Typically, this will comprise a tokenizer and a model.
"""
self.device = DeviceManager.get_device()
def analyze_text(self, in_text):
"""
Function that GLTR interacts with to analyze text and get token probabilities
Params:
- in_text: str -- The text that you want to analyze
- topk: int, optional -- Desired pred_topk count (default from runtime_config.DEFAULT_TOPK)
Output:
- payload: dict -- The wrapper for results in this function, described below
Payload values
==============
bpe_strings: list of dict -- Each dict contains {"offset": [start, end], "raw": str,
"real_topk": [rank, prob], "pred_topk": [(token, prob), ...]}
- offset: character offsets in the original text [start, end]
- raw: token text extracted from original text
- real_topk: (ranking, prob) of each token(优先级默认0)
- pred_topk: top-k 候选列表(若不可用则为空数组)
"""
raise NotImplementedError
@register_model(name='qwen2.5-0.5b')
class QwenLM(AbstractLanguageChecker):
"""
Qwen 系列模型支持
默认使用 Qwen2.5-0.5B Base 模型(适合计算 surprisal 和信息量)
"""
def __init__(self, model_path=None, model_name=None):
super(QwenLM, self).__init__()
model_name = model_name or getattr(self.__class__, '_registered_model_name', DEFAULT_MODEL)
if model_path is not None and str(model_path).strip():
resolved = str(model_path).strip()
else:
resolved = resolve_hf_path(model_name)
# 加载运行时配置(支持部分覆盖)
self._load_runtime_config(model_name)
self.tokenizer, self.model, self.device = ensure_model_loaded(resolved)
# ============================================================
# 关于 torch.compile() 的性能优化讨论结论:
#
# CPU 环境:
# - 成本 > 收益,不推荐使用
#
# CUDA 环境(如果未来升级到 GPU Space):
# - 加速比:30-70%(显著提升)
# - 编译时间:相对推理时间更短
# - Triton 内核优化:显著减少显存读写
# - 结论:强烈推荐使用,需配合预热确保形状覆盖
# 如需启用,可在此处添加:
# if torch.cuda.is_available() and hasattr(torch, 'compile'):
# self.model = torch.compile(self.model, mode="default")
# # 并在启动时运行预热推理覆盖 chunk_size 长度
# ============================================================
# 初始化分析计数器(用于控制GPU内存统计打印频率)
self._analysis_count = 0
def _load_runtime_config(self, model_name: Optional[str]):
"""
加载运行时配置:基于模型和平台的四层配置合并
Args:
model_name: 模型标识符(如 'qwen3-1.7b')
"""
# 调用配置模块的完整加载流程
# 返回: (platform, max_token_length, chunk_size)
self.platform, self.max_length, self.chunk_size = load_runtime_config(
model_name=model_name or "default_model"
)
def _encode_text(self, in_text: str) -> Tuple[torch.Tensor, List[Tuple[int, int]]]:
"""编码文本并返回 token_ids 和 offsets"""
# 使用 tokenizer 的原生截断功能
enc_out = self.tokenizer(
in_text,
return_tensors='pt',
return_offsets_mapping=True,
max_length=self.max_length,
truncation=True
)
token_ids = enc_out['input_ids']
token_offsets = enc_out['offset_mapping'][0].tolist()
# 通过最后一个 offset 和文本长度对比判断是否截断
if token_offsets:
last_offset_end = token_offsets[-1][1]
if last_offset_end < len(in_text):
# 文本被截断了,警告token截断信息,和字数截断信息
print(f"⚠️ 文本过长,已截断至前 {self.max_length} token ({len(in_text)} char -> {last_offset_end} char)")
token_ids = token_ids.to(self.device)
return token_ids, token_offsets
def _run_inference_and_process_chunked(
self,
token_ids: torch.Tensor,
effective_topk: int,
progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None
) -> Tuple[List[List[Tuple[str, float]]], List[float]]:
"""
分块推理并即时处理:核心内存优化逻辑
利用 KV Cache 分段计算 Logits,计算完立即释放,避免保留全量 Logits。
数值说明:在 float16(如 MPS)上,在「仅前缀 forward」vs「整段 forward」同位置 logits 的逐元素对比,可能出现微小差异;
float16(MPS/CUDA)可能因实现路径出现约 1%的 量级差,非掩码错误。CPU float32 下则完全一致。
"""
seq_len = token_ids.shape[1]
# 使用初始化时根据平台确定的 chunk_size
chunk_size = self.chunk_size
real_probs_list = []
pred_topk_list = []
past_key_values = None
# 预先清理
DeviceManager.clear_cache(self.device)
full_input_ids = token_ids
# 因果 LM:logits[i] 预测 input_ids[i+1];首 token 无左文,不在此循环中计分
# 我们使用 past_key_values 增量推理
# 第一次:输入 input_ids[:, :chunk_size],输出 logits 对应位置 0..chunk_size-1 (预测 1..chunk_size)
total_chunks = (seq_len + chunk_size - 1) // chunk_size
with torch.inference_mode():
for i in range(total_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, seq_len)
current_chunk_len = end_idx - start_idx
# 准备输入(统一逻辑,避免边界 token 重复)
if i == 0:
input_chunk = full_input_ids[:, :end_idx]
else:
input_chunk = full_input_ids[:, start_idx:end_idx]
# 1. 运行推理
outputs = self.model(
input_ids=input_chunk,
past_key_values=past_key_values,
use_cache=True
)
past_key_values = outputs.past_key_values
logits = outputs.logits
# 获取 targets
# full_input_ids[:, 1:] 是所有 targets
# 当前块 targets 范围: [start_idx : end_idx]
chunk_targets = full_input_ids[:, 1+start_idx : 1+end_idx]
valid_len = chunk_targets.shape[1]
if valid_len == 0:
continue
# 最后一块覆盖到序列末尾时,最后一个 logit 位预测的是「下一 token」,需裁掉
current_logits = logits[:, :valid_len, :]
# 2. 处理当前块的 Softmax 和 TopK
probs_chunk = torch.softmax(current_logits, dim=2)
# 提取真实概率
chunk_target_probs = torch.gather(probs_chunk, 2, chunk_targets.unsqueeze(-1))
real_probs_list.extend(chunk_target_probs.flatten().detach().cpu().float().numpy().tolist())
# 提取 TopK
# 由于 chunk_size 已确保小于 MPS_TOPK_BUG_THRESHOLD,所以直接计算
topk_vals, topk_inds = torch.topk(probs_chunk, k=effective_topk, dim=2)
chunk_pred_topk = self._decode_topk_tokens(
topk_vals, topk_inds, effective_topk, valid_len
)
pred_topk_list.extend(chunk_pred_topk)
# 3. 立即释放内存
del logits
del current_logits
del probs_chunk
del chunk_target_probs
# outputs 会在下一次循环时被覆盖,无需手动处理
# 进度更新(基于实际处理的 token 数量)
if progress_callback:
pct = int(end_idx / seq_len * 100) # 推理阶段独立的 0-100%
progress_callback(2, 3, 'inference', pct)
# 循环结束,清理 KV Cache
del past_key_values
DeviceManager.clear_cache(self.device)
return pred_topk_list, real_probs_list
def _decode_topk_tokens(
self,
topk_prob_values: torch.Tensor,
topk_prob_inds: torch.Tensor,
effective_topk: int,
seq_len: int
) -> List[List[Tuple[str, float]]]:
"""解码 TopK tokens 并构建预测列表(长度等于参与 topk 的序列长度)"""
topk_prob_values_cpu = topk_prob_values[0].detach().cpu().float().numpy()
topk_prob_inds_flat = topk_prob_inds[0].cpu().flatten().tolist()
probs_flat = topk_prob_values_cpu.flatten().tolist()
flat_pairs = pred_topk_pairs_from_flat_ids_and_probs(
topk_prob_inds_flat, probs_flat, self.tokenizer
)
return [
flat_pairs[i * effective_topk : (i + 1) * effective_topk]
for i in range(seq_len)
]
def _build_bpe_strings(
self,
token_offsets: List[Tuple[int, int]],
real_topk: List[Tuple[int, float]],
pred_topk: List[List[Tuple[str, float]]],
in_text: str
) -> List[Dict]:
"""构建最终的 BPE 字符串列表"""
# 确保长度一致
min_len = min(len(token_offsets), len(real_topk), len(pred_topk) if pred_topk else len(token_offsets))
bpe_strings = []
for idx in range(min_len):
start, end = token_offsets[idx]
raw_text = in_text[start:end] if start < end else ""
token_payload = {
"offset": [start, end],
"raw": raw_text,
"real_topk": list(real_topk[idx]),
"pred_topk": pred_topk[idx] if pred_topk else []
}
bpe_strings.append(token_payload)
return bpe_strings
def analyze_text(self, in_text: str, progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None) -> Dict[str, List[Dict]]:
"""
计算文本中每个 token 的概率
进度回调参数: (step: int, total_steps: int, stage: str, percentage: Optional[int])
- step: 当前步骤 (1-based)
- total_steps: 总步骤数 (固定为 3)
- stage: 阶段名称 (encoding/inference/processing)
- percentage: 可选的百分比,仅在 inference 阶段提供
"""
TOTAL_STEPS = 3
try:
# Step 1: 编码文本
if progress_callback:
progress_callback(1, TOTAL_STEPS, 'encoding', None)
token_ids, token_offsets = self._encode_text(in_text)
# Step 2: 分块推理并处理(带百分比进度)
# 这取代了原来的 _run_model_inference, MPS 流式处理, 和 _process_topk
if progress_callback:
progress_callback(2, 3, 'inference', 0)
pred_topk, real_topk_probs = self._run_inference_and_process_chunked(
token_ids, DEFAULT_TOPK, progress_callback
)
# Step 3: 构建结果
if progress_callback:
progress_callback(3, TOTAL_STEPS, 'processing', None)
if token_ids.shape[1] >= 1:
p0, pred0 = compute_first_token_lm_with_bos_prefix_cache(
self.model,
self.tokenizer,
self.device,
int(token_ids[0, 0].item()),
DEFAULT_TOPK,
)
pred_topk.insert(0, pred0)
real_topk_probs.insert(0, p0)
seq_len = len(real_topk_probs)
real_topk = list(zip([0] * seq_len, [round_to_sig_figs(p) for p in real_topk_probs]))
bpe_strings = self._build_bpe_strings(token_offsets, real_topk, pred_topk, in_text)
# 最终清理
DeviceManager.clear_cache(self.device)
gc.collect()
# 更新分析计数器
self._analysis_count += 1
# 打印分析任务完成后的内存统计(第1、11、21...次分析后打印)
if self.device.type == "cuda" and (self._analysis_count - 1) % 10 == 0:
device_idx = self.device.index if self.device.index is not None else 0
DeviceManager.print_cuda_memory_summary(device=device_idx)
return {'bpe_strings': bpe_strings}
except Exception as e:
import traceback
print(f"❌ Error in QwenLM.analyze_text: {e}")
traceback.print_exc()
return {'bpe_strings': []}
# _cleanup_tensors 方法已被移除,因为不再需要显式清理小张量
# ============================================================
# 自动注册:根据 MODEL_PATHS 与 SEMANTIC_MODEL_PATHS 自动注册所有模型
# ============================================================
# 只需要在 model_paths.py 中添加模型路径,即可自动注册
# 无需手动创建子类,实现 DRY 原则
def _auto_register_models():
"""自动注册 MODEL_PATHS 与 SEMANTIC_MODEL_PATHS 中的所有模型"""
for model_name in (*MODEL_PATHS.keys(), *SEMANTIC_MODEL_PATHS.keys()):
if model_name not in REGISTERED_MODELS:
# 动态创建模型类并注册
# 使用闭包捕获当前 model_name
def make_init():
def __init__(self):
QwenLM.__init__(self)
return __init__
model_class = type(
f'QwenLM_{model_name.replace(".", "_").replace("-", "_")}',
(QwenLM,),
{
'__init__': make_init(),
'__doc__': f'{model_name} 模型支持(自动注册)'
}
)
register_model(model_name)(model_class)
# 执行自动注册
_auto_register_models()