File size: 4,994 Bytes
494c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import type { FrontendAnalyzeResult } from '../api/GLTR_API';
import { calculateSurprisal, calculateSurprisalDensity } from './Util';
import { extractRealTopkFromTokens } from './tokenUtils';

/** 首/末 bin 包含超出范围的值,中间 bin 为 [x0, x1) */
function valueInBinRange(value: number, x0: number, x1: number, binIndex: number, no_bins: number): boolean {
    const isFirstBin = binIndex === 0;
    const isLastBin = binIndex === no_bins - 1;
    if (isFirstBin) return value < x1;
    if (isLastBin) return value >= x0;
    return value >= x0 && value < x1;
}

export type HistogramType = 'token' | 'byte' | 'raw_score_normed';
export type HighlightData = FrontendAnalyzeResult & {
    rawScoresNormed?: number[];
    attentionRawScores?: number[];
    signalProbs?: number[];
    pPwValues?: number[];
    pwScores?: number[];
};

/**
 * 根据直方图 bin 的范围计算需要高亮的 token 索引集合(基于合并后 token 的 surprisal,与直方图数据一致)
 * @param x0 bin 起始值
 * @param x1 bin 结束值
 * @param binIndex bin在bins数组中的索引
 * @param no_bins 直方图的总bin数量
 * @param data 前端分析结果(包含 bpeBpeMergedTokens)
 * @returns 需要高亮的 merged token 索引集合
 */
export function calculateTokenSurprisalHighlights(
    x0: number,
    x1: number,
    binIndex: number,
    no_bins: number,
    data: HighlightData
): Set<number> {
    const highlightedIndices = new Set<number>();
    const bpeBpeMergedTokens = data.bpeBpeMergedTokens;
    if (!bpeBpeMergedTokens?.length) return highlightedIndices;

    const mergedRealTopk = extractRealTopkFromTokens(bpeBpeMergedTokens);
    for (let i = 0; i < bpeBpeMergedTokens.length; i++) {
        const surprisal = calculateSurprisal(mergedRealTopk[i][1]);
        if (!Number.isFinite(surprisal)) continue;
        if (!valueInBinRange(surprisal, x0, x1, binIndex, no_bins)) continue;
        highlightedIndices.add(i);
    }
    return highlightedIndices;
}

/**
 * 根据直方图 bin 的范围计算需要高亮的 token 索引集合(基于信息密度)
 * @param x0 bin 起始值
 * @param x1 bin 结束值
 * @param binIndex bin在bins数组中的索引
 * @param no_bins 直方图的总bin数量
 * @param data 前端分析结果(包含 bpeBpeMergedTokens)
 * @returns 需要高亮的 merged token 索引集合
 */
export function calculateByteSurprisalHighlights(
    x0: number,
    x1: number,
    binIndex: number,
    no_bins: number,
    data: HighlightData
): Set<number> {
    const highlightedIndices = new Set<number>();
    const bpeBpeMergedTokens = data.bpeBpeMergedTokens;
    if (!bpeBpeMergedTokens?.length) return highlightedIndices;

    for (let i = 0; i < bpeBpeMergedTokens.length; i++) {
        const informationDensity = calculateSurprisalDensity(bpeBpeMergedTokens[i]);
        if (!Number.isFinite(informationDensity)) continue;
        if (!valueInBinRange(informationDensity, x0, x1, binIndex, no_bins)) continue;
        highlightedIndices.add(i);
    }
    return highlightedIndices;
}

/**
 * 根据直方图 bin 的范围计算需要高亮的 token 索引集合(基于 raw_score_normed)
 * 使用 rawScoresNormed(与 bpeBpeMergedTokens 对齐),按 bin 范围筛选
 */
export function calculateRawScoreNormedHighlights(
    x0: number,
    x1: number,
    binIndex: number,
    no_bins: number,
    data: HighlightData
): Set<number> {
    const highlightedIndices = new Set<number>();
    const scores = data.rawScoresNormed;
    if (!scores?.length) return highlightedIndices;

    for (let i = 0; i < scores.length; i++) {
        const score = scores[i];
        if (!Number.isFinite(score)) continue;
        if (!valueInBinRange(score, x0, x1, binIndex, no_bins)) continue;
        highlightedIndices.add(i);
    }
    return highlightedIndices;
}

/**
 * 根据直方图类型和 bin 范围计算需要高亮的 token 索引集合
 * @param histogramType 直方图类型
 * @param x0 bin 起始值
 * @param x1 bin 结束值
 * @param binIndex bin在bins数组中的索引
 * @param no_bins 直方图的总bin数量
 * @param data 前端分析结果
 * @returns 需要高亮的 merged token 索引集合和对应的高亮样式
 */
export function calculateHighlights(
    histogramType: HistogramType,
    x0: number,
    x1: number,
    binIndex: number,
    no_bins: number,
    data: HighlightData
): { indices: Set<number>; style: 'border' | 'underline' } {
    if (histogramType === 'byte') {
        return {
            indices: calculateByteSurprisalHighlights(x0, x1, binIndex, no_bins, data),
            style: 'underline'
        };
    }
    if (histogramType === 'raw_score_normed') {
        return {
            indices: calculateRawScoreNormedHighlights(x0, x1, binIndex, no_bins, data),
            style: 'underline'
        };
    }
    return {
        indices: calculateTokenSurprisalHighlights(x0, x1, binIndex, no_bins, data),
        style: 'border'
    };
}