| |
| |
| |
| |
| import type { AttributionApiResponse, PredictionAttributeModelVariant } from './attributionResultCache'; |
| import type { PromptTokenSpan } from './genAttributeDagPreprocess'; |
| import type { CompletionFinishReason } from '../utils/generationEndReasonLabel'; |
| import { fetchPredictionAttribute, fetchTokenize } from './predictionAttributeClient'; |
|
|
| |
| export const TOKEN_GEN_MAX_TOKENS_DEFAULT = 100; |
|
|
| function splitCodePointPrefix(text: string, prefixLength: number): { prefix: string; rest: string } | null { |
| if (prefixLength < 0) return null; |
| const chars = Array.from(text); |
| if (prefixLength > chars.length) return null; |
| return { |
| prefix: chars.slice(0, prefixLength).join(''), |
| rest: chars.slice(prefixLength).join(''), |
| }; |
| } |
|
|
| export type TokenGenStep = { |
| |
| context: string; |
| |
| |
| |
| promptRegionEnd: number; |
| response: AttributionApiResponse; |
| |
| token: string; |
| |
| currentText: string; |
| }; |
|
|
| export type TokenGenAttributionOptions = { |
| initialContext: string; |
| apiPrefix: string; |
| model: PredictionAttributeModelVariant; |
| |
| |
| |
| |
| teacherForcingContinuation?: string; |
| |
| |
| |
| |
| stopAfterTeacherForcing?: boolean; |
| |
| maxTokens?: number; |
| |
| onStep: (step: TokenGenStep, stepIndex: number) => void; |
| onComplete: (reason: CompletionFinishReason) => void; |
| onError: (err: Error) => void; |
| |
| flowId: string; |
| }; |
|
|
| export type TokenGenAttributionHandle = { |
| abort(): void; |
| getStep(idx: number): TokenGenStep | undefined; |
| getAllSteps(): TokenGenStep[]; |
| |
| readonly tokenCount: number; |
| }; |
|
|
| export function startTokenGenAttribution(opts: TokenGenAttributionOptions): TokenGenAttributionHandle { |
| const { |
| initialContext, |
| apiPrefix, |
| model, |
| maxTokens = TOKEN_GEN_MAX_TOKENS_DEFAULT, |
| stopAfterTeacherForcing = false, |
| flowId, |
| } = opts; |
| const tfOpt = opts.teacherForcingContinuation; |
| const forcingEnabled = typeof tfOpt === 'string' && tfOpt.length > 0; |
| const promptRegionEnd = initialContext.length; |
| let aborted = false; |
| let generatedText = ''; |
| let remainingForcing = tfOpt ?? ''; |
| let forcingPieces: Array<{ token: string; tokenId: number }> = []; |
| let forcingPieceIndex = 0; |
| const steps: TokenGenStep[] = []; |
|
|
| const loop = async (): Promise<void> => { |
| if (forcingEnabled) { |
| let spans; |
| try { |
| spans = await fetchTokenize(apiPrefix, tfOpt, model); |
| } catch (err) { |
| const error = err instanceof Error ? err : new Error(String(err)); |
| opts.onError(error); |
| opts.onComplete('error'); |
| return; |
| } |
| if (!spans.length) { |
| opts.onError(new Error('Teacher forcing tokenize returned empty spans.')); |
| opts.onComplete('error'); |
| return; |
| } |
| const chars = Array.from(tfOpt); |
| let cursor = 0; |
| const pieces: Array<{ token: string; tokenId: number }> = []; |
| for (const span of spans) { |
| const [start, end] = span.offset; |
| const tokenId = (span as PromptTokenSpan).token_id; |
| if (start < 0 || end <= start || end > chars.length) { |
| opts.onError( |
| new Error(`Teacher forcing tokenize returned invalid span [${start}, ${end}) for continuation.`) |
| ); |
| opts.onComplete('error'); |
| return; |
| } |
| if (start > cursor) { |
| opts.onError( |
| new Error( |
| `Teacher forcing tokenize produced gap: span starts at ${start} but consumed cursor is ${cursor}.` |
| ) |
| ); |
| opts.onComplete('error'); |
| return; |
| } |
| if (end <= cursor) { |
| continue; |
| } |
| if (typeof tokenId !== 'number' || !Number.isInteger(tokenId) || tokenId < 0) { |
| opts.onError( |
| new Error( |
| `Teacher forcing tokenize span is missing token_id at offset [${start}, ${end}).` |
| ) |
| ); |
| opts.onComplete('error'); |
| return; |
| } |
| pieces.push({ token: chars.slice(cursor, end).join(''), tokenId }); |
| cursor = end; |
| } |
| if (cursor !== chars.length) { |
| opts.onError( |
| new Error( |
| `Teacher forcing tokenize did not fully cover continuation: consumed ${cursor}/${chars.length} code points.` |
| ) |
| ); |
| opts.onComplete('error'); |
| return; |
| } |
| if (!pieces.length) { |
| opts.onError(new Error('Teacher forcing tokenize produced no consumable pieces.')); |
| opts.onComplete('error'); |
| return; |
| } |
| forcingPieces = pieces; |
| } |
|
|
| while (true) { |
| if (aborted) { |
| opts.onComplete('abort'); |
| return; |
| } |
| if (steps.length >= maxTokens) { |
| opts.onComplete('length'); |
| return; |
| } |
| const forcingExhausted = forcingEnabled && forcingPieceIndex >= forcingPieces.length; |
| if (forcingExhausted && stopAfterTeacherForcing) { |
| opts.onComplete('stop'); |
| return; |
| } |
|
|
| const context = initialContext + generatedText; |
| const targetTokenId = |
| forcingEnabled && !forcingExhausted ? forcingPieces[forcingPieceIndex]!.tokenId : undefined; |
|
|
| let response: AttributionApiResponse; |
| try { |
| response = await fetchPredictionAttribute( |
| apiPrefix, |
| context, |
| null, |
| model, |
| 'gen_attribute.html', |
| targetTokenId, |
| flowId, |
| steps.length, |
| ); |
| } catch (err) { |
| const error = err instanceof Error ? err : new Error(String(err)); |
| opts.onError(error); |
| opts.onComplete('error'); |
| return; |
| } |
|
|
| if (aborted) { |
| opts.onComplete('abort'); |
| return; |
| } |
|
|
| let token = response.target_token ?? ''; |
|
|
| if (forcingEnabled && !forcingExhausted) { |
| token = forcingPieces[forcingPieceIndex]!.token; |
| const sliced = splitCodePointPrefix(remainingForcing, Array.from(token).length); |
| if (!sliced) { |
| opts.onError( |
| new Error( |
| `Teacher forcing piece consume failed at step=${forcingPieceIndex}: token="${token}", remaining="${remainingForcing}"` |
| ) |
| ); |
| opts.onComplete('error'); |
| return; |
| } |
| remainingForcing = sliced.rest; |
| forcingPieceIndex++; |
| } |
| generatedText += token; |
|
|
| if (aborted) { |
| opts.onComplete('abort'); |
| return; |
| } |
|
|
| const step: TokenGenStep = { |
| context, |
| promptRegionEnd, |
| response, |
| token, |
| currentText: generatedText, |
| }; |
| const stepIndex = steps.length; |
| steps.push(step); |
|
|
| try { |
| opts.onStep(step, stepIndex); |
| } catch (err) { |
| const error = err instanceof Error ? err : new Error(String(err)); |
| opts.onError(error); |
| opts.onComplete('error'); |
| return; |
| } |
|
|
| if (!token || response.is_eos) { |
| opts.onComplete('stop'); |
| return; |
| } |
| } |
| }; |
|
|
| void loop(); |
|
|
| return { |
| abort() { |
| aborted = true; |
| }, |
| getStep(idx) { |
| return steps[idx]; |
| }, |
| getAllSteps() { |
| return steps.slice(); |
| }, |
| get tokenCount() { |
| return steps.length; |
| }, |
| }; |
| } |
|
|
| |
| export function createHydratedTokenGenHandle(frozenSteps: TokenGenStep[]): TokenGenAttributionHandle { |
| const steps = frozenSteps.slice(); |
| return { |
| abort() { |
| |
| }, |
| getStep(idx) { |
| return steps[idx]; |
| }, |
| getAllSteps() { |
| return steps.slice(); |
| }, |
| get tokenCount() { |
| return steps.length; |
| }, |
| }; |
| } |
|
|