"""预测归因 API""" import gc import time from backend.model_manager import _inference_lock from backend.oom import exit_if_oom from backend.prediction_attributor import analyze_prediction_attribution from backend.api.analyze import LOCK_WAIT_TIMEOUT from backend.access_log import get_client_ip, log_prediction_attribute_request def prediction_attribute(attribution_request): """ 对上下文文本的下一 token 预测做归因分析。 Args: attribution_request: 须含 ``context``、``model``。归因目标二选一: 省略 ``target_prediction`` 且省略 ``target_token_id`` 时为 top-1; 或传非空 ``target_prediction``(字符串首 token); 或传 ``target_token_id``(非负整数词表 id);二者不可同时出现。 Returns: (响应字典, 状态码) 元组 """ context = attribution_request.get("context") target_prediction = attribution_request.get("target_prediction") target_token_id = attribution_request.get("target_token_id") model = attribution_request.get("model") source_page = attribution_request.get("source_page") flow_id = attribution_request.get("flow_id") flow_step = attribution_request.get("flow_step") if context is None: return {"success": False, "message": "Missing required field: context"}, 400 if not isinstance(context, str): return {"success": False, "message": "context must be a string"}, 400 if context == "": return {"success": False, "message": "Missing required field: context"}, 400 if target_prediction is not None and not isinstance(target_prediction, str): return {"success": False, "message": "target_prediction must be a string"}, 400 if target_prediction == "": return {"success": False, "message": "target_prediction must not be empty"}, 400 if target_token_id is not None and not isinstance(target_token_id, int): return {"success": False, "message": "target_token_id must be an integer"}, 400 if target_token_id is not None and target_token_id < 0: return {"success": False, "message": "target_token_id must be >= 0"}, 400 if target_prediction is not None and target_token_id is not None: return {"success": False, "message": "target_prediction and target_token_id are mutually exclusive"}, 400 if model is None: return {"success": False, "message": "Missing required field: model"}, 400 if not isinstance(model, str): return {"success": False, "message": "model must be a string"}, 400 if model not in ("base", "instruct"): return {"success": False, "message": 'model must be "base" or "instruct"'}, 400 allowed_source_pages = { "analysis.html", "chat.html", "attribution.html", "gen_attribute.html", } if source_page is None: return {"success": False, "message": "Missing required field: source_page"}, 400 if not isinstance(source_page, str): return {"success": False, "message": "source_page must be a string"}, 400 if source_page == "": return {"success": False, "message": "source_page must not be empty"}, 400 if source_page not in allowed_source_pages: return { "success": False, "message": "source_page must be one of: analysis.html, chat.html, attribution.html, gen_attribute.html", }, 400 if flow_id is not None and not isinstance(flow_id, str): return {"success": False, "message": "flow_id must be a string"}, 400 if flow_id == "": return {"success": False, "message": "flow_id must not be empty"}, 400 if flow_step is not None and not isinstance(flow_step, int): return {"success": False, "message": "flow_step must be an integer"}, 400 if flow_step is not None and flow_step < 0: return {"success": False, "message": "flow_step must be >= 0"}, 400 is_causal_flow = source_page == "gen_attribute.html" if is_causal_flow: if flow_id is None: return {"success": False, "message": "Missing required field: flow_id for causal flow"}, 400 if flow_step is None: return {"success": False, "message": "Missing required field: flow_step for causal flow"}, 400 elif flow_id is not None or flow_step is not None: return { "success": False, "message": "flow_id/flow_step are only allowed when source_page is gen_attribute.html", }, 400 client_ip = get_client_ip() start_time = time.perf_counter() request_id = log_prediction_attribute_request( context=context, target_prediction=target_prediction, target_token_id=target_token_id, model=model, source_page=source_page, flow_id=flow_id, flow_step=flow_step, client_ip=client_ip, ) lock_acquired = _inference_lock.acquire(timeout=LOCK_WAIT_TIMEOUT) if not lock_acquired: return { "success": False, "message": ( f"Queue wait exceeded {LOCK_WAIT_TIMEOUT} seconds; " "server is busy, please try again later." ), }, 503 try: result = analyze_prediction_attribution( context, target_prediction, model=model, target_token_id=target_token_id, ) except ValueError as e: return {"success": False, "message": str(e)}, 400 except Exception as e: import traceback traceback.print_exc() exit_if_oom(e, defer_seconds=1) return {"success": False, "message": str(e)}, 500 finally: _inference_lock.release() gc.collect() elapsed = time.perf_counter() - start_time tokens = len(result.get("token_attribution", [])) target_token = result.get("target_token") if flow_id is None: print( f"\t📤 API prediction_attribute response: req_id={request_id}, " f"target={target_token!r}, tokens={tokens}, response_time={elapsed:.4f}s" ) else: print( f"\t📤 API prediction_attribute response: req_id={request_id}, " f"flow_id={flow_id!r}, flow_step={flow_step}, " f"target={target_token!r}, tokens={tokens}, response_time={elapsed:.4f}s" ) return {"success": True, **result}, 200