File size: 10,848 Bytes
494c9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
/**
 * 语义分析相关工具函数
 */

import type { BpeMergeReason } from '../api/GLTR_API';
import {
    type DigitMergePipelineOptions,
    digitMergeIndexGroupsByText,
    dropEmptyZeroWidthTokens,
    flattenMergePartsForDigitGroup,
    mergeSequentialOverlap,
    mergeSourcePartsForOverlapPair,
    sliceTextByCodePointOffsets,
} from './mergeTokenSpans';

/**
 * 合并/归一化管线中的原始强度:已写入 rawScore 时用其值,否则用 score。
 */
export function getAttentionRawScore<T extends { score: number }>(t: T): number {
    const ext = t as { rawScore?: number };
    return ext.rawScore !== undefined ? ext.rawScore : t.score;
}

/**
 * 将 score 归一化到 [0,1];写入 rawScore(归一化前的强度,供 tooltip attentionRawScores)。
 * 语义 / 归因路径应在 **overlap 与 digit 合并并对原始 score 求和之后** 再调用,使 max 与合并后强度一致。
 * 若调用方已将「原始梯度」放在 rawScore、且 score 置 0(如未匹配块),则以 rawScore 作为 tooltip 保留值,仅用 score 参与 max 归一。
 */
export function normalizeTokenScores<T extends { score: number }>(tokens: T[]): Array<T & { rawScore: number }> {
    const max = Math.max(0, ...tokens.map((t) => t.score).filter(Number.isFinite));
    return tokens.map((t) => {
        const rawScore = getAttentionRawScore(t);
        if (max <= 0) {
            return { ...t, rawScore };
        }
        return { ...t, rawScore, score: t.score / max };
    });
}

const encoder = new TextEncoder();

/**
 * 返回 text 的 UTF-8 字节数(返回值单位:字节)。buf 为 encodeInto 的写入目标,其长度即上界。
 * 若 text 的真实字节数超过 buf.length,则返回 buf.length(而非精确值),调用方应据此判断"超限"。
 * 用 read < text.length 检测是否还有字符未写入,避免多字节字符边界恰好填满 buf 时的误判。
 */
export function getUtf8ByteLength(text: string, buf: Uint8Array): number {
    const { read, written } = encoder.encodeInto(text, buf);
    return read < text.length ? buf.length : written;
}


/** 从 start 起找下一段落结束位置(段落边界:≥2个连续换行符)。返回值包含尾部所有连续换行符;若无段落边界,返回 text.length。 */
function nextParagraphEnd(text: string, start: number): number {
    const nl = text.indexOf("\n\n", start);
    if (nl === -1) return text.length;
    let end = nl + 2;
    while (end < text.length && text[end] === "\n") end++;
    return end;
}

/** 从 start 起找下一行结束位置。连续换行算作一行(防止切断 BPE 分词)。 */
function nextLineEnd(text: string, start: number): number {
    const nl = text.indexOf("\n", start);
    if (nl === -1) return text.length;
    let end = nl + 1;
    while (end < text.length && text[end] === "\n") end++;
    return end;
}

/** 返回从 start 起累计 UTF-8 字节不超过 byteLimit 的最大字符索引(不切断代理对)。start:字符索引;byteLimit:UTF-8 字节数;返回值:字符索引。 */
export function charIndexForByteLimit(text: string, start: number, byteLimit: number): number {
    const buf = new Uint8Array(4);
    let bytes = 0;
    let i = start;
    while (i < text.length) {
        const cp = text.codePointAt(i)!;
        const charLen = cp > 0xFFFF ? 2 : 1;
        const byteLen = encoder.encodeInto(text.slice(i, i + charLen), buf).written;
        if (bytes + byteLen > byteLimit) break;
        bytes += byteLen;
        i += charLen;
    }
    return i;
}

// 纯中文文章:永远不会触碰英文标点,完全规避 . 的歧义问题
// 纯英文文章:前两组无命中,自然降级到英文标点,行为正确
// 中英混排时:划分效果会变差
const SEPARATOR_GROUPS: string[][] = [
    // 第一优先级:中文句子级
    ["。", "!", "?", "…"],
    // 第二优先级:中文子句级
    [";", ","],
    // 第三优先级:英文句子级
    [".", "!", "?"],
    // 第四优先级:英文子句级
    [";", ","],
    // 第五优先级:空格
    [" ", "\t"],
];

/**
 * 在 [start, maxEnd) 范围内,按 groups 优先级找最靠右的分隔符边界。
 * start、maxEnd、返回值:均为字符索引。同组内取最靠右的;找不到则尝试下一组;均无则回退到 maxEnd。
 */
export function findSplitPoint(text: string, start: number, maxEnd: number): number {
    const window = text.slice(start, maxEnd);
    for (const group of SEPARATOR_GROUPS) {
        let bestEnd = -1;
        for (const sep of group) {
            const i = window.lastIndexOf(sep);
            // 同组内取最靠右的
            if (i !== -1 && i + sep.length > bestEnd) bestEnd = i + sep.length;
        }
        if (bestEnd !== -1) return start + bestEnd;
        // 找不到则尝试下一组
    }
    // todo: 如果回退到字符单位的边界,切分后分词结果有可能和原文分词不一致,会报错。
    // todo: 其实英文句号、空格等作为边界也有可能有分词不一致问题,这里会是一个坑
    return maxEnd;
}

