MatTableGPT / start_mcp.py
SEUyishu's picture
Update start_mcp.py
fbb09f9 verified
#!/usr/bin/env python3
"""
MaTableGPT MCP Server Launcher (Simplified SSE)
================================================
A minimal MCP SSE server implementation for HuggingFace Space.
Usage:
python start_mcp.py [--host HOST] [--port PORT] [--mode MODE]
"""
import os
import sys
import argparse
import logging
import json
import asyncio
import uuid
# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("matablgpt-mcp")
def run_sse_server(host: str, port: int):
"""Run MCP server in SSE mode."""
import uvicorn
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.responses import JSONResponse, HTMLResponse, StreamingResponse
from starlette.requests import Request
# Try import MCP service
try:
from mcp_service import mcp
mcp_available = True
logger.info("MCP service loaded successfully")
except Exception as e:
mcp_available = False
mcp = None
logger.error(f"Failed to load MCP service: {e}")
# Store SSE connections
connections = {}
async def sse_endpoint(request: Request):
"""SSE endpoint - client connects here first."""
conn_id = str(uuid.uuid4())
queue = asyncio.Queue()
connections[conn_id] = queue
logger.info(f"SSE connection: {conn_id}")
async def generate():
try:
# Send the message endpoint URL (just the path, not JSON)
yield f"event: endpoint\ndata: /messages?sessionId={conn_id}\n\n"
while True:
try:
# Shorter timeout for more frequent keepalives
msg = await asyncio.wait_for(queue.get(), timeout=15)
yield f"event: message\ndata: {json.dumps(msg)}\n\n"
except asyncio.TimeoutError:
# Send keepalive more frequently
yield ": keepalive\n\n"
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"SSE generate error: {e}")
finally:
connections.pop(conn_id, None)
logger.info(f"SSE closed: {conn_id}")
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache, no-store, must-revalidate",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Access-Control-Allow-Origin": "*",
}
)
async def messages_endpoint(request: Request):
"""Messages endpoint - client sends JSON-RPC here."""
session_id = request.query_params.get("sessionId")
if not session_id or session_id not in connections:
logger.error(f"Invalid session: {session_id}, active: {list(connections.keys())}")
return JSONResponse({"error": "Invalid session"}, status_code=400)
if not mcp_available:
return JSONResponse({"error": "MCP service not available"}, status_code=500)
body = await request.json()
method = body.get("method", "")
params = body.get("params", {})
msg_id = body.get("id")
logger.info(f"Method: {method}, ID: {msg_id}")
async def process_request():
"""Process the request in background."""
try:
# Process MCP methods
if method == "initialize":
result = {
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "MaTableGPT-MCP", "version": "1.0.0"},
"capabilities": {"tools": {}}
}
elif method == "tools/list":
tools = []
for name, tool in mcp._tool_manager._tools.items():
tools.append({
"name": tool.name,
"description": tool.description or name,
"inputSchema": tool.parameters if hasattr(tool, 'parameters') else {"type": "object", "properties": {}}
})
result = {"tools": tools}
logger.info(f"Listed {len(tools)} tools")
elif method == "tools/call":
tool_name = params.get("name")
tool_args = params.get("arguments", {})
if tool_name not in mcp._tool_manager._tools:
raise Exception(f"Unknown tool: {tool_name}")
logger.info(f"Calling tool: {tool_name}")
tool = mcp._tool_manager._tools[tool_name]
# Call tool directly (don't use executor - it breaks httpx/openai)
if tool.is_async:
tool_result = await tool.fn(**tool_args)
else:
tool_result = tool.fn(**tool_args)
result = {"content": [{"type": "text", "text": json.dumps(tool_result)}]}
logger.info(f"Tool {tool_name} completed")
else:
raise Exception(f"Unknown method: {method}")
response = {"jsonrpc": "2.0", "id": msg_id, "result": result}
except Exception as e:
logger.error(f"Error: {e}")
import traceback
traceback.print_exc()
response = {
"jsonrpc": "2.0",
"id": msg_id,
"error": {"code": -32000, "message": str(e)}
}
# Send response via SSE
if session_id in connections:
await connections[session_id].put(response)
logger.info(f"Response sent for {method}, id={msg_id}")
else:
logger.error(f"Session {session_id} disconnected before response")
# Handle notifications immediately
if method == "notifications/initialized":
return JSONResponse({"ok": True})
# Start background task for other requests
asyncio.create_task(process_request())
return JSONResponse({"ok": True})
async def health(request: Request):
return JSONResponse({"status": "ok", "service": "MaTableGPT-MCP"})
async def home(request: Request):
html = """<!DOCTYPE html>
<html><head><title>MaTableGPT MCP</title></head>
<body>
<h1>🔬 MaTableGPT MCP Service</h1>
<p>SSE Endpoint: <code>/sse</code></p>
<p>Status: ✅ Running</p>
</body></html>"""
return HTMLResponse(html)
app = Starlette(routes=[
Route("/", home),
Route("/health", health),
Route("/sse", sse_endpoint),
Route("/messages", messages_endpoint, methods=["POST"]),
])
logger.info(f"Starting SSE server on {host}:{port}")
uvicorn.run(app, host=host, port=port, log_level="info")
def run_stdio_server():
"""Run MCP server in stdio mode."""
from mcp_service import mcp
logger.info("Starting stdio mode...")
mcp.run()
def main():
parser = argparse.ArgumentParser(description="MaTableGPT MCP Server")
parser.add_argument('--host', default=os.environ.get('MCP_HOST', '0.0.0.0'))
parser.add_argument('--port', type=int, default=int(os.environ.get('MCP_PORT', '7860')))
parser.add_argument('--mode', choices=['stdio', 'sse'], default='sse')
args = parser.parse_args()
# Log API config
api_base = os.environ.get('LLM_API_BASE') or os.environ.get('OPENAI_API_BASE')
if api_base:
logger.info(f"API base: {api_base}")
if args.mode == 'stdio':
run_stdio_server()
else:
run_sse_server(args.host, args.port)
if __name__ == "__main__":
main()