m3gnet / start_mcp.py
SEUyishu's picture
Update start_mcp.py
6d298cf verified
#!/usr/bin/env python3
"""Entry point for running the M3GNet FastMCP service."""
import argparse
import asyncio
import json
import logging
import os
import sys
import uuid
from typing import Any, Dict, Optional
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("m3gnet-mcp-server")
def _load_service():
try:
from mcp_service import mcp # noqa: WPS433 (runtime import)
return mcp
except Exception as exc: # noqa: BLE001
logger.error("Failed to import MCP service: %s", exc, exc_info=True)
raise
def run_stdio():
"""Run the FastMCP service over stdio."""
service = _load_service()
logger.info("Starting stdio server mode")
service.run()
def run_sse(host: str, port: int):
"""Run the FastMCP service with a lightweight SSE transport."""
service = _load_service()
logger.info("Starting SSE server on %s:%s", host, port)
import uvicorn
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, StreamingResponse, Response
from starlette.routing import Route
# Store active SSE connections
connections: Dict[str, asyncio.Queue] = {}
async def sse_endpoint(request: Request):
"""SSE endpoint - establishes connection and sends events."""
connection_id = str(uuid.uuid4())
queue: asyncio.Queue = asyncio.Queue()
connections[connection_id] = queue
logger.info("SSE connection established: %s", connection_id)
async def event_stream():
try:
# Send initial endpoint event telling client where to POST messages
yield f"event: endpoint\ndata: /messages?sessionId={connection_id}\n\n"
while True:
try:
message = await asyncio.wait_for(queue.get(), timeout=30)
yield f"event: message\ndata: {json.dumps(message)}\n\n"
except asyncio.TimeoutError:
# Send keep-alive comment to prevent connection timeout
yield ": keep-alive\n\n"
except asyncio.CancelledError:
logger.info("SSE stream cancelled: %s", connection_id)
finally:
connections.pop(connection_id, None)
logger.info("SSE connection closed: %s", connection_id)
headers = {
"Cache-Control": "no-cache, no-store, must-revalidate",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "*",
}
return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers)
async def messages_endpoint(request: Request):
"""Messages endpoint - receives MCP protocol messages from client."""
# Handle CORS preflight
if request.method == "OPTIONS":
return Response(
status_code=200,
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
"Access-Control-Allow-Headers": "*",
}
)
session_id = request.query_params.get("sessionId")
if not session_id or session_id not in connections:
logger.warning("Invalid session: %s", session_id)
return JSONResponse(
{"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid session"}},
status_code=400,
headers={"Access-Control-Allow-Origin": "*"}
)
try:
body = await request.json()
except Exception as e:
logger.error("Failed to parse JSON body: %s", e)
return JSONResponse(
{"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}},
status_code=400,
headers={"Access-Control-Allow-Origin": "*"}
)
method = body.get("method")
params = body.get("params", {})
message_id = body.get("id") # Can be None for notifications
logger.info("Received method=%s id=%s", method, message_id)
# Check if this is a notification (no id field means notification)
is_notification = message_id is None
async def handle_request():
result: Optional[Dict[str, Any]] = None
error: Optional[Dict[str, Any]] = None
try:
if method == "initialize":
result = {
"protocolVersion": "2024-11-05",
"serverInfo": {"name": "M3GNet-MCP", "version": "1.0.0"},
"capabilities": {"tools": {}},
}
elif method == "tools/list":
tools = []
for tool in service._tool_manager._tools.values():
tool_schema = getattr(tool, "parameters", None)
if tool_schema is None:
tool_schema = {"type": "object", "properties": {}}
tools.append({
"name": tool.name,
"description": tool.description or tool.name,
"inputSchema": tool_schema,
})
result = {"tools": tools}
elif method == "tools/call":
tool_name = params.get("name")
arguments = params.get("arguments", {})
tool = service._tool_manager._tools.get(tool_name)
if tool is None:
raise ValueError(f"Unknown tool: {tool_name}")
logger.info("Invoking tool '%s'", tool_name)
if tool.is_async:
tool_response = await tool.fn(**arguments)
else:
tool_response = tool.fn(**arguments)
result = {"content": [{"type": "text", "text": json.dumps(tool_response, default=str)}]}
elif method and method.startswith("notifications/"):
# Notifications do not require a response per JSON-RPC spec
logger.info("Handled notification: %s", method)
return # Exit without sending response
else:
error = {"code": -32601, "message": f"Method not found: {method}"}
except Exception as exc:
logger.exception("Error handling request")
error = {"code": -32000, "message": str(exc)}
# Only send response for requests (not notifications)
if not is_notification and session_id in connections:
if error:
response = {"jsonrpc": "2.0", "id": message_id, "error": error}
else:
response = {"jsonrpc": "2.0", "id": message_id, "result": result}
await connections[session_id].put(response)
# Run handler in background
asyncio.create_task(handle_request())
# Return accepted response
return JSONResponse(
{"ok": True},
headers={"Access-Control-Allow-Origin": "*"}
)
async def health(_: Request):
"""Health check endpoint."""
return JSONResponse(
{"status": "ok", "service": "M3GNet-MCP"},
headers={"Access-Control-Allow-Origin": "*"}
)
async def home(_: Request):
"""Homepage with service info."""
return HTMLResponse("""<!DOCTYPE html>
<html>
<head><title>M3GNet MCP</title></head>
<body>
<h1>M3GNet MCP Service</h1>
<p>Status: <span style="color:green">Running</span></p>
<p>SSE endpoint: <code>/sse</code></p>
<p>Messages endpoint: <code>/messages</code></p>
<h2>Cursor Configuration</h2>
<pre>
{
"mcpServers": {
"m3gnet": {
"url": "http://localhost:7860/sse"
}
}
}
</pre>
</body>
</html>""")
# Create app with CORS middleware
middleware = [
Middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
allow_credentials=True,
)
]
app = Starlette(
routes=[
Route("/", home, methods=["GET"]),
Route("/health", health, methods=["GET"]),
Route("/sse", sse_endpoint, methods=["GET"]),
Route("/messages", messages_endpoint, methods=["GET", "POST", "OPTIONS"]),
],
middleware=middleware,
)
uvicorn.run(app, host=host, port=port, log_level="info")
def main() -> None:
parser = argparse.ArgumentParser(description="Run the M3GNet MCP service")
parser.add_argument("--mode", choices=["stdio", "sse"], default=os.environ.get("MCP_MODE", "sse"))
parser.add_argument("--host", default=os.environ.get("MCP_HOST", "0.0.0.0"))
parser.add_argument("--port", type=int, default=int(os.environ.get("MCP_PORT", "7860")))
args = parser.parse_args()
if args.mode == "stdio":
run_stdio()
else:
run_sse(args.host, args.port)
if __name__ == "__main__":
main()