InfoLens / backend /pred_topk_format.py
dqy08's picture
initial beta release
494c9e4
"""
pred_topk 列表的格式化:与 language_checker 中 batch_decode + round_to_sig_figs 语义一致,供信息密度与续写共用。
"""
from typing import List, Tuple
import torch
from backend.api.utils import round_to_sig_figs
def pred_topk_pairs_from_flat_ids_and_probs(
ids_flat: List[int],
probs_flat: List[float],
tokenizer,
) -> List[Tuple[str, float]]:
"""
对 torch.topk 展平后的 id / 概率序列解码为 [(token 文本, 概率), ...]。
与 QwenLM._decode_topk_tokens 内层逻辑一致(单次 batch_decode)。
"""
if len(ids_flat) != len(probs_flat):
raise ValueError("ids_flat 与 probs_flat 长度须一致")
if not ids_flat:
return []
decoded = tokenizer.batch_decode([[tid] for tid in ids_flat], skip_special_tokens=False)
return [
(decoded[j], round_to_sig_figs(float(probs_flat[j])))
for j in range(len(ids_flat))
]
def pred_topk_pairs_from_probs_1d(
probs: torch.Tensor,
tokenizer,
top_k: int,
) -> List[Tuple[str, float]]:
"""单步 1D softmax 概率向量上的 top-k,用于续写 generate 的每步 scores。"""
top_k = min(int(top_k), int(probs.numel()))
if top_k <= 0:
return []
topk_probs, topk_ids = torch.topk(probs, top_k, dim=-1)
ids_flat = topk_ids.cpu().flatten().tolist()
probs_flat = topk_probs.detach().cpu().float().numpy().flatten().tolist()
return pred_topk_pairs_from_flat_ids_and_probs(ids_flat, probs_flat, tokenizer)