File size: 6,212 Bytes
494c9e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import type { BpeMergeReason, FrontendAnalyzeResult, FrontendToken } from '../api/GLTR_API';
import type { AttributionApiResponse } from './attributionResultCache';
import { getDigitsMergeEnabled } from '../utils/digitsMergeManager';
import {
getAttentionRawScore,
mergeAttentionTokensFullyForRendering,
normalizeTokenScores,
} from '../utils/semanticUtils';
/** 在 `context` 内的半开区间,用于限定「prompt」上的模式匹配范围 */
export type ExcludeRegexMatchRegion = {
start: number;
end: number;
};
export type AttributionDisplayOptions = {
colorRangeMax: number | null;
/** 已生效的排除配置(未使能时应传 ''):每行一条正则,在 {@link ExcludeRegexMatchRegion} 内 `g` 匹配 */
excludePromptPatternsText: string;
/**
* 正则仅作用于 `context` 的 `[start, end)` 子串;缺省为 `[0, context.length)`(整段 context 视为 prompt)。
*/
excludePromptPatternsRegion?: ExcludeRegexMatchRegion;
};
function mapNormedScoresToColorRange(rawScoresNormed: number[], x: number): number[] {
return rawScoresNormed.map((s) => (s > x ? 1 : s / x));
}
/** 行内注释:此前缀及其后整段不参与正则(见 {@link collectExcludeRegexMatchIntervals})。 */
const EXCLUDE_REGEX_LINE_COMMENT_MARKER = '#comment#';
/**
* 每行一条正则(`g` 匹配),在 `region` 限定的 `context` 子串上收集所有匹配区间 `[start, end)`(坐标为全串下标),不合并。
* 未传 `region` 时等价于 `[0, context.length)`。
* `excludeMultiline` 宜来自 `textarea.value`(API 值已规范为 `\n` 换行);不做 `trim`,以免改变正则语义。
* 行内可先写正则,再接 {@link EXCLUDE_REGEX_LINE_COMMENT_MARKER} 及说明;该标记及之后整段丢弃后再解析。删后为空则跳过(含整行仅注释)。
* 某行解析为非法正则时跳过该行(不影响其它行),避免抛错导致页面无法重绘。
* 供 {@link isOffsetSpanFullyExcluded} 与 DAG 预处理共用。
*/
export function collectExcludeRegexMatchIntervals(
context: string,
excludeMultiline: string,
region?: ExcludeRegexMatchRegion
): [number, number][] {
const r0 = region?.start ?? 0;
const r1 = region?.end ?? context.length;
const lo = Math.max(0, Math.min(r0, context.length));
const hi = Math.max(lo, Math.min(r1, context.length));
const slice = context.slice(lo, hi);
const intervals: [number, number][] = [];
for (const rawLine of excludeMultiline.split('\n')) {
const cut = rawLine.indexOf(EXCLUDE_REGEX_LINE_COMMENT_MARKER);
const line = cut === -1 ? rawLine : rawLine.slice(0, cut);
if (line === '') continue;
try {
const re = new RegExp(line, 'g');
for (const m of slice.matchAll(re)) {
if (m.index === undefined) continue;
const abs = lo + m.index;
intervals.push([abs, abs + m[0].length]);
}
} catch {
// 非法正则:跳过本行,其余行与 UI 仍可用
}
}
return intervals;
}
/** 当且仅当 `[ts, te)` 完全落在某一匹配区间内时返回 true(区间列表不合并,逐段判断)。 */
export function isOffsetSpanFullyExcluded(ts: number, te: number, intervals: [number, number][]): boolean {
for (const [a, b] of intervals) {
if (a <= ts && te <= b) return true;
}
return false;
}
/**
* 将归因 API 响应转为 {@link GLTR_Text_Box} 可用的 {@link FrontendAnalyzeResult}(含 rawScoresNormed / attentionRawScores / 可选 colorScores)。
* 管线:overlap + digit 合并 → {@link normalizeTokenScores},与语义 attention 一致。
*/
export function buildAttributionDisplayResult(
context: string,
response: AttributionApiResponse,
options: AttributionDisplayOptions
): FrontendAnalyzeResult {
const tokens = response.token_attribution ?? [];
const region = options.excludePromptPatternsRegion ?? { start: 0, end: context.length };
const excludeIntervals = collectExcludeRegexMatchIntervals(
context,
options.excludePromptPatternsText,
region
);
const originalTokens: FrontendToken[] = tokens.map((t) => ({
raw: t.raw,
offset: t.offset,
pred_topk: []
}));
const effective = tokens.map((t) => {
const [ts, te] = t.offset;
const excluded = isOffsetSpanFullyExcluded(ts, te, excludeIntervals);
return {
offset: t.offset,
raw: t.raw,
score: excluded ? 0 : t.score,
};
});
const merged = mergeAttentionTokensFullyForRendering(effective, context, {
digitMerge: getDigitsMergeEnabled(),
});
const normalized = normalizeTokenScores(merged);
const digitMergedTokens: FrontendToken[] = normalized.map((t) => {
const m = (t as { bpe_merged?: BpeMergeReason }).bpe_merged;
const parts = (t as { bpe_merge_parts?: string[] }).bpe_merge_parts;
const row: FrontendToken = {
offset: t.offset,
raw: t.raw,
pred_topk: [],
};
if (m !== undefined) {
row.bpe_merged = m;
}
if (parts !== undefined) {
row.bpe_merge_parts = [...parts];
}
return row;
});
const attentionRawScores = normalized.map((t) => getAttentionRawScore(t));
const rawScoresNormed = normalized.map((t) => t.score);
const result = {
model: response.model ?? null,
error: null,
bpe_strings: digitMergedTokens,
originalTokens,
bpeBpeMergedTokens: digitMergedTokens.map((t) => ({ ...t })),
originalText: context
} as FrontendAnalyzeResult;
const ext = result as FrontendAnalyzeResult & {
rawScoresNormed: number[];
colorScores?: number[];
attentionRawScores: number[];
};
ext.rawScoresNormed = rawScoresNormed;
ext.attentionRawScores = attentionRawScores;
if (options.colorRangeMax != null) {
ext.colorScores = mapNormedScoresToColorRange(rawScoresNormed, options.colorRangeMax);
}
return result;
}
|