"""Semantic analysis API:返回原文各 token 对 prompt 的平均关注度""" import gc import json import queue import threading import time from typing import Optional from backend.model_manager import _inference_lock from backend.oom import exit_if_oom from backend.semantic_analyzer import analyze_semantic as _analyze_semantic from backend.api.sse_utils import ( SSEProgressReporter, consume_progress_queue, send_result_event, send_error_event, ) from backend.access_log import get_client_ip from backend.api.analyze import QueueTimeoutError, ANALYSIS_TIMEOUT, LOCK_WAIT_TIMEOUT def _log_request(query, text, client_ip=None): from backend.access_log import log_analyze_semantic_request return log_analyze_semantic_request(query, text, client_ip) def _build_success_response(result, debug_info: bool = False): """构建成功响应。debug_info=True 时包含 debug_info 对象(abbrev、topk_tokens、topk_probs)""" resp = { "success": True, "model": result["model"], "token_attention": result["token_attention"], "full_match_degree": result["full_match_degree"], } if debug_info and "debug_info" in result: resp["debug_info"] = result["debug_info"] return resp def _generate_semantic_events( query: str, text: str, submode: Optional[str] = None, debug_info: bool = False, full_match_degree_only: bool = False, client_ip: Optional[str] = None ): """ 流式语义分析核心:生成 SSE 事件流(progress + result/error)。 供 _analyze_semantic_with_stream 和 _analyze_semantic_plain 复用。 client_ip 需在入口处获取并传入,因流式响应时生成器执行时请求上下文已失效。 """ if client_ip is None: client_ip = get_client_ip() start_time = time.perf_counter() request_id = _log_request(query, text, client_ip) progress_queue = queue.Queue() analysis_done = threading.Event() analysis_result = None analysis_error = None lock_wait_time = None def progress_callback(step: int, total_steps: int, stage: str, percentage: Optional[int]): progress_queue.put(("progress", step, total_steps, stage, percentage)) def run_analysis(): nonlocal analysis_result, analysis_error, lock_wait_time try: lock_wait_start = time.perf_counter() lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT) if not lock_acquired: analysis_error = QueueTimeoutError( f"排队等待超过 {LOCK_WAIT_TIMEOUT} 秒,服务繁忙,请稍后重试" ) return lock_wait_time = time.perf_counter() - lock_wait_start try: from backend.access_log import log_analyze_semantic_start log_analyze_semantic_start(request_id, lock_wait_time, stream_mode=True) result = _analyze_semantic(query, text, submode_override=submode, progress_callback=progress_callback, debug_info=debug_info, full_match_degree_only=full_match_degree_only) analysis_result = result finally: _inference_lock.release() except Exception as e: analysis_error = e finally: analysis_done.set() progress_queue.put(("done", None, None)) try: analysis_thread = threading.Thread(target=run_analysis, daemon=True) analysis_thread.start() timeout_reached = False for kind, event_str in consume_progress_queue( progress_queue, analysis_done, start_time, ANALYSIS_TIMEOUT, "语义分析" ): if kind == 'timeout': timeout_reached = True yield event_str break if kind == 'progress': yield event_str elif kind == 'done': break if timeout_reached: gc.collect() return if analysis_error: if isinstance(analysis_error, QueueTimeoutError): print(f"⏱️ 排队超时: {analysis_error}") yield send_error_event(str(analysis_error), 503) gc.collect() return raise analysis_error if analysis_result is None: print("⚠️ 语义分析结果为空,但没有错误信息") yield send_error_event("分析失败:未获取到结果", 500) gc.collect() return elapsed = time.perf_counter() - start_time tokens = len(analysis_result.get("token_attention", [])) print( f"\t📤 API analyze_semantic (stream) response: req_id={request_id}, " f"tokens={tokens}, response_time={elapsed:.4f}s" ) yield send_result_event(_build_success_response(analysis_result, debug_info)) except Exception as e: import traceback traceback.print_exc() exit_if_oom(e, defer_seconds=1) yield send_error_event(str(e), 500) finally: gc.collect() def _analyze_semantic_with_stream( query: str, text: str, submode: Optional[str] = None, debug_info: bool = False, full_match_degree_only: bool = False, client_ip: Optional[str] = None ): """流式语义分析,通过 SSE 返回阶段级进度""" return SSEProgressReporter( lambda: _generate_semantic_events(query, text, submode, debug_info, full_match_degree_only, client_ip) ).create_response() def _analyze_semantic_plain( query: str, text: str, submode: Optional[str] = None, debug_info: bool = False, full_match_degree_only: bool = False, client_ip: Optional[str] = None ): """ 非流式语义分析:封装流式实现,消费事件流后返回 JSON。 供脚本等简单客户端使用。 """ result = None error_msg = None status_code = 500 try: for event_str in _generate_semantic_events(query, text, submode, debug_info, full_match_degree_only, client_ip): if not event_str.startswith('data: '): continue data = json.loads(event_str[6:].strip()) t = data.get('type') if t == 'result': result = data.get('data') elif t == 'error': error_msg = data.get('message', '分析失败') status_code = data.get('status_code', 500) break except Exception as e: import traceback traceback.print_exc() exit_if_oom(e, defer_seconds=1) error_msg = str(e) finally: gc.collect() if error_msg: return {"success": False, "message": error_msg}, status_code if result is None: return {"success": False, "message": "分析失败:未获取到结果"}, 500 return result, 200 def analyze_semantic(semantic_request): """ 分析原文 token 对 prompt 的关注度。 Args: semantic_request: 包含 query、text、stream(可选)、submode(可选)的字典 Returns: stream=True 时返回 SSE 响应;否则返回 (响应字典, 状态码) 元组 """ query = (semantic_request.get("query") or "") text = semantic_request.get("text") or "" stream = semantic_request.get("stream", False) submode = (semantic_request.get("submode") or "").strip() or None debug_info = bool(semantic_request.get("debug_info", False)) full_match_degree_only = bool(semantic_request.get("full_match_degree_only", False)) if not query: return {"success": False, "message": "缺少 query 字段"}, 400 if not text: return {"success": False, "message": "缺少 text 字段"}, 400 client_ip = get_client_ip() if stream: return _analyze_semantic_with_stream(query, text, submode, debug_info, full_match_degree_only, client_ip) return _analyze_semantic_plain(query, text, submode, debug_info, full_match_degree_only, client_ip)