Commit ·
494c9e4
0
Parent(s):
initial beta release
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .cursorindexingignore +3 -0
- .dockerignore +51 -0
- .gitattributes +36 -0
- .gitignore +42 -0
- Dockerfile +66 -0
- LICENSE +201 -0
- README.md +58 -0
- backend/__init__.py +18 -0
- backend/access_log.py +233 -0
- backend/api/__init__.py +2 -0
- backend/api/analyze.py +412 -0
- backend/api/analyze_semantic.py +212 -0
- backend/api/demo.py +183 -0
- backend/api/fetch_url.py +221 -0
- backend/api/folder.py +102 -0
- backend/api/model_switch.py +229 -0
- backend/api/openai_completions.py +379 -0
- backend/api/prediction_attribute.py +79 -0
- backend/api/sse_utils.py +181 -0
- backend/api/static.py +60 -0
- backend/api/utils.py +118 -0
- backend/app_context.py +110 -0
- backend/class_register.py +16 -0
- backend/completion_generator.py +558 -0
- backend/data_utils.py +97 -0
- backend/demo_folder.py +339 -0
- backend/device.py +97 -0
- backend/language_checker.py +422 -0
- backend/load_utils.py +69 -0
- backend/logging_config.py +37 -0
- backend/model_loader.py +169 -0
- backend/model_manager.py +233 -0
- backend/next_token_topk.py +26 -0
- backend/oom.py +55 -0
- backend/path_utils.py +92 -0
- backend/pred_topk_format.py +44 -0
- backend/prediction_attributor.py +185 -0
- backend/project_registry.py +72 -0
- backend/quantization_config.py +42 -0
- backend/runtime_config.py +402 -0
- backend/schemas.py +43 -0
- backend/semantic_analyzer.py +280 -0
- client/src/analysis.html +188 -0
- client/src/attribution.html +166 -0
- client/src/chat.html +171 -0
- client/src/compare.html +69 -0
- client/src/content/home.en.html +91 -0
- client/src/content/home.zh.html +68 -0
- client/src/content/images/attribute-dark.png +3 -0
- client/src/content/images/attribute.png +3 -0
.cursorindexingignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Don't index SpecStory auto-save files, but allow explicit context inclusion via @ references
|
| 3 |
+
.specstory/**
|
.dockerignore
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --- 核心语言与依赖 ---
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
.venv/
|
| 5 |
+
venv/
|
| 6 |
+
env/
|
| 7 |
+
node_modules/
|
| 8 |
+
client/src/node_modules/
|
| 9 |
+
client/src/.cache-loader/
|
| 10 |
+
|
| 11 |
+
# --- 构建产物与缓存 ---
|
| 12 |
+
client/dist/
|
| 13 |
+
build/
|
| 14 |
+
dist/
|
| 15 |
+
*.egg-info/
|
| 16 |
+
.cache_huggingface/
|
| 17 |
+
*.tsbuildinfo
|
| 18 |
+
|
| 19 |
+
# --- 项目特定配置 ---
|
| 20 |
+
# 忽略所有数据,防止误传大文件
|
| 21 |
+
data/*
|
| 22 |
+
# 白名单:只保留public文件夹
|
| 23 |
+
!data/demo/
|
| 24 |
+
data/demo/*
|
| 25 |
+
!data/demo/public/
|
| 26 |
+
|
| 27 |
+
# 忽略临时文件和日志
|
| 28 |
+
notes.md
|
| 29 |
+
.env
|
| 30 |
+
*.log
|
| 31 |
+
npm-debug.log*
|
| 32 |
+
yarn-debug.log*
|
| 33 |
+
yarn-error.log*
|
| 34 |
+
|
| 35 |
+
# --- 系统与 IDE ---
|
| 36 |
+
.DS_Store
|
| 37 |
+
.DS_Store?
|
| 38 |
+
._*
|
| 39 |
+
.Spotlight-V100
|
| 40 |
+
.Trashes
|
| 41 |
+
ehthumbs.db
|
| 42 |
+
Thumbs.db
|
| 43 |
+
.vscode/
|
| 44 |
+
.idea/
|
| 45 |
+
*.swp
|
| 46 |
+
*.swo
|
| 47 |
+
*~
|
| 48 |
+
|
| 49 |
+
# --- Git ---
|
| 50 |
+
.git
|
| 51 |
+
.gitignore
|
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.cursor/
|
| 2 |
+
.cache_huggingface
|
| 3 |
+
# --- 核心语言与依赖 ---
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.py[cod]
|
| 6 |
+
.venv/
|
| 7 |
+
venv/
|
| 8 |
+
env/
|
| 9 |
+
node_modules/
|
| 10 |
+
client/src/node_modules/
|
| 11 |
+
client/src/.cache-loader/
|
| 12 |
+
|
| 13 |
+
# --- 构建产物 ---
|
| 14 |
+
client/dist/
|
| 15 |
+
build/
|
| 16 |
+
dist/
|
| 17 |
+
*.egg-info/
|
| 18 |
+
|
| 19 |
+
# --- 项目特定配置 ---
|
| 20 |
+
# 忽略所有数据,防止误传大文件
|
| 21 |
+
data/*
|
| 22 |
+
# 白名单:只保留 GLTR 演示数据
|
| 23 |
+
!data/demo/
|
| 24 |
+
data/demo/*
|
| 25 |
+
!data/demo/public/
|
| 26 |
+
data/demo/public/.deleted/
|
| 27 |
+
|
| 28 |
+
# 忽略临时笔记和 HuggingFace 缓存
|
| 29 |
+
notes/*
|
| 30 |
+
user_dialog_history/*
|
| 31 |
+
.cache_huggingface/
|
| 32 |
+
.env
|
| 33 |
+
|
| 34 |
+
# --- 系统与 IDE ---
|
| 35 |
+
.DS_Store
|
| 36 |
+
.idea/
|
| 37 |
+
.vscode/
|
| 38 |
+
*.swp
|
| 39 |
+
*.log
|
| 40 |
+
.specstory
|
| 41 |
+
scripts/log.py
|
| 42 |
+
scripts/results/
|
Dockerfile
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# syntax=docker/dockerfile:1
|
| 2 |
+
|
| 3 |
+
# -----------------------------------------------------------------------------
|
| 4 |
+
# Frontend build stage (stable Node toolchain for webpack/TS)
|
| 5 |
+
# -----------------------------------------------------------------------------
|
| 6 |
+
FROM node:20-bookworm-slim AS frontend
|
| 7 |
+
WORKDIR /app/client/src
|
| 8 |
+
|
| 9 |
+
COPY client/src/package.json client/src/package-lock.json ./
|
| 10 |
+
RUN npm ci
|
| 11 |
+
|
| 12 |
+
COPY client/src/ ./
|
| 13 |
+
# prebuild 需要读取的 JSON,否则 updateIntroHTML.js 会 ENOENT
|
| 14 |
+
COPY data/demo/public/ /app/data/demo/public/
|
| 15 |
+
RUN npm run build
|
| 16 |
+
|
| 17 |
+
# -----------------------------------------------------------------------------
|
| 18 |
+
# Runtime stage (Hugging Face Spaces runs container as UID 1000)
|
| 19 |
+
# Reference: https://huggingface.co/docs/hub/spaces-sdks-docker
|
| 20 |
+
# -----------------------------------------------------------------------------
|
| 21 |
+
FROM python:3.10-slim
|
| 22 |
+
|
| 23 |
+
# System deps (git for Hugging Face Hub downloads, build-essential for triton/AWQ CUDA kernel compilation)
|
| 24 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 25 |
+
git \
|
| 26 |
+
build-essential \
|
| 27 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 28 |
+
|
| 29 |
+
# Create a non-root user with UID 1000 (mandatory in Spaces)
|
| 30 |
+
RUN useradd -m -u 1000 user
|
| 31 |
+
USER user
|
| 32 |
+
|
| 33 |
+
# 只设置构建时需要的环境变量(pip install 需要这些路径)
|
| 34 |
+
ENV HOME=/home/user \
|
| 35 |
+
PATH=/home/user/.local/bin:$PATH
|
| 36 |
+
|
| 37 |
+
WORKDIR $HOME/app
|
| 38 |
+
|
| 39 |
+
# Python deps (installed to user site-packages when system site is not writable)
|
| 40 |
+
COPY --chown=user:users requirements.txt ./
|
| 41 |
+
RUN pip install --no-cache-dir --upgrade pip && \
|
| 42 |
+
pip install --no-cache-dir -r requirements.txt
|
| 43 |
+
|
| 44 |
+
# 运行时环境变量移到依赖安装之后(这些变量不影响依赖安装)
|
| 45 |
+
ENV PYTHONUNBUFFERED=1
|
| 46 |
+
|
| 47 |
+
# 启用 hf-transfer 加速下载
|
| 48 |
+
ENV HF_HUB_ENABLE_HF_TRANSFER=1
|
| 49 |
+
|
| 50 |
+
# App source(仅复制运行时需要的路径)
|
| 51 |
+
COPY --chown=user:users *.py *.yaml ./
|
| 52 |
+
COPY --chown=user:users backend/ ./backend/
|
| 53 |
+
COPY --chown=user:users data/demo/public/ ./data/demo/public/
|
| 54 |
+
|
| 55 |
+
# Frontend build artifacts
|
| 56 |
+
COPY --chown=user:users --from=frontend /app/client/dist ./client/dist
|
| 57 |
+
|
| 58 |
+
# ENV FORCE_INT8=1
|
| 59 |
+
|
| 60 |
+
EXPOSE 7860
|
| 61 |
+
# 硬件的模型适配:
|
| 62 |
+
# 在CPU basic 上使用0.6b模型能达到及格的速度
|
| 63 |
+
# 在CPU upgrade 上使用1.7b模型能达到及格的速度
|
| 64 |
+
# 在本地M5 16G芯片上使用4b模型能达到及格的速度(瓶颈是内存大小);M5 16G内存仅能同时支持一种分析模型(信息密度分析或语义分析)
|
| 65 |
+
CMD ["python", "run.py", "--no_auto_load", "--port", "7860", "--model", "qwen3-1.7b", "--semantic_model", "qwen3-1.7b-instruct"]
|
| 66 |
+
# CMD ["python", "run.py", "--no_auto_load", "--port", "7860", "--model", "qwen3-0.6b", "--semantic_model", "qwen3-0.6b-instruct"]
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Info Lens
|
| 3 |
+
emoji: 🔭
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: docker
|
| 7 |
+
short_description: Explore the informational nature of LLMs and language.
|
| 8 |
+
tags:
|
| 9 |
+
- nlp
|
| 10 |
+
- text-analysis
|
| 11 |
+
- information
|
| 12 |
+
- visualization
|
| 13 |
+
- reading-tools
|
| 14 |
+
app_port: 7860
|
| 15 |
+
pinned: false
|
| 16 |
+
license: apache-2.0
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# Info Lens
|
| 20 |
+
|
| 21 |
+
**Info Lens** is a small toolbox for exploring the informational nature of LLMs and language.
|
| 22 |
+
|
| 23 |
+
## Legacy name: InfoRadar
|
| 24 |
+
|
| 25 |
+
InfoRadar is the former project and repo name. It still appears in parts of the codebase.
|
| 26 |
+
|
| 27 |
+
## 📦 Quick Start
|
| 28 |
+
|
| 29 |
+
### Using Docker (Recommended)
|
| 30 |
+
|
| 31 |
+
This is the simplest way to run Info Lens:
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
# 1. Build the image
|
| 35 |
+
docker build -t inforadar .
|
| 36 |
+
|
| 37 |
+
# 2. Run the container (Map port to 7860)
|
| 38 |
+
docker run -p 7860:7860 inforadar
|
| 39 |
+
```
|
| 40 |
+
Once running, visit `http://localhost:7860` in your browser.
|
| 41 |
+
|
| 42 |
+
### Local Development
|
| 43 |
+
|
| 44 |
+
**Backend Environment**:
|
| 45 |
+
```bash
|
| 46 |
+
pip install -r requirements.txt
|
| 47 |
+
python server.py
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
**Frontend Build**:
|
| 51 |
+
```bash
|
| 52 |
+
cd client/src && npm install && npm run build
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
## 📜 License
|
| 56 |
+
|
| 57 |
+
Apache 2.0
|
| 58 |
+
|
backend/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .class_register import REGISTERED_MODELS
|
| 2 |
+
|
| 3 |
+
'''
|
| 4 |
+
Import all classes in this directory so that classes with
|
| 5 |
+
@register_model are registered.
|
| 6 |
+
'''
|
| 7 |
+
|
| 8 |
+
from os.path import basename, dirname, join
|
| 9 |
+
from glob import glob
|
| 10 |
+
pwd = dirname(__file__)
|
| 11 |
+
for x in glob(join(pwd, '*.py')):
|
| 12 |
+
if not basename(x).startswith('__'):
|
| 13 |
+
__import__('backend.' + basename(x)[:-3],
|
| 14 |
+
globals(), locals())
|
| 15 |
+
|
| 16 |
+
__all__ = [
|
| 17 |
+
'REGISTERED_MODELS'
|
| 18 |
+
]
|
backend/access_log.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""服务访问日志"""
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
from flask import request
|
| 6 |
+
import threading
|
| 7 |
+
|
| 8 |
+
# 全局请求计数器和锁
|
| 9 |
+
_request_counter = 0
|
| 10 |
+
_request_counter_lock = threading.Lock()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _get_client_ip():
|
| 14 |
+
"""获取请求来源IP"""
|
| 15 |
+
try:
|
| 16 |
+
if request.headers.get('X-Forwarded-For'):
|
| 17 |
+
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
|
| 18 |
+
elif request.headers.get('X-Real-IP'):
|
| 19 |
+
return request.headers.get('X-Real-IP')
|
| 20 |
+
else:
|
| 21 |
+
return request.remote_addr
|
| 22 |
+
except RuntimeError as e:
|
| 23 |
+
if "Working outside of request context" in str(e):
|
| 24 |
+
# 在没有请求上下文时返回本地地址
|
| 25 |
+
return "unknown"
|
| 26 |
+
else:
|
| 27 |
+
raise
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_client_ip():
|
| 31 |
+
"""获取客户端IP(供其他模块使用)"""
|
| 32 |
+
return _get_client_ip()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _log_request(event_type: str, details: str = "", client_ip: str = None):
|
| 36 |
+
"""打印服务请求日志"""
|
| 37 |
+
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
| 38 |
+
ip = client_ip if client_ip is not None else _get_client_ip()
|
| 39 |
+
|
| 40 |
+
log_msg = f"[{timestamp}] {ip:15s} | {event_type}"
|
| 41 |
+
if details:
|
| 42 |
+
log_msg += f" | {details}"
|
| 43 |
+
|
| 44 |
+
print(log_msg)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def log_page_load(path: str):
|
| 48 |
+
"""记录页面访问(含 ?ref= 参数)"""
|
| 49 |
+
details = f"path='{path}'"
|
| 50 |
+
try:
|
| 51 |
+
ref = request.args.get("ref")
|
| 52 |
+
if ref:
|
| 53 |
+
details += f", ref='{ref}'"
|
| 54 |
+
except RuntimeError:
|
| 55 |
+
pass
|
| 56 |
+
_log_request("📄 页面访问", details)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def log_demo_file(path: str):
|
| 60 |
+
"""记录demo文件请求"""
|
| 61 |
+
_log_request("🎯 demo文件", f"file='{path}'")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def log_analyze_request(text: str, stream_mode: bool = False, client_ip: str = None):
|
| 65 |
+
"""
|
| 66 |
+
记录收到分析请求
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
int: 请求ID
|
| 70 |
+
"""
|
| 71 |
+
global _request_counter
|
| 72 |
+
|
| 73 |
+
# 生成请求ID
|
| 74 |
+
with _request_counter_lock:
|
| 75 |
+
_request_counter += 1
|
| 76 |
+
request_id = _request_counter
|
| 77 |
+
|
| 78 |
+
preview_length = 100
|
| 79 |
+
text_preview = text[:preview_length] + '......' if text and len(text) > preview_length else (text if text else '')
|
| 80 |
+
char_count = len(text) if text else 0
|
| 81 |
+
byte_count = len(text.encode('utf-8')) if text else 0
|
| 82 |
+
mode_str = "(stream)" if stream_mode else ""
|
| 83 |
+
|
| 84 |
+
details = f"req_id={request_id}, text='{text_preview}', chars={char_count}, bytes={byte_count}"
|
| 85 |
+
_log_request(f"📥 收到请求{mode_str}", details, client_ip)
|
| 86 |
+
|
| 87 |
+
return request_id
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def log_analyze_start(request_id: int, wait_time: float, stream_mode: bool = False):
|
| 91 |
+
"""记录开始处理分析请求(内部事件)"""
|
| 92 |
+
from backend.app_context import get_verbose
|
| 93 |
+
if not get_verbose():
|
| 94 |
+
return
|
| 95 |
+
mode_str = "(stream)" if stream_mode else ""
|
| 96 |
+
print(f"\t🔄 API analyze {mode_str} start: req_id={request_id}, wait_time={wait_time:.2f}s")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def log_fetch_url(url: str, char_count: int = None):
|
| 100 |
+
"""记录URL抓取请求"""
|
| 101 |
+
details = f"url='{url}'"
|
| 102 |
+
if char_count is not None:
|
| 103 |
+
details += f", chars={char_count}"
|
| 104 |
+
_log_request("🌐 URL抓取", details)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def log_check_admin(success: bool, token: str = None):
|
| 108 |
+
"""记录管理员权限检查"""
|
| 109 |
+
status = "成功" if success else "失败"
|
| 110 |
+
details = f"结果={status}"
|
| 111 |
+
if not success and token:
|
| 112 |
+
details += f", token='{token}'"
|
| 113 |
+
_log_request("🔐 管理员权限检查", details)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def log_analyze_semantic_start(request_id: int, wait_time: float, stream_mode: bool = False):
|
| 117 |
+
"""记录开始处理 semantic 分析请求(内部事件)"""
|
| 118 |
+
from backend.app_context import get_verbose
|
| 119 |
+
if not get_verbose():
|
| 120 |
+
return
|
| 121 |
+
mode_str = "(stream)" if stream_mode else ""
|
| 122 |
+
print(f"\t🔄 API analyze_semantic {mode_str} start: req_id={request_id}, wait_time={wait_time:.2f}s")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def log_analyze_semantic_request(query: str, text: str, client_ip: str = None):
|
| 126 |
+
"""
|
| 127 |
+
记录收到 semantic 分析请求
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
int: 请求ID
|
| 131 |
+
"""
|
| 132 |
+
global _request_counter
|
| 133 |
+
|
| 134 |
+
with _request_counter_lock:
|
| 135 |
+
_request_counter += 1
|
| 136 |
+
request_id = _request_counter
|
| 137 |
+
|
| 138 |
+
preview = 50
|
| 139 |
+
q_preview = query[:preview] + "..." if len(query) > preview else query
|
| 140 |
+
t_preview = text[:preview] + "..." if len(text) > preview else text
|
| 141 |
+
details = f"req_id={request_id}, query='{q_preview}', text='{t_preview}', chars={len(text)}"
|
| 142 |
+
_log_request("📥 semantic 分析请求", details, client_ip)
|
| 143 |
+
return request_id
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def log_openai_completions_start(request_id: int, wait_time: float):
|
| 147 |
+
"""记录开始处理 OpenAI completions 请求(内部事件)"""
|
| 148 |
+
from backend.app_context import get_verbose
|
| 149 |
+
if not get_verbose():
|
| 150 |
+
return
|
| 151 |
+
print(f"\t🔄 API openai_completions start: req_id={request_id}, wait_time={wait_time:.2f}s")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def log_openai_completions_request(
|
| 155 |
+
model: str, prompt: str, client_ip: str = None,
|
| 156 |
+
):
|
| 157 |
+
"""
|
| 158 |
+
记录收到 OpenAI completions 请求
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
int: 请求ID
|
| 162 |
+
"""
|
| 163 |
+
global _request_counter
|
| 164 |
+
|
| 165 |
+
with _request_counter_lock:
|
| 166 |
+
_request_counter += 1
|
| 167 |
+
request_id = _request_counter
|
| 168 |
+
|
| 169 |
+
preview = 50
|
| 170 |
+
p_preview = prompt[:preview] + "..." if len(prompt) > preview else prompt
|
| 171 |
+
details = (
|
| 172 |
+
f"req_id={request_id}, model='{model}', "
|
| 173 |
+
f"prompt='{p_preview}', chars={len(prompt)}"
|
| 174 |
+
)
|
| 175 |
+
_log_request("📥 openai completions 请求", details, client_ip)
|
| 176 |
+
return request_id
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def log_prediction_attribute_request(
|
| 180 |
+
context: str,
|
| 181 |
+
target_prediction: Optional[str],
|
| 182 |
+
model: str,
|
| 183 |
+
client_ip: str = None,
|
| 184 |
+
) -> int:
|
| 185 |
+
"""
|
| 186 |
+
记录收到 prediction_attribute 请求。
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
int: 请求 ID(与其它 API 的 req_id 同源递增)
|
| 190 |
+
"""
|
| 191 |
+
global _request_counter
|
| 192 |
+
|
| 193 |
+
with _request_counter_lock:
|
| 194 |
+
_request_counter += 1
|
| 195 |
+
request_id = _request_counter
|
| 196 |
+
|
| 197 |
+
preview = 50
|
| 198 |
+
c_preview = context[:preview] + "..." if len(context) > preview else context
|
| 199 |
+
if target_prediction is None:
|
| 200 |
+
t_preview = "<top-1>"
|
| 201 |
+
else:
|
| 202 |
+
t_preview = (
|
| 203 |
+
target_prediction[:preview] + "..."
|
| 204 |
+
if len(target_prediction) > preview
|
| 205 |
+
else target_prediction
|
| 206 |
+
)
|
| 207 |
+
details = (
|
| 208 |
+
f"req_id={request_id}, model={model!r}, context='{c_preview}', target='{t_preview}', "
|
| 209 |
+
f"context_chars={len(context)}"
|
| 210 |
+
)
|
| 211 |
+
_log_request("📥 prediction_attribute 请求", details, client_ip)
|
| 212 |
+
return request_id
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def log_openai_completions_prompt_request(
|
| 216 |
+
model: str,
|
| 217 |
+
user_prompt: str,
|
| 218 |
+
system: Optional[str] = None,
|
| 219 |
+
client_ip: str = None,
|
| 220 |
+
) -> None:
|
| 221 |
+
"""记录 POST /v1/completions/prompt(仅拼装 chat template,不分配 req_id)。"""
|
| 222 |
+
preview = 50
|
| 223 |
+
|
| 224 |
+
def _pv(s: str) -> str:
|
| 225 |
+
return s[:preview] + "..." if len(s) > preview else s
|
| 226 |
+
|
| 227 |
+
up = _pv(user_prompt)
|
| 228 |
+
if system is None:
|
| 229 |
+
details = f"model='{model}', user_prompt='{up}'"
|
| 230 |
+
else:
|
| 231 |
+
details = f"model='{model}', system='{_pv(system)}', user_prompt='{up}'"
|
| 232 |
+
_log_request("📥 openai completions/prompt 请求", details, client_ip)
|
| 233 |
+
|
backend/api/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API 路由模块"""
|
| 2 |
+
|
backend/api/analyze.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""文本分析 API"""
|
| 2 |
+
import gc
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
import queue
|
| 6 |
+
import threading
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from backend.schemas import create_empty_analysis_result
|
| 9 |
+
from backend.model_manager import project_registry, DEFAULT_MODEL, _inference_lock
|
| 10 |
+
from model_paths import resolve_hf_path
|
| 11 |
+
from backend.oom import exit_if_oom
|
| 12 |
+
from backend.api.sse_utils import (
|
| 13 |
+
SSEProgressReporter,
|
| 14 |
+
consume_progress_queue,
|
| 15 |
+
send_result_event,
|
| 16 |
+
send_error_event,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# 自定义异常:排队超时
|
| 21 |
+
class QueueTimeoutError(Exception):
|
| 22 |
+
"""排队等待获取锁超时"""
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# 使用 model_manager 中的统一推理锁(与 analyze_semantic 共用)
|
| 27 |
+
# 单次分析的总处理时长限制(秒)
|
| 28 |
+
ANALYSIS_TIMEOUT = 60.0
|
| 29 |
+
# 等待获取锁的最大时间(秒)- 如果排队时间过长,直接拒绝请求
|
| 30 |
+
LOCK_WAIT_TIMEOUT = 10.0
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _analyze_result_model_display(model: Optional[str]) -> Optional[str]:
|
| 34 |
+
"""主分析 result.model:对外统一为 HuggingFace 仓库 id(与 model_paths.resolve_hf_path 一致)。"""
|
| 35 |
+
if not model or not str(model).strip():
|
| 36 |
+
return None
|
| 37 |
+
return resolve_hf_path(str(model).strip())
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _build_response(model: str, text: str, result):
|
| 41 |
+
"""构建标准响应"""
|
| 42 |
+
# 将 model 添加到 result 中,并确保 model 在最前面
|
| 43 |
+
if not isinstance(result, dict):
|
| 44 |
+
result = {}
|
| 45 |
+
result = result.copy()
|
| 46 |
+
# 如果 result 中已有 model,先移除
|
| 47 |
+
if 'model' in result:
|
| 48 |
+
model_value = result.pop('model')
|
| 49 |
+
else:
|
| 50 |
+
model_value = model
|
| 51 |
+
# 重新构建 result,确保 model 在最前面
|
| 52 |
+
result = {'model': _analyze_result_model_display(model_value), **result}
|
| 53 |
+
return {
|
| 54 |
+
"request": {'text': text},
|
| 55 |
+
"result": result
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _error_response(model: str, text: str, message: str, status_code: int):
|
| 60 |
+
"""构建错误响应(统一格式)"""
|
| 61 |
+
# 统一错误格式:包含 success=false 和 message
|
| 62 |
+
result = create_empty_analysis_result(message, _analyze_result_model_display(model))
|
| 63 |
+
return {
|
| 64 |
+
"success": False,
|
| 65 |
+
"message": message,
|
| 66 |
+
"request": {'text': text or ''},
|
| 67 |
+
"result": result
|
| 68 |
+
}, status_code
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _validate_and_prepare_request(analyze_request):
|
| 72 |
+
"""
|
| 73 |
+
验证请求并准备参数
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
(model, text, error_msg, status_code) 元组
|
| 77 |
+
如果验证失败,返回 (None, None, error_msg, status_code)
|
| 78 |
+
如果成功,返回 (model, text, None, None)
|
| 79 |
+
"""
|
| 80 |
+
model = analyze_request.get('model')
|
| 81 |
+
text = analyze_request.get('text')
|
| 82 |
+
|
| 83 |
+
if not text:
|
| 84 |
+
return None, None, "缺少分析文本,请提供 text 字段", 400
|
| 85 |
+
|
| 86 |
+
# 获取默认模型(使用模块级上下文以获取持久化的当前活动模型)
|
| 87 |
+
from backend.app_context import get_app_context
|
| 88 |
+
context = get_app_context(prefer_module_context=True)
|
| 89 |
+
default_model = context.model_name if context.model_name else DEFAULT_MODEL
|
| 90 |
+
|
| 91 |
+
# 处理 default、None 或空字符串,使用默认模型
|
| 92 |
+
if not model or model == 'default' or model == '':
|
| 93 |
+
model = default_model
|
| 94 |
+
else:
|
| 95 |
+
# 只允许使用默认模型,其他模型请求将被拒绝
|
| 96 |
+
if model != default_model:
|
| 97 |
+
return None, None, f"当前仅支持默认模型 '{default_model}',不允许使用其他模型", 400
|
| 98 |
+
|
| 99 |
+
return model, text, None, None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _load_project_with_error_handling(model):
|
| 103 |
+
"""
|
| 104 |
+
获取已加载的模型;若未加载则根据配置进行懒加载或返回错误。
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
(project_obj, error_msg, status_code) 元组
|
| 108 |
+
如果成功,返回 (project_obj, None, None)
|
| 109 |
+
如果失败,返回 (None, error_msg, status_code)
|
| 110 |
+
"""
|
| 111 |
+
# 检查模型是否在注册表中
|
| 112 |
+
if not project_registry.is_available(model):
|
| 113 |
+
available_models = list(project_registry.available_model_names())
|
| 114 |
+
error_msg = f"❌ 模型 '{model}' 未注册。可用模型: {available_models}"
|
| 115 |
+
print(error_msg)
|
| 116 |
+
return None, error_msg, 404
|
| 117 |
+
|
| 118 |
+
# 检查模型是否已加载
|
| 119 |
+
p = project_registry.get(model)
|
| 120 |
+
if p is None:
|
| 121 |
+
from backend.app_context import get_app_context
|
| 122 |
+
from backend.model_manager import ensure_main_slot_ready
|
| 123 |
+
|
| 124 |
+
context = get_app_context(prefer_module_context=True)
|
| 125 |
+
if context.model_loading:
|
| 126 |
+
error_msg = f"模型 '{model}' 正在后台加载中,请稍后重试"
|
| 127 |
+
print(f"⚠️ {error_msg}")
|
| 128 |
+
return None, error_msg, 503
|
| 129 |
+
# 懒加载模式 (--no_auto_load):首次请求仅初始化主槽位(权重 + QwenLM 项目)
|
| 130 |
+
if getattr(context.args, 'no_auto_load', False):
|
| 131 |
+
try:
|
| 132 |
+
ensure_main_slot_ready()
|
| 133 |
+
p = project_registry.get(model)
|
| 134 |
+
except Exception as e: # noqa: BLE001
|
| 135 |
+
import traceback
|
| 136 |
+
print(f"⚠️ 模型懒加载失败: {e}")
|
| 137 |
+
traceback.print_exc()
|
| 138 |
+
return None, f"模型加载失败: {str(e)}", 500
|
| 139 |
+
if p is None:
|
| 140 |
+
error_msg = f"模型 '{model}' 未加载,请联系管理员"
|
| 141 |
+
print(f"⚠️ {error_msg}")
|
| 142 |
+
return None, error_msg, 503
|
| 143 |
+
return p, None, None
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _log_request(text, stream_mode=False, client_ip=None):
|
| 147 |
+
"""
|
| 148 |
+
打印请求日志
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
int: 请求ID
|
| 152 |
+
"""
|
| 153 |
+
from backend.access_log import log_analyze_request
|
| 154 |
+
return log_analyze_request(text, stream_mode, client_ip)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _log_response(res, char_count, elapsed_time, stream_mode=False, request_id=None, wait_time=None):
|
| 158 |
+
"""打印响应日志"""
|
| 159 |
+
tokens = len(res.get('bpe_strings', []))
|
| 160 |
+
text_length = char_count
|
| 161 |
+
mode_str = "(stream)" if stream_mode else ""
|
| 162 |
+
|
| 163 |
+
# 构建日志消息
|
| 164 |
+
msg = f"\t📤 API analyze {mode_str} response:"
|
| 165 |
+
if request_id is not None:
|
| 166 |
+
msg += f" req_id={request_id},"
|
| 167 |
+
msg += f" tokens={tokens}, text_length={text_length}"
|
| 168 |
+
msg += f", response_time={elapsed_time:.4f}s"
|
| 169 |
+
|
| 170 |
+
print(msg)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _validate_and_fix_result(res):
|
| 174 |
+
"""验证和修复结果结构"""
|
| 175 |
+
if not isinstance(res, dict):
|
| 176 |
+
res = {'bpe_strings': []}
|
| 177 |
+
if 'bpe_strings' not in res or not isinstance(res.get('bpe_strings'), list):
|
| 178 |
+
res['bpe_strings'] = res.get('bpe_strings', []) if isinstance(res.get('bpe_strings'), list) else []
|
| 179 |
+
return res
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def analyze(analyze_request):
|
| 183 |
+
"""
|
| 184 |
+
分析文本
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
analyze_request: 分析请求字典,包含:
|
| 188 |
+
- model: 模型名称
|
| 189 |
+
- text: 要分析的文本
|
| 190 |
+
- stream: 可选,如果为 True 则返回 SSE 流式响应(带进度信息)
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
如果 stream=True: SSE 响应对象
|
| 194 |
+
否则: (响应字典, 状态码) 元组
|
| 195 |
+
"""
|
| 196 |
+
# 检查模型是否正在加载中(使用模块级上下文)
|
| 197 |
+
from backend.app_context import get_app_context
|
| 198 |
+
context = get_app_context(prefer_module_context=True)
|
| 199 |
+
if context.model_loading:
|
| 200 |
+
return _error_response('', '', '模型正在加载中,请稍后重试', 503)
|
| 201 |
+
|
| 202 |
+
# 在请求上下文中获取 client_ip,流式响应时生成器内可能已失效
|
| 203 |
+
from backend.access_log import get_client_ip
|
| 204 |
+
client_ip = get_client_ip()
|
| 205 |
+
|
| 206 |
+
# 检查是否启用流式响应
|
| 207 |
+
stream = analyze_request.get('stream', False)
|
| 208 |
+
if stream:
|
| 209 |
+
return _analyze_with_stream(analyze_request, client_ip)
|
| 210 |
+
return _analyze_plain(analyze_request, client_ip)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _analyze_with_stream(analyze_request, client_ip):
|
| 214 |
+
"""
|
| 215 |
+
流式分析文本,通过SSE返回进度和结果(内部函数)
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
analyze_request: 分析请求字典,包含 model 和 text
|
| 219 |
+
client_ip: 客户端 IP,在入口处获取后传入
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
SSE响应对象
|
| 223 |
+
"""
|
| 224 |
+
reporter = SSEProgressReporter(lambda: _generate_analyze_events(analyze_request, client_ip))
|
| 225 |
+
return reporter.create_response()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _analyze_plain(analyze_request, client_ip):
|
| 229 |
+
"""
|
| 230 |
+
非流式分析:封装流式实现,消费事件流后返回 JSON。
|
| 231 |
+
供脚本等简单客户端使用。
|
| 232 |
+
"""
|
| 233 |
+
result = None
|
| 234 |
+
error_msg = None
|
| 235 |
+
status_code = 500
|
| 236 |
+
try:
|
| 237 |
+
for event_str in _generate_analyze_events(analyze_request, client_ip):
|
| 238 |
+
if not event_str.startswith('data: '):
|
| 239 |
+
continue
|
| 240 |
+
data = json.loads(event_str[6:].strip())
|
| 241 |
+
t = data.get('type')
|
| 242 |
+
if t == 'result':
|
| 243 |
+
result = data.get('data')
|
| 244 |
+
elif t == 'error':
|
| 245 |
+
error_msg = data.get('message', '分析失败')
|
| 246 |
+
status_code = data.get('status_code', 500)
|
| 247 |
+
break
|
| 248 |
+
except Exception as e:
|
| 249 |
+
import traceback
|
| 250 |
+
traceback.print_exc()
|
| 251 |
+
exit_if_oom(e, defer_seconds=1)
|
| 252 |
+
error_msg = f"分析失败: {str(e)}"
|
| 253 |
+
finally:
|
| 254 |
+
gc.collect()
|
| 255 |
+
|
| 256 |
+
if error_msg:
|
| 257 |
+
model = analyze_request.get('model') or ''
|
| 258 |
+
text = analyze_request.get('text') or ''
|
| 259 |
+
return _error_response(model, text, error_msg, status_code)
|
| 260 |
+
if result is None:
|
| 261 |
+
return _error_response('', '', '分析失败:未获取到结果', 500)
|
| 262 |
+
return result, 200
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def _generate_analyze_events(analyze_request, client_ip):
|
| 266 |
+
"""
|
| 267 |
+
流式分析核心:生成 SSE 事件流(progress + result/error)。
|
| 268 |
+
供 _analyze_with_stream 和 _analyze_plain 复用。
|
| 269 |
+
client_ip 需在入口处获取并传入,因流式响应时生成器执行时请求上下文可能已失效。
|
| 270 |
+
"""
|
| 271 |
+
# 再次检查模型加载状态(在生成器内部,使用模块级上下文)
|
| 272 |
+
from backend.app_context import get_app_context
|
| 273 |
+
context = get_app_context(prefer_module_context=True)
|
| 274 |
+
if context.model_loading:
|
| 275 |
+
yield send_error_event('模型正在加载中,请稍后重试', 503)
|
| 276 |
+
return
|
| 277 |
+
|
| 278 |
+
start_time = time.perf_counter()
|
| 279 |
+
|
| 280 |
+
# 验证和准备请求
|
| 281 |
+
model, text, error_msg, status_code = _validate_and_prepare_request(analyze_request)
|
| 282 |
+
if error_msg:
|
| 283 |
+
yield send_error_event(error_msg, status_code or 400)
|
| 284 |
+
return
|
| 285 |
+
|
| 286 |
+
# 加载模型
|
| 287 |
+
p, error_msg, status_code = _load_project_with_error_handling(model)
|
| 288 |
+
if error_msg:
|
| 289 |
+
yield send_error_event(error_msg, status_code or 500)
|
| 290 |
+
return
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
char_count = len(text) if text else 0
|
| 294 |
+
request_id = _log_request(text, stream_mode=True, client_ip=client_ip)
|
| 295 |
+
|
| 296 |
+
# 创建线程安全的进度队列
|
| 297 |
+
progress_queue = queue.Queue()
|
| 298 |
+
analysis_done = threading.Event()
|
| 299 |
+
analysis_result = None
|
| 300 |
+
analysis_error = None
|
| 301 |
+
lock_wait_time = None # 记录等待锁的时间
|
| 302 |
+
|
| 303 |
+
def progress_callback_func(step: int, total_steps: int, stage: str, percentage: Optional[int]):
|
| 304 |
+
"""进度回调函数,将事件加入队列"""
|
| 305 |
+
progress_queue.put(('progress', step, total_steps, stage, percentage))
|
| 306 |
+
|
| 307 |
+
def run_analysis():
|
| 308 |
+
"""在单独线程中运行分析"""
|
| 309 |
+
nonlocal analysis_result, analysis_error, lock_wait_time
|
| 310 |
+
try:
|
| 311 |
+
# 记录开始等待锁的时间
|
| 312 |
+
lock_wait_start = time.perf_counter()
|
| 313 |
+
|
| 314 |
+
# 尝试获取锁,设置超时避免长时间排队
|
| 315 |
+
lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT)
|
| 316 |
+
if not lock_acquired:
|
| 317 |
+
# 获取锁超时,说明前面有任务在执行且耗时较长
|
| 318 |
+
analysis_error = QueueTimeoutError(
|
| 319 |
+
f"排队等待超过 {LOCK_WAIT_TIMEOUT} 秒,服务繁忙,请稍后重试"
|
| 320 |
+
)
|
| 321 |
+
return
|
| 322 |
+
|
| 323 |
+
# 记录等待时间
|
| 324 |
+
lock_wait_time = time.perf_counter() - lock_wait_start
|
| 325 |
+
|
| 326 |
+
try:
|
| 327 |
+
from backend.access_log import log_analyze_start
|
| 328 |
+
log_analyze_start(request_id, lock_wait_time, stream_mode=True)
|
| 329 |
+
|
| 330 |
+
# 在持有锁的情况下执行分析
|
| 331 |
+
# 注意:这里的执行时长也会受到 ANALYSIS_TIMEOUT 的监控(在外层循环中)
|
| 332 |
+
res = p.lm.analyze_text(text, progress_callback=progress_callback_func)
|
| 333 |
+
analysis_result = res
|
| 334 |
+
finally:
|
| 335 |
+
# 确保锁一定会被释放
|
| 336 |
+
_inference_lock.release()
|
| 337 |
+
except Exception as e:
|
| 338 |
+
analysis_error = e
|
| 339 |
+
finally:
|
| 340 |
+
analysis_done.set()
|
| 341 |
+
progress_queue.put(('done', None, None)) # 发送完成信号
|
| 342 |
+
|
| 343 |
+
# 启动分析线程
|
| 344 |
+
analysis_thread = threading.Thread(target=run_analysis, daemon=True)
|
| 345 |
+
analysis_thread.start()
|
| 346 |
+
|
| 347 |
+
# 实时发送进度事件,并检查超时
|
| 348 |
+
timeout_reached = False
|
| 349 |
+
for kind, event_str in consume_progress_queue(
|
| 350 |
+
progress_queue, analysis_done, start_time, ANALYSIS_TIMEOUT, "分析"
|
| 351 |
+
):
|
| 352 |
+
if kind == 'timeout':
|
| 353 |
+
timeout_reached = True
|
| 354 |
+
yield event_str
|
| 355 |
+
break
|
| 356 |
+
if kind == 'progress':
|
| 357 |
+
yield event_str
|
| 358 |
+
elif kind == 'done':
|
| 359 |
+
break
|
| 360 |
+
|
| 361 |
+
# 如果超时,不等待分析完成,直接返回
|
| 362 |
+
if timeout_reached:
|
| 363 |
+
gc.collect()
|
| 364 |
+
return
|
| 365 |
+
|
| 366 |
+
# 检查是否有错误
|
| 367 |
+
# 注意:此时已收到 'done' 信号,分析线程已完成其工作(或发生错误)
|
| 368 |
+
# 线程是 daemon 的,会自动清理,无需显式等待
|
| 369 |
+
if analysis_error:
|
| 370 |
+
# 排队超时:返回友好的错误消息
|
| 371 |
+
if isinstance(analysis_error, QueueTimeoutError):
|
| 372 |
+
print(f"⏱️ 排队超时: {analysis_error}")
|
| 373 |
+
yield send_error_event(str(analysis_error), 503)
|
| 374 |
+
gc.collect()
|
| 375 |
+
return
|
| 376 |
+
# 其他错误:抛出异常,由外层的 try-except 处理
|
| 377 |
+
raise analysis_error
|
| 378 |
+
|
| 379 |
+
# 检查结果是否为空(理论上不应该发生,因为要么有结果,要么有错误)
|
| 380 |
+
if analysis_result is None:
|
| 381 |
+
print("⚠️ 分析结果为空,但没有错误信息")
|
| 382 |
+
yield send_error_event("分析失败:未获取到结果", 500)
|
| 383 |
+
gc.collect()
|
| 384 |
+
return
|
| 385 |
+
|
| 386 |
+
res = analysis_result
|
| 387 |
+
|
| 388 |
+
elapsed_time = time.perf_counter() - start_time
|
| 389 |
+
_log_response(res, char_count, elapsed_time, stream_mode=True,
|
| 390 |
+
request_id=request_id, wait_time=lock_wait_time)
|
| 391 |
+
|
| 392 |
+
# 验证和修复结果
|
| 393 |
+
res = _validate_and_fix_result(res)
|
| 394 |
+
|
| 395 |
+
# 构建最终响应
|
| 396 |
+
final_response = _build_response(model, text, res)
|
| 397 |
+
|
| 398 |
+
# 发送最终结果
|
| 399 |
+
yield send_result_event(final_response)
|
| 400 |
+
|
| 401 |
+
# 强制垃圾回收以释放内存
|
| 402 |
+
gc.collect()
|
| 403 |
+
|
| 404 |
+
except Exception as e:
|
| 405 |
+
import traceback
|
| 406 |
+
traceback.print_exc()
|
| 407 |
+
exit_if_oom(e, defer_seconds=1)
|
| 408 |
+
yield send_error_event(str(e), 500)
|
| 409 |
+
# 即使出错也进行垃圾回收
|
| 410 |
+
gc.collect()
|
| 411 |
+
|
| 412 |
+
|
backend/api/analyze_semantic.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Semantic analysis API:返回原文各 token 对 prompt 的平均关注度"""
|
| 2 |
+
import gc
|
| 3 |
+
import json
|
| 4 |
+
import queue
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from backend.model_manager import _inference_lock
|
| 10 |
+
from backend.oom import exit_if_oom
|
| 11 |
+
from backend.semantic_analyzer import analyze_semantic as _analyze_semantic
|
| 12 |
+
from backend.api.sse_utils import (
|
| 13 |
+
SSEProgressReporter,
|
| 14 |
+
consume_progress_queue,
|
| 15 |
+
send_result_event,
|
| 16 |
+
send_error_event,
|
| 17 |
+
)
|
| 18 |
+
from backend.access_log import get_client_ip
|
| 19 |
+
from backend.api.analyze import QueueTimeoutError, ANALYSIS_TIMEOUT, LOCK_WAIT_TIMEOUT
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _log_request(query, text, client_ip=None):
|
| 23 |
+
from backend.access_log import log_analyze_semantic_request
|
| 24 |
+
return log_analyze_semantic_request(query, text, client_ip)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _build_success_response(result, debug_info: bool = False):
|
| 28 |
+
"""构建成功响应。debug_info=True 时包含 debug_info 对象(abbrev、topk_tokens、topk_probs)"""
|
| 29 |
+
resp = {
|
| 30 |
+
"success": True,
|
| 31 |
+
"model": result["model"],
|
| 32 |
+
"token_attention": result["token_attention"],
|
| 33 |
+
"full_match_degree": result["full_match_degree"],
|
| 34 |
+
}
|
| 35 |
+
if debug_info and "debug_info" in result:
|
| 36 |
+
resp["debug_info"] = result["debug_info"]
|
| 37 |
+
return resp
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _generate_semantic_events(
|
| 41 |
+
query: str, text: str, submode: Optional[str] = None, debug_info: bool = False,
|
| 42 |
+
full_match_degree_only: bool = False, client_ip: Optional[str] = None
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
流式语义分析核心:生成 SSE 事件流(progress + result/error)。
|
| 46 |
+
供 _analyze_semantic_with_stream 和 _analyze_semantic_plain 复用。
|
| 47 |
+
client_ip 需在入口处获取并传入,因流式响应时生成器执行时请求上下文已失效。
|
| 48 |
+
"""
|
| 49 |
+
if client_ip is None:
|
| 50 |
+
client_ip = get_client_ip()
|
| 51 |
+
start_time = time.perf_counter()
|
| 52 |
+
request_id = _log_request(query, text, client_ip)
|
| 53 |
+
|
| 54 |
+
progress_queue = queue.Queue()
|
| 55 |
+
analysis_done = threading.Event()
|
| 56 |
+
analysis_result = None
|
| 57 |
+
analysis_error = None
|
| 58 |
+
lock_wait_time = None
|
| 59 |
+
|
| 60 |
+
def progress_callback(step: int, total_steps: int, stage: str, percentage: Optional[int]):
|
| 61 |
+
progress_queue.put(("progress", step, total_steps, stage, percentage))
|
| 62 |
+
|
| 63 |
+
def run_analysis():
|
| 64 |
+
nonlocal analysis_result, analysis_error, lock_wait_time
|
| 65 |
+
try:
|
| 66 |
+
lock_wait_start = time.perf_counter()
|
| 67 |
+
lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT)
|
| 68 |
+
if not lock_acquired:
|
| 69 |
+
analysis_error = QueueTimeoutError(
|
| 70 |
+
f"排队等待超过 {LOCK_WAIT_TIMEOUT} 秒,服务繁忙,请稍后重试"
|
| 71 |
+
)
|
| 72 |
+
return
|
| 73 |
+
lock_wait_time = time.perf_counter() - lock_wait_start
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
from backend.access_log import log_analyze_semantic_start
|
| 77 |
+
log_analyze_semantic_start(request_id, lock_wait_time, stream_mode=True)
|
| 78 |
+
result = _analyze_semantic(query, text, submode_override=submode, progress_callback=progress_callback, debug_info=debug_info, full_match_degree_only=full_match_degree_only)
|
| 79 |
+
analysis_result = result
|
| 80 |
+
finally:
|
| 81 |
+
_inference_lock.release()
|
| 82 |
+
except Exception as e:
|
| 83 |
+
analysis_error = e
|
| 84 |
+
finally:
|
| 85 |
+
analysis_done.set()
|
| 86 |
+
progress_queue.put(("done", None, None))
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
analysis_thread = threading.Thread(target=run_analysis, daemon=True)
|
| 90 |
+
analysis_thread.start()
|
| 91 |
+
|
| 92 |
+
timeout_reached = False
|
| 93 |
+
for kind, event_str in consume_progress_queue(
|
| 94 |
+
progress_queue, analysis_done, start_time, ANALYSIS_TIMEOUT, "语义分析"
|
| 95 |
+
):
|
| 96 |
+
if kind == 'timeout':
|
| 97 |
+
timeout_reached = True
|
| 98 |
+
yield event_str
|
| 99 |
+
break
|
| 100 |
+
if kind == 'progress':
|
| 101 |
+
yield event_str
|
| 102 |
+
elif kind == 'done':
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
if timeout_reached:
|
| 106 |
+
gc.collect()
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
if analysis_error:
|
| 110 |
+
if isinstance(analysis_error, QueueTimeoutError):
|
| 111 |
+
print(f"⏱️ 排队超时: {analysis_error}")
|
| 112 |
+
yield send_error_event(str(analysis_error), 503)
|
| 113 |
+
gc.collect()
|
| 114 |
+
return
|
| 115 |
+
raise analysis_error
|
| 116 |
+
|
| 117 |
+
if analysis_result is None:
|
| 118 |
+
print("⚠️ 语义分析结果为空,但没有错误信息")
|
| 119 |
+
yield send_error_event("分析失败:未获取到结果", 500)
|
| 120 |
+
gc.collect()
|
| 121 |
+
return
|
| 122 |
+
|
| 123 |
+
elapsed = time.perf_counter() - start_time
|
| 124 |
+
tokens = len(analysis_result.get("token_attention", []))
|
| 125 |
+
print(
|
| 126 |
+
f"\t📤 API analyze_semantic (stream) response: req_id={request_id}, "
|
| 127 |
+
f"tokens={tokens}, response_time={elapsed:.4f}s"
|
| 128 |
+
)
|
| 129 |
+
yield send_result_event(_build_success_response(analysis_result, debug_info))
|
| 130 |
+
except Exception as e:
|
| 131 |
+
import traceback
|
| 132 |
+
traceback.print_exc()
|
| 133 |
+
exit_if_oom(e, defer_seconds=1)
|
| 134 |
+
yield send_error_event(str(e), 500)
|
| 135 |
+
finally:
|
| 136 |
+
gc.collect()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _analyze_semantic_with_stream(
|
| 140 |
+
query: str, text: str, submode: Optional[str] = None, debug_info: bool = False,
|
| 141 |
+
full_match_degree_only: bool = False, client_ip: Optional[str] = None
|
| 142 |
+
):
|
| 143 |
+
"""流式语义分析,通过 SSE 返回阶段级进度"""
|
| 144 |
+
return SSEProgressReporter(
|
| 145 |
+
lambda: _generate_semantic_events(query, text, submode, debug_info, full_match_degree_only, client_ip)
|
| 146 |
+
).create_response()
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _analyze_semantic_plain(
|
| 150 |
+
query: str, text: str, submode: Optional[str] = None, debug_info: bool = False,
|
| 151 |
+
full_match_degree_only: bool = False, client_ip: Optional[str] = None
|
| 152 |
+
):
|
| 153 |
+
"""
|
| 154 |
+
非流式语义分析:封装流式实现,消费事件流后返回 JSON。
|
| 155 |
+
供脚本等简单客户端使用。
|
| 156 |
+
"""
|
| 157 |
+
result = None
|
| 158 |
+
error_msg = None
|
| 159 |
+
status_code = 500
|
| 160 |
+
try:
|
| 161 |
+
for event_str in _generate_semantic_events(query, text, submode, debug_info, full_match_degree_only, client_ip):
|
| 162 |
+
if not event_str.startswith('data: '):
|
| 163 |
+
continue
|
| 164 |
+
data = json.loads(event_str[6:].strip())
|
| 165 |
+
t = data.get('type')
|
| 166 |
+
if t == 'result':
|
| 167 |
+
result = data.get('data')
|
| 168 |
+
elif t == 'error':
|
| 169 |
+
error_msg = data.get('message', '分析失败')
|
| 170 |
+
status_code = data.get('status_code', 500)
|
| 171 |
+
break
|
| 172 |
+
except Exception as e:
|
| 173 |
+
import traceback
|
| 174 |
+
traceback.print_exc()
|
| 175 |
+
exit_if_oom(e, defer_seconds=1)
|
| 176 |
+
error_msg = str(e)
|
| 177 |
+
finally:
|
| 178 |
+
gc.collect()
|
| 179 |
+
|
| 180 |
+
if error_msg:
|
| 181 |
+
return {"success": False, "message": error_msg}, status_code
|
| 182 |
+
if result is None:
|
| 183 |
+
return {"success": False, "message": "分析失败:未获取到结果"}, 500
|
| 184 |
+
return result, 200
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def analyze_semantic(semantic_request):
|
| 188 |
+
"""
|
| 189 |
+
分析原文 token 对 prompt 的关注度。
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
semantic_request: 包含 query、text、stream(可选)、submode(可选)的字典
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
stream=True 时返回 SSE 响应;否则返回 (响应字典, 状态码) 元组
|
| 196 |
+
"""
|
| 197 |
+
query = (semantic_request.get("query") or "")
|
| 198 |
+
text = semantic_request.get("text") or ""
|
| 199 |
+
stream = semantic_request.get("stream", False)
|
| 200 |
+
submode = (semantic_request.get("submode") or "").strip() or None
|
| 201 |
+
debug_info = bool(semantic_request.get("debug_info", False))
|
| 202 |
+
full_match_degree_only = bool(semantic_request.get("full_match_degree_only", False))
|
| 203 |
+
|
| 204 |
+
if not query:
|
| 205 |
+
return {"success": False, "message": "缺少 query 字段"}, 400
|
| 206 |
+
if not text:
|
| 207 |
+
return {"success": False, "message": "缺少 text 字段"}, 400
|
| 208 |
+
|
| 209 |
+
client_ip = get_client_ip()
|
| 210 |
+
if stream:
|
| 211 |
+
return _analyze_semantic_with_stream(query, text, submode, debug_info, full_match_degree_only, client_ip)
|
| 212 |
+
return _analyze_semantic_plain(query, text, submode, debug_info, full_match_degree_only, client_ip)
|
backend/api/demo.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Demo 文件管理 API"""
|
| 2 |
+
from backend.data_utils import save_demo_payload
|
| 3 |
+
from backend.demo_folder import (
|
| 4 |
+
list_demo_items,
|
| 5 |
+
move_demo_file,
|
| 6 |
+
rename_demo_file,
|
| 7 |
+
delete_demo_file,
|
| 8 |
+
move_folder,
|
| 9 |
+
)
|
| 10 |
+
from backend.api.utils import (
|
| 11 |
+
get_demo_directory,
|
| 12 |
+
handle_api_error,
|
| 13 |
+
handle_api_success,
|
| 14 |
+
require_admin,
|
| 15 |
+
validate_admin_token,
|
| 16 |
+
)
|
| 17 |
+
from backend.access_log import log_check_admin
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def list_demos(path: str = ""):
|
| 21 |
+
"""
|
| 22 |
+
扫描demo目录下的文件夹和文件,返回列表
|
| 23 |
+
支持指定路径参数,返回指定路径下的内容
|
| 24 |
+
文件名(去掉.json后缀)作为demo名称
|
| 25 |
+
支持中文文件名和路径
|
| 26 |
+
从data/demo目录读取(更专业的数据目录结构)
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
path: 可选,指定要列出的路径,默认为根目录(空字符串)
|
| 30 |
+
"""
|
| 31 |
+
demo_dir = get_demo_directory(create=False)
|
| 32 |
+
try:
|
| 33 |
+
result = list_demo_items(demo_dir, path)
|
| 34 |
+
# if not result.get("items"):
|
| 35 |
+
# print(f"⚠️ 路径 '{path}' 下没有内容: {demo_dir}")
|
| 36 |
+
# else:
|
| 37 |
+
# items_count = len(result["items"])
|
| 38 |
+
# folders_count = sum(1 for item in result["items"] if item["type"] == "folder")
|
| 39 |
+
# files_count = sum(1 for item in result["items"] if item["type"] == "file")
|
| 40 |
+
# print(f"✓ 路径 '{path}': {folders_count} 个文件夹, {files_count} 个文件 (共 {items_count} 项)")
|
| 41 |
+
return result
|
| 42 |
+
except Exception as e:
|
| 43 |
+
error_result = handle_api_error("Failed to scan demo directory", e)
|
| 44 |
+
return {"path": path, "items": []}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@require_admin
|
| 48 |
+
def save_demo(save_request):
|
| 49 |
+
"""
|
| 50 |
+
保存demo文件到data/demo目录
|
| 51 |
+
请求格式: { name: string, data: AnalyzeResponse, path?: string, overwrite?: boolean }
|
| 52 |
+
path: 可选,保存路径,默认为根目录("/")
|
| 53 |
+
overwrite: 可选,是否覆盖已存在的文件,默认为False
|
| 54 |
+
"""
|
| 55 |
+
name = save_request.get('name')
|
| 56 |
+
data = save_request.get('data')
|
| 57 |
+
path = save_request.get('path', '/') # 默认为根目录
|
| 58 |
+
overwrite = save_request.get('overwrite', False) # 默认为False
|
| 59 |
+
|
| 60 |
+
if not name or not data:
|
| 61 |
+
return {
|
| 62 |
+
'success': False,
|
| 63 |
+
'message': 'Missing required parameters: name or data'
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
demo_dir = get_demo_directory(create=True)
|
| 68 |
+
result = save_demo_payload(demo_dir, name, data, path, overwrite)
|
| 69 |
+
if result.get('success'):
|
| 70 |
+
print(f"✓ Demo已保存: {demo_dir / result['file']}")
|
| 71 |
+
else:
|
| 72 |
+
print(f"❌ Save failed: {result.get('message')}")
|
| 73 |
+
return result
|
| 74 |
+
except Exception as e:
|
| 75 |
+
return handle_api_error('Save failed', e)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@require_admin
|
| 79 |
+
def delete_demo(delete_request):
|
| 80 |
+
"""
|
| 81 |
+
将demo文件移动到deleted文件夹(软删除)
|
| 82 |
+
请求格式: { file: string } # 文件名(包含.json后缀)
|
| 83 |
+
"""
|
| 84 |
+
file = delete_request.get('file')
|
| 85 |
+
|
| 86 |
+
if not file:
|
| 87 |
+
return {
|
| 88 |
+
'success': False,
|
| 89 |
+
'message': 'Missing required parameter: file'
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
demo_dir = get_demo_directory(create=False)
|
| 94 |
+
result = delete_demo_file(demo_dir, file)
|
| 95 |
+
return handle_api_success(result)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
return handle_api_error('Delete failed', e)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
@require_admin
|
| 101 |
+
def move_demo(move_request):
|
| 102 |
+
"""
|
| 103 |
+
移动demo文件或文件夹
|
| 104 |
+
请求格式: { file: string, target_path: string } 或 { path: string, target_path: string }
|
| 105 |
+
"""
|
| 106 |
+
file = move_request.get('file')
|
| 107 |
+
path = move_request.get('path')
|
| 108 |
+
target_path = move_request.get('target_path', '')
|
| 109 |
+
|
| 110 |
+
if not target_path and target_path != '':
|
| 111 |
+
return {
|
| 112 |
+
'success': False,
|
| 113 |
+
'message': 'Missing required parameter: target_path'
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
if not file and not path:
|
| 117 |
+
return {
|
| 118 |
+
'success': False,
|
| 119 |
+
'message': 'Missing required parameter: file or path'
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
demo_dir = get_demo_directory(create=False)
|
| 124 |
+
|
| 125 |
+
if file:
|
| 126 |
+
# 移动文件
|
| 127 |
+
result = move_demo_file(demo_dir, file, target_path)
|
| 128 |
+
else:
|
| 129 |
+
# 移动文件夹
|
| 130 |
+
result = move_folder(demo_dir, path, target_path)
|
| 131 |
+
|
| 132 |
+
return handle_api_success(result)
|
| 133 |
+
except Exception as e:
|
| 134 |
+
return handle_api_error('Move failed', e)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@require_admin
|
| 138 |
+
def rename_demo(rename_request):
|
| 139 |
+
"""
|
| 140 |
+
重命名demo文件
|
| 141 |
+
请求格式: { file: string, new_name: string }
|
| 142 |
+
"""
|
| 143 |
+
file = rename_request.get('file')
|
| 144 |
+
new_name = rename_request.get('new_name')
|
| 145 |
+
|
| 146 |
+
if not file or not new_name:
|
| 147 |
+
return {
|
| 148 |
+
'success': False,
|
| 149 |
+
'message': 'Missing required parameter: file or new_name'
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
demo_dir = get_demo_directory(create=False)
|
| 154 |
+
result = rename_demo_file(demo_dir, file, new_name)
|
| 155 |
+
return handle_api_success(result)
|
| 156 |
+
except Exception as e:
|
| 157 |
+
return handle_api_error('Rename failed', e)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def check_admin(check_request):
|
| 161 |
+
"""
|
| 162 |
+
检查管理员token是否有效
|
| 163 |
+
请求格式: { token: string }
|
| 164 |
+
"""
|
| 165 |
+
from flask import request
|
| 166 |
+
|
| 167 |
+
# 从请求体或请求头获取token
|
| 168 |
+
request_token = check_request.get('token') or request.headers.get('X-Admin-Token')
|
| 169 |
+
|
| 170 |
+
# 验证token
|
| 171 |
+
is_valid, error_message = validate_admin_token(request_token)
|
| 172 |
+
|
| 173 |
+
# 记录管理员权限检查
|
| 174 |
+
log_check_admin(is_valid, token=request_token)
|
| 175 |
+
|
| 176 |
+
if is_valid:
|
| 177 |
+
return {'success': True}
|
| 178 |
+
else:
|
| 179 |
+
return {
|
| 180 |
+
'success': False,
|
| 181 |
+
'message': error_message
|
| 182 |
+
}
|
| 183 |
+
|
backend/api/fetch_url.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""URL 文本提取 API"""
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
from urllib.parse import urlparse
|
| 5 |
+
import trafilatura
|
| 6 |
+
import requests
|
| 7 |
+
from backend.api.utils import handle_api_error
|
| 8 |
+
|
| 9 |
+
# 单次提取的最大字符数上限(防止异常大页面影响性能)
|
| 10 |
+
MAX_EXTRACTED_TEXT_LENGTH = 20000
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def _is_valid_url(url: str) -> bool:
|
| 14 |
+
"""验证 URL 格式"""
|
| 15 |
+
try:
|
| 16 |
+
result = urlparse(url)
|
| 17 |
+
return all([result.scheme in ['http', 'https'], result.netloc])
|
| 18 |
+
except Exception:
|
| 19 |
+
return False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _is_local_or_private(url: str) -> bool:
|
| 23 |
+
"""检查是否为本地或私有网络地址(防止 SSRF 攻击)"""
|
| 24 |
+
try:
|
| 25 |
+
parsed = urlparse(url)
|
| 26 |
+
hostname = parsed.hostname
|
| 27 |
+
|
| 28 |
+
if not hostname:
|
| 29 |
+
return True
|
| 30 |
+
|
| 31 |
+
# 检查是否为 localhost
|
| 32 |
+
if hostname in ['localhost', '127.0.0.1', '::1']:
|
| 33 |
+
return True
|
| 34 |
+
|
| 35 |
+
# 检查是否为私有 IP 地址
|
| 36 |
+
private_patterns = [
|
| 37 |
+
r'^10\.', # 10.0.0.0/8
|
| 38 |
+
r'^172\.(1[6-9]|2[0-9]|3[0-1])\.', # 172.16.0.0/12
|
| 39 |
+
r'^192\.168\.', # 192.168.0.0/16
|
| 40 |
+
r'^169\.254\.', # 169.254.0.0/16 (link-local)
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
for pattern in private_patterns:
|
| 44 |
+
if re.match(pattern, hostname):
|
| 45 |
+
return True
|
| 46 |
+
|
| 47 |
+
return False
|
| 48 |
+
except Exception:
|
| 49 |
+
return True # 解析失败时保守处理,拒绝访问
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _format_article_text(metadata: dict) -> str:
|
| 53 |
+
"""
|
| 54 |
+
将元数据和正文格式化为类似网页显示的纯文本
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
metadata: trafilatura 提取的 JSON 数据(已解析为字典)
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
格式化后的文章文本
|
| 61 |
+
"""
|
| 62 |
+
lines = []
|
| 63 |
+
|
| 64 |
+
# 标题
|
| 65 |
+
if metadata.get('title'):
|
| 66 |
+
lines.append(metadata['title'])
|
| 67 |
+
lines.append('')
|
| 68 |
+
|
| 69 |
+
# 元数据信息(无标签,直接显示内容)
|
| 70 |
+
meta_parts = []
|
| 71 |
+
if metadata.get('author'):
|
| 72 |
+
meta_parts.append(metadata['author'])
|
| 73 |
+
if metadata.get('date'):
|
| 74 |
+
meta_parts.append(metadata['date'])
|
| 75 |
+
# if metadata.get('hostname'):
|
| 76 |
+
# meta_parts.append(metadata['hostname'])
|
| 77 |
+
if metadata.get('source-hostname'):
|
| 78 |
+
meta_parts.append(metadata['source-hostname'])
|
| 79 |
+
# if metadata.get('filedate'):
|
| 80 |
+
# meta_parts.append(metadata['filedate'])
|
| 81 |
+
|
| 82 |
+
if meta_parts:
|
| 83 |
+
lines.append(' | '.join(meta_parts))
|
| 84 |
+
lines.append('')
|
| 85 |
+
|
| 86 |
+
# 正文
|
| 87 |
+
if metadata.get('text'):
|
| 88 |
+
lines.append(metadata['text'])
|
| 89 |
+
|
| 90 |
+
return '\n'.join(lines)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def fetch_url(fetch_request):
|
| 94 |
+
"""
|
| 95 |
+
从 URL 提取文本内容
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
fetch_request: 包含 url 字段的字典
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
(响应字典, 状态码) 元组
|
| 102 |
+
"""
|
| 103 |
+
url = fetch_request.get('url', '').strip()
|
| 104 |
+
|
| 105 |
+
# 验证 URL
|
| 106 |
+
if not url:
|
| 107 |
+
return {
|
| 108 |
+
'success': False,
|
| 109 |
+
'message': '缺少 URL 参数,请提供 url 字段'
|
| 110 |
+
}, 400
|
| 111 |
+
|
| 112 |
+
if not _is_valid_url(url):
|
| 113 |
+
return {
|
| 114 |
+
'success': False,
|
| 115 |
+
'message': f'无效的 URL 格式: {url}'
|
| 116 |
+
}, 400
|
| 117 |
+
|
| 118 |
+
# 安全检查:防止 SSRF 攻击
|
| 119 |
+
if _is_local_or_private(url):
|
| 120 |
+
return {
|
| 121 |
+
'success': False,
|
| 122 |
+
'message': '不允许访问本地或私有网络地址'
|
| 123 |
+
}, 400
|
| 124 |
+
|
| 125 |
+
# 提取文本和元数据
|
| 126 |
+
try:
|
| 127 |
+
from backend.access_log import log_fetch_url
|
| 128 |
+
log_fetch_url(url)
|
| 129 |
+
|
| 130 |
+
# 使用 requests 下载网页,设置浏览器 User-Agent 和请求头
|
| 131 |
+
headers = {
|
| 132 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
|
| 133 |
+
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
|
| 134 |
+
'Accept-Language': 'zh-CN,zh;q=0.9,en;q=0.8',
|
| 135 |
+
'Accept-Encoding': 'gzip, deflate, br',
|
| 136 |
+
'Connection': 'keep-alive',
|
| 137 |
+
'Upgrade-Insecure-Requests': '1',
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
# 下载网页内容(设置超时和请求头)
|
| 141 |
+
response = requests.get(url, headers=headers, timeout=10, allow_redirects=True)
|
| 142 |
+
response.raise_for_status()
|
| 143 |
+
|
| 144 |
+
# 检查响应内容类型
|
| 145 |
+
content_type = response.headers.get('Content-Type', '').lower()
|
| 146 |
+
if 'text/html' not in content_type and 'text/xml' not in content_type:
|
| 147 |
+
return {
|
| 148 |
+
'success': False,
|
| 149 |
+
'message': f'不支持的内容类型: {content_type},仅支持 HTML/XML 页面'
|
| 150 |
+
}, 400
|
| 151 |
+
|
| 152 |
+
# 使用 trafilatura 提取结构化数据(包含元数据和正文)
|
| 153 |
+
result_json = trafilatura.extract(
|
| 154 |
+
response.text,
|
| 155 |
+
url=url,
|
| 156 |
+
with_metadata=True,
|
| 157 |
+
output_format='json'
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if not result_json:
|
| 161 |
+
print("⚠️ 无法提取页面内容")
|
| 162 |
+
return {
|
| 163 |
+
'success': False,
|
| 164 |
+
'message': '无法从网页中提取文本内容,可能不是文章页面或页面需要验证'
|
| 165 |
+
}, 400
|
| 166 |
+
|
| 167 |
+
# 解析 JSON 数据
|
| 168 |
+
metadata = json.loads(result_json)
|
| 169 |
+
|
| 170 |
+
# 检查是否有正文内容
|
| 171 |
+
if not metadata.get('text') or not metadata['text'].strip():
|
| 172 |
+
print("⚠️ 提取到元数据但无正文内容")
|
| 173 |
+
print("元数据:", json.dumps(metadata, ensure_ascii=False, indent=2))
|
| 174 |
+
return {
|
| 175 |
+
'success': False,
|
| 176 |
+
'message': '无法从网页中提取正文内容'
|
| 177 |
+
}, 400
|
| 178 |
+
|
| 179 |
+
# 格式化文本(元数据 + 正文)
|
| 180 |
+
formatted_text = _format_article_text(metadata)
|
| 181 |
+
original_char_count = len(formatted_text)
|
| 182 |
+
|
| 183 |
+
# 构建返回消息(如果截断了,添加提示)
|
| 184 |
+
message = None
|
| 185 |
+
# 检查并截断超长文本
|
| 186 |
+
if original_char_count > MAX_EXTRACTED_TEXT_LENGTH:
|
| 187 |
+
formatted_text = formatted_text[:MAX_EXTRACTED_TEXT_LENGTH]
|
| 188 |
+
message = f'内容较长,已截断为前 {MAX_EXTRACTED_TEXT_LENGTH} 字符(原始长度: {original_char_count} 字符)'
|
| 189 |
+
|
| 190 |
+
char_count = len(formatted_text)
|
| 191 |
+
|
| 192 |
+
# 打印提取结果
|
| 193 |
+
# print(formatted_text.split('\n')[:4])
|
| 194 |
+
# print(f"✓ 提取成功: {char_count} 字符" + (f" (截断前: {original_char_count} 字符)" if original_char_count > char_count else ""))
|
| 195 |
+
# 打印除正文外的metadata内容
|
| 196 |
+
metadata_less = metadata.copy()
|
| 197 |
+
metadata_less['raw_text'] = ''
|
| 198 |
+
metadata_less['text'] = ''
|
| 199 |
+
# print(json.dumps(metadata_less, ensure_ascii=False, indent=2))
|
| 200 |
+
|
| 201 |
+
return {
|
| 202 |
+
'success': True,
|
| 203 |
+
'text': formatted_text,
|
| 204 |
+
'url': url,
|
| 205 |
+
'char_count': char_count,
|
| 206 |
+
'message': message
|
| 207 |
+
}, 200
|
| 208 |
+
|
| 209 |
+
except requests.exceptions.Timeout:
|
| 210 |
+
return {
|
| 211 |
+
'success': False,
|
| 212 |
+
'message': '请求超时,请检查网络连接或稍后重试'
|
| 213 |
+
}, 400
|
| 214 |
+
except requests.exceptions.RequestException as e:
|
| 215 |
+
return {
|
| 216 |
+
'success': False,
|
| 217 |
+
'message': f'无法访问 URL: {str(e)}'
|
| 218 |
+
}, 400
|
| 219 |
+
except Exception as e: # noqa: BLE001
|
| 220 |
+
error_response = handle_api_error('URL 文本提取失败', e)
|
| 221 |
+
return error_response, 500
|
backend/api/folder.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""文件夹管理 API"""
|
| 2 |
+
from backend.demo_folder import (
|
| 3 |
+
get_all_folders,
|
| 4 |
+
move_folder,
|
| 5 |
+
rename_folder,
|
| 6 |
+
delete_folder,
|
| 7 |
+
create_folder,
|
| 8 |
+
)
|
| 9 |
+
from backend.api.utils import (
|
| 10 |
+
get_demo_directory,
|
| 11 |
+
handle_api_error,
|
| 12 |
+
handle_api_success,
|
| 13 |
+
require_admin,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _move_folder_internal(demo_dir, path, target_path):
|
| 18 |
+
"""内部函数:移动文件夹"""
|
| 19 |
+
return move_folder(demo_dir, path, target_path)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@require_admin
|
| 23 |
+
def rename_folder_api(rename_request):
|
| 24 |
+
"""
|
| 25 |
+
重命名文件夹
|
| 26 |
+
请求格式: { path: string, new_name: string }
|
| 27 |
+
"""
|
| 28 |
+
path = rename_request.get('path')
|
| 29 |
+
new_name = rename_request.get('new_name')
|
| 30 |
+
|
| 31 |
+
if not path or not new_name:
|
| 32 |
+
return {
|
| 33 |
+
'success': False,
|
| 34 |
+
'message': 'Missing required parameter: path or new_name'
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
demo_dir = get_demo_directory(create=False)
|
| 39 |
+
result = rename_folder(demo_dir, path, new_name)
|
| 40 |
+
return handle_api_success(result)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
return handle_api_error('Rename failed', e)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@require_admin
|
| 46 |
+
def delete_folder_api(delete_request):
|
| 47 |
+
"""
|
| 48 |
+
删除文件夹(移动到.deleted目录)
|
| 49 |
+
请求格式: { path: string }
|
| 50 |
+
"""
|
| 51 |
+
path = delete_request.get('path')
|
| 52 |
+
|
| 53 |
+
if not path:
|
| 54 |
+
return {
|
| 55 |
+
'success': False,
|
| 56 |
+
'message': 'Missing required parameter: path'
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
demo_dir = get_demo_directory(create=False)
|
| 61 |
+
result = delete_folder(demo_dir, path)
|
| 62 |
+
return handle_api_success(result)
|
| 63 |
+
except Exception as e:
|
| 64 |
+
return handle_api_error('Delete failed', e)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def list_all_folders():
|
| 68 |
+
"""
|
| 69 |
+
获取所有文件夹列表(用于移动操作的选择器)
|
| 70 |
+
返回格式: { folders: string[] }
|
| 71 |
+
"""
|
| 72 |
+
try:
|
| 73 |
+
demo_dir = get_demo_directory(create=False)
|
| 74 |
+
folders = get_all_folders(demo_dir)
|
| 75 |
+
return {'folders': folders}
|
| 76 |
+
except Exception as e:
|
| 77 |
+
handle_api_error("Failed to get folder list", e)
|
| 78 |
+
return {'folders': []}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@require_admin
|
| 82 |
+
def create_folder_api(create_request):
|
| 83 |
+
"""
|
| 84 |
+
创建新文件夹
|
| 85 |
+
请求格式: { parent_path: string, folder_name: string }
|
| 86 |
+
"""
|
| 87 |
+
parent_path = create_request.get('parent_path', '/')
|
| 88 |
+
folder_name = create_request.get('folder_name')
|
| 89 |
+
|
| 90 |
+
if not folder_name:
|
| 91 |
+
return {
|
| 92 |
+
'success': False,
|
| 93 |
+
'message': 'Missing required parameter: folder_name'
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
demo_dir = get_demo_directory(create=False)
|
| 98 |
+
result = create_folder(demo_dir, parent_path, folder_name)
|
| 99 |
+
return handle_api_success(result)
|
| 100 |
+
except Exception as e:
|
| 101 |
+
return handle_api_error('Create failed', e)
|
| 102 |
+
|
backend/api/model_switch.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""模型切换 API"""
|
| 2 |
+
import gc
|
| 3 |
+
import os
|
| 4 |
+
from typing import Optional
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from backend import REGISTERED_MODELS
|
| 8 |
+
from backend.model_manager import project_registry
|
| 9 |
+
from backend.app_context import get_app_context
|
| 10 |
+
from backend.api.utils import require_admin
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_available_models():
|
| 14 |
+
"""获取所有可用的模型列表"""
|
| 15 |
+
return {
|
| 16 |
+
'success': True,
|
| 17 |
+
'models': list(REGISTERED_MODELS.keys())
|
| 18 |
+
}, 200
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _get_device_type() -> str:
|
| 22 |
+
"""获取当前设备类型"""
|
| 23 |
+
if torch.cuda.is_available():
|
| 24 |
+
return "cuda"
|
| 25 |
+
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 26 |
+
return "mps"
|
| 27 |
+
else:
|
| 28 |
+
return "cpu"
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _restore_env_vars(old_force_int8: Optional[str], old_force_bfloat16: Optional[str]) -> None:
|
| 32 |
+
"""恢复环境变量配置"""
|
| 33 |
+
if old_force_int8 is not None:
|
| 34 |
+
os.environ['FORCE_INT8'] = old_force_int8
|
| 35 |
+
else:
|
| 36 |
+
os.environ.pop('FORCE_INT8', None)
|
| 37 |
+
|
| 38 |
+
if old_force_bfloat16 is not None:
|
| 39 |
+
os.environ['CPU_FORCE_BFLOAT16'] = old_force_bfloat16
|
| 40 |
+
else:
|
| 41 |
+
os.environ.pop('CPU_FORCE_BFLOAT16', None)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_current_model():
|
| 45 |
+
"""获取当前使用的模型及量化配置"""
|
| 46 |
+
# 使用模块级上下文以获取持久化的模型状态
|
| 47 |
+
context = get_app_context(prefer_module_context=True)
|
| 48 |
+
device_type = _get_device_type()
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
'success': True,
|
| 52 |
+
'model': context.model_name,
|
| 53 |
+
'loading': context.model_loading,
|
| 54 |
+
'device_type': device_type,
|
| 55 |
+
'use_int8': os.environ.get('FORCE_INT8') == '1',
|
| 56 |
+
'use_bfloat16': os.environ.get('CPU_FORCE_BFLOAT16') == '1'
|
| 57 |
+
}, 200
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@require_admin
|
| 61 |
+
def switch_model(switch_request):
|
| 62 |
+
"""
|
| 63 |
+
切换模型(需要管理员权限)
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
switch_request: 切换请求字典,包含:
|
| 67 |
+
- model: 目标模型名称
|
| 68 |
+
- use_int8: 是否使用 INT8 量化(可选)
|
| 69 |
+
- use_bfloat16: 是否使用 bfloat16(可选,仅CPU)
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
(响应字典, 状态码) 元组
|
| 73 |
+
"""
|
| 74 |
+
if False: # 原在线切换逻辑保留,不执行;恢复时请删除此守卫并测试
|
| 75 |
+
target_model = switch_request.get('model')
|
| 76 |
+
use_int8 = switch_request.get('use_int8', False)
|
| 77 |
+
use_bfloat16 = switch_request.get('use_bfloat16', False)
|
| 78 |
+
|
| 79 |
+
# 验证请求
|
| 80 |
+
if not target_model:
|
| 81 |
+
return {
|
| 82 |
+
'success': False,
|
| 83 |
+
'message': 'Missing model parameter'
|
| 84 |
+
}, 400
|
| 85 |
+
|
| 86 |
+
# 检查模型是否可用
|
| 87 |
+
if target_model not in REGISTERED_MODELS:
|
| 88 |
+
available_models = list(REGISTERED_MODELS.keys())
|
| 89 |
+
return {
|
| 90 |
+
'success': False,
|
| 91 |
+
'message': f'Model {target_model} does not exist. Available models: {", ".join(available_models)}'
|
| 92 |
+
}, 404
|
| 93 |
+
|
| 94 |
+
# 获取设备类型
|
| 95 |
+
device_type = _get_device_type()
|
| 96 |
+
|
| 97 |
+
# 验证量化参数与设备兼容性
|
| 98 |
+
if use_int8 and device_type == "mps":
|
| 99 |
+
return {
|
| 100 |
+
'success': False,
|
| 101 |
+
'message': 'INT8 quantization is not supported on MPS device'
|
| 102 |
+
}, 400
|
| 103 |
+
|
| 104 |
+
if use_bfloat16 and device_type != "cpu":
|
| 105 |
+
return {
|
| 106 |
+
'success': False,
|
| 107 |
+
'message': 'bfloat16 quantization is only supported on CPU device'
|
| 108 |
+
}, 400
|
| 109 |
+
|
| 110 |
+
if use_int8 and use_bfloat16:
|
| 111 |
+
return {
|
| 112 |
+
'success': False,
|
| 113 |
+
'message': 'Cannot enable both INT8 and bfloat16 quantization'
|
| 114 |
+
}, 400
|
| 115 |
+
|
| 116 |
+
# 使用模块级上下文以确保状态修改持久化(不会被后续请求重置)
|
| 117 |
+
context = get_app_context(prefer_module_context=True)
|
| 118 |
+
current_model = context.model_name
|
| 119 |
+
|
| 120 |
+
# 保存当前环境变量配置(用于回滚)
|
| 121 |
+
old_force_int8 = os.environ.get('FORCE_INT8')
|
| 122 |
+
old_force_bfloat16 = os.environ.get('CPU_FORCE_BFLOAT16')
|
| 123 |
+
|
| 124 |
+
# 检查是否已经是目标模型且量化配置相同
|
| 125 |
+
current_int8 = os.environ.get('FORCE_INT8') == '1'
|
| 126 |
+
current_bfloat16 = os.environ.get('CPU_FORCE_BFLOAT16') == '1'
|
| 127 |
+
|
| 128 |
+
if (current_model == target_model and
|
| 129 |
+
current_int8 == use_int8 and
|
| 130 |
+
current_bfloat16 == use_bfloat16):
|
| 131 |
+
return {
|
| 132 |
+
'success': True,
|
| 133 |
+
'message': f'Already using model {target_model} (same quantization configuration)',
|
| 134 |
+
'model': target_model
|
| 135 |
+
}, 200
|
| 136 |
+
|
| 137 |
+
# 检查模型是否正在加载中(初始加载或切换)
|
| 138 |
+
if context.model_loading:
|
| 139 |
+
return {
|
| 140 |
+
'success': False,
|
| 141 |
+
'message': 'Model is currently loading, please try again later'
|
| 142 |
+
}, 503
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
# 标记开始加载
|
| 146 |
+
context.set_model_loading(True)
|
| 147 |
+
print(f"🔄 开始切换模型: {current_model} -> {target_model}")
|
| 148 |
+
|
| 149 |
+
# 设置新的量化环境变量
|
| 150 |
+
if use_int8:
|
| 151 |
+
os.environ['FORCE_INT8'] = '1'
|
| 152 |
+
print(" 设置量化: INT8")
|
| 153 |
+
else:
|
| 154 |
+
os.environ.pop('FORCE_INT8', None)
|
| 155 |
+
|
| 156 |
+
if use_bfloat16:
|
| 157 |
+
os.environ['CPU_FORCE_BFLOAT16'] = '1'
|
| 158 |
+
print(" 设置量化: bfloat16")
|
| 159 |
+
else:
|
| 160 |
+
os.environ.pop('CPU_FORCE_BFLOAT16', None)
|
| 161 |
+
|
| 162 |
+
# 卸载旧模型
|
| 163 |
+
if current_model and current_model in project_registry:
|
| 164 |
+
print(f" 卸载旧模型: {current_model}")
|
| 165 |
+
project_registry.unload(current_model)
|
| 166 |
+
gc.collect()
|
| 167 |
+
if device_type == "cuda":
|
| 168 |
+
torch.cuda.empty_cache()
|
| 169 |
+
elif device_type == "mps":
|
| 170 |
+
torch.mps.empty_cache()
|
| 171 |
+
|
| 172 |
+
# 加载新模型
|
| 173 |
+
print(f" 加载新模型: {target_model}")
|
| 174 |
+
project_registry.ensure_loaded(target_model)
|
| 175 |
+
|
| 176 |
+
# 更新当前模型
|
| 177 |
+
context.set_current_model(target_model)
|
| 178 |
+
|
| 179 |
+
print(f"✅ 模型切换成功: {target_model}")
|
| 180 |
+
|
| 181 |
+
return {
|
| 182 |
+
'success': True,
|
| 183 |
+
'message': f'Model switched to {target_model}',
|
| 184 |
+
'model': target_model
|
| 185 |
+
}, 200
|
| 186 |
+
|
| 187 |
+
except KeyError:
|
| 188 |
+
# 模型不存在(虽然前面已经检查过,但以防万一)
|
| 189 |
+
print(f"❌ 模型切换失败: 模型 {target_model} 未注册")
|
| 190 |
+
# 回滚:恢复旧模型名称和环境变量
|
| 191 |
+
context.set_current_model(current_model)
|
| 192 |
+
_restore_env_vars(old_force_int8, old_force_bfloat16)
|
| 193 |
+
return {
|
| 194 |
+
'success': False,
|
| 195 |
+
'message': f'Model {target_model} is not registered'
|
| 196 |
+
}, 404
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
# 加载失败,尝试回滚
|
| 200 |
+
print(f"❌ 模型切换失败: {e}")
|
| 201 |
+
print(f" 尝试回滚到旧模型: {current_model}")
|
| 202 |
+
|
| 203 |
+
try:
|
| 204 |
+
# 回滚:恢复环境变量和重新加载旧模型
|
| 205 |
+
_restore_env_vars(old_force_int8, old_force_bfloat16)
|
| 206 |
+
if current_model:
|
| 207 |
+
project_registry.ensure_loaded(current_model)
|
| 208 |
+
context.set_current_model(current_model)
|
| 209 |
+
print(f"✅ 已回滚到旧模型: {current_model}")
|
| 210 |
+
except Exception as rollback_error:
|
| 211 |
+
print(f"⚠️ 回滚失败: {rollback_error}")
|
| 212 |
+
|
| 213 |
+
return {
|
| 214 |
+
'success': False,
|
| 215 |
+
'message': f'Model switch failed: {str(e)}'
|
| 216 |
+
}, 500
|
| 217 |
+
|
| 218 |
+
finally:
|
| 219 |
+
# 无论成功还是失败,都要清除加载标志
|
| 220 |
+
context.set_model_loading(False)
|
| 221 |
+
gc.collect()
|
| 222 |
+
|
| 223 |
+
return (
|
| 224 |
+
{
|
| 225 |
+
'success': False,
|
| 226 |
+
'message': '在线模型切换已禁用,请通过命令行 --model / --semantic_model 指定后重启服务',
|
| 227 |
+
},
|
| 228 |
+
501,
|
| 229 |
+
)
|
backend/api/openai_completions.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OpenAI 兼容 /v1/completions:语义分析同款模型续写,其余响应字段固定。"""
|
| 2 |
+
|
| 3 |
+
import gc
|
| 4 |
+
import queue
|
| 5 |
+
import threading
|
| 6 |
+
import time
|
| 7 |
+
import traceback
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 9 |
+
|
| 10 |
+
from backend.model_manager import _inference_lock, get_semantic_model_display_name
|
| 11 |
+
from backend.oom import exit_if_oom, is_oom_error
|
| 12 |
+
from backend.completion_generator import (
|
| 13 |
+
PromptTooLongError,
|
| 14 |
+
apply_chat_template_for_completion,
|
| 15 |
+
completion_cancel_requested,
|
| 16 |
+
generate_completion_text,
|
| 17 |
+
global_completion_stop_event,
|
| 18 |
+
inference_shutdown_event,
|
| 19 |
+
)
|
| 20 |
+
from backend.api.analyze import LOCK_WAIT_TIMEOUT, QueueTimeoutError
|
| 21 |
+
from backend.api.sse_utils import (
|
| 22 |
+
SSEProgressReporter,
|
| 23 |
+
send_completion_delta_event,
|
| 24 |
+
send_error_event,
|
| 25 |
+
send_result_event,
|
| 26 |
+
)
|
| 27 |
+
from backend.access_log import get_client_ip
|
| 28 |
+
|
| 29 |
+
# 单次续写 SSE:从进入流式生成器起算的墙钟上限(含排队等推理锁 + 生成)。
|
| 30 |
+
COMPLETION_WALL_CLOCK_TIMEOUT_SEC = 300.0
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _log_cmpl_issue(request_id: int, msg: str) -> None:
|
| 34 |
+
"""续写非正常结束时一行说明(与成功时的 ``_log_completion_finished`` 二选一)。"""
|
| 35 |
+
print(f"\t⚠️ openai_completions req_id={request_id}: {msg}")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _log_request(model: str, prompt: str, client_ip=None):
|
| 39 |
+
from backend.access_log import log_openai_completions_request
|
| 40 |
+
return log_openai_completions_request(model, prompt, client_ip)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _build_response(
|
| 44 |
+
completion_text: str,
|
| 45 |
+
finish_reason: str,
|
| 46 |
+
prompt_tokens: int,
|
| 47 |
+
completion_tokens: int,
|
| 48 |
+
bpe_strings: List[Dict[str, Any]],
|
| 49 |
+
):
|
| 50 |
+
"""OpenAICompletionsResponse:choices + usage;info_radar 为续写 token 级数据。"""
|
| 51 |
+
total = prompt_tokens + completion_tokens
|
| 52 |
+
return {
|
| 53 |
+
"id": "cmpl-stub-info-radar",
|
| 54 |
+
"object": "text_completion",
|
| 55 |
+
"created": int(time.time()),
|
| 56 |
+
"model": get_semantic_model_display_name(),
|
| 57 |
+
"choices": [
|
| 58 |
+
{
|
| 59 |
+
"text": completion_text,
|
| 60 |
+
"index": 0,
|
| 61 |
+
"finish_reason": finish_reason,
|
| 62 |
+
}
|
| 63 |
+
],
|
| 64 |
+
"usage": {
|
| 65 |
+
"prompt_tokens": prompt_tokens,
|
| 66 |
+
"completion_tokens": completion_tokens,
|
| 67 |
+
"total_tokens": total,
|
| 68 |
+
},
|
| 69 |
+
"info_radar": {
|
| 70 |
+
"bpe_strings": bpe_strings,
|
| 71 |
+
},
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# 与 generate_completion_text 返回一致(末项 TTFT 秒;未生成时为 None)
|
| 76 |
+
CompletionRunResult = Tuple[str, str, int, int, List[Dict[str, Any]], Optional[float]]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _completion_inference_after_lock(
|
| 80 |
+
prompt: str,
|
| 81 |
+
request_id: int,
|
| 82 |
+
lock_wait_time: float,
|
| 83 |
+
*,
|
| 84 |
+
stream_delta: Optional[Callable[[str, bool], None]] = None,
|
| 85 |
+
max_tokens: Optional[int] = None,
|
| 86 |
+
) -> CompletionRunResult:
|
| 87 |
+
"""
|
| 88 |
+
在已持有推理锁的上下文中执行续写(旧版非流式路径的持锁体内逻辑)。
|
| 89 |
+
流式可传 stream_delta;中止由 ``completion_cancel_requested()`` 统一判断。
|
| 90 |
+
"""
|
| 91 |
+
from backend.access_log import log_openai_completions_start
|
| 92 |
+
|
| 93 |
+
log_openai_completions_start(request_id, lock_wait_time)
|
| 94 |
+
return generate_completion_text(prompt, stream_delta=stream_delta, max_tokens=max_tokens)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _log_completion_finished(
|
| 98 |
+
request_id: int,
|
| 99 |
+
prompt_tokens: int,
|
| 100 |
+
completion_tokens: int,
|
| 101 |
+
elapsed: float,
|
| 102 |
+
ttft_s: Optional[float],
|
| 103 |
+
) -> None:
|
| 104 |
+
"""旧非流式分支在返回 JSON 前、流式在发出末条 result 前的同一行日志。
|
| 105 |
+
|
| 106 |
+
prompt tokens/s = prompt_tokens / TTFT;generate tokens/s = completion_tokens / (elapsed − TTFT)。
|
| 107 |
+
``elapsed`` 为 SSE 起点至结束;与 TTFT 计时原点不完全一致时,吞吐率为近似值。
|
| 108 |
+
无 TTFT(``ttft_s`` 为 ``None``)时不输出时间与吞吐字段。
|
| 109 |
+
"""
|
| 110 |
+
if ttft_s is None:
|
| 111 |
+
tps_part = ""
|
| 112 |
+
else:
|
| 113 |
+
decode_s = elapsed - ttft_s
|
| 114 |
+
prompt_time_s = f"{ttft_s:.4f}" if ttft_s > 0 else "n/a"
|
| 115 |
+
gen_time_s = f"{decode_s:.4f}" if decode_s > 0 else "n/a"
|
| 116 |
+
prompt_part = f"{prompt_tokens / ttft_s:.2f}" if ttft_s > 0 else "n/a"
|
| 117 |
+
gen_part = (
|
| 118 |
+
f"{completion_tokens / decode_s:.2f}"
|
| 119 |
+
if completion_tokens and decode_s > 0
|
| 120 |
+
else "n/a"
|
| 121 |
+
)
|
| 122 |
+
tps_part = (
|
| 123 |
+
f", time= {prompt_time_s} / {gen_time_s}s, "
|
| 124 |
+
f"tokens/s= {prompt_part} / {gen_part}"
|
| 125 |
+
)
|
| 126 |
+
print(
|
| 127 |
+
f"\t📤 API openai_completions response: req_id={request_id}, "
|
| 128 |
+
f"prompt/generate tokens= {prompt_tokens} / {completion_tokens}, "
|
| 129 |
+
f"{tps_part}"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def _generate_completion_events(
|
| 134 |
+
prompt: str,
|
| 135 |
+
request_id: int,
|
| 136 |
+
*,
|
| 137 |
+
max_tokens: Optional[int] = None,
|
| 138 |
+
):
|
| 139 |
+
global_completion_stop_event.clear()
|
| 140 |
+
q: queue.Queue = queue.Queue()
|
| 141 |
+
start_time = time.perf_counter()
|
| 142 |
+
|
| 143 |
+
def run():
|
| 144 |
+
try:
|
| 145 |
+
lock_wait_start = time.perf_counter()
|
| 146 |
+
lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT)
|
| 147 |
+
if not lock_acquired:
|
| 148 |
+
q.put(("error", QueueTimeoutError(
|
| 149 |
+
f"排队等待超过 {LOCK_WAIT_TIMEOUT} 秒,服务繁忙,请稍后重试"
|
| 150 |
+
)))
|
| 151 |
+
return
|
| 152 |
+
lock_wait_time = time.perf_counter() - lock_wait_start
|
| 153 |
+
try:
|
| 154 |
+
def stream_delta(text: str, stream_end: bool) -> None:
|
| 155 |
+
if completion_cancel_requested():
|
| 156 |
+
return
|
| 157 |
+
q.put(("delta", text, stream_end))
|
| 158 |
+
|
| 159 |
+
result = _completion_inference_after_lock(
|
| 160 |
+
prompt,
|
| 161 |
+
request_id,
|
| 162 |
+
lock_wait_time,
|
| 163 |
+
stream_delta=stream_delta,
|
| 164 |
+
max_tokens=max_tokens,
|
| 165 |
+
)
|
| 166 |
+
finally:
|
| 167 |
+
_inference_lock.release()
|
| 168 |
+
gc.collect()
|
| 169 |
+
q.put(("result", result))
|
| 170 |
+
except Exception as e:
|
| 171 |
+
q.put(("error", e))
|
| 172 |
+
|
| 173 |
+
worker = threading.Thread(target=run, daemon=True)
|
| 174 |
+
worker.start()
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
while True:
|
| 178 |
+
elapsed = time.perf_counter() - start_time
|
| 179 |
+
if elapsed >= COMPLETION_WALL_CLOCK_TIMEOUT_SEC:
|
| 180 |
+
try:
|
| 181 |
+
item = q.get_nowait()
|
| 182 |
+
except queue.Empty:
|
| 183 |
+
global_completion_stop_event.set()
|
| 184 |
+
_log_cmpl_issue(
|
| 185 |
+
request_id,
|
| 186 |
+
f"墙钟超时 {elapsed:.1f}s / 上限 {COMPLETION_WALL_CLOCK_TIMEOUT_SEC:.0f}s",
|
| 187 |
+
)
|
| 188 |
+
yield send_error_event(
|
| 189 |
+
f"续写处理超过 {COMPLETION_WALL_CLOCK_TIMEOUT_SEC:.0f} 秒(墙钟限制),已中止",
|
| 190 |
+
504,
|
| 191 |
+
)
|
| 192 |
+
return
|
| 193 |
+
else:
|
| 194 |
+
try:
|
| 195 |
+
# 每 100ms 醒一次,检查一次是否到 60 秒
|
| 196 |
+
item = q.get(timeout=0.1)
|
| 197 |
+
except queue.Empty:
|
| 198 |
+
continue
|
| 199 |
+
kind = item[0]
|
| 200 |
+
if kind == "delta":
|
| 201 |
+
_, text, stream_end = item
|
| 202 |
+
if text or stream_end:
|
| 203 |
+
yield send_completion_delta_event(text, stream_end)
|
| 204 |
+
elif kind == "result":
|
| 205 |
+
(
|
| 206 |
+
_completion_text,
|
| 207 |
+
finish_reason,
|
| 208 |
+
prompt_tokens,
|
| 209 |
+
completion_tokens,
|
| 210 |
+
bpe_strings,
|
| 211 |
+
ttft_s,
|
| 212 |
+
) = item[1]
|
| 213 |
+
elapsed = time.perf_counter() - start_time
|
| 214 |
+
if global_completion_stop_event.is_set() or inference_shutdown_event.is_set():
|
| 215 |
+
finish_reason = "abort"
|
| 216 |
+
if inference_shutdown_event.is_set():
|
| 217 |
+
_log_cmpl_issue(
|
| 218 |
+
request_id,
|
| 219 |
+
f"进程终止,续写中止 elapsed={elapsed:.2f}s "
|
| 220 |
+
f"tokens={prompt_tokens}/{completion_tokens}",
|
| 221 |
+
)
|
| 222 |
+
elif global_completion_stop_event.is_set():
|
| 223 |
+
_log_cmpl_issue(
|
| 224 |
+
request_id,
|
| 225 |
+
f"用户 Stop,续写中止 elapsed={elapsed:.2f}s "
|
| 226 |
+
f"tokens={prompt_tokens}/{completion_tokens}",
|
| 227 |
+
)
|
| 228 |
+
else:
|
| 229 |
+
_log_completion_finished(
|
| 230 |
+
request_id,
|
| 231 |
+
prompt_tokens,
|
| 232 |
+
completion_tokens,
|
| 233 |
+
elapsed,
|
| 234 |
+
ttft_s,
|
| 235 |
+
)
|
| 236 |
+
yield send_result_event(
|
| 237 |
+
_build_response(
|
| 238 |
+
_completion_text,
|
| 239 |
+
finish_reason,
|
| 240 |
+
prompt_tokens,
|
| 241 |
+
completion_tokens,
|
| 242 |
+
bpe_strings,
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
return
|
| 246 |
+
elif kind == "error":
|
| 247 |
+
err = item[1]
|
| 248 |
+
if isinstance(err, PromptTooLongError):
|
| 249 |
+
_log_cmpl_issue(request_id, f"prompt too long: {err}")
|
| 250 |
+
yield send_error_event(str(err), 400)
|
| 251 |
+
elif isinstance(err, QueueTimeoutError):
|
| 252 |
+
_log_cmpl_issue(request_id, f"排队超时: {err}")
|
| 253 |
+
yield send_error_event(str(err), 503)
|
| 254 |
+
else:
|
| 255 |
+
exit_if_oom(err, defer_seconds=1)
|
| 256 |
+
if is_oom_error(err):
|
| 257 |
+
yield send_error_event(str(err), 500)
|
| 258 |
+
return
|
| 259 |
+
_log_cmpl_issue(
|
| 260 |
+
request_id,
|
| 261 |
+
"".join(
|
| 262 |
+
traceback.format_exception(
|
| 263 |
+
type(err), err, err.__traceback__
|
| 264 |
+
)
|
| 265 |
+
).strip(),
|
| 266 |
+
)
|
| 267 |
+
yield send_error_event(str(err), 500)
|
| 268 |
+
return
|
| 269 |
+
finally:
|
| 270 |
+
gc.collect()
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def _completions_sse_response(
|
| 274 |
+
prompt: str,
|
| 275 |
+
request_id: int,
|
| 276 |
+
*,
|
| 277 |
+
max_tokens: Optional[int] = None,
|
| 278 |
+
):
|
| 279 |
+
return SSEProgressReporter(
|
| 280 |
+
lambda: _generate_completion_events(prompt, request_id, max_tokens=max_tokens)
|
| 281 |
+
).create_response()
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def completions_stop():
|
| 285 |
+
"""
|
| 286 |
+
单用户串行:置位全局停止标志,使当前续写在 generate 与 SSE 回调中尽快结束。
|
| 287 |
+
无需 body;新一次 POST /v1/completions 时会在流式生成器入口清除该标志。
|
| 288 |
+
"""
|
| 289 |
+
global_completion_stop_event.set()
|
| 290 |
+
return {"ok": True}, 200
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def completions_prompt(completions_prompt_request):
|
| 294 |
+
"""
|
| 295 |
+
将用户原文套用 chat template,返回实际送入续写接口的完整 prompt 字符串(JSON)。
|
| 296 |
+
|
| 297 |
+
Args:
|
| 298 |
+
completions_prompt_request: 含 model、prompt(用户输入),见 server_openai_definitions.yaml
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
(dict with prompt_used, 200) 或校验/过长错误
|
| 302 |
+
"""
|
| 303 |
+
if not isinstance(completions_prompt_request, dict):
|
| 304 |
+
completions_prompt_request = {}
|
| 305 |
+
model = completions_prompt_request.get("model")
|
| 306 |
+
prompt = completions_prompt_request.get("prompt")
|
| 307 |
+
|
| 308 |
+
if not model:
|
| 309 |
+
return {"success": False, "message": "缺少 model 字段"}, 400
|
| 310 |
+
if prompt is None:
|
| 311 |
+
return {"success": False, "message": "缺少 prompt 字段"}, 400
|
| 312 |
+
if not isinstance(prompt, str):
|
| 313 |
+
return {"success": False, "message": "prompt 必须为字符串"}, 400
|
| 314 |
+
|
| 315 |
+
system_opt: Optional[str]
|
| 316 |
+
if "system" not in completions_prompt_request:
|
| 317 |
+
system_opt = None
|
| 318 |
+
else:
|
| 319 |
+
system_raw = completions_prompt_request.get("system")
|
| 320 |
+
if not isinstance(system_raw, str):
|
| 321 |
+
return {"success": False, "message": "system 必须为字符串"}, 400
|
| 322 |
+
system_opt = system_raw
|
| 323 |
+
|
| 324 |
+
client_ip = get_client_ip()
|
| 325 |
+
from backend.access_log import log_openai_completions_prompt_request
|
| 326 |
+
|
| 327 |
+
log_openai_completions_prompt_request(
|
| 328 |
+
model,
|
| 329 |
+
user_prompt=prompt,
|
| 330 |
+
system=system_opt,
|
| 331 |
+
client_ip=client_ip,
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
try:
|
| 335 |
+
prompt_used = apply_chat_template_for_completion(prompt, system_opt)
|
| 336 |
+
except PromptTooLongError as e:
|
| 337 |
+
return {"success": False, "message": str(e)}, 400
|
| 338 |
+
|
| 339 |
+
return {"prompt_used": prompt_used}, 200
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def completions(completions_request):
|
| 343 |
+
"""
|
| 344 |
+
文本补写:与 analyze_semantic 共用推理锁与 semantic 模型;响应恒为 text/event-stream(SSE)。
|
| 345 |
+
``prompt`` 须为已确定的完整模型输入(需 chat template 时请先调 POST /v1/completions/prompt)。
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
completions_request: 含 model、prompt 等,见 server_openai_definitions.yaml
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
SSE Response;校验失败时 (错误体, 400/503/500)
|
| 352 |
+
"""
|
| 353 |
+
if not isinstance(completions_request, dict):
|
| 354 |
+
completions_request = {}
|
| 355 |
+
model = completions_request.get("model")
|
| 356 |
+
prompt = completions_request.get("prompt")
|
| 357 |
+
|
| 358 |
+
if not model:
|
| 359 |
+
return {"success": False, "message": "缺少 model 字段"}, 400
|
| 360 |
+
if prompt is None:
|
| 361 |
+
return {"success": False, "message": "缺少 prompt 字段"}, 400
|
| 362 |
+
if not isinstance(prompt, str):
|
| 363 |
+
return {"success": False, "message": "prompt 必须为字符串"}, 400
|
| 364 |
+
|
| 365 |
+
max_tokens_raw = completions_request.get("max_tokens")
|
| 366 |
+
max_tokens: Optional[int]
|
| 367 |
+
if max_tokens_raw is None:
|
| 368 |
+
max_tokens = None
|
| 369 |
+
elif type(max_tokens_raw) is not int:
|
| 370 |
+
return {"success": False, "message": "max_tokens 须为正整数"}, 400
|
| 371 |
+
elif max_tokens_raw <= 0:
|
| 372 |
+
return {"success": False, "message": "max_tokens 须 > 0"}, 400
|
| 373 |
+
else:
|
| 374 |
+
max_tokens = max_tokens_raw
|
| 375 |
+
|
| 376 |
+
client_ip = get_client_ip()
|
| 377 |
+
request_id = _log_request(model, prompt, client_ip)
|
| 378 |
+
|
| 379 |
+
return _completions_sse_response(prompt, request_id, max_tokens=max_tokens)
|
backend/api/prediction_attribute.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""预测归因 API"""
|
| 2 |
+
import gc
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
from backend.model_manager import _inference_lock
|
| 6 |
+
from backend.oom import exit_if_oom
|
| 7 |
+
from backend.prediction_attributor import analyze_prediction_attribution
|
| 8 |
+
from backend.api.analyze import LOCK_WAIT_TIMEOUT
|
| 9 |
+
from backend.access_log import get_client_ip, log_prediction_attribute_request
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def prediction_attribute(attribution_request):
|
| 13 |
+
"""
|
| 14 |
+
对上下文文本的下一 token 预测做归因分析。
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
attribution_request: 包含 context 和 target_prediction 的字典
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
(响应字典, 状态码) 元组
|
| 21 |
+
"""
|
| 22 |
+
context = attribution_request.get("context")
|
| 23 |
+
target_prediction = attribution_request.get("target_prediction")
|
| 24 |
+
model = attribution_request.get("model")
|
| 25 |
+
|
| 26 |
+
if context is None:
|
| 27 |
+
return {"success": False, "message": "Missing required field: context"}, 400
|
| 28 |
+
if not isinstance(context, str):
|
| 29 |
+
return {"success": False, "message": "context must be a string"}, 400
|
| 30 |
+
if context == "":
|
| 31 |
+
return {"success": False, "message": "Missing required field: context"}, 400
|
| 32 |
+
|
| 33 |
+
if target_prediction is not None and not isinstance(target_prediction, str):
|
| 34 |
+
return {"success": False, "message": "target_prediction must be a string"}, 400
|
| 35 |
+
if target_prediction == "":
|
| 36 |
+
return {"success": False, "message": "target_prediction must not be empty"}, 400
|
| 37 |
+
|
| 38 |
+
if model is None:
|
| 39 |
+
return {"success": False, "message": "Missing required field: model"}, 400
|
| 40 |
+
if not isinstance(model, str):
|
| 41 |
+
return {"success": False, "message": "model must be a string"}, 400
|
| 42 |
+
if model not in ("base", "instruct"):
|
| 43 |
+
return {"success": False, "message": 'model must be "base" or "instruct"'}, 400
|
| 44 |
+
|
| 45 |
+
client_ip = get_client_ip()
|
| 46 |
+
start_time = time.perf_counter()
|
| 47 |
+
request_id = log_prediction_attribute_request(context, target_prediction, model, client_ip)
|
| 48 |
+
|
| 49 |
+
lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT)
|
| 50 |
+
if not lock_acquired:
|
| 51 |
+
return {
|
| 52 |
+
"success": False,
|
| 53 |
+
"message": (
|
| 54 |
+
f"Queue wait exceeded {LOCK_WAIT_TIMEOUT} seconds; "
|
| 55 |
+
"server is busy, please try again later."
|
| 56 |
+
),
|
| 57 |
+
}, 503
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
result = analyze_prediction_attribution(context, target_prediction, model=model)
|
| 61 |
+
except ValueError as e:
|
| 62 |
+
return {"success": False, "message": str(e)}, 400
|
| 63 |
+
except Exception as e:
|
| 64 |
+
import traceback
|
| 65 |
+
traceback.print_exc()
|
| 66 |
+
exit_if_oom(e, defer_seconds=1)
|
| 67 |
+
return {"success": False, "message": str(e)}, 500
|
| 68 |
+
finally:
|
| 69 |
+
_inference_lock.release()
|
| 70 |
+
gc.collect()
|
| 71 |
+
|
| 72 |
+
elapsed = time.perf_counter() - start_time
|
| 73 |
+
tokens = len(result.get("token_attribution", []))
|
| 74 |
+
print(
|
| 75 |
+
f"\t📤 API prediction_attribute response: req_id={request_id}, "
|
| 76 |
+
f"tokens={tokens}, response_time={elapsed:.4f}s"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return {"success": True, **result}, 200
|
backend/api/sse_utils.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Server-Sent Events (SSE) 工具模块"""
|
| 2 |
+
import json
|
| 3 |
+
import queue
|
| 4 |
+
import time
|
| 5 |
+
from typing import Callable, Generator, Optional, Tuple
|
| 6 |
+
from flask import Response
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SSEProgressReporter:
|
| 10 |
+
"""SSE进度报告器"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, generator_func: Callable):
|
| 13 |
+
"""
|
| 14 |
+
初始化SSE进度报告器
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
generator_func: 生成器函数,用于生成SSE事件
|
| 18 |
+
"""
|
| 19 |
+
self.generator_func = generator_func
|
| 20 |
+
|
| 21 |
+
def generate(self):
|
| 22 |
+
"""生成SSE事件流"""
|
| 23 |
+
try:
|
| 24 |
+
for event in self.generator_func():
|
| 25 |
+
yield event
|
| 26 |
+
except Exception as e:
|
| 27 |
+
# 发送错误事件
|
| 28 |
+
error_data = {
|
| 29 |
+
'type': 'error',
|
| 30 |
+
'message': str(e)
|
| 31 |
+
}
|
| 32 |
+
yield f"data: {json.dumps(error_data)}\n\n"
|
| 33 |
+
|
| 34 |
+
def create_response(self) -> Response:
|
| 35 |
+
"""创建SSE响应"""
|
| 36 |
+
return Response(
|
| 37 |
+
self.generate(),
|
| 38 |
+
mimetype='text/event-stream',
|
| 39 |
+
headers={
|
| 40 |
+
'Cache-Control': 'no-cache',
|
| 41 |
+
'X-Accel-Buffering': 'no', # 禁用nginx缓冲
|
| 42 |
+
'Connection': 'keep-alive'
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def send_progress_event(step: int, total_steps: int, stage: str, percentage: Optional[int] = None, message: Optional[str] = None) -> str:
|
| 48 |
+
"""
|
| 49 |
+
生成SSE进度事件
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
step: 当前步骤 (1-based)
|
| 53 |
+
total_steps: 总步骤数
|
| 54 |
+
stage: 阶段名称 (encoding, inference, processing)
|
| 55 |
+
percentage: 可选的进度百分比 (0-100),仅在需要显示百分比的阶段提供
|
| 56 |
+
message: 可选的进度消息
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
SSE格式的事件字符串
|
| 60 |
+
"""
|
| 61 |
+
data = {
|
| 62 |
+
'type': 'progress',
|
| 63 |
+
'step': step,
|
| 64 |
+
'total_steps': total_steps,
|
| 65 |
+
'stage': stage
|
| 66 |
+
}
|
| 67 |
+
if percentage is not None:
|
| 68 |
+
data['percentage'] = percentage
|
| 69 |
+
if message:
|
| 70 |
+
data['message'] = message
|
| 71 |
+
return f"data: {json.dumps(data)}\n\n"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def send_result_event(result: dict) -> str:
|
| 75 |
+
"""
|
| 76 |
+
生成SSE结果事件
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
result: 分析结果字典
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
SSE格式的事件字符串
|
| 83 |
+
"""
|
| 84 |
+
data = {
|
| 85 |
+
'type': 'result',
|
| 86 |
+
'data': result
|
| 87 |
+
}
|
| 88 |
+
return f"data: {json.dumps(data)}\n\n"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def send_completion_delta_event(text: str, stream_end: bool) -> str:
|
| 92 |
+
"""续写流式:与 analyze 的 progress/result 并列,type=delta。"""
|
| 93 |
+
data = {
|
| 94 |
+
"type": "delta",
|
| 95 |
+
"text": text,
|
| 96 |
+
}
|
| 97 |
+
if stream_end:
|
| 98 |
+
data["stream_end"] = True
|
| 99 |
+
return f"data: {json.dumps(data)}\n\n"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def send_prompt_used_event(prompt_used: str) -> str:
|
| 103 |
+
"""续写流式:在首条 delta 之前下发实际送入模型的 prompt 原文。"""
|
| 104 |
+
data = {
|
| 105 |
+
"type": "prompt_used",
|
| 106 |
+
"prompt_used": prompt_used,
|
| 107 |
+
}
|
| 108 |
+
return f"data: {json.dumps(data)}\n\n"
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def send_error_event(message: str, status_code: Optional[int] = None) -> str:
|
| 112 |
+
"""
|
| 113 |
+
生成SSE错误事件
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
message: 错误消息
|
| 117 |
+
status_code: 可选 HTTP 状态码,供非流式封装解析
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
SSE格式的事件字符串
|
| 121 |
+
"""
|
| 122 |
+
data = {'type': 'error', 'message': message}
|
| 123 |
+
if status_code is not None:
|
| 124 |
+
data['status_code'] = status_code
|
| 125 |
+
return f"data: {json.dumps(data)}\n\n"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def consume_progress_queue(
|
| 129 |
+
progress_queue: queue.Queue,
|
| 130 |
+
analysis_done,
|
| 131 |
+
start_time: float,
|
| 132 |
+
timeout_seconds: float,
|
| 133 |
+
timeout_label: str = "分析",
|
| 134 |
+
) -> Generator[Tuple[str, str], None, None]:
|
| 135 |
+
"""
|
| 136 |
+
消费进度队列,yield (kind, event_str)。
|
| 137 |
+
kind: 'progress' | 'timeout' | 'done'
|
| 138 |
+
event_str: SSE 格式字符串(timeout 时含错误信息,done 时为空)
|
| 139 |
+
"""
|
| 140 |
+
done_received = False
|
| 141 |
+
last_progress_info = None
|
| 142 |
+
|
| 143 |
+
while True:
|
| 144 |
+
elapsed = time.perf_counter() - start_time
|
| 145 |
+
if elapsed >= timeout_seconds:
|
| 146 |
+
progress_str = f" | {last_progress_info}" if last_progress_info else ""
|
| 147 |
+
print(f"⏱️ {timeout_label}超时: 处理时长 {elapsed:.2f}s 超过限制 {timeout_seconds}s,已放弃{progress_str}")
|
| 148 |
+
yield ('timeout', send_error_event(f"分析超时:处理时长超过 {timeout_seconds} 秒限制,已放弃"))
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
try:
|
| 152 |
+
event_data = progress_queue.get(timeout=0.1)
|
| 153 |
+
event_type = event_data[0]
|
| 154 |
+
if event_type == 'progress':
|
| 155 |
+
_, step, total_steps, stage, percentage = event_data
|
| 156 |
+
if total_steps > 0:
|
| 157 |
+
last_progress_info = f"step={step}/{total_steps}"
|
| 158 |
+
else:
|
| 159 |
+
last_progress_info = f"step={step}"
|
| 160 |
+
if stage:
|
| 161 |
+
last_progress_info += f" stage={stage}"
|
| 162 |
+
if percentage is not None:
|
| 163 |
+
last_progress_info += f" {percentage}%"
|
| 164 |
+
yield ('progress', send_progress_event(step, total_steps, stage, percentage))
|
| 165 |
+
elif event_type == 'done':
|
| 166 |
+
done_received = True
|
| 167 |
+
while not progress_queue.empty():
|
| 168 |
+
try:
|
| 169 |
+
remaining = progress_queue.get_nowait()
|
| 170 |
+
if remaining[0] == 'progress':
|
| 171 |
+
_, step, total_steps, stage, percentage = remaining
|
| 172 |
+
yield ('progress', send_progress_event(step, total_steps, stage, percentage))
|
| 173 |
+
except queue.Empty:
|
| 174 |
+
break
|
| 175 |
+
yield ('done', '')
|
| 176 |
+
return
|
| 177 |
+
except queue.Empty:
|
| 178 |
+
if analysis_done.is_set() and done_received:
|
| 179 |
+
yield ('done', '')
|
| 180 |
+
return
|
| 181 |
+
|
backend/api/static.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""静态文件路由"""
|
| 2 |
+
import mimetypes
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from urllib.parse import unquote
|
| 5 |
+
|
| 6 |
+
from flask import Response, redirect, abort, request
|
| 7 |
+
from werkzeug.utils import safe_join
|
| 8 |
+
|
| 9 |
+
from backend.access_log import log_page_load, log_demo_file
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _read_static_file(directory: str, path: str) -> Response:
|
| 13 |
+
"""读取静态文件并返回 Response,避免 send_from_directory 在 ASGI/a2wsgi 下
|
| 14 |
+
流式传输导致的 Content-Length 不匹配(RuntimeError: Response content shorter than Content-Length)。
|
| 15 |
+
"""
|
| 16 |
+
base = Path(directory).resolve()
|
| 17 |
+
safe_path = safe_join(str(base), path)
|
| 18 |
+
if safe_path is None:
|
| 19 |
+
abort(404)
|
| 20 |
+
full_path = Path(safe_path)
|
| 21 |
+
if not full_path.is_file() or not str(full_path.resolve()).startswith(str(base)):
|
| 22 |
+
abort(404)
|
| 23 |
+
content = full_path.read_bytes()
|
| 24 |
+
mimetype, _ = mimetypes.guess_type(path)
|
| 25 |
+
mimetype = mimetype or "application/octet-stream"
|
| 26 |
+
return Response(content, mimetype=mimetype, headers={"Content-Length": str(len(content))})
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def register_static_routes(app):
|
| 30 |
+
"""注册静态文件路由"""
|
| 31 |
+
|
| 32 |
+
@app.route('/')
|
| 33 |
+
def redir():
|
| 34 |
+
target = 'client/index.html'
|
| 35 |
+
if request.query_string:
|
| 36 |
+
target += '?' + request.query_string.decode()
|
| 37 |
+
return redirect(target)
|
| 38 |
+
|
| 39 |
+
@app.route('/client/<path:path>')
|
| 40 |
+
def send_static(path):
|
| 41 |
+
"""serves all files from ./client/dist/ to ``/client/<path:path>``"""
|
| 42 |
+
if path.endswith('.html'):
|
| 43 |
+
log_page_load(path)
|
| 44 |
+
return _read_static_file('client/dist', path)
|
| 45 |
+
|
| 46 |
+
@app.route('/demo/<path:path>')
|
| 47 |
+
def send_demo(path):
|
| 48 |
+
"""serves all demo files from the demo dir to ``/demo/<path:path>``"""
|
| 49 |
+
from backend.app_context import get_data_dir
|
| 50 |
+
data_dir = get_data_dir()
|
| 51 |
+
log_demo_file(path)
|
| 52 |
+
try:
|
| 53 |
+
decoded_path = unquote(path)
|
| 54 |
+
return _read_static_file(str(data_dir), decoded_path)
|
| 55 |
+
except Exception:
|
| 56 |
+
try:
|
| 57 |
+
return _read_static_file(str(data_dir), path)
|
| 58 |
+
except Exception:
|
| 59 |
+
abort(404)
|
| 60 |
+
|
backend/api/utils.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""API 工具函数"""
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import traceback
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def round_to_sig_figs(x: float, n: int = 7) -> float:
|
| 8 |
+
"""将浮点数舍入为 n 位有效数字。0 或非有限值原样返回。"""
|
| 9 |
+
if x == 0 or not math.isfinite(x):
|
| 10 |
+
return x
|
| 11 |
+
return float(f"{x:.{n}g}")
|
| 12 |
+
from functools import wraps
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from flask import request, jsonify
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_demo_directory(create=False):
|
| 18 |
+
"""获取 demo 目录路径"""
|
| 19 |
+
from backend.app_context import get_demo_directory as _get_demo_dir
|
| 20 |
+
return _get_demo_dir(create=create)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def handle_api_error(operation_name: str, error: Exception) -> dict:
|
| 24 |
+
"""
|
| 25 |
+
统一的 API 错误处理
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
operation_name: 操作名称(如 'Save failed'、'Delete failed')
|
| 29 |
+
error: 异常对象
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
标准错误响应字典
|
| 33 |
+
"""
|
| 34 |
+
error_msg = f'{operation_name}: {str(error)}'
|
| 35 |
+
print(f"❌ {error_msg}")
|
| 36 |
+
traceback.print_exc()
|
| 37 |
+
return {
|
| 38 |
+
'success': False,
|
| 39 |
+
'message': error_msg
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def handle_api_success(result: dict, operation_name: str = None) -> dict:
|
| 44 |
+
"""
|
| 45 |
+
处理 API 成功响应,打印日志
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
result: 操作结果字典
|
| 49 |
+
operation_name: 可选的操作名称,用于日志
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
结果字典
|
| 53 |
+
"""
|
| 54 |
+
if result.get('success'):
|
| 55 |
+
if operation_name:
|
| 56 |
+
print(f"✓ {operation_name}")
|
| 57 |
+
elif result.get('message'):
|
| 58 |
+
print(f"✓ {result.get('message')}")
|
| 59 |
+
else:
|
| 60 |
+
message = result.get('message', 'Operation failed')
|
| 61 |
+
print(f"❌ {message}")
|
| 62 |
+
return result
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_admin_token() -> str:
|
| 66 |
+
"""
|
| 67 |
+
获取管理员token(从环境变量读取)
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
管理员token字符串,如果未设置则返回None
|
| 71 |
+
"""
|
| 72 |
+
return os.environ.get('INFORADAR_ADMIN_TOKEN')
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def validate_admin_token(request_token: str) -> tuple[bool, str]:
|
| 76 |
+
"""
|
| 77 |
+
验证管理员token是否有效
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
request_token: 要验证的token
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
(是否有效, 错误信息)
|
| 84 |
+
"""
|
| 85 |
+
admin_token = get_admin_token()
|
| 86 |
+
|
| 87 |
+
# 如果未配置INFORADAR_ADMIN_TOKEN,返回未启用状态
|
| 88 |
+
if admin_token is None:
|
| 89 |
+
return False, 'Admin features are not enabled'
|
| 90 |
+
|
| 91 |
+
# 验证token
|
| 92 |
+
if request_token == admin_token:
|
| 93 |
+
return True, ''
|
| 94 |
+
else:
|
| 95 |
+
return False, 'Invalid admin token'
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def require_admin(f):
|
| 99 |
+
"""
|
| 100 |
+
装饰器:要求管理员权限才能访问的API
|
| 101 |
+
|
| 102 |
+
检查请求头中的 X-Admin-Token 是否与配置的 INFORADAR_ADMIN_TOKEN 匹配
|
| 103 |
+
如果未配置 INFORADAR_ADMIN_TOKEN,视为全是普通用户,拒绝所有写操作
|
| 104 |
+
"""
|
| 105 |
+
@wraps(f)
|
| 106 |
+
def wrapper(*args, **kwargs):
|
| 107 |
+
request_token = request.headers.get('X-Admin-Token')
|
| 108 |
+
is_valid, error_message = validate_admin_token(request_token)
|
| 109 |
+
|
| 110 |
+
if not is_valid:
|
| 111 |
+
return {
|
| 112 |
+
'success': False,
|
| 113 |
+
'message': 'Admin permission required'
|
| 114 |
+
}, 403
|
| 115 |
+
|
| 116 |
+
return f(*args, **kwargs)
|
| 117 |
+
return wrapper
|
| 118 |
+
|
backend/app_context.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""应用上下文管理
|
| 2 |
+
|
| 3 |
+
使用类级别单例模式,提供进程级共享状态。
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional
|
| 9 |
+
from argparse import Namespace
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class AppContext:
|
| 13 |
+
"""
|
| 14 |
+
应用上下文(进程级单例)
|
| 15 |
+
|
| 16 |
+
通过 AppContext.init() 初始化,通过 AppContext.get() 获取。
|
| 17 |
+
单例模式确保整个进程共享同一个上下文,避免模块重新导入导致的状态不一致。
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
_instance: Optional['AppContext'] = None
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def get(cls) -> 'AppContext':
|
| 24 |
+
"""获取上下文单例(必须先调用 init)"""
|
| 25 |
+
if cls._instance is None:
|
| 26 |
+
raise RuntimeError("AppContext 未初始化,请先调用 AppContext.init()")
|
| 27 |
+
return cls._instance
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def init(cls, args: Namespace, data_dir: Path) -> 'AppContext':
|
| 31 |
+
"""
|
| 32 |
+
初始化上下文单例(幂等操作)
|
| 33 |
+
|
| 34 |
+
如果已初始化则返回现有实例,确保模块重新导入时不会覆盖状态。
|
| 35 |
+
"""
|
| 36 |
+
if cls._instance is not None:
|
| 37 |
+
return cls._instance
|
| 38 |
+
cls._instance = cls(args, data_dir)
|
| 39 |
+
gc = getattr(args, "gradient_checkpointing", True)
|
| 40 |
+
print(
|
| 41 |
+
f"[Info Radar] gradient_checkpointing={'on' if gc else 'off'}",
|
| 42 |
+
file=sys.stderr,
|
| 43 |
+
flush=True,
|
| 44 |
+
)
|
| 45 |
+
return cls._instance
|
| 46 |
+
|
| 47 |
+
@classmethod
|
| 48 |
+
def is_initialized(cls) -> bool:
|
| 49 |
+
"""检查上下文是否已初始化"""
|
| 50 |
+
return cls._instance is not None
|
| 51 |
+
|
| 52 |
+
def __init__(self, args: Namespace, data_dir: Path):
|
| 53 |
+
"""私有构造函数,请使用 AppContext.init()"""
|
| 54 |
+
self.args = args
|
| 55 |
+
self.data_dir = data_dir
|
| 56 |
+
self._model_loading = True # 初始时处于加载状态
|
| 57 |
+
self._current_model_name = getattr(args, 'model', None)
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def model_name(self) -> str:
|
| 61 |
+
"""当前模型名称"""
|
| 62 |
+
return self._current_model_name
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def model_loading(self) -> bool:
|
| 66 |
+
"""模型是否正在加载"""
|
| 67 |
+
return self._model_loading
|
| 68 |
+
|
| 69 |
+
def set_current_model(self, model_name: str):
|
| 70 |
+
"""设置当前模型名称"""
|
| 71 |
+
self._current_model_name = model_name
|
| 72 |
+
|
| 73 |
+
def set_model_loading(self, loading: bool):
|
| 74 |
+
"""设置模型加载状态"""
|
| 75 |
+
self._model_loading = loading
|
| 76 |
+
|
| 77 |
+
def get_demo_dir(self, create: bool = False) -> Path:
|
| 78 |
+
"""获取 demo 目录路径"""
|
| 79 |
+
from backend.data_utils import get_demo_dir
|
| 80 |
+
return get_demo_dir(self.data_dir, create=create)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ============= 兼容性接口(供旧代码平滑迁移)=============
|
| 84 |
+
|
| 85 |
+
def get_app_context(prefer_module_context: bool = False) -> AppContext:
|
| 86 |
+
"""获取应用上下文(兼容旧接口,prefer_module_context 参数已忽略)"""
|
| 87 |
+
return AppContext.get()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def get_args() -> Namespace:
|
| 91 |
+
"""获取命令行参数"""
|
| 92 |
+
return AppContext.get().args
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def get_verbose() -> bool:
|
| 96 |
+
"""是否输出详细调试信息(由 --verbose 控制)"""
|
| 97 |
+
try:
|
| 98 |
+
return getattr(get_args(), "verbose", False)
|
| 99 |
+
except RuntimeError:
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_data_dir() -> Path:
|
| 104 |
+
"""获取数据目录"""
|
| 105 |
+
return AppContext.get().data_dir
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def get_demo_directory(create: bool = False) -> Path:
|
| 109 |
+
"""获取 demo 目录"""
|
| 110 |
+
return AppContext.get().get_demo_dir(create=create)
|
backend/class_register.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
REGISTERED_MODELS = {}
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def register_model(name):
|
| 5 |
+
"""
|
| 6 |
+
注册模型类的装饰器
|
| 7 |
+
|
| 8 |
+
自动将注册的模型名保存到类属性 _registered_model_name 中,
|
| 9 |
+
避免在子类初始化时重复指定模型名
|
| 10 |
+
"""
|
| 11 |
+
def decorator(cls):
|
| 12 |
+
REGISTERED_MODELS[name] = cls
|
| 13 |
+
# 将注册的模型名保存到类属性中
|
| 14 |
+
cls._registered_model_name = name
|
| 15 |
+
return cls
|
| 16 |
+
return decorator
|
backend/completion_generator.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenAI /v1/completions:core_generate_from_text 为唯一续写入口。
|
| 3 |
+
|
| 4 |
+
Chat 模板拼装见 apply_chat_template_for_completion(供 POST /v1/completions/prompt);
|
| 5 |
+
POST /v1/completions 的 prompt 须为已确定的模型输入字符串。
|
| 6 |
+
整段上下文 token 上限(prompt + 续写合计)为本模块 ``completion_max_token_length``;
|
| 7 |
+
可选 max_tokens 限制续写长度,且与 prompt 之和不超过该上限。
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import signal
|
| 11 |
+
import sys
|
| 12 |
+
import threading
|
| 13 |
+
import time
|
| 14 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from transformers import StoppingCriteria, StoppingCriteriaList, TextStreamer
|
| 18 |
+
|
| 19 |
+
from backend.api.utils import round_to_sig_figs
|
| 20 |
+
from backend.app_context import get_verbose
|
| 21 |
+
from backend.device import DeviceManager
|
| 22 |
+
from backend.model_manager import ensure_semantic_slot_ready
|
| 23 |
+
from backend.pred_topk_format import pred_topk_pairs_from_probs_1d
|
| 24 |
+
from backend.runtime_config import DEFAULT_TOPK
|
| 25 |
+
|
| 26 |
+
# 续写路径:prompt + 续写合计不得超过该 token 数(与语义分析 runtime 无关)。
|
| 27 |
+
completion_max_token_length = 1000
|
| 28 |
+
|
| 29 |
+
# 特殊 token 亦视为分析/展示内容,故不跳过。
|
| 30 |
+
_COMPLETION_DECODE_SKIP_SPECIAL = False
|
| 31 |
+
|
| 32 |
+
# 进程收到 SIGTERM / SIGINT 时置位。
|
| 33 |
+
inference_shutdown_event = threading.Event()
|
| 34 |
+
|
| 35 |
+
# 单用户串行:用户 POST /v1/completions/stop、或 SSE 墙钟超时,与 inference_shutdown 一起在续写路径检查。
|
| 36 |
+
# 新一次 POST /v1/completions(SSE 入口)时由 openai_completions clear。
|
| 37 |
+
global_completion_stop_event = threading.Event()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def completion_cancel_requested() -> bool:
|
| 41 |
+
"""是否应停止当前续写(进程退出或全局停止)。"""
|
| 42 |
+
return inference_shutdown_event.is_set() or global_completion_stop_event.is_set()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def register_inference_shutdown_handlers() -> None:
|
| 46 |
+
"""注册 SIGTERM / SIGINT:置位 inference_shutdown_event,使 model.generate 尽快在下一步停止。
|
| 47 |
+
|
| 48 |
+
应在主线程、进程启动早期调用一次(如 server 加载时)。SIGINT 在置位后抛出 KeyboardInterrupt,便于开发态 Ctrl+C 退出。
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def _on_sigterm(signum: int, frame: Any) -> None:
|
| 52 |
+
inference_shutdown_event.set()
|
| 53 |
+
|
| 54 |
+
def _on_sigint(signum: int, frame: Any) -> None:
|
| 55 |
+
inference_shutdown_event.set()
|
| 56 |
+
raise KeyboardInterrupt
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
signal.signal(signal.SIGTERM, _on_sigterm)
|
| 60 |
+
except (ValueError, OSError):
|
| 61 |
+
pass
|
| 62 |
+
try:
|
| 63 |
+
signal.signal(signal.SIGINT, _on_sigint)
|
| 64 |
+
except (ValueError, OSError):
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class PromptTooLongError(ValueError):
|
| 69 |
+
"""prompt 过长或占满上下文导致无法续写(``input_len >= ctx_limit`` 时由 ``core_generate_from_text`` 抛出)。"""
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _completion_without_generate(
|
| 73 |
+
prompt_tokens: int,
|
| 74 |
+
) -> Tuple[str, str, int, int, List[Dict[str, Any]], Optional[float]]:
|
| 75 |
+
"""取消续写时未进入 ``model.generate`` 的返回(与前端 ``abort`` 展示一致)。"""
|
| 76 |
+
return "", "abort", prompt_tokens, 0, [], None
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _print_completion_stream_delta(text: str, stream_end: bool) -> None:
|
| 80 |
+
"""接收 TextStreamer 切分好的增量片段,由本模块打印(与默认 TextStreamer 输出一致)。"""
|
| 81 |
+
# 仅在verbose时打印
|
| 82 |
+
if get_verbose():
|
| 83 |
+
print(text, flush=True, end="" if not stream_end else None)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def _compose_stream_delta(
|
| 87 |
+
stream_delta: Optional[Callable[[str, bool], None]],
|
| 88 |
+
) -> Callable[[str, bool], None]:
|
| 89 |
+
"""
|
| 90 |
+
将可选的 SSE/外部 stream_delta 与本地 verbose 打印组合:二者互不替代,可同时生效。
|
| 91 |
+
"""
|
| 92 |
+
def on_delta(text: str, stream_end: bool) -> None:
|
| 93 |
+
if stream_delta is not None:
|
| 94 |
+
stream_delta(text, stream_end)
|
| 95 |
+
_print_completion_stream_delta(text, stream_end)
|
| 96 |
+
|
| 97 |
+
return on_delta
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class _DeltaTextStreamer(TextStreamer):
|
| 101 |
+
"""继承 put/end 的增量切分逻辑,只把片段交给回调,不直接 print。"""
|
| 102 |
+
|
| 103 |
+
def __init__(
|
| 104 |
+
self,
|
| 105 |
+
tokenizer,
|
| 106 |
+
on_delta: Callable[[str, bool], None],
|
| 107 |
+
*,
|
| 108 |
+
skip_prompt: bool = False,
|
| 109 |
+
**decode_kwargs: Any,
|
| 110 |
+
) -> None:
|
| 111 |
+
super().__init__(tokenizer, skip_prompt=skip_prompt, **decode_kwargs)
|
| 112 |
+
self._on_delta = on_delta
|
| 113 |
+
|
| 114 |
+
def on_finalized_text(self, text: str, stream_end: bool = False) -> None:
|
| 115 |
+
self._on_delta(text, stream_end)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class _CancelOnEventStoppingCriteria(StoppingCriteria):
|
| 119 |
+
"""每步检查 ``completion_cancel_requested()``,尽快结束 generate。"""
|
| 120 |
+
|
| 121 |
+
def __call__(
|
| 122 |
+
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs: Any
|
| 123 |
+
) -> torch.BoolTensor:
|
| 124 |
+
# StoppingCriteria 约定:返回与 batch 等长的 bool 向量,True 表示该行本步停止生成。
|
| 125 |
+
batch_size = input_ids.shape[0]
|
| 126 |
+
cancel_requested = completion_cancel_requested()
|
| 127 |
+
return torch.full(
|
| 128 |
+
(batch_size,),
|
| 129 |
+
fill_value=cancel_requested,
|
| 130 |
+
device=input_ids.device,
|
| 131 |
+
dtype=torch.bool,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _stack_scores_to_cpu(
|
| 136 |
+
scores: Tuple[torch.Tensor, ...],
|
| 137 |
+
) -> torch.Tensor:
|
| 138 |
+
"""将 ``generate(..., output_scores=True)`` 的 scores 元组沿 batch 维拼成 ``[n, vocab]``,并一次搬到 CPU。"""
|
| 139 |
+
if not scores:
|
| 140 |
+
return torch.empty(0, 0)
|
| 141 |
+
# 每步形状为 (batch, vocab),greedy batch=1 时 cat(dim=0) -> (n, vocab)
|
| 142 |
+
return torch.cat(scores, dim=0).detach().cpu()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _print_completion_warning(msg: str) -> None:
|
| 146 |
+
print(msg, file=sys.stderr, flush=True)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _completion_one_token_debug(tokenizer, token_id: int) -> str:
|
| 150 |
+
"""续写路径调试用:单 token 的 id 与 decode(repr 便于观察空白/换行)。"""
|
| 151 |
+
decoded = tokenizer.decode([token_id], skip_special_tokens=False)
|
| 152 |
+
return f"id={token_id}, decode={decoded!r}"
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _warn_decode_reencode_mismatch(
|
| 156 |
+
tokenizer,
|
| 157 |
+
*,
|
| 158 |
+
n: int,
|
| 159 |
+
mismatch_count: int,
|
| 160 |
+
first: int,
|
| 161 |
+
new_cpu: torch.Tensor,
|
| 162 |
+
reencoded: torch.Tensor,
|
| 163 |
+
) -> None:
|
| 164 |
+
"""token 序列不一致时警告(文案与原 RuntimeError 一致),随后走增量 decode offset。"""
|
| 165 |
+
g0 = int(new_cpu[first].item())
|
| 166 |
+
r0 = int(reencoded[first].item())
|
| 167 |
+
lines = [
|
| 168 |
+
"续写段 decode→encode 与 generate 的 token 序列不一致,无法使用 offset_mapping。",
|
| 169 |
+
f" 共 {n} token,其中 {mismatch_count} 处 id 不同(首处 index={first})。",
|
| 170 |
+
" 首处:",
|
| 171 |
+
f" generate {_completion_one_token_debug(tokenizer, g0)}",
|
| 172 |
+
f" reencode {_completion_one_token_debug(tokenizer, r0)}",
|
| 173 |
+
]
|
| 174 |
+
nxt = first + 1
|
| 175 |
+
if nxt < n:
|
| 176 |
+
g1 = int(new_cpu[nxt].item())
|
| 177 |
+
r1 = int(reencoded[nxt].item())
|
| 178 |
+
lines.extend(
|
| 179 |
+
[
|
| 180 |
+
f" 后一处 (index={nxt}):",
|
| 181 |
+
f" generate {_completion_one_token_debug(tokenizer, g1)}",
|
| 182 |
+
f" reencode {_completion_one_token_debug(tokenizer, r1)}",
|
| 183 |
+
]
|
| 184 |
+
)
|
| 185 |
+
_print_completion_warning("\n".join(lines))
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _warn_decode_reencode_length_mismatch(
|
| 189 |
+
new_cpu: torch.Tensor,
|
| 190 |
+
reencoded: torch.Tensor,
|
| 191 |
+
) -> None:
|
| 192 |
+
msg = (
|
| 193 |
+
"续写段 decode→encode 与 generate 的 token 序列不一致(长度不同),无法使用 offset_mapping。\n"
|
| 194 |
+
f" new_ids: shape={tuple(new_cpu.shape)}\n"
|
| 195 |
+
f" reencode: shape={tuple(reencoded.shape)}"
|
| 196 |
+
)
|
| 197 |
+
_print_completion_warning(msg)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _lcp_prefix_len(a: str, b: str) -> int:
|
| 201 |
+
"""``a`` 与 ``b`` 的最长公共前缀长度(Python ``str`` 下标,Unicode 标量)。 """
|
| 202 |
+
k, n = 0, min(len(a), len(b))
|
| 203 |
+
while k < n and a[k] == b[k]:
|
| 204 |
+
k += 1
|
| 205 |
+
return k
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _verbose_incremental_offset_step(
|
| 209 |
+
*,
|
| 210 |
+
step_1based: int,
|
| 211 |
+
n_tokens: int,
|
| 212 |
+
token_id: int,
|
| 213 |
+
tokenizer,
|
| 214 |
+
skip: bool,
|
| 215 |
+
offset: Tuple[int, int],
|
| 216 |
+
matched: int,
|
| 217 |
+
curr_len: int,
|
| 218 |
+
raw: str,
|
| 219 |
+
) -> None:
|
| 220 |
+
"""verbose:本步 ``offset``/``raw``;LCP 未盖满前缀时附 ``single_decode``。"""
|
| 221 |
+
if not get_verbose():
|
| 222 |
+
return
|
| 223 |
+
s, e = offset
|
| 224 |
+
raw_show = raw if len(raw) <= 240 else raw[:237] + "..."
|
| 225 |
+
line = (
|
| 226 |
+
f"[incremental-offset] step {step_1based}/{n_tokens} id={token_id} "
|
| 227 |
+
f"offset=[{s},{e}) raw={raw_show!r}"
|
| 228 |
+
)
|
| 229 |
+
if matched < curr_len:
|
| 230 |
+
one = tokenizer.decode([token_id], skip_special_tokens=skip)
|
| 231 |
+
line += f" (bpe mismatch) single_decode={one!r}"
|
| 232 |
+
_print_completion_warning(line)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _print_full_decode_text_mismatch(full_decode: str, text: str) -> None:
|
| 236 |
+
"""整段 ``decode(ids)`` 与 ``completion_text`` 不等时打印一行级诊断。"""
|
| 237 |
+
lines = [
|
| 238 |
+
"续写段整段 decode 与 completion_text 不一致:",
|
| 239 |
+
f" len(decode)={len(full_decode)}, len(text)={len(text)}",
|
| 240 |
+
]
|
| 241 |
+
n = min(len(full_decode), len(text))
|
| 242 |
+
first_diff = next((k for k in range(n) if full_decode[k] != text[k]), None)
|
| 243 |
+
if first_diff is not None:
|
| 244 |
+
a, b = full_decode[first_diff], text[first_diff]
|
| 245 |
+
lines.append(f" 首处 index={first_diff}: {a!r} vs {b!r}")
|
| 246 |
+
elif len(full_decode) != len(text):
|
| 247 |
+
lines.append(" 同源码点前缀一致,仅长度不同。")
|
| 248 |
+
_print_completion_warning("\n".join(lines))
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _completion_incremental_offsets_and_raws(
|
| 252 |
+
tokenizer,
|
| 253 |
+
new_ids: torch.Tensor,
|
| 254 |
+
completion_text: str,
|
| 255 |
+
*,
|
| 256 |
+
skip: bool,
|
| 257 |
+
) -> Tuple[List[Tuple[int, int]], List[str]]:
|
| 258 |
+
"""
|
| 259 |
+
慢路径:解码器码点。第 ``i`` 步 ``curr = decode(ids[:i+1])``,
|
| 260 |
+
``matched = LCP(curr, completion_text)``(自 0 全量比较,避免 decode 非单调时增量 LCP 偏差);
|
| 261 |
+
``offset``:若 ``matched < len(curr)``(前缀与全文前沿未对齐),则 ``(off_left, off_left)``;
|
| 262 |
+
否则 ``(off_left, len(curr))``。``raw`` 恒为 ``curr[off_left:]``。
|
| 263 |
+
未对齐时 BPE 与全文对不齐,乱码段码点数、``offset`` 无可靠展示语义;右界收拢为左界仅为避免
|
| 264 |
+
前端按 ``completion_text`` 切片校验 ``raw`` 时报错(零宽区间不取切片)。
|
| 265 |
+
``off_left``:首步 ``0``;若上一步 ``matched == len(curr)``,则 ``off_left = matched``;若上一步
|
| 266 |
+
``matched < len(curr)``,则冻结 ``off_left`` 直至再次出现完全对齐步。
|
| 267 |
+
须 ``decode(ids) == completion_text``,否则报错。
|
| 268 |
+
"""
|
| 269 |
+
ids = [int(t) for t in new_ids.tolist()]
|
| 270 |
+
n_tok = len(ids)
|
| 271 |
+
|
| 272 |
+
offsets: List[Tuple[int, int]] = []
|
| 273 |
+
raws: List[str] = []
|
| 274 |
+
off_left = 0
|
| 275 |
+
|
| 276 |
+
# 每步对前缀 ``ids[:i+1]`` 整段 decode;重复切片为语义所需,非疏忽。
|
| 277 |
+
for i in range(n_tok):
|
| 278 |
+
curr = tokenizer.decode(ids[: i + 1], skip_special_tokens=skip)
|
| 279 |
+
matched = _lcp_prefix_len(curr, completion_text)
|
| 280 |
+
curr_len = len(curr)
|
| 281 |
+
raw = curr[off_left:]
|
| 282 |
+
# 未对齐:乱码长度与 offset 无可靠意义;右界=左界,避免前端 text[s:e]==raw 类校验失败。
|
| 283 |
+
if matched < curr_len:
|
| 284 |
+
off = (off_left, off_left)
|
| 285 |
+
else:
|
| 286 |
+
off = (off_left, curr_len)
|
| 287 |
+
# _verbose_incremental_offset_step(
|
| 288 |
+
# step_1based=i + 1,
|
| 289 |
+
# n_tokens=n_tok,
|
| 290 |
+
# token_id=ids[i],
|
| 291 |
+
# tokenizer=tokenizer,
|
| 292 |
+
# skip=skip,
|
| 293 |
+
# offset=off,
|
| 294 |
+
# matched=matched,
|
| 295 |
+
# curr_len=curr_len,
|
| 296 |
+
# raw=raw,
|
| 297 |
+
# )
|
| 298 |
+
offsets.append(off)
|
| 299 |
+
raws.append(raw)
|
| 300 |
+
if matched == len(curr):
|
| 301 |
+
off_left = matched
|
| 302 |
+
|
| 303 |
+
full = tokenizer.decode(ids, skip_special_tokens=skip)
|
| 304 |
+
if full != completion_text:
|
| 305 |
+
_print_full_decode_text_mismatch(full, completion_text)
|
| 306 |
+
raise RuntimeError(
|
| 307 |
+
"续写段 decode(ids) 与 completion_text 不一致,无法填解码器坐标 offset/raw。"
|
| 308 |
+
)
|
| 309 |
+
return offsets, raws
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def _build_generated_bpe_strings(
|
| 313 |
+
tokenizer,
|
| 314 |
+
new_ids: torch.Tensor,
|
| 315 |
+
scores_logits: torch.Tensor,
|
| 316 |
+
top_k: int,
|
| 317 |
+
completion_text: str,
|
| 318 |
+
) -> List[Dict[str, Any]]:
|
| 319 |
+
"""
|
| 320 |
+
续写段每个生成 token 的信息密度风格条目:offset/raw(相对续写全文)、real_topk、pred_topk。
|
| 321 |
+
|
| 322 |
+
new_ids:1D int64,须已在 CPU,与 generate 输出一致。
|
| 323 |
+
scores_logits:float,形状 ``[n, vocab]``,须已在 CPU(避免逐步 GPU softmax / .item() 往返)。
|
| 324 |
+
completion_text:与 ``tokenizer.decode(new_ids, skip_special_tokens=...)`` 使用同一套参数得到的续写原文(调用方已 decode 一次,避免重复)。
|
| 325 |
+
若整段 encode 与 ``new_ids`` 一致则用 ``offset_mapping``(快路径,offset 为 ``completion_text`` 内下标);
|
| 326 |
+
否则用增量 decode(慢路径):LCP 未盖满前缀时 ``offset`` 为 ``(off_left, off_left)``(见该函数注释:主要为避免前端切片校验报错),否则 ``(off_left, len(curr))``;``raw`` 恒为 ``curr[off_left:]``。
|
| 327 |
+
"""
|
| 328 |
+
n = int(new_ids.numel())
|
| 329 |
+
if n == 0:
|
| 330 |
+
return []
|
| 331 |
+
if scores_logits.dim() != 2 or scores_logits.shape[0] != n:
|
| 332 |
+
raise RuntimeError(
|
| 333 |
+
f"scores_logits 形状与 new_ids 不一致:scores_logits.shape={tuple(scores_logits.shape)}, n={n}"
|
| 334 |
+
)
|
| 335 |
+
top_k = min(top_k, int(scores_logits.shape[-1]))
|
| 336 |
+
new_cpu = new_ids.detach().cpu()
|
| 337 |
+
skip = _COMPLETION_DECODE_SKIP_SPECIAL
|
| 338 |
+
|
| 339 |
+
enc = tokenizer(
|
| 340 |
+
completion_text,
|
| 341 |
+
return_tensors="pt",
|
| 342 |
+
return_offsets_mapping=True,
|
| 343 |
+
add_special_tokens=False,
|
| 344 |
+
)
|
| 345 |
+
reencoded = enc["input_ids"][0]
|
| 346 |
+
ids_match = reencoded.shape == new_cpu.shape and torch.equal(reencoded, new_cpu)
|
| 347 |
+
|
| 348 |
+
incremental_raws: Optional[List[str]]
|
| 349 |
+
if ids_match:
|
| 350 |
+
offset_mapping = enc["offset_mapping"][0].tolist()
|
| 351 |
+
incremental_raws = None
|
| 352 |
+
else:
|
| 353 |
+
if reencoded.shape != new_cpu.shape:
|
| 354 |
+
_warn_decode_reencode_length_mismatch(new_cpu, reencoded)
|
| 355 |
+
else:
|
| 356 |
+
diff = reencoded != new_cpu
|
| 357 |
+
first = int(torch.where(diff)[0][0].item())
|
| 358 |
+
_warn_decode_reencode_mismatch(
|
| 359 |
+
tokenizer,
|
| 360 |
+
n=n,
|
| 361 |
+
mismatch_count=int(diff.sum().item()),
|
| 362 |
+
first=first,
|
| 363 |
+
new_cpu=new_cpu,
|
| 364 |
+
reencoded=reencoded,
|
| 365 |
+
)
|
| 366 |
+
print("已使用增量 decode 对齐路径;结果不受影响。", flush=True)
|
| 367 |
+
offset_mapping, incremental_raws = _completion_incremental_offsets_and_raws(
|
| 368 |
+
tokenizer, new_cpu, completion_text, skip=skip
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
out: List[Dict[str, Any]] = []
|
| 372 |
+
for step in range(n):
|
| 373 |
+
logits = scores_logits[step]
|
| 374 |
+
probs = torch.softmax(logits, dim=-1)
|
| 375 |
+
tid = int(new_ids[step].item())
|
| 376 |
+
s, e = offset_mapping[step]
|
| 377 |
+
if incremental_raws is not None:
|
| 378 |
+
raw = incremental_raws[step]
|
| 379 |
+
else:
|
| 380 |
+
raw = completion_text[s:e] if s < e else ""
|
| 381 |
+
out.append(
|
| 382 |
+
{
|
| 383 |
+
"offset": [s, e],
|
| 384 |
+
"raw": raw,
|
| 385 |
+
"real_topk": [0, round_to_sig_figs(float(probs[tid].item()))],
|
| 386 |
+
"pred_topk": pred_topk_pairs_from_probs_1d(probs, tokenizer, top_k),
|
| 387 |
+
}
|
| 388 |
+
)
|
| 389 |
+
return out
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def core_generate_from_text(
|
| 393 |
+
formatted_text: str,
|
| 394 |
+
*,
|
| 395 |
+
stream_delta: Optional[Callable[[str, bool], None]] = None,
|
| 396 |
+
max_tokens: Optional[int] = None,
|
| 397 |
+
) -> Tuple[str, str, int, int, List[Dict[str, Any]], Optional[float]]:
|
| 398 |
+
"""
|
| 399 |
+
对一段已确定的模型输入字符串做自回归续写(默认贪心;函数内 ``_use_low_temp_sampling`` 可临时切到低温采样)。
|
| 400 |
+
|
| 401 |
+
编码后 prompt token 数不得超过上下文上限;续写步数不超过「剩余上下文」且不超过可选 ``max_tokens``。
|
| 402 |
+
|
| 403 |
+
中止条件见 ``completion_cancel_requested()``(进程信号、全局停止含用户 Stop / 墙钟超时)。
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
stream_delta: 可选;若提供则额外调用(如 SSE)。本地 verbose 打印由 ``_print_completion_stream_delta`` 单独控制,与是否传入 stream_delta 无关。
|
| 407 |
+
max_tokens: 可选;正整数,限制本次最多生成多少个新 token(与 ``min(max_tokens, 上限 − prompt)`` 取小)。省略则用尽剩余上下文额度。
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
(续写文本, finish_reason, prompt_tokens, completion_tokens, 续写段 bpe_strings, ttft_s)。
|
| 411 |
+
ttft_s 为自 ``model.generate`` 起至首次产出续写片段的秒数;仅取消时为 ``None``。
|
| 412 |
+
"""
|
| 413 |
+
tokenizer, model, device = ensure_semantic_slot_ready()
|
| 414 |
+
ctx_limit = completion_max_token_length
|
| 415 |
+
|
| 416 |
+
model.eval()
|
| 417 |
+
enc = tokenizer(formatted_text, return_tensors="pt")
|
| 418 |
+
input_ids = enc["input_ids"].to(device)
|
| 419 |
+
input_len = input_ids.shape[1]
|
| 420 |
+
n = int(input_len)
|
| 421 |
+
if n >= ctx_limit:
|
| 422 |
+
raise PromptTooLongError(
|
| 423 |
+
"Prompt too long: "
|
| 424 |
+
f"{n} tokens (context limit is {ctx_limit} tokens; prompt plus completion must not exceed this limit)."
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
remaining = ctx_limit - n
|
| 428 |
+
if max_tokens is None:
|
| 429 |
+
effective_max_new = remaining
|
| 430 |
+
else:
|
| 431 |
+
effective_max_new = min(max_tokens, remaining)
|
| 432 |
+
if get_verbose():
|
| 433 |
+
print(
|
| 434 |
+
f"📌 completion: 推理原文 (tokens={input_len}, ctx_limit={ctx_limit}, max_new={effective_max_new}):\n"
|
| 435 |
+
f"{formatted_text}",
|
| 436 |
+
end="", # 不换行, 用于和后续打印推理结果拼在一起
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
prompt_tokens = int(input_len)
|
| 440 |
+
# 主要防止:排队等推理锁期间用户已取消,拿到锁后在此短路,避免无意义进入 generate。
|
| 441 |
+
# 墙钟 / 进程信号等其它情况较少见。
|
| 442 |
+
if completion_cancel_requested():
|
| 443 |
+
return _completion_without_generate(prompt_tokens)
|
| 444 |
+
|
| 445 |
+
try:
|
| 446 |
+
base_on_delta = _compose_stream_delta(stream_delta)
|
| 447 |
+
ttft_seconds: Optional[float] = None
|
| 448 |
+
gen_start_t0 = 0.0
|
| 449 |
+
|
| 450 |
+
def on_delta_with_ttft(text: str, stream_end: bool) -> None:
|
| 451 |
+
nonlocal ttft_seconds
|
| 452 |
+
if ttft_seconds is None:
|
| 453 |
+
ttft_seconds = time.perf_counter() - gen_start_t0
|
| 454 |
+
base_on_delta(text, stream_end)
|
| 455 |
+
|
| 456 |
+
streamer = _DeltaTextStreamer(
|
| 457 |
+
tokenizer,
|
| 458 |
+
on_delta_with_ttft,
|
| 459 |
+
skip_prompt=True,
|
| 460 |
+
skip_special_tokens=_COMPLETION_DECODE_SKIP_SPECIAL,
|
| 461 |
+
)
|
| 462 |
+
# 临时实验:置 True 启用低温采样;默认 False 为贪心解码(可复现)。
|
| 463 |
+
_use_low_temp_sampling = False
|
| 464 |
+
_low_temperature = 0.2
|
| 465 |
+
|
| 466 |
+
gen_kw: Dict[str, Any] = {
|
| 467 |
+
"input_ids": input_ids,
|
| 468 |
+
"max_new_tokens": effective_max_new,
|
| 469 |
+
"return_dict_in_generate": True,
|
| 470 |
+
"output_scores": True,
|
| 471 |
+
"streamer": streamer,
|
| 472 |
+
"stopping_criteria": StoppingCriteriaList([_CancelOnEventStoppingCriteria()]),
|
| 473 |
+
}
|
| 474 |
+
if _use_low_temp_sampling:
|
| 475 |
+
gen_kw["do_sample"] = True
|
| 476 |
+
gen_kw["temperature"] = _low_temperature
|
| 477 |
+
else:
|
| 478 |
+
gen_kw["do_sample"] = False
|
| 479 |
+
|
| 480 |
+
gen_start_t0 = time.perf_counter()
|
| 481 |
+
with torch.inference_mode():
|
| 482 |
+
outputs = model.generate(**gen_kw)
|
| 483 |
+
if device.type == "cuda":
|
| 484 |
+
torch.cuda.synchronize(device)
|
| 485 |
+
elif device.type == "mps":
|
| 486 |
+
torch.mps.synchronize()
|
| 487 |
+
|
| 488 |
+
gen = outputs.sequences
|
| 489 |
+
new_ids = gen[0, input_len:].detach().cpu().contiguous()
|
| 490 |
+
text = tokenizer.decode(new_ids, skip_special_tokens=_COMPLETION_DECODE_SKIP_SPECIAL)
|
| 491 |
+
|
| 492 |
+
if outputs.scores is None:
|
| 493 |
+
raise RuntimeError("model.generate 未返回 scores(需 output_scores=True)")
|
| 494 |
+
|
| 495 |
+
if new_ids.numel() == 0:
|
| 496 |
+
bpe_strings: List[Dict[str, Any]] = []
|
| 497 |
+
else:
|
| 498 |
+
# [len, vocab_size] 的 float32 logits
|
| 499 |
+
# 内存开销 1000 token x qwen 150k ~= 600MB
|
| 500 |
+
scores_cpu = _stack_scores_to_cpu(outputs.scores)
|
| 501 |
+
bpe_strings = _build_generated_bpe_strings(
|
| 502 |
+
tokenizer, new_ids, scores_cpu, DEFAULT_TOPK, text
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
# 续写增量已由 _print_completion_stream_delta 打印,此处不再重复打印全文
|
| 506 |
+
if completion_cancel_requested():
|
| 507 |
+
# 用户 Stop / 进程中止等:StoppingCriteria 提前结束时 new_ids 常少于上限,
|
| 508 |
+
# 勿用 "stop"(OpenAI 语义多为自然结束),否则前端会误显示为 EOS。
|
| 509 |
+
finish_reason = "abort"
|
| 510 |
+
else:
|
| 511 |
+
finish_reason = "length" if new_ids.numel() >= effective_max_new else "stop"
|
| 512 |
+
prompt_tokens = int(input_len)
|
| 513 |
+
completion_tokens = int(new_ids.numel())
|
| 514 |
+
return text, finish_reason, prompt_tokens, completion_tokens, bpe_strings, ttft_seconds
|
| 515 |
+
finally:
|
| 516 |
+
DeviceManager.clear_cache(device)
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def apply_chat_template_for_completion(
|
| 520 |
+
user_content: str,
|
| 521 |
+
system: Optional[str] = None,
|
| 522 |
+
) -> str:
|
| 523 |
+
"""
|
| 524 |
+
将单条 user 文本套用到 tokenizer chat template,返回实际送入 core_generate_from_text 的字符串。
|
| 525 |
+
|
| 526 |
+
调用方未传入 ``system``(即 ``None``)时仅拼装单条 user 消息;传入字符串时(含 ``\"\"``、仅空白)
|
| 527 |
+
原样作为 chat template 的 system 段,不做裁剪或改写。长度与上下文上限由 ``core_generate_from_text``
|
| 528 |
+
在生成前校验。
|
| 529 |
+
"""
|
| 530 |
+
tokenizer, _, _ = ensure_semantic_slot_ready()
|
| 531 |
+
if system is None:
|
| 532 |
+
messages = [{"role": "user", "content": user_content}]
|
| 533 |
+
else:
|
| 534 |
+
messages = [
|
| 535 |
+
{"role": "system", "content": system},
|
| 536 |
+
{"role": "user", "content": user_content},
|
| 537 |
+
]
|
| 538 |
+
return tokenizer.apply_chat_template(
|
| 539 |
+
messages,
|
| 540 |
+
tokenize=False,
|
| 541 |
+
add_generation_prompt=True,
|
| 542 |
+
enable_thinking=False,
|
| 543 |
+
)
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
def generate_completion_text(
|
| 547 |
+
prompt: str,
|
| 548 |
+
stream_delta: Optional[Callable[[str, bool], None]] = None,
|
| 549 |
+
*,
|
| 550 |
+
max_tokens: Optional[int] = None,
|
| 551 |
+
) -> Tuple[str, str, int, int, List[Dict[str, Any]], Optional[float]]:
|
| 552 |
+
"""
|
| 553 |
+
``prompt`` 须为已确定的完整模型输入(不再在服务端套用 chat template)。
|
| 554 |
+
|
| 555 |
+
流式可传 stream_delta;中止由 ``completion_cancel_requested()`` 统一判断。
|
| 556 |
+
``max_tokens`` 为可选的正整数续写上限(与 API 约定一致)。
|
| 557 |
+
"""
|
| 558 |
+
return core_generate_from_text(prompt, stream_delta=stream_delta, max_tokens=max_tokens)
|
backend/data_utils.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
DEFAULT_DATA_DIR = Path(os.path.abspath("data/demo/public"))
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def resolve_data_dir(dir_arg: Optional[str]) -> Path:
|
| 10 |
+
"""
|
| 11 |
+
Resolve the base data directory from CLI args or fall back to demo/public.
|
| 12 |
+
"""
|
| 13 |
+
if dir_arg:
|
| 14 |
+
return Path(dir_arg).expanduser().absolute()
|
| 15 |
+
return DEFAULT_DATA_DIR
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_demo_dir(data_dir: Path, create: bool = False) -> Path:
|
| 19 |
+
"""Return the demo directory under the given data dir, optionally creating it."""
|
| 20 |
+
# data_dir 此时默认就是 data/demo/public 的绝对路径
|
| 21 |
+
demo_dir = data_dir
|
| 22 |
+
if create:
|
| 23 |
+
demo_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
return demo_dir
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def list_demo_files(demo_dir: Path) -> List[Dict[str, str]]:
|
| 28 |
+
"""Return sorted demo metadata from a directory. Missing dirs result in empty list."""
|
| 29 |
+
if not demo_dir.exists():
|
| 30 |
+
return []
|
| 31 |
+
|
| 32 |
+
demo_list = []
|
| 33 |
+
for file_path in demo_dir.glob("*.json"):
|
| 34 |
+
demo_list.append(
|
| 35 |
+
{
|
| 36 |
+
"name": file_path.stem,
|
| 37 |
+
"file": file_path.name,
|
| 38 |
+
}
|
| 39 |
+
)
|
| 40 |
+
demo_list.sort(key=lambda item: item["name"])
|
| 41 |
+
return demo_list
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def sanitize_demo_name(name: str) -> str:
|
| 45 |
+
"""Remove unsafe characters from a demo name to create a safe filename."""
|
| 46 |
+
unsafe_chars = ['/', '\\', ':', '*', '?', '"', '<', '>', '|']
|
| 47 |
+
safe_name = name or ""
|
| 48 |
+
for char in unsafe_chars:
|
| 49 |
+
safe_name = safe_name.replace(char, '_')
|
| 50 |
+
safe_name = safe_name.strip(' .')
|
| 51 |
+
return safe_name[:200]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def save_demo_payload(demo_dir: Path, name: str, data: Dict[str, Any], path: str = "", overwrite: bool = False) -> Dict[str, Any]:
|
| 55 |
+
"""
|
| 56 |
+
Persist an AnalyzeResponse payload as a demo JSON file.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
demo_dir: demo目录的绝对路径
|
| 60 |
+
name: demo文件名(不含扩展名)
|
| 61 |
+
data: 要保存的数据
|
| 62 |
+
path: 保存路径,可以是 ""、"/" 或以 "/" 开头的路径,默认为根目录
|
| 63 |
+
overwrite: 是否覆盖已存在的文件,默认为False
|
| 64 |
+
"""
|
| 65 |
+
from backend.path_utils import resolve_demo_path
|
| 66 |
+
|
| 67 |
+
safe_name = sanitize_demo_name(name)
|
| 68 |
+
if not safe_name:
|
| 69 |
+
return {"success": False, "message": "文件名无效"}
|
| 70 |
+
|
| 71 |
+
# 解析目标路径
|
| 72 |
+
target_dir = resolve_demo_path(demo_dir, path)
|
| 73 |
+
if target_dir is None:
|
| 74 |
+
return {"success": False, "message": f"无效的保存路径: {path}"}
|
| 75 |
+
|
| 76 |
+
# 确保目标目录存在
|
| 77 |
+
target_dir.mkdir(parents=True, exist_ok=True)
|
| 78 |
+
file_path = target_dir / f"{safe_name}.json"
|
| 79 |
+
|
| 80 |
+
# 检查文件是否存在
|
| 81 |
+
if file_path.exists() and not overwrite:
|
| 82 |
+
return {
|
| 83 |
+
"success": False,
|
| 84 |
+
"exists": True,
|
| 85 |
+
"message": f'文件 "{safe_name}.json" 已存在',
|
| 86 |
+
"file": file_path.name,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
with open(file_path, "w", encoding="utf-8") as f:
|
| 90 |
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"success": True,
|
| 94 |
+
"message": f'Demo "{name}" 保存成功',
|
| 95 |
+
"file": file_path.name,
|
| 96 |
+
}
|
| 97 |
+
|
backend/demo_folder.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Demo文件夹操作模块
|
| 3 |
+
提供文件夹和文件的列表、移动、重命名、删除等功能
|
| 4 |
+
"""
|
| 5 |
+
import os
|
| 6 |
+
import shutil
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Dict, List, Optional
|
| 10 |
+
|
| 11 |
+
from backend.path_utils import (
|
| 12 |
+
normalize_path,
|
| 13 |
+
check_path_in_demo_dir,
|
| 14 |
+
validate_demo_path,
|
| 15 |
+
resolve_demo_path
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ==================== 辅助函数 ====================
|
| 20 |
+
|
| 21 |
+
def _normalize_path(path: str) -> str:
|
| 22 |
+
"""统一处理路径:将空字符串转换为 "/" (向后兼容包装器)"""
|
| 23 |
+
return normalize_path(path)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _build_api_path(parent_path: str, item_name: str) -> str:
|
| 27 |
+
"""构建API路径格式(统一使用 "/" 开头的格式)"""
|
| 28 |
+
if parent_path and parent_path != "/":
|
| 29 |
+
return f"{parent_path}/{item_name}"
|
| 30 |
+
return f"/{item_name}"
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _error_response(message: str) -> Dict[str, any]:
|
| 34 |
+
"""统一错误响应格式"""
|
| 35 |
+
return {"success": False, "message": message}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _success_response(message: str) -> Dict[str, any]:
|
| 39 |
+
"""统一成功响应格式"""
|
| 40 |
+
return {"success": True, "message": message}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _get_timestamped_name(base_name: str, extension: str = "") -> str:
|
| 44 |
+
"""生成带时间戳的名称"""
|
| 45 |
+
timestamp = int(time.time())
|
| 46 |
+
return f"{base_name}_{timestamp}{extension}"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _ensure_deleted_dir(demo_dir: Path) -> Path:
|
| 50 |
+
"""确保.deleted目录存在并返回路径"""
|
| 51 |
+
deleted_dir = demo_dir.resolve() / '.deleted'
|
| 52 |
+
deleted_dir.mkdir(parents=True, exist_ok=True)
|
| 53 |
+
return deleted_dir
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _validate_json_file(file_path: Path) -> Optional[str]:
|
| 57 |
+
"""验证文件存在且为JSON文件,返回错误消息或None"""
|
| 58 |
+
if not file_path.exists():
|
| 59 |
+
return "文件不存在"
|
| 60 |
+
if not file_path.is_file():
|
| 61 |
+
return "路径不是文件"
|
| 62 |
+
if file_path.suffix != '.json':
|
| 63 |
+
return "只能操作JSON文件"
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _validate_folder(folder_path: Path) -> Optional[str]:
|
| 68 |
+
"""验证文件夹存在,返回错误消息或None"""
|
| 69 |
+
if not folder_path.exists():
|
| 70 |
+
return "文件夹不存在"
|
| 71 |
+
if not folder_path.is_dir():
|
| 72 |
+
return "路径不是文件夹"
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# ==================== 文件系统操作函数 ====================
|
| 77 |
+
# 核心路径处理函数已移至 backend/path_utils.py
|
| 78 |
+
|
| 79 |
+
def list_demo_items(demo_dir: Path, path: str = "") -> Dict[str, any]:
|
| 80 |
+
"""返回指定路径下的文件夹和文件列表,自动忽略隐藏文件夹"""
|
| 81 |
+
normalized_path = _normalize_path(path)
|
| 82 |
+
target_dir = resolve_demo_path(demo_dir, normalized_path)
|
| 83 |
+
|
| 84 |
+
if not target_dir or not target_dir.exists():
|
| 85 |
+
return {"path": normalized_path, "items": []}
|
| 86 |
+
|
| 87 |
+
items = []
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
for item_path in target_dir.iterdir():
|
| 91 |
+
if item_path.name.startswith('.'):
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
if item_path.is_dir():
|
| 95 |
+
items.append({
|
| 96 |
+
"type": "folder",
|
| 97 |
+
"name": item_path.name,
|
| 98 |
+
"path": _build_api_path(normalized_path, item_path.name)
|
| 99 |
+
})
|
| 100 |
+
elif item_path.is_file() and item_path.suffix == '.json':
|
| 101 |
+
items.append({
|
| 102 |
+
"type": "file",
|
| 103 |
+
"name": item_path.stem,
|
| 104 |
+
"path": _build_api_path(normalized_path, item_path.name)
|
| 105 |
+
})
|
| 106 |
+
except Exception as e:
|
| 107 |
+
import traceback
|
| 108 |
+
print(f"❌ 扫描目录失败: {e}")
|
| 109 |
+
traceback.print_exc()
|
| 110 |
+
return {"path": normalized_path, "items": []}
|
| 111 |
+
|
| 112 |
+
# 排序:文件夹在前,文件在后,各自按名称排序
|
| 113 |
+
folders = sorted([item for item in items if item["type"] == "folder"], key=lambda x: x["name"])
|
| 114 |
+
files = sorted([item for item in items if item["type"] == "file"], key=lambda x: x["name"])
|
| 115 |
+
|
| 116 |
+
return {"path": path, "items": folders + files}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_all_folders(demo_dir: Path, exclude_path: Optional[str] = None) -> List[str]:
|
| 120 |
+
"""递归获取所有文件夹列表(用于移动操作),自动忽略隐藏文件夹"""
|
| 121 |
+
folders = []
|
| 122 |
+
|
| 123 |
+
def _scan_directory(current_dir: Path, current_path: str):
|
| 124 |
+
"""递归扫描目录"""
|
| 125 |
+
try:
|
| 126 |
+
for item in current_dir.iterdir():
|
| 127 |
+
if item.name.startswith('.'):
|
| 128 |
+
continue
|
| 129 |
+
|
| 130 |
+
if item.is_dir():
|
| 131 |
+
folder_path = _build_api_path(current_path, item.name)
|
| 132 |
+
|
| 133 |
+
if exclude_path and (folder_path == exclude_path or folder_path.startswith(exclude_path + "/")):
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
folders.append(folder_path)
|
| 137 |
+
_scan_directory(item, folder_path)
|
| 138 |
+
except Exception as e:
|
| 139 |
+
import traceback
|
| 140 |
+
print(f"❌ 扫描文件夹失败: {e}")
|
| 141 |
+
traceback.print_exc()
|
| 142 |
+
|
| 143 |
+
_scan_directory(demo_dir, "/")
|
| 144 |
+
folders.insert(0, "/")
|
| 145 |
+
return folders
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def move_demo_file(demo_dir: Path, source_path: str, target_path: str) -> Dict[str, any]:
|
| 149 |
+
"""移动demo文件"""
|
| 150 |
+
source_file = resolve_demo_path(demo_dir, source_path)
|
| 151 |
+
if not source_file:
|
| 152 |
+
return _error_response(f"源文件不存在: {source_path}")
|
| 153 |
+
|
| 154 |
+
error_msg = _validate_json_file(source_file)
|
| 155 |
+
if error_msg:
|
| 156 |
+
return _error_response(f"源文件{error_msg}: {source_path}" if "不存在" not in error_msg else error_msg)
|
| 157 |
+
|
| 158 |
+
target_dir = resolve_demo_path(demo_dir, target_path)
|
| 159 |
+
if not target_dir:
|
| 160 |
+
return _error_response(f"无效的目标路径: {target_path}")
|
| 161 |
+
|
| 162 |
+
target_dir.mkdir(parents=True, exist_ok=True)
|
| 163 |
+
target_file = target_dir / source_file.name
|
| 164 |
+
|
| 165 |
+
if target_file.exists() and target_file != source_file:
|
| 166 |
+
return _error_response(f"目标位置已存在同名文件: {source_file.name}")
|
| 167 |
+
|
| 168 |
+
try:
|
| 169 |
+
shutil.move(str(source_file), str(target_file))
|
| 170 |
+
return _success_response(f"文件已移动到 {target_path}")
|
| 171 |
+
except Exception as e:
|
| 172 |
+
return _error_response(f"移动失败: {str(e)}")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def rename_demo_file(demo_dir: Path, file_path: str, new_name: str) -> Dict[str, any]:
|
| 176 |
+
"""重命名demo文件"""
|
| 177 |
+
from backend.data_utils import sanitize_demo_name
|
| 178 |
+
|
| 179 |
+
source_file = resolve_demo_path(demo_dir, file_path)
|
| 180 |
+
if not source_file:
|
| 181 |
+
return _error_response(f"文件不存在: {file_path}")
|
| 182 |
+
|
| 183 |
+
error_msg = _validate_json_file(source_file)
|
| 184 |
+
if error_msg:
|
| 185 |
+
return _error_response(error_msg)
|
| 186 |
+
|
| 187 |
+
safe_name = sanitize_demo_name(new_name)
|
| 188 |
+
if not safe_name:
|
| 189 |
+
return _error_response("新名称无效")
|
| 190 |
+
|
| 191 |
+
target_file = source_file.parent / f"{safe_name}.json"
|
| 192 |
+
|
| 193 |
+
if target_file.exists() and target_file != source_file:
|
| 194 |
+
return _error_response(f"文件 '{safe_name}.json' 已存在")
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
source_file.rename(target_file)
|
| 198 |
+
return _success_response(f"文件已重命名为 '{safe_name}.json'")
|
| 199 |
+
except Exception as e:
|
| 200 |
+
return _error_response(f"重命名失败: {str(e)}")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def move_folder(demo_dir: Path, source_path: str, target_path: str) -> Dict[str, any]:
|
| 204 |
+
"""移动文件夹(递归)"""
|
| 205 |
+
source_folder = resolve_demo_path(demo_dir, source_path)
|
| 206 |
+
if not source_folder:
|
| 207 |
+
return _error_response(f"源文件夹不存在: {source_path}")
|
| 208 |
+
|
| 209 |
+
error_msg = _validate_folder(source_folder)
|
| 210 |
+
if error_msg:
|
| 211 |
+
return _error_response(f"源{error_msg}: {source_path}" if "不存在" not in error_msg else error_msg)
|
| 212 |
+
|
| 213 |
+
target_dir = resolve_demo_path(demo_dir, target_path)
|
| 214 |
+
if not target_dir:
|
| 215 |
+
return _error_response(f"无效的目标路径: {target_path}")
|
| 216 |
+
|
| 217 |
+
target_dir.mkdir(parents=True, exist_ok=True)
|
| 218 |
+
target_folder = target_dir / source_folder.name
|
| 219 |
+
|
| 220 |
+
if target_folder.exists():
|
| 221 |
+
return _error_response(f"目标位置已存在同名文件夹: {source_folder.name}")
|
| 222 |
+
|
| 223 |
+
# 检查是否尝试移动到自己的子目录
|
| 224 |
+
if check_path_in_demo_dir(target_folder.resolve(), source_folder.resolve()):
|
| 225 |
+
return _error_response("不能将文件夹移动到自己的子目录")
|
| 226 |
+
|
| 227 |
+
try:
|
| 228 |
+
shutil.move(str(source_folder), str(target_folder))
|
| 229 |
+
return _success_response(f"文件夹已移动到 {target_path}")
|
| 230 |
+
except Exception as e:
|
| 231 |
+
return _error_response(f"移动失败: {str(e)}")
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def rename_folder(demo_dir: Path, folder_path: str, new_name: str) -> Dict[str, any]:
|
| 235 |
+
"""重命名文件夹"""
|
| 236 |
+
from backend.data_utils import sanitize_demo_name
|
| 237 |
+
|
| 238 |
+
source_folder = resolve_demo_path(demo_dir, folder_path)
|
| 239 |
+
if not source_folder:
|
| 240 |
+
return _error_response(f"文件夹不存在: {folder_path}")
|
| 241 |
+
|
| 242 |
+
error_msg = _validate_folder(source_folder)
|
| 243 |
+
if error_msg:
|
| 244 |
+
return _error_response(error_msg)
|
| 245 |
+
|
| 246 |
+
safe_name = sanitize_demo_name(new_name)
|
| 247 |
+
if not safe_name:
|
| 248 |
+
return _error_response("新名称无效")
|
| 249 |
+
|
| 250 |
+
target_folder = source_folder.parent / safe_name
|
| 251 |
+
|
| 252 |
+
if target_folder.exists():
|
| 253 |
+
return _error_response(f"文件夹 '{safe_name}' 已存在")
|
| 254 |
+
|
| 255 |
+
try:
|
| 256 |
+
source_folder.rename(target_folder)
|
| 257 |
+
return _success_response(f"文件夹已重命名为 '{safe_name}'")
|
| 258 |
+
except Exception as e:
|
| 259 |
+
return _error_response(f"重命名失败: {str(e)}")
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def create_folder(demo_dir: Path, parent_path: str, folder_name: str) -> Dict[str, any]:
|
| 263 |
+
"""创建新文件夹"""
|
| 264 |
+
from backend.data_utils import sanitize_demo_name
|
| 265 |
+
|
| 266 |
+
parent_dir = resolve_demo_path(demo_dir, parent_path)
|
| 267 |
+
if not parent_dir:
|
| 268 |
+
return _error_response(f"无效的父路径: {parent_path}")
|
| 269 |
+
|
| 270 |
+
safe_name = sanitize_demo_name(folder_name)
|
| 271 |
+
if not safe_name:
|
| 272 |
+
return _error_response("文件夹名称无效")
|
| 273 |
+
|
| 274 |
+
target_folder = parent_dir / safe_name
|
| 275 |
+
|
| 276 |
+
if target_folder.exists():
|
| 277 |
+
return _error_response(f"文件夹 '{safe_name}' 已存在")
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
target_folder.mkdir(parents=True, exist_ok=False)
|
| 281 |
+
return _success_response(f"文件夹 '{safe_name}' 已创建")
|
| 282 |
+
except Exception as e:
|
| 283 |
+
return _error_response(f"创建失败: {str(e)}")
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def delete_folder(demo_dir: Path, folder_path: str) -> Dict[str, any]:
|
| 287 |
+
"""删除文件夹(移动到 .deleted 隐藏目录)"""
|
| 288 |
+
source_folder = resolve_demo_path(demo_dir, folder_path)
|
| 289 |
+
if not source_folder:
|
| 290 |
+
return _error_response(f"文件夹不存在: {folder_path}")
|
| 291 |
+
|
| 292 |
+
error_msg = _validate_folder(source_folder)
|
| 293 |
+
if error_msg:
|
| 294 |
+
return _error_response(error_msg)
|
| 295 |
+
|
| 296 |
+
deleted_dir = _ensure_deleted_dir(demo_dir)
|
| 297 |
+
target_folder = deleted_dir / source_folder.name
|
| 298 |
+
|
| 299 |
+
if target_folder.exists():
|
| 300 |
+
target_folder = deleted_dir / _get_timestamped_name(source_folder.name)
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
shutil.move(str(source_folder), str(target_folder))
|
| 304 |
+
return _success_response("文件夹已移动到 .deleted 目录")
|
| 305 |
+
except Exception as e:
|
| 306 |
+
return _error_response(f"删除失败: {str(e)}")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def delete_demo_file(demo_dir: Path, file_path: str) -> Dict[str, any]:
|
| 310 |
+
"""删除demo文件(移动到 .deleted 隐藏目录)"""
|
| 311 |
+
demo_dir_resolved = demo_dir.resolve()
|
| 312 |
+
source_file = resolve_demo_path(demo_dir_resolved, file_path)
|
| 313 |
+
|
| 314 |
+
if not source_file:
|
| 315 |
+
return _error_response(f"文件不存在: {file_path}")
|
| 316 |
+
|
| 317 |
+
error_msg = _validate_json_file(source_file)
|
| 318 |
+
if error_msg:
|
| 319 |
+
return _error_response(error_msg)
|
| 320 |
+
|
| 321 |
+
try:
|
| 322 |
+
relative_path = source_file.relative_to(demo_dir_resolved)
|
| 323 |
+
except ValueError:
|
| 324 |
+
return _error_response("无效的文件路径")
|
| 325 |
+
|
| 326 |
+
deleted_dir = _ensure_deleted_dir(demo_dir_resolved)
|
| 327 |
+
target_file = deleted_dir / relative_path
|
| 328 |
+
target_parent = target_file.parent
|
| 329 |
+
target_parent.mkdir(parents=True, exist_ok=True)
|
| 330 |
+
|
| 331 |
+
if target_file.exists():
|
| 332 |
+
target_file = target_parent / _get_timestamped_name(source_file.stem, ".json")
|
| 333 |
+
|
| 334 |
+
try:
|
| 335 |
+
shutil.move(str(source_file), str(target_file))
|
| 336 |
+
return _success_response(f"文件已移动到 .deleted 目录: {relative_path.as_posix()}")
|
| 337 |
+
except Exception as e:
|
| 338 |
+
return _error_response(f"删除失败: {str(e)}")
|
| 339 |
+
|
backend/device.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""设备管理:CPU/CUDA/MPS 检测与内存统计"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class DeviceManager:
|
| 8 |
+
"""设备管理工具类,统一处理设备相关的操作"""
|
| 9 |
+
|
| 10 |
+
@staticmethod
|
| 11 |
+
def clear_cache(device: torch.device) -> None:
|
| 12 |
+
"""清理设备缓存"""
|
| 13 |
+
if device.type == "cuda":
|
| 14 |
+
torch.cuda.empty_cache()
|
| 15 |
+
elif device.type == "mps":
|
| 16 |
+
torch.mps.empty_cache()
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def synchronize(device: torch.device) -> None:
|
| 20 |
+
"""同步设备操作"""
|
| 21 |
+
if device.type == "cuda":
|
| 22 |
+
torch.cuda.synchronize()
|
| 23 |
+
elif device.type == "mps":
|
| 24 |
+
torch.mps.synchronize()
|
| 25 |
+
|
| 26 |
+
@staticmethod
|
| 27 |
+
def get_device() -> torch.device:
|
| 28 |
+
"""
|
| 29 |
+
获取计算设备
|
| 30 |
+
优先级:1. FORCE_CPU=1 强制 CPU 2. cuda > mps > cpu
|
| 31 |
+
"""
|
| 32 |
+
if os.environ.get('FORCE_CPU') == '1':
|
| 33 |
+
return torch.device("cpu")
|
| 34 |
+
if torch.cuda.is_available():
|
| 35 |
+
return torch.device("cuda")
|
| 36 |
+
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 37 |
+
return torch.device("mps")
|
| 38 |
+
return torch.device("cpu")
|
| 39 |
+
|
| 40 |
+
@staticmethod
|
| 41 |
+
def get_device_name(device: torch.device) -> str:
|
| 42 |
+
"""获取设备显示名称"""
|
| 43 |
+
if device.type == "cuda":
|
| 44 |
+
return "GPU"
|
| 45 |
+
elif device.type == "mps":
|
| 46 |
+
return "Apple Silicon"
|
| 47 |
+
else:
|
| 48 |
+
return "CPU"
|
| 49 |
+
|
| 50 |
+
@staticmethod
|
| 51 |
+
def print_model_load_stats(model: torch.nn.Module, load_time: float) -> None:
|
| 52 |
+
"""打印模型加载统计信息(大小、时间、速度)"""
|
| 53 |
+
# 计算模型大小
|
| 54 |
+
model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
|
| 55 |
+
model_size_mb = model_size_bytes / (1024 * 1024)
|
| 56 |
+
# 计算加载速度
|
| 57 |
+
load_speed_mb_per_sec = model_size_mb / load_time if load_time > 0 else 0
|
| 58 |
+
# 格式化大小
|
| 59 |
+
size_str = f"{model_size_mb:.1f}MB" if model_size_mb < 1024 else f"{model_size_mb / 1024:.2f}GB"
|
| 60 |
+
# 格式化时间
|
| 61 |
+
if load_time < 1:
|
| 62 |
+
time_str = f"{load_time * 1000:.1f}ms"
|
| 63 |
+
elif load_time < 60:
|
| 64 |
+
time_str = f"{load_time:.2f}s"
|
| 65 |
+
else:
|
| 66 |
+
time_str = f"{int(load_time // 60)}m{load_time % 60:.1f}s"
|
| 67 |
+
print(f"✅ 模型加载完成 [大小: {size_str}, 耗时: {time_str}, 速度: {load_speed_mb_per_sec:.1f}MB/s]")
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def print_cuda_memory_summary(title="GPU 内存统计", device=0):
|
| 71 |
+
"""打印详细的 CUDA 内存统计信息"""
|
| 72 |
+
if not torch.cuda.is_available():
|
| 73 |
+
return
|
| 74 |
+
print(f"\n{'='*60}")
|
| 75 |
+
print(f"🔍 {title}")
|
| 76 |
+
print(f"{'='*60}")
|
| 77 |
+
# 基本统计
|
| 78 |
+
allocated = torch.cuda.memory_allocated(device) / 1024**3
|
| 79 |
+
reserved = torch.cuda.memory_reserved(device) / 1024**3
|
| 80 |
+
max_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
|
| 81 |
+
total = torch.cuda.get_device_properties(device).total_memory / 1024**3
|
| 82 |
+
print(f"📊 总显存: {total:.2f} GB")
|
| 83 |
+
print(f"✅ 已分配 (allocated): {allocated:.2f} GB ({allocated/total*100:.1f}%)")
|
| 84 |
+
print(f"📦 已预留 (reserved): {reserved:.2f} GB ({reserved/total*100:.1f}%)")
|
| 85 |
+
print(f"📈 峰值分配: {max_allocated:.2f} GB")
|
| 86 |
+
print(f"💚 可用空间: {total - reserved:.2f} GB ({(total-reserved)/total*100:.1f}%)")
|
| 87 |
+
print(f"🔸 碎片化: {reserved - allocated:.2f} GB")
|
| 88 |
+
# 详细统计(简化版)
|
| 89 |
+
try:
|
| 90 |
+
stats = torch.cuda.memory_stats(device)
|
| 91 |
+
num_allocs = stats.get("num_alloc_retries", 0)
|
| 92 |
+
num_ooms = stats.get("num_ooms", 0)
|
| 93 |
+
if num_allocs > 0 or num_ooms > 0:
|
| 94 |
+
print(f"⚠️ 分配重试: {num_allocs} 次, OOM: {num_ooms} 次")
|
| 95 |
+
except Exception:
|
| 96 |
+
pass
|
| 97 |
+
print(f"{'='*60}\n")
|
backend/language_checker.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import gc
|
| 3 |
+
from typing import Callable, Dict, List, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from .api.utils import round_to_sig_figs
|
| 6 |
+
from .pred_topk_format import pred_topk_pairs_from_flat_ids_and_probs
|
| 7 |
+
from .class_register import register_model, REGISTERED_MODELS
|
| 8 |
+
from .device import DeviceManager
|
| 9 |
+
from .model_manager import ensure_model_loaded
|
| 10 |
+
from .runtime_config import load_runtime_config, DEFAULT_TOPK
|
| 11 |
+
from model_paths import DEFAULT_MODEL, MODEL_PATHS, SEMANTIC_MODEL_PATHS, resolve_hf_path
|
| 12 |
+
|
| 13 |
+
# 按 id(model) 缓存「仅含 BOS/等价起始符一步 forward」得到的末位词表 logits(全词表,不随分析文本变)
|
| 14 |
+
_bos_first_position_logits_cache: Dict[int, torch.Tensor] = {}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def compute_first_token_lm_with_bos_prefix_cache(
|
| 18 |
+
model: torch.nn.Module,
|
| 19 |
+
tokenizer,
|
| 20 |
+
device: torch.device,
|
| 21 |
+
first_token_id: int,
|
| 22 |
+
effective_topk: int,
|
| 23 |
+
) -> Tuple[float, List[Tuple[str, float]]]:
|
| 24 |
+
"""
|
| 25 |
+
首 token 无左文时的 workaround:与旧版 BOS 前缀一致,对单 token 输入 [bos] 做一步 forward,
|
| 26 |
+
将末位 logits(预测首段文本第一个 token 的分布)缓存到 CPU,再在 CPU 上 softmax/topk。
|
| 27 |
+
|
| 28 |
+
同一 model 实例复用同一份词表 logits,不在每次分析时重复 forward。
|
| 29 |
+
"""
|
| 30 |
+
mid = id(model)
|
| 31 |
+
if mid not in _bos_first_position_logits_cache:
|
| 32 |
+
if tokenizer.bos_token_id is not None:
|
| 33 |
+
bos_id = int(tokenizer.bos_token_id)
|
| 34 |
+
elif tokenizer.eos_token_id is not None:
|
| 35 |
+
bos_id = int(tokenizer.eos_token_id)
|
| 36 |
+
else:
|
| 37 |
+
bos_id = 0
|
| 38 |
+
with torch.inference_mode():
|
| 39 |
+
bos_in = torch.tensor([[bos_id]], device=device, dtype=torch.long)
|
| 40 |
+
out = model(input_ids=bos_in)
|
| 41 |
+
# [V]:在 BOS 条件下预测「第一个文本 token」的分布
|
| 42 |
+
row = out.logits[0, -1, :].detach().float()
|
| 43 |
+
_bos_first_position_logits_cache[mid] = row.cpu()
|
| 44 |
+
|
| 45 |
+
logits = _bos_first_position_logits_cache[mid]
|
| 46 |
+
probs = torch.softmax(logits, dim=-1)
|
| 47 |
+
p = float(probs[first_token_id].item())
|
| 48 |
+
|
| 49 |
+
topk_vals, topk_inds = torch.topk(probs, k=min(effective_topk, probs.shape[0]), dim=-1)
|
| 50 |
+
topk_vals = topk_vals.float().numpy()
|
| 51 |
+
topk_inds_flat = topk_inds.flatten().tolist()
|
| 52 |
+
topk_tokens_decoded = tokenizer.batch_decode(
|
| 53 |
+
[[tid] for tid in topk_inds_flat],
|
| 54 |
+
skip_special_tokens=False,
|
| 55 |
+
)
|
| 56 |
+
pred_topk = [
|
| 57 |
+
(topk_tokens_decoded[j], round_to_sig_figs(float(topk_vals[j])))
|
| 58 |
+
for j in range(len(topk_tokens_decoded))
|
| 59 |
+
]
|
| 60 |
+
return p, pred_topk
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class AbstractLanguageChecker:
|
| 64 |
+
"""
|
| 65 |
+
Abstract Class that defines the Backend API of GLTR.
|
| 66 |
+
|
| 67 |
+
To extend the GLTR interface, you need to inherit this and
|
| 68 |
+
fill in the defined functions.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self):
|
| 72 |
+
"""
|
| 73 |
+
In the subclass, you need to load all necessary components
|
| 74 |
+
for the other functions.
|
| 75 |
+
Typically, this will comprise a tokenizer and a model.
|
| 76 |
+
"""
|
| 77 |
+
self.device = DeviceManager.get_device()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def analyze_text(self, in_text):
|
| 81 |
+
"""
|
| 82 |
+
Function that GLTR interacts with to analyze text and get token probabilities
|
| 83 |
+
|
| 84 |
+
Params:
|
| 85 |
+
- in_text: str -- The text that you want to analyze
|
| 86 |
+
- topk: int, optional -- Desired pred_topk count (default from runtime_config.DEFAULT_TOPK)
|
| 87 |
+
|
| 88 |
+
Output:
|
| 89 |
+
- payload: dict -- The wrapper for results in this function, described below
|
| 90 |
+
|
| 91 |
+
Payload values
|
| 92 |
+
==============
|
| 93 |
+
bpe_strings: list of dict -- Each dict contains {"offset": [start, end], "raw": str,
|
| 94 |
+
"real_topk": [rank, prob], "pred_topk": [(token, prob), ...]}
|
| 95 |
+
- offset: character offsets in the original text [start, end]
|
| 96 |
+
- raw: token text extracted from original text
|
| 97 |
+
- real_topk: (ranking, prob) of each token(优先级默认0)
|
| 98 |
+
- pred_topk: top-k 候选列表(若不可用则为空数组)
|
| 99 |
+
"""
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@register_model(name='qwen2.5-0.5b')
|
| 105 |
+
class QwenLM(AbstractLanguageChecker):
|
| 106 |
+
"""
|
| 107 |
+
Qwen 系列模型支持
|
| 108 |
+
默认使用 Qwen2.5-0.5B Base 模型(适合计算 surprisal 和信息量)
|
| 109 |
+
"""
|
| 110 |
+
def __init__(self, model_path=None, model_name=None):
|
| 111 |
+
super(QwenLM, self).__init__()
|
| 112 |
+
model_name = model_name or getattr(self.__class__, '_registered_model_name', DEFAULT_MODEL)
|
| 113 |
+
if model_path is not None and str(model_path).strip():
|
| 114 |
+
resolved = str(model_path).strip()
|
| 115 |
+
else:
|
| 116 |
+
resolved = resolve_hf_path(model_name)
|
| 117 |
+
|
| 118 |
+
# 加载运行时配置(支持部分覆盖)
|
| 119 |
+
self._load_runtime_config(model_name)
|
| 120 |
+
|
| 121 |
+
self.tokenizer, self.model, self.device = ensure_model_loaded(resolved)
|
| 122 |
+
|
| 123 |
+
# ============================================================
|
| 124 |
+
# 关于 torch.compile() 的性能优化讨论结论:
|
| 125 |
+
#
|
| 126 |
+
# CPU 环境:
|
| 127 |
+
# - 成本 > 收益,不推荐使用
|
| 128 |
+
#
|
| 129 |
+
# CUDA 环境(如果未来升级到 GPU Space):
|
| 130 |
+
# - 加速比:30-70%(显著提升)
|
| 131 |
+
# - 编译时间:相对推理时间更短
|
| 132 |
+
# - Triton 内核优化:显著减少显存读写
|
| 133 |
+
# - 结论:强烈推荐使用,需配合预热确保形状覆盖
|
| 134 |
+
# 如需启用,可在此处添加:
|
| 135 |
+
# if torch.cuda.is_available() and hasattr(torch, 'compile'):
|
| 136 |
+
# self.model = torch.compile(self.model, mode="default")
|
| 137 |
+
# # 并在启动时运行预热推理覆盖 chunk_size 长度
|
| 138 |
+
# ============================================================
|
| 139 |
+
|
| 140 |
+
# 初始化分析计数器(用于控制GPU内存统计打印频率)
|
| 141 |
+
self._analysis_count = 0
|
| 142 |
+
|
| 143 |
+
def _load_runtime_config(self, model_name: Optional[str]):
|
| 144 |
+
"""
|
| 145 |
+
加载运行时配置:基于模型和平台的四层配置合并
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
model_name: 模型标识符(如 'qwen3-1.7b')
|
| 149 |
+
"""
|
| 150 |
+
# 调用配置模块的完整加载流程
|
| 151 |
+
# 返回: (platform, max_token_length, chunk_size)
|
| 152 |
+
self.platform, self.max_length, self.chunk_size = load_runtime_config(
|
| 153 |
+
model_name=model_name or "default_model"
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def _encode_text(self, in_text: str) -> Tuple[torch.Tensor, List[Tuple[int, int]]]:
|
| 157 |
+
"""编码文本并返回 token_ids 和 offsets"""
|
| 158 |
+
# 使用 tokenizer 的原生截断功能
|
| 159 |
+
enc_out = self.tokenizer(
|
| 160 |
+
in_text,
|
| 161 |
+
return_tensors='pt',
|
| 162 |
+
return_offsets_mapping=True,
|
| 163 |
+
max_length=self.max_length,
|
| 164 |
+
truncation=True
|
| 165 |
+
)
|
| 166 |
+
token_ids = enc_out['input_ids']
|
| 167 |
+
token_offsets = enc_out['offset_mapping'][0].tolist()
|
| 168 |
+
|
| 169 |
+
# 通过最后一个 offset 和文本长度对比判断是否截断
|
| 170 |
+
if token_offsets:
|
| 171 |
+
last_offset_end = token_offsets[-1][1]
|
| 172 |
+
if last_offset_end < len(in_text):
|
| 173 |
+
# 文本被截断了,警告token截断信息,和字数截断信息
|
| 174 |
+
print(f"⚠️ 文本过长,已截断至前 {self.max_length} token ({len(in_text)} char -> {last_offset_end} char)")
|
| 175 |
+
|
| 176 |
+
token_ids = token_ids.to(self.device)
|
| 177 |
+
|
| 178 |
+
return token_ids, token_offsets
|
| 179 |
+
|
| 180 |
+
def _run_inference_and_process_chunked(
|
| 181 |
+
self,
|
| 182 |
+
token_ids: torch.Tensor,
|
| 183 |
+
effective_topk: int,
|
| 184 |
+
progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None
|
| 185 |
+
) -> Tuple[List[List[Tuple[str, float]]], List[float]]:
|
| 186 |
+
"""
|
| 187 |
+
分块推理并即时处理:核心内存优化逻辑
|
| 188 |
+
利用 KV Cache 分段计算 Logits,计算完立即释放,避免保留全量 Logits。
|
| 189 |
+
|
| 190 |
+
数值说明:在 float16(如 MPS)上,在「仅前缀 forward」vs「整段 forward」同位置 logits 的逐元素对比,可能出现微小差异;
|
| 191 |
+
float16(MPS/CUDA)可能因实现路径出现约 1%的 量级差,非掩码错误。CPU float32 下则完全一致。
|
| 192 |
+
"""
|
| 193 |
+
seq_len = token_ids.shape[1]
|
| 194 |
+
|
| 195 |
+
# 使用初始化时根据平台确定的 chunk_size
|
| 196 |
+
chunk_size = self.chunk_size
|
| 197 |
+
|
| 198 |
+
real_probs_list = []
|
| 199 |
+
pred_topk_list = []
|
| 200 |
+
past_key_values = None
|
| 201 |
+
|
| 202 |
+
# 预先清理
|
| 203 |
+
DeviceManager.clear_cache(self.device)
|
| 204 |
+
|
| 205 |
+
full_input_ids = token_ids
|
| 206 |
+
|
| 207 |
+
# 因果 LM:logits[i] 预测 input_ids[i+1];首 token 无左文,不在此循环中计分
|
| 208 |
+
|
| 209 |
+
# 我们使用 past_key_values 增量推理
|
| 210 |
+
# 第一次:输入 input_ids[:, :chunk_size],输出 logits 对应位置 0..chunk_size-1 (预测 1..chunk_size)
|
| 211 |
+
|
| 212 |
+
total_chunks = (seq_len + chunk_size - 1) // chunk_size
|
| 213 |
+
|
| 214 |
+
with torch.inference_mode():
|
| 215 |
+
for i in range(total_chunks):
|
| 216 |
+
start_idx = i * chunk_size
|
| 217 |
+
end_idx = min((i + 1) * chunk_size, seq_len)
|
| 218 |
+
current_chunk_len = end_idx - start_idx
|
| 219 |
+
|
| 220 |
+
# 准备输入(统一逻辑,避免边界 token 重复)
|
| 221 |
+
if i == 0:
|
| 222 |
+
input_chunk = full_input_ids[:, :end_idx]
|
| 223 |
+
else:
|
| 224 |
+
input_chunk = full_input_ids[:, start_idx:end_idx]
|
| 225 |
+
|
| 226 |
+
# 1. 运行推理
|
| 227 |
+
outputs = self.model(
|
| 228 |
+
input_ids=input_chunk,
|
| 229 |
+
past_key_values=past_key_values,
|
| 230 |
+
use_cache=True
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
past_key_values = outputs.past_key_values
|
| 234 |
+
|
| 235 |
+
logits = outputs.logits
|
| 236 |
+
|
| 237 |
+
# 获取 targets
|
| 238 |
+
# full_input_ids[:, 1:] 是所有 targets
|
| 239 |
+
# 当前块 targets 范围: [start_idx : end_idx]
|
| 240 |
+
chunk_targets = full_input_ids[:, 1+start_idx : 1+end_idx]
|
| 241 |
+
valid_len = chunk_targets.shape[1]
|
| 242 |
+
if valid_len == 0:
|
| 243 |
+
continue
|
| 244 |
+
# 最后一块覆盖到序列末尾时,最后一个 logit 位预测的是「下一 token」,需裁掉
|
| 245 |
+
current_logits = logits[:, :valid_len, :]
|
| 246 |
+
|
| 247 |
+
# 2. 处理当前块的 Softmax 和 TopK
|
| 248 |
+
probs_chunk = torch.softmax(current_logits, dim=2)
|
| 249 |
+
|
| 250 |
+
# 提取真实概率
|
| 251 |
+
chunk_target_probs = torch.gather(probs_chunk, 2, chunk_targets.unsqueeze(-1))
|
| 252 |
+
real_probs_list.extend(chunk_target_probs.flatten().detach().cpu().float().numpy().tolist())
|
| 253 |
+
|
| 254 |
+
# 提取 TopK
|
| 255 |
+
# 由于 chunk_size 已确保小于 MPS_TOPK_BUG_THRESHOLD,所以直接计算
|
| 256 |
+
topk_vals, topk_inds = torch.topk(probs_chunk, k=effective_topk, dim=2)
|
| 257 |
+
chunk_pred_topk = self._decode_topk_tokens(
|
| 258 |
+
topk_vals, topk_inds, effective_topk, valid_len
|
| 259 |
+
)
|
| 260 |
+
pred_topk_list.extend(chunk_pred_topk)
|
| 261 |
+
|
| 262 |
+
# 3. 立即释放内存
|
| 263 |
+
del logits
|
| 264 |
+
del current_logits
|
| 265 |
+
del probs_chunk
|
| 266 |
+
del chunk_target_probs
|
| 267 |
+
# outputs 会在下一次循环时被覆盖,无需手动处理
|
| 268 |
+
|
| 269 |
+
# 进度更新(基于实际处理的 token 数量)
|
| 270 |
+
if progress_callback:
|
| 271 |
+
pct = int(end_idx / seq_len * 100) # 推理阶段独立的 0-100%
|
| 272 |
+
progress_callback(2, 3, 'inference', pct)
|
| 273 |
+
|
| 274 |
+
# 循环结束,清理 KV Cache
|
| 275 |
+
del past_key_values
|
| 276 |
+
DeviceManager.clear_cache(self.device)
|
| 277 |
+
|
| 278 |
+
return pred_topk_list, real_probs_list
|
| 279 |
+
|
| 280 |
+
def _decode_topk_tokens(
|
| 281 |
+
self,
|
| 282 |
+
topk_prob_values: torch.Tensor,
|
| 283 |
+
topk_prob_inds: torch.Tensor,
|
| 284 |
+
effective_topk: int,
|
| 285 |
+
seq_len: int
|
| 286 |
+
) -> List[List[Tuple[str, float]]]:
|
| 287 |
+
"""解码 TopK tokens 并构建预测列表(长度等于参与 topk 的序列长度)"""
|
| 288 |
+
topk_prob_values_cpu = topk_prob_values[0].detach().cpu().float().numpy()
|
| 289 |
+
topk_prob_inds_flat = topk_prob_inds[0].cpu().flatten().tolist()
|
| 290 |
+
probs_flat = topk_prob_values_cpu.flatten().tolist()
|
| 291 |
+
flat_pairs = pred_topk_pairs_from_flat_ids_and_probs(
|
| 292 |
+
topk_prob_inds_flat, probs_flat, self.tokenizer
|
| 293 |
+
)
|
| 294 |
+
return [
|
| 295 |
+
flat_pairs[i * effective_topk : (i + 1) * effective_topk]
|
| 296 |
+
for i in range(seq_len)
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
def _build_bpe_strings(
|
| 300 |
+
self,
|
| 301 |
+
token_offsets: List[Tuple[int, int]],
|
| 302 |
+
real_topk: List[Tuple[int, float]],
|
| 303 |
+
pred_topk: List[List[Tuple[str, float]]],
|
| 304 |
+
in_text: str
|
| 305 |
+
) -> List[Dict]:
|
| 306 |
+
"""构建最终的 BPE 字符串列表"""
|
| 307 |
+
# 确保长度一致
|
| 308 |
+
min_len = min(len(token_offsets), len(real_topk), len(pred_topk) if pred_topk else len(token_offsets))
|
| 309 |
+
|
| 310 |
+
bpe_strings = []
|
| 311 |
+
for idx in range(min_len):
|
| 312 |
+
start, end = token_offsets[idx]
|
| 313 |
+
raw_text = in_text[start:end] if start < end else ""
|
| 314 |
+
token_payload = {
|
| 315 |
+
"offset": [start, end],
|
| 316 |
+
"raw": raw_text,
|
| 317 |
+
"real_topk": list(real_topk[idx]),
|
| 318 |
+
"pred_topk": pred_topk[idx] if pred_topk else []
|
| 319 |
+
}
|
| 320 |
+
bpe_strings.append(token_payload)
|
| 321 |
+
|
| 322 |
+
return bpe_strings
|
| 323 |
+
|
| 324 |
+
def analyze_text(self, in_text: str, progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None) -> Dict[str, List[Dict]]:
|
| 325 |
+
"""
|
| 326 |
+
计算文本中每个 token 的概率
|
| 327 |
+
|
| 328 |
+
进度回调参数: (step: int, total_steps: int, stage: str, percentage: Optional[int])
|
| 329 |
+
- step: 当前步骤 (1-based)
|
| 330 |
+
- total_steps: 总步骤数 (固定为 3)
|
| 331 |
+
- stage: 阶段名称 (encoding/inference/processing)
|
| 332 |
+
- percentage: 可选的百分比,仅在 inference 阶段提供
|
| 333 |
+
"""
|
| 334 |
+
TOTAL_STEPS = 3
|
| 335 |
+
|
| 336 |
+
try:
|
| 337 |
+
# Step 1: 编码文本
|
| 338 |
+
if progress_callback:
|
| 339 |
+
progress_callback(1, TOTAL_STEPS, 'encoding', None)
|
| 340 |
+
token_ids, token_offsets = self._encode_text(in_text)
|
| 341 |
+
|
| 342 |
+
# Step 2: 分块推理并处理(带百分比进度)
|
| 343 |
+
# 这取代了原来的 _run_model_inference, MPS 流式处理, 和 _process_topk
|
| 344 |
+
|
| 345 |
+
if progress_callback:
|
| 346 |
+
progress_callback(2, 3, 'inference', 0)
|
| 347 |
+
pred_topk, real_topk_probs = self._run_inference_and_process_chunked(
|
| 348 |
+
token_ids, DEFAULT_TOPK, progress_callback
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Step 3: 构建结果
|
| 352 |
+
if progress_callback:
|
| 353 |
+
progress_callback(3, TOTAL_STEPS, 'processing', None)
|
| 354 |
+
|
| 355 |
+
if token_ids.shape[1] >= 1:
|
| 356 |
+
p0, pred0 = compute_first_token_lm_with_bos_prefix_cache(
|
| 357 |
+
self.model,
|
| 358 |
+
self.tokenizer,
|
| 359 |
+
self.device,
|
| 360 |
+
int(token_ids[0, 0].item()),
|
| 361 |
+
DEFAULT_TOPK,
|
| 362 |
+
)
|
| 363 |
+
pred_topk.insert(0, pred0)
|
| 364 |
+
real_topk_probs.insert(0, p0)
|
| 365 |
+
|
| 366 |
+
seq_len = len(real_topk_probs)
|
| 367 |
+
real_topk = list(zip([0] * seq_len, [round_to_sig_figs(p) for p in real_topk_probs]))
|
| 368 |
+
|
| 369 |
+
bpe_strings = self._build_bpe_strings(token_offsets, real_topk, pred_topk, in_text)
|
| 370 |
+
|
| 371 |
+
# 最终清理
|
| 372 |
+
DeviceManager.clear_cache(self.device)
|
| 373 |
+
gc.collect()
|
| 374 |
+
|
| 375 |
+
# 更新分析计数器
|
| 376 |
+
self._analysis_count += 1
|
| 377 |
+
|
| 378 |
+
# 打印分析任务完成后的内存统计(第1、11、21...次分析后打印)
|
| 379 |
+
if self.device.type == "cuda" and (self._analysis_count - 1) % 10 == 0:
|
| 380 |
+
device_idx = self.device.index if self.device.index is not None else 0
|
| 381 |
+
DeviceManager.print_cuda_memory_summary(device=device_idx)
|
| 382 |
+
|
| 383 |
+
return {'bpe_strings': bpe_strings}
|
| 384 |
+
|
| 385 |
+
except Exception as e:
|
| 386 |
+
import traceback
|
| 387 |
+
print(f"❌ Error in QwenLM.analyze_text: {e}")
|
| 388 |
+
traceback.print_exc()
|
| 389 |
+
return {'bpe_strings': []}
|
| 390 |
+
|
| 391 |
+
# _cleanup_tensors 方法已被移除,因为不再需要显式清理小张量
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
# ============================================================
|
| 395 |
+
# 自动注册:根据 MODEL_PATHS 与 SEMANTIC_MODEL_PATHS 自动注册所有模型
|
| 396 |
+
# ============================================================
|
| 397 |
+
# 只需要在 model_paths.py 中添加模型路径,即可自动注册
|
| 398 |
+
# 无需手动创建子类,实现 DRY 原则
|
| 399 |
+
def _auto_register_models():
|
| 400 |
+
"""自动注册 MODEL_PATHS 与 SEMANTIC_MODEL_PATHS 中的所有模型"""
|
| 401 |
+
for model_name in (*MODEL_PATHS.keys(), *SEMANTIC_MODEL_PATHS.keys()):
|
| 402 |
+
if model_name not in REGISTERED_MODELS:
|
| 403 |
+
# 动态创建模型类并注册
|
| 404 |
+
# 使用闭包捕获当前 model_name
|
| 405 |
+
def make_init():
|
| 406 |
+
def __init__(self):
|
| 407 |
+
QwenLM.__init__(self)
|
| 408 |
+
return __init__
|
| 409 |
+
|
| 410 |
+
model_class = type(
|
| 411 |
+
f'QwenLM_{model_name.replace(".", "_").replace("-", "_")}',
|
| 412 |
+
(QwenLM,),
|
| 413 |
+
{
|
| 414 |
+
'__init__': make_init(),
|
| 415 |
+
'__doc__': f'{model_name} 模型支持(自动注册)'
|
| 416 |
+
}
|
| 417 |
+
)
|
| 418 |
+
register_model(model_name)(model_class)
|
| 419 |
+
|
| 420 |
+
# 执行自动注册
|
| 421 |
+
_auto_register_models()
|
| 422 |
+
|
backend/load_utils.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""HuggingFace 模型下载与加载:下载独立,加载仅考虑本地"""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from typing import Callable, TypeVar
|
| 6 |
+
|
| 7 |
+
T = TypeVar("T")
|
| 8 |
+
|
| 9 |
+
# 与 transformers 的 checkpoint 命名一致
|
| 10 |
+
_SAFE_WEIGHTS = "model.safetensors"
|
| 11 |
+
_SAFE_WEIGHTS_INDEX = "model.safetensors.index.json"
|
| 12 |
+
_WEIGHTS = "pytorch_model.bin"
|
| 13 |
+
_WEIGHTS_INDEX = "pytorch_model.bin.index.json"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _is_model_cache_complete(local_path: str) -> bool:
|
| 17 |
+
"""
|
| 18 |
+
本地检查模型权重是否完整。与 transformers 的 _get_resolved_checkpoint_files 逻辑一致。
|
| 19 |
+
"""
|
| 20 |
+
def _p(f: str) -> str:
|
| 21 |
+
return os.path.join(local_path, f)
|
| 22 |
+
|
| 23 |
+
if os.path.isfile(_p(_SAFE_WEIGHTS)):
|
| 24 |
+
return True
|
| 25 |
+
index_file = _p(_SAFE_WEIGHTS_INDEX)
|
| 26 |
+
if os.path.isfile(index_file):
|
| 27 |
+
with open(index_file) as f:
|
| 28 |
+
index = json.load(f)
|
| 29 |
+
shards = set(index.get("weight_map", {}).values())
|
| 30 |
+
return all(os.path.isfile(_p(s)) for s in shards)
|
| 31 |
+
if os.path.isfile(_p(_WEIGHTS)):
|
| 32 |
+
return True
|
| 33 |
+
index_file = _p(_WEIGHTS_INDEX)
|
| 34 |
+
if os.path.isfile(index_file):
|
| 35 |
+
with open(index_file) as f:
|
| 36 |
+
index = json.load(f)
|
| 37 |
+
shards = set(index.get("weight_map", {}).values())
|
| 38 |
+
return all(os.path.isfile(_p(s)) for s in shards)
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def ensure_model_local(model_path: str, *, force_download: bool = False) -> str:
|
| 43 |
+
"""
|
| 44 |
+
确保模型在本地可用,返回本地路径。
|
| 45 |
+
- 本地目录:直接返回
|
| 46 |
+
- HuggingFace ID:优先用本地缓存(不联网),缓存不完整时 force_download 可触发下载
|
| 47 |
+
"""
|
| 48 |
+
if os.path.isdir(model_path):
|
| 49 |
+
return model_path
|
| 50 |
+
if "/" in model_path and not os.path.exists(model_path):
|
| 51 |
+
from huggingface_hub import snapshot_download
|
| 52 |
+
if force_download:
|
| 53 |
+
return snapshot_download(model_path)
|
| 54 |
+
try:
|
| 55 |
+
path = snapshot_download(model_path, local_files_only=True)
|
| 56 |
+
if not _is_model_cache_complete(path):
|
| 57 |
+
return snapshot_download(model_path)
|
| 58 |
+
return path
|
| 59 |
+
except Exception:
|
| 60 |
+
return snapshot_download(model_path)
|
| 61 |
+
return model_path
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def resolve_and_load(model_path: str, loader: Callable[[str, bool], T]) -> T:
|
| 65 |
+
"""
|
| 66 |
+
先确保模型本地可用,再加载。加载时始终使用 local_files_only=True。
|
| 67 |
+
"""
|
| 68 |
+
path = ensure_model_local(model_path)
|
| 69 |
+
return loader(path, True)
|
backend/logging_config.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
日志配置模块
|
| 3 |
+
统一管理应用的日志配置
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def configure_logging(app=None):
|
| 10 |
+
"""
|
| 11 |
+
配置应用日志:完全屏蔽所有连接和请求相关的日志
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
app: Connexion/Flask 应用实例(可选)
|
| 15 |
+
"""
|
| 16 |
+
# 屏蔽第三方库的日志
|
| 17 |
+
logging.getLogger('werkzeug').setLevel(logging.CRITICAL)
|
| 18 |
+
logging.getLogger('connexion').setLevel(logging.CRITICAL)
|
| 19 |
+
logging.getLogger('flask_cors').setLevel(logging.CRITICAL)
|
| 20 |
+
logging.getLogger('flask').setLevel(logging.CRITICAL)
|
| 21 |
+
logging.getLogger('urllib3').setLevel(logging.CRITICAL)
|
| 22 |
+
logging.getLogger('transformers').setLevel(logging.CRITICAL)
|
| 23 |
+
logging.getLogger('torch').setLevel(logging.CRITICAL)
|
| 24 |
+
|
| 25 |
+
# 设置根日志级别,只显示严重错误
|
| 26 |
+
logging.basicConfig(level=logging.CRITICAL, format='%(message)s')
|
| 27 |
+
|
| 28 |
+
# 配置 Flask app logger(如果提供了应用实例)
|
| 29 |
+
if app:
|
| 30 |
+
try:
|
| 31 |
+
app.app.logger.setLevel(logging.CRITICAL)
|
| 32 |
+
# 禁用 Werkzeug 的访问日志
|
| 33 |
+
import werkzeug.serving
|
| 34 |
+
werkzeug.serving.WSGIRequestHandler.log_request = lambda *args, **kwargs: None
|
| 35 |
+
except Exception:
|
| 36 |
+
pass
|
| 37 |
+
|
backend/model_loader.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Causal LM 模型加载:设备策略与加载逻辑统一封装
|
| 3 |
+
|
| 4 |
+
供 language_checker.QwenLM(信息密度分析)与 model_manager.ensure_model_loaded 共用,
|
| 5 |
+
消除重复的设备分支、量化配置、加载后处理等逻辑。
|
| 6 |
+
|
| 7 |
+
加载策略说明:
|
| 8 |
+
- INT8 量化:bitsandbytes 8bit,device_map="cpu"/"auto",减少约 4 倍内存
|
| 9 |
+
- CPU 手动模式:无 device_map,.to(device),默认 float32
|
| 10 |
+
- GPU/MPS 自动模式:device_map="auto",float16
|
| 11 |
+
|
| 12 |
+
dtype/设备与因果 LM 在「仅前缀 forward」vs「整段 forward」同位置 logits 的逐元素对比:
|
| 13 |
+
float32(CPU)常完全一致;float16(MPS/CUDA)可能因实现路径出现约 1e-2 量级差,非掩码错误。
|
| 14 |
+
复现与说明见 scripts/reproduce_logits_triple_path.py、scripts/prove_fp16_gemm_shape_sensitivity.py。
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import time
|
| 19 |
+
from typing import Any, Dict, Optional
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 23 |
+
from transformers.utils import is_flash_attn_2_available
|
| 24 |
+
|
| 25 |
+
from .device import DeviceManager
|
| 26 |
+
from .load_utils import resolve_and_load
|
| 27 |
+
from .quantization_config import get_quantization_config
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_device_load_strategy(device: torch.device) -> Dict[str, Any]:
|
| 31 |
+
"""
|
| 32 |
+
根据设备推断加载策略(device_map、dtype、use_int8 等)。
|
| 33 |
+
|
| 34 |
+
打印设备模式说明,与 QwenLM 风格一致。
|
| 35 |
+
环境变量:FORCE_INT8=1 / CPU_FORCE_BFLOAT16=1
|
| 36 |
+
返回供 load_causal_lm 使用的参数字典。
|
| 37 |
+
"""
|
| 38 |
+
qconfig = get_quantization_config(device)
|
| 39 |
+
use_int8 = qconfig.use_int8
|
| 40 |
+
device_map = None
|
| 41 |
+
dtype = qconfig.dtype
|
| 42 |
+
use_low_cpu_mem = False
|
| 43 |
+
|
| 44 |
+
if device.type == "cpu":
|
| 45 |
+
print("🔧 CPU 模式:手动控制设备分配")
|
| 46 |
+
if use_int8:
|
| 47 |
+
device_map = "cpu"
|
| 48 |
+
print("⚠️ 启用 INT8 量化(FORCE_INT8=1,实验性,在某些情况下会降低性能)")
|
| 49 |
+
elif dtype == torch.bfloat16:
|
| 50 |
+
use_low_cpu_mem = True
|
| 51 |
+
print("⚠️ 启用 bfloat16(CPU_FORCE_BFLOAT16=1,需硬件支持 AVX-512_BF16 或 AMX,否则可能极慢)")
|
| 52 |
+
else:
|
| 53 |
+
use_low_cpu_mem = True
|
| 54 |
+
print("🔧 dtype: float32") # 默认: float32
|
| 55 |
+
elif device.type == "cuda":
|
| 56 |
+
print("🔧 CUDA 模式:自动设备分配")
|
| 57 |
+
device_map = "auto"
|
| 58 |
+
use_low_cpu_mem = True
|
| 59 |
+
if use_int8:
|
| 60 |
+
print("⚠️ 启用 INT8 量化(FORCE_INT8=1)")
|
| 61 |
+
else:
|
| 62 |
+
print("🔧 dtype: float16")
|
| 63 |
+
print("🔧 device_map: auto")
|
| 64 |
+
else:
|
| 65 |
+
# MPS 模式:自动设备分配 + float16(MPS 不支持 INT8 量化)
|
| 66 |
+
print(f"🔧 {device.type.upper()} 模式:自动设备分配")
|
| 67 |
+
if os.environ.get("FORCE_INT8") == "1":
|
| 68 |
+
print("⚠️ MPS 不支持 INT8 量化,已忽略 FORCE_INT8=1 环境变量")
|
| 69 |
+
device_map = "auto"
|
| 70 |
+
use_low_cpu_mem = True
|
| 71 |
+
print("🔧 dtype: float16")
|
| 72 |
+
print("🔧 device_map: auto")
|
| 73 |
+
|
| 74 |
+
return {
|
| 75 |
+
"device_map": device_map,
|
| 76 |
+
"dtype": dtype,
|
| 77 |
+
"use_low_cpu_mem": use_low_cpu_mem,
|
| 78 |
+
"use_int8": use_int8,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def attn_implementation_for_device(device: torch.device) -> str:
|
| 83 |
+
"""
|
| 84 |
+
非 CUDA:eager,兼容性最好(CPU / MPS 等)。
|
| 85 |
+
CUDA:已安装 flash-attn 时用 flash_attention_2;否则 eager(不使用 sdpa)。
|
| 86 |
+
"""
|
| 87 |
+
if device.type != "cuda":
|
| 88 |
+
return "eager"
|
| 89 |
+
if is_flash_attn_2_available():
|
| 90 |
+
return "flash_attention_2"
|
| 91 |
+
return "eager"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load_causal_lm(
|
| 95 |
+
model_path: str,
|
| 96 |
+
device: torch.device,
|
| 97 |
+
*,
|
| 98 |
+
attn_implementation: Optional[str] = None,
|
| 99 |
+
extra_model_kwargs: Optional[Dict[str, Any]] = None,
|
| 100 |
+
) -> torch.nn.Module:
|
| 101 |
+
"""
|
| 102 |
+
加载 Causal LM 模型,统一处理设备策略、量化、加载后处理。
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
model_path: HuggingFace 模型路径或本地路径
|
| 106 |
+
device: 目标设备
|
| 107 |
+
attn_implementation: 可选;未传时可在外层用 attn_implementation_for_device(device)
|
| 108 |
+
extra_model_kwargs: 可选,额外传给 from_pretrained 的参数
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
已 eval() 的模型
|
| 112 |
+
"""
|
| 113 |
+
strategy = get_device_load_strategy(device)
|
| 114 |
+
device_map = strategy["device_map"]
|
| 115 |
+
dtype = strategy["dtype"]
|
| 116 |
+
use_low_cpu_mem = strategy["use_low_cpu_mem"]
|
| 117 |
+
use_int8 = strategy["use_int8"]
|
| 118 |
+
|
| 119 |
+
load_kw: Dict[str, Any] = {
|
| 120 |
+
"trust_remote_code": True,
|
| 121 |
+
"low_cpu_mem_usage": use_low_cpu_mem or use_int8,
|
| 122 |
+
}
|
| 123 |
+
if attn_implementation is not None:
|
| 124 |
+
load_kw["attn_implementation"] = attn_implementation
|
| 125 |
+
if extra_model_kwargs:
|
| 126 |
+
load_kw.update(extra_model_kwargs)
|
| 127 |
+
|
| 128 |
+
def _load(path: str, lf: bool):
|
| 129 |
+
kw = dict(local_files_only=lf, **load_kw)
|
| 130 |
+
if use_int8:
|
| 131 |
+
from transformers import BitsAndBytesConfig
|
| 132 |
+
return AutoModelForCausalLM.from_pretrained(
|
| 133 |
+
path,
|
| 134 |
+
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
| 135 |
+
device_map=device_map,
|
| 136 |
+
**kw,
|
| 137 |
+
)
|
| 138 |
+
if device_map:
|
| 139 |
+
return AutoModelForCausalLM.from_pretrained(
|
| 140 |
+
path,
|
| 141 |
+
device_map=device_map,
|
| 142 |
+
dtype=dtype,
|
| 143 |
+
**kw,
|
| 144 |
+
)
|
| 145 |
+
return AutoModelForCausalLM.from_pretrained(
|
| 146 |
+
path, dtype=dtype, **kw
|
| 147 |
+
).to(device)
|
| 148 |
+
|
| 149 |
+
t0 = time.perf_counter()
|
| 150 |
+
model = resolve_and_load(model_path, _load)
|
| 151 |
+
load_time = time.perf_counter() - t0
|
| 152 |
+
|
| 153 |
+
DeviceManager.print_model_load_stats(model, load_time)
|
| 154 |
+
model.eval()
|
| 155 |
+
if device.type == "cuda":
|
| 156 |
+
device_idx = device.index if device.index is not None else 0
|
| 157 |
+
DeviceManager.print_cuda_memory_summary(device=device_idx)
|
| 158 |
+
return model
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def load_tokenizer(model_path: str):
|
| 162 |
+
"""加载 tokenizer。本地优先时先解析为缓存路径,避免 tokenizer 内部 model_info 联网。"""
|
| 163 |
+
|
| 164 |
+
def _load(path: str, lf: bool):
|
| 165 |
+
return AutoTokenizer.from_pretrained(
|
| 166 |
+
path, trust_remote_code=True, local_files_only=lf
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
return resolve_and_load(model_path, _load)
|
backend/model_manager.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""模型管理模块:主槽位与语义槽位对称配置,权重缓存共用。"""
|
| 2 |
+
from enum import Enum
|
| 3 |
+
import threading
|
| 4 |
+
|
| 5 |
+
from backend import REGISTERED_MODELS
|
| 6 |
+
from backend.project_registry import ModelRegistry
|
| 7 |
+
from backend.device import DeviceManager
|
| 8 |
+
from backend.model_loader import attn_implementation_for_device, load_causal_lm, load_tokenizer
|
| 9 |
+
|
| 10 |
+
from model_paths import DEFAULT_MODEL, DEFAULT_SEMANTIC_MODEL, resolve_hf_path
|
| 11 |
+
|
| 12 |
+
project_registry = ModelRegistry(REGISTERED_MODELS)
|
| 13 |
+
_init_lock = threading.Lock()
|
| 14 |
+
|
| 15 |
+
# 统一推理锁:信息密度分析与 Semantic 分析共用,确保模型推理串行执行
|
| 16 |
+
_inference_lock = threading.Lock()
|
| 17 |
+
|
| 18 |
+
# 按 HuggingFace 路径去重的已加载模型缓存(主分析 / 语义 / 续写共用)
|
| 19 |
+
_hf_load_lock = threading.Lock()
|
| 20 |
+
_hf_loaded: dict[str, tuple] = {}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ModelSlot(str, Enum):
|
| 24 |
+
"""与 CLI --model / --semantic_model 对应的两个对等槽位。"""
|
| 25 |
+
|
| 26 |
+
MAIN = "main"
|
| 27 |
+
SEMANTIC = "semantic"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# 启动预载与「全部权重」枚举时使用的槽位顺序(对等、无主次)
|
| 31 |
+
CONFIGURED_SLOTS: tuple[ModelSlot, ...] = (ModelSlot.MAIN, ModelSlot.SEMANTIC)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _resolved_hf_path_for_slot(slot: ModelSlot) -> str:
|
| 35 |
+
"""由应用上下文解析槽位对应的 HuggingFace 路径(或本地路径字符串)。"""
|
| 36 |
+
if slot == ModelSlot.MAIN:
|
| 37 |
+
try:
|
| 38 |
+
from backend.app_context import get_app_context
|
| 39 |
+
|
| 40 |
+
context = get_app_context(prefer_module_context=True)
|
| 41 |
+
model_name = context.model_name or DEFAULT_MODEL
|
| 42 |
+
except RuntimeError:
|
| 43 |
+
model_name = DEFAULT_MODEL
|
| 44 |
+
return resolve_hf_path(model_name)
|
| 45 |
+
if slot == ModelSlot.SEMANTIC:
|
| 46 |
+
try:
|
| 47 |
+
from backend.app_context import get_args
|
| 48 |
+
|
| 49 |
+
raw = getattr(get_args(), "semantic_model", DEFAULT_SEMANTIC_MODEL)
|
| 50 |
+
except RuntimeError:
|
| 51 |
+
raw = DEFAULT_SEMANTIC_MODEL
|
| 52 |
+
return resolve_hf_path(raw)
|
| 53 |
+
raise ValueError(f"unknown ModelSlot: {slot!r}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def ensure_slot_weights_loaded(slot: ModelSlot):
|
| 57 |
+
"""
|
| 58 |
+
加载指定槽位权重(若未缓存);主 / 语义完全相同的入口。
|
| 59 |
+
返回 (tokenizer, model, device)。
|
| 60 |
+
"""
|
| 61 |
+
return ensure_model_loaded(_resolved_hf_path_for_slot(slot))
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def ensure_model_loaded(resolved_hf_path: str):
|
| 65 |
+
"""
|
| 66 |
+
唯一底层加载入口:保证 resolved_hf_path 对应权重已加载。
|
| 67 |
+
返回 (tokenizer, model, device),其中 device 为模型参数所在 device。
|
| 68 |
+
"""
|
| 69 |
+
with _hf_load_lock:
|
| 70 |
+
hit = _hf_loaded.get(resolved_hf_path)
|
| 71 |
+
if hit is not None:
|
| 72 |
+
return hit
|
| 73 |
+
|
| 74 |
+
device = DeviceManager.get_device()
|
| 75 |
+
display = resolved_hf_path.split("/")[-1] if "/" in resolved_hf_path else resolved_hf_path
|
| 76 |
+
print(f"📦 正在加载模型权重: {display}")
|
| 77 |
+
tokenizer = load_tokenizer(resolved_hf_path)
|
| 78 |
+
model = load_causal_lm(
|
| 79 |
+
resolved_hf_path,
|
| 80 |
+
device,
|
| 81 |
+
attn_implementation=attn_implementation_for_device(device),
|
| 82 |
+
)
|
| 83 |
+
for p in model.parameters():
|
| 84 |
+
p.requires_grad_(False)
|
| 85 |
+
model_device = next(model.parameters()).device
|
| 86 |
+
device_name = DeviceManager.get_device_name(device)
|
| 87 |
+
print(f"✓ {display} 已加载 ({device_name})")
|
| 88 |
+
out = (tokenizer, model, model_device)
|
| 89 |
+
_hf_loaded[resolved_hf_path] = out
|
| 90 |
+
return out
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def ensure_project_loaded(project_name: str):
|
| 94 |
+
"""确保项目已加载,如果未加载则加载它"""
|
| 95 |
+
if not project_name:
|
| 96 |
+
raise ValueError("model name is required")
|
| 97 |
+
if not project_registry.is_available(project_name):
|
| 98 |
+
raise KeyError(project_name)
|
| 99 |
+
try:
|
| 100 |
+
return project_registry.ensure_loaded(project_name)
|
| 101 |
+
except KeyError:
|
| 102 |
+
# Re-raise to allow caller to format message uniformly.
|
| 103 |
+
raise
|
| 104 |
+
except Exception as exc: # noqa: BLE001 - propagate detailed message
|
| 105 |
+
raise RuntimeError(f"模型 '{project_name}' 加载失败: {exc}") from exc
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _register_main_qwenlm_if_needed():
|
| 109 |
+
"""
|
| 110 |
+
信息密度路径:在 MAIN 槽位权重已就绪后,注册 project_registry 中的 QwenLM 实例。
|
| 111 |
+
语义槽位无对应 registry 包装,故仅此槽位需要。
|
| 112 |
+
"""
|
| 113 |
+
from backend.app_context import get_app_context
|
| 114 |
+
|
| 115 |
+
context = get_app_context(prefer_module_context=True)
|
| 116 |
+
selected_name = context.model_name
|
| 117 |
+
|
| 118 |
+
if not selected_name:
|
| 119 |
+
raise ValueError("未指定模型名称")
|
| 120 |
+
|
| 121 |
+
if selected_name in project_registry:
|
| 122 |
+
_ensure_default_project_ready(selected_name)
|
| 123 |
+
return
|
| 124 |
+
|
| 125 |
+
if not project_registry.is_available(selected_name):
|
| 126 |
+
raise KeyError(f"模型 '{selected_name}' 未找到,可用模型: {list(REGISTERED_MODELS.keys())}")
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
project_registry.load(selected_name)
|
| 130 |
+
_ensure_default_project_ready(selected_name)
|
| 131 |
+
except Exception as exc: # noqa: BLE001
|
| 132 |
+
raise RuntimeError(f"模型 '{selected_name}' 加载失败: {exc}") from exc
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def preload_all_slots():
|
| 136 |
+
"""
|
| 137 |
+
启动预载(非 --no_auto_load):对 CONFIGURED_SLOTS 各解析 HF 路径,去重后加载全部权重,
|
| 138 |
+
再注册主槽位 QwenLM 项目。两槽位在「先加载权重」层面完全对等。
|
| 139 |
+
"""
|
| 140 |
+
from backend.app_context import get_app_context
|
| 141 |
+
|
| 142 |
+
get_app_context(prefer_module_context=True)
|
| 143 |
+
|
| 144 |
+
paths = {_resolved_hf_path_for_slot(s) for s in CONFIGURED_SLOTS}
|
| 145 |
+
|
| 146 |
+
with _init_lock:
|
| 147 |
+
for path in paths:
|
| 148 |
+
ensure_model_loaded(path)
|
| 149 |
+
_register_main_qwenlm_if_needed()
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def ensure_slot_ready(slot: ModelSlot):
|
| 153 |
+
"""
|
| 154 |
+
槽位业务就绪(对称 API):保证该槽位后续推理所需状态已备好。
|
| 155 |
+
|
| 156 |
+
- 两槽位均先保证 HF 权重已加载,返回 (tokenizer, model, device)。
|
| 157 |
+
- MAIN 另需将 QwenLM 挂入 project_registry(信息密度管线);SEMANTIC 无 registry 步骤。
|
| 158 |
+
|
| 159 |
+
懒加载时:信息密度调 ensure_main_slot_ready();语义/续写调 ensure_semantic_slot_ready()。
|
| 160 |
+
"""
|
| 161 |
+
from backend.app_context import get_app_context
|
| 162 |
+
|
| 163 |
+
get_app_context(prefer_module_context=True)
|
| 164 |
+
|
| 165 |
+
if slot == ModelSlot.MAIN:
|
| 166 |
+
with _init_lock:
|
| 167 |
+
out = ensure_slot_weights_loaded(ModelSlot.MAIN)
|
| 168 |
+
_register_main_qwenlm_if_needed()
|
| 169 |
+
return out
|
| 170 |
+
if slot == ModelSlot.SEMANTIC:
|
| 171 |
+
return ensure_slot_weights_loaded(ModelSlot.SEMANTIC)
|
| 172 |
+
raise ValueError(f"unknown ModelSlot: {slot!r}")
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def ensure_main_slot_ready():
|
| 176 |
+
"""懒加载首次信息密度:同 ensure_slot_ready(ModelSlot.MAIN)。"""
|
| 177 |
+
return ensure_slot_ready(ModelSlot.MAIN)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def ensure_semantic_slot_ready():
|
| 181 |
+
"""懒加载首次语义类请求:同 ensure_slot_ready(ModelSlot.SEMANTIC)。"""
|
| 182 |
+
return ensure_slot_ready(ModelSlot.SEMANTIC)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def get_current_model_max_token_length() -> int:
|
| 186 |
+
"""
|
| 187 |
+
查询当前生效模型的 max_token_length 参数。
|
| 188 |
+
优先从已加载的模型实例获取,未加载时取 default_model.default_cpu_machine 配置。
|
| 189 |
+
"""
|
| 190 |
+
from backend.app_context import get_app_context
|
| 191 |
+
from backend.runtime_config import RUNTIME_CONFIGS
|
| 192 |
+
|
| 193 |
+
try:
|
| 194 |
+
context = get_app_context(prefer_module_context=True)
|
| 195 |
+
model_name = context.model_name or DEFAULT_MODEL
|
| 196 |
+
except RuntimeError:
|
| 197 |
+
model_name = "default_model"
|
| 198 |
+
|
| 199 |
+
project = project_registry.get(model_name)
|
| 200 |
+
if project is not None and hasattr(project.lm, "max_length"):
|
| 201 |
+
return project.lm.max_length
|
| 202 |
+
return RUNTIME_CONFIGS["default_model"]["default_cpu_machine"]["max_token_length"]
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def _ensure_default_project_ready(selected_name: str):
|
| 206 |
+
"""确保默认项目已准备好"""
|
| 207 |
+
if not selected_name:
|
| 208 |
+
return
|
| 209 |
+
if selected_name in project_registry:
|
| 210 |
+
return
|
| 211 |
+
print(f"⚠️ 默认模型未缓存,正在预加载: {selected_name}")
|
| 212 |
+
project_registry.ensure_loaded(selected_name)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# 旧名保留(与槽位就绪 API 等价)
|
| 216 |
+
ensure_semantic_loaded = ensure_semantic_slot_ready
|
| 217 |
+
ensure_main_project_ready = ensure_main_slot_ready
|
| 218 |
+
|
| 219 |
+
def get_semantic_model_display_name() -> str:
|
| 220 |
+
"""返回 semantic 槽位 HuggingFace 路径(用于结果中的 model 字段)"""
|
| 221 |
+
return _resolved_hf_path_for_slot(ModelSlot.SEMANTIC)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def ensure_main_model_loaded():
|
| 225 |
+
"""
|
| 226 |
+
仅需主模型前向、且不必经过 project_registry 时(如 attribution):MAIN 槽位权重。
|
| 227 |
+
"""
|
| 228 |
+
return ensure_slot_weights_loaded(ModelSlot.MAIN)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def get_main_model_display_name() -> str:
|
| 232 |
+
"""返回主槽位 HuggingFace 路径(用于结果中的 model 字段)"""
|
| 233 |
+
return _resolved_hf_path_for_slot(ModelSlot.MAIN)
|
backend/next_token_topk.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
下一 token 的 top-k 解码:与语义分析 logits_gradient 一致,供 semantic / attribution 复用。
|
| 3 |
+
"""
|
| 4 |
+
from typing import List, Tuple
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from .api.utils import round_to_sig_figs
|
| 9 |
+
|
| 10 |
+
DEFAULT_NEXT_TOKEN_TOPK = 10
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def decode_topk_ids_to_strings_and_rounded_probs(
|
| 14 |
+
probs_1d: torch.Tensor,
|
| 15 |
+
tokenizer,
|
| 16 |
+
topk_ids_1d: torch.Tensor,
|
| 17 |
+
) -> Tuple[List[str], List[float]]:
|
| 18 |
+
"""
|
| 19 |
+
probs_1d: 对单位置 logits 的 softmax,shape [vocab_size]。
|
| 20 |
+
topk_ids_1d: torch.topk(logits, k) 返回的 indices,shape [k]。
|
| 21 |
+
返回与语义分析 debug_info 相同形态的 topk_tokens、topk_probs(概率已 round_to_sig_figs)。
|
| 22 |
+
"""
|
| 23 |
+
ids_list = topk_ids_1d.tolist()
|
| 24 |
+
topk_tokens = [tokenizer.decode([int(tid)]) for tid in ids_list]
|
| 25 |
+
topk_probs = [round_to_sig_figs(probs_1d[int(tid)].item()) for tid in ids_list]
|
| 26 |
+
return topk_tokens, topk_probs
|
backend/oom.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""OOM 处理:MPS/CUDA 显存或 CPU 内存不足时退出进程,由进程管理器重启"""
|
| 2 |
+
import os
|
| 3 |
+
import threading
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _check_oom_msg(msg: str) -> bool:
|
| 8 |
+
patterns = (
|
| 9 |
+
"out of memory",
|
| 10 |
+
"out of memory error",
|
| 11 |
+
"memory allocation",
|
| 12 |
+
"cannot allocate memory",
|
| 13 |
+
"insufficient memory",
|
| 14 |
+
"ran out of memory",
|
| 15 |
+
"resource exhausted",
|
| 16 |
+
"cuda error: out of memory",
|
| 17 |
+
"mps backend out of memory",
|
| 18 |
+
)
|
| 19 |
+
return any(p in msg.lower() for p in patterns)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def is_oom_error(e: Exception) -> bool:
|
| 23 |
+
"""检测是否为 OOM(含 MPS/CUDA 显存、CPU 内存),此类错误后进程无法恢复,需重启"""
|
| 24 |
+
if isinstance(e, MemoryError):
|
| 25 |
+
return True
|
| 26 |
+
if _check_oom_msg(str(e)):
|
| 27 |
+
return True
|
| 28 |
+
# 检查异常链(如被 RuntimeError 包装的 OOM)
|
| 29 |
+
for exc in (getattr(e, "__cause__", None), getattr(e, "__context__", None)):
|
| 30 |
+
if exc is not None and (isinstance(exc, MemoryError) or _check_oom_msg(str(exc))):
|
| 31 |
+
return True
|
| 32 |
+
return False
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def exit_if_oom(e: Exception, defer_seconds: float = 0) -> None:
|
| 36 |
+
"""若为 OOM 则退出进程,由进程管理器重启以恢复内存。
|
| 37 |
+
|
| 38 |
+
defer_seconds: 延迟退出秒数,用于先返回错误响应再退出(非流式需 > 0)
|
| 39 |
+
"""
|
| 40 |
+
if not is_oom_error(e):
|
| 41 |
+
return
|
| 42 |
+
msg = f"🛑 OOM 检测到,进程退出以便重启: {e}"
|
| 43 |
+
if defer_seconds > 0:
|
| 44 |
+
msg = f"🛑 OOM 检测到,{defer_seconds}s 后进程退出以便重启: {e}"
|
| 45 |
+
print(msg)
|
| 46 |
+
|
| 47 |
+
def _exit():
|
| 48 |
+
if defer_seconds > 0:
|
| 49 |
+
time.sleep(defer_seconds)
|
| 50 |
+
os._exit(1)
|
| 51 |
+
|
| 52 |
+
if defer_seconds > 0:
|
| 53 |
+
threading.Thread(target=_exit, daemon=False).start()
|
| 54 |
+
else:
|
| 55 |
+
os._exit(1)
|
backend/path_utils.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
路径处理工具模块
|
| 3 |
+
统一管理路径验证、规范化、解析等逻辑
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def normalize_path(path: str) -> str:
|
| 12 |
+
"""
|
| 13 |
+
统一处理路径:将空字符串转换为 "/"
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
path: 输入路径
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
规范化后的路径
|
| 20 |
+
"""
|
| 21 |
+
return path if path else "/"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def check_path_in_demo_dir(path: Path, demo_dir: Path) -> bool:
|
| 25 |
+
"""
|
| 26 |
+
检查路径是否在demo目录内(Python 3.8兼容)
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
path: 要检查的路径
|
| 30 |
+
demo_dir: demo目录路径
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
True 如果路径在demo目录内
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
return path.is_relative_to(demo_dir)
|
| 37 |
+
except AttributeError:
|
| 38 |
+
# Python 3.8兼容性:使用os.path.commonpath
|
| 39 |
+
path_str = str(path)
|
| 40 |
+
demo_dir_str = str(demo_dir)
|
| 41 |
+
common = os.path.commonpath([path_str, demo_dir_str])
|
| 42 |
+
return common == demo_dir_str
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def validate_demo_path(path: str, demo_dir: Path) -> bool:
|
| 46 |
+
"""
|
| 47 |
+
验证路径安全性,防止路径遍历攻击
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
path: 要验证的相对路径
|
| 51 |
+
demo_dir: demo目录的绝对路径
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
True 如果路径安全
|
| 55 |
+
"""
|
| 56 |
+
if not path or path == "/":
|
| 57 |
+
return True
|
| 58 |
+
|
| 59 |
+
# 移除首尾斜杠并规范化路径
|
| 60 |
+
normalized_path = path.strip('/').replace('\\', '/')
|
| 61 |
+
|
| 62 |
+
# 检查路径是否包含 ".." 或其他危险字符
|
| 63 |
+
if '..' in normalized_path.split('/'):
|
| 64 |
+
return False
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
resolved_path = (demo_dir / normalized_path).resolve()
|
| 68 |
+
demo_dir_resolved = demo_dir.resolve()
|
| 69 |
+
return check_path_in_demo_dir(resolved_path, demo_dir_resolved)
|
| 70 |
+
except Exception:
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def resolve_demo_path(demo_dir: Path, path: str) -> Optional[Path]:
|
| 75 |
+
"""
|
| 76 |
+
解析并验证路径,返回绝对路径
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
demo_dir: demo目录的绝对路径
|
| 80 |
+
path: 要解析的相对路径
|
| 81 |
+
|
| 82 |
+
Returns:
|
| 83 |
+
解析后的绝对路径,验证失败则返回 None
|
| 84 |
+
"""
|
| 85 |
+
if not validate_demo_path(path, demo_dir):
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
if not path or path == "/":
|
| 89 |
+
return demo_dir
|
| 90 |
+
|
| 91 |
+
return (demo_dir / path.lstrip('/')).resolve()
|
| 92 |
+
|
backend/pred_topk_format.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
pred_topk 列表的格式化:与 language_checker 中 batch_decode + round_to_sig_figs 语义一致,供信息密度与续写共用。
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from backend.api.utils import round_to_sig_figs
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def pred_topk_pairs_from_flat_ids_and_probs(
|
| 13 |
+
ids_flat: List[int],
|
| 14 |
+
probs_flat: List[float],
|
| 15 |
+
tokenizer,
|
| 16 |
+
) -> List[Tuple[str, float]]:
|
| 17 |
+
"""
|
| 18 |
+
对 torch.topk 展平后的 id / 概率序列解码为 [(token 文本, 概率), ...]。
|
| 19 |
+
与 QwenLM._decode_topk_tokens 内层逻辑一致(单次 batch_decode)。
|
| 20 |
+
"""
|
| 21 |
+
if len(ids_flat) != len(probs_flat):
|
| 22 |
+
raise ValueError("ids_flat 与 probs_flat 长度须一致")
|
| 23 |
+
if not ids_flat:
|
| 24 |
+
return []
|
| 25 |
+
decoded = tokenizer.batch_decode([[tid] for tid in ids_flat], skip_special_tokens=False)
|
| 26 |
+
return [
|
| 27 |
+
(decoded[j], round_to_sig_figs(float(probs_flat[j])))
|
| 28 |
+
for j in range(len(ids_flat))
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def pred_topk_pairs_from_probs_1d(
|
| 33 |
+
probs: torch.Tensor,
|
| 34 |
+
tokenizer,
|
| 35 |
+
top_k: int,
|
| 36 |
+
) -> List[Tuple[str, float]]:
|
| 37 |
+
"""单步 1D softmax 概率向量上的 top-k,用于续写 generate 的每步 scores。"""
|
| 38 |
+
top_k = min(int(top_k), int(probs.numel()))
|
| 39 |
+
if top_k <= 0:
|
| 40 |
+
return []
|
| 41 |
+
topk_probs, topk_ids = torch.topk(probs, top_k, dim=-1)
|
| 42 |
+
ids_flat = topk_ids.cpu().flatten().tolist()
|
| 43 |
+
probs_flat = topk_probs.detach().cpu().float().numpy().flatten().tolist()
|
| 44 |
+
return pred_topk_pairs_from_flat_ids_and_probs(ids_flat, probs_flat, tokenizer)
|
backend/prediction_attributor.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
预测归因:对任意上下文的下一个 token 预测,计算指定候选 token 的 logit
|
| 3 |
+
对输入各 token embedding 的梯度,以梯度 L2 范数作为归因分。
|
| 4 |
+
|
| 5 |
+
由请求参数 `model` 选择权重槽位:base 为主槽位(--model),instruct 为语义槽位(--semantic_model)。
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import math
|
| 9 |
+
from typing import Dict, Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from .api.utils import round_to_sig_figs
|
| 14 |
+
from .device import DeviceManager
|
| 15 |
+
from .model_manager import (
|
| 16 |
+
ModelSlot,
|
| 17 |
+
ensure_slot_weights_loaded,
|
| 18 |
+
get_main_model_display_name,
|
| 19 |
+
get_semantic_model_display_name,
|
| 20 |
+
)
|
| 21 |
+
from .next_token_topk import decode_topk_ids_to_strings_and_rounded_probs, DEFAULT_NEXT_TOKEN_TOPK
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _get_gradient_checkpointing() -> bool:
|
| 25 |
+
"""默认 True;``--no-gradient-checkpointing`` 关闭。"""
|
| 26 |
+
try:
|
| 27 |
+
from backend.app_context import get_args
|
| 28 |
+
|
| 29 |
+
return getattr(get_args(), "gradient_checkpointing", True)
|
| 30 |
+
except RuntimeError:
|
| 31 |
+
return True
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# 归因输入长度上限(token 数);超长则报错
|
| 35 |
+
ATTRIBUTION_MAX_TOKEN_LENGTH = 500
|
| 36 |
+
|
| 37 |
+
# 与 API 请求体 `model` 一致:base=主槽位,instruct=语义槽位
|
| 38 |
+
PREDICTION_ATTR_MODEL_BASE = "base"
|
| 39 |
+
PREDICTION_ATTR_MODEL_INSTRUCT = "instruct"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _slot_for_prediction_attr_model(model: str) -> ModelSlot:
|
| 43 |
+
if model == PREDICTION_ATTR_MODEL_BASE:
|
| 44 |
+
return ModelSlot.MAIN
|
| 45 |
+
if model == PREDICTION_ATTR_MODEL_INSTRUCT:
|
| 46 |
+
return ModelSlot.SEMANTIC
|
| 47 |
+
raise ValueError(
|
| 48 |
+
f"Unsupported model {model!r}; only {PREDICTION_ATTR_MODEL_BASE!r} and "
|
| 49 |
+
f"{PREDICTION_ATTR_MODEL_INSTRUCT!r} are supported."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def analyze_prediction_attribution(
|
| 54 |
+
context: str, target_prediction: Optional[str] = None, *, model: str
|
| 55 |
+
) -> Dict:
|
| 56 |
+
"""
|
| 57 |
+
计算 context 中各 token 对 target_prediction 首 token 预测的归因分。
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
context: 输入上下文文本(token 数不得超过 ATTRIBUTION_MAX_TOKEN_LENGTH,否则抛 ValueError)
|
| 61 |
+
target_prediction: 目标预测文本;tokenize 后取第一个 token 作为归因目标。
|
| 62 |
+
省略或传 None 时自动使用 top-1(贪心解码)。
|
| 63 |
+
model: ``base`` 为主槽位权重,``instruct`` 为语义槽位权重(与 API 请求体一致)
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
{
|
| 67 |
+
"model": str,
|
| 68 |
+
"target_token": str, # 归因目标 token 的字符串
|
| 69 |
+
"target_prob": float, # 该 token 在 next-token 分布中的预测概率
|
| 70 |
+
"token_attribution": [{"offset": [s, e], "raw": str, "score": float}, ...],
|
| 71 |
+
"debug_info": {"topk_tokens": [...], "topk_probs": [...]}, # 与语义分析同形(下一 token top10)
|
| 72 |
+
"is_eos": bool, # target_token 是否为 EOS token
|
| 73 |
+
}
|
| 74 |
+
"""
|
| 75 |
+
slot = _slot_for_prediction_attr_model(model)
|
| 76 |
+
tokenizer, hf_model, device = ensure_slot_weights_loaded(slot)
|
| 77 |
+
model_display = (
|
| 78 |
+
get_main_model_display_name() if slot == ModelSlot.MAIN else get_semantic_model_display_name()
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# 归因目标 id 仅在前向得到 logits 后解析:top-1 用 argmax;显式 target 用 encode(可与 argmax 不同)。
|
| 82 |
+
use_top1 = target_prediction is None
|
| 83 |
+
|
| 84 |
+
# 对 context 编码,保留 offset_mapping 用于还原字符位置
|
| 85 |
+
enc = tokenizer(context, return_tensors="pt", return_offsets_mapping=True)
|
| 86 |
+
input_ids = enc["input_ids"].to(device)
|
| 87 |
+
offset_mapping = enc["offset_mapping"][0].tolist()
|
| 88 |
+
n_tokens = input_ids.shape[1]
|
| 89 |
+
if n_tokens > ATTRIBUTION_MAX_TOKEN_LENGTH:
|
| 90 |
+
raise ValueError(
|
| 91 |
+
"Context exceeds attribution length limit "
|
| 92 |
+
f"({ATTRIBUTION_MAX_TOKEN_LENGTH} tokens); current length is {n_tokens} tokens."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# 通过 embedding 层获取可微输入
|
| 96 |
+
embed_layer = hf_model.get_input_embeddings()
|
| 97 |
+
embeds = embed_layer(input_ids).detach().clone().requires_grad_(True)
|
| 98 |
+
|
| 99 |
+
use_gc = _get_gradient_checkpointing()
|
| 100 |
+
try:
|
| 101 |
+
hf_model.eval()
|
| 102 |
+
if use_gc:
|
| 103 |
+
hf_model.gradient_checkpointing_enable()
|
| 104 |
+
with torch.set_grad_enabled(True):
|
| 105 |
+
# 归因只需最后一步 logits,不需要 KV cache;关闭可显著降低长上下文内存峰值。
|
| 106 |
+
outputs = hf_model(inputs_embeds=embeds, output_attentions=False, use_cache=False)
|
| 107 |
+
|
| 108 |
+
# 显式同步,确保前向已完成(与 semantic logits_gradient 一致)
|
| 109 |
+
if device.type == "cuda":
|
| 110 |
+
torch.cuda.synchronize(device)
|
| 111 |
+
elif device.type == "mps":
|
| 112 |
+
torch.mps.synchronize()
|
| 113 |
+
|
| 114 |
+
logits = outputs.logits[0, -1, :] # next-token logits,shape: [vocab_size]
|
| 115 |
+
probs = torch.softmax(logits, dim=-1)
|
| 116 |
+
_, topk_ids = torch.topk(logits, DEFAULT_NEXT_TOKEN_TOPK)
|
| 117 |
+
topk_tokens, topk_probs = decode_topk_ids_to_strings_and_rounded_probs(
|
| 118 |
+
probs, tokenizer, topk_ids
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
if use_top1:
|
| 122 |
+
target_token_id = int(topk_ids[0].item())
|
| 123 |
+
target_token = tokenizer.decode([target_token_id])
|
| 124 |
+
else:
|
| 125 |
+
assert target_prediction is not None
|
| 126 |
+
target_ids = tokenizer.encode(target_prediction, add_special_tokens=False)
|
| 127 |
+
if not target_ids:
|
| 128 |
+
raise ValueError(f"Cannot tokenize target_prediction: {target_prediction!r}")
|
| 129 |
+
target_token_id = target_ids[0]
|
| 130 |
+
target_token = tokenizer.decode([target_token_id])
|
| 131 |
+
|
| 132 |
+
target_prob = round_to_sig_figs(probs[target_token_id].item())
|
| 133 |
+
|
| 134 |
+
# 对目标 token 的 raw logit 反传(不经 softmax,避免饱和与竞争污染)
|
| 135 |
+
logits[target_token_id].backward()
|
| 136 |
+
|
| 137 |
+
grad = embeds.grad
|
| 138 |
+
if grad is None:
|
| 139 |
+
raise RuntimeError(
|
| 140 |
+
"Gradient did not propagate; this model may not support attribution (e.g. int8 quantization)."
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# 显式同步,确保反向已完成后再读梯度(与 semantic logits_gradient 一致)
|
| 144 |
+
if device.type == "cuda":
|
| 145 |
+
torch.cuda.synchronize(device)
|
| 146 |
+
elif device.type == "mps":
|
| 147 |
+
torch.mps.synchronize()
|
| 148 |
+
|
| 149 |
+
norms = grad[0].float().norm(dim=-1).cpu().tolist()
|
| 150 |
+
|
| 151 |
+
# 按 offset 过滤特殊 token(BOS/EOS 的 span 长度为 0)
|
| 152 |
+
token_attribution = []
|
| 153 |
+
nan_count = 0
|
| 154 |
+
for (s, e), norm in zip(offset_mapping, norms):
|
| 155 |
+
if s >= e:
|
| 156 |
+
continue
|
| 157 |
+
if not math.isfinite(norm):
|
| 158 |
+
score = 0.0
|
| 159 |
+
nan_count += 1
|
| 160 |
+
else:
|
| 161 |
+
score = round_to_sig_figs(norm)
|
| 162 |
+
token_attribution.append({
|
| 163 |
+
"offset": [s, e],
|
| 164 |
+
"raw": context[s:e],
|
| 165 |
+
"score": score,
|
| 166 |
+
})
|
| 167 |
+
if nan_count > 0:
|
| 168 |
+
print(f"⚠️ token_attribution 中有 {nan_count} 个 score 为 NaN/Inf,已替换为 0。")
|
| 169 |
+
|
| 170 |
+
eos_id = tokenizer.eos_token_id
|
| 171 |
+
is_eos = eos_id is not None and target_token_id == int(eos_id)
|
| 172 |
+
|
| 173 |
+
return {
|
| 174 |
+
"model": model_display,
|
| 175 |
+
"target_token": target_token,
|
| 176 |
+
"target_prob": target_prob,
|
| 177 |
+
"token_attribution": token_attribution,
|
| 178 |
+
"debug_info": {"topk_tokens": topk_tokens, "topk_probs": topk_probs},
|
| 179 |
+
"is_eos": is_eos,
|
| 180 |
+
}
|
| 181 |
+
finally:
|
| 182 |
+
if use_gc:
|
| 183 |
+
hf_model.gradient_checkpointing_disable()
|
| 184 |
+
# 与 semantic_analyzer._analyze_logits_gradient 一致:每次推理后清理,避免 MPS/CUDA 累积
|
| 185 |
+
DeviceManager.clear_cache(device)
|
backend/project_registry.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Iterable, Optional, Sequence, Tuple
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ModelInstance:
|
| 5 |
+
"""Lightweight wrapper holding a configured language model instance."""
|
| 6 |
+
|
| 7 |
+
def __init__(self, model_cls, config):
|
| 8 |
+
self.config = config
|
| 9 |
+
self.lm = model_cls()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ModelRegistry:
|
| 13 |
+
"""Manages lazy loading and caching of backend language models."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, available_models: Dict[str, object]):
|
| 16 |
+
self._available_models = available_models
|
| 17 |
+
self._projects: Dict[str, ModelInstance] = {}
|
| 18 |
+
|
| 19 |
+
def __contains__(self, project_name: str) -> bool:
|
| 20 |
+
return project_name in self._projects
|
| 21 |
+
|
| 22 |
+
def get(self, project_name: str) -> Optional[ModelInstance]:
|
| 23 |
+
return self._projects.get(project_name)
|
| 24 |
+
|
| 25 |
+
def configs(self) -> Dict[str, object]:
|
| 26 |
+
return {name: project.config for name, project in self._projects.items()}
|
| 27 |
+
|
| 28 |
+
def available_model_names(self) -> Sequence[str]:
|
| 29 |
+
return tuple(self._available_models.keys())
|
| 30 |
+
|
| 31 |
+
def is_available(self, project_name: str) -> bool:
|
| 32 |
+
return project_name in self._available_models
|
| 33 |
+
|
| 34 |
+
def load(self, project_name: str) -> ModelInstance:
|
| 35 |
+
if project_name not in self._available_models:
|
| 36 |
+
raise KeyError(f"模型 '{project_name}' 未在 REGISTERED_MODELS 中注册")
|
| 37 |
+
|
| 38 |
+
project = ModelInstance(self._available_models[project_name], project_name)
|
| 39 |
+
self._projects[project_name] = project
|
| 40 |
+
return project
|
| 41 |
+
|
| 42 |
+
def ensure_loaded(self, project_name: str) -> ModelInstance:
|
| 43 |
+
"""Return a project instance, loading it if necessary."""
|
| 44 |
+
if project_name in self._projects:
|
| 45 |
+
return self._projects[project_name]
|
| 46 |
+
return self.load(project_name)
|
| 47 |
+
|
| 48 |
+
def unload(self, project_name: str) -> bool:
|
| 49 |
+
"""卸载指定模型,释放内存"""
|
| 50 |
+
if project_name in self._projects:
|
| 51 |
+
del self._projects[project_name]
|
| 52 |
+
return True
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
def ensure_any(self, candidates: Iterable[str]) -> Tuple[str, ModelInstance]:
|
| 56 |
+
"""Load (or reuse) the first successfully instantiated project."""
|
| 57 |
+
last_error: Optional[Exception] = None
|
| 58 |
+
for candidate in candidates:
|
| 59 |
+
if not candidate:
|
| 60 |
+
continue
|
| 61 |
+
if candidate in self._projects:
|
| 62 |
+
return candidate, self._projects[candidate]
|
| 63 |
+
try:
|
| 64 |
+
project = self.load(candidate)
|
| 65 |
+
return candidate, project
|
| 66 |
+
except Exception as exc: # noqa: BLE001 - bubble up aggregated info
|
| 67 |
+
last_error = exc
|
| 68 |
+
continue
|
| 69 |
+
if last_error:
|
| 70 |
+
raise last_error
|
| 71 |
+
raise ValueError("没有可用的模型!")
|
| 72 |
+
|
backend/quantization_config.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
量化配置(语义分析、信息密度分析共用)
|
| 3 |
+
|
| 4 |
+
从环境变量读取并返回设备相关的量化策略:
|
| 5 |
+
- FORCE_INT8=1: INT8 量化(CPU/CUDA 支持,MPS 不支持)
|
| 6 |
+
- CPU_FORCE_BFLOAT16=1: CPU 使用 bfloat16
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
from typing import NamedTuple
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class QuantizationConfig(NamedTuple):
|
| 16 |
+
"""量化配置,语义模型和信息密度模型共用"""
|
| 17 |
+
use_int8: bool
|
| 18 |
+
dtype: torch.dtype
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_quantization_config(device: torch.device) -> QuantizationConfig:
|
| 22 |
+
"""
|
| 23 |
+
根据设备和环境变量返回量化配置。
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
QuantizationConfig: use_int8, dtype
|
| 27 |
+
"""
|
| 28 |
+
force_int8 = os.environ.get("FORCE_INT8") == "1"
|
| 29 |
+
force_bfloat16 = os.environ.get("CPU_FORCE_BFLOAT16") == "1"
|
| 30 |
+
|
| 31 |
+
if device.type == "cpu":
|
| 32 |
+
use_int8 = force_int8
|
| 33 |
+
dtype = torch.bfloat16 if force_bfloat16 else torch.float32
|
| 34 |
+
elif device.type == "cuda":
|
| 35 |
+
use_int8 = force_int8
|
| 36 |
+
dtype = torch.float16
|
| 37 |
+
else:
|
| 38 |
+
# MPS 不支持 INT8
|
| 39 |
+
use_int8 = False
|
| 40 |
+
dtype = torch.float16
|
| 41 |
+
|
| 42 |
+
return QuantizationConfig(use_int8=use_int8, dtype=dtype)
|
backend/runtime_config.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
运行时配置管理模块
|
| 3 |
+
|
| 4 |
+
负责管理不同模型在不同平台下的运行时参数配置,包括:
|
| 5 |
+
- max_token_length: 文本分析的最大 token 数限制(信息密度分析)
|
| 6 |
+
- chunk_size: 推理时的分块大小
|
| 7 |
+
- 语义分析有独立的 SEMANTIC_RUNTIME_CONFIGS,仅含 max_token_length
|
| 8 |
+
|
| 9 |
+
平台 ID 说明:
|
| 10 |
+
- local_mps: 本地 Apple Silicon(M1/M2/M3)
|
| 11 |
+
- cloud_cuda: 云端 CUDA GPU
|
| 12 |
+
- cloud_cpu_16g: 云端大内存 CPU(如 HF Space 免费层,16G RAM)
|
| 13 |
+
- cloud_cpu_32g: 云端大内存 CPU(如 HF Space CPU upgrade,32G RAM)
|
| 14 |
+
- default_cpu_machine: 默认 CPU 机器(未知或未识别的 CPU 环境)
|
| 15 |
+
- 未来可扩展: cloud_cuda_a100, cloud_cuda_24g 等
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import torch
|
| 20 |
+
import sys
|
| 21 |
+
from typing import Dict, Optional
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# ============= 平台级常量 =============
|
| 25 |
+
|
| 26 |
+
# 分析接口的 pred_topk 默认数量(候选词数量)
|
| 27 |
+
# 前端 ToolTip 显示数量与此保持一致
|
| 28 |
+
DEFAULT_TOPK = 10
|
| 29 |
+
|
| 30 |
+
# MPS 单次 TopK 操作的安全序列长度上限(避免 MPS bug)
|
| 31 |
+
# chunk_size 必须小于此值以确保每个 chunk 的 TopK 计算安全
|
| 32 |
+
MPS_TOPK_BUG_THRESHOLD = 2048
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ============= 运行时参数配置表 (Model × Platform) =============
|
| 36 |
+
#
|
| 37 |
+
# 二维表结构:每个模型针对每个平台配置 max_token_length 和 chunk_size
|
| 38 |
+
#
|
| 39 |
+
# 四层覆盖优先级(从高到低):
|
| 40 |
+
# 1. (model_name, platform) - 模型在该平台的专用配置(最精确)
|
| 41 |
+
# 2. (model_name, "default_cpu_machine") - 模型的通用配置(跨平台)
|
| 42 |
+
# 3. ("default_model", platform) - 平台的通用配置(跨模型)
|
| 43 |
+
# 4. ("default_model", "default_cpu_machine") - 全局兜底配置
|
| 44 |
+
#
|
| 45 |
+
# 每层支持部分覆盖:只填 max_token_length 或 chunk_size 均可
|
| 46 |
+
|
| 47 |
+
RUNTIME_CONFIGS = {
|
| 48 |
+
# 全局默认模型配置
|
| 49 |
+
"default_model": {
|
| 50 |
+
# 默认 CPU 机器配置(最保守,用于未识别的 CPU 环境)
|
| 51 |
+
"default_cpu_machine": {
|
| 52 |
+
"max_token_length": 2000,
|
| 53 |
+
"chunk_size": 256
|
| 54 |
+
},
|
| 55 |
+
# 云端 CPU(16G),如 HF Spaces CPU basic
|
| 56 |
+
"cloud_cpu_16g": {
|
| 57 |
+
"max_token_length": 2000,
|
| 58 |
+
"chunk_size": 256
|
| 59 |
+
},
|
| 60 |
+
# 云端 CPU(32G),如 HF Spaces CPU upgrade
|
| 61 |
+
"cloud_cpu_32g": {
|
| 62 |
+
"max_token_length": 5000,
|
| 63 |
+
"chunk_size": 512
|
| 64 |
+
},
|
| 65 |
+
# 云端 GPU 显存充足
|
| 66 |
+
"cloud_cuda": {
|
| 67 |
+
# "max_token_length": 10000,
|
| 68 |
+
"max_token_length": 5000,
|
| 69 |
+
"chunk_size": 1024
|
| 70 |
+
},
|
| 71 |
+
# 本地 Apple Silicon
|
| 72 |
+
"local_mps": {
|
| 73 |
+
"max_token_length": 2000,
|
| 74 |
+
"chunk_size": 512
|
| 75 |
+
}
|
| 76 |
+
},
|
| 77 |
+
# # Qwen3-1.7B
|
| 78 |
+
# "qwen3-1.7b": {
|
| 79 |
+
# "local_mps": {
|
| 80 |
+
# "max_token_length": 2000,
|
| 81 |
+
# "chunk_size": 128
|
| 82 |
+
# }
|
| 83 |
+
# }
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ============= 语义分析运行时配置(仅 max_token_length) =============
|
| 88 |
+
# 按平台配置,语义分析独立于信息密度模型
|
| 89 |
+
|
| 90 |
+
SEMANTIC_RUNTIME_CONFIGS = {
|
| 91 |
+
"default_cpu_machine": {"max_token_length": 300},
|
| 92 |
+
"cloud_cpu_16g": {"max_token_length": 300},
|
| 93 |
+
"cloud_cpu_32g": {"max_token_length": 1000},
|
| 94 |
+
"cloud_cuda": {"max_token_length": 1000},
|
| 95 |
+
"local_mps": {"max_token_length": 300},
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# ============= 平台检测与配置解析 =============
|
| 100 |
+
|
| 101 |
+
def detect_platform(verbose: bool = True) -> str:
|
| 102 |
+
"""
|
| 103 |
+
自动检测当前运行平台
|
| 104 |
+
|
| 105 |
+
优先级:
|
| 106 |
+
1. 环境变量 FORCE_CPU(显式强制 CPU 模式)
|
| 107 |
+
2. 自动探测硬件(cuda/mps/cpu)
|
| 108 |
+
3. 细分 CPU 类型(如 cloud_cpu_16g)
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
verbose: 是否打印检测信息
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
平台 ID 字符串(如 'local_mps', 'cloud_cuda', 'cloud_cpu_16g', 'cloud_cpu_32g', 'default_cpu_machine')
|
| 115 |
+
"""
|
| 116 |
+
# 1. 显式强制 CPU(可通过环境变量 FORCE_CPU=1 启用)
|
| 117 |
+
if os.environ.get("FORCE_CPU") == "1":
|
| 118 |
+
print(f"🔧 强制 CPU 模式")
|
| 119 |
+
return _detect_cpu_variant()
|
| 120 |
+
|
| 121 |
+
# 2. 自动探测 GPU/MPS
|
| 122 |
+
if torch.cuda.is_available():
|
| 123 |
+
platform = "cloud_cuda"
|
| 124 |
+
elif torch.backends.mps.is_available():
|
| 125 |
+
platform = "local_mps"
|
| 126 |
+
else:
|
| 127 |
+
# 3. 细分 CPU 类型
|
| 128 |
+
platform = _detect_cpu_variant()
|
| 129 |
+
|
| 130 |
+
print(f"🔍 自动检测平台配置: {platform}")
|
| 131 |
+
return platform
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _detect_cpu_variant() -> str:
|
| 135 |
+
"""
|
| 136 |
+
检测具体的 CPU 环境变体(内部函数)
|
| 137 |
+
根据内存大小识别不同的 CPU 环境:
|
| 138 |
+
- >= 30GB: cloud_cpu_32g(32G 内存环境)
|
| 139 |
+
- >= 15GB: cloud_cpu_16g(16G 内存环境)
|
| 140 |
+
- 其他: default_cpu_machine(默认配置)
|
| 141 |
+
|
| 142 |
+
优先检测容器内存限制(cgroup),如果不可用则回退到系统内存检测。
|
| 143 |
+
"""
|
| 144 |
+
total_memory = 0
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
# 优先检测容器内存限制(cgroup)
|
| 148 |
+
# cgroup v2: /sys/fs/cgroup/memory.max
|
| 149 |
+
# cgroup v1: /sys/fs/cgroup/memory/memory.limit_in_bytes
|
| 150 |
+
cgroup_paths = [
|
| 151 |
+
"/sys/fs/cgroup/memory.max", # cgroup v2
|
| 152 |
+
"/sys/fs/cgroup/memory/memory.limit_in_bytes", # cgroup v1
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
for cgroup_path in cgroup_paths:
|
| 156 |
+
try:
|
| 157 |
+
if os.path.exists(cgroup_path):
|
| 158 |
+
with open(cgroup_path, 'r') as f:
|
| 159 |
+
limit_str = f.read().strip()
|
| 160 |
+
# cgroup v2 可能返回 "max" 表示无限制
|
| 161 |
+
if limit_str == "max":
|
| 162 |
+
break
|
| 163 |
+
limit_bytes = int(limit_str)
|
| 164 |
+
if limit_bytes > 0 and limit_bytes < (2 ** 63): # 合理范围
|
| 165 |
+
total_memory = limit_bytes
|
| 166 |
+
print(f"🔍 从 cgroup 检测到容器内存限制: {total_memory / (1024 ** 3):.2f} GB")
|
| 167 |
+
break
|
| 168 |
+
except (ValueError, IOError, OSError):
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
# 如果 cgroup 检测失败,回退到系统内存检测
|
| 172 |
+
if total_memory == 0 and sys.platform != "win32":
|
| 173 |
+
try:
|
| 174 |
+
page_size = os.sysconf('SC_PAGE_SIZE')
|
| 175 |
+
phys_pages = os.sysconf('SC_PHYS_PAGES')
|
| 176 |
+
total_memory = page_size * phys_pages
|
| 177 |
+
print(f"🔍 从系统配置检测到内存: {total_memory / (1024 ** 3):.2f} GB")
|
| 178 |
+
except (ValueError, AttributeError):
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
# 转换为 GB
|
| 182 |
+
total_memory_gb = total_memory / (1024 ** 3)
|
| 183 |
+
|
| 184 |
+
# 判断标准:
|
| 185 |
+
# - >= 30GB: cloud_cpu_32g(HF Spaces CPU upgrade 通常会有 30.x GB 可见)
|
| 186 |
+
# - >= 15GB: cloud_cpu_16g(HF Spaces CPU basic 通常会有 15.x GB 可见)
|
| 187 |
+
if total_memory_gb >= 30.0:
|
| 188 |
+
return "cloud_cpu_32g"
|
| 189 |
+
elif total_memory_gb >= 15.0:
|
| 190 |
+
return "cloud_cpu_16g"
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print(f"⚠️ CPU 环境检测失败,回退到默认配置: {e}")
|
| 194 |
+
|
| 195 |
+
return "default_cpu_machine"
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def merge_runtime_config(model_name: str, platform: str, verbose: bool = True) -> Dict[str, int]:
|
| 199 |
+
"""
|
| 200 |
+
四层配置合并:支持部分覆盖,并追踪配置来源
|
| 201 |
+
|
| 202 |
+
优先级(从高到低):
|
| 203 |
+
1. (model_name, platform) - 模型在该平台的专用配置
|
| 204 |
+
2. (model_name, "default_cpu_machine") - 模型通用配置
|
| 205 |
+
3. ("default_model", platform) - 平台通用配置
|
| 206 |
+
4. ("default_model", "default_cpu_machine") - 全局兜底
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
model_name: 模型名称(如 'qwen3-1.7b')
|
| 210 |
+
platform: 平台 ID(如 'local_mps')
|
| 211 |
+
verbose: 是否打印配置来源提示
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
合并后的配置字典 {"max_token_length": int, "chunk_size": int}
|
| 215 |
+
|
| 216 |
+
Raises:
|
| 217 |
+
ValueError: 配置不完整时抛出
|
| 218 |
+
"""
|
| 219 |
+
# 准备四层配置(从低优先级到高优先级)
|
| 220 |
+
layers = [
|
| 221 |
+
{
|
| 222 |
+
"name": "default_model.default_cpu_machine",
|
| 223 |
+
"config": RUNTIME_CONFIGS.get("default_model", {}).get("default_cpu_machine", {})
|
| 224 |
+
},
|
| 225 |
+
{
|
| 226 |
+
"name": f"default_model.{platform}",
|
| 227 |
+
"config": RUNTIME_CONFIGS.get("default_model", {}).get(platform, {})
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"name": f"{model_name}.default_cpu_machine",
|
| 231 |
+
"config": RUNTIME_CONFIGS.get(model_name, {}).get("default_cpu_machine", {})
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"name": f"{model_name}.{platform}",
|
| 235 |
+
"config": RUNTIME_CONFIGS.get(model_name, {}).get(platform, {})
|
| 236 |
+
}
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
# 追踪每个配置项的来源
|
| 240 |
+
config_sources = {} # {"max_token_length": "层级名称", "chunk_size": "层级名称"}
|
| 241 |
+
merged = {}
|
| 242 |
+
|
| 243 |
+
# 依次合并(后面的覆盖前面的)
|
| 244 |
+
for layer in layers:
|
| 245 |
+
layer_config = layer["config"]
|
| 246 |
+
for key, value in layer_config.items():
|
| 247 |
+
merged[key] = value
|
| 248 |
+
config_sources[key] = layer["name"]
|
| 249 |
+
|
| 250 |
+
# 确保必需字段存在
|
| 251 |
+
if "max_token_length" not in merged or "chunk_size" not in merged:
|
| 252 |
+
raise ValueError(
|
| 253 |
+
f"配置不完整: model={model_name}, platform={platform}, "
|
| 254 |
+
f"merged={merged}. 缺少必需字段!"
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# 打印当前使用的配置项的配置来源
|
| 258 |
+
for key, source in config_sources.items():
|
| 259 |
+
actual_value = merged[key]
|
| 260 |
+
print(f"\t{key}={actual_value} ( {source})")
|
| 261 |
+
|
| 262 |
+
return merged
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
_semantic_max_token_length_cache: Optional[int] = None
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def get_semantic_max_token_length(verbose: bool = False) -> int:
|
| 269 |
+
"""
|
| 270 |
+
获取语义分析的 max_token_length(从 SEMANTIC_RUNTIME_CONFIGS 按平台读取)
|
| 271 |
+
平台检测结果会缓存,避免每次分析重复检测。
|
| 272 |
+
"""
|
| 273 |
+
global _semantic_max_token_length_cache
|
| 274 |
+
if _semantic_max_token_length_cache is not None:
|
| 275 |
+
return _semantic_max_token_length_cache
|
| 276 |
+
platform = detect_platform(verbose=verbose)
|
| 277 |
+
config = SEMANTIC_RUNTIME_CONFIGS.get(platform, SEMANTIC_RUNTIME_CONFIGS["default_cpu_machine"])
|
| 278 |
+
_semantic_max_token_length_cache = config["max_token_length"]
|
| 279 |
+
return _semantic_max_token_length_cache
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def validate_platform_config(platform: str, chunk_size: int, verbose: bool = True) -> None:
|
| 283 |
+
"""
|
| 284 |
+
平台级安全校验(前置到初始化阶段)
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
platform: 平台 ID
|
| 288 |
+
chunk_size: 配置的 chunk_size
|
| 289 |
+
verbose: 是否打印校验信息
|
| 290 |
+
|
| 291 |
+
Raises:
|
| 292 |
+
ValueError: 配置不符合平台限制时抛出
|
| 293 |
+
"""
|
| 294 |
+
# MPS 平台的特殊限制
|
| 295 |
+
if "mps" in platform.lower():
|
| 296 |
+
if chunk_size > MPS_TOPK_BUG_THRESHOLD:
|
| 297 |
+
raise ValueError(
|
| 298 |
+
f"❌ MPS 平台配置错误: chunk_size ({chunk_size}) "
|
| 299 |
+
f"超过安全上限 ({MPS_TOPK_BUG_THRESHOLD})\n"
|
| 300 |
+
f" 平台: {platform}\n"
|
| 301 |
+
f" 建议: 调整 RUNTIME_CONFIGS 中 {platform} 的 chunk_size"
|
| 302 |
+
)
|
| 303 |
+
if verbose:
|
| 304 |
+
print(f"✓ MPS 平台安全检查通过: chunk_size={chunk_size} (上限={MPS_TOPK_BUG_THRESHOLD})")
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def _get_cpu_info() -> Optional[str]:
|
| 308 |
+
"""
|
| 309 |
+
读取 CPU 型号信息(仅用于显示)
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
model_name, if None, return "未知"
|
| 313 |
+
"""
|
| 314 |
+
model_name = None
|
| 315 |
+
|
| 316 |
+
try:
|
| 317 |
+
if sys.platform == 'linux':
|
| 318 |
+
with open('/proc/cpuinfo', 'r') as f:
|
| 319 |
+
for line in f:
|
| 320 |
+
# 读取 model name
|
| 321 |
+
if model_name is None and 'model name' in line.lower():
|
| 322 |
+
model_name = line.split(':', 1)[1].strip()
|
| 323 |
+
|
| 324 |
+
# 如果已经读取到所需信息,可以提前退出
|
| 325 |
+
if model_name:
|
| 326 |
+
break
|
| 327 |
+
except Exception:
|
| 328 |
+
pass
|
| 329 |
+
|
| 330 |
+
return model_name
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def _print_cpu_info() -> None:
|
| 334 |
+
"""
|
| 335 |
+
打印 CPU 型号信息(所有平台都打印)
|
| 336 |
+
"""
|
| 337 |
+
try:
|
| 338 |
+
cpu_model = _get_cpu_info()
|
| 339 |
+
model = cpu_model or "未知"
|
| 340 |
+
|
| 341 |
+
print(f"💻 CPU 型号: {model}")
|
| 342 |
+
except Exception as e:
|
| 343 |
+
print(f"⚠️ CPU 信息获取失败: {e}")
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _print_cpu_thread_info() -> None:
|
| 347 |
+
"""打印 CPU 线程配置信息(PyTorch 默认配置)"""
|
| 348 |
+
try:
|
| 349 |
+
intra_threads = torch.get_num_threads()
|
| 350 |
+
inter_threads = torch.get_num_interop_threads()
|
| 351 |
+
print(f"🧵 PyTorch 线程配置: intra-op={intra_threads}, inter-op={inter_threads}")
|
| 352 |
+
except Exception as e:
|
| 353 |
+
print(f"⚠️ CPU 线程信息获取失败: {e}")
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def load_runtime_config(model_name: str, verbose: bool = False) -> tuple[str, int, int]:
|
| 357 |
+
"""
|
| 358 |
+
加载运行时配置的完整流程:检测平台 -> 合并配置 -> 校验 -> CPU调试信息
|
| 359 |
+
|
| 360 |
+
这是配置加载的主入口函数,封装了完整的配置加载逻辑。
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
model_name: 模型标识符(如 'qwen3-1.7b')
|
| 364 |
+
verbose: 是否打印详细的配置信息
|
| 365 |
+
|
| 366 |
+
Returns:
|
| 367 |
+
tuple[platform, max_token_length, chunk_size]
|
| 368 |
+
|
| 369 |
+
Raises:
|
| 370 |
+
ValueError: 配置不完整或不符合平台限制时抛出
|
| 371 |
+
"""
|
| 372 |
+
# 1. 检测平台
|
| 373 |
+
platform = detect_platform(verbose=verbose)
|
| 374 |
+
|
| 375 |
+
# 2. 四层配置合并(支持部分覆盖,并追踪配置来源)
|
| 376 |
+
config = merge_runtime_config(
|
| 377 |
+
model_name=model_name or "default_model",
|
| 378 |
+
platform=platform,
|
| 379 |
+
verbose=verbose
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# 3. 提取配置
|
| 383 |
+
max_token_length = config["max_token_length"]
|
| 384 |
+
chunk_size = config["chunk_size"]
|
| 385 |
+
|
| 386 |
+
# 4. 平台级安全校验(MPS 限制等)
|
| 387 |
+
validate_platform_config(platform, chunk_size, verbose=verbose)
|
| 388 |
+
|
| 389 |
+
# 5. 打印 CPU 信息(所有平台都打印)
|
| 390 |
+
_print_cpu_info()
|
| 391 |
+
|
| 392 |
+
# 6. CPU 线程配置信息打印(仅针对 CPU 平台)
|
| 393 |
+
if "cpu" in platform.lower():
|
| 394 |
+
_print_cpu_thread_info() # 打印调试信息
|
| 395 |
+
|
| 396 |
+
# 7. 打印配置摘要
|
| 397 |
+
print(
|
| 398 |
+
f"⚙️ 运行时配置已加载 [model={model_name}, platform={platform}]: "
|
| 399 |
+
f"max_token_length={max_token_length}, chunk_size={chunk_size}"
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
return platform, max_token_length, chunk_size
|
backend/schemas.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import asdict, dataclass, field
|
| 2 |
+
from typing import Dict, List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class TokenWithOffset:
|
| 7 |
+
offset: Tuple[int, int]
|
| 8 |
+
raw: str
|
| 9 |
+
real_topk: Optional[Tuple[int, float]] = None
|
| 10 |
+
pred_topk: List[Tuple[str, float]] = field(default_factory=list)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class AnalyzeResult:
|
| 15 |
+
model: Optional[str] = None
|
| 16 |
+
bpe_strings: List[TokenWithOffset] = field(default_factory=list)
|
| 17 |
+
error: Optional[str] = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class AnalyzeRequest:
|
| 22 |
+
model: str
|
| 23 |
+
text: str
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class AnalyzeResponse:
|
| 28 |
+
request: AnalyzeRequest
|
| 29 |
+
result: AnalyzeResult
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def serialize_analyze_result(result: AnalyzeResult) -> Dict:
|
| 33 |
+
return asdict(result)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_empty_analysis_result(error: Optional[str] = None, model: Optional[str] = None) -> Dict:
|
| 37 |
+
result = AnalyzeResult()
|
| 38 |
+
if error:
|
| 39 |
+
result.error = error
|
| 40 |
+
if model:
|
| 41 |
+
result.model = model
|
| 42 |
+
return serialize_analyze_result(result)
|
| 43 |
+
|
backend/semantic_analyzer.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Semantic analysis:基于 instruct 模型提取原文 token 与 query 的相关度
|
| 3 |
+
|
| 4 |
+
使用 logits_gradient 梯度归因策略(与预测更一致),子策略由 --logits_gradient_submode 指定:
|
| 5 |
+
- count:top-10 logits 梯度(排除 0),prompt 引导「数量」。0.6b下只适合用于判断文章整体是否有关联,1.7b下全能
|
| 6 |
+
- match_score:目标 token logit 梯度,prompt 引导「相关度打分」。0.6b/1.7b下都不太有竞争力。【已废弃】
|
| 7 |
+
- fill_blank:填空式,top-10 logits 梯度(排除 无),prompt 引导「最相关的一个词」。0.6b下只适合用于给token打分,1.7b下全能
|
| 8 |
+
|
| 9 |
+
count/fill_blank 按概率加权(Σ pᵢ·zᵢ)。
|
| 10 |
+
|
| 11 |
+
模型由 --semantic_model 参数指定,默认 qwen3-0.6b-instruct
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import gc
|
| 15 |
+
import math
|
| 16 |
+
from typing import Callable, Dict, List, Optional
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from .api.utils import round_to_sig_figs
|
| 21 |
+
from .device import DeviceManager
|
| 22 |
+
from .model_manager import ensure_semantic_slot_ready, get_semantic_model_display_name
|
| 23 |
+
from .next_token_topk import decode_topk_ids_to_strings_and_rounded_probs, DEFAULT_NEXT_TOKEN_TOPK
|
| 24 |
+
from .runtime_config import get_semantic_max_token_length
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _get_logits_gradient_submode() -> str:
|
| 29 |
+
"""logits_gradient 子策略:count / match_score(已废弃) / fill_blank"""
|
| 30 |
+
try:
|
| 31 |
+
from backend.app_context import get_args
|
| 32 |
+
return getattr(get_args(), "logits_gradient_submode", "fill_blank")
|
| 33 |
+
except RuntimeError:
|
| 34 |
+
return "fill_blank"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _truncate_text_by_tokens(tokenizer, text: str, max_tokens: int) -> str:
|
| 38 |
+
"""将 text 截断至最多 max_tokens 个 token;超长时打印提示。"""
|
| 39 |
+
text_ids = tokenizer.encode(text, add_special_tokens=False)
|
| 40 |
+
if len(text_ids) > max_tokens:
|
| 41 |
+
print(f"⚠️ 原文过长,已截断至前 {max_tokens} token")
|
| 42 |
+
return tokenizer.decode(text_ids[:max_tokens])
|
| 43 |
+
return text
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _get_gradient_checkpointing() -> bool:
|
| 47 |
+
"""默认 True(run.py);``--no-gradient-checkpointing`` 关闭。"""
|
| 48 |
+
try:
|
| 49 |
+
from backend.app_context import get_args
|
| 50 |
+
return getattr(get_args(), "gradient_checkpointing", True)
|
| 51 |
+
except RuntimeError:
|
| 52 |
+
return True
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def _get_verbose() -> bool:
|
| 56 |
+
"""是否输出详细调试信息(由 --verbose 控制)"""
|
| 57 |
+
from backend.app_context import get_verbose
|
| 58 |
+
return get_verbose()
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _analyze_logits_gradient(
|
| 62 |
+
query: str,
|
| 63 |
+
text: str,
|
| 64 |
+
tokenizer,
|
| 65 |
+
model,
|
| 66 |
+
device,
|
| 67 |
+
submode_override: Optional[str] = None,
|
| 68 |
+
progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None,
|
| 69 |
+
debug_info: bool = False,
|
| 70 |
+
full_match_degree_only: bool = False,
|
| 71 |
+
) -> Dict:
|
| 72 |
+
"""
|
| 73 |
+
梯度归因:logits 对输入 embedding 的梯度。
|
| 74 |
+
子策略:count / match_score(已废弃) / fill_blank,由 --logits_gradient_submode 指定。
|
| 75 |
+
submode_override: 评估时可选覆盖,用于同一进程内测试不同子模式。
|
| 76 |
+
"""
|
| 77 |
+
TOTAL_STEPS = 4
|
| 78 |
+
|
| 79 |
+
submode = submode_override if submode_override is not None else _get_logits_gradient_submode()
|
| 80 |
+
max_length = get_semantic_max_token_length()
|
| 81 |
+
|
| 82 |
+
if progress_callback:
|
| 83 |
+
progress_callback(1, TOTAL_STEPS, "encoding", None)
|
| 84 |
+
# 根据submodule来决定不同的instruction
|
| 85 |
+
# 文档前用 \n\n 分隔,避免 tokenizer 将首字符与空格合并,导致 offset_mapping 计算错误
|
| 86 |
+
if submode == "count":
|
| 87 |
+
instruction = f"请问下面文字中有多少个词与查询主题({query})相关?文字内容:\n\n"
|
| 88 |
+
elif submode == "match_score": # 已废弃
|
| 89 |
+
instruction = f"请问下面文字与查询主题({query})的相关程度是多少?请回答0/1/2(2为最高相关)。文字内容:\n\n"
|
| 90 |
+
elif submode == "fill_blank":
|
| 91 |
+
instruction = f"请问下面文字中哪个词与查询主题({query})最相关?如无相关词则回答“无”。文字内容:\n\n"
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f"未知子模式: {submode}")
|
| 94 |
+
|
| 95 |
+
# 截断 text 到 max_length token,再拼
|
| 96 |
+
truncated_text = _truncate_text_by_tokens(tokenizer, text, max_length)
|
| 97 |
+
|
| 98 |
+
messages = [{"role": "user", "content": instruction + truncated_text}]
|
| 99 |
+
formatted = tokenizer.apply_chat_template(
|
| 100 |
+
messages, tokenize=False, add_generation_prompt=True,
|
| 101 |
+
enable_thinking=False
|
| 102 |
+
)
|
| 103 |
+
# 生成引导词:chat template 只支持完整消息,引导词需追加到 formatted
|
| 104 |
+
if submode == "count":
|
| 105 |
+
generation_guide = f"原文中与查询主题({query})相关的词的数量 = **"
|
| 106 |
+
elif submode == "match_score": # 已废弃
|
| 107 |
+
generation_guide = f"文章和查询主题({query})的相关程度(0-2)打分为:**"
|
| 108 |
+
elif submode == "fill_blank":
|
| 109 |
+
# “引号是特意为了防止模型生成引号
|
| 110 |
+
generation_guide = f"原文中与查询主题({query})最相关的一个词是:**“"
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"未知子模式: {submode}")
|
| 113 |
+
formatted += generation_guide
|
| 114 |
+
|
| 115 |
+
# logits_gradient count/fill_blank 的 top-k,影响梯度目标覆盖的候选词数量
|
| 116 |
+
LOGITS_GRADIENT_TOPK = DEFAULT_NEXT_TOKEN_TOPK
|
| 117 |
+
|
| 118 |
+
idx = formatted.find(instruction)
|
| 119 |
+
instruction_start_char = idx if idx >= 0 else 0
|
| 120 |
+
text_start_char = instruction_start_char + len(instruction)
|
| 121 |
+
text_end_char = text_start_char + len(truncated_text)
|
| 122 |
+
lines = truncated_text.splitlines()
|
| 123 |
+
abbrev_text = truncated_text if len(lines) <= 2 else f"{lines[0]}\n...\n{lines[-1]}"
|
| 124 |
+
abbrev = formatted[:text_start_char] + abbrev_text + formatted[text_end_char:]
|
| 125 |
+
|
| 126 |
+
enc = tokenizer(
|
| 127 |
+
formatted,
|
| 128 |
+
return_tensors="pt",
|
| 129 |
+
return_offsets_mapping=True,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
input_ids = enc["input_ids"].to(device)
|
| 133 |
+
offset_mapping = enc["offset_mapping"][0].tolist()
|
| 134 |
+
|
| 135 |
+
prompt_end = len(offset_mapping)
|
| 136 |
+
for i, (s, _) in enumerate(offset_mapping):
|
| 137 |
+
if s >= text_start_char:
|
| 138 |
+
prompt_end = i
|
| 139 |
+
break
|
| 140 |
+
|
| 141 |
+
embed_layer = model.get_input_embeddings()
|
| 142 |
+
embeds = embed_layer(input_ids).detach().clone().requires_grad_(True)
|
| 143 |
+
|
| 144 |
+
use_gc = _get_gradient_checkpointing()
|
| 145 |
+
if _get_verbose():
|
| 146 |
+
print(f"📌 logits_gradient: 推理原文 (tokens={len(offset_mapping)}):\n{abbrev}")
|
| 147 |
+
if progress_callback:
|
| 148 |
+
progress_callback(2, TOTAL_STEPS, "inference", None)
|
| 149 |
+
model.eval()
|
| 150 |
+
if use_gc:
|
| 151 |
+
model.gradient_checkpointing_enable()
|
| 152 |
+
try:
|
| 153 |
+
with torch.set_grad_enabled(not full_match_degree_only):
|
| 154 |
+
outputs = model(
|
| 155 |
+
inputs_embeds=embeds,
|
| 156 |
+
output_attentions=False,
|
| 157 |
+
)
|
| 158 |
+
# 显式同步,确保已完成,progress_callback 时机准确
|
| 159 |
+
if device.type == "cuda":
|
| 160 |
+
torch.cuda.synchronize(device)
|
| 161 |
+
elif device.type == "mps":
|
| 162 |
+
torch.mps.synchronize()
|
| 163 |
+
|
| 164 |
+
logits = outputs.logits[:, -1, :]
|
| 165 |
+
topk_vals, topk_ids = torch.topk(logits, LOGITS_GRADIENT_TOPK, dim=-1)
|
| 166 |
+
probs = torch.softmax(logits, dim=-1)
|
| 167 |
+
topk_tokens, topk_probs = decode_topk_ids_to_strings_and_rounded_probs(
|
| 168 |
+
probs[0], tokenizer, topk_ids[0]
|
| 169 |
+
)
|
| 170 |
+
if _get_verbose():
|
| 171 |
+
print(f"top{LOGITS_GRADIENT_TOPK}: {[f'{t}({p*100:.1f}%)' for t, p in zip(topk_tokens, topk_probs)]}")
|
| 172 |
+
|
| 173 |
+
neg_token = "无" if submode == "fill_blank" else "0"
|
| 174 |
+
neg_id = tokenizer.encode(neg_token, add_special_tokens=False)[0]
|
| 175 |
+
# 全文匹配度:count/match_score(已废弃) 用 1-P("0"),fill_blank 用 1-P("无")
|
| 176 |
+
p_neg = probs[0, neg_id].item()
|
| 177 |
+
full_match_degree = round(1.0 - p_neg, 4)
|
| 178 |
+
|
| 179 |
+
if full_match_degree_only:
|
| 180 |
+
return {
|
| 181 |
+
"model": get_semantic_model_display_name(),
|
| 182 |
+
"token_attention": [],
|
| 183 |
+
"full_match_degree": full_match_degree,
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
if progress_callback:
|
| 187 |
+
progress_callback(3, TOTAL_STEPS, "backward", None)
|
| 188 |
+
# 归因目标:raw logits(不经过 softmax backward),避免饱和与竞争污染。
|
| 189 |
+
if submode == "count" or submode == "fill_blank":
|
| 190 |
+
# count/fill_blank 均用 top-10、按概率加权 Σ pᵢ·zᵢ,并排除 neg_token(0/无)以保持梯度方向与「相关」一致。
|
| 191 |
+
vals = topk_vals[0]
|
| 192 |
+
w = probs[0, topk_ids[0]].detach().clone()
|
| 193 |
+
# 排除 neg_token
|
| 194 |
+
w[topk_ids[0] == neg_id] = 0
|
| 195 |
+
|
| 196 |
+
target_logit = (w * vals).sum()
|
| 197 |
+
elif submode == "match_score": # 已废弃
|
| 198 |
+
target_ids = tokenizer.encode("2", add_special_tokens=False)
|
| 199 |
+
if not target_ids:
|
| 200 |
+
raise ValueError("tokenizer 无法编码 '2'")
|
| 201 |
+
target_logit = logits[0, target_ids[0]]
|
| 202 |
+
else:
|
| 203 |
+
raise ValueError(f"未知 submode: {submode}")
|
| 204 |
+
target_logit.backward()
|
| 205 |
+
grad = embeds.grad
|
| 206 |
+
if grad is None:
|
| 207 |
+
raise RuntimeError("logits_gradient: 梯度未回传,可能模型不支持(如 int8 量化)")
|
| 208 |
+
|
| 209 |
+
# 显式同步,确保已完成,progress_callback 时机准确
|
| 210 |
+
if device.type == "cuda":
|
| 211 |
+
torch.cuda.synchronize(device)
|
| 212 |
+
elif device.type == "mps":
|
| 213 |
+
torch.mps.synchronize()
|
| 214 |
+
if progress_callback:
|
| 215 |
+
progress_callback(4, TOTAL_STEPS, "processing", None)
|
| 216 |
+
|
| 217 |
+
text_token_end = len(offset_mapping)
|
| 218 |
+
# 在 GPU 上一次性计算所有 token 的 ‖∇f‖,避免循环内 .item() 导致 500 次 GPU→CPU 同步
|
| 219 |
+
grad_slice = grad[0, prompt_end:text_token_end].float()
|
| 220 |
+
norms = grad_slice.norm(dim=-1).cpu().tolist()
|
| 221 |
+
token_attention: List[Dict] = []
|
| 222 |
+
nan_count = 0
|
| 223 |
+
for i in range(prompt_end, text_token_end):
|
| 224 |
+
s, e = offset_mapping[i]
|
| 225 |
+
if s >= text_start_char and e <= text_end_char:
|
| 226 |
+
s_rel, e_rel = s - text_start_char, e - text_start_char
|
| 227 |
+
score = norms[i - prompt_end]
|
| 228 |
+
if not math.isfinite(score):
|
| 229 |
+
score = 0.0
|
| 230 |
+
nan_count += 1
|
| 231 |
+
else:
|
| 232 |
+
score = round_to_sig_figs(score)
|
| 233 |
+
token_attention.append({"offset": [s_rel, e_rel], "raw": truncated_text[s_rel:e_rel], "score": score})
|
| 234 |
+
if nan_count > 0:
|
| 235 |
+
print(f"⚠️ token_attention 中有 {nan_count} 个 score 为 NaN/Inf,已替换为 0。")
|
| 236 |
+
|
| 237 |
+
out = {
|
| 238 |
+
"model": get_semantic_model_display_name(),
|
| 239 |
+
"token_attention": token_attention,
|
| 240 |
+
"full_match_degree": full_match_degree,
|
| 241 |
+
}
|
| 242 |
+
if debug_info:
|
| 243 |
+
out["debug_info"] = {"abbrev": abbrev, "topk_tokens": topk_tokens, "topk_probs": topk_probs}
|
| 244 |
+
return out
|
| 245 |
+
finally:
|
| 246 |
+
if use_gc:
|
| 247 |
+
model.gradient_checkpointing_disable()
|
| 248 |
+
# 每次推理后清理:避免连续多次调用时 MPS/CUDA 内存累积导致卡死
|
| 249 |
+
DeviceManager.clear_cache(device)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def analyze_semantic(
|
| 253 |
+
query: str,
|
| 254 |
+
text: str,
|
| 255 |
+
submode_override: Optional[str] = None,
|
| 256 |
+
progress_callback: Optional[Callable[[int, int, str, Optional[int]], None]] = None,
|
| 257 |
+
debug_info: bool = False,
|
| 258 |
+
full_match_degree_only: bool = False,
|
| 259 |
+
) -> Dict:
|
| 260 |
+
"""
|
| 261 |
+
分析原文各 token 与 query 的相关度(使用 logits_gradient 梯度归因)。
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
query: 查询主题
|
| 265 |
+
text: 原文
|
| 266 |
+
submode_override: 评估时可选覆盖子模式(count/match_score已废弃/fill_blank)
|
| 267 |
+
progress_callback: 可选进度回调 (step, total_steps, stage, percentage)
|
| 268 |
+
debug_info: 为 True 时返回 debug_abbrev(推理原文缩写);topk_tokens、topk_probs 始终在结果中
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
{"model", "token_attention", "full_match_degree"};debug_info=True 时包含 debug_info 对象
|
| 272 |
+
"""
|
| 273 |
+
tokenizer, model, device = ensure_semantic_slot_ready()
|
| 274 |
+
return _analyze_logits_gradient(
|
| 275 |
+
query, text, tokenizer, model, device,
|
| 276 |
+
submode_override=submode_override,
|
| 277 |
+
progress_callback=progress_callback,
|
| 278 |
+
debug_info=debug_info,
|
| 279 |
+
full_match_degree_only=full_match_degree_only,
|
| 280 |
+
)
|
client/src/analysis.html
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8">
|
| 6 |
+
<title></title>
|
| 7 |
+
<meta name="description"
|
| 8 |
+
content="Info Highlight visualizes token-level information density in text using LLMs, helping you quickly find key content and skip redundancy.">
|
| 9 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 10 |
+
<link rel="stylesheet" type="text/css" href="start.css">
|
| 11 |
+
</head>
|
| 12 |
+
|
| 13 |
+
<body>
|
| 14 |
+
|
| 15 |
+
<main class="main_frame">
|
| 16 |
+
<section class="left_panel">
|
| 17 |
+
<div class="floating_content">
|
| 18 |
+
<header class="app-page-toolbar app-page-toolbar--bleed">
|
| 19 |
+
<h1 class="page-toolbar-title"><span class="title-main-line"><span data-page-title data-i18n></span><span class="title-tagline" data-page-subtitle data-i18n></span></span></h1>
|
| 20 |
+
<div class="app-page-toolbar-actions">
|
| 21 |
+
<a href="index.html" class="home-link" title="InfoLens Home" data-i18n="text,title">InfoLens Home</a>
|
| 22 |
+
<a href="compare.html?showTextRender=1&demos=/quick-start-1.json,/quick-start-2.json" target="_blank" class="compare-link" style="display: none;" title="Compare analysis results" data-i18n="text,title">Compare results</a>
|
| 23 |
+
<div class="settings-menu-wrapper">
|
| 24 |
+
<button id="settings_btn" class="settings-btn" title="Settings" data-i18n="title">
|
| 25 |
+
<span class="settings-icon">⚙️</span>
|
| 26 |
+
</button>
|
| 27 |
+
<div id="settings_menu" class="settings-menu" style="display: none;">
|
| 28 |
+
<!-- INCLUDE partials/settings-menu-analysis.html -->
|
| 29 |
+
<!-- INCLUDE partials/settings-menu-common-mid.html -->
|
| 30 |
+
<!-- INCLUDE partials/settings-menu-trailing-admin.html -->
|
| 31 |
+
</div>
|
| 32 |
+
</div>
|
| 33 |
+
</div>
|
| 34 |
+
</header>
|
| 35 |
+
|
| 36 |
+
<!-- 首页介绍内容容器(由 JS 动态加载) -->
|
| 37 |
+
<section id="home-intro-content" class="intro-section">
|
| 38 |
+
<!-- Content will be loaded dynamically -->
|
| 39 |
+
</section>
|
| 40 |
+
|
| 41 |
+
<section class="demo-section">
|
| 42 |
+
<div class="demo-header">
|
| 43 |
+
<span id="demo_header_text" data-i18n>Quick start - select a demo:</span>
|
| 44 |
+
<button id="refresh_demo_btn" class="refresh-btn" title="Refresh demo list" data-i18n="title">↻</button>
|
| 45 |
+
<div class="file-input-wrapper">
|
| 46 |
+
<button id="open_local_demo_btn" class="file-input-button" type="button" title="Open demo file from local" data-i18n="text,title">Select local</button>
|
| 47 |
+
<span id="open_local_demo_filename" class="file-input-filename" data-i18n>No file
|
| 48 |
+
selected</span>
|
| 49 |
+
<input type="file" id="open_local_demo_input" style="display: none;"
|
| 50 |
+
accept=".json,application/json">
|
| 51 |
+
</div>
|
| 52 |
+
<span id="demos_loading" class="demos-loading" data-i18n>Refreshing...</span>
|
| 53 |
+
</div>
|
| 54 |
+
<div class="demos"></div>
|
| 55 |
+
</section>
|
| 56 |
+
|
| 57 |
+
<section class="input-section">
|
| 58 |
+
<div class="input-header">
|
| 59 |
+
<span id="input_header_text"><span class="demo" data-i18n>or enter text:</span></span>
|
| 60 |
+
<div class="text-action-buttons-top">
|
| 61 |
+
<div class="textarea-counter" id="text_count_display">
|
| 62 |
+
<span id="text_count_value">0</span> <span id="char_unit" data-i18n>chars</span>
|
| 63 |
+
</div>
|
| 64 |
+
<button id="clear_text_btn" class="text-action-btn" data-i18n>Clear</button>
|
| 65 |
+
<button id="paste_text_btn" class="text-action-btn" data-i18n>Paste</button>
|
| 66 |
+
<button id="load_url_btn" class="text-action-btn" title="Load text from URL and analyze"
|
| 67 |
+
data-i18n="text,title">Analyze URL</button>
|
| 68 |
+
<button id="analyze_save_btn" class="text-action-btn" data-i18n>Analyze&Upload</button>
|
| 69 |
+
</div>
|
| 70 |
+
</div>
|
| 71 |
+
<div class="textarea-wrapper">
|
| 72 |
+
<textarea id="test_text"></textarea>
|
| 73 |
+
<div class="button-group">
|
| 74 |
+
<div class="button-left">
|
| 75 |
+
<button id="submit_text_btn" class="primary-btn" data-i18n>Analyze</button>
|
| 76 |
+
<div class="loadersmall loader-small-container"></div>
|
| 77 |
+
<span id="analyze_progress" class="analyze-progress"></span>
|
| 78 |
+
</div>
|
| 79 |
+
<div id="text_metrics" class="text-metrics">
|
| 80 |
+
<div class="text-metrics-primary">
|
| 81 |
+
<span id="metric_bytes">0 B</span>
|
| 82 |
+
<span class="text-metrics-divider">|</span>
|
| 83 |
+
<span id="metric_chars">0 chars</span>
|
| 84 |
+
<span class="text-metrics-divider">|</span>
|
| 85 |
+
<span id="metric_tokens">0 tokens</span>
|
| 86 |
+
</div>
|
| 87 |
+
<div id="metric_total_surprisal" class="text-metrics-secondary">total information = 0
|
| 88 |
+
bits</div>
|
| 89 |
+
<div id="metric_model" class="text-metrics-secondary">model: </div>
|
| 90 |
+
</div>
|
| 91 |
+
<div class="button-right">
|
| 92 |
+
<button id="save_demo_btn" class="primary-btn inactive" data-i18n>Upload</button>
|
| 93 |
+
<button id="save_local_demo_btn" class="primary-btn inactive" title="Save to local file"
|
| 94 |
+
data-i18n="text,title">Save</button>
|
| 95 |
+
</div>
|
| 96 |
+
</div>
|
| 97 |
+
</div>
|
| 98 |
+
</section>
|
| 99 |
+
|
| 100 |
+
<section id="semantic_analysis_section" class="semantic-analysis-section" style="display: none;">
|
| 101 |
+
<div class="semantic-analysis-controls">
|
| 102 |
+
<div class="semantic-search-row">
|
| 103 |
+
<div class="semantic-search-input-wrapper">
|
| 104 |
+
<input type="text" id="semantic_search_input" class="semantic-search-input" placeholder="Enter query for semantic analysis">
|
| 105 |
+
<button type="button" id="semantic_search_clear" class="semantic-search-clear demo-delete-btn" title="Clear" aria-label="Clear" data-i18n="title,aria-label">×</button>
|
| 106 |
+
<ul id="semantic_search_history_dropdown" class="semantic-search-history-dropdown"></ul>
|
| 107 |
+
</div>
|
| 108 |
+
<div class="semantic-search-actions">
|
| 109 |
+
<button id="semantic_search_btn" class="primary-btn" data-i18n>Search</button>
|
| 110 |
+
<span id="semantic_match_degree" class="semantic-match-degree" style="display: none;"></span>
|
| 111 |
+
<div id="semantic_search_loader" class="semantic-search-loader" style="visibility: hidden;"></div>
|
| 112 |
+
<span id="semantic_progress" class="semantic-progress"></span>
|
| 113 |
+
</div>
|
| 114 |
+
</div>
|
| 115 |
+
<div id="semantic_submode_row" class="semantic-submode-row" data-admin-only style="display: none;">
|
| 116 |
+
<span class="semantic-submode-group">
|
| 117 |
+
<label><input type="checkbox" id="semantic_chunked_mode" title="analyse in chunks" checked> chunked</label>
|
| 118 |
+
</span>
|
| 119 |
+
<span class="semantic-submode-group">
|
| 120 |
+
<label class="semantic-submode-label" for="semantic_submode_select">submode: </label>
|
| 121 |
+
<select id="semantic_submode_select" class="semantic-submode-select">
|
| 122 |
+
<option value="count">count</option>
|
| 123 |
+
<option value="fill_blank">fill_blank</option>
|
| 124 |
+
<option value="hybrid" selected>hybrid</option>
|
| 125 |
+
</select>
|
| 126 |
+
</span>
|
| 127 |
+
<span id="semantic_threshold_item" class="semantic-submode-group" data-admin-only style="display: none;">
|
| 128 |
+
<label class="semantic-submode-label" for="semantic_threshold_input">Match threshold:</label>
|
| 129 |
+
<input type="number" id="semantic_threshold_input" class="semantic-threshold-input" min="0" max="1">
|
| 130 |
+
</span>
|
| 131 |
+
<span class="semantic-submode-group semantic-submode-group-right">
|
| 132 |
+
<label class="semantic-submode-label" for="semantic_color_source_select">color source: </label>
|
| 133 |
+
<select id="semantic_color_source_select" class="semantic-submode-select">
|
| 134 |
+
<option value="raw_score_normed">raw score normed</option>
|
| 135 |
+
<option value="signal_probability">signal probability</option>
|
| 136 |
+
<option value="pw_score" selected>pw score</option>
|
| 137 |
+
</select>
|
| 138 |
+
</span>
|
| 139 |
+
</div>
|
| 140 |
+
</div>
|
| 141 |
+
</section>
|
| 142 |
+
</div>
|
| 143 |
+
|
| 144 |
+
<section id="all_result" class="results-section">
|
| 145 |
+
<div id="stats" class="stats-container">
|
| 146 |
+
<div id="match_score_progress_item" class="histogram-item" style="display: none;">
|
| 147 |
+
<div id="match_score_progress_title"></div>
|
| 148 |
+
<svg id="stats_match_score_progress"></svg>
|
| 149 |
+
</div>
|
| 150 |
+
<div id="raw_score_normed_histogram_item" class="histogram-item" style="display: none;">
|
| 151 |
+
<div id="raw_score_normed_histogram_title"></div>
|
| 152 |
+
<svg id="stats_raw_score_normed"></svg>
|
| 153 |
+
</div>
|
| 154 |
+
<div id="token_histogram_item" class="histogram-item" style="display: none;">
|
| 155 |
+
<div id="token_histogram_title"></div>
|
| 156 |
+
<svg id="stats_frac"></svg>
|
| 157 |
+
</div>
|
| 158 |
+
<div id="surprisal_progress_item" class="histogram-item" style="display: none;">
|
| 159 |
+
<div id="surprisal_progress_title"></div>
|
| 160 |
+
<svg id="stats_surprisal_progress"></svg>
|
| 161 |
+
</div>
|
| 162 |
+
</div>
|
| 163 |
+
</section>
|
| 164 |
+
</section>
|
| 165 |
+
|
| 166 |
+
<div class="resizer" id="resizer"></div>
|
| 167 |
+
|
| 168 |
+
<section class="right_panel">
|
| 169 |
+
<div id="results">
|
| 170 |
+
<div id="major_tooltip" class="tooltip">
|
| 171 |
+
<div class="currentToken"></div>
|
| 172 |
+
<div class="myDetail"></div>
|
| 173 |
+
<br />
|
| 174 |
+
<div class="predictions predictions-table"></div>
|
| 175 |
+
</div>
|
| 176 |
+
</div>
|
| 177 |
+
</section>
|
| 178 |
+
</main>
|
| 179 |
+
|
| 180 |
+
<div id="toast" class="toast"></div>
|
| 181 |
+
|
| 182 |
+
<!-- INCLUDE partials/attribution-sidebar.html -->
|
| 183 |
+
|
| 184 |
+
<script src="vendor.js"></script>
|
| 185 |
+
<script src="start.js"></script>
|
| 186 |
+
</body>
|
| 187 |
+
|
| 188 |
+
</html>
|
client/src/attribution.html
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8">
|
| 6 |
+
<title></title>
|
| 7 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 8 |
+
<link rel="stylesheet" type="text/css" href="attribution.css">
|
| 9 |
+
</head>
|
| 10 |
+
|
| 11 |
+
<body>
|
| 12 |
+
|
| 13 |
+
<main class="main_frame">
|
| 14 |
+
<section class="left_panel">
|
| 15 |
+
<div class="floating_content">
|
| 16 |
+
<header class="app-page-toolbar app-page-toolbar--bleed">
|
| 17 |
+
<h1 class="page-toolbar-title"><span class="title-main-line"><span data-page-title data-i18n></span><span class="title-tagline" data-page-subtitle data-i18n></span></span></h1>
|
| 18 |
+
<div class="app-page-toolbar-actions">
|
| 19 |
+
<a href="index.html" class="home-link" title="InfoLens Home" data-i18n="text,title">InfoLens Home</a>
|
| 20 |
+
<div class="settings-menu-wrapper">
|
| 21 |
+
<button id="settings_btn" class="settings-btn" title="Settings" data-i18n="title">
|
| 22 |
+
<span class="settings-icon">⚙️</span>
|
| 23 |
+
</button>
|
| 24 |
+
<div id="settings_menu" class="settings-menu" style="display: none;">
|
| 25 |
+
<!-- INCLUDE partials/settings-menu-common-mid.html -->
|
| 26 |
+
<!-- INCLUDE partials/settings-menu-trailing-admin.html -->
|
| 27 |
+
</div>
|
| 28 |
+
</div>
|
| 29 |
+
</div>
|
| 30 |
+
</header>
|
| 31 |
+
|
| 32 |
+
<div class="chat-cached-history-bar">
|
| 33 |
+
<div class="semantic-search-input-wrapper chat-prompt-history-wrapper">
|
| 34 |
+
<button type="button" id="attribution_cached_history_btn" class="text-action-btn" data-i18n>Cached history</button>
|
| 35 |
+
<ul id="attribution_cached_history_dropdown" class="semantic-search-history-dropdown"></ul>
|
| 36 |
+
</div>
|
| 37 |
+
</div>
|
| 38 |
+
|
| 39 |
+
<section class="input-section">
|
| 40 |
+
<div class="chat-prompt-panel">
|
| 41 |
+
<div class="input-header">
|
| 42 |
+
<span data-i18n>Context</span>
|
| 43 |
+
<div class="text-action-buttons-top">
|
| 44 |
+
<div class="textarea-counter" id="context_count_display">
|
| 45 |
+
<span id="context_count_value">0</span> <span data-i18n>chars</span>
|
| 46 |
+
</div>
|
| 47 |
+
<button type="button" id="clear_context_btn" class="text-action-btn" data-i18n>Clear</button>
|
| 48 |
+
<button type="button" id="paste_context_btn" class="text-action-btn" data-i18n>Paste</button>
|
| 49 |
+
<button type="button" id="context_history_btn" class="text-action-btn" data-i18n>History</button>
|
| 50 |
+
</div>
|
| 51 |
+
</div>
|
| 52 |
+
<div class="textarea-wrapper chat-prompt-textarea-block">
|
| 53 |
+
<div class="semantic-search-input-wrapper chat-prompt-history-wrapper">
|
| 54 |
+
<textarea id="context_text"></textarea>
|
| 55 |
+
<ul id="context_history_dropdown" class="semantic-search-history-dropdown"></ul>
|
| 56 |
+
</div>
|
| 57 |
+
</div>
|
| 58 |
+
</div>
|
| 59 |
+
|
| 60 |
+
<div class="chat-prompt-panel attribution-target-panel">
|
| 61 |
+
<div class="input-header">
|
| 62 |
+
<span data-i18n>Target prediction</span>
|
| 63 |
+
<div class="text-action-buttons-top">
|
| 64 |
+
<div class="textarea-counter" id="target_count_display">
|
| 65 |
+
<span id="target_count_value">0</span> <span data-i18n>chars</span>
|
| 66 |
+
</div>
|
| 67 |
+
<button type="button" id="clear_target_btn" class="text-action-btn" data-i18n>Clear</button>
|
| 68 |
+
<button type="button" id="paste_target_btn" class="text-action-btn" data-i18n>Paste</button>
|
| 69 |
+
<button type="button" id="target_history_btn" class="text-action-btn" data-i18n>History</button>
|
| 70 |
+
</div>
|
| 71 |
+
</div>
|
| 72 |
+
<div class="textarea-wrapper chat-prompt-textarea-block">
|
| 73 |
+
<div class="semantic-search-input-wrapper chat-prompt-history-wrapper">
|
| 74 |
+
<textarea id="target_text"></textarea>
|
| 75 |
+
<ul id="target_history_dropdown" class="semantic-search-history-dropdown"></ul>
|
| 76 |
+
</div>
|
| 77 |
+
</div>
|
| 78 |
+
</div>
|
| 79 |
+
|
| 80 |
+
<div class="textarea-wrapper chat-prompt-actions-row">
|
| 81 |
+
<div class="semantic-submode-row chat-completion-options-row attribution-model-variant-row">
|
| 82 |
+
<span class="semantic-submode-group">
|
| 83 |
+
<label class="semantic-submode-label" for="attribution_model_variant" data-i18n>Model</label>
|
| 84 |
+
<select id="attribution_model_variant" class="semantic-submode-select" aria-label="Attribution model slot" data-i18n="aria-label">
|
| 85 |
+
<option value="base">base</option>
|
| 86 |
+
<option value="instruct">instruct</option>
|
| 87 |
+
</select>
|
| 88 |
+
</span>
|
| 89 |
+
</div>
|
| 90 |
+
<div class="button-group">
|
| 91 |
+
<div class="button-left">
|
| 92 |
+
<button type="button" id="analyze_btn" class="primary-btn inactive" disabled data-i18n>Analyze attribution</button>
|
| 93 |
+
<div class="loadersmall loader-small-container"></div>
|
| 94 |
+
</div>
|
| 95 |
+
<div id="attribution_result_info" class="text-metrics is-hidden"></div>
|
| 96 |
+
<div class="button-right">
|
| 97 |
+
<button type="button" id="force_retry_btn" class="primary-btn inactive" disabled title="Fetch again without using cached result" data-i18n="text,title">Force retry</button>
|
| 98 |
+
</div>
|
| 99 |
+
</div>
|
| 100 |
+
</div>
|
| 101 |
+
|
| 102 |
+
<div class="semantic-submode-row attribution-max-score-row">
|
| 103 |
+
<span class="semantic-submode-group">
|
| 104 |
+
<label class="attribution-use-mapping-label">
|
| 105 |
+
<input type="checkbox" id="attribution_use_mapping">
|
| 106 |
+
<span></span>
|
| 107 |
+
</label>
|
| 108 |
+
</span>
|
| 109 |
+
<span class="semantic-submode-group attribution-max-score-slider-group">
|
| 110 |
+
<label class="semantic-submode-label" for="attribution_max_score_range" data-i18n>Max score</label>
|
| 111 |
+
<input type="range" id="attribution_max_score_range" class="attribution-max-score-range"
|
| 112 |
+
min="0.01" max="1" step="0.01" value="1"
|
| 113 |
+
title="For threshold x∈(0,1]: map normalized scores in [0,x] linearly to display intensities [0,1]; scores above x saturate at maximum intensity. At x=1, equivalent to disabling mapping."
|
| 114 |
+
data-i18n="title"
|
| 115 |
+
disabled>
|
| 116 |
+
<span id="attribution_max_score_value" class="attribution-max-score-value" aria-live="polite">1.00</span>
|
| 117 |
+
</span>
|
| 118 |
+
</div>
|
| 119 |
+
|
| 120 |
+
<div class="attribution-exclude-prompt-patterns-row">
|
| 121 |
+
<div class="semantic-submode-row attribution-exclude-prompt-patterns-header">
|
| 122 |
+
<span class="semantic-submode-group">
|
| 123 |
+
<label class="attribution-use-mapping-label"
|
| 124 |
+
title="When enabled, each line below is a regex with the global flag, matched only within the context field below. If a token offset lies fully inside a match, its score is treated as 0."
|
| 125 |
+
data-i18n="title">
|
| 126 |
+
<input type="checkbox" id="attribution_exclude_prompt_patterns_enable" checked>
|
| 127 |
+
<span></span>
|
| 128 |
+
</label>
|
| 129 |
+
</span>
|
| 130 |
+
<span class="semantic-submode-group">
|
| 131 |
+
<label class="semantic-submode-label" for="attribution_exclude_prompt_patterns" data-i18n>Exclude prompt patterns</label>
|
| 132 |
+
</span>
|
| 133 |
+
</div>
|
| 134 |
+
<textarea id="attribution_exclude_prompt_patterns" class="attribution-exclude-prompt-patterns-input" rows="2"
|
| 135 |
+
placeholder="One regex per line (context only)"
|
| 136 |
+
spellcheck="false"
|
| 137 |
+
autocomplete="off"
|
| 138 |
+
title="One regex per line (global flag), matched only within the context text; if a token offset lies fully inside a match, its score is treated as 0."
|
| 139 |
+
data-i18n="placeholder,title"></textarea>
|
| 140 |
+
</div>
|
| 141 |
+
</section>
|
| 142 |
+
</div>
|
| 143 |
+
</section>
|
| 144 |
+
|
| 145 |
+
<div class="resizer" id="resizer"></div>
|
| 146 |
+
|
| 147 |
+
<section class="right_panel">
|
| 148 |
+
<div id="results" class="attribution-inspector-surface">
|
| 149 |
+
<div id="major_tooltip" class="tooltip">
|
| 150 |
+
<div class="currentToken"></div>
|
| 151 |
+
<div class="myDetail"></div>
|
| 152 |
+
<br />
|
| 153 |
+
<div class="predictions predictions-table"></div>
|
| 154 |
+
</div>
|
| 155 |
+
</div>
|
| 156 |
+
</section>
|
| 157 |
+
</main>
|
| 158 |
+
|
| 159 |
+
<div id="toast" class="toast"></div>
|
| 160 |
+
|
| 161 |
+
<script src="vendor.js"></script>
|
| 162 |
+
<script src="attribution.js"></script>
|
| 163 |
+
|
| 164 |
+
</body>
|
| 165 |
+
|
| 166 |
+
</html>
|
client/src/chat.html
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8">
|
| 6 |
+
<title></title>
|
| 7 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 8 |
+
<link rel="stylesheet" type="text/css" href="chat.css">
|
| 9 |
+
</head>
|
| 10 |
+
|
| 11 |
+
<body>
|
| 12 |
+
|
| 13 |
+
<main class="main_frame">
|
| 14 |
+
<section class="left_panel">
|
| 15 |
+
<div class="floating_content">
|
| 16 |
+
<header class="app-page-toolbar app-page-toolbar--bleed">
|
| 17 |
+
<h1 class="page-toolbar-title"><span class="title-main-line"><span data-page-title data-i18n></span><span class="title-tagline" data-page-subtitle data-i18n></span></span></h1>
|
| 18 |
+
<div class="app-page-toolbar-actions">
|
| 19 |
+
<a href="index.html" class="home-link" title="InfoLens Home" data-i18n="text,title">InfoLens Home</a>
|
| 20 |
+
<div class="settings-menu-wrapper">
|
| 21 |
+
<button id="settings_btn" class="settings-btn" title="Settings" data-i18n="title">
|
| 22 |
+
<span class="settings-icon">⚙️</span>
|
| 23 |
+
</button>
|
| 24 |
+
<div id="settings_menu" class="settings-menu" style="display: none;">
|
| 25 |
+
<!-- INCLUDE partials/settings-menu-common-mid.html -->
|
| 26 |
+
<!-- INCLUDE partials/settings-menu-trailing-admin.html -->
|
| 27 |
+
</div>
|
| 28 |
+
</div>
|
| 29 |
+
</div>
|
| 30 |
+
</header>
|
| 31 |
+
|
| 32 |
+
<div class="chat-cached-history-bar">
|
| 33 |
+
<div class="semantic-search-input-wrapper chat-prompt-history-wrapper">
|
| 34 |
+
<button type="button" id="chat_cached_history_btn" class="text-action-btn" data-i18n>Cached history</button>
|
| 35 |
+
<ul id="chat_cached_history_dropdown" class="semantic-search-history-dropdown"></ul>
|
| 36 |
+
</div>
|
| 37 |
+
</div>
|
| 38 |
+
|
| 39 |
+
<section class="input-section">
|
| 40 |
+
<div class="semantic-submode-row chat-raw-prompt-mode-row">
|
| 41 |
+
<span class="semantic-submode-group">
|
| 42 |
+
<label for="chat_skip_chat_template">
|
| 43 |
+
<input type="checkbox" id="chat_skip_chat_template" />
|
| 44 |
+
<span data-i18n>Raw prompt mode</span>
|
| 45 |
+
</label>
|
| 46 |
+
</span>
|
| 47 |
+
</div>
|
| 48 |
+
<div id="raw_input_panel" class="chat-prompt-panel">
|
| 49 |
+
<div class="input-header">
|
| 50 |
+
<span><span class="demo" data-i18n>Raw prompt</span></span>
|
| 51 |
+
<div class="text-action-buttons-top">
|
| 52 |
+
<div class="textarea-counter" id="text_count_display">
|
| 53 |
+
<span id="text_count_value">0</span> <span data-i18n>chars</span>
|
| 54 |
+
</div>
|
| 55 |
+
<button type="button" id="clear_text_btn" class="text-action-btn">Clear</button>
|
| 56 |
+
<button type="button" id="paste_text_btn" class="text-action-btn">Paste</button>
|
| 57 |
+
<button type="button" id="chat_raw_input_history_btn" class="text-action-btn" data-i18n>History</button>
|
| 58 |
+
</div>
|
| 59 |
+
</div>
|
| 60 |
+
<div class="textarea-wrapper chat-prompt-textarea-block">
|
| 61 |
+
<div class="semantic-search-input-wrapper chat-prompt-history-wrapper">
|
| 62 |
+
<textarea id="test_text"></textarea>
|
| 63 |
+
<ul id="chat_raw_input_history_dropdown" class="semantic-search-history-dropdown"></ul>
|
| 64 |
+
</div>
|
| 65 |
+
</div>
|
| 66 |
+
</div>
|
| 67 |
+
<div id="chat_input_panel" hidden>
|
| 68 |
+
<div class="chat-prompt-panel" id="chat_system_prompt_panel">
|
| 69 |
+
<div class="input-header">
|
| 70 |
+
<label class="chat-use-system-label">
|
| 71 |
+
<input type="checkbox" id="chat_use_system_prompt" checked />
|
| 72 |
+
<span class="demo" data-i18n>System</span>
|
| 73 |
+
</label>
|
| 74 |
+
<div class="text-action-buttons-top">
|
| 75 |
+
<div class="textarea-counter" id="chat_system_text_count_display">
|
| 76 |
+
<span id="chat_system_text_count_value">0</span> <span data-i18n>chars</span>
|
| 77 |
+
</div>
|
| 78 |
+
<button type="button" id="chat_system_clear_text_btn" class="text-action-btn">Clear</button>
|
| 79 |
+
<button type="button" id="chat_system_paste_text_btn" class="text-action-btn">Paste</button>
|
| 80 |
+
<button type="button" id="chat_system_prompt_history_btn" class="text-action-btn">History</button>
|
| 81 |
+
</div>
|
| 82 |
+
</div>
|
| 83 |
+
<div class="textarea-wrapper chat-prompt-textarea-block">
|
| 84 |
+
<div class="semantic-search-input-wrapper chat-prompt-history-wrapper">
|
| 85 |
+
<textarea id="chat_system_text">You are a helpful assistant.</textarea>
|
| 86 |
+
<ul id="chat_system_prompt_history_dropdown" class="semantic-search-history-dropdown"></ul>
|
| 87 |
+
</div>
|
| 88 |
+
</div>
|
| 89 |
+
</div>
|
| 90 |
+
<div class="chat-prompt-panel">
|
| 91 |
+
<div class="input-header">
|
| 92 |
+
<span><span class="demo" data-i18n>User</span></span>
|
| 93 |
+
<div class="text-action-buttons-top">
|
| 94 |
+
<div class="textarea-counter" id="chat_user_text_count_display">
|
| 95 |
+
<span id="chat_user_text_count_value">0</span> <span data-i18n>chars</span>
|
| 96 |
+
</div>
|
| 97 |
+
<button type="button" id="chat_user_clear_text_btn" class="text-action-btn">Clear</button>
|
| 98 |
+
<button type="button" id="chat_user_paste_text_btn" class="text-action-btn">Paste</button>
|
| 99 |
+
<button type="button" id="chat_user_prompt_history_btn" class="text-action-btn">History</button>
|
| 100 |
+
</div>
|
| 101 |
+
</div>
|
| 102 |
+
<div class="textarea-wrapper chat-prompt-textarea-block">
|
| 103 |
+
<div class="semantic-search-input-wrapper chat-prompt-history-wrapper">
|
| 104 |
+
<textarea id="chat_user_text"></textarea>
|
| 105 |
+
<ul id="chat_user_prompt_history_dropdown" class="semantic-search-history-dropdown"></ul>
|
| 106 |
+
</div>
|
| 107 |
+
</div>
|
| 108 |
+
</div>
|
| 109 |
+
</div>
|
| 110 |
+
<div class="textarea-wrapper chat-prompt-actions-row">
|
| 111 |
+
<div class="semantic-submode-row chat-completion-options-row">
|
| 112 |
+
<span class="semantic-submode-group">
|
| 113 |
+
<label class="chat-max-new-tokens-label" for="chat_max_new_tokens">
|
| 114 |
+
<span class="semantic-submode-label" data-i18n>Max new tokens:</span>
|
| 115 |
+
<input type="text" id="chat_max_new_tokens" class="semantic-threshold-input chat-max-new-tokens-input" inputmode="numeric" autocomplete="off" />
|
| 116 |
+
</label>
|
| 117 |
+
</span>
|
| 118 |
+
</div>
|
| 119 |
+
<div class="button-group">
|
| 120 |
+
<div class="button-left">
|
| 121 |
+
<button type="button" id="submit_text_btn" class="primary-btn inactive" disabled data-i18n>Ask</button>
|
| 122 |
+
<div class="generation-status-slot loader-small-container">
|
| 123 |
+
<div class="loadersmall"></div>
|
| 124 |
+
<span id="chat_complete_reason" class="generation-end-reason"></span>
|
| 125 |
+
</div>
|
| 126 |
+
<span id="analyze_progress" class="analyze-progress"></span>
|
| 127 |
+
</div>
|
| 128 |
+
<div id="text_metrics" class="text-metrics text-metrics-chat">
|
| 129 |
+
<div id="metric_usage" class="text-metrics-secondary"></div>
|
| 130 |
+
<div id="metric_model" class="text-metrics-secondary">model: </div>
|
| 131 |
+
</div>
|
| 132 |
+
<div class="button-right">
|
| 133 |
+
<button type="button" id="force_retry_btn" class="primary-btn inactive" disabled title="Fetch again without using cached result" data-i18n="text,title">Force retry</button>
|
| 134 |
+
</div>
|
| 135 |
+
</div>
|
| 136 |
+
</div>
|
| 137 |
+
</section>
|
| 138 |
+
</div>
|
| 139 |
+
</section>
|
| 140 |
+
|
| 141 |
+
<div class="resizer" id="resizer"></div>
|
| 142 |
+
|
| 143 |
+
<section class="right_panel">
|
| 144 |
+
<div class="chat-right-stack">
|
| 145 |
+
<div id="chat_prompt_used" class="chat-prompt-used truncated-text" hidden></div>
|
| 146 |
+
<div id="chat_streaming_preview" class="chat-streaming-preview" hidden></div>
|
| 147 |
+
<div id="results">
|
| 148 |
+
<div id="major_tooltip" class="tooltip">
|
| 149 |
+
<div class="currentToken"></div>
|
| 150 |
+
<div class="myDetail"></div>
|
| 151 |
+
<br />
|
| 152 |
+
<div class="predictions predictions-table"></div>
|
| 153 |
+
</div>
|
| 154 |
+
</div>
|
| 155 |
+
<div class="chat-copy-fulltext-row">
|
| 156 |
+
<button type="button" id="chat_copy_fulltext_btn" class="text-action-btn">Copy full text</button>
|
| 157 |
+
</div>
|
| 158 |
+
</div>
|
| 159 |
+
</section>
|
| 160 |
+
</main>
|
| 161 |
+
|
| 162 |
+
<div id="toast" class="toast"></div>
|
| 163 |
+
|
| 164 |
+
<!-- INCLUDE partials/attribution-sidebar.html -->
|
| 165 |
+
|
| 166 |
+
<script src="vendor.js"></script>
|
| 167 |
+
<script src="chat.js"></script>
|
| 168 |
+
|
| 169 |
+
</body>
|
| 170 |
+
|
| 171 |
+
</html>
|
client/src/compare.html
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
|
| 4 |
+
<head>
|
| 5 |
+
<meta charset="UTF-8">
|
| 6 |
+
<title>Info Highlight / Compare</title>
|
| 7 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 8 |
+
<link rel="stylesheet" type="text/css" href="compare.css">
|
| 9 |
+
<!--<link rel="stylesheet" type="text/css" href="vendor.css">-->
|
| 10 |
+
|
| 11 |
+
</head>
|
| 12 |
+
|
| 13 |
+
<body>
|
| 14 |
+
|
| 15 |
+
<div class="main_frame">
|
| 16 |
+
<!-- Grid 容器包裹工具栏和内容区 -->
|
| 17 |
+
<div class="compare-wrapper">
|
| 18 |
+
<!-- 顶部工具栏 -->
|
| 19 |
+
<div class="app-page-toolbar">
|
| 20 |
+
<h1 class="compare-toolbar-title">
|
| 21 |
+
<span class="compare-toolbar-title-app" data-i18n>Info Highlight</span>
|
| 22 |
+
<span class="compare-toolbar-title-sep" aria-hidden="true">/</span>
|
| 23 |
+
<span class="compare-toolbar-title-page" data-i18n>Compare</span>
|
| 24 |
+
</h1>
|
| 25 |
+
<div class="app-page-toolbar-actions">
|
| 26 |
+
<label style="display: flex; align-items: center; gap: 5px; cursor: pointer;">
|
| 27 |
+
<input type="checkbox" id="show_text_render_toggle">
|
| 28 |
+
<span data-i18n>Show Text Rendering</span>
|
| 29 |
+
</label>
|
| 30 |
+
<label style="display: flex; align-items: center; gap: 5px; cursor: pointer;">
|
| 31 |
+
<input type="checkbox" id="model_diff_mode_toggle">
|
| 32 |
+
<span data-i18n>Diff Mode</span>
|
| 33 |
+
</label>
|
| 34 |
+
<button id="edit_mode_toggle" data-i18n>Edit</button>
|
| 35 |
+
<button id="clear_demos_btn" data-i18n>Clear</button>
|
| 36 |
+
<button id="add_demos_btn" data-i18n>Add</button>
|
| 37 |
+
</div>
|
| 38 |
+
</div>
|
| 39 |
+
|
| 40 |
+
<!-- 对比结果展示区 -->
|
| 41 |
+
<div id="compare-container" class="compare-container">
|
| 42 |
+
<!-- 空状态提示(自动显示/隐藏) -->
|
| 43 |
+
<div class="compare-empty-state">
|
| 44 |
+
<div class="empty-icon">📊</div>
|
| 45 |
+
<div class="empty-title" data-i18n>No comparison data</div>
|
| 46 |
+
</div>
|
| 47 |
+
<!-- Demo 列将通过 JavaScript 动态创建 -->
|
| 48 |
+
</div>
|
| 49 |
+
</div>
|
| 50 |
+
</div>
|
| 51 |
+
|
| 52 |
+
<!-- Toast通知容器 -->
|
| 53 |
+
<div id="toast" class="toast"></div>
|
| 54 |
+
|
| 55 |
+
<!-- 全局tooltip容器(用于对比模式) -->
|
| 56 |
+
<div id="global_tooltip" class="tooltip">
|
| 57 |
+
<div class="currentToken"></div>
|
| 58 |
+
<div class="myDetail"></div>
|
| 59 |
+
<br />
|
| 60 |
+
<div class="predictions predictions-table"></div>
|
| 61 |
+
</div>
|
| 62 |
+
|
| 63 |
+
<script src="vendor.js"></script>
|
| 64 |
+
<script src="compare.js"></script>
|
| 65 |
+
|
| 66 |
+
</body>
|
| 67 |
+
|
| 68 |
+
</html>
|
| 69 |
+
|
client/src/content/home.en.html
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- 简介 / Hero(始终可见) -->
|
| 2 |
+
<div class="intro-brief" style="--intro-rgb: 255, 71, 64">
|
| 3 |
+
<span class="intro-token" style="--a:0.56">Want</span><span class="intro-token" style="--a:0.53"> key</span><span class="intro-token" style="--a:0.29"> points</span><span class="intro-token" style="--a:0.31"> at</span><span class="intro-token" style="--a:0.09"> a</span><span class="intro-token" style="--a:0.00"> glance</span><span class="intro-token" style="--a:0.04">?</span><span class="intro-token" style="--a:0.26"> Or</span><span class="intro-token" style="--a:0.31"> simply</span><span class="intro-token" style="--a:0.29"> curious</span><span class="intro-token" style="--a:0.03"> about</span><span class="intro-token" style="--a:0.08"> the</span><span class="intro-token" style="--a:0.33"> information</span><span class="intro-token" style="--a:0.68">-the</span><span class="intro-token" style="--a:0.02">oret</span><span class="intro-token" style="--a:0.00">ic</span><span class="intro-token" style="--a:0.29"> nature</span><span class="intro-token" style="--a:0.00"> of</span><span class="intro-token" style="--a:0.31"> language</span><span class="intro-token" style="--a:0.19">?</span><br><br><span class="intro-token" style="--a:0.32">Try</span><span class="intro-token" style="--a:0.47"> Info</span><span class="intro-token" style="--a:0.70"> Highlight</span><span class="intro-token" style="--a:0.17">.</span><span class="intro-token" style="--a:0.06"> It</span><span class="intro-token" style="--a:0.25"> uses</span><span class="intro-token" style="--a:0.34"> large</span><span class="intro-token" style="--a:0.02"> language</span><span class="intro-token" style="--a:0.00"> models</span><span class="intro-token" style="--a:0.02"> to</span><span class="intro-token" style="--a:0.23"> analyze</span><span class="intro-token" style="--a:0.14"> text</span><span class="intro-token" style="--a:0.37"> information</span><span class="intro-token" style="--a:0.19"> density</span><span class="intro-token" style="--a:0.05"> and</span><span class="intro-token" style="--a:0.34"> visual</span><span class="intro-token" style="--a:0.01">izes</span><span class="intro-token" style="--a:0.39"> where</span><span class="intro-token" style="--a:0.08"> the</span><span class="intro-token" style="--a:0.26"> important</span><span class="intro-token" style="--a:0.13"> parts</span><span class="intro-token" style="--a:0.05"> are</span><span class="intro-token" style="--a:0.08">.</span><br><br><span class="intro-token" style="--a:0.17">The</span><span class="intro-token" style="--a:0.40"> color</span><span class="intro-token" style="--a:0.17"> intensity</span><span class="intro-token" style="--a:0.07"> of</span><span class="intro-token" style="--a:0.06"> each</span><span class="intro-token" style="--a:0.27"> token</span><span class="intro-token" style="--a:0.10"> indicates</span><span class="intro-token" style="--a:0.07"> how</span><span class="intro-token" style="--a:0.04"> much</span><span class="intro-token" style="--a:0.03"> information</span><span class="intro-token" style="--a:0.03"> it</span><span class="intro-token" style="--a:0.09"> carries</span><span class="intro-token" style="--a:0.04">.</span><span class="intro-token" style="--a:0.39"> Try</span><span class="intro-token" style="--a:0.04"> it</span><span class="intro-token" style="--a:0.12"> yourself</span><span class="intro-token" style="--a:0.21">!</span>
|
| 4 |
+
</div>
|
| 5 |
+
|
| 6 |
+
<!-- 了解更多(默认折叠) -->
|
| 7 |
+
<details class="intro-more">
|
| 8 |
+
<summary>
|
| 9 |
+
<span class="intro-summary-when-closed">Learn more</span>
|
| 10 |
+
<span class="intro-summary-when-open">Hide</span>
|
| 11 |
+
</summary>
|
| 12 |
+
|
| 13 |
+
<!-- 原理直觉 -->
|
| 14 |
+
<div class="intro-block">
|
| 15 |
+
<h4>Intuitive Understanding of Information</h4>
|
| 16 |
+
<p>From a linguistic perspective, information represents the novelty/surprise/importance of a word. Words that
|
| 17 |
+
are harder to predict from context typically carry more information. A simple example: "This morning I opened the door and saw a 'UFO'."
|
| 18 |
+
vs "This morning I opened the door and saw a 'cat'." — clearly "UFO" carries more information.</p>
|
| 19 |
+
</div>
|
| 20 |
+
|
| 21 |
+
<!-- 技术定义 -->
|
| 22 |
+
<div class="intro-block intro-technical">
|
| 23 |
+
<h4>Information-Theoretic Perspective</h4>
|
| 24 |
+
<p>In our implementation, the information content of each token comes from how difficult it is for the LLM to
|
| 25 |
+
predict that token from left to right.</p>
|
| 26 |
+
<p>
|
| 27 |
+
From an information-theoretic perspective, this can be expressed as the conditional information of a token
|
| 28 |
+
given the model and the preceding context:
|
| 29 |
+
</p>
|
| 30 |
+
<pre>
|
| 31 |
+
Information of tokenᵢ in a text = -log₂P(tokenᵢ | model, token₀, …, tokenᵢ₋₁)
|
| 32 |
+
</pre>
|
| 33 |
+
<p>The core assumption behind Info Highlight is that this conditional information aligns with human subjective
|
| 34 |
+
perception, such as novelty, surprise, and potential importance.
|
| 35 |
+
</p>
|
| 36 |
+
</div>
|
| 37 |
+
|
| 38 |
+
<!-- 误差与局限 -->
|
| 39 |
+
<div class="intro-block">
|
| 40 |
+
<h4>Ideal vs Reality</h4>
|
| 41 |
+
<p>
|
| 42 |
+
For an ideal model, whose knowledge and contextual understanding match that of the reader, the evaluation
|
| 43 |
+
would perfectly align with human subjective perception.
|
| 44 |
+
</p>
|
| 45 |
+
<p>Therefore, the gap between current results and reader perception mainly comes from two aspects:</p>
|
| 46 |
+
<ul>
|
| 47 |
+
<li><strong>Model capability vs human reader:</strong> The model's understanding and knowledge may be generally less than,
|
| 48 |
+
or possibly exceed, the reader's. Imagine comparing a state-of-the-art LLM with a ten-year-old reader.</li>
|
| 49 |
+
<li><strong>Model context vs human reader:</strong> The model only has the text read so far as context, much less
|
| 50 |
+
than the reader's. Info Highlight uses base models without instruction tuning or prompts (which actually
|
| 51 |
+
gives the best results).</li>
|
| 52 |
+
</ul>
|
| 53 |
+
<p>The good news is that LLMs are improving so fast: current analysis results already reflect mainstream
|
| 54 |
+
readers' subjective perception to some extent, and can be used to evaluate article information content and
|
| 55 |
+
improve reading speed.</p>
|
| 56 |
+
</div>
|
| 57 |
+
|
| 58 |
+
<!-- Tribute -->
|
| 59 |
+
<div class="intro-block">
|
| 60 |
+
<h4>Tribute</h4>
|
| 61 |
+
<p>Built on the classic project <a href="http://gltr.io" target="_blank" rel="noopener">GLTR.io</a>,
|
| 62 |
+
developed by Hendrik Strobelt et al. in 2019. GLTR was a web demo that pioneered using GPT-2 prediction
|
| 63 |
+
probabilities to detect generated text.</p>
|
| 64 |
+
<p>However, Info Highlight is not meant to detect AI text, but to evaluate the "information quality" of text.</p>
|
| 65 |
+
</div>
|
| 66 |
+
|
| 67 |
+
<!-- FAQ -->
|
| 68 |
+
<div class="intro-block intro-faq">
|
| 69 |
+
<h4>FAQ</h4>
|
| 70 |
+
|
| 71 |
+
<p><strong>Is it an AI text detector?</strong></p>
|
| 72 |
+
<p>No.</p>
|
| 73 |
+
<p>When we dislike AI text, we actually dislike low-quality text. We dislike low-quality human-written text,
|
| 74 |
+
rather than high-quality AI-generated content. So the key is the "information quality" of the text.
|
| 75 |
+
Info Highlight aims to detect "information quality" rather than "AI signs", though it can be used to detect
|
| 76 |
+
AI-generated nonsense with no information content.</p>
|
| 77 |
+
|
| 78 |
+
<p><strong>What LLM is currently used?</strong></p>
|
| 79 |
+
<p>Currently the open-source <strong>Qwen3-0.6B/1.7B/4B/14B-Base</strong> is used. Among them, the 4B model gives
|
| 80 |
+
results quite close to most people's subjective perception among the models the author has tested (note that
|
| 81 |
+
larger model does not necessarily lead to more consistency with the reader's subjective perception). When
|
| 82 |
+
hardware is limited, 0.6B/1.7B models are used; they perform slightly worse than 4B (information
|
| 83 |
+
content difference is within ~15%), but the trend is similar.</p>
|
| 84 |
+
|
| 85 |
+
<p><strong>Why does information content affect text quality?</strong></p>
|
| 86 |
+
<p>Low information content means the LLM can easily predict it from context. If even a machine can predict it,
|
| 87 |
+
how important can it be? Conversely, high information content means the LLM has difficulty predicting it
|
| 88 |
+
from context. (Assuming it's not a mistake) Then it represents key information the author wants to convey
|
| 89 |
+
that the machine doesn't know.</p>
|
| 90 |
+
</div>
|
| 91 |
+
</details>
|
client/src/content/home.zh.html
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- 简介 / Hero(始终可见) -->
|
| 2 |
+
<div class="intro-brief" style="--intro-rgb: 255, 71, 64">
|
| 3 |
+
<span class="intro-token" style="--a:0.63">想</span><span class="intro-token" style="--a:0.58">一眼</span><span class="intro-token" style="--a:0.43">找到</span><span class="intro-token" style="--a:0.52">文章</span><span class="intro-token" style="--a:0.35">的关键</span><span class="intro-token" style="--a:0.13">点</span><span class="intro-token" style="--a:0.31">?</span><span class="intro-token" style="--a:0.29">或者</span><span class="intro-token" style="--a:0.27">只是</span><span class="intro-token" style="--a:0.37">好奇</span><span class="intro-token" style="--a:0.38">文字</span><span class="intro-token" style="--a:0.48">的信息</span><span class="intro-token" style="--a:0.56">论</span><span class="intro-token" style="--a:0.48">奥</span><span class="intro-token" style="--a:0.03">秘</span><span class="intro-token" style="--a:0.18">?</span><br><br><span class="intro-token" style="--a:0.47">试试</span><span class="intro-token" style="--a:0.38">Info</span><span class="intro-token" style="--a:0.70"> Highlight</span><span class="intro-token" style="--a:0.36">.</span><span class="intro-token" style="--a:0.16"> 它</span><span class="intro-token" style="--a:0.00">它</span><span class="intro-token" style="--a:0.27">用</span><span class="intro-token" style="--a:0.29">大</span><span class="intro-token" style="--a:0.13">语言</span><span class="intro-token" style="--a:0.00">模型</span><span class="intro-token" style="--a:0.18">分析</span><span class="intro-token" style="--a:0.15">文本</span><span class="intro-token" style="--a:0.34">的信息</span><span class="intro-token" style="--a:0.09">密度</span><span class="intro-token" style="--a:0.03">,</span><span class="intro-token" style="--a:0.48">可视化</span><span class="intro-token" style="--a:0.19">展示</span><span class="intro-token" style="--a:0.49">哪里</span><span class="intro-token" style="--a:0.30">更重要</span><span class="intro-token" style="--a:0.11">。</span><br><br><span class="intro-token" style="--a:0.41">每个</span><span class="intro-token" style="--a:0.21">字</span><span class="intro-token" style="--a:0.41">的颜色</span><span class="intro-token" style="--a:0.18">深</span><span class="intro-token" style="--a:0.00">浅</span><span class="intro-token" style="--a:0.11">,</span><span class="intro-token" style="--a:0.18">表示</span><span class="intro-token" style="--a:0.11">它</span><span class="intro-token" style="--a:0.31">承载</span><span class="intro-token" style="--a:0.02">的信息</span><span class="intro-token" style="--a:0.03">量</span><span class="intro-token" style="--a:0.11">大小</span><span class="intro-token" style="--a:0.05">。</span><span class="intro-token" style="--a:0.49">自己</span><span class="intro-token" style="--a:0.28">试试</span><span class="intro-token" style="--a:0.06">吧</span><span class="intro-token" style="--a:0.11">!</span>
|
| 4 |
+
</div>
|
| 5 |
+
|
| 6 |
+
<!-- 了解更多(默认折叠) -->
|
| 7 |
+
<details class="intro-more">
|
| 8 |
+
<summary>
|
| 9 |
+
<span class="intro-summary-when-closed">了解更多</span>
|
| 10 |
+
<span class="intro-summary-when-open">收起</span>
|
| 11 |
+
</summary>
|
| 12 |
+
|
| 13 |
+
<!-- 原理直觉 -->
|
| 14 |
+
<div class="intro-block">
|
| 15 |
+
<h4>信息量的直观理解</h4>
|
| 16 |
+
<p>从语言学角度看,信息量代表一个词所包含的新意/意外性/关键程度。越难从上下文中预测出来的词,通常携带的信息就越多。一个简单的例子:"今天早上我打开门看见了一只'飞碟'" 和
|
| 17 |
+
"今天早上我打开门看见了一只'猫'":在这里显然"飞碟"的信息量更大。</p>
|
| 18 |
+
</div>
|
| 19 |
+
|
| 20 |
+
<!-- 技术定义 -->
|
| 21 |
+
<div class="intro-block intro-technical">
|
| 22 |
+
<h4>信息论视角</h4>
|
| 23 |
+
<p>在工程实现中,每个 token 的信息量,来自大模型从左到右预测当前 token 的难度。</p>
|
| 24 |
+
<p>从信息论角度,它可以表示为当前 token 相对于大模型和已读上下文的条件信息量:</p>
|
| 25 |
+
<pre>
|
| 26 |
+
一段文本中的 tokenᵢ 的信息量 = -log₂P(tokenᵢ | model, token₀, …, tokenᵢ₋₁)
|
| 27 |
+
</pre>
|
| 28 |
+
<p>Info Highlight 的核心假设就是,这个条件信息量的大小和人类的主观感受(新意/意外性/潜在关键程度)是一致的。</p>
|
| 29 |
+
</div>
|
| 30 |
+
|
| 31 |
+
<!-- 误差与局限 -->
|
| 32 |
+
<div class="intro-block">
|
| 33 |
+
<h4>理想与现实</h4>
|
| 34 |
+
<p>对于一个想象中的理想模型(它的包含了上下文的知识量和阅读者一致),那么它评估出的结果应该和阅读者的主观感受是完全一致的。</p>
|
| 35 |
+
<p>所以,目前的实际结果和阅读者主观感受之间的差距,主要来自两个方面:</p>
|
| 36 |
+
<ul>
|
| 37 |
+
<li><strong>模型能力和阅读者的差异:</strong>模型的理解能力和知识量很可能不如阅读者,也有小可能性过剩,想象一下目前的SOTA大模型和一个十岁孩子阅读者相��。</li>
|
| 38 |
+
<li><strong>模型上下文和阅读者的差异:</strong>模型只有文章已读部分作为上下文,远小于阅读者。Info Highlight 使用没有 instruct 微调的 base 模型,也没有任何提示词(其实这样效果已经是最好了)。
|
| 39 |
+
</li>
|
| 40 |
+
</ul>
|
| 41 |
+
<p>好消息是,大模型进步实在太快了:目前的分析结果已经在一定程度上反映了主流阅读者的主观感受,可以用来评估文章的信息含量,还可以提高阅读速度。</p>
|
| 42 |
+
</div>
|
| 43 |
+
|
| 44 |
+
<!-- 致谢 -->
|
| 45 |
+
<div class="intro-block">
|
| 46 |
+
<h4>致谢</h4>
|
| 47 |
+
<p>基于 2019 年 Hendrik Strobelt 等人开发的经典项目 <a href="http://gltr.io" target="_blank" rel="noopener">GLTR.io</a>。GLTR 是一个网页演示,率先用 GPT-2 的预测概率来检测生成文本。</p>
|
| 48 |
+
<p>不过 Info Highlight 的目标不是检测 AI 文本,而是评估文本的“信息质量”。</p>
|
| 49 |
+
</div>
|
| 50 |
+
|
| 51 |
+
<!-- FAQ -->
|
| 52 |
+
<div class="intro-block intro-faq">
|
| 53 |
+
<h4>常见问题</h4>
|
| 54 |
+
|
| 55 |
+
<p><strong>它是 AI 文本检测器吗?</strong></p>
|
| 56 |
+
<p>不是。</p>
|
| 57 |
+
<p>当我们反感AI文本时,我们其实是反感低质量的文本。我们更反感低质量的真人写的文本,而不是AI生成的高质量内容。所以,关键是文本的"信息质量"。Info Highlight 的目标是检测"信息质量"而不是“AI痕迹”,虽然它可以用来检测没有信息量的AI胡编文本。
|
| 58 |
+
</p>
|
| 59 |
+
|
| 60 |
+
<p><strong>目前使用的是什么大模型?</strong></p>
|
| 61 |
+
<p>当前使用的是开源的 <strong>Qwen3-0.6B/1.7B/4B/14B-Base</strong>,其中4B模型是作者测试过的模型里结果挺接近大部分人主观感受的一个(注意并不一定是模型越大越符合阅读者的主观感受)。
|
| 62 |
+
当硬件配置限制时,会用0.6B/1.7B模型,它们效果比4B稍差(信息量评估结果差异约15%以内),但趋势是类似的。</p>
|
| 63 |
+
|
| 64 |
+
<p><strong>说到底,为什么信息量会影响文本的质量?</strong></p>
|
| 65 |
+
<p>一个词的信息量低,意味着大模型能很容易从上文预测出来。既然机器都能预测出来,那它还能有多关键呢?反之,一个词的信息量高,意味着大模型很难从上文预测出来。(如果不是错误表达的话)那它就代表了作者想要表达,而机器不知道的关键信息。
|
| 66 |
+
</p>
|
| 67 |
+
</div>
|
| 68 |
+
</details>
|
client/src/content/images/attribute-dark.png
ADDED
|
Git LFS Details
|
client/src/content/images/attribute.png
ADDED
|
Git LFS Details
|