| from __future__ import annotations |
|
|
| import asyncio |
| import copy |
| import json |
| import os |
| import threading |
| import urllib.parse |
| from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping |
| from dataclasses import dataclass |
| from types import MappingProxyType |
| from typing import Any |
|
|
| import aiohttp |
|
|
| from astrbot import logger |
| from astrbot.core import sp |
| from astrbot.core.agent.mcp_client import MCPClient, MCPTool |
| from astrbot.core.agent.tool import FunctionTool, ToolSet |
| from astrbot.core.utils.astrbot_path import get_astrbot_data_path |
|
|
| DEFAULT_MCP_CONFIG = {"mcpServers": {}} |
|
|
| DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 180.0 |
| DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 180.0 |
| MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT" |
| ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT" |
| MAX_MCP_TIMEOUT_SECONDS = 300.0 |
|
|
|
|
| class MCPInitError(Exception): |
| """Base exception for MCP initialization failures.""" |
|
|
|
|
| class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError): |
| """Raised when MCP client initialization exceeds the configured timeout.""" |
|
|
|
|
| class MCPAllServicesFailedError(MCPInitError): |
| """Raised when all configured MCP services fail to initialize.""" |
|
|
|
|
| class MCPShutdownTimeoutError(asyncio.TimeoutError): |
| """Raised when MCP shutdown exceeds the configured timeout.""" |
|
|
| def __init__(self, names: list[str], timeout: float) -> None: |
| self.names = names |
| self.timeout = timeout |
| message = f"MCP 服务关闭超时({timeout:g} 秒):{', '.join(names)}" |
| super().__init__(message) |
|
|
|
|
| @dataclass |
| class MCPInitSummary: |
| total: int |
| success: int |
| failed: list[str] |
|
|
|
|
| @dataclass |
| class _MCPServerRuntime: |
| name: str |
| client: MCPClient |
| shutdown_event: asyncio.Event |
| lifecycle_task: asyncio.Task[None] |
|
|
|
|
| class _MCPClientDictView(Mapping[str, MCPClient]): |
| """Read-only view of MCP clients derived from runtime state.""" |
|
|
| def __init__(self, runtime: dict[str, _MCPServerRuntime]) -> None: |
| self._runtime = runtime |
|
|
| def __getitem__(self, key: str) -> MCPClient: |
| return self._runtime[key].client |
|
|
| def __iter__(self): |
| return iter(self._runtime) |
|
|
| def __len__(self) -> int: |
| return len(self._runtime) |
|
|
|
|
| def _resolve_timeout( |
| timeout: float | int | str | None = None, |
| *, |
| env_name: str = MCP_INIT_TIMEOUT_ENV, |
| default: float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS, |
| ) -> float: |
| """Resolve timeout with precedence: explicit argument > env value > default.""" |
| source = f"环境变量 {env_name}" |
| if timeout is None: |
| timeout = os.getenv(env_name, str(default)) |
| else: |
| source = "显式参数 timeout" |
|
|
| try: |
| timeout_value = float(timeout) |
| except (TypeError, ValueError): |
| logger.warning( |
| f"超时配置({source})={timeout!r} 无效,使用默认值 {default:g} 秒。" |
| ) |
| return default |
|
|
| if timeout_value <= 0: |
| logger.warning( |
| f"超时配置({source})={timeout_value:g} 必须大于 0,使用默认值 {default:g} 秒。" |
| ) |
| return default |
|
|
| if timeout_value > MAX_MCP_TIMEOUT_SECONDS: |
| logger.warning( |
| f"超时配置({source})={timeout_value:g} 过大,已限制为最大值 " |
| f"{MAX_MCP_TIMEOUT_SECONDS:g} 秒,以避免长时间等待。" |
| ) |
| return MAX_MCP_TIMEOUT_SECONDS |
|
|
| return timeout_value |
|
|
|
|
| SUPPORTED_TYPES = [ |
| "string", |
| "number", |
| "object", |
| "array", |
| "boolean", |
| ] |
|
|
| PY_TO_JSON_TYPE = { |
| "int": "number", |
| "float": "number", |
| "bool": "boolean", |
| "str": "string", |
| "dict": "object", |
| "list": "array", |
| "tuple": "array", |
| "set": "array", |
| } |
| |
| FuncTool = FunctionTool |
|
|
|
|
| def _prepare_config(config: dict) -> dict: |
| """准备配置,处理嵌套格式""" |
| if config.get("mcpServers"): |
| first_key = next(iter(config["mcpServers"])) |
| config = config["mcpServers"][first_key] |
| config.pop("active", None) |
| return config |
|
|
|
|
| async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: |
| """快速测试 MCP 服务器可达性""" |
| import aiohttp |
|
|
| cfg = _prepare_config(config.copy()) |
|
|
| url = cfg["url"] |
| headers = cfg.get("headers", {}) |
| timeout = cfg.get("timeout", 10) |
|
|
| try: |
| async with aiohttp.ClientSession() as session: |
| if cfg.get("transport") == "streamable_http": |
| test_payload = { |
| "jsonrpc": "2.0", |
| "method": "initialize", |
| "id": 0, |
| "params": { |
| "protocolVersion": "2024-11-05", |
| "capabilities": {}, |
| "clientInfo": {"name": "test-client", "version": "1.2.3"}, |
| }, |
| } |
| async with session.post( |
| url, |
| headers={ |
| **headers, |
| "Content-Type": "application/json", |
| "Accept": "application/json, text/event-stream", |
| }, |
| json=test_payload, |
| timeout=aiohttp.ClientTimeout(total=timeout), |
| ) as response: |
| if response.status == 200: |
| return True, "" |
| return False, f"HTTP {response.status}: {response.reason}" |
| else: |
| async with session.get( |
| url, |
| headers={ |
| **headers, |
| "Accept": "application/json, text/event-stream", |
| }, |
| timeout=aiohttp.ClientTimeout(total=timeout), |
| ) as response: |
| if response.status == 200: |
| return True, "" |
| return False, f"HTTP {response.status}: {response.reason}" |
|
|
| except asyncio.TimeoutError: |
| return False, f"连接超时: {timeout}秒" |
| except Exception as e: |
| return False, f"{e!s}" |
|
|
|
|
| class FunctionToolManager: |
| def __init__(self) -> None: |
| self.func_list: list[FuncTool] = [] |
| self._mcp_server_runtime: dict[str, _MCPServerRuntime] = {} |
| """MCP 服务运行时状态(唯一事实来源)""" |
| self._mcp_server_runtime_view = MappingProxyType(self._mcp_server_runtime) |
| self._mcp_client_dict_view = _MCPClientDictView(self._mcp_server_runtime) |
| self._timeout_mismatch_warned = False |
| self._timeout_warn_lock = threading.Lock() |
| self._runtime_lock = asyncio.Lock() |
| self._mcp_starting: set[str] = set() |
| self._init_timeout_default = _resolve_timeout( |
| timeout=None, |
| env_name=MCP_INIT_TIMEOUT_ENV, |
| default=DEFAULT_MCP_INIT_TIMEOUT_SECONDS, |
| ) |
| self._enable_timeout_default = _resolve_timeout( |
| timeout=None, |
| env_name=ENABLE_MCP_TIMEOUT_ENV, |
| default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS, |
| ) |
| self._warn_on_timeout_mismatch( |
| self._init_timeout_default, |
| self._enable_timeout_default, |
| ) |
|
|
| @property |
| def mcp_client_dict(self) -> Mapping[str, MCPClient]: |
| """Read-only compatibility view for external callers that still read mcp_client_dict. |
| |
| Note: Mutating this mapping is unsupported and will raise TypeError. |
| """ |
| return self._mcp_client_dict_view |
|
|
| @property |
| def mcp_server_runtime_view(self) -> Mapping[str, _MCPServerRuntime]: |
| """Read-only view of MCP runtime metadata for external callers.""" |
| return self._mcp_server_runtime_view |
|
|
| @property |
| def mcp_server_runtime(self) -> Mapping[str, _MCPServerRuntime]: |
| """Backward-compatible read-only view (deprecated). Do not mutate. |
| |
| Note: Mutations are not supported and will raise TypeError. |
| """ |
| return self._mcp_server_runtime_view |
|
|
| def empty(self) -> bool: |
| return len(self.func_list) == 0 |
|
|
| def spec_to_func( |
| self, |
| name: str, |
| func_args: list[dict], |
| desc: str, |
| handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]], |
| ) -> FuncTool: |
| params = { |
| "type": "object", |
| "properties": {}, |
| } |
| for param in func_args: |
| p = copy.deepcopy(param) |
| p.pop("name", None) |
| params["properties"][param["name"]] = p |
| return FuncTool( |
| name=name, |
| parameters=params, |
| description=desc, |
| handler=handler, |
| ) |
|
|
| def add_func( |
| self, |
| name: str, |
| func_args: list, |
| desc: str, |
| handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]], |
| ) -> None: |
| """添加函数调用工具 |
| |
| @param name: 函数名 |
| @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] |
| @param desc: 函数描述 |
| @param func_obj: 处理函数 |
| """ |
| |
| self.remove_func(name) |
|
|
| self.func_list.append( |
| self.spec_to_func( |
| name=name, |
| func_args=func_args, |
| desc=desc, |
| handler=handler, |
| ), |
| ) |
| logger.info(f"添加函数调用工具: {name}") |
|
|
| def remove_func(self, name: str) -> None: |
| """删除一个函数调用工具。""" |
| for i, f in enumerate(self.func_list): |
| if f.name == name: |
| self.func_list.pop(i) |
| break |
|
|
| def get_func(self, name) -> FuncTool | None: |
| for f in self.func_list: |
| if f.name == name: |
| return f |
|
|
| def get_full_tool_set(self) -> ToolSet: |
| """获取完整工具集""" |
| tool_set = ToolSet(self.func_list.copy()) |
| return tool_set |
|
|
| @staticmethod |
| def _log_safe_mcp_debug_config(cfg: dict) -> None: |
| |
| if "command" in cfg: |
| cmd = cfg["command"] |
| executable = str(cmd[0] if isinstance(cmd, (list, tuple)) and cmd else cmd) |
| args_val = cfg.get("args", []) |
| args_count = ( |
| len(args_val) |
| if isinstance(args_val, (list, tuple)) |
| else (0 if args_val is None else 1) |
| ) |
| logger.debug(f" 命令可执行文件: {executable}, 参数数量: {args_count}") |
| return |
|
|
| if "url" in cfg: |
| parsed = urllib.parse.urlparse(str(cfg["url"])) |
| host = parsed.hostname or "" |
| scheme = parsed.scheme or "unknown" |
| try: |
| port = f":{parsed.port}" if parsed.port else "" |
| except ValueError: |
| port = "" |
| logger.debug(f" 主机: {scheme}://{host}{port}") |
|
|
| async def init_mcp_clients( |
| self, raise_on_all_failed: bool = False |
| ) -> MCPInitSummary: |
| """从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下: |
| ``` |
| { |
| "mcpServers": { |
| "weather": { |
| "command": "uv", |
| "args": [ |
| "--directory", |
| "/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather", |
| "run", |
| "weather.py" |
| ] |
| } |
| } |
| ... |
| } |
| ``` |
| |
| Timeout behavior: |
| - 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值。 |
| - 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT(独立于初始化超时)。 |
| """ |
| data_dir = get_astrbot_data_path() |
|
|
| mcp_json_file = os.path.join(data_dir, "mcp_server.json") |
| if not os.path.exists(mcp_json_file): |
| |
| with open(mcp_json_file, "w", encoding="utf-8") as f: |
| json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) |
| logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}") |
| return MCPInitSummary(total=0, success=0, failed=[]) |
|
|
| with open(mcp_json_file, encoding="utf-8") as f: |
| mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"] |
|
|
| init_timeout = self._init_timeout_default |
| timeout_display = f"{init_timeout:g}" |
|
|
| active_configs: list[tuple[str, dict, asyncio.Event]] = [] |
| for name, cfg in mcp_server_json_obj.items(): |
| if cfg.get("active", True): |
| shutdown_event = asyncio.Event() |
| active_configs.append((name, cfg, shutdown_event)) |
|
|
| if not active_configs: |
| return MCPInitSummary(total=0, success=0, failed=[]) |
|
|
| logger.info(f"等待 {len(active_configs)} 个 MCP 服务初始化...") |
|
|
| init_tasks = [ |
| asyncio.create_task( |
| self._start_mcp_server( |
| name=name, |
| cfg=cfg, |
| shutdown_event=shutdown_event, |
| timeout=init_timeout, |
| ), |
| name=f"mcp-init:{name}", |
| ) |
| for (name, cfg, shutdown_event) in active_configs |
| ] |
| results = await asyncio.gather(*init_tasks, return_exceptions=True) |
|
|
| success_count = 0 |
| failed_services: list[str] = [] |
|
|
| for (name, cfg, _), result in zip(active_configs, results, strict=False): |
| if isinstance(result, Exception): |
| if isinstance(result, MCPInitTimeoutError): |
| logger.error( |
| f"Connected to MCP server {name} timeout ({timeout_display} seconds)" |
| ) |
| else: |
| logger.error(f"Failed to initialize MCP server {name}: {result}") |
| self._log_safe_mcp_debug_config(cfg) |
| failed_services.append(name) |
| async with self._runtime_lock: |
| self._mcp_server_runtime.pop(name, None) |
| continue |
|
|
| success_count += 1 |
|
|
| if failed_services: |
| logger.warning( |
| f"The following MCP services failed to initialize: {', '.join(failed_services)}. " |
| f"Please check the mcp_server.json file and server availability." |
| ) |
|
|
| summary = MCPInitSummary( |
| total=len(active_configs), success=success_count, failed=failed_services |
| ) |
| logger.info( |
| f"MCP services initialization completed: {summary.success}/{summary.total} successful, {len(summary.failed)} failed." |
| ) |
| if summary.total > 0 and summary.success == 0: |
| msg = "All MCP services failed to initialize, please check the mcp_server.json and server availability." |
| if raise_on_all_failed: |
| raise MCPAllServicesFailedError(msg) |
| logger.error(msg) |
| return summary |
|
|
| async def _start_mcp_server( |
| self, |
| name: str, |
| cfg: dict, |
| *, |
| shutdown_event: asyncio.Event | None = None, |
| timeout: float, |
| ) -> None: |
| """Initialize MCP server with timeout and register task/event together. |
| |
| This method is idempotent. If the server is already running, the existing |
| runtime is kept and the new config is ignored. |
| """ |
| async with self._runtime_lock: |
| if name in self._mcp_server_runtime or name in self._mcp_starting: |
| logger.warning( |
| f"Connected to MCP server {name}, ignoring this startup request (timeout={timeout:g})." |
| ) |
| self._log_safe_mcp_debug_config(cfg) |
| return |
| self._mcp_starting.add(name) |
|
|
| if shutdown_event is None: |
| shutdown_event = asyncio.Event() |
|
|
| mcp_client: MCPClient | None = None |
| try: |
| mcp_client = await asyncio.wait_for( |
| self._init_mcp_client(name, cfg), |
| timeout=timeout, |
| ) |
| except asyncio.TimeoutError as exc: |
| raise MCPInitTimeoutError( |
| f"Connected to MCP server {name} timeout ({timeout:g} seconds)" |
| ) from exc |
| except Exception: |
| logger.error(f"Failed to initialize MCP client {name}", exc_info=True) |
| raise |
| finally: |
| if mcp_client is None: |
| async with self._runtime_lock: |
| self._mcp_starting.discard(name) |
|
|
| async def lifecycle() -> None: |
| try: |
| await shutdown_event.wait() |
| logger.info(f"Received shutdown signal for MCP client {name}") |
| except asyncio.CancelledError: |
| logger.debug(f"MCP client {name} task was cancelled") |
| raise |
| finally: |
| await self._terminate_mcp_client(name) |
|
|
| lifecycle_task = asyncio.create_task(lifecycle(), name=f"mcp-client:{name}") |
| async with self._runtime_lock: |
| self._mcp_server_runtime[name] = _MCPServerRuntime( |
| name=name, |
| client=mcp_client, |
| shutdown_event=shutdown_event, |
| lifecycle_task=lifecycle_task, |
| ) |
| self._mcp_starting.discard(name) |
|
|
| async def _shutdown_runtimes( |
| self, |
| runtimes: list[_MCPServerRuntime], |
| timeout: float, |
| *, |
| strict: bool = True, |
| ) -> list[str]: |
| """Shutdown runtimes and wait for lifecycle tasks to complete.""" |
| lifecycle_tasks = [ |
| runtime.lifecycle_task |
| for runtime in runtimes |
| if not runtime.lifecycle_task.done() |
| ] |
| if not lifecycle_tasks: |
| return [] |
|
|
| for runtime in runtimes: |
| runtime.shutdown_event.set() |
|
|
| try: |
| results = await asyncio.wait_for( |
| asyncio.gather(*lifecycle_tasks, return_exceptions=True), |
| timeout=timeout, |
| ) |
| except asyncio.TimeoutError: |
| pending_names = [ |
| runtime.name |
| for runtime in runtimes |
| if not runtime.lifecycle_task.done() |
| ] |
| for task in lifecycle_tasks: |
| if not task.done(): |
| task.cancel() |
| await asyncio.gather(*lifecycle_tasks, return_exceptions=True) |
| if strict: |
| raise MCPShutdownTimeoutError(pending_names, timeout) |
| logger.warning( |
| "MCP server shutdown timeout (%s seconds), the following servers were not fully closed: %s", |
| f"{timeout:g}", |
| ", ".join(pending_names), |
| ) |
| return pending_names |
| else: |
| for result in results: |
| if isinstance(result, asyncio.CancelledError): |
| logger.debug("MCP lifecycle task was cancelled during shutdown.") |
| elif isinstance(result, Exception): |
| logger.error( |
| "MCP lifecycle task failed during shutdown.", |
| exc_info=(type(result), result, result.__traceback__), |
| ) |
| return [] |
|
|
| async def _cleanup_mcp_client_safely( |
| self, mcp_client: MCPClient, name: str |
| ) -> None: |
| """安全清理单个 MCP 客户端,避免清理异常中断主流程。""" |
| try: |
| await mcp_client.cleanup() |
| except Exception as cleanup_exc: |
| logger.error( |
| f"Failed to cleanup MCP client resources {name}: {cleanup_exc}" |
| ) |
|
|
| async def _init_mcp_client(self, name: str, config: dict) -> MCPClient: |
| """初始化单个MCP客户端""" |
| mcp_client = MCPClient() |
| mcp_client.name = name |
| try: |
| await mcp_client.connect_to_server(config, name) |
| tools_res = await mcp_client.list_tools_and_save() |
| except asyncio.CancelledError: |
| await self._cleanup_mcp_client_safely(mcp_client, name) |
| raise |
| except Exception: |
| await self._cleanup_mcp_client_safely(mcp_client, name) |
| raise |
| logger.debug(f"MCP server {name} list tools response: {tools_res}") |
| tool_names = [tool.name for tool in tools_res.tools] |
|
|
| |
| self.func_list = [ |
| f |
| for f in self.func_list |
| if not (isinstance(f, MCPTool) and f.mcp_server_name == name) |
| ] |
|
|
| |
| for tool in mcp_client.tools: |
| func_tool = MCPTool( |
| mcp_tool=tool, |
| mcp_client=mcp_client, |
| mcp_server_name=name, |
| ) |
| self.func_list.append(func_tool) |
|
|
| logger.info(f"Connected to MCP server {name}, Tools: {tool_names}") |
| return mcp_client |
|
|
| async def _terminate_mcp_client(self, name: str) -> None: |
| """关闭并清理MCP客户端""" |
| async with self._runtime_lock: |
| runtime = self._mcp_server_runtime.get(name) |
| if runtime: |
| client = runtime.client |
| |
| await self._cleanup_mcp_client_safely(client, name) |
| |
| self.func_list = [ |
| f |
| for f in self.func_list |
| if not (isinstance(f, MCPTool) and f.mcp_server_name == name) |
| ] |
| async with self._runtime_lock: |
| self._mcp_server_runtime.pop(name, None) |
| self._mcp_starting.discard(name) |
| logger.info(f"Disconnected from MCP server {name}") |
| return |
|
|
| |
| self.func_list = [ |
| f |
| for f in self.func_list |
| if not (isinstance(f, MCPTool) and f.mcp_server_name == name) |
| ] |
| async with self._runtime_lock: |
| self._mcp_starting.discard(name) |
|
|
| @staticmethod |
| async def test_mcp_server_connection(config: dict) -> list[str]: |
| if "url" in config: |
| success, error_msg = await _quick_test_mcp_connection(config) |
| if not success: |
| raise Exception(error_msg) |
|
|
| mcp_client = MCPClient() |
| try: |
| logger.debug(f"testing MCP server connection with config: {config}") |
| await mcp_client.connect_to_server(config, "test") |
| tools_res = await mcp_client.list_tools_and_save() |
| tool_names = [tool.name for tool in tools_res.tools] |
| finally: |
| logger.debug("Cleaning up MCP client after testing connection.") |
| await mcp_client.cleanup() |
| return tool_names |
|
|
| async def enable_mcp_server( |
| self, |
| name: str, |
| config: dict, |
| shutdown_event: asyncio.Event | None = None, |
| timeout: float | int | str | None = None, |
| ) -> None: |
| """Enable a new MCP server and initialize it. |
| |
| Args: |
| name: The name of the MCP server. |
| config: Configuration for the MCP server. |
| shutdown_event: Event to signal when the MCP client should shut down. |
| timeout: Timeout in seconds for initialization. |
| Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout). |
| |
| Raises: |
| MCPInitTimeoutError: If initialization does not complete within timeout. |
| Exception: If there is an error during initialization. |
| """ |
| if timeout is None: |
| timeout_value = self._enable_timeout_default |
| else: |
| timeout_value = _resolve_timeout( |
| timeout=timeout, |
| env_name=ENABLE_MCP_TIMEOUT_ENV, |
| default=self._enable_timeout_default, |
| ) |
| await self._start_mcp_server( |
| name=name, |
| cfg=config, |
| shutdown_event=shutdown_event, |
| timeout=timeout_value, |
| ) |
|
|
| async def disable_mcp_server( |
| self, |
| name: str | None = None, |
| timeout: float = 10, |
| ) -> None: |
| """Disable an MCP server by its name. |
| |
| Args: |
| name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled. |
| timeout (int): Timeout. |
| |
| Raises: |
| MCPShutdownTimeoutError: If shutdown does not complete within timeout. |
| Only raised when disabling a specific server (name is not None). |
| |
| """ |
| if name: |
| async with self._runtime_lock: |
| runtime = self._mcp_server_runtime.get(name) |
| if runtime is None: |
| return |
|
|
| await self._shutdown_runtimes([runtime], timeout, strict=True) |
| else: |
| async with self._runtime_lock: |
| runtimes = list(self._mcp_server_runtime.values()) |
| await self._shutdown_runtimes(runtimes, timeout, strict=False) |
|
|
| def _warn_on_timeout_mismatch( |
| self, |
| init_timeout: float, |
| enable_timeout: float, |
| ) -> None: |
| if init_timeout == enable_timeout: |
| return |
| with self._timeout_warn_lock: |
| if self._timeout_mismatch_warned: |
| return |
| logger.info( |
| "检测到 MCP 初始化超时与动态启用超时配置不同:" |
| "初始化使用 %s 秒,动态启用使用 %s 秒。如需一致,请设置相同值。", |
| f"{init_timeout:g}", |
| f"{enable_timeout:g}", |
| ) |
| self._timeout_mismatch_warned = True |
|
|
| def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list: |
| """获得 OpenAI API 风格的**已经激活**的工具描述""" |
| tools = [f for f in self.func_list if f.active] |
| toolset = ToolSet(tools) |
| return toolset.openai_schema( |
| omit_empty_parameter_field=omit_empty_parameter_field, |
| ) |
|
|
| def get_func_desc_anthropic_style(self) -> list: |
| """获得 Anthropic API 风格的**已经激活**的工具描述""" |
| tools = [f for f in self.func_list if f.active] |
| toolset = ToolSet(tools) |
| return toolset.anthropic_schema() |
|
|
| def get_func_desc_google_genai_style(self) -> dict: |
| """获得 Google GenAI API 风格的**已经激活**的工具描述""" |
| tools = [f for f in self.func_list if f.active] |
| toolset = ToolSet(tools) |
| return toolset.google_schema() |
|
|
| def deactivate_llm_tool(self, name: str) -> bool: |
| """停用一个已经注册的函数调用工具。 |
| |
| Returns: |
| 如果没找到,会返回 False |
| |
| """ |
| func_tool = self.get_func(name) |
| if func_tool is not None: |
| func_tool.active = False |
|
|
| inactivated_llm_tools: list = sp.get( |
| "inactivated_llm_tools", |
| [], |
| scope="global", |
| scope_id="global", |
| ) |
| if name not in inactivated_llm_tools: |
| inactivated_llm_tools.append(name) |
| sp.put( |
| "inactivated_llm_tools", |
| inactivated_llm_tools, |
| scope="global", |
| scope_id="global", |
| ) |
|
|
| return True |
| return False |
|
|
| |
| def activate_llm_tool(self, name: str, star_map: dict) -> bool: |
| func_tool = self.get_func(name) |
| if func_tool is not None: |
| if func_tool.handler_module_path in star_map: |
| if not star_map[func_tool.handler_module_path].activated: |
| raise ValueError( |
| f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。", |
| ) |
|
|
| func_tool.active = True |
|
|
| inactivated_llm_tools: list = sp.get( |
| "inactivated_llm_tools", |
| [], |
| scope="global", |
| scope_id="global", |
| ) |
| if name in inactivated_llm_tools: |
| inactivated_llm_tools.remove(name) |
| sp.put( |
| "inactivated_llm_tools", |
| inactivated_llm_tools, |
| scope="global", |
| scope_id="global", |
| ) |
|
|
| return True |
| return False |
|
|
| @property |
| def mcp_config_path(self): |
| data_dir = get_astrbot_data_path() |
| return os.path.join(data_dir, "mcp_server.json") |
|
|
| def load_mcp_config(self): |
| if not os.path.exists(self.mcp_config_path): |
| |
| os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True) |
| with open(self.mcp_config_path, "w", encoding="utf-8") as f: |
| json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4) |
| return DEFAULT_MCP_CONFIG |
|
|
| try: |
| with open(self.mcp_config_path, encoding="utf-8") as f: |
| return json.load(f) |
| except Exception as e: |
| logger.error(f"加载 MCP 配置失败: {e}") |
| return DEFAULT_MCP_CONFIG |
|
|
| def save_mcp_config(self, config: dict) -> bool: |
| try: |
| with open(self.mcp_config_path, "w", encoding="utf-8") as f: |
| json.dump(config, f, ensure_ascii=False, indent=4) |
| return True |
| except Exception as e: |
| logger.error(f"保存 MCP 配置失败: {e}") |
| return False |
|
|
| async def sync_modelscope_mcp_servers(self, access_token: str) -> None: |
| """从 ModelScope 平台同步 MCP 服务器配置""" |
| base_url = "https://www.modelscope.cn/openapi/v1" |
| url = f"{base_url}/mcp/servers/operational" |
| headers = { |
| "Authorization": f"Bearer {access_token.strip()}", |
| "Content-Type": "application/json", |
| } |
|
|
| try: |
| async with aiohttp.ClientSession() as session: |
| async with session.get(url, headers=headers) as response: |
| if response.status == 200: |
| data = await response.json() |
| mcp_server_list = data.get("data", {}).get( |
| "mcp_server_list", |
| [], |
| ) |
| local_mcp_config = self.load_mcp_config() |
|
|
| synced_count = 0 |
| for server in mcp_server_list: |
| server_name = server["name"] |
| operational_urls = server.get("operational_urls", []) |
| if not operational_urls: |
| continue |
| url_info = operational_urls[0] |
| server_url = url_info.get("url") |
| if not server_url: |
| continue |
| |
| local_mcp_config["mcpServers"][server_name] = { |
| "url": server_url, |
| "transport": "sse", |
| "active": True, |
| "provider": "modelscope", |
| } |
| synced_count += 1 |
|
|
| if synced_count > 0: |
| self.save_mcp_config(local_mcp_config) |
| tasks = [] |
| for server in mcp_server_list: |
| name = server["name"] |
| tasks.append( |
| self.enable_mcp_server( |
| name=name, |
| config=local_mcp_config["mcpServers"][name], |
| ), |
| ) |
| await asyncio.gather(*tasks) |
| logger.info( |
| f"从 ModelScope 同步了 {synced_count} 个 MCP 服务器", |
| ) |
| else: |
| logger.warning("没有找到可用的 ModelScope MCP 服务器") |
| else: |
| raise Exception( |
| f"ModelScope API 请求失败: HTTP {response.status}", |
| ) |
|
|
| except aiohttp.ClientError as e: |
| raise Exception(f"网络连接错误: {e!s}") |
| except Exception as e: |
| raise Exception(f"同步 ModelScope MCP 服务器时发生错误: {e!s}") |
|
|
| def __str__(self) -> str: |
| return str(self.func_list) |
|
|
| def __repr__(self) -> str: |
| return str(self.func_list) |
|
|
|
|
| |
| FuncCall = FunctionToolManager |
|
|