| import { EventEmitter } from 'events'; |
| import { logger } from '@librechat/data-schemas'; |
| import { fetch as undiciFetch, Agent } from 'undici'; |
| import { |
| StdioClientTransport, |
| getDefaultEnvironment, |
| } from '@modelcontextprotocol/sdk/client/stdio.js'; |
| import { Client } from '@modelcontextprotocol/sdk/client/index.js'; |
| import { SSEClientTransport } from '@modelcontextprotocol/sdk/client/sse.js'; |
| import { WebSocketClientTransport } from '@modelcontextprotocol/sdk/client/websocket.js'; |
| import { ResourceListChangedNotificationSchema } from '@modelcontextprotocol/sdk/types.js'; |
| import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; |
| import type { Transport } from '@modelcontextprotocol/sdk/shared/transport.js'; |
| import type { JSONRPCMessage } from '@modelcontextprotocol/sdk/types.js'; |
| import type { |
| RequestInit as UndiciRequestInit, |
| RequestInfo as UndiciRequestInfo, |
| Response as UndiciResponse, |
| } from 'undici'; |
| import type { MCPOAuthTokens } from './oauth/types'; |
| import { withTimeout } from '~/utils/promise'; |
| import type * as t from './types'; |
| import { sanitizeUrlForLogging } from './utils'; |
| import { mcpConfig } from './mcpConfig'; |
|
|
| type FetchLike = (url: string | URL, init?: RequestInit) => Promise<Response>; |
|
|
| function isStdioOptions(options: t.MCPOptions): options is t.StdioOptions { |
| return 'command' in options; |
| } |
|
|
| function isWebSocketOptions(options: t.MCPOptions): options is t.WebSocketOptions { |
| if ('url' in options) { |
| const protocol = new URL(options.url).protocol; |
| return protocol === 'ws:' || protocol === 'wss:'; |
| } |
| return false; |
| } |
|
|
| function isSSEOptions(options: t.MCPOptions): options is t.SSEOptions { |
| if ('url' in options) { |
| const protocol = new URL(options.url).protocol; |
| return protocol !== 'ws:' && protocol !== 'wss:'; |
| } |
| return false; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| function isStreamableHTTPOptions(options: t.MCPOptions): options is t.StreamableHTTPOptions { |
| if ('url' in options && 'type' in options) { |
| const optionType = options.type as string; |
| if (optionType === 'streamable-http' || optionType === 'http') { |
| const protocol = new URL(options.url).protocol; |
| return protocol !== 'ws:' && protocol !== 'wss:'; |
| } |
| } |
| return false; |
| } |
|
|
| const FIVE_MINUTES = 5 * 60 * 1000; |
| const DEFAULT_TIMEOUT = 60000; |
|
|
| interface MCPConnectionParams { |
| serverName: string; |
| serverConfig: t.MCPOptions; |
| userId?: string; |
| oauthTokens?: MCPOAuthTokens | null; |
| } |
|
|
| export class MCPConnection extends EventEmitter { |
| public client: Client; |
| private options: t.MCPOptions; |
| private transport: Transport | null = null; |
| private connectionState: t.ConnectionState = 'disconnected'; |
| private connectPromise: Promise<void> | null = null; |
| private readonly MAX_RECONNECT_ATTEMPTS = 3; |
| public readonly serverName: string; |
| private shouldStopReconnecting = false; |
| private isReconnecting = false; |
| private isInitializing = false; |
| private reconnectAttempts = 0; |
| private readonly userId?: string; |
| private lastPingTime: number; |
| private lastConnectionCheckAt: number = 0; |
| private oauthTokens?: MCPOAuthTokens | null; |
| private requestHeaders?: Record<string, string> | null; |
| private oauthRequired = false; |
| iconPath?: string; |
| timeout?: number; |
| url?: string; |
|
|
| setRequestHeaders(headers: Record<string, string> | null): void { |
| if (!headers) { |
| return; |
| } |
| const normalizedHeaders: Record<string, string> = {}; |
| for (const [key, value] of Object.entries(headers)) { |
| normalizedHeaders[key.toLowerCase()] = value; |
| } |
| this.requestHeaders = normalizedHeaders; |
| } |
|
|
| getRequestHeaders(): Record<string, string> | null | undefined { |
| return this.requestHeaders; |
| } |
|
|
| constructor(params: MCPConnectionParams) { |
| super(); |
| this.options = params.serverConfig; |
| this.serverName = params.serverName; |
| this.userId = params.userId; |
| this.iconPath = params.serverConfig.iconPath; |
| this.timeout = params.serverConfig.timeout; |
| this.lastPingTime = Date.now(); |
| if (params.oauthTokens) { |
| this.oauthTokens = params.oauthTokens; |
| } |
| this.client = new Client( |
| { |
| name: '@librechat/api-client', |
| version: '1.2.3', |
| }, |
| { |
| capabilities: {}, |
| }, |
| ); |
|
|
| this.setupEventListeners(); |
| } |
|
|
| |
| private getLogPrefix(): string { |
| const userPart = this.userId ? `[User: ${this.userId}]` : ''; |
| return `[MCP]${userPart}[${this.serverName}]`; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| private createFetchFunction( |
| getHeaders: () => Record<string, string> | null | undefined, |
| timeout?: number, |
| ): (input: UndiciRequestInfo, init?: UndiciRequestInit) => Promise<UndiciResponse> { |
| return function customFetch( |
| input: UndiciRequestInfo, |
| init?: UndiciRequestInit, |
| ): Promise<UndiciResponse> { |
| const requestHeaders = getHeaders(); |
| const effectiveTimeout = timeout || DEFAULT_TIMEOUT; |
| const agent = new Agent({ |
| bodyTimeout: effectiveTimeout, |
| headersTimeout: effectiveTimeout, |
| }); |
| if (!requestHeaders) { |
| return undiciFetch(input, { ...init, dispatcher: agent }); |
| } |
|
|
| let initHeaders: Record<string, string> = {}; |
| if (init?.headers) { |
| if (init.headers instanceof Headers) { |
| initHeaders = Object.fromEntries(init.headers.entries()); |
| } else if (Array.isArray(init.headers)) { |
| initHeaders = Object.fromEntries(init.headers); |
| } else { |
| initHeaders = init.headers as Record<string, string>; |
| } |
| } |
|
|
| return undiciFetch(input, { |
| ...init, |
| headers: { |
| ...initHeaders, |
| ...requestHeaders, |
| }, |
| dispatcher: agent, |
| }); |
| }; |
| } |
|
|
| private emitError(error: unknown, errorContext: string): void { |
| const errorMessage = error instanceof Error ? error.message : String(error); |
| logger.error(`${this.getLogPrefix()} ${errorContext}: ${errorMessage}`); |
| } |
|
|
| private constructTransport(options: t.MCPOptions): Transport { |
| try { |
| let type: t.MCPOptions['type']; |
| if (isStdioOptions(options)) { |
| type = 'stdio'; |
| } else if (isWebSocketOptions(options)) { |
| type = 'websocket'; |
| } else if (isStreamableHTTPOptions(options)) { |
| |
| type = 'streamable-http'; |
| } else if (isSSEOptions(options)) { |
| type = 'sse'; |
| } else { |
| throw new Error( |
| 'Cannot infer transport type: options.type is not provided and cannot be inferred from other properties.', |
| ); |
| } |
|
|
| switch (type) { |
| case 'stdio': |
| if (!isStdioOptions(options)) { |
| throw new Error('Invalid options for stdio transport.'); |
| } |
| return new StdioClientTransport({ |
| command: options.command, |
| args: options.args, |
| |
| |
| env: { ...getDefaultEnvironment(), ...(options.env ?? {}) }, |
| }); |
|
|
| case 'websocket': |
| if (!isWebSocketOptions(options)) { |
| throw new Error('Invalid options for websocket transport.'); |
| } |
| this.url = options.url; |
| return new WebSocketClientTransport(new URL(options.url)); |
|
|
| case 'sse': { |
| if (!isSSEOptions(options)) { |
| throw new Error('Invalid options for sse transport.'); |
| } |
| this.url = options.url; |
| const url = new URL(options.url); |
| logger.info( |
| `${this.getLogPrefix()} Creating SSE transport: ${sanitizeUrlForLogging(url)}`, |
| ); |
| const abortController = new AbortController(); |
|
|
| |
| const headers = { ...options.headers }; |
| if (this.oauthTokens?.access_token) { |
| headers['Authorization'] = `Bearer ${this.oauthTokens.access_token}`; |
| } |
|
|
| const timeoutValue = this.timeout || DEFAULT_TIMEOUT; |
| const transport = new SSEClientTransport(url, { |
| requestInit: { |
| headers, |
| signal: abortController.signal, |
| }, |
| eventSourceInit: { |
| fetch: (url, init) => { |
| const fetchHeaders = new Headers(Object.assign({}, init?.headers, headers)); |
| const agent = new Agent({ |
| bodyTimeout: timeoutValue, |
| headersTimeout: timeoutValue, |
| }); |
| return undiciFetch(url, { |
| ...init, |
| dispatcher: agent, |
| headers: fetchHeaders, |
| }); |
| }, |
| }, |
| fetch: this.createFetchFunction( |
| this.getRequestHeaders.bind(this), |
| this.timeout, |
| ) as unknown as FetchLike, |
| }); |
|
|
| transport.onclose = () => { |
| logger.info(`${this.getLogPrefix()} SSE transport closed`); |
| this.emit('connectionChange', 'disconnected'); |
| }; |
|
|
| transport.onmessage = (message) => { |
| logger.info(`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`); |
| }; |
|
|
| this.setupTransportErrorHandlers(transport); |
| return transport; |
| } |
|
|
| case 'streamable-http': { |
| if (!isStreamableHTTPOptions(options)) { |
| throw new Error('Invalid options for streamable-http transport.'); |
| } |
| this.url = options.url; |
| const url = new URL(options.url); |
| logger.info( |
| `${this.getLogPrefix()} Creating streamable-http transport: ${sanitizeUrlForLogging(url)}`, |
| ); |
| const abortController = new AbortController(); |
|
|
| |
| const headers = { ...options.headers }; |
| if (this.oauthTokens?.access_token) { |
| headers['Authorization'] = `Bearer ${this.oauthTokens.access_token}`; |
| } |
|
|
| const transport = new StreamableHTTPClientTransport(url, { |
| requestInit: { |
| headers, |
| signal: abortController.signal, |
| }, |
| fetch: this.createFetchFunction( |
| this.getRequestHeaders.bind(this), |
| this.timeout, |
| ) as unknown as FetchLike, |
| }); |
|
|
| transport.onclose = () => { |
| logger.info(`${this.getLogPrefix()} Streamable-http transport closed`); |
| this.emit('connectionChange', 'disconnected'); |
| }; |
|
|
| transport.onmessage = (message: JSONRPCMessage) => { |
| logger.info(`${this.getLogPrefix()} Message received: ${JSON.stringify(message)}`); |
| }; |
|
|
| this.setupTransportErrorHandlers(transport); |
| return transport; |
| } |
|
|
| default: { |
| throw new Error(`Unsupported transport type: ${type}`); |
| } |
| } |
| } catch (error) { |
| this.emitError(error, 'Failed to construct transport'); |
| throw error; |
| } |
| } |
|
|
| private setupEventListeners(): void { |
| this.isInitializing = true; |
| this.on('connectionChange', (state: t.ConnectionState) => { |
| this.connectionState = state; |
| if (state === 'connected') { |
| this.isReconnecting = false; |
| this.isInitializing = false; |
| this.shouldStopReconnecting = false; |
| this.reconnectAttempts = 0; |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| } else if (state === 'error' && !this.isReconnecting && !this.isInitializing) { |
| this.handleReconnection().catch((error) => { |
| logger.error(`${this.getLogPrefix()} Reconnection handler failed:`, error); |
| }); |
| } |
| }); |
|
|
| this.subscribeToResources(); |
| } |
|
|
| private async handleReconnection(): Promise<void> { |
| if ( |
| this.isReconnecting || |
| this.shouldStopReconnecting || |
| this.isInitializing || |
| this.oauthRequired |
| ) { |
| if (this.oauthRequired) { |
| logger.info(`${this.getLogPrefix()} OAuth required, skipping reconnection attempts`); |
| } |
| return; |
| } |
|
|
| this.isReconnecting = true; |
| const backoffDelay = (attempt: number) => Math.min(1000 * Math.pow(2, attempt), 30000); |
|
|
| try { |
| while ( |
| this.reconnectAttempts < this.MAX_RECONNECT_ATTEMPTS && |
| !(this.shouldStopReconnecting as boolean) |
| ) { |
| this.reconnectAttempts++; |
| const delay = backoffDelay(this.reconnectAttempts); |
|
|
| logger.info( |
| `${this.getLogPrefix()} Reconnecting ${this.reconnectAttempts}/${this.MAX_RECONNECT_ATTEMPTS} (delay: ${delay}ms)`, |
| ); |
|
|
| await new Promise((resolve) => setTimeout(resolve, delay)); |
|
|
| try { |
| await this.connect(); |
| this.reconnectAttempts = 0; |
| return; |
| } catch (error) { |
| logger.error(`${this.getLogPrefix()} Reconnection attempt failed:`, error); |
|
|
| if ( |
| this.reconnectAttempts === this.MAX_RECONNECT_ATTEMPTS || |
| (this.shouldStopReconnecting as boolean) |
| ) { |
| logger.error(`${this.getLogPrefix()} Stopping reconnection attempts`); |
| return; |
| } |
| } |
| } |
| } finally { |
| this.isReconnecting = false; |
| } |
| } |
|
|
| private subscribeToResources(): void { |
| this.client.setNotificationHandler(ResourceListChangedNotificationSchema, async () => { |
| this.emit('resourcesChanged'); |
| }); |
| } |
|
|
| async connectClient(): Promise<void> { |
| if (this.connectionState === 'connected') { |
| return; |
| } |
|
|
| if (this.connectPromise) { |
| return this.connectPromise; |
| } |
|
|
| if (this.shouldStopReconnecting) { |
| return; |
| } |
|
|
| this.emit('connectionChange', 'connecting'); |
|
|
| this.connectPromise = (async () => { |
| try { |
| if (this.transport) { |
| try { |
| await this.client.close(); |
| this.transport = null; |
| } catch (error) { |
| logger.warn(`${this.getLogPrefix()} Error closing connection:`, error); |
| } |
| } |
|
|
| this.transport = this.constructTransport(this.options); |
| this.setupTransportDebugHandlers(); |
|
|
| const connectTimeout = this.options.initTimeout ?? 120000; |
| await withTimeout( |
| this.client.connect(this.transport), |
| connectTimeout, |
| `Connection timeout after ${connectTimeout}ms`, |
| ); |
|
|
| this.connectionState = 'connected'; |
| this.emit('connectionChange', 'connected'); |
| this.reconnectAttempts = 0; |
| } catch (error) { |
| |
| if (this.isOAuthError(error)) { |
| logger.warn(`${this.getLogPrefix()} OAuth authentication required`); |
| this.oauthRequired = true; |
| const serverUrl = this.url; |
| logger.debug( |
| `${this.getLogPrefix()} Server URL for OAuth: ${serverUrl ? sanitizeUrlForLogging(serverUrl) : 'undefined'}`, |
| ); |
|
|
| const oauthTimeout = this.options.initTimeout ?? 60000 * 2; |
| |
| const oauthHandledPromise = new Promise<void>((resolve, reject) => { |
| let timeoutId: NodeJS.Timeout | null = null; |
| let oauthHandledListener: (() => void) | null = null; |
| let oauthFailedListener: ((error: Error) => void) | null = null; |
|
|
| |
| const cleanup = () => { |
| if (timeoutId) { |
| clearTimeout(timeoutId); |
| } |
| if (oauthHandledListener) { |
| this.off('oauthHandled', oauthHandledListener); |
| } |
| if (oauthFailedListener) { |
| this.off('oauthFailed', oauthFailedListener); |
| } |
| }; |
|
|
| |
| oauthHandledListener = () => { |
| cleanup(); |
| resolve(); |
| }; |
|
|
| |
| oauthFailedListener = (error: Error) => { |
| cleanup(); |
| reject(error); |
| }; |
|
|
| |
| timeoutId = setTimeout(() => { |
| cleanup(); |
| reject(new Error(`OAuth handling timeout after ${oauthTimeout}ms`)); |
| }, oauthTimeout); |
|
|
| |
| this.once('oauthHandled', oauthHandledListener); |
| this.once('oauthFailed', oauthFailedListener); |
| }); |
|
|
| |
| this.emit('oauthRequired', { |
| serverName: this.serverName, |
| error, |
| serverUrl, |
| userId: this.userId, |
| }); |
|
|
| try { |
| |
| await oauthHandledPromise; |
| |
| this.oauthRequired = false; |
| |
| logger.info( |
| `${this.getLogPrefix()} OAuth handled successfully, connection will be retried`, |
| ); |
| return; |
| } catch (oauthError) { |
| |
| this.oauthRequired = false; |
| logger.error(`${this.getLogPrefix()} OAuth handling failed:`, oauthError); |
| |
| throw error; |
| } |
| } |
|
|
| this.connectionState = 'error'; |
| this.emit('connectionChange', 'error'); |
| throw error; |
| } finally { |
| this.connectPromise = null; |
| } |
| })(); |
|
|
| return this.connectPromise; |
| } |
|
|
| private setupTransportDebugHandlers(): void { |
| if (!this.transport) { |
| return; |
| } |
|
|
| this.transport.onmessage = (msg) => { |
| logger.debug(`${this.getLogPrefix()} Transport received: ${JSON.stringify(msg)}`); |
| }; |
|
|
| const originalSend = this.transport.send.bind(this.transport); |
| this.transport.send = async (msg) => { |
| if ('result' in msg && !('method' in msg) && Object.keys(msg.result ?? {}).length === 0) { |
| if (Date.now() - this.lastPingTime < FIVE_MINUTES) { |
| throw new Error('Empty result'); |
| } |
| this.lastPingTime = Date.now(); |
| } |
| logger.debug(`${this.getLogPrefix()} Transport sending: ${JSON.stringify(msg)}`); |
| return originalSend(msg); |
| }; |
| } |
|
|
| async connect(): Promise<void> { |
| try { |
| await this.disconnect(); |
| await this.connectClient(); |
| if (!(await this.isConnected())) { |
| throw new Error('Connection not established'); |
| } |
| } catch (error) { |
| logger.error(`${this.getLogPrefix()} Connection failed:`, error); |
| throw error; |
| } |
| } |
|
|
| private setupTransportErrorHandlers(transport: Transport): void { |
| transport.onerror = (error) => { |
| if (error && typeof error === 'object' && 'code' in error) { |
| const errorCode = (error as unknown as { code?: number }).code; |
|
|
| |
| if ( |
| errorCode === 404 && |
| String(error?.message).toLowerCase().includes('failed to open sse stream') |
| ) { |
| logger.warn(`${this.getLogPrefix()} SSE stream not available (404). Ignoring.`); |
| return; |
| } |
|
|
| |
| if (errorCode === 401 || errorCode === 403) { |
| logger.warn(`${this.getLogPrefix()} OAuth authentication error detected`); |
| this.emit('oauthError', error); |
| } |
| } |
|
|
| logger.error(`${this.getLogPrefix()} Transport error:`, error); |
|
|
| this.emit('connectionChange', 'error'); |
| }; |
| } |
|
|
| public async disconnect(): Promise<void> { |
| try { |
| if (this.transport) { |
| await this.client.close(); |
| this.transport = null; |
| } |
| if (this.connectionState === 'disconnected') { |
| return; |
| } |
| this.connectionState = 'disconnected'; |
| this.emit('connectionChange', 'disconnected'); |
| } finally { |
| this.connectPromise = null; |
| } |
| } |
|
|
| async fetchResources(): Promise<t.MCPResource[]> { |
| try { |
| const { resources } = await this.client.listResources(); |
| return resources; |
| } catch (error) { |
| this.emitError(error, 'Failed to fetch resources'); |
| return []; |
| } |
| } |
|
|
| async fetchTools() { |
| try { |
| const { tools } = await this.client.listTools(); |
| return tools; |
| } catch (error) { |
| this.emitError(error, 'Failed to fetch tools'); |
| return []; |
| } |
| } |
|
|
| async fetchPrompts(): Promise<t.MCPPrompt[]> { |
| try { |
| const { prompts } = await this.client.listPrompts(); |
| return prompts; |
| } catch (error) { |
| this.emitError(error, 'Failed to fetch prompts'); |
| return []; |
| } |
| } |
|
|
| public async isConnected(): Promise<boolean> { |
| |
| if (this.connectionState !== 'connected') { |
| return false; |
| } |
|
|
| |
| const now = Date.now(); |
| if (now - this.lastConnectionCheckAt < mcpConfig.CONNECTION_CHECK_TTL) { |
| return true; |
| } |
| this.lastConnectionCheckAt = now; |
|
|
| try { |
| |
| await this.client.ping(); |
| return this.connectionState === 'connected'; |
| } catch (error) { |
| |
| const pingUnsupported = |
| error instanceof Error && |
| ((error as Error)?.message.includes('-32601') || |
| (error as Error)?.message.includes('-32602') || |
| (error as Error)?.message.includes('invalid method ping') || |
| (error as Error)?.message.includes('Unsupported method: ping') || |
| (error as Error)?.message.includes('method not found')); |
|
|
| if (!pingUnsupported) { |
| logger.error(`${this.getLogPrefix()} Ping failed:`, error); |
| return false; |
| } |
|
|
| |
| logger.debug( |
| `${this.getLogPrefix()} Server does not support ping method, verifying connection with capabilities`, |
| ); |
|
|
| try { |
| |
| const capabilities = this.client.getServerCapabilities(); |
|
|
| |
| if (capabilities?.tools) { |
| await this.client.listTools(); |
| return this.connectionState === 'connected'; |
| } else if (capabilities?.resources) { |
| await this.client.listResources(); |
| return this.connectionState === 'connected'; |
| } else if (capabilities?.prompts) { |
| await this.client.listPrompts(); |
| return this.connectionState === 'connected'; |
| } else { |
| |
| logger.debug( |
| `${this.getLogPrefix()} No capabilities to test, assuming connected based on state`, |
| ); |
| return this.connectionState === 'connected'; |
| } |
| } catch (capabilityError) { |
| |
| logger.error(`${this.getLogPrefix()} Connection verification failed:`, capabilityError); |
| return false; |
| } |
| } |
| } |
|
|
| public setOAuthTokens(tokens: MCPOAuthTokens): void { |
| this.oauthTokens = tokens; |
| } |
|
|
| private isOAuthError(error: unknown): boolean { |
| if (!error || typeof error !== 'object') { |
| return false; |
| } |
|
|
| |
| if ('message' in error && typeof error.message === 'string') { |
| return error.message.includes('401') || error.message.includes('Non-200 status code (401)'); |
| } |
|
|
| |
| if ('code' in error) { |
| const code = (error as { code?: number }).code; |
| return code === 401 || code === 403; |
| } |
|
|
| return false; |
| } |
| } |
|
|