| |
| |
| |
|
|
| import type { BpeMergeReason } from '../api/GLTR_API'; |
| import { |
| type DigitMergePipelineOptions, |
| digitMergeIndexGroupsByText, |
| dropEmptyZeroWidthTokens, |
| flattenMergePartsForDigitGroup, |
| mergeSequentialOverlap, |
| mergeSourcePartsForOverlapPair, |
| sliceTextByCodePointOffsets, |
| } from './mergeTokenSpans'; |
|
|
| |
| |
| |
| export function getAttentionRawScore<T extends { score: number }>(t: T): number { |
| const ext = t as { rawScore?: number }; |
| return ext.rawScore !== undefined ? ext.rawScore : t.score; |
| } |
|
|
| |
| |
| |
| |
| |
| export function normalizeTokenScores<T extends { score: number }>(tokens: T[]): Array<T & { rawScore: number }> { |
| const max = Math.max(0, ...tokens.map((t) => t.score).filter(Number.isFinite)); |
| return tokens.map((t) => { |
| const rawScore = getAttentionRawScore(t); |
| if (max <= 0) { |
| return { ...t, rawScore }; |
| } |
| return { ...t, rawScore, score: t.score / max }; |
| }); |
| } |
|
|
| const encoder = new TextEncoder(); |
|
|
| |
| |
| |
| |
| |
| export function getUtf8ByteLength(text: string, buf: Uint8Array): number { |
| const { read, written } = encoder.encodeInto(text, buf); |
| return read < text.length ? buf.length : written; |
| } |
|
|
|
|
| |
| function nextParagraphEnd(text: string, start: number): number { |
| const nl = text.indexOf("\n\n", start); |
| if (nl === -1) return text.length; |
| let end = nl + 2; |
| while (end < text.length && text[end] === "\n") end++; |
| return end; |
| } |
|
|
| |
| function nextLineEnd(text: string, start: number): number { |
| const nl = text.indexOf("\n", start); |
| if (nl === -1) return text.length; |
| let end = nl + 1; |
| while (end < text.length && text[end] === "\n") end++; |
| return end; |
| } |
|
|
| |
| export function charIndexForByteLimit(text: string, start: number, byteLimit: number): number { |
| const buf = new Uint8Array(4); |
| let bytes = 0; |
| let i = start; |
| while (i < text.length) { |
| const cp = text.codePointAt(i)!; |
| const charLen = cp > 0xFFFF ? 2 : 1; |
| const byteLen = encoder.encodeInto(text.slice(i, i + charLen), buf).written; |
| if (bytes + byteLen > byteLimit) break; |
| bytes += byteLen; |
| i += charLen; |
| } |
| return i; |
| } |
|
|
| |
| |
| |
| const SEPARATOR_GROUPS: string[][] = [ |
| |
| ["。", "!", "?", "…"], |
| |
| [";", ","], |
| |
| [".", "!", "?"], |
| |
| [";", ","], |
| |
| [" ", "\t"], |
| ]; |
|
|
| |
| |
| |
| |
| export function findSplitPoint(text: string, start: number, maxEnd: number): number { |
| const window = text.slice(start, maxEnd); |
| for (const group of SEPARATOR_GROUPS) { |
| let bestEnd = -1; |
| for (const sep of group) { |
| const i = window.lastIndexOf(sep); |
| |
| if (i !== -1 && i + sep.length > bestEnd) bestEnd = i + sep.length; |
| } |
| if (bestEnd !== -1) return start + bestEnd; |
| |
| } |
| |
| |
| return maxEnd; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| export function mergeAttentionTokensForRendering<T extends { offset: [number, number]; raw: string; score: number }>( |
| tokens: T[], |
| text: string |
| ): T[] { |
| if (tokens.length === 0) return tokens; |
| const prepared = dropEmptyZeroWidthTokens(tokens); |
| if (prepared.length === 0) return prepared; |
| return mergeSequentialOverlap(prepared, { |
| getOffset: (t) => t.offset, |
| cloneForStep: (t) => ({ ...t, offset: [t.offset[0], t.offset[1]] as [number, number] }) as T, |
| sliceMergedRaw: (start, end) => sliceTextByCodePointOffsets(text, start, end), |
| mergeOverlappingPair: (current, next, mergedOffset, mergedRaw) => |
| ({ |
| ...current, |
| offset: mergedOffset, |
| raw: mergedRaw, |
| score: current.score + next.score, |
| bpe_merge_parts: mergeSourcePartsForOverlapPair(text, current, next), |
| bpe_merged: 'overlap' satisfies BpeMergeReason, |
| }) as T, |
| }); |
| } |
|
|
| |
| |
| |
| export function mergeAttentionDigitTokens<T extends { offset: [number, number]; raw: string; score: number }>( |
| tokens: T[], |
| text: string |
| ): T[] { |
| const mergeGroups = digitMergeIndexGroupsByText(text, tokens); |
| return mergeGroups.map((group) => { |
| if (group.length === 1) { |
| return tokens[group[0]!]!; |
| } |
| const first = tokens[group[0]!]!; |
| const last = tokens[group[group.length - 1]!]!; |
| const mergedRaw = sliceTextByCodePointOffsets(text, first.offset[0], last.offset[1]); |
| const mergedScore = group.reduce((sum, idx) => sum + tokens[idx]!.score, 0); |
| return { |
| ...first, |
| offset: [first.offset[0], last.offset[1]] as [number, number], |
| raw: mergedRaw, |
| score: mergedScore, |
| bpe_merge_parts: flattenMergePartsForDigitGroup(group, tokens), |
| bpe_merged: 'digit' satisfies BpeMergeReason, |
| } as T; |
| }); |
| } |
|
|
| |
| |
| |
| export function mergeAttentionTokensFullyForRendering<T extends { offset: [number, number]; raw: string; score: number }>( |
| tokens: T[], |
| text: string, |
| options: DigitMergePipelineOptions = {} |
| ): T[] { |
| const overlapped = mergeAttentionTokensForRendering(tokens, text); |
| if (options.digitMerge === false) { |
| return overlapped; |
| } |
| return mergeAttentionDigitTokens(overlapped, text); |
| } |
|
|
| |
| export function splitTextToChunks(text: string, bytesPerChunk: number): Array<{ text: string; startOffset: number }> { |
| if (bytesPerChunk <= 0) { |
| throw new Error("分块字节上限必须大于 0,当前值: " + bytesPerChunk); |
| } |
| if (text.includes("\r")) { |
| throw new Error("文本包含 \\r (CR) 换行符,当前仅支持 \\n (LF)。"); |
| } |
| const chunks: Array<{ text: string; startOffset: number }> = []; |
| let pos = 0; |
| const encodeBuf = new Uint8Array(bytesPerChunk + 1); |
| while (pos < text.length) { |
| let chunkEnd = pos; |
| let chunkBytes = 0; |
| outer: while (chunkEnd < text.length) { |
| const paragEnd = nextParagraphEnd(text, chunkEnd); |
| const paragBytes = getUtf8ByteLength(text.slice(chunkEnd, paragEnd), encodeBuf); |
| if (chunkBytes > 0 && chunkBytes + paragBytes > bytesPerChunk) break; |
| if (chunkBytes === 0 && paragBytes > bytesPerChunk) { |
| |
| while (chunkEnd < paragEnd) { |
| const lineEnd = nextLineEnd(text, chunkEnd); |
| const lineBytes = getUtf8ByteLength(text.slice(chunkEnd, lineEnd), encodeBuf); |
| if (lineBytes > bytesPerChunk) { |
| const maxEnd = charIndexForByteLimit(text, chunkEnd, bytesPerChunk); |
| chunkEnd = findSplitPoint(text, chunkEnd, maxEnd); |
| break outer; |
| } |
| if (chunkBytes > 0 && chunkBytes + lineBytes > bytesPerChunk) break outer; |
| chunkBytes += lineBytes; |
| chunkEnd = lineEnd; |
| } |
| continue outer; |
| } |
| chunkBytes += paragBytes; |
| chunkEnd = paragEnd; |
| } |
| chunks.push({ text: text.slice(pos, chunkEnd), startOffset: pos }); |
| pos = chunkEnd; |
| } |
| return chunks; |
| } |
|
|