| import { logger } from '@librechat/data-schemas'; |
| import type { OAuthClientInformation } from '@modelcontextprotocol/sdk/shared/auth.js'; |
| import type { TokenMethods } from '@librechat/data-schemas'; |
| import type { MCPOAuthTokens, MCPOAuthFlowMetadata, OAuthMetadata } from '~/mcp/oauth'; |
| import type { FlowStateManager } from '~/flow/manager'; |
| import type { FlowMetadata } from '~/flow/types'; |
| import type * as t from './types'; |
| import { MCPTokenStorage, MCPOAuthHandler } from '~/mcp/oauth'; |
| import { sanitizeUrlForLogging } from './utils'; |
| import { withTimeout } from '~/utils/promise'; |
| import { MCPConnection } from './connection'; |
| import { processMCPEnv } from '~/utils'; |
|
|
| |
| |
| |
| |
| |
| export class MCPConnectionFactory { |
| protected readonly serverName: string; |
| protected readonly serverConfig: t.MCPOptions; |
| protected readonly logPrefix: string; |
| protected readonly useOAuth: boolean; |
|
|
| |
| protected readonly userId?: string; |
| protected readonly flowManager?: FlowStateManager<MCPOAuthTokens | null>; |
| protected readonly tokenMethods?: TokenMethods; |
| protected readonly signal?: AbortSignal; |
| protected readonly oauthStart?: (authURL: string) => Promise<void>; |
| protected readonly oauthEnd?: () => Promise<void>; |
| protected readonly returnOnOAuth?: boolean; |
| protected readonly connectionTimeout?: number; |
|
|
| |
| static async create( |
| basic: t.BasicConnectionOptions, |
| oauth?: t.OAuthConnectionOptions, |
| ): Promise<MCPConnection> { |
| const factory = new this(basic, oauth); |
| return factory.createConnection(); |
| } |
|
|
| protected constructor(basic: t.BasicConnectionOptions, oauth?: t.OAuthConnectionOptions) { |
| this.serverConfig = processMCPEnv({ |
| options: basic.serverConfig, |
| user: oauth?.user, |
| customUserVars: oauth?.customUserVars, |
| body: oauth?.requestBody, |
| }); |
| this.serverName = basic.serverName; |
| this.useOAuth = !!oauth?.useOAuth; |
| this.connectionTimeout = oauth?.connectionTimeout; |
| this.logPrefix = oauth?.user |
| ? `[MCP][${basic.serverName}][${oauth.user.id}]` |
| : `[MCP][${basic.serverName}]`; |
|
|
| if (oauth?.useOAuth) { |
| this.userId = oauth.user.id; |
| this.flowManager = oauth.flowManager; |
| this.tokenMethods = oauth.tokenMethods; |
| this.signal = oauth.signal; |
| this.oauthStart = oauth.oauthStart; |
| this.oauthEnd = oauth.oauthEnd; |
| this.returnOnOAuth = oauth.returnOnOAuth; |
| } |
| } |
|
|
| |
| protected async createConnection(): Promise<MCPConnection> { |
| const oauthTokens = this.useOAuth ? await this.getOAuthTokens() : null; |
| const connection = new MCPConnection({ |
| serverName: this.serverName, |
| serverConfig: this.serverConfig, |
| userId: this.userId, |
| oauthTokens, |
| }); |
|
|
| let cleanupOAuthHandlers: (() => void) | null = null; |
| if (this.useOAuth) { |
| cleanupOAuthHandlers = this.handleOAuthEvents(connection); |
| } |
|
|
| try { |
| await this.attemptToConnect(connection); |
| if (cleanupOAuthHandlers) { |
| cleanupOAuthHandlers(); |
| } |
| return connection; |
| } catch (error) { |
| if (cleanupOAuthHandlers) { |
| cleanupOAuthHandlers(); |
| } |
| throw error; |
| } |
| } |
|
|
| |
| protected async getOAuthTokens(): Promise<MCPOAuthTokens | null> { |
| if (!this.tokenMethods?.findToken) return null; |
|
|
| try { |
| const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName); |
| const tokens = await this.flowManager!.createFlowWithHandler( |
| flowId, |
| 'mcp_get_tokens', |
| async () => { |
| return await MCPTokenStorage.getTokens({ |
| userId: this.userId!, |
| serverName: this.serverName, |
| findToken: this.tokenMethods!.findToken!, |
| createToken: this.tokenMethods!.createToken, |
| updateToken: this.tokenMethods!.updateToken, |
| refreshTokens: this.createRefreshTokensFunction(), |
| }); |
| }, |
| this.signal, |
| ); |
|
|
| if (tokens) logger.info(`${this.logPrefix} Loaded OAuth tokens`); |
| return tokens; |
| } catch (error) { |
| logger.debug(`${this.logPrefix} No existing tokens found or error loading tokens`, error); |
| return null; |
| } |
| } |
|
|
| |
| protected createRefreshTokensFunction(): ( |
| refreshToken: string, |
| metadata: { |
| userId: string; |
| serverName: string; |
| identifier: string; |
| clientInfo?: OAuthClientInformation; |
| }, |
| ) => Promise<MCPOAuthTokens> { |
| return async (refreshToken, metadata) => { |
| return await MCPOAuthHandler.refreshOAuthTokens( |
| refreshToken, |
| { |
| serverUrl: (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url, |
| serverName: metadata.serverName, |
| clientInfo: metadata.clientInfo, |
| }, |
| this.serverConfig.oauth_headers ?? {}, |
| this.serverConfig.oauth, |
| ); |
| }; |
| } |
|
|
| |
| protected handleOAuthEvents(connection: MCPConnection): () => void { |
| const oauthHandler = async (data: { serverUrl?: string }) => { |
| logger.info(`${this.logPrefix} oauthRequired event received`); |
|
|
| |
| if (this.returnOnOAuth) { |
| try { |
| const config = this.serverConfig; |
| const { authorizationUrl, flowId, flowMetadata } = |
| await MCPOAuthHandler.initiateOAuthFlow( |
| this.serverName, |
| data.serverUrl || '', |
| this.userId!, |
| config?.oauth_headers ?? {}, |
| config?.oauth, |
| ); |
|
|
| |
| |
| this.flowManager!.createFlow(flowId, 'mcp_oauth', flowMetadata).catch(() => { |
| |
| |
| }); |
|
|
| if (this.oauthStart) { |
| logger.info(`${this.logPrefix} OAuth flow started, issuing authorization URL`); |
| await this.oauthStart(authorizationUrl); |
| } |
|
|
| |
| |
| connection.emit('oauthFailed', new Error('OAuth flow initiated - return early')); |
| return; |
| } catch (error) { |
| logger.error(`${this.logPrefix} Failed to initiate OAuth flow`, error); |
| connection.emit('oauthFailed', new Error('OAuth initiation failed')); |
| return; |
| } |
| } |
|
|
| |
| const result = await this.handleOAuthRequired(); |
|
|
| if (result?.tokens && this.tokenMethods?.createToken) { |
| try { |
| connection.setOAuthTokens(result.tokens); |
| await MCPTokenStorage.storeTokens({ |
| userId: this.userId!, |
| serverName: this.serverName, |
| tokens: result.tokens, |
| createToken: this.tokenMethods.createToken, |
| updateToken: this.tokenMethods.updateToken, |
| findToken: this.tokenMethods.findToken, |
| clientInfo: result.clientInfo, |
| metadata: result.metadata, |
| }); |
| logger.info(`${this.logPrefix} OAuth tokens saved to storage`); |
| } catch (error) { |
| logger.error(`${this.logPrefix} Failed to save OAuth tokens to storage`, error); |
| } |
| } |
|
|
| |
| if (result?.tokens) { |
| connection.emit('oauthHandled'); |
| } else { |
| |
| logger.warn(`${this.logPrefix} OAuth failed, emitting oauthFailed event`); |
| connection.emit('oauthFailed', new Error('OAuth authentication failed')); |
| } |
| }; |
|
|
| connection.on('oauthRequired', oauthHandler); |
|
|
| return () => { |
| connection.removeListener('oauthRequired', oauthHandler); |
| }; |
| } |
|
|
| |
| protected async attemptToConnect(connection: MCPConnection): Promise<void> { |
| const connectTimeout = this.connectionTimeout ?? this.serverConfig.initTimeout ?? 30000; |
| await withTimeout( |
| this.connectTo(connection), |
| connectTimeout, |
| `Connection timeout after ${connectTimeout}ms`, |
| ); |
|
|
| if (await connection.isConnected()) return; |
| logger.error(`${this.logPrefix} Failed to establish connection.`); |
| } |
|
|
| |
| private async connectTo(connection: MCPConnection): Promise<void> { |
| const maxAttempts = 3; |
| let attempts = 0; |
| let oauthHandled = false; |
|
|
| while (attempts < maxAttempts) { |
| try { |
| await connection.connect(); |
| if (await connection.isConnected()) { |
| return; |
| } |
| throw new Error('Connection attempt succeeded but status is not connected'); |
| } catch (error) { |
| attempts++; |
|
|
| if (this.useOAuth && this.isOAuthError(error)) { |
| |
| if (this.oauthStart && !oauthHandled) { |
| const errorWithFlag = error as (Error & { isOAuthError?: boolean }) | undefined; |
| if (errorWithFlag?.isOAuthError) { |
| oauthHandled = true; |
| logger.info(`${this.logPrefix} Handling OAuth`); |
| await this.handleOAuthRequired(); |
| } |
| } |
| |
| logger.info(`${this.logPrefix} OAuth required, stopping connection attempts`); |
| throw error; |
| } |
|
|
| if (attempts === maxAttempts) { |
| logger.error(`${this.logPrefix} Failed to connect after ${maxAttempts} attempts`, error); |
| throw error; |
| } |
| await new Promise((resolve) => setTimeout(resolve, 2000 * attempts)); |
| } |
| } |
| } |
|
|
| |
| private isOAuthError(error: unknown): boolean { |
| if (!error || typeof error !== 'object') { |
| return false; |
| } |
|
|
| |
| if ('message' in error && typeof error.message === 'string') { |
| return error.message.includes('401') || error.message.includes('Non-200 status code (401)'); |
| } |
|
|
| |
| if ('code' in error) { |
| const code = (error as { code?: number }).code; |
| return code === 401 || code === 403; |
| } |
|
|
| return false; |
| } |
|
|
| |
| protected async handleOAuthRequired(): Promise<{ |
| tokens: MCPOAuthTokens | null; |
| clientInfo?: OAuthClientInformation; |
| metadata?: OAuthMetadata; |
| } | null> { |
| const serverUrl = (this.serverConfig as t.SSEOptions | t.StreamableHTTPOptions).url; |
| logger.debug( |
| `${this.logPrefix} \`handleOAuthRequired\` called with serverUrl: ${serverUrl ? sanitizeUrlForLogging(serverUrl) : 'undefined'}`, |
| ); |
|
|
| if (!this.flowManager || !serverUrl) { |
| logger.error( |
| `${this.logPrefix} OAuth required but flow manager not available or server URL missing for ${this.serverName}`, |
| ); |
| logger.warn(`${this.logPrefix} Please configure OAuth credentials for ${this.serverName}`); |
| return null; |
| } |
|
|
| try { |
| logger.debug(`${this.logPrefix} Checking for existing OAuth flow for ${this.serverName}...`); |
|
|
| |
| const flowId = MCPOAuthHandler.generateFlowId(this.userId!, this.serverName); |
|
|
| |
| const existingFlow = await this.flowManager.getFlowState(flowId, 'mcp_oauth'); |
|
|
| if (existingFlow && existingFlow.status === 'PENDING') { |
| logger.debug( |
| `${this.logPrefix} OAuth flow already exists for ${flowId}, waiting for completion`, |
| ); |
| |
| const tokens = await this.flowManager.createFlow(flowId, 'mcp_oauth'); |
| if (typeof this.oauthEnd === 'function') { |
| await this.oauthEnd(); |
| } |
| logger.info( |
| `${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`, |
| ); |
|
|
| |
| const existingMetadata = existingFlow.metadata as unknown as MCPOAuthFlowMetadata; |
| const clientInfo = existingMetadata?.clientInfo; |
|
|
| return { tokens, clientInfo }; |
| } |
|
|
| |
| |
| if (existingFlow && existingFlow.status !== 'PENDING') { |
| const STALE_FLOW_THRESHOLD = 2 * 60 * 1000; |
| const { isStale, age, status } = await this.flowManager.isFlowStale( |
| flowId, |
| 'mcp_oauth', |
| STALE_FLOW_THRESHOLD, |
| ); |
|
|
| if (isStale) { |
| try { |
| await this.flowManager.deleteFlow(flowId, 'mcp_oauth'); |
| logger.debug( |
| `${this.logPrefix} Cleared stale ${status} OAuth flow (age: ${Math.round(age / 1000)}s)`, |
| ); |
| } catch (error) { |
| logger.warn(`${this.logPrefix} Failed to clear stale OAuth flow`, error); |
| } |
| } else { |
| logger.debug( |
| `${this.logPrefix} Skipping cleanup of recent ${status} flow (age: ${Math.round(age / 1000)}s, threshold: ${STALE_FLOW_THRESHOLD / 1000}s)`, |
| ); |
| |
| if (status === 'FAILED') { |
| logger.warn( |
| `${this.logPrefix} Recent OAuth flow failed, will retry after ${Math.round((STALE_FLOW_THRESHOLD - age) / 1000)}s`, |
| ); |
| } |
| } |
| } |
|
|
| logger.debug(`${this.logPrefix} Initiating new OAuth flow for ${this.serverName}...`); |
| const { |
| authorizationUrl, |
| flowId: newFlowId, |
| flowMetadata, |
| } = await MCPOAuthHandler.initiateOAuthFlow( |
| this.serverName, |
| serverUrl, |
| this.userId!, |
| this.serverConfig.oauth_headers ?? {}, |
| this.serverConfig.oauth, |
| ); |
|
|
| if (typeof this.oauthStart === 'function') { |
| logger.info(`${this.logPrefix} OAuth flow started, issued authorization URL to user`); |
| await this.oauthStart(authorizationUrl); |
| } else { |
| logger.info( |
| `${this.logPrefix} OAuth flow started, no \`oauthStart\` handler defined, relying on callback endpoint`, |
| ); |
| } |
|
|
| |
| const tokens = await this.flowManager.createFlow( |
| newFlowId, |
| 'mcp_oauth', |
| flowMetadata as FlowMetadata, |
| this.signal, |
| ); |
| if (typeof this.oauthEnd === 'function') { |
| await this.oauthEnd(); |
| } |
| logger.info(`${this.logPrefix} OAuth flow completed, tokens received for ${this.serverName}`); |
|
|
| |
| const clientInfo = flowMetadata?.clientInfo; |
| const metadata = flowMetadata?.metadata; |
|
|
| return { |
| tokens, |
| clientInfo, |
| metadata, |
| }; |
| } catch (error) { |
| logger.error(`${this.logPrefix} Failed to complete OAuth flow for ${this.serverName}`, error); |
| return null; |
| } |
| } |
| } |
|
|