""" pred_topk 列表的格式化:与 language_checker 中 batch_decode + round_to_sig_figs 语义一致,供信息密度与续写共用。 """ from typing import List, Tuple import torch from backend.api.utils import round_to_sig_figs def pred_topk_pairs_from_flat_ids_and_probs( ids_flat: List[int], probs_flat: List[float], tokenizer, ) -> List[Tuple[str, float]]: """ 对 torch.topk 展平后的 id / 概率序列解码为 [(token 文本, 概率), ...]。 与 QwenLM._decode_topk_tokens 内层逻辑一致(单次 batch_decode)。 """ if len(ids_flat) != len(probs_flat): raise ValueError("ids_flat 与 probs_flat 长度须一致") if not ids_flat: return [] decoded = tokenizer.batch_decode([[tid] for tid in ids_flat], skip_special_tokens=False) return [ (decoded[j], round_to_sig_figs(float(probs_flat[j]))) for j in range(len(ids_flat)) ] def pred_topk_pairs_from_probs_1d( probs: torch.Tensor, tokenizer, top_k: int, ) -> List[Tuple[str, float]]: """单步 1D softmax 概率向量上的 top-k,用于续写 generate 的每步 scores。""" top_k = min(int(top_k), int(probs.numel())) if top_k <= 0: return [] topk_probs, topk_ids = torch.topk(probs, top_k, dim=-1) ids_flat = topk_ids.cpu().flatten().tolist() probs_flat = topk_probs.detach().cpu().float().numpy().flatten().tolist() return pred_topk_pairs_from_flat_ids_and_probs(ids_flat, probs_flat, tokenizer)