InfoLens / client /src /ts /attribution /tokenGenAttributionRunner.ts
dqy08's picture
prediction attribute 统计和log改进. history下拉高度改进;某些demo从14b模型改为1.7b模型,更符合直觉
a0b7722
/**
* 逐 token 生成归因:基于 /api/prediction-attribute。
* 默认 `target_prediction` 为空 → 服务端 top-1 贪心;传入 {@link TokenGenAttributionOptions.teacherForcingContinuation} 时按用户续写逐步强制首 token 再归因。
*/
import type { AttributionApiResponse, PredictionAttributeModelVariant } from './attributionResultCache';
import type { PromptTokenSpan } from './genAttributeDagPreprocess';
import type { CompletionFinishReason } from '../utils/generationEndReasonLabel';
import { fetchPredictionAttribute, fetchTokenize } from './predictionAttributeClient';
/** 与生成归因页(含 DAG)「Max tokens」输入框默认值一致 */
export const TOKEN_GEN_MAX_TOKENS_DEFAULT = 100;
function splitCodePointPrefix(text: string, prefixLength: number): { prefix: string; rest: string } | null {
if (prefixLength < 0) return null;
const chars = Array.from(text);
if (prefixLength > chars.length) return null;
return {
prefix: chars.slice(0, prefixLength).join(''),
rest: chars.slice(prefixLength).join(''),
};
}
export type TokenGenStep = {
/** 本步归因所用的 context(不含新 token) */
context: string;
/**
* 静态初始 prompt 在 `context` 中的 exclusive 结尾下标;`context.slice(0, promptRegionEnd)` 为不含已生成后缀的 prompt。
*/
promptRegionEnd: number;
response: AttributionApiResponse;
/** 本步生成的 token 字符串(即 response.target_token) */
token: string;
/** 目前已累积的全部生成文本(含本步 token) */
currentText: string;
};
export type TokenGenAttributionOptions = {
initialContext: string;
apiPrefix: string;
model: PredictionAttributeModelVariant;
/**
* 非空则启用 teacher forcing:启动时仅调用一次 `/api/tokenize` 预取 token_id,
* 后续每步通过 `target_token_id` 指定归因目标,并按 spans 的码点覆盖推进。
*/
teacherForcingContinuation?: string;
/**
* teacher forcing token 用尽后是否停止。
* `true`:停止;`false`(默认):切换为 top-1 继续生成,直到 maxTokens 或 EOS。
*/
stopAfterTeacherForcing?: boolean;
/** 最大生成 token 数,默认 {@link TOKEN_GEN_MAX_TOKENS_DEFAULT} */
maxTokens?: number;
/** 每生成一个 token 后的回调;`stepIndex` 从 0 起,与 {@link TokenGenAttributionHandle.getAllSteps} 下标一致 */
onStep: (step: TokenGenStep, stepIndex: number) => void;
onComplete: (reason: CompletionFinishReason) => void;
onError: (err: Error) => void;
/** 单次连续生成归因会话 ID;用于后端日志压缩与统计归类。 */
flowId: string;
};
export type TokenGenAttributionHandle = {
abort(): void;
getStep(idx: number): TokenGenStep | undefined;
getAllSteps(): TokenGenStep[];
/** 已生成的 token 总数(含进行中步骤) */
readonly tokenCount: number;
};
export function startTokenGenAttribution(opts: TokenGenAttributionOptions): TokenGenAttributionHandle {
const {
initialContext,
apiPrefix,
model,
maxTokens = TOKEN_GEN_MAX_TOKENS_DEFAULT,
stopAfterTeacherForcing = false,
flowId,
} = opts;
const tfOpt = opts.teacherForcingContinuation;
const forcingEnabled = typeof tfOpt === 'string' && tfOpt.length > 0;
const promptRegionEnd = initialContext.length;
let aborted = false;
let generatedText = '';
let remainingForcing = tfOpt ?? '';
let forcingPieces: Array<{ token: string; tokenId: number }> = [];
let forcingPieceIndex = 0;
const steps: TokenGenStep[] = [];
const loop = async (): Promise<void> => {
if (forcingEnabled) {
let spans;
try {
spans = await fetchTokenize(apiPrefix, tfOpt, model);
} catch (err) {
const error = err instanceof Error ? err : new Error(String(err));
opts.onError(error);
opts.onComplete('error');
return;
}
if (!spans.length) {
opts.onError(new Error('Teacher forcing tokenize returned empty spans.'));
opts.onComplete('error');
return;
}
const chars = Array.from(tfOpt);
let cursor = 0;
const pieces: Array<{ token: string; tokenId: number }> = [];
for (const span of spans) {
const [start, end] = span.offset;
const tokenId = (span as PromptTokenSpan).token_id;
if (start < 0 || end <= start || end > chars.length) {
opts.onError(
new Error(`Teacher forcing tokenize returned invalid span [${start}, ${end}) for continuation.`)
);
opts.onComplete('error');
return;
}
if (start > cursor) {
opts.onError(
new Error(
`Teacher forcing tokenize produced gap: span starts at ${start} but consumed cursor is ${cursor}.`
)
);
opts.onComplete('error');
return;
}
if (end <= cursor) {
continue;
}
if (typeof tokenId !== 'number' || !Number.isInteger(tokenId) || tokenId < 0) {
opts.onError(
new Error(
`Teacher forcing tokenize span is missing token_id at offset [${start}, ${end}).`
)
);
opts.onComplete('error');
return;
}
pieces.push({ token: chars.slice(cursor, end).join(''), tokenId });
cursor = end;
}
if (cursor !== chars.length) {
opts.onError(
new Error(
`Teacher forcing tokenize did not fully cover continuation: consumed ${cursor}/${chars.length} code points.`
)
);
opts.onComplete('error');
return;
}
if (!pieces.length) {
opts.onError(new Error('Teacher forcing tokenize produced no consumable pieces.'));
opts.onComplete('error');
return;
}
forcingPieces = pieces;
}
while (true) {
if (aborted) {
opts.onComplete('abort');
return;
}
if (steps.length >= maxTokens) {
opts.onComplete('length');
return;
}
const forcingExhausted = forcingEnabled && forcingPieceIndex >= forcingPieces.length;
if (forcingExhausted && stopAfterTeacherForcing) {
opts.onComplete('stop');
return;
}
const context = initialContext + generatedText;
const targetTokenId =
forcingEnabled && !forcingExhausted ? forcingPieces[forcingPieceIndex]!.tokenId : undefined;
let response: AttributionApiResponse;
try {
response = await fetchPredictionAttribute(
apiPrefix,
context,
null,
model,
'gen_attribute.html',
targetTokenId,
flowId,
steps.length,
);
} catch (err) {
const error = err instanceof Error ? err : new Error(String(err));
opts.onError(error);
opts.onComplete('error');
return;
}
if (aborted) {
opts.onComplete('abort');
return;
}
let token = response.target_token ?? '';
if (forcingEnabled && !forcingExhausted) {
token = forcingPieces[forcingPieceIndex]!.token;
const sliced = splitCodePointPrefix(remainingForcing, Array.from(token).length);
if (!sliced) {
opts.onError(
new Error(
`Teacher forcing piece consume failed at step=${forcingPieceIndex}: token="${token}", remaining="${remainingForcing}"`
)
);
opts.onComplete('error');
return;
}
remainingForcing = sliced.rest;
forcingPieceIndex++;
}
generatedText += token;
if (aborted) {
opts.onComplete('abort');
return;
}
const step: TokenGenStep = {
context,
promptRegionEnd,
response,
token,
currentText: generatedText,
};
const stepIndex = steps.length;
steps.push(step);
try {
opts.onStep(step, stepIndex);
} catch (err) {
const error = err instanceof Error ? err : new Error(String(err));
opts.onError(error);
opts.onComplete('error');
return;
}
if (!token || response.is_eos) {
opts.onComplete('stop');
return;
}
}
};
void loop();
return {
abort() {
aborted = true;
},
getStep(idx) {
return steps[idx];
},
getAllSteps() {
return steps.slice();
},
get tokenCount() {
return steps.length;
},
};
}
/** Hydrate a read-only handle for DAG refresh / exclude replay(顺序即步序)。 */
export function createHydratedTokenGenHandle(frozenSteps: TokenGenStep[]): TokenGenAttributionHandle {
const steps = frozenSteps.slice();
return {
abort() {
/* no-op */
},
getStep(idx) {
return steps[idx];
},
getAllSteps() {
return steps.slice();
},
get tokenCount() {
return steps.length;
},
};
}