| import asyncio |
| import logging |
| from contextlib import AsyncExitStack |
| from datetime import timedelta |
| from typing import Generic |
|
|
| from tenacity import ( |
| before_sleep_log, |
| retry, |
| retry_if_exception_type, |
| stop_after_attempt, |
| wait_exponential, |
| ) |
|
|
| from astrbot import logger |
| from astrbot.core.agent.run_context import ContextWrapper |
| from astrbot.core.utils.log_pipe import LogPipe |
|
|
| from .run_context import TContext |
| from .tool import FunctionTool |
|
|
| try: |
| import anyio |
| import mcp |
| from mcp.client.sse import sse_client |
| except (ModuleNotFoundError, ImportError): |
| logger.warning( |
| "Warning: Missing 'mcp' dependency, MCP services will be unavailable." |
| ) |
|
|
| try: |
| from mcp.client.streamable_http import streamablehttp_client |
| except (ModuleNotFoundError, ImportError): |
| logger.warning( |
| "Warning: Missing 'mcp' dependency or MCP library version too old, Streamable HTTP connection unavailable.", |
| ) |
|
|
|
|
| def _prepare_config(config: dict) -> dict: |
| """Prepare configuration, handle nested format""" |
| if config.get("mcpServers"): |
| first_key = next(iter(config["mcpServers"])) |
| config = config["mcpServers"][first_key] |
| config.pop("active", None) |
| return config |
|
|
|
|
| async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]: |
| """Quick test MCP server connectivity""" |
| import aiohttp |
|
|
| cfg = _prepare_config(config.copy()) |
|
|
| url = cfg["url"] |
| headers = cfg.get("headers", {}) |
| timeout = cfg.get("timeout", 10) |
|
|
| try: |
| if "transport" in cfg: |
| transport_type = cfg["transport"] |
| elif "type" in cfg: |
| transport_type = cfg["type"] |
| else: |
| raise Exception("MCP connection config missing transport or type field") |
|
|
| async with aiohttp.ClientSession() as session: |
| if transport_type == "streamable_http": |
| test_payload = { |
| "jsonrpc": "2.0", |
| "method": "initialize", |
| "id": 0, |
| "params": { |
| "protocolVersion": "2024-11-05", |
| "capabilities": {}, |
| "clientInfo": {"name": "test-client", "version": "1.2.3"}, |
| }, |
| } |
| async with session.post( |
| url, |
| headers={ |
| **headers, |
| "Content-Type": "application/json", |
| "Accept": "application/json, text/event-stream", |
| }, |
| json=test_payload, |
| timeout=aiohttp.ClientTimeout(total=timeout), |
| ) as response: |
| if response.status == 200: |
| return True, "" |
| return False, f"HTTP {response.status}: {response.reason}" |
| else: |
| async with session.get( |
| url, |
| headers={ |
| **headers, |
| "Accept": "application/json, text/event-stream", |
| }, |
| timeout=aiohttp.ClientTimeout(total=timeout), |
| ) as response: |
| if response.status == 200: |
| return True, "" |
| return False, f"HTTP {response.status}: {response.reason}" |
|
|
| except asyncio.TimeoutError: |
| return False, f"Connection timeout: {timeout} seconds" |
| except Exception as e: |
| return False, f"{e!s}" |
|
|
|
|
| class MCPClient: |
| def __init__(self) -> None: |
| |
| self.session: mcp.ClientSession | None = None |
| self.exit_stack = AsyncExitStack() |
| self._old_exit_stacks: list[AsyncExitStack] = [] |
|
|
| self.name: str | None = None |
| self.active: bool = True |
| self.tools: list[mcp.Tool] = [] |
| self.server_errlogs: list[str] = [] |
| self.running_event = asyncio.Event() |
|
|
| |
| self._mcp_server_config: dict | None = None |
| self._server_name: str | None = None |
| self._reconnect_lock = asyncio.Lock() |
| self._reconnecting: bool = False |
|
|
| async def connect_to_server(self, mcp_server_config: dict, name: str) -> None: |
| """Connect to MCP server |
| |
| If `url` parameter exists: |
| 1. When transport is specified as `streamable_http`, use Streamable HTTP connection. |
| 2. When transport is specified as `sse`, use SSE connection. |
| 3. If not specified, default to SSE connection to MCP service. |
| |
| Args: |
| mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server |
| |
| """ |
| |
| self._mcp_server_config = mcp_server_config |
| self._server_name = name |
|
|
| cfg = _prepare_config(mcp_server_config.copy()) |
|
|
| def logging_callback( |
| msg: str | mcp.types.LoggingMessageNotificationParams, |
| ) -> None: |
| |
| if isinstance(msg, mcp.types.LoggingMessageNotificationParams): |
| if msg.level in ("warning", "error", "critical", "alert", "emergency"): |
| log_msg = f"[{msg.level.upper()}] {str(msg.data)}" |
| self.server_errlogs.append(log_msg) |
|
|
| if "url" in cfg: |
| success, error_msg = await _quick_test_mcp_connection(cfg) |
| if not success: |
| raise Exception(error_msg) |
|
|
| if "transport" in cfg: |
| transport_type = cfg["transport"] |
| elif "type" in cfg: |
| transport_type = cfg["type"] |
| else: |
| raise Exception("MCP connection config missing transport or type field") |
|
|
| if transport_type != "streamable_http": |
| |
| self._streams_context = sse_client( |
| url=cfg["url"], |
| headers=cfg.get("headers", {}), |
| timeout=cfg.get("timeout", 5), |
| sse_read_timeout=cfg.get("sse_read_timeout", 60 * 5), |
| ) |
| streams = await self.exit_stack.enter_async_context( |
| self._streams_context, |
| ) |
|
|
| |
| read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60)) |
| self.session = await self.exit_stack.enter_async_context( |
| mcp.ClientSession( |
| *streams, |
| read_timeout_seconds=read_timeout, |
| logging_callback=logging_callback, |
| ), |
| ) |
| else: |
| timeout = timedelta(seconds=cfg.get("timeout", 30)) |
| sse_read_timeout = timedelta( |
| seconds=cfg.get("sse_read_timeout", 60 * 5), |
| ) |
| self._streams_context = streamablehttp_client( |
| url=cfg["url"], |
| headers=cfg.get("headers", {}), |
| timeout=timeout, |
| sse_read_timeout=sse_read_timeout, |
| terminate_on_close=cfg.get("terminate_on_close", True), |
| ) |
| read_s, write_s, _ = await self.exit_stack.enter_async_context( |
| self._streams_context, |
| ) |
|
|
| |
| read_timeout = timedelta(seconds=cfg.get("session_read_timeout", 60)) |
| self.session = await self.exit_stack.enter_async_context( |
| mcp.ClientSession( |
| read_stream=read_s, |
| write_stream=write_s, |
| read_timeout_seconds=read_timeout, |
| logging_callback=logging_callback, |
| ), |
| ) |
|
|
| else: |
| server_params = mcp.StdioServerParameters( |
| **cfg, |
| ) |
|
|
| def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None: |
| |
| if isinstance(msg, mcp.types.LoggingMessageNotificationParams): |
| if msg.level in ( |
| "warning", |
| "error", |
| "critical", |
| "alert", |
| "emergency", |
| ): |
| log_msg = f"[{msg.level.upper()}] {str(msg.data)}" |
| self.server_errlogs.append(log_msg) |
|
|
| stdio_transport = await self.exit_stack.enter_async_context( |
| mcp.stdio_client( |
| server_params, |
| errlog=LogPipe( |
| level=logging.INFO, |
| logger=logger, |
| identifier=f"MCPServer-{name}", |
| callback=callback, |
| ), |
| ), |
| ) |
|
|
| |
| self.session = await self.exit_stack.enter_async_context( |
| mcp.ClientSession(*stdio_transport), |
| ) |
| await self.session.initialize() |
|
|
| async def list_tools_and_save(self) -> mcp.ListToolsResult: |
| """List all tools from the server and save them to self.tools""" |
| if not self.session: |
| raise Exception("MCP Client is not initialized") |
| response = await self.session.list_tools() |
| self.tools = response.tools |
| return response |
|
|
| async def _reconnect(self) -> None: |
| """Reconnect to the MCP server using the stored configuration. |
| |
| Uses asyncio.Lock to ensure thread-safe reconnection in concurrent environments. |
| |
| Raises: |
| Exception: raised when reconnection fails |
| """ |
| async with self._reconnect_lock: |
| |
| if self._reconnecting: |
| logger.debug( |
| f"MCP Client {self._server_name} is already reconnecting, skipping" |
| ) |
| return |
|
|
| if not self._mcp_server_config or not self._server_name: |
| raise Exception("Cannot reconnect: missing connection configuration") |
|
|
| self._reconnecting = True |
| try: |
| logger.info( |
| f"Attempting to reconnect to MCP server {self._server_name}..." |
| ) |
|
|
| |
| if self.exit_stack: |
| self._old_exit_stacks.append(self.exit_stack) |
|
|
| |
| self.session = None |
|
|
| |
| self.exit_stack = AsyncExitStack() |
|
|
| |
| await self.connect_to_server(self._mcp_server_config, self._server_name) |
| await self.list_tools_and_save() |
|
|
| logger.info( |
| f"Successfully reconnected to MCP server {self._server_name}" |
| ) |
| except Exception as e: |
| logger.error( |
| f"Failed to reconnect to MCP server {self._server_name}: {e}" |
| ) |
| raise |
| finally: |
| self._reconnecting = False |
|
|
| async def call_tool_with_reconnect( |
| self, |
| tool_name: str, |
| arguments: dict, |
| read_timeout_seconds: timedelta, |
| ) -> mcp.types.CallToolResult: |
| """Call MCP tool with automatic reconnection on failure, max 2 retries. |
| |
| Args: |
| tool_name: tool name |
| arguments: tool arguments |
| read_timeout_seconds: read timeout |
| |
| Returns: |
| MCP tool call result |
| |
| Raises: |
| ValueError: MCP session is not available |
| anyio.ClosedResourceError: raised after reconnection failure |
| """ |
|
|
| @retry( |
| retry=retry_if_exception_type(anyio.ClosedResourceError), |
| stop=stop_after_attempt(2), |
| wait=wait_exponential(multiplier=1, min=1, max=3), |
| before_sleep=before_sleep_log(logger, logging.WARNING), |
| reraise=True, |
| ) |
| async def _call_with_retry(): |
| if not self.session: |
| raise ValueError("MCP session is not available for MCP function tools.") |
|
|
| try: |
| return await self.session.call_tool( |
| name=tool_name, |
| arguments=arguments, |
| read_timeout_seconds=read_timeout_seconds, |
| ) |
| except anyio.ClosedResourceError: |
| logger.warning( |
| f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..." |
| ) |
| |
| await self._reconnect() |
| |
| raise |
|
|
| return await _call_with_retry() |
|
|
| async def cleanup(self) -> None: |
| """Clean up resources including old exit stacks from reconnections""" |
| |
| try: |
| await self.exit_stack.aclose() |
| except Exception as e: |
| logger.debug(f"Error closing current exit stack: {e}") |
|
|
| |
| |
| |
| self._old_exit_stacks.clear() |
|
|
| |
| self.running_event.set() |
|
|
|
|
| class MCPTool(FunctionTool, Generic[TContext]): |
| """A function tool that calls an MCP service.""" |
|
|
| def __init__( |
| self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs |
| ) -> None: |
| super().__init__( |
| name=mcp_tool.name, |
| description=mcp_tool.description or "", |
| parameters=mcp_tool.inputSchema, |
| ) |
| self.mcp_tool = mcp_tool |
| self.mcp_client = mcp_client |
| self.mcp_server_name = mcp_server_name |
|
|
| async def call( |
| self, context: ContextWrapper[TContext], **kwargs |
| ) -> mcp.types.CallToolResult: |
| return await self.mcp_client.call_tool_with_reconnect( |
| tool_name=self.mcp_tool.name, |
| arguments=kwargs, |
| read_timeout_seconds=timedelta(seconds=context.tool_call_timeout), |
| ) |
|
|