| |
| |
| |
| |
|
|
| import {FrontendAnalyzeResult} from "../api/GLTR_API"; |
| import {calculateSurprisal, calculateSurprisalDensity} from "../utils/Util"; |
| import {getByteSurprisalColor, getTokenSurprisalColor, getDiffColor, getSemanticSimilarityColor} from "../utils/SurprisalColorConfig"; |
| import {TokenFragmentRect, RectCacheEntry, ZERO_WIDTH_FRAGMENT_PLACEHOLDER_PX} from "./types"; |
|
|
| |
| export interface DiffOverlayOptions { |
| enabled: boolean; |
| deltaByteSurprisals: number[]; |
| charToByteIndexMap: number[]; |
| } |
|
|
| |
| export interface SemanticOverlayOptions { |
| analysisMode: boolean; |
| |
| rawScoresNormed?: number[]; |
| } |
|
|
| export interface SvgOverlayManagerOptions { |
| |
| getTokenRealTopk: (rd: FrontendAnalyzeResult, tokenIndex: number) => [number, number] | undefined; |
| |
| addTokenEventListeners: (element: SVGGElement, tokenIndex: number, rd: FrontendAnalyzeResult) => void; |
| |
| tokenRenderStyle?: 'density' | 'classic'; |
| |
| disableInfoDensityRender?: boolean; |
| |
| diff?: DiffOverlayOptions; |
| |
| semantic?: SemanticOverlayOptions; |
| |
| |
| |
| |
| surprisalColorMax?: number; |
| } |
|
|
| export class SvgOverlayManager { |
| private rectCache: Map<string, RectCacheEntry> = new Map(); |
| |
| private semanticOverlayCache: Map<string, SVGRectElement> = new Map(); |
| private baseNode: HTMLElement; |
| private options: SvgOverlayManagerOptions; |
|
|
| constructor(baseNode: HTMLElement, options: SvgOverlayManagerOptions) { |
| this.baseNode = baseNode; |
| this.options = options; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| createSvgOverlay(positions: TokenFragmentRect[], rd: FrontendAnalyzeResult): SVGSVGElement { |
| const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); |
| svg.setAttribute('class', 'svg-overlay'); |
|
|
| |
| |
| |
| const containerWidth = this.baseNode.clientWidth || this.baseNode.offsetWidth || 800; |
| const containerHeight = this.baseNode.clientHeight || this.calculateContainerHeight(positions); |
|
|
|
|
| |
| svg.setAttribute('width', containerWidth.toString()); |
| svg.setAttribute('height', containerHeight.toString()); |
| svg.setAttribute('viewBox', `0 0 ${containerWidth} ${containerHeight}`); |
| svg.setAttribute('preserveAspectRatio', 'none'); |
|
|
| |
| svg.style.position = 'absolute'; |
| svg.style.top = '0'; |
| svg.style.left = '0'; |
| svg.style.width = '100%'; |
| svg.style.height = '100%'; |
| svg.style.pointerEvents = 'none'; |
| svg.style.zIndex = '1'; |
| |
|
|
| |
| this.rectCache.clear(); |
| this.semanticOverlayCache.clear(); |
|
|
| |
| const positionsByToken = this.groupPositionsByToken(positions); |
|
|
| |
| positionsByToken.forEach((tokenPositions, tokenIndex) => { |
| const group = this.createTokenGroup(svg, tokenPositions, tokenIndex, rd); |
| svg.appendChild(group); |
| }); |
|
|
| return svg; |
| } |
|
|
| |
| |
| |
| |
| |
| updateSvgPositions(svg: SVGSVGElement, positions: TokenFragmentRect[]): void { |
| |
| const containerWidth = this.baseNode.clientWidth || 0; |
| const containerHeight = this.baseNode.clientHeight || this.calculateContainerHeight(positions); |
|
|
| svg.setAttribute('width', containerWidth.toString()); |
| svg.setAttribute('height', containerHeight.toString()); |
| svg.setAttribute('viewBox', `0 0 ${containerWidth} ${containerHeight}`); |
|
|
| |
| |
| |
| positions.forEach(pos => { |
| const x = Math.max(0, pos.x); |
| const y = Math.max(0, pos.y); |
| const width = Math.max(1, this.displayWidth(pos)); |
| const height = Math.max(1, pos.height); |
| const attrs = { x, y, width, height }; |
|
|
| const cacheEntry = this.rectCache.get(pos.rectKey); |
| if (cacheEntry?.rect) { |
| Object.entries(attrs).forEach(([k, v]) => cacheEntry.rect.setAttribute(k, v.toString())); |
| } |
| const overlayRect = this.semanticOverlayCache.get(pos.rectKey); |
| if (overlayRect) { |
| Object.entries(attrs).forEach(([k, v]) => overlayRect.setAttribute(k, v.toString())); |
| } |
| }); |
| } |
|
|
| |
| |
| |
| getRectCache(): Map<string, RectCacheEntry> { |
| return this.rectCache; |
| } |
|
|
| |
| getSemanticOverlayCache(): Map<string, SVGRectElement> { |
| return this.semanticOverlayCache; |
| } |
|
|
| |
| |
| |
| |
| appendTokenRects(newPositions: TokenFragmentRect[], svg: SVGSVGElement, rd: FrontendAnalyzeResult): void { |
| const positionsByToken = this.groupPositionsByToken(newPositions); |
| positionsByToken.forEach((tokenPositions, tokenIndex) => { |
| const group = this.createTokenGroup(svg, tokenPositions, tokenIndex, rd); |
| svg.appendChild(group); |
| }); |
| } |
|
|
| |
| |
| |
| |
| updateBaseRectColors( |
| rd: FrontendAnalyzeResult, |
| overrides: { disableInfoDensityRender: boolean; tokenRenderStyle: 'density' | 'classic' } |
| ): void { |
| this.rectCache.forEach(({ rect, tokenIndex }, rectKey) => { |
| const color = this.computeBaseRectColor(rectKey, tokenIndex, rd, overrides); |
| rect.setAttribute('fill', color); |
| rect.setAttribute('data-target-color', color); |
| }); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| updateSemanticColors(rawScoresNormed: (number | undefined)[], fromTokenIndex = 0): number { |
| let count = 0; |
| for (let tokenIndex = fromTokenIndex; tokenIndex < rawScoresNormed.length; tokenIndex++) { |
| const score = rawScoresNormed[tokenIndex]; |
| const color = score !== undefined ? getSemanticSimilarityColor(score) : 'transparent'; |
| for (let i = 0; ; i++) { |
| const rectKey = `${tokenIndex}-${i}`; |
| const overlayRect = this.semanticOverlayCache.get(rectKey); |
| if (!overlayRect) break; |
| overlayRect.setAttribute('fill', color); |
| overlayRect.setAttribute('data-target-color', color); |
| count++; |
| } |
| } |
| return count; |
| } |
|
|
| |
| |
| |
| clearRectCache(): void { |
| this.rectCache.clear(); |
| this.semanticOverlayCache.clear(); |
| } |
|
|
| |
| |
| |
| hasMissingRects(positions: TokenFragmentRect[]): boolean { |
| return positions.some(pos => !this.rectCache.has(pos.rectKey)) || |
| positions.length !== this.rectCache.size; |
| } |
|
|
| |
| |
| |
| private calculateContainerHeight(positions: TokenFragmentRect[]): number { |
| const textLayer = this.baseNode.querySelector('.text-layer') as HTMLElement; |
| const containerRect = this.baseNode.getBoundingClientRect(); |
| |
| |
| const maxTokenBottom = positions.length > 0 |
| ? Math.max(...positions.map(p => p.y + p.height)) |
| : 0; |
|
|
| if (textLayer) { |
| |
| const textLayerRect = textLayer.getBoundingClientRect(); |
| return Math.max( |
| textLayerRect.height || 0, |
| maxTokenBottom, |
| this.baseNode.clientHeight || 0 |
| ); |
| } else { |
| |
| return Math.max( |
| maxTokenBottom, |
| this.baseNode.scrollHeight || 0, |
| this.baseNode.clientHeight || 0, |
| containerRect.height || 0, |
| 600 |
| ); |
| } |
| } |
|
|
| |
| |
| |
| private groupPositionsByToken(positions: TokenFragmentRect[]): Map<number, TokenFragmentRect[]> { |
| const positionsByToken = new Map<number, TokenFragmentRect[]>(); |
| positions.forEach(pos => { |
| if (!positionsByToken.has(pos.tokenIndex)) { |
| positionsByToken.set(pos.tokenIndex, []); |
| } |
| positionsByToken.get(pos.tokenIndex)!.push(pos); |
| }); |
| return positionsByToken; |
| } |
|
|
| |
| |
| |
| private createTokenGroup( |
| svg: SVGSVGElement, |
| tokenPositions: TokenFragmentRect[], |
| tokenIndex: number, |
| rd: FrontendAnalyzeResult |
| ): SVGGElement { |
| const group = document.createElementNS('http://www.w3.org/2000/svg', 'g'); |
| group.setAttribute('data-token-index', tokenIndex.toString()); |
| group.setAttribute('class', 'token-group'); |
| group.style.pointerEvents = 'auto'; |
| group.style.cursor = 'pointer'; |
|
|
| tokenPositions.forEach(pos => { |
| const rect = this.createRect(pos, tokenIndex, rd); |
| group.appendChild(rect); |
| |
| const sem = this.options.semantic; |
| if (sem?.analysisMode) { |
| const overlayRect = this.createSemanticOverlayRect(pos, tokenIndex, rd); |
| group.appendChild(overlayRect); |
| } |
| }); |
|
|
| this.options.addTokenEventListeners(group, tokenIndex, rd); |
|
|
| return group; |
| } |
|
|
| |
| |
| |
| private computeBaseRectColor( |
| _rectKey: string, |
| tokenIndex: number, |
| rd: FrontendAnalyzeResult, |
| overrides?: { disableInfoDensityRender: boolean; tokenRenderStyle: 'density' | 'classic' } |
| ): string { |
| const disableInfoDensityRender = overrides?.disableInfoDensityRender ?? this.options.disableInfoDensityRender; |
| const tokenRenderStyle = overrides?.tokenRenderStyle ?? this.options.tokenRenderStyle ?? 'classic'; |
|
|
| if (this.options.diff?.enabled && this.options.diff.deltaByteSurprisals.length > 0) { |
| const diff = this.options.diff; |
| const tokenData = rd.bpe_strings[tokenIndex]; |
| const offset = tokenData.offset; |
| const charStart = offset[0]; |
| const charEnd = offset[1]; |
| const charToByteIndexMap = diff.charToByteIndexMap; |
| const deltaByteSurprisals = diff.deltaByteSurprisals; |
| const tokenByteDeltas: number[] = []; |
|
|
| if (!charToByteIndexMap.length) return getDiffColor(0); |
| const byteStart = charToByteIndexMap[charStart] ?? charStart; |
| const byteEnd = charToByteIndexMap[charEnd] ?? charEnd; |
| for (let byteIdx = byteStart; byteIdx < byteEnd && byteIdx < deltaByteSurprisals.length; byteIdx++) { |
| tokenByteDeltas.push(deltaByteSurprisals[byteIdx]); |
| } |
| const avgDelta = tokenByteDeltas.length > 0 |
| ? tokenByteDeltas.reduce((sum, val) => sum + val, 0) / tokenByteDeltas.length |
| : 0; |
| return getDiffColor(avgDelta); |
| } |
| if (disableInfoDensityRender) return 'transparent'; |
| const tokenData = rd.bpe_strings[tokenIndex]; |
| const cap = this.options.surprisalColorMax; |
| if (tokenRenderStyle === 'classic') { |
| const topk = this.options.getTokenRealTopk(rd, tokenIndex); |
| const surprisal = topk != null ? calculateSurprisal(topk[1]) : 0; |
| return getTokenSurprisalColor(surprisal, undefined, cap); |
| } |
| return getByteSurprisalColor(calculateSurprisalDensity(tokenData), 1, undefined, cap); |
| } |
|
|
| |
| |
| |
| private createRect( |
| pos: TokenFragmentRect, |
| tokenIndex: number, |
| rd: FrontendAnalyzeResult |
| ): SVGRectElement { |
| const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect'); |
| this.setRectGeometry(rect, pos); |
|
|
| |
| rect.setAttribute('data-token-index', pos.tokenIndex.toString()); |
| rect.setAttribute('data-fragment-index', pos.fragmentIndex.toString()); |
| rect.setAttribute('data-rect-key', pos.rectKey); |
|
|
| |
| this.rectCache.set(pos.rectKey, { rect, tokenIndex: pos.tokenIndex }); |
|
|
| const color = this.computeBaseRectColor(pos.rectKey, pos.tokenIndex, rd); |
| |
| rect.setAttribute('fill', color); |
| rect.setAttribute('data-target-color', color); |
|
|
| |
| rect.style.pointerEvents = 'auto'; |
|
|
| return rect; |
| } |
|
|
| |
| |
| |
| |
| private createSemanticOverlayRect( |
| pos: TokenFragmentRect, |
| tokenIndex: number, |
| rd: FrontendAnalyzeResult |
| ): SVGRectElement { |
| const rect = document.createElementNS('http://www.w3.org/2000/svg', 'rect'); |
| const sem = this.options.semantic!; |
| const score = sem.rawScoresNormed![tokenIndex]; |
| const color = score !== undefined ? getSemanticSimilarityColor(score) : 'transparent'; |
|
|
| this.setRectGeometry(rect, pos); |
| rect.setAttribute('data-token-index', pos.tokenIndex.toString()); |
| rect.setAttribute('data-fragment-index', pos.fragmentIndex.toString()); |
| rect.setAttribute('data-rect-key', pos.rectKey); |
| rect.setAttribute('fill', color); |
| rect.setAttribute('data-target-color', color); |
| rect.style.pointerEvents = 'auto'; |
|
|
| this.semanticOverlayCache.set(pos.rectKey, rect); |
| return rect; |
| } |
|
|
| |
| private setRectGeometry(rect: SVGRectElement, pos: TokenFragmentRect): void { |
| const x = Math.max(0, pos.x); |
| const y = Math.max(0, pos.y); |
| const width = Math.max(1, this.displayWidth(pos)); |
| const height = Math.max(1, pos.height); |
| rect.setAttribute('x', x.toString()); |
| rect.setAttribute('y', y.toString()); |
| rect.setAttribute('width', width.toString()); |
| rect.setAttribute('height', height.toString()); |
| rect.setAttribute('rx', '6'); |
| rect.setAttribute('ry', '6'); |
| } |
|
|
| |
| |
| |
| |
| |
| private displayWidth(pos: TokenFragmentRect): number { |
| return pos.width > 0 ? pos.width : ZERO_WIDTH_FRAGMENT_PLACEHOLDER_PX; |
| } |
| } |
|
|
|
|