| import requests |
| import json |
| import base64 |
| import threading |
| import time |
| import uuid |
| from datetime import datetime |
| from typing import Dict, Optional, Any |
|
|
| from utils import logger, mask_email |
| import config |
| from retry import with_retry |
|
|
| class OnDemandAPIClient: |
| """OnDemand API 客户端,处理认证、会话管理和查询""" |
| |
| def __init__(self, email: str, password: str, client_id: str = "default_client"): |
| """初始化客户端 |
| |
| Args: |
| email: OnDemand账户邮箱 |
| password: OnDemand账户密码 |
| client_id: 客户端标识符,用于日志记录 |
| """ |
| self.email = email |
| self.password = password |
| self.client_id = client_id |
| self.token = "" |
| self.refresh_token = "" |
| self.user_id = "" |
| self.company_id = "" |
| self.session_id = "" |
| self.base_url = "https://gateway.on-demand.io/v1" |
| self.chat_base_url = "https://api.on-demand.io/chat/v1/client" |
| self.last_error: Optional[str] = None |
| self.last_activity = datetime.now() |
| self.lock = threading.RLock() |
|
|
| |
| self._associated_user_identifier: Optional[str] = None |
| self._associated_request_ip: Optional[str] = None |
| self._current_request_context_hash: Optional[str] = None |
| |
| |
| masked_email = mask_email(email) |
| logger.info(f"已为 {masked_email} 初始化 OnDemandAPIClient (ID: {client_id})") |
| |
| def _log(self, message: str, level: str = "INFO"): |
| """内部日志方法,使用结构化日志记录 |
| |
| Args: |
| message: 日志消息 |
| level: 日志级别 |
| """ |
| masked_email = mask_email(self.email) |
| log_method = getattr(logger, level.lower(), logger.info) |
| log_method(f"[{self.client_id} / {masked_email}] {message}") |
| self.last_activity = datetime.now() |
|
|
| def get_authorization(self) -> str: |
| """生成登录用 Basic Authorization 头""" |
| text = f"{self.email}:{self.password}" |
| encoded = base64.b64encode(text.encode("utf-8")).decode("utf-8") |
| return encoded |
| |
| def _do_request(self, method: str, url: str, headers: Dict[str, str], |
| data: Optional[Dict] = None, stream: bool = False, |
| timeout: int = None) -> requests.Response: |
| """执行HTTP请求的实际逻辑,不包含重试 |
| |
| Args: |
| method: HTTP方法 (GET, POST等) |
| url: 请求URL |
| headers: HTTP头 |
| data: 请求数据 |
| stream: 是否使用流式传输 |
| timeout: 请求超时时间 |
| |
| Returns: |
| requests.Response对象 |
| |
| Raises: |
| requests.exceptions.RequestException: 请求失败 |
| """ |
| if method.upper() == 'GET': |
| response = requests.get(url, headers=headers, stream=stream, timeout=timeout) |
| elif method.upper() == 'POST': |
| json_data = json.dumps(data) if data else None |
| response = requests.post(url, data=json_data, headers=headers, stream=stream, timeout=timeout) |
| else: |
| raise ValueError(f"不支持的HTTP方法: {method}") |
| |
| response.raise_for_status() |
| return response |
| |
| @with_retry() |
| def sign_in(self, context: Optional[str] = None) -> bool: |
| """登录以获取 token, refreshToken, userId, 和 companyId""" |
| with self.lock: |
| self.last_error = None |
| url = f"{self.base_url}/auth/user/signin" |
| payload = {"accountType": "default"} |
| headers = { |
| 'User-Agent': "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/135.0.0.0 Safari/537.36 Edg/135.0.0.0", |
| 'Accept': "application/json, text/plain, */*", |
| 'Content-Type': "application/json", |
| 'Authorization': f"Basic {self.get_authorization()}", |
| 'Referer': "https://app.on-demand.io/" |
| } |
| if context: |
| self._current_request_context_hash = context |
| |
| try: |
| masked_email = mask_email(self.email) |
| self._log(f"尝试登录 {masked_email}...") |
| |
| |
| response = self._do_request('POST', url, headers, payload, timeout=config.get_config_value('request_timeout')) |
| data = response.json() |
| |
| if config.get_config_value('debug_mode'): |
| |
| debug_data = data.copy() |
| if 'data' in debug_data and 'tokenData' in debug_data['data']: |
| debug_data['data']['tokenData']['token'] = '***REDACTED***' |
| debug_data['data']['tokenData']['refreshToken'] = '***REDACTED***' |
| self._log(f"登录原始响应: {json.dumps(debug_data, indent=2, ensure_ascii=False)}", "DEBUG") |
| |
| self.token = data.get('data', {}).get('tokenData', {}).get('token', '') |
| self.refresh_token = data.get('data', {}).get('tokenData', {}).get('refreshToken', '') |
| self.user_id = data.get('data', {}).get('user', {}).get('userId', '') |
| self.company_id = data.get('data', {}).get('user', {}).get('default_company_id', '') |
| |
| if self.token and self.user_id and self.company_id: |
| self._log(f"登录成功。已获取必要的凭证。") |
| return True |
| else: |
| self.last_error = "登录成功,但未能从响应中提取必要的字段。" |
| self._log(f"登录失败: {self.last_error}", level="ERROR") |
| return False |
| |
| except requests.exceptions.RequestException as e: |
| self.last_error = f"登录请求失败: {e}" |
| self._log(f"登录失败: {e}", level="ERROR") |
| raise |
| |
| except json.JSONDecodeError as e: |
| self.last_error = f"登录 JSON 解码失败: {e}. 响应文本: {response.text if 'response' in locals() else 'N/A'}" |
| self._log(self.last_error, level="ERROR") |
| return False |
| |
| except Exception as e: |
| self.last_error = f"登录过程中发生意外错误: {e}" |
| self._log(self.last_error, level="ERROR") |
| return False |
|
|
| @with_retry() |
| def refresh_token_if_needed(self) -> bool: |
| """如果令牌过期或无效,则刷新令牌 |
| |
| Returns: |
| bool: 刷新成功返回True,否则返回False |
| """ |
| with self.lock: |
| self.last_error = None |
| if not self.refresh_token: |
| self.last_error = "没有可用的 refresh token 来刷新令牌。" |
| self._log(self.last_error, level="WARNING") |
| return False |
|
|
| url = f"{self.base_url}/auth/user/refresh_token" |
| payload = {"data": {"token": self.token, "refreshToken": self.refresh_token}} |
| headers = {'Content-Type': "application/json"} |
| |
| try: |
| self._log("尝试刷新令牌...") |
| |
| |
| response = self._do_request('POST', url, headers, payload, timeout=config.get_config_value('request_timeout')) |
| data = response.json() |
| |
| if config.get_config_value('debug_mode'): |
| |
| debug_data = data.copy() |
| if 'data' in debug_data: |
| if 'token' in debug_data['data']: |
| debug_data['data']['token'] = '***REDACTED***' |
| if 'refreshToken' in debug_data['data']: |
| debug_data['data']['refreshToken'] = '***REDACTED***' |
| self._log(f"刷新令牌原始响应: {json.dumps(debug_data, indent=2, ensure_ascii=False)}", "DEBUG") |
| |
| new_token = data.get('data', {}).get('token', '') |
| new_refresh_token = data.get('data', {}).get('refreshToken', '') |
| |
| if new_token: |
| self.token = new_token |
| if new_refresh_token: |
| self.refresh_token = new_refresh_token |
| self._log("令牌刷新成功。") |
| return True |
| else: |
| self.last_error = "令牌刷新成功,但响应中没有新的 token。" |
| self._log(f"令牌刷新失败: {self.last_error}", level="ERROR") |
| return False |
| |
| except requests.exceptions.RequestException as e: |
| self.last_error = f"令牌刷新请求失败: {e}" |
| self._log(f"令牌刷新失败: {e}", level="ERROR") |
| |
| |
| if hasattr(e, 'response') and e.response is not None and e.response.status_code == 401: |
| self._log("令牌刷新返回401错误,可能需要完全重新登录", level="WARNING") |
| |
| raise |
| |
| except json.JSONDecodeError as e: |
| self.last_error = f"令牌刷新 JSON 解码失败: {e}. 响应文本: {response.text if 'response' in locals() else 'N/A'}" |
| self._log(self.last_error, level="ERROR") |
| return False |
| |
| except Exception as e: |
| self.last_error = f"令牌刷新过程中发生意外错误: {e}" |
| self._log(self.last_error, level="ERROR") |
| return False |
|
|
| @with_retry() |
| def create_session(self, external_user_id: str = "openai-adapter-user", external_context: Optional[str] = None) -> bool: |
| """为聊天创建一个新会话 |
| |
| Args: |
| external_user_id: 外部用户ID前缀,会附加UUID确保唯一性 |
| external_context: 外部上下文哈希 (可选) |
| |
| Returns: |
| bool: 创建成功返回True,否则返回False |
| """ |
| with self.lock: |
| self.last_error = None |
| if external_context: |
| self._current_request_context_hash = external_context |
| if not self.token or not self.user_id or not self.company_id: |
| self.last_error = "创建会话缺少 token, user_id, 或 company_id。正在尝试登录。" |
| self._log(self.last_error, level="WARNING") |
| if not self.sign_in(): |
| self.last_error = f"无法创建会话:登录失败。最近的客户端错误: {self.last_error}" |
| return False |
|
|
| url = f"{self.chat_base_url}/sessions" |
| |
| unique_id = f"{external_user_id}-{uuid.uuid4().hex}" |
| payload = {"externalUserId": unique_id, "pluginIds": []} |
| headers = { |
| 'Content-Type': "application/json", |
| 'Authorization': f"Bearer {self.token}", |
| 'x-company-id': self.company_id, |
| 'x-user-id': self.user_id |
| } |
| |
| self._log(f"尝试创建会话,company_id: {self.company_id}, user_id: {self.user_id}, external_id: {unique_id}") |
| |
| try: |
| try: |
| |
| response = self._do_request('POST', url, headers, payload, timeout=config.get_config_value('request_timeout')) |
| except requests.exceptions.HTTPError as e: |
| |
| if e.response.status_code == 401: |
| self._log("创建会话时令牌过期,尝试刷新...", level="INFO") |
| if self.refresh_token_if_needed(): |
| headers['Authorization'] = f"Bearer {self.token}" |
| response = self._do_request('POST', url, headers, payload, timeout=config.get_config_value('request_timeout')) |
| else: |
| self._log("令牌刷新失败。尝试完全重新登录以创建会话。", level="WARNING") |
| if self.sign_in(): |
| headers['Authorization'] = f"Bearer {self.token}" |
| response = self._do_request('POST', url, headers, payload, timeout=config.get_config_value('request_timeout')) |
| else: |
| self.last_error = f"会话创建失败:令牌刷新和重新登录均失败。最近的客户端错误: {self.last_error}" |
| self._log(self.last_error, level="ERROR") |
| return False |
| else: |
| |
| raise |
| |
| data = response.json() |
| |
| if config.get_config_value('debug_mode'): |
| self._log(f"创建会话原始响应: {json.dumps(data, indent=2, ensure_ascii=False)}", "DEBUG") |
| |
| session_id_val = data.get('data', {}).get('id', '') |
| if session_id_val: |
| self.session_id = session_id_val |
| self._log(f"会话创建成功。会话 ID: {self.session_id}") |
| return True |
| else: |
| self.last_error = f"会话创建成功,但响应中没有会话 ID。" |
| self._log(f"会话创建失败: {self.last_error}", level="ERROR") |
| return False |
| |
| except requests.exceptions.RequestException as e: |
| self.last_error = f"会话创建请求失败: {e}" |
| self._log(f"会话创建失败: {e}", level="ERROR") |
| raise |
| |
| except json.JSONDecodeError as e: |
| self.last_error = f"会话创建 JSON 解码失败: {e}. 响应文本: {response.text if 'response' in locals() else 'N/A'}" |
| self._log(self.last_error, level="ERROR") |
| return False |
| |
| except Exception as e: |
| self.last_error = f"会话创建过程中发生意外错误: {e}" |
| self._log(self.last_error, level="ERROR") |
| return False |
|
|
| @with_retry() |
| def send_query(self, query: str, endpoint_id: str = "predefined-claude-3.7-sonnet", |
| stream: bool = False, model_configs_input: Optional[Dict] = None, |
| full_query_override: Optional[str] = None) -> Dict: |
| """向聊天会话发送查询,并处理流式或非流式响应 |
| |
| Args: |
| query: 查询文本 (如果提供了 full_query_override,则此参数被忽略) |
| endpoint_id: OnDemand端点ID |
| stream: 是否使用流式响应 |
| model_configs_input: 模型配置参数,如temperature、maxTokens等 |
| |
| Returns: |
| Dict: 包含响应内容或流对象的字典 |
| """ |
| with self.lock: |
| self.last_error = None |
| |
| |
| if not self.session_id: |
| self.last_error = "没有可用的会话 ID。正在尝试创建新会话。" |
| self._log(self.last_error, level="WARNING") |
| if not self.create_session(): |
| self.last_error = f"查询失败:会话创建失败。最近的客户端错误: {self.last_error}" |
| self._log(self.last_error, level="ERROR") |
| return {"error": self.last_error} |
| |
| if not self.token: |
| self.last_error = "发送查询没有可用的 token。" |
| self._log(self.last_error, level="ERROR") |
| return {"error": self.last_error} |
|
|
| url = f"{self.chat_base_url}/sessions/{self.session_id}/query" |
| |
| |
| current_query = "" |
| if query is None: |
| self._log("警告:查询内容为None,已替换为空字符串", level="WARNING") |
| elif not isinstance(query, str): |
| current_query = str(query) |
| self._log(f"警告:查询内容不是字符串类型,已转换为字符串: {type(query)} -> {type(current_query)}", level="WARNING") |
| else: |
| current_query = query |
|
|
| |
| query_to_send = full_query_override if full_query_override is not None else current_query |
| if full_query_override is not None: |
| self._log(f"使用 full_query_override (长度: {len(full_query_override)}) 代替原始 query。", "DEBUG") |
|
|
| payload = { |
| "endpointId": endpoint_id, |
| "query": query_to_send, |
| "pluginIds": [], |
| "responseMode": "stream" if stream else "sync", |
| "debugMode": "on" if config.get_config_value('debug_mode') else "off", |
| "fulfillmentOnly": False |
| } |
| |
| |
| if model_configs_input: |
| |
| |
| |
| processed_model_configs = {k: v for k, v in model_configs_input.items() if v is not None} |
| if processed_model_configs: |
| payload["modelConfigs"] = processed_model_configs |
| |
| self._log(f"最终的payload: {json.dumps(payload, ensure_ascii=False)}", level="DEBUG") |
| |
| headers = { |
| 'Content-Type': "application/json", |
| 'Authorization': f"Bearer {self.token}", |
| 'x-company-id': self.company_id |
| } |
| |
| truncated_query_log = current_query[:100] + "..." if len(current_query) > 100 else current_query |
| self._log(f"向端点 {endpoint_id} 发送查询 (stream={stream})。查询内容: {truncated_query_log}") |
|
|
| try: |
| response = self._do_request('POST', url, headers, payload, stream=True, timeout=config.get_config_value('stream_timeout')) |
|
|
| if stream: |
| self._log("返回流式响应对象供外部处理") |
| return {"stream": True, "response_obj": response} |
| else: |
| full_answer = "" |
| try: |
| |
| |
| |
| response_body = response.text |
| response.close() |
|
|
| self._log(f"非流式响应原始文本 (前500字符): {response_body[:500]}", "DEBUG") |
| |
| try: |
| |
| data = json.loads(response_body) |
| if isinstance(data, dict): |
| if "answer" in data and isinstance(data["answer"], str): |
| full_answer = data["answer"] |
| elif "content" in data and isinstance(data["content"], str): |
| full_answer = data["content"] |
| elif data.get("eventType") == "fulfillment" and "answer" in data: |
| full_answer = data.get("answer", "") |
| else: |
| if not full_answer: |
| self._log(f"非流式响应解析为JSON后,未在顶层或常见字段找到答案: {response_body[:200]}", "WARNING") |
| else: |
| self._log(f"非流式响应解析为JSON后,不是字典类型: {type(data)}", "WARNING") |
| |
| except json.JSONDecodeError: |
| |
| self._log(f"非流式响应直接解析JSON失败,尝试按SSE行解析: {response_body[:200]}", "WARNING") |
| for line in response_body.splitlines(): |
| if line: |
| decoded_line = line |
| if decoded_line.startswith("data:"): |
| json_str = decoded_line[len("data:"):].strip() |
| if json_str == "[DONE]": |
| break |
| try: |
| event_data = json.loads(json_str) |
| if event_data.get("eventType", "") == "fulfillment": |
| full_answer += event_data.get("answer", "") |
| except json.JSONDecodeError: |
| self._log(f"非流式后备SSE解析时 JSONDecodeError: {json_str}", level="WARNING") |
| continue |
| |
| self._log(f"非流式响应接收完毕。聚合内容长度: {len(full_answer)}") |
| return {"stream": False, "content": full_answer} |
|
|
| except requests.exceptions.RequestException as e: |
| self.last_error = f"非流式请求时发生错误: {e}" |
| self._log(self.last_error, level="ERROR") |
| |
| |
| return {"error": self.last_error, "stream": False, "content": ""} |
| except Exception as e: |
| self.last_error = f"非流式处理中发生意外错误: {e}" |
| self._log(self.last_error, level="ERROR") |
| return {"error": self.last_error, "stream": False, "content": ""} |
| |
| except requests.exceptions.RequestException as e: |
| self.last_error = f"请求失败: {e}" |
| self._log(f"查询失败: {e}", level="ERROR") |
| raise |
| |
| except Exception as e: |
| error_message = f"send_query 过程中发生意外错误: {e}" |
| error_type = type(e).__name__ |
| self.last_error = error_message |
| self._log(f"{error_message} (错误类型: {error_type})", level="CRITICAL") |
| return {"error": str(e)} |