File size: 17,607 Bytes
494c9e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 | import torch
import gc
from typing import Callable, Dict, List, Optional, Tuple
from .api.utils import round_to_sig_figs
from .pred_topk_format import pred_topk_pairs_from_flat_ids_and_probs
from .class_register import register_model, REGISTERED_MODELS
from .device import DeviceManager
from .model_manager import ensure_model_loaded
from .runtime_config import load_runtime_config, DEFAULT_TOPK
from model_paths import DEFAULT_MODEL, MODEL_PATHS, SEMANTIC_MODEL_PATHS, resolve_hf_path
# 按 id(model) 缓存「仅含 BOS/等价起始符一步 forward」得到的末位词表 logits(全词表,不随分析文本变)
_bos_first_position_logits_cache: Dict[int, torch.Tensor] = {}
def compute_first_token_lm_with_bos_prefix_cache(
model: torch.nn.Module,
tokenizer,
device: torch.device,
first_token_id: int,
effective_topk: int,
) -> Tuple[float, List[Tuple[str, float]]]:
"""
首 token 无左文时的 workaround:与旧版 BOS 前缀一致,对单 token 输入 [bos] 做一步 forward,
将末位 logits(预测首段文本第一个 token 的分布)缓存到 CPU,再在 CPU 上 softmax/topk。
同一 model 实例复用同一份词表 logits,不在每次分析时重复 forward。
"""
mid = id(model)
if mid not in _bos_first_position_logits_cache:
if tokenizer.bos_token_id is not None:
bos_id = int(tokenizer.bos_token_id)
elif tokenizer.eos_token_id is not None:
bos_id = int(tokenizer.eos_token_id)
else:
bos_id = 0
with torch.inference_mode():
bos_in = torch.tensor([[bos_id]], device=device, dtype=torch.long)
out = model(input_ids=bos_in)
# [V]:在 BOS 条件下预测「第一个文本 token」的分布
row = out.logits[0, -1, :].detach().float()
_bos_first_position_logits_cache[mid] = row.cpu()
logits = _bos_first_position_logits_cache[mid]
probs = torch.softmax(logits, dim=-1)
p = float(probs[first_token_id].item())
topk_vals, topk_inds = torch.topk(probs, k=min(effective_topk, probs.shape[0]), dim=-1)
topk_vals = topk_vals.float().numpy()
topk_inds_flat = topk_inds.flatten().tolist()
topk_tokens_decoded = tokenizer.batch_decode(
[[tid] for tid in topk_inds_flat],
skip_special_tokens=False,
)
pred_topk = [
(topk_tokens_decoded[j], round_to_sig_figs(float(topk_vals[j])))
for j in range(len(topk_tokens_decoded))
]
return p, pred_topk
class AbstractLanguageChecker:
"""
Abstract Class that defines the Backend API of GLTR.
To extend the GLTR interface, you need to inherit this and
fill in the defined functions.
"""
def __init__(self):
"""
In the subclass, you need to load all necessary components
for the other functions.
Typically, this will comprise a tokenizer and a model.
"""
self.device = DeviceManager.get_device()
def analyze_text(self, in_text):
"""
Function that GLTR interacts with to analyze text and get token probabilities
Params:
- in_text: str -- The text that you want to analyze
- topk: int, optional -- Desired pred_topk count (default from runtime_config.DEFAULT_TOPK)
Output:
- payload: dict -- The wrapper for results in this function, described below
Payload values
==============
bpe_strings: list of dict -- Each dict contains {"offset": [start, end], "raw": str,
"real_topk": [rank, prob], "pred_topk": [(token, prob), ...]}
- offset: character offsets in the original text [start, end]
- raw: token text extracted from original text
- real_topk: (ranking, prob) of each token(优先级默认0)
- pred_topk: top-k 候选列表(若不可用则为空数组)
"""
raise NotImplementedError
@register_model(name='qwen2.5-0.5b')
class QwenLM(AbstractLanguageChecker):
"""
Qwen 系列模型支持
默认使用 Qwen2.5-0.5B Base 模型(适合计算 surprisal 和信息量)
"""
def __init__(self, model_path=None, model_name=None):
super(QwenLM, self).__init__()
model_name = model_name or getattr(self.__class__, '_registered_model_name', DEFAULT_MODEL)
if model_path is not None and str(model_path).strip():
resolved = str(model_path).strip()
else:
resolved = resolve_hf_path(model_name)
# 加载运行时配置(支持部分覆盖)
self._load_runtime_config(model_name)
self.tokenizer, self.model, self.device = ensure_model_loaded(resolved)
# ============================================================
# 关于 torch.compile() 的性能优化讨论结论:
#
# CPU 环境:
# - 成本 > 收益,不推荐使用
#
# CUDA 环境(如果未来升级到 GPU Space):
# - 加速比:30-70%(显著提升)
# - 编译时间:相对推理时间更短
# - Triton 内核优化:显著减少显存读写
# - 结论:强烈推荐使用,需配合预热确保形状覆盖
# 如需启用,可在此处添加:
# if torch.cuda.is_available() and hasattr(torch, 'compile'):
# self.model = torch.compile(self.model, mode="default")
# # 并在启动时运行预热推理覆盖 chunk_size 长度
# ============================================================
# 初始化分析计数器(用于控制GPU内存统计打印频率)
self._analysis_count = 0
def _load_runtime_config(self, model_name: Optional[str]):
"""
加载运行时配置:基于模型和平台的四层配置合并
Args:
model_name: 模型标识符(如 'qwen3-1.7b')
"""
# 调用配置模块的完整加载流程
# 返回: (platform, max_token_length, chunk_size)
self.platform, self.max_length, self.chunk_size = load_runtime_config(
model_name=model_name or "default_model"
)
def _encode_text(self, in_text: str) -> Tuple[torch.Tensor, List[Tuple[int, int]]]:
"""编码文本并返回 token_ids 和 offsets"""
# 使用 tokenizer 的原生截断功能
enc_out = self.tokenizer(
in_text,
return_tensors='pt',
return_offsets_mapping=True,
max_length=self.max_length,
truncation=True
)
token_ids = enc_out['input_ids']
token_offsets = enc_out['offset_mapping'][0].tolist()
# 通过最后一个 offset 和文本长度对比判断是否截断
if token_offsets:
last_offset_end = token_offsets[-1][1]
if last_offset_end < len(in_text):
# 文本被截断了,警告token截断信息,和字数截断信息
print(f"⚠️ 文本过长,已截断至前 {self.max_length} token ({len(in_text)} char -> {last_offset_end} char)")
token_ids = token_ids.to(self.device)
return token_ids, token_offsets
def _run_inference_and_process_chunked(
self,
token_ids: torch.Tensor,
effective_topk: int,
progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None
) -> Tuple[List[List[Tuple[str, float]]], List[float]]:
"""
分块推理并即时处理:核心内存优化逻辑
利用 KV Cache 分段计算 Logits,计算完立即释放,避免保留全量 Logits。
数值说明:在 float16(如 MPS)上,在「仅前缀 forward」vs「整段 forward」同位置 logits 的逐元素对比,可能出现微小差异;
float16(MPS/CUDA)可能因实现路径出现约 1%的 量级差,非掩码错误。CPU float32 下则完全一致。
"""
seq_len = token_ids.shape[1]
# 使用初始化时根据平台确定的 chunk_size
chunk_size = self.chunk_size
real_probs_list = []
pred_topk_list = []
past_key_values = None
# 预先清理
DeviceManager.clear_cache(self.device)
full_input_ids = token_ids
# 因果 LM:logits[i] 预测 input_ids[i+1];首 token 无左文,不在此循环中计分
# 我们使用 past_key_values 增量推理
# 第一次:输入 input_ids[:, :chunk_size],输出 logits 对应位置 0..chunk_size-1 (预测 1..chunk_size)
total_chunks = (seq_len + chunk_size - 1) // chunk_size
with torch.inference_mode():
for i in range(total_chunks):
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, seq_len)
current_chunk_len = end_idx - start_idx
# 准备输入(统一逻辑,避免边界 token 重复)
if i == 0:
input_chunk = full_input_ids[:, :end_idx]
else:
input_chunk = full_input_ids[:, start_idx:end_idx]
# 1. 运行推理
outputs = self.model(
input_ids=input_chunk,
past_key_values=past_key_values,
use_cache=True
)
past_key_values = outputs.past_key_values
logits = outputs.logits
# 获取 targets
# full_input_ids[:, 1:] 是所有 targets
# 当前块 targets 范围: [start_idx : end_idx]
chunk_targets = full_input_ids[:, 1+start_idx : 1+end_idx]
valid_len = chunk_targets.shape[1]
if valid_len == 0:
continue
# 最后一块覆盖到序列末尾时,最后一个 logit 位预测的是「下一 token」,需裁掉
current_logits = logits[:, :valid_len, :]
# 2. 处理当前块的 Softmax 和 TopK
probs_chunk = torch.softmax(current_logits, dim=2)
# 提取真实概率
chunk_target_probs = torch.gather(probs_chunk, 2, chunk_targets.unsqueeze(-1))
real_probs_list.extend(chunk_target_probs.flatten().detach().cpu().float().numpy().tolist())
# 提取 TopK
# 由于 chunk_size 已确保小于 MPS_TOPK_BUG_THRESHOLD,所以直接计算
topk_vals, topk_inds = torch.topk(probs_chunk, k=effective_topk, dim=2)
chunk_pred_topk = self._decode_topk_tokens(
topk_vals, topk_inds, effective_topk, valid_len
)
pred_topk_list.extend(chunk_pred_topk)
# 3. 立即释放内存
del logits
del current_logits
del probs_chunk
del chunk_target_probs
# outputs 会在下一次循环时被覆盖,无需手动处理
# 进度更新(基于实际处理的 token 数量)
if progress_callback:
pct = int(end_idx / seq_len * 100) # 推理阶段独立的 0-100%
progress_callback(2, 3, 'inference', pct)
# 循环结束,清理 KV Cache
del past_key_values
DeviceManager.clear_cache(self.device)
return pred_topk_list, real_probs_list
def _decode_topk_tokens(
self,
topk_prob_values: torch.Tensor,
topk_prob_inds: torch.Tensor,
effective_topk: int,
seq_len: int
) -> List[List[Tuple[str, float]]]:
"""解码 TopK tokens 并构建预测列表(长度等于参与 topk 的序列长度)"""
topk_prob_values_cpu = topk_prob_values[0].detach().cpu().float().numpy()
topk_prob_inds_flat = topk_prob_inds[0].cpu().flatten().tolist()
probs_flat = topk_prob_values_cpu.flatten().tolist()
flat_pairs = pred_topk_pairs_from_flat_ids_and_probs(
topk_prob_inds_flat, probs_flat, self.tokenizer
)
return [
flat_pairs[i * effective_topk : (i + 1) * effective_topk]
for i in range(seq_len)
]
def _build_bpe_strings(
self,
token_offsets: List[Tuple[int, int]],
real_topk: List[Tuple[int, float]],
pred_topk: List[List[Tuple[str, float]]],
in_text: str
) -> List[Dict]:
"""构建最终的 BPE 字符串列表"""
# 确保长度一致
min_len = min(len(token_offsets), len(real_topk), len(pred_topk) if pred_topk else len(token_offsets))
bpe_strings = []
for idx in range(min_len):
start, end = token_offsets[idx]
raw_text = in_text[start:end] if start < end else ""
token_payload = {
"offset": [start, end],
"raw": raw_text,
"real_topk": list(real_topk[idx]),
"pred_topk": pred_topk[idx] if pred_topk else []
}
bpe_strings.append(token_payload)
return bpe_strings
def analyze_text(self, in_text: str, progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None) -> Dict[str, List[Dict]]:
"""
计算文本中每个 token 的概率
进度回调参数: (step: int, total_steps: int, stage: str, percentage: Optional[int])
- step: 当前步骤 (1-based)
- total_steps: 总步骤数 (固定为 3)
- stage: 阶段名称 (encoding/inference/processing)
- percentage: 可选的百分比,仅在 inference 阶段提供
"""
TOTAL_STEPS = 3
try:
# Step 1: 编码文本
if progress_callback:
progress_callback(1, TOTAL_STEPS, 'encoding', None)
token_ids, token_offsets = self._encode_text(in_text)
# Step 2: 分块推理并处理(带百分比进度)
# 这取代了原来的 _run_model_inference, MPS 流式处理, 和 _process_topk
if progress_callback:
progress_callback(2, 3, 'inference', 0)
pred_topk, real_topk_probs = self._run_inference_and_process_chunked(
token_ids, DEFAULT_TOPK, progress_callback
)
# Step 3: 构建结果
if progress_callback:
progress_callback(3, TOTAL_STEPS, 'processing', None)
if token_ids.shape[1] >= 1:
p0, pred0 = compute_first_token_lm_with_bos_prefix_cache(
self.model,
self.tokenizer,
self.device,
int(token_ids[0, 0].item()),
DEFAULT_TOPK,
)
pred_topk.insert(0, pred0)
real_topk_probs.insert(0, p0)
seq_len = len(real_topk_probs)
real_topk = list(zip([0] * seq_len, [round_to_sig_figs(p) for p in real_topk_probs]))
bpe_strings = self._build_bpe_strings(token_offsets, real_topk, pred_topk, in_text)
# 最终清理
DeviceManager.clear_cache(self.device)
gc.collect()
# 更新分析计数器
self._analysis_count += 1
# 打印分析任务完成后的内存统计(第1、11、21...次分析后打印)
if self.device.type == "cuda" and (self._analysis_count - 1) % 10 == 0:
device_idx = self.device.index if self.device.index is not None else 0
DeviceManager.print_cuda_memory_summary(device=device_idx)
return {'bpe_strings': bpe_strings}
except Exception as e:
import traceback
print(f"❌ Error in QwenLM.analyze_text: {e}")
traceback.print_exc()
return {'bpe_strings': []}
# _cleanup_tensors 方法已被移除,因为不再需要显式清理小张量
# ============================================================
# 自动注册:根据 MODEL_PATHS 与 SEMANTIC_MODEL_PATHS 自动注册所有模型
# ============================================================
# 只需要在 model_paths.py 中添加模型路径,即可自动注册
# 无需手动创建子类,实现 DRY 原则
def _auto_register_models():
"""自动注册 MODEL_PATHS 与 SEMANTIC_MODEL_PATHS 中的所有模型"""
for model_name in (*MODEL_PATHS.keys(), *SEMANTIC_MODEL_PATHS.keys()):
if model_name not in REGISTERED_MODELS:
# 动态创建模型类并注册
# 使用闭包捕获当前 model_name
def make_init():
def __init__(self):
QwenLM.__init__(self)
return __init__
model_class = type(
f'QwenLM_{model_name.replace(".", "_").replace("-", "_")}',
(QwenLM,),
{
'__init__': make_init(),
'__doc__': f'{model_name} 模型支持(自动注册)'
}
)
register_model(model_name)(model_class)
# 执行自动注册
_auto_register_models()
|