InfoLens / backend /api /analyze.py
dqy08's picture
initial beta release
494c9e4
"""文本分析 API"""
import gc
import json
import time
import queue
import threading
from typing import Optional
from backend.schemas import create_empty_analysis_result
from backend.model_manager import project_registry, DEFAULT_MODEL, _inference_lock
from model_paths import resolve_hf_path
from backend.oom import exit_if_oom
from backend.api.sse_utils import (
SSEProgressReporter,
consume_progress_queue,
send_result_event,
send_error_event,
)
# 自定义异常:排队超时
class QueueTimeoutError(Exception):
"""排队等待获取锁超时"""
pass
# 使用 model_manager 中的统一推理锁(与 analyze_semantic 共用)
# 单次分析的总处理时长限制(秒)
ANALYSIS_TIMEOUT = 60.0
# 等待获取锁的最大时间(秒)- 如果排队时间过长,直接拒绝请求
LOCK_WAIT_TIMEOUT = 10.0
def _analyze_result_model_display(model: Optional[str]) -> Optional[str]:
"""主分析 result.model:对外统一为 HuggingFace 仓库 id(与 model_paths.resolve_hf_path 一致)。"""
if not model or not str(model).strip():
return None
return resolve_hf_path(str(model).strip())
def _build_response(model: str, text: str, result):
"""构建标准响应"""
# 将 model 添加到 result 中,并确保 model 在最前面
if not isinstance(result, dict):
result = {}
result = result.copy()
# 如果 result 中已有 model,先移除
if 'model' in result:
model_value = result.pop('model')
else:
model_value = model
# 重新构建 result,确保 model 在最前面
result = {'model': _analyze_result_model_display(model_value), **result}
return {
"request": {'text': text},
"result": result
}
def _error_response(model: str, text: str, message: str, status_code: int):
"""构建错误响应(统一格式)"""
# 统一错误格式:包含 success=false 和 message
result = create_empty_analysis_result(message, _analyze_result_model_display(model))
return {
"success": False,
"message": message,
"request": {'text': text or ''},
"result": result
}, status_code
def _validate_and_prepare_request(analyze_request):
"""
验证请求并准备参数
Returns:
(model, text, error_msg, status_code) 元组
如果验证失败,返回 (None, None, error_msg, status_code)
如果成功,返回 (model, text, None, None)
"""
model = analyze_request.get('model')
text = analyze_request.get('text')
if not text:
return None, None, "缺少分析文本,请提供 text 字段", 400
# 获取默认模型(使用模块级上下文以获取持久化的当前活动模型)
from backend.app_context import get_app_context
context = get_app_context(prefer_module_context=True)
default_model = context.model_name if context.model_name else DEFAULT_MODEL
# 处理 default、None 或空字符串,使用默认模型
if not model or model == 'default' or model == '':
model = default_model
else:
# 只允许使用默认模型,其他模型请求将被拒绝
if model != default_model:
return None, None, f"当前仅支持默认模型 '{default_model}',不允许使用其他模型", 400
return model, text, None, None
def _load_project_with_error_handling(model):
"""
获取已加载的模型;若未加载则根据配置进行懒加载或返回错误。
Returns:
(project_obj, error_msg, status_code) 元组
如果成功,返回 (project_obj, None, None)
如果失败,返回 (None, error_msg, status_code)
"""
# 检查模型是否在注册表中
if not project_registry.is_available(model):
available_models = list(project_registry.available_model_names())
error_msg = f"❌ 模型 '{model}' 未注册。可用模型: {available_models}"
print(error_msg)
return None, error_msg, 404
# 检查模型是否已加载
p = project_registry.get(model)
if p is None:
from backend.app_context import get_app_context
from backend.model_manager import ensure_main_slot_ready
context = get_app_context(prefer_module_context=True)
if context.model_loading:
error_msg = f"模型 '{model}' 正在后台加载中,请稍后重试"
print(f"⚠️ {error_msg}")
return None, error_msg, 503
# 懒加载模式 (--no_auto_load):首次请求仅初始化主槽位(权重 + QwenLM 项目)
if getattr(context.args, 'no_auto_load', False):
try:
ensure_main_slot_ready()
p = project_registry.get(model)
except Exception as e: # noqa: BLE001
import traceback
print(f"⚠️ 模型懒加载失败: {e}")
traceback.print_exc()
return None, f"模型加载失败: {str(e)}", 500
if p is None:
error_msg = f"模型 '{model}' 未加载,请联系管理员"
print(f"⚠️ {error_msg}")
return None, error_msg, 503
return p, None, None
def _log_request(text, stream_mode=False, client_ip=None):
"""
打印请求日志
Returns:
int: 请求ID
"""
from backend.access_log import log_analyze_request
return log_analyze_request(text, stream_mode, client_ip)
def _log_response(res, char_count, elapsed_time, stream_mode=False, request_id=None, wait_time=None):
"""打印响应日志"""
tokens = len(res.get('bpe_strings', []))
text_length = char_count
mode_str = "(stream)" if stream_mode else ""
# 构建日志消息
msg = f"\t📤 API analyze {mode_str} response:"
if request_id is not None:
msg += f" req_id={request_id},"
msg += f" tokens={tokens}, text_length={text_length}"
msg += f", response_time={elapsed_time:.4f}s"
print(msg)
def _validate_and_fix_result(res):
"""验证和修复结果结构"""
if not isinstance(res, dict):
res = {'bpe_strings': []}
if 'bpe_strings' not in res or not isinstance(res.get('bpe_strings'), list):
res['bpe_strings'] = res.get('bpe_strings', []) if isinstance(res.get('bpe_strings'), list) else []
return res
def analyze(analyze_request):
"""
分析文本
Args:
analyze_request: 分析请求字典,包含:
- model: 模型名称
- text: 要分析的文本
- stream: 可选,如果为 True 则返回 SSE 流式响应(带进度信息)
Returns:
如果 stream=True: SSE 响应对象
否则: (响应字典, 状态码) 元组
"""
# 检查模型是否正在加载中(使用模块级上下文)
from backend.app_context import get_app_context
context = get_app_context(prefer_module_context=True)
if context.model_loading:
return _error_response('', '', '模型正在加载中,请稍后重试', 503)
# 在请求上下文中获取 client_ip,流式响应时生成器内可能已失效
from backend.access_log import get_client_ip
client_ip = get_client_ip()
# 检查是否启用流式响应
stream = analyze_request.get('stream', False)
if stream:
return _analyze_with_stream(analyze_request, client_ip)
return _analyze_plain(analyze_request, client_ip)
def _analyze_with_stream(analyze_request, client_ip):
"""
流式分析文本,通过SSE返回进度和结果(内部函数)
Args:
analyze_request: 分析请求字典,包含 model 和 text
client_ip: 客户端 IP,在入口处获取后传入
Returns:
SSE响应对象
"""
reporter = SSEProgressReporter(lambda: _generate_analyze_events(analyze_request, client_ip))
return reporter.create_response()
def _analyze_plain(analyze_request, client_ip):
"""
非流式分析:封装流式实现,消费事件流后返回 JSON。
供脚本等简单客户端使用。
"""
result = None
error_msg = None
status_code = 500
try:
for event_str in _generate_analyze_events(analyze_request, client_ip):
if not event_str.startswith('data: '):
continue
data = json.loads(event_str[6:].strip())
t = data.get('type')
if t == 'result':
result = data.get('data')
elif t == 'error':
error_msg = data.get('message', '分析失败')
status_code = data.get('status_code', 500)
break
except Exception as e:
import traceback
traceback.print_exc()
exit_if_oom(e, defer_seconds=1)
error_msg = f"分析失败: {str(e)}"
finally:
gc.collect()
if error_msg:
model = analyze_request.get('model') or ''
text = analyze_request.get('text') or ''
return _error_response(model, text, error_msg, status_code)
if result is None:
return _error_response('', '', '分析失败:未获取到结果', 500)
return result, 200
def _generate_analyze_events(analyze_request, client_ip):
"""
流式分析核心:生成 SSE 事件流(progress + result/error)。
供 _analyze_with_stream 和 _analyze_plain 复用。
client_ip 需在入口处获取并传入,因流式响应时生成器执行时请求上下文可能已失效。
"""
# 再次检查模型加载状态(在生成器内部,使用模块级上下文)
from backend.app_context import get_app_context
context = get_app_context(prefer_module_context=True)
if context.model_loading:
yield send_error_event('模型正在加载中,请稍后重试', 503)
return
start_time = time.perf_counter()
# 验证和准备请求
model, text, error_msg, status_code = _validate_and_prepare_request(analyze_request)
if error_msg:
yield send_error_event(error_msg, status_code or 400)
return
# 加载模型
p, error_msg, status_code = _load_project_with_error_handling(model)
if error_msg:
yield send_error_event(error_msg, status_code or 500)
return
try:
char_count = len(text) if text else 0
request_id = _log_request(text, stream_mode=True, client_ip=client_ip)
# 创建线程安全的进度队列
progress_queue = queue.Queue()
analysis_done = threading.Event()
analysis_result = None
analysis_error = None
lock_wait_time = None # 记录等待锁的时间
def progress_callback_func(step: int, total_steps: int, stage: str, percentage: Optional[int]):
"""进度回调函数,将事件加入队列"""
progress_queue.put(('progress', step, total_steps, stage, percentage))
def run_analysis():
"""在单独线程中运行分析"""
nonlocal analysis_result, analysis_error, lock_wait_time
try:
# 记录开始等待锁的时间
lock_wait_start = time.perf_counter()
# 尝试获取锁,设置超时避免长时间排队
lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT)
if not lock_acquired:
# 获取锁超时,说明前面有任务在执行且耗时较长
analysis_error = QueueTimeoutError(
f"排队等待超过 {LOCK_WAIT_TIMEOUT} 秒,服务繁忙,请稍后重试"
)
return
# 记录等待时间
lock_wait_time = time.perf_counter() - lock_wait_start
try:
from backend.access_log import log_analyze_start
log_analyze_start(request_id, lock_wait_time, stream_mode=True)
# 在持有锁的情况下执行分析
# 注意:这里的执行时长也会受到 ANALYSIS_TIMEOUT 的监控(在外层循环中)
res = p.lm.analyze_text(text, progress_callback=progress_callback_func)
analysis_result = res
finally:
# 确保锁一定会被释放
_inference_lock.release()
except Exception as e:
analysis_error = e
finally:
analysis_done.set()
progress_queue.put(('done', None, None)) # 发送完成信号
# 启动分析线程
analysis_thread = threading.Thread(target=run_analysis, daemon=True)
analysis_thread.start()
# 实时发送进度事件,并检查超时
timeout_reached = False
for kind, event_str in consume_progress_queue(
progress_queue, analysis_done, start_time, ANALYSIS_TIMEOUT, "分析"
):
if kind == 'timeout':
timeout_reached = True
yield event_str
break
if kind == 'progress':
yield event_str
elif kind == 'done':
break
# 如果超时,不等待分析完成,直接返回
if timeout_reached:
gc.collect()
return
# 检查是否有错误
# 注意:此时已收到 'done' 信号,分析线程已完成其工作(或发生错误)
# 线程是 daemon 的,会自动清理,无需显式等待
if analysis_error:
# 排队超时:返回友好的错误消息
if isinstance(analysis_error, QueueTimeoutError):
print(f"⏱️ 排队超时: {analysis_error}")
yield send_error_event(str(analysis_error), 503)
gc.collect()
return
# 其他错误:抛出异常,由外层的 try-except 处理
raise analysis_error
# 检查结果是否为空(理论上不应该发生,因为要么有结果,要么有错误)
if analysis_result is None:
print("⚠️ 分析结果为空,但没有错误信息")
yield send_error_event("分析失败:未获取到结果", 500)
gc.collect()
return
res = analysis_result
elapsed_time = time.perf_counter() - start_time
_log_response(res, char_count, elapsed_time, stream_mode=True,
request_id=request_id, wait_time=lock_wait_time)
# 验证和修复结果
res = _validate_and_fix_result(res)
# 构建最终响应
final_response = _build_response(model, text, res)
# 发送最终结果
yield send_result_event(final_response)
# 强制垃圾回收以释放内存
gc.collect()
except Exception as e:
import traceback
traceback.print_exc()
exit_if_oom(e, defer_seconds=1)
yield send_error_event(str(e), 500)
# 即使出错也进行垃圾回收
gc.collect()