| import torch |
| import gc |
| from typing import Callable, Dict, List, Optional, Tuple |
|
|
| from .api.utils import round_to_sig_figs |
| from .pred_topk_format import pred_topk_pairs_from_flat_ids_and_probs |
| from .class_register import register_model, REGISTERED_MODELS |
| from .device import DeviceManager |
| from .model_manager import ensure_model_loaded |
| from .runtime_config import load_runtime_config, DEFAULT_TOPK |
| from model_paths import DEFAULT_MODEL, MODEL_PATHS, SEMANTIC_MODEL_PATHS, resolve_hf_path |
|
|
| |
| _bos_first_position_logits_cache: Dict[int, torch.Tensor] = {} |
|
|
|
|
| def compute_first_token_lm_with_bos_prefix_cache( |
| model: torch.nn.Module, |
| tokenizer, |
| device: torch.device, |
| first_token_id: int, |
| effective_topk: int, |
| ) -> Tuple[float, List[Tuple[str, float]]]: |
| """ |
| 首 token 无左文时的 workaround:与旧版 BOS 前缀一致,对单 token 输入 [bos] 做一步 forward, |
| 将末位 logits(预测首段文本第一个 token 的分布)缓存到 CPU,再在 CPU 上 softmax/topk。 |
| |
| 同一 model 实例复用同一份词表 logits,不在每次分析时重复 forward。 |
| """ |
| mid = id(model) |
| if mid not in _bos_first_position_logits_cache: |
| if tokenizer.bos_token_id is not None: |
| bos_id = int(tokenizer.bos_token_id) |
| elif tokenizer.eos_token_id is not None: |
| bos_id = int(tokenizer.eos_token_id) |
| else: |
| bos_id = 0 |
| with torch.inference_mode(): |
| bos_in = torch.tensor([[bos_id]], device=device, dtype=torch.long) |
| out = model(input_ids=bos_in) |
| |
| row = out.logits[0, -1, :].detach().float() |
| _bos_first_position_logits_cache[mid] = row.cpu() |
|
|
| logits = _bos_first_position_logits_cache[mid] |
| probs = torch.softmax(logits, dim=-1) |
| p = float(probs[first_token_id].item()) |
|
|
| topk_vals, topk_inds = torch.topk(probs, k=min(effective_topk, probs.shape[0]), dim=-1) |
| topk_vals = topk_vals.float().numpy() |
| topk_inds_flat = topk_inds.flatten().tolist() |
| topk_tokens_decoded = tokenizer.batch_decode( |
| [[tid] for tid in topk_inds_flat], |
| skip_special_tokens=False, |
| ) |
| pred_topk = [ |
| (topk_tokens_decoded[j], round_to_sig_figs(float(topk_vals[j]))) |
| for j in range(len(topk_tokens_decoded)) |
| ] |
| return p, pred_topk |
|
|
|
|
| class AbstractLanguageChecker: |
| """ |
| Abstract Class that defines the Backend API of GLTR. |
| |
| To extend the GLTR interface, you need to inherit this and |
| fill in the defined functions. |
| """ |
|
|
| def __init__(self): |
| """ |
| In the subclass, you need to load all necessary components |
| for the other functions. |
| Typically, this will comprise a tokenizer and a model. |
| """ |
| self.device = DeviceManager.get_device() |
| |
|
|
| def analyze_text(self, in_text): |
| """ |
| Function that GLTR interacts with to analyze text and get token probabilities |
| |
| Params: |
| - in_text: str -- The text that you want to analyze |
| - topk: int, optional -- Desired pred_topk count (default from runtime_config.DEFAULT_TOPK) |
| |
| Output: |
| - payload: dict -- The wrapper for results in this function, described below |
| |
| Payload values |
| ============== |
| bpe_strings: list of dict -- Each dict contains {"offset": [start, end], "raw": str, |
| "real_topk": [rank, prob], "pred_topk": [(token, prob), ...]} |
| - offset: character offsets in the original text [start, end] |
| - raw: token text extracted from original text |
| - real_topk: (ranking, prob) of each token(优先级默认0) |
| - pred_topk: top-k 候选列表(若不可用则为空数组) |
| """ |
| raise NotImplementedError |
|
|
|
|
|
|
| @register_model(name='qwen2.5-0.5b') |
| class QwenLM(AbstractLanguageChecker): |
| """ |
| Qwen 系列模型支持 |
| 默认使用 Qwen2.5-0.5B Base 模型(适合计算 surprisal 和信息量) |
| """ |
| def __init__(self, model_path=None, model_name=None): |
| super(QwenLM, self).__init__() |
| model_name = model_name or getattr(self.__class__, '_registered_model_name', DEFAULT_MODEL) |
| if model_path is not None and str(model_path).strip(): |
| resolved = str(model_path).strip() |
| else: |
| resolved = resolve_hf_path(model_name) |
|
|
| |
| self._load_runtime_config(model_name) |
|
|
| self.tokenizer, self.model, self.device = ensure_model_loaded(resolved) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self._analysis_count = 0 |
|
|
| def _load_runtime_config(self, model_name: Optional[str]): |
| """ |
| 加载运行时配置:基于模型和平台的四层配置合并 |
| |
| Args: |
| model_name: 模型标识符(如 'qwen3-1.7b') |
| """ |
| |
| |
| self.platform, self.max_length, self.chunk_size = load_runtime_config( |
| model_name=model_name or "default_model" |
| ) |
|
|
| def _encode_text(self, in_text: str) -> Tuple[torch.Tensor, List[Tuple[int, int]]]: |
| """编码文本并返回 token_ids 和 offsets""" |
| |
| enc_out = self.tokenizer( |
| in_text, |
| return_tensors='pt', |
| return_offsets_mapping=True, |
| max_length=self.max_length, |
| truncation=True |
| ) |
| token_ids = enc_out['input_ids'] |
| token_offsets = enc_out['offset_mapping'][0].tolist() |
| |
| |
| if token_offsets: |
| last_offset_end = token_offsets[-1][1] |
| if last_offset_end < len(in_text): |
| |
| print(f"⚠️ 文本过长,已截断至前 {self.max_length} token ({len(in_text)} char -> {last_offset_end} char)") |
| |
| token_ids = token_ids.to(self.device) |
| |
| return token_ids, token_offsets |
|
|
| def _run_inference_and_process_chunked( |
| self, |
| token_ids: torch.Tensor, |
| effective_topk: int, |
| progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None |
| ) -> Tuple[List[List[Tuple[str, float]]], List[float]]: |
| """ |
| 分块推理并即时处理:核心内存优化逻辑 |
| 利用 KV Cache 分段计算 Logits,计算完立即释放,避免保留全量 Logits。 |
| |
| 数值说明:在 float16(如 MPS)上,在「仅前缀 forward」vs「整段 forward」同位置 logits 的逐元素对比,可能出现微小差异; |
| float16(MPS/CUDA)可能因实现路径出现约 1%的 量级差,非掩码错误。CPU float32 下则完全一致。 |
| """ |
| seq_len = token_ids.shape[1] |
| |
| |
| chunk_size = self.chunk_size |
| |
| real_probs_list = [] |
| pred_topk_list = [] |
| past_key_values = None |
| |
| |
| DeviceManager.clear_cache(self.device) |
| |
| full_input_ids = token_ids |
| |
| |
| |
| |
| |
| |
| total_chunks = (seq_len + chunk_size - 1) // chunk_size |
| |
| with torch.inference_mode(): |
| for i in range(total_chunks): |
| start_idx = i * chunk_size |
| end_idx = min((i + 1) * chunk_size, seq_len) |
| current_chunk_len = end_idx - start_idx |
| |
| |
| if i == 0: |
| input_chunk = full_input_ids[:, :end_idx] |
| else: |
| input_chunk = full_input_ids[:, start_idx:end_idx] |
| |
| |
| outputs = self.model( |
| input_ids=input_chunk, |
| past_key_values=past_key_values, |
| use_cache=True |
| ) |
| |
| past_key_values = outputs.past_key_values |
|
|
| logits = outputs.logits |
| |
| |
| |
| |
| chunk_targets = full_input_ids[:, 1+start_idx : 1+end_idx] |
| valid_len = chunk_targets.shape[1] |
| if valid_len == 0: |
| continue |
| |
| current_logits = logits[:, :valid_len, :] |
| |
| |
| probs_chunk = torch.softmax(current_logits, dim=2) |
| |
| |
| chunk_target_probs = torch.gather(probs_chunk, 2, chunk_targets.unsqueeze(-1)) |
| real_probs_list.extend(chunk_target_probs.flatten().detach().cpu().float().numpy().tolist()) |
| |
| |
| |
| topk_vals, topk_inds = torch.topk(probs_chunk, k=effective_topk, dim=2) |
| chunk_pred_topk = self._decode_topk_tokens( |
| topk_vals, topk_inds, effective_topk, valid_len |
| ) |
| pred_topk_list.extend(chunk_pred_topk) |
| |
| |
| del logits |
| del current_logits |
| del probs_chunk |
| del chunk_target_probs |
| |
| |
| |
| if progress_callback: |
| pct = int(end_idx / seq_len * 100) |
| progress_callback(2, 3, 'inference', pct) |
|
|
| |
| del past_key_values |
| DeviceManager.clear_cache(self.device) |
| |
| return pred_topk_list, real_probs_list |
|
|
| def _decode_topk_tokens( |
| self, |
| topk_prob_values: torch.Tensor, |
| topk_prob_inds: torch.Tensor, |
| effective_topk: int, |
| seq_len: int |
| ) -> List[List[Tuple[str, float]]]: |
| """解码 TopK tokens 并构建预测列表(长度等于参与 topk 的序列长度)""" |
| topk_prob_values_cpu = topk_prob_values[0].detach().cpu().float().numpy() |
| topk_prob_inds_flat = topk_prob_inds[0].cpu().flatten().tolist() |
| probs_flat = topk_prob_values_cpu.flatten().tolist() |
| flat_pairs = pred_topk_pairs_from_flat_ids_and_probs( |
| topk_prob_inds_flat, probs_flat, self.tokenizer |
| ) |
| return [ |
| flat_pairs[i * effective_topk : (i + 1) * effective_topk] |
| for i in range(seq_len) |
| ] |
|
|
| def _build_bpe_strings( |
| self, |
| token_offsets: List[Tuple[int, int]], |
| real_topk: List[Tuple[int, float]], |
| pred_topk: List[List[Tuple[str, float]]], |
| in_text: str |
| ) -> List[Dict]: |
| """构建最终的 BPE 字符串列表""" |
| |
| min_len = min(len(token_offsets), len(real_topk), len(pred_topk) if pred_topk else len(token_offsets)) |
| |
| bpe_strings = [] |
| for idx in range(min_len): |
| start, end = token_offsets[idx] |
| raw_text = in_text[start:end] if start < end else "" |
| token_payload = { |
| "offset": [start, end], |
| "raw": raw_text, |
| "real_topk": list(real_topk[idx]), |
| "pred_topk": pred_topk[idx] if pred_topk else [] |
| } |
| bpe_strings.append(token_payload) |
| |
| return bpe_strings |
|
|
| def analyze_text(self, in_text: str, progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None) -> Dict[str, List[Dict]]: |
| """ |
| 计算文本中每个 token 的概率 |
| |
| 进度回调参数: (step: int, total_steps: int, stage: str, percentage: Optional[int]) |
| - step: 当前步骤 (1-based) |
| - total_steps: 总步骤数 (固定为 3) |
| - stage: 阶段名称 (encoding/inference/processing) |
| - percentage: 可选的百分比,仅在 inference 阶段提供 |
| """ |
| TOTAL_STEPS = 3 |
| |
| try: |
| |
| if progress_callback: |
| progress_callback(1, TOTAL_STEPS, 'encoding', None) |
| token_ids, token_offsets = self._encode_text(in_text) |
| |
| |
| |
| |
| if progress_callback: |
| progress_callback(2, 3, 'inference', 0) |
| pred_topk, real_topk_probs = self._run_inference_and_process_chunked( |
| token_ids, DEFAULT_TOPK, progress_callback |
| ) |
| |
| |
| if progress_callback: |
| progress_callback(3, TOTAL_STEPS, 'processing', None) |
|
|
| if token_ids.shape[1] >= 1: |
| p0, pred0 = compute_first_token_lm_with_bos_prefix_cache( |
| self.model, |
| self.tokenizer, |
| self.device, |
| int(token_ids[0, 0].item()), |
| DEFAULT_TOPK, |
| ) |
| pred_topk.insert(0, pred0) |
| real_topk_probs.insert(0, p0) |
|
|
| seq_len = len(real_topk_probs) |
| real_topk = list(zip([0] * seq_len, [round_to_sig_figs(p) for p in real_topk_probs])) |
| |
| bpe_strings = self._build_bpe_strings(token_offsets, real_topk, pred_topk, in_text) |
| |
| |
| DeviceManager.clear_cache(self.device) |
| gc.collect() |
| |
| |
| self._analysis_count += 1 |
| |
| |
| if self.device.type == "cuda" and (self._analysis_count - 1) % 10 == 0: |
| device_idx = self.device.index if self.device.index is not None else 0 |
| DeviceManager.print_cuda_memory_summary(device=device_idx) |
| |
| return {'bpe_strings': bpe_strings} |
| |
| except Exception as e: |
| import traceback |
| print(f"❌ Error in QwenLM.analyze_text: {e}") |
| traceback.print_exc() |
| return {'bpe_strings': []} |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| def _auto_register_models(): |
| """自动注册 MODEL_PATHS 与 SEMANTIC_MODEL_PATHS 中的所有模型""" |
| for model_name in (*MODEL_PATHS.keys(), *SEMANTIC_MODEL_PATHS.keys()): |
| if model_name not in REGISTERED_MODELS: |
| |
| |
| def make_init(): |
| def __init__(self): |
| QwenLM.__init__(self) |
| return __init__ |
| |
| model_class = type( |
| f'QwenLM_{model_name.replace(".", "_").replace("-", "_")}', |
| (QwenLM,), |
| { |
| '__init__': make_init(), |
| '__doc__': f'{model_name} 模型支持(自动注册)' |
| } |
| ) |
| register_model(model_name)(model_class) |
|
|
| |
| _auto_register_models() |
|
|
|
|