| """ |
| 下一 token 的 top-k 解码:与语义分析 logits_gradient 一致,供 semantic / attribution 复用。 |
| """ |
| from typing import List, Tuple |
|
|
| import torch |
|
|
| from .api.utils import round_to_sig_figs |
|
|
| DEFAULT_NEXT_TOKEN_TOPK = 10 |
|
|
|
|
| def decode_topk_ids_to_strings_and_rounded_probs( |
| probs_1d: torch.Tensor, |
| tokenizer, |
| topk_ids_1d: torch.Tensor, |
| ) -> Tuple[List[str], List[float]]: |
| """ |
| probs_1d: 对单位置 logits 的 softmax,shape [vocab_size]。 |
| topk_ids_1d: torch.topk(logits, k) 返回的 indices,shape [k]。 |
| 返回与语义分析 debug_info 相同形态的 topk_tokens、topk_probs(概率已 round_to_sig_figs)。 |
| """ |
| ids_list = topk_ids_1d.tolist() |
| topk_tokens = [tokenizer.decode([int(tid)]) for tid in ids_list] |
| topk_probs = [round_to_sig_figs(probs_1d[int(tid)].item()) for tid in ids_list] |
| return topk_tokens, topk_probs |
|
|