/**
 * 合并 token_attention 中因 BPE overlap 产生的重叠 token(offset 几何合并与 mergeTokensForRendering 一致)。
 *
 * BPE overlap 多为 tokenizer 的 offset 与字边界不对齐所致:相邻条目的 raw / offset 在表层可能看起来「重叠」,
 * 但底层仍是按 tokenizer 位置各不相同的嵌入与梯度;并非同一条底层数据被算了两次。
 *
 * 输入须为 API 的原始 `score`(梯度范数);重叠时 **相加**。归一化到 [0,1] 须在合并之后由 normalizeTokenScores 统一做。
 *
 * 与 BPE 一致:先 {@link dropEmptyZeroWidthTokens},再 {@link mergeSequentialOverlap}(含零宽落在下一区间内之合并)。
 */
export function mergeAttentionTokensForRendering<T extends { offset: [number, number]; raw: string; score: number }>(
    tokens: T[],
    text: string
): T[] {
    if (tokens.length === 0) return tokens;
    const prepared = dropEmptyZeroWidthTokens(tokens);
    if (prepared.length === 0) return prepared;
    return mergeSequentialOverlap(prepared, {
        getOffset: (t) => t.offset,
        cloneForStep: (t) => ({ ...t, offset: [t.offset[0], t.offset[1]] as [number, number] }) as T,
        sliceMergedRaw: (start, end) => sliceTextByCodePointOffsets(text, start, end),
        mergeOverlappingPair: (current, next, mergedOffset, mergedRaw) =>
            ({
                ...current,
                offset: mergedOffset,
                raw: mergedRaw,
                score: current.score + next.score,
                bpe_merge_parts: mergeSourcePartsForOverlapPair(text, current, next),
                bpe_merged: 'overlap' satisfies BpeMergeReason,
            }) as T,
    });
}

/**
 * Digit 合并:与 {@link mergeBpeDigitTokens} 相同分组规则({@link digitMergeIndexGroupsByText}),对 attention 的 `score` **求和**(BPE 侧为概率相乘)。
 */
export function mergeAttentionDigitTokens<T extends { offset: [number, number]; raw: string; score: number }>(
    tokens: T[],
    text: string
): T[] {
    const mergeGroups = digitMergeIndexGroupsByText(text, tokens);
    return mergeGroups.map((group) => {
        if (group.length === 1) {
            return tokens[group[0]!]!;
        }
        const first = tokens[group[0]!]!;
        const last = tokens[group[group.length - 1]!]!;
        const mergedRaw = sliceTextByCodePointOffsets(text, first.offset[0], last.offset[1]);
        const mergedScore = group.reduce((sum, idx) => sum + tokens[idx]!.score, 0);
        return {
            ...first,
            offset: [first.offset[0], last.offset[1]] as [number, number],
            raw: mergedRaw,
            score: mergedScore,
            bpe_merge_parts: flattenMergePartsForDigitGroup(group, tokens),
            bpe_merged: 'digit' satisfies BpeMergeReason,
        } as T;
    });
}

/**
 * 语义 / 归因 attention 的统一合并:先 overlap(与 BPE 几何一致),可选再 digit;归一化由调用方 {@link normalizeTokenScores} 完成。
 */
export function mergeAttentionTokensFullyForRendering<T extends { offset: [number, number]; raw: string; score: number }>(
    tokens: T[],
    text: string,
    options: DigitMergePipelineOptions = {}
): T[] {
    const overlapped = mergeAttentionTokensForRendering(tokens, text);
    if (options.digitMerge === false) {
        return overlapped;
    }
    return mergeAttentionDigitTokens(overlapped, text);
}

/** bytesPerChunk:UTF-8 字节数;startOffset:字符索引。 */
export function splitTextToChunks(text: string, bytesPerChunk: number): Array<{ text: string; startOffset: number }> {
    if (bytesPerChunk <= 0) {
        throw new Error("分块字节上限必须大于 0,当前值: " + bytesPerChunk);
    }
    if (text.includes("\r")) {
        throw new Error("文本包含 \\r (CR) 换行符,当前仅支持 \\n (LF)。");
    }
    const chunks: Array<{ text: string; startOffset: number }> = [];
    let pos = 0; // 字符索引
    const encodeBuf = new Uint8Array(bytesPerChunk + 1); // +1 使超长行 written>bytesPerChunk,wouldExceed 恒为 true
    while (pos < text.length) {
        let chunkEnd = pos; // 字符索引
        let chunkBytes = 0; // UTF-8 字节数
        outer: while (chunkEnd < text.length) {
            const paragEnd = nextParagraphEnd(text, chunkEnd);
            const paragBytes = getUtf8ByteLength(text.slice(chunkEnd, paragEnd), encodeBuf);
            if (chunkBytes > 0 && chunkBytes + paragBytes > bytesPerChunk) break;
            if (chunkBytes === 0 && paragBytes > bytesPerChunk) {
                // 段落超限,降级到行模式:贪婪消费行直到 chunk 满或段落结束
                while (chunkEnd < paragEnd) {
                    const lineEnd = nextLineEnd(text, chunkEnd);
                    const lineBytes = getUtf8ByteLength(text.slice(chunkEnd, lineEnd), encodeBuf);
                    if (lineBytes > bytesPerChunk) {
                        const maxEnd = charIndexForByteLimit(text, chunkEnd, bytesPerChunk);
                        chunkEnd = findSplitPoint(text, chunkEnd, maxEnd);
                        break outer;
                    }
                    if (chunkBytes > 0 && chunkBytes + lineBytes > bytesPerChunk) break outer;
                    chunkBytes += lineBytes;
                    chunkEnd = lineEnd;
                }
                continue outer; // 段落内所有行已贪婪填充,继续评估下一段落
            }
            chunkBytes += paragBytes;
            chunkEnd = paragEnd;
        }
        chunks.push({ text: text.slice(pos, chunkEnd), startOffset: pos });
        pos = chunkEnd;
    }
    return chunks;
}