| """文本分析 API""" |
| import gc |
| import json |
| import time |
| import queue |
| import threading |
| from typing import Optional |
| from backend.schemas import create_empty_analysis_result |
| from backend.model_manager import project_registry, DEFAULT_MODEL, _inference_lock |
| from model_paths import resolve_hf_path |
| from backend.oom import exit_if_oom |
| from backend.api.sse_utils import ( |
| SSEProgressReporter, |
| consume_progress_queue, |
| send_result_event, |
| send_error_event, |
| ) |
|
|
|
|
| |
| class QueueTimeoutError(Exception): |
| """排队等待获取锁超时""" |
| pass |
|
|
|
|
| |
| |
| ANALYSIS_TIMEOUT = 60.0 |
| |
| LOCK_WAIT_TIMEOUT = 10.0 |
|
|
|
|
| def _analyze_result_model_display(model: Optional[str]) -> Optional[str]: |
| """主分析 result.model:对外统一为 HuggingFace 仓库 id(与 model_paths.resolve_hf_path 一致)。""" |
| if not model or not str(model).strip(): |
| return None |
| return resolve_hf_path(str(model).strip()) |
|
|
|
|
| def _build_response(model: str, text: str, result): |
| """构建标准响应""" |
| |
| if not isinstance(result, dict): |
| result = {} |
| result = result.copy() |
| |
| if 'model' in result: |
| model_value = result.pop('model') |
| else: |
| model_value = model |
| |
| result = {'model': _analyze_result_model_display(model_value), **result} |
| return { |
| "request": {'text': text}, |
| "result": result |
| } |
|
|
|
|
| def _error_response(model: str, text: str, message: str, status_code: int): |
| """构建错误响应(统一格式)""" |
| |
| result = create_empty_analysis_result(message, _analyze_result_model_display(model)) |
| return { |
| "success": False, |
| "message": message, |
| "request": {'text': text or ''}, |
| "result": result |
| }, status_code |
|
|
|
|
| def _validate_and_prepare_request(analyze_request): |
| """ |
| 验证请求并准备参数 |
| |
| Returns: |
| (model, text, error_msg, status_code) 元组 |
| 如果验证失败,返回 (None, None, error_msg, status_code) |
| 如果成功,返回 (model, text, None, None) |
| """ |
| model = analyze_request.get('model') |
| text = analyze_request.get('text') |
| |
| if not text: |
| return None, None, "缺少分析文本,请提供 text 字段", 400 |
| |
| |
| from backend.app_context import get_app_context |
| context = get_app_context(prefer_module_context=True) |
| default_model = context.model_name if context.model_name else DEFAULT_MODEL |
| |
| |
| if not model or model == 'default' or model == '': |
| model = default_model |
| else: |
| |
| if model != default_model: |
| return None, None, f"当前仅支持默认模型 '{default_model}',不允许使用其他模型", 400 |
| |
| return model, text, None, None |
|
|
|
|
| def _load_project_with_error_handling(model): |
| """ |
| 获取已加载的模型;若未加载则根据配置进行懒加载或返回错误。 |
| |
| Returns: |
| (project_obj, error_msg, status_code) 元组 |
| 如果成功,返回 (project_obj, None, None) |
| 如果失败,返回 (None, error_msg, status_code) |
| """ |
| |
| if not project_registry.is_available(model): |
| available_models = list(project_registry.available_model_names()) |
| error_msg = f"❌ 模型 '{model}' 未注册。可用模型: {available_models}" |
| print(error_msg) |
| return None, error_msg, 404 |
| |
| |
| p = project_registry.get(model) |
| if p is None: |
| from backend.app_context import get_app_context |
| from backend.model_manager import ensure_main_slot_ready |
|
|
| context = get_app_context(prefer_module_context=True) |
| if context.model_loading: |
| error_msg = f"模型 '{model}' 正在后台加载中,请稍后重试" |
| print(f"⚠️ {error_msg}") |
| return None, error_msg, 503 |
| |
| if getattr(context.args, 'no_auto_load', False): |
| try: |
| ensure_main_slot_ready() |
| p = project_registry.get(model) |
| except Exception as e: |
| import traceback |
| print(f"⚠️ 模型懒加载失败: {e}") |
| traceback.print_exc() |
| return None, f"模型加载失败: {str(e)}", 500 |
| if p is None: |
| error_msg = f"模型 '{model}' 未加载,请联系管理员" |
| print(f"⚠️ {error_msg}") |
| return None, error_msg, 503 |
| return p, None, None |
|
|
|
|
| def _log_request(text, stream_mode=False, client_ip=None): |
| """ |
| 打印请求日志 |
| |
| Returns: |
| int: 请求ID |
| """ |
| from backend.access_log import log_analyze_request |
| return log_analyze_request(text, stream_mode, client_ip) |
|
|
|
|
| def _log_response(res, char_count, elapsed_time, stream_mode=False, request_id=None, wait_time=None): |
| """打印响应日志""" |
| tokens = len(res.get('bpe_strings', [])) |
| text_length = char_count |
| mode_str = "(stream)" if stream_mode else "" |
| |
| |
| msg = f"\t📤 API analyze {mode_str} response:" |
| if request_id is not None: |
| msg += f" req_id={request_id}," |
| msg += f" tokens={tokens}, text_length={text_length}" |
| msg += f", response_time={elapsed_time:.4f}s" |
| |
| print(msg) |
|
|
|
|
| def _validate_and_fix_result(res): |
| """验证和修复结果结构""" |
| if not isinstance(res, dict): |
| res = {'bpe_strings': []} |
| if 'bpe_strings' not in res or not isinstance(res.get('bpe_strings'), list): |
| res['bpe_strings'] = res.get('bpe_strings', []) if isinstance(res.get('bpe_strings'), list) else [] |
| return res |
|
|
|
|
| def analyze(analyze_request): |
| """ |
| 分析文本 |
| |
| Args: |
| analyze_request: 分析请求字典,包含: |
| - model: 模型名称 |
| - text: 要分析的文本 |
| - stream: 可选,如果为 True 则返回 SSE 流式响应(带进度信息) |
| |
| Returns: |
| 如果 stream=True: SSE 响应对象 |
| 否则: (响应字典, 状态码) 元组 |
| """ |
| |
| from backend.app_context import get_app_context |
| context = get_app_context(prefer_module_context=True) |
| if context.model_loading: |
| return _error_response('', '', '模型正在加载中,请稍后重试', 503) |
|
|
| |
| from backend.access_log import get_client_ip |
| client_ip = get_client_ip() |
|
|
| |
| stream = analyze_request.get('stream', False) |
| if stream: |
| return _analyze_with_stream(analyze_request, client_ip) |
| return _analyze_plain(analyze_request, client_ip) |
|
|
|
|
| def _analyze_with_stream(analyze_request, client_ip): |
| """ |
| 流式分析文本,通过SSE返回进度和结果(内部函数) |
| |
| Args: |
| analyze_request: 分析请求字典,包含 model 和 text |
| client_ip: 客户端 IP,在入口处获取后传入 |
| |
| Returns: |
| SSE响应对象 |
| """ |
| reporter = SSEProgressReporter(lambda: _generate_analyze_events(analyze_request, client_ip)) |
| return reporter.create_response() |
|
|
|
|
| def _analyze_plain(analyze_request, client_ip): |
| """ |
| 非流式分析:封装流式实现,消费事件流后返回 JSON。 |
| 供脚本等简单客户端使用。 |
| """ |
| result = None |
| error_msg = None |
| status_code = 500 |
| try: |
| for event_str in _generate_analyze_events(analyze_request, client_ip): |
| if not event_str.startswith('data: '): |
| continue |
| data = json.loads(event_str[6:].strip()) |
| t = data.get('type') |
| if t == 'result': |
| result = data.get('data') |
| elif t == 'error': |
| error_msg = data.get('message', '分析失败') |
| status_code = data.get('status_code', 500) |
| break |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| exit_if_oom(e, defer_seconds=1) |
| error_msg = f"分析失败: {str(e)}" |
| finally: |
| gc.collect() |
|
|
| if error_msg: |
| model = analyze_request.get('model') or '' |
| text = analyze_request.get('text') or '' |
| return _error_response(model, text, error_msg, status_code) |
| if result is None: |
| return _error_response('', '', '分析失败:未获取到结果', 500) |
| return result, 200 |
|
|
|
|
| def _generate_analyze_events(analyze_request, client_ip): |
| """ |
| 流式分析核心:生成 SSE 事件流(progress + result/error)。 |
| 供 _analyze_with_stream 和 _analyze_plain 复用。 |
| client_ip 需在入口处获取并传入,因流式响应时生成器执行时请求上下文可能已失效。 |
| """ |
| |
| from backend.app_context import get_app_context |
| context = get_app_context(prefer_module_context=True) |
| if context.model_loading: |
| yield send_error_event('模型正在加载中,请稍后重试', 503) |
| return |
|
|
| start_time = time.perf_counter() |
|
|
| |
| model, text, error_msg, status_code = _validate_and_prepare_request(analyze_request) |
| if error_msg: |
| yield send_error_event(error_msg, status_code or 400) |
| return |
|
|
| |
| p, error_msg, status_code = _load_project_with_error_handling(model) |
| if error_msg: |
| yield send_error_event(error_msg, status_code or 500) |
| return |
|
|
| try: |
| char_count = len(text) if text else 0 |
| request_id = _log_request(text, stream_mode=True, client_ip=client_ip) |
|
|
| |
| progress_queue = queue.Queue() |
| analysis_done = threading.Event() |
| analysis_result = None |
| analysis_error = None |
| lock_wait_time = None |
|
|
| def progress_callback_func(step: int, total_steps: int, stage: str, percentage: Optional[int]): |
| """进度回调函数,将事件加入队列""" |
| progress_queue.put(('progress', step, total_steps, stage, percentage)) |
|
|
| def run_analysis(): |
| """在单独线程中运行分析""" |
| nonlocal analysis_result, analysis_error, lock_wait_time |
| try: |
| |
| lock_wait_start = time.perf_counter() |
|
|
| |
| lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT) |
| if not lock_acquired: |
| |
| analysis_error = QueueTimeoutError( |
| f"排队等待超过 {LOCK_WAIT_TIMEOUT} 秒,服务繁忙,请稍后重试" |
| ) |
| return |
|
|
| |
| lock_wait_time = time.perf_counter() - lock_wait_start |
|
|
| try: |
| from backend.access_log import log_analyze_start |
| log_analyze_start(request_id, lock_wait_time, stream_mode=True) |
|
|
| |
| |
| res = p.lm.analyze_text(text, progress_callback=progress_callback_func) |
| analysis_result = res |
| finally: |
| |
| _inference_lock.release() |
| except Exception as e: |
| analysis_error = e |
| finally: |
| analysis_done.set() |
| progress_queue.put(('done', None, None)) |
|
|
| |
| analysis_thread = threading.Thread(target=run_analysis, daemon=True) |
| analysis_thread.start() |
|
|
| |
| timeout_reached = False |
| for kind, event_str in consume_progress_queue( |
| progress_queue, analysis_done, start_time, ANALYSIS_TIMEOUT, "分析" |
| ): |
| if kind == 'timeout': |
| timeout_reached = True |
| yield event_str |
| break |
| if kind == 'progress': |
| yield event_str |
| elif kind == 'done': |
| break |
|
|
| |
| if timeout_reached: |
| gc.collect() |
| return |
|
|
| |
| |
| |
| if analysis_error: |
| |
| if isinstance(analysis_error, QueueTimeoutError): |
| print(f"⏱️ 排队超时: {analysis_error}") |
| yield send_error_event(str(analysis_error), 503) |
| gc.collect() |
| return |
| |
| raise analysis_error |
|
|
| |
| if analysis_result is None: |
| print("⚠️ 分析结果为空,但没有错误信息") |
| yield send_error_event("分析失败:未获取到结果", 500) |
| gc.collect() |
| return |
|
|
| res = analysis_result |
|
|
| elapsed_time = time.perf_counter() - start_time |
| _log_response(res, char_count, elapsed_time, stream_mode=True, |
| request_id=request_id, wait_time=lock_wait_time) |
|
|
| |
| res = _validate_and_fix_result(res) |
|
|
| |
| final_response = _build_response(model, text, res) |
|
|
| |
| yield send_result_event(final_response) |
|
|
| |
| gc.collect() |
|
|
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| exit_if_oom(e, defer_seconds=1) |
| yield send_error_event(str(e), 500) |
| |
| gc.collect() |
|
|
|
|
|
|