InfoLens / backend /api /prediction_attribute.py
dqy08's picture
prediction attribute 统计和log改进. history下拉高度改进;某些demo从14b模型改为1.7b模型,更符合直觉
a0b7722
"""预测归因 API"""
import gc
import time
from backend.model_manager import _inference_lock
from backend.oom import exit_if_oom
from backend.prediction_attributor import analyze_prediction_attribution
from backend.api.analyze import LOCK_WAIT_TIMEOUT
from backend.access_log import get_client_ip, log_prediction_attribute_request
def prediction_attribute(attribution_request):
"""
对上下文文本的下一 token 预测做归因分析。
Args:
attribution_request: 须含 ``context``、``model``。归因目标二选一:
省略 ``target_prediction`` 且省略 ``target_token_id`` 时为 top-1;
或传非空 ``target_prediction``(字符串首 token);
或传 ``target_token_id``(非负整数词表 id);二者不可同时出现。
Returns:
(响应字典, 状态码) 元组
"""
context = attribution_request.get("context")
target_prediction = attribution_request.get("target_prediction")
target_token_id = attribution_request.get("target_token_id")
model = attribution_request.get("model")
source_page = attribution_request.get("source_page")
flow_id = attribution_request.get("flow_id")
flow_step = attribution_request.get("flow_step")
if context is None:
return {"success": False, "message": "Missing required field: context"}, 400
if not isinstance(context, str):
return {"success": False, "message": "context must be a string"}, 400
if context == "":
return {"success": False, "message": "Missing required field: context"}, 400
if target_prediction is not None and not isinstance(target_prediction, str):
return {"success": False, "message": "target_prediction must be a string"}, 400
if target_prediction == "":
return {"success": False, "message": "target_prediction must not be empty"}, 400
if target_token_id is not None and not isinstance(target_token_id, int):
return {"success": False, "message": "target_token_id must be an integer"}, 400
if target_token_id is not None and target_token_id < 0:
return {"success": False, "message": "target_token_id must be >= 0"}, 400
if target_prediction is not None and target_token_id is not None:
return {"success": False, "message": "target_prediction and target_token_id are mutually exclusive"}, 400
if model is None:
return {"success": False, "message": "Missing required field: model"}, 400
if not isinstance(model, str):
return {"success": False, "message": "model must be a string"}, 400
if model not in ("base", "instruct"):
return {"success": False, "message": 'model must be "base" or "instruct"'}, 400
allowed_source_pages = {
"analysis.html",
"chat.html",
"attribution.html",
"gen_attribute.html",
}
if source_page is None:
return {"success": False, "message": "Missing required field: source_page"}, 400
if not isinstance(source_page, str):
return {"success": False, "message": "source_page must be a string"}, 400
if source_page == "":
return {"success": False, "message": "source_page must not be empty"}, 400
if source_page not in allowed_source_pages:
return {
"success": False,
"message": "source_page must be one of: analysis.html, chat.html, attribution.html, gen_attribute.html",
}, 400
if flow_id is not None and not isinstance(flow_id, str):
return {"success": False, "message": "flow_id must be a string"}, 400
if flow_id == "":
return {"success": False, "message": "flow_id must not be empty"}, 400
if flow_step is not None and not isinstance(flow_step, int):
return {"success": False, "message": "flow_step must be an integer"}, 400
if flow_step is not None and flow_step < 0:
return {"success": False, "message": "flow_step must be >= 0"}, 400
is_causal_flow = source_page == "gen_attribute.html"
if is_causal_flow:
if flow_id is None:
return {"success": False, "message": "Missing required field: flow_id for causal flow"}, 400
if flow_step is None:
return {"success": False, "message": "Missing required field: flow_step for causal flow"}, 400
elif flow_id is not None or flow_step is not None:
return {
"success": False,
"message": "flow_id/flow_step are only allowed when source_page is gen_attribute.html",
}, 400
client_ip = get_client_ip()
start_time = time.perf_counter()
request_id = log_prediction_attribute_request(
context=context,
target_prediction=target_prediction,
target_token_id=target_token_id,
model=model,
source_page=source_page,
flow_id=flow_id,
flow_step=flow_step,
client_ip=client_ip,
)
lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT)
if not lock_acquired:
return {
"success": False,
"message": (
f"Queue wait exceeded {LOCK_WAIT_TIMEOUT} seconds; "
"server is busy, please try again later."
),
}, 503
try:
result = analyze_prediction_attribution(
context,
target_prediction,
model=model,
target_token_id=target_token_id,
)
except ValueError as e:
return {"success": False, "message": str(e)}, 400
except Exception as e:
import traceback
traceback.print_exc()
exit_if_oom(e, defer_seconds=1)
return {"success": False, "message": str(e)}, 500
finally:
_inference_lock.release()
gc.collect()
elapsed = time.perf_counter() - start_time
tokens = len(result.get("token_attribution", []))
target_token = result.get("target_token")
if flow_id is None:
print(
f"\t📤 API prediction_attribute response: req_id={request_id}, "
f"target={target_token!r}, tokens={tokens}, response_time={elapsed:.4f}s"
)
else:
print(
f"\t📤 API prediction_attribute response: req_id={request_id}, "
f"flow_id={flow_id!r}, flow_step={flow_step}, "
f"target={target_token!r}, tokens={tokens}, response_time={elapsed:.4f}s"
)
return {"success": True, **result}, 200