File size: 1,513 Bytes
494c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
pred_topk 列表的格式化:与 language_checker 中 batch_decode + round_to_sig_figs 语义一致,供信息密度与续写共用。
"""

from typing import List, Tuple

import torch

from backend.api.utils import round_to_sig_figs


def pred_topk_pairs_from_flat_ids_and_probs(
    ids_flat: List[int],
    probs_flat: List[float],
    tokenizer,
) -> List[Tuple[str, float]]:
    """
    对 torch.topk 展平后的 id / 概率序列解码为 [(token 文本, 概率), ...]。
    与 QwenLM._decode_topk_tokens 内层逻辑一致(单次 batch_decode)。
    """
    if len(ids_flat) != len(probs_flat):
        raise ValueError("ids_flat 与 probs_flat 长度须一致")
    if not ids_flat:
        return []
    decoded = tokenizer.batch_decode([[tid] for tid in ids_flat], skip_special_tokens=False)
    return [
        (decoded[j], round_to_sig_figs(float(probs_flat[j])))
        for j in range(len(ids_flat))
    ]


def pred_topk_pairs_from_probs_1d(
    probs: torch.Tensor,
    tokenizer,
    top_k: int,
) -> List[Tuple[str, float]]:
    """单步 1D softmax 概率向量上的 top-k,用于续写 generate 的每步 scores。"""
    top_k = min(int(top_k), int(probs.numel()))
    if top_k <= 0:
        return []
    topk_probs, topk_ids = torch.topk(probs, top_k, dim=-1)
    ids_flat = topk_ids.cpu().flatten().tolist()
    probs_flat = topk_probs.detach().cpu().float().numpy().flatten().tolist()
    return pred_topk_pairs_from_flat_ids_and_probs(ids_flat, probs_flat, tokenizer)