File size: 6,411 Bytes
494c9e4 c911b05 494c9e4 c911b05 494c9e4 a0b7722 494c9e4 c911b05 494c9e4 a0b7722 494c9e4 c911b05 a0b7722 c911b05 494c9e4 c911b05 494c9e4 53e5b08 a0b7722 494c9e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """预测归因 API"""
import gc
import time
from backend.model_manager import _inference_lock
from backend.oom import exit_if_oom
from backend.prediction_attributor import analyze_prediction_attribution
from backend.api.analyze import LOCK_WAIT_TIMEOUT
from backend.access_log import get_client_ip, log_prediction_attribute_request
def prediction_attribute(attribution_request):
"""
对上下文文本的下一 token 预测做归因分析。
Args:
attribution_request: 须含 ``context``、``model``。归因目标二选一:
省略 ``target_prediction`` 且省略 ``target_token_id`` 时为 top-1;
或传非空 ``target_prediction``(字符串首 token);
或传 ``target_token_id``(非负整数词表 id);二者不可同时出现。
Returns:
(响应字典, 状态码) 元组
"""
context = attribution_request.get("context")
target_prediction = attribution_request.get("target_prediction")
target_token_id = attribution_request.get("target_token_id")
model = attribution_request.get("model")
source_page = attribution_request.get("source_page")
flow_id = attribution_request.get("flow_id")
flow_step = attribution_request.get("flow_step")
if context is None:
return {"success": False, "message": "Missing required field: context"}, 400
if not isinstance(context, str):
return {"success": False, "message": "context must be a string"}, 400
if context == "":
return {"success": False, "message": "Missing required field: context"}, 400
if target_prediction is not None and not isinstance(target_prediction, str):
return {"success": False, "message": "target_prediction must be a string"}, 400
if target_prediction == "":
return {"success": False, "message": "target_prediction must not be empty"}, 400
if target_token_id is not None and not isinstance(target_token_id, int):
return {"success": False, "message": "target_token_id must be an integer"}, 400
if target_token_id is not None and target_token_id < 0:
return {"success": False, "message": "target_token_id must be >= 0"}, 400
if target_prediction is not None and target_token_id is not None:
return {"success": False, "message": "target_prediction and target_token_id are mutually exclusive"}, 400
if model is None:
return {"success": False, "message": "Missing required field: model"}, 400
if not isinstance(model, str):
return {"success": False, "message": "model must be a string"}, 400
if model not in ("base", "instruct"):
return {"success": False, "message": 'model must be "base" or "instruct"'}, 400
allowed_source_pages = {
"analysis.html",
"chat.html",
"attribution.html",
"gen_attribute.html",
}
if source_page is None:
return {"success": False, "message": "Missing required field: source_page"}, 400
if not isinstance(source_page, str):
return {"success": False, "message": "source_page must be a string"}, 400
if source_page == "":
return {"success": False, "message": "source_page must not be empty"}, 400
if source_page not in allowed_source_pages:
return {
"success": False,
"message": "source_page must be one of: analysis.html, chat.html, attribution.html, gen_attribute.html",
}, 400
if flow_id is not None and not isinstance(flow_id, str):
return {"success": False, "message": "flow_id must be a string"}, 400
if flow_id == "":
return {"success": False, "message": "flow_id must not be empty"}, 400
if flow_step is not None and not isinstance(flow_step, int):
return {"success": False, "message": "flow_step must be an integer"}, 400
if flow_step is not None and flow_step < 0:
return {"success": False, "message": "flow_step must be >= 0"}, 400
is_causal_flow = source_page == "gen_attribute.html"
if is_causal_flow:
if flow_id is None:
return {"success": False, "message": "Missing required field: flow_id for causal flow"}, 400
if flow_step is None:
return {"success": False, "message": "Missing required field: flow_step for causal flow"}, 400
elif flow_id is not None or flow_step is not None:
return {
"success": False,
"message": "flow_id/flow_step are only allowed when source_page is gen_attribute.html",
}, 400
client_ip = get_client_ip()
start_time = time.perf_counter()
request_id = log_prediction_attribute_request(
context=context,
target_prediction=target_prediction,
target_token_id=target_token_id,
model=model,
source_page=source_page,
flow_id=flow_id,
flow_step=flow_step,
client_ip=client_ip,
)
lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT)
if not lock_acquired:
return {
"success": False,
"message": (
f"Queue wait exceeded {LOCK_WAIT_TIMEOUT} seconds; "
"server is busy, please try again later."
),
}, 503
try:
result = analyze_prediction_attribution(
context,
target_prediction,
model=model,
target_token_id=target_token_id,
)
except ValueError as e:
return {"success": False, "message": str(e)}, 400
except Exception as e:
import traceback
traceback.print_exc()
exit_if_oom(e, defer_seconds=1)
return {"success": False, "message": str(e)}, 500
finally:
_inference_lock.release()
gc.collect()
elapsed = time.perf_counter() - start_time
tokens = len(result.get("token_attribution", []))
target_token = result.get("target_token")
if flow_id is None:
print(
f"\t📤 API prediction_attribute response: req_id={request_id}, "
f"target={target_token!r}, tokens={tokens}, response_time={elapsed:.4f}s"
)
else:
print(
f"\t📤 API prediction_attribute response: req_id={request_id}, "
f"flow_id={flow_id!r}, flow_step={flow_step}, "
f"target={target_token!r}, tokens={tokens}, response_time={elapsed:.4f}s"
)
return {"success": True, **result}, 200
|