File size: 8,031 Bytes
494c9e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | """Semantic analysis API:返回原文各 token 对 prompt 的平均关注度"""
import gc
import json
import queue
import threading
import time
from typing import Optional
from backend.model_manager import _inference_lock
from backend.oom import exit_if_oom
from backend.semantic_analyzer import analyze_semantic as _analyze_semantic
from backend.api.sse_utils import (
SSEProgressReporter,
consume_progress_queue,
send_result_event,
send_error_event,
)
from backend.access_log import get_client_ip
from backend.api.analyze import QueueTimeoutError, ANALYSIS_TIMEOUT, LOCK_WAIT_TIMEOUT
def _log_request(query, text, client_ip=None):
from backend.access_log import log_analyze_semantic_request
return log_analyze_semantic_request(query, text, client_ip)
def _build_success_response(result, debug_info: bool = False):
"""构建成功响应。debug_info=True 时包含 debug_info 对象(abbrev、topk_tokens、topk_probs)"""
resp = {
"success": True,
"model": result["model"],
"token_attention": result["token_attention"],
"full_match_degree": result["full_match_degree"],
}
if debug_info and "debug_info" in result:
resp["debug_info"] = result["debug_info"]
return resp
def _generate_semantic_events(
query: str, text: str, submode: Optional[str] = None, debug_info: bool = False,
full_match_degree_only: bool = False, client_ip: Optional[str] = None
):
"""
流式语义分析核心:生成 SSE 事件流(progress + result/error)。
供 _analyze_semantic_with_stream 和 _analyze_semantic_plain 复用。
client_ip 需在入口处获取并传入,因流式响应时生成器执行时请求上下文已失效。
"""
if client_ip is None:
client_ip = get_client_ip()
start_time = time.perf_counter()
request_id = _log_request(query, text, client_ip)
progress_queue = queue.Queue()
analysis_done = threading.Event()
analysis_result = None
analysis_error = None
lock_wait_time = None
def progress_callback(step: int, total_steps: int, stage: str, percentage: Optional[int]):
progress_queue.put(("progress", step, total_steps, stage, percentage))
def run_analysis():
nonlocal analysis_result, analysis_error, lock_wait_time
try:
lock_wait_start = time.perf_counter()
lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT)
if not lock_acquired:
analysis_error = QueueTimeoutError(
f"排队等待超过 {LOCK_WAIT_TIMEOUT} 秒,服务繁忙,请稍后重试"
)
return
lock_wait_time = time.perf_counter() - lock_wait_start
try:
from backend.access_log import log_analyze_semantic_start
log_analyze_semantic_start(request_id, lock_wait_time, stream_mode=True)
result = _analyze_semantic(query, text, submode_override=submode, progress_callback=progress_callback, debug_info=debug_info, full_match_degree_only=full_match_degree_only)
analysis_result = result
finally:
_inference_lock.release()
except Exception as e:
analysis_error = e
finally:
analysis_done.set()
progress_queue.put(("done", None, None))
try:
analysis_thread = threading.Thread(target=run_analysis, daemon=True)
analysis_thread.start()
timeout_reached = False
for kind, event_str in consume_progress_queue(
progress_queue, analysis_done, start_time, ANALYSIS_TIMEOUT, "语义分析"
):
if kind == 'timeout':
timeout_reached = True
yield event_str
break
if kind == 'progress':
yield event_str
elif kind == 'done':
break
if timeout_reached:
gc.collect()
return
if analysis_error:
if isinstance(analysis_error, QueueTimeoutError):
print(f"⏱️ 排队超时: {analysis_error}")
yield send_error_event(str(analysis_error), 503)
gc.collect()
return
raise analysis_error
if analysis_result is None:
print("⚠️ 语义分析结果为空,但没有错误信息")
yield send_error_event("分析失败:未获取到结果", 500)
gc.collect()
return
elapsed = time.perf_counter() - start_time
tokens = len(analysis_result.get("token_attention", []))
print(
f"\t📤 API analyze_semantic (stream) response: req_id={request_id}, "
f"tokens={tokens}, response_time={elapsed:.4f}s"
)
yield send_result_event(_build_success_response(analysis_result, debug_info))
except Exception as e:
import traceback
traceback.print_exc()
exit_if_oom(e, defer_seconds=1)
yield send_error_event(str(e), 500)
finally:
gc.collect()
def _analyze_semantic_with_stream(
query: str, text: str, submode: Optional[str] = None, debug_info: bool = False,
full_match_degree_only: bool = False, client_ip: Optional[str] = None
):
"""流式语义分析,通过 SSE 返回阶段级进度"""
return SSEProgressReporter(
lambda: _generate_semantic_events(query, text, submode, debug_info, full_match_degree_only, client_ip)
).create_response()
def _analyze_semantic_plain(
query: str, text: str, submode: Optional[str] = None, debug_info: bool = False,
full_match_degree_only: bool = False, client_ip: Optional[str] = None
):
"""
非流式语义分析:封装流式实现,消费事件流后返回 JSON。
供脚本等简单客户端使用。
"""
result = None
error_msg = None
status_code = 500
try:
for event_str in _generate_semantic_events(query, text, submode, debug_info, full_match_degree_only, client_ip):
if not event_str.startswith('data: '):
continue
data = json.loads(event_str[6:].strip())
t = data.get('type')
if t == 'result':
result = data.get('data')
elif t == 'error':
error_msg = data.get('message', '分析失败')
status_code = data.get('status_code', 500)
break
except Exception as e:
import traceback
traceback.print_exc()
exit_if_oom(e, defer_seconds=1)
error_msg = str(e)
finally:
gc.collect()
if error_msg:
return {"success": False, "message": error_msg}, status_code
if result is None:
return {"success": False, "message": "分析失败:未获取到结果"}, 500
return result, 200
def analyze_semantic(semantic_request):
"""
分析原文 token 对 prompt 的关注度。
Args:
semantic_request: 包含 query、text、stream(可选)、submode(可选)的字典
Returns:
stream=True 时返回 SSE 响应;否则返回 (响应字典, 状态码) 元组
"""
query = (semantic_request.get("query") or "")
text = semantic_request.get("text") or ""
stream = semantic_request.get("stream", False)
submode = (semantic_request.get("submode") or "").strip() or None
debug_info = bool(semantic_request.get("debug_info", False))
full_match_degree_only = bool(semantic_request.get("full_match_degree_only", False))
if not query:
return {"success": False, "message": "缺少 query 字段"}, 400
if not text:
return {"success": False, "message": "缺少 text 字段"}, 400
client_ip = get_client_ip()
if stream:
return _analyze_semantic_with_stream(query, text, submode, debug_info, full_match_degree_only, client_ip)
return _analyze_semantic_plain(query, text, submode, debug_info, full_match_degree_only, client_ip)
|