Spaces:
Sleeping
Sleeping
Update start_mcp.py
Browse files- start_mcp.py +125 -52
start_mcp.py
CHANGED
|
@@ -9,7 +9,7 @@ import os
|
|
| 9 |
import sys
|
| 10 |
import uuid
|
| 11 |
|
| 12 |
-
from typing import Any, Dict
|
| 13 |
|
| 14 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
|
|
@@ -20,7 +20,6 @@ logger = logging.getLogger("m3gnet-mcp-server")
|
|
| 20 |
def _load_service():
|
| 21 |
try:
|
| 22 |
from mcp_service import mcp # noqa: WPS433 (runtime import)
|
| 23 |
-
|
| 24 |
return mcp
|
| 25 |
except Exception as exc: # noqa: BLE001
|
| 26 |
logger.error("Failed to import MCP service: %s", exc, exc_info=True)
|
|
@@ -29,7 +28,6 @@ def _load_service():
|
|
| 29 |
|
| 30 |
def run_stdio():
|
| 31 |
"""Run the FastMCP service over stdio."""
|
| 32 |
-
|
| 33 |
service = _load_service()
|
| 34 |
logger.info("Starting stdio server mode")
|
| 35 |
service.run()
|
|
@@ -37,19 +35,22 @@ def run_stdio():
|
|
| 37 |
|
| 38 |
def run_sse(host: str, port: int):
|
| 39 |
"""Run the FastMCP service with a lightweight SSE transport."""
|
| 40 |
-
|
| 41 |
service = _load_service()
|
| 42 |
logger.info("Starting SSE server on %s:%s", host, port)
|
| 43 |
|
| 44 |
-
import uvicorn
|
| 45 |
from starlette.applications import Starlette
|
|
|
|
|
|
|
| 46 |
from starlette.requests import Request
|
| 47 |
-
from starlette.responses import HTMLResponse, JSONResponse, StreamingResponse
|
| 48 |
from starlette.routing import Route
|
| 49 |
|
|
|
|
| 50 |
connections: Dict[str, asyncio.Queue] = {}
|
| 51 |
|
| 52 |
-
async def sse_endpoint(request: Request):
|
|
|
|
| 53 |
connection_id = str(uuid.uuid4())
|
| 54 |
queue: asyncio.Queue = asyncio.Queue()
|
| 55 |
connections[connection_id] = queue
|
|
@@ -57,13 +58,17 @@ def run_sse(host: str, port: int):
|
|
| 57 |
|
| 58 |
async def event_stream():
|
| 59 |
try:
|
|
|
|
| 60 |
yield f"event: endpoint\ndata: /messages?sessionId={connection_id}\n\n"
|
| 61 |
while True:
|
| 62 |
try:
|
| 63 |
-
message = await asyncio.wait_for(queue.get(), timeout=
|
| 64 |
yield f"event: message\ndata: {json.dumps(message)}\n\n"
|
| 65 |
except asyncio.TimeoutError:
|
|
|
|
| 66 |
yield ": keep-alive\n\n"
|
|
|
|
|
|
|
| 67 |
finally:
|
| 68 |
connections.pop(connection_id, None)
|
| 69 |
logger.info("SSE connection closed: %s", connection_id)
|
|
@@ -73,47 +78,79 @@ def run_sse(host: str, port: int):
|
|
| 73 |
"Connection": "keep-alive",
|
| 74 |
"X-Accel-Buffering": "no",
|
| 75 |
"Access-Control-Allow-Origin": "*",
|
|
|
|
|
|
|
| 76 |
}
|
| 77 |
return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers)
|
| 78 |
|
| 79 |
-
async def messages_endpoint(request: Request):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
session_id = request.query_params.get("sessionId")
|
| 81 |
-
if session_id not in connections:
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
body = await request.json()
|
| 85 |
method = body.get("method")
|
| 86 |
params = body.get("params", {})
|
| 87 |
-
message_id = body.get("id")
|
| 88 |
|
| 89 |
logger.info("Received method=%s id=%s", method, message_id)
|
| 90 |
|
| 91 |
-
#
|
| 92 |
-
is_notification =
|
| 93 |
|
| 94 |
async def handle_request():
|
|
|
|
|
|
|
|
|
|
| 95 |
try:
|
| 96 |
if method == "initialize":
|
| 97 |
-
result
|
| 98 |
"protocolVersion": "2024-11-05",
|
| 99 |
"serverInfo": {"name": "M3GNet-MCP", "version": "1.0.0"},
|
| 100 |
"capabilities": {"tools": {}},
|
| 101 |
}
|
| 102 |
elif method == "tools/list":
|
| 103 |
tools = []
|
| 104 |
-
for tool in service._tool_manager._tools.values():
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
| 112 |
result = {"tools": tools}
|
| 113 |
elif method == "tools/call":
|
| 114 |
tool_name = params.get("name")
|
| 115 |
arguments = params.get("arguments", {})
|
| 116 |
-
tool = service._tool_manager._tools.get(tool_name)
|
| 117 |
if tool is None:
|
| 118 |
raise ValueError(f"Unknown tool: {tool_name}")
|
| 119 |
logger.info("Invoking tool '%s'", tool_name)
|
|
@@ -121,48 +158,84 @@ def run_sse(host: str, port: int):
|
|
| 121 |
tool_response = await tool.fn(**arguments)
|
| 122 |
else:
|
| 123 |
tool_response = tool.fn(**arguments)
|
| 124 |
-
result = {"content": [{"type": "text", "text": json.dumps(tool_response)}]}
|
| 125 |
-
elif
|
| 126 |
# Notifications do not require a response per JSON-RPC spec
|
| 127 |
logger.info("Handled notification: %s", method)
|
| 128 |
-
return
|
| 129 |
else:
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
# Only send response if message_id is present (i.e., it's a request, not a notification)
|
| 133 |
-
if message_id is not None:
|
| 134 |
-
response = {"jsonrpc": "2.0", "id": message_id, "result": result}
|
| 135 |
-
await connections[session_id].put(response)
|
| 136 |
|
| 137 |
-
except Exception as exc:
|
| 138 |
logger.exception("Error handling request")
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
| 147 |
|
|
|
|
| 148 |
asyncio.create_task(handle_request())
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
async def health(_: Request):
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
async def home(_: Request):
|
|
|
|
| 155 |
return HTMLResponse("""<!DOCTYPE html>
|
| 156 |
-
<html>
|
| 157 |
-
<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
app = Starlette(
|
| 160 |
routes=[
|
| 161 |
-
Route("/", home),
|
| 162 |
-
Route("/health", health),
|
| 163 |
-
Route("/sse", sse_endpoint),
|
| 164 |
-
Route("/messages", messages_endpoint, methods=["POST"]),
|
| 165 |
-
]
|
|
|
|
| 166 |
)
|
| 167 |
|
| 168 |
uvicorn.run(app, host=host, port=port, log_level="info")
|
|
|
|
| 9 |
import sys
|
| 10 |
import uuid
|
| 11 |
|
| 12 |
+
from typing import Any, Dict, Optional
|
| 13 |
|
| 14 |
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
|
|
|
|
| 20 |
def _load_service():
|
| 21 |
try:
|
| 22 |
from mcp_service import mcp # noqa: WPS433 (runtime import)
|
|
|
|
| 23 |
return mcp
|
| 24 |
except Exception as exc: # noqa: BLE001
|
| 25 |
logger.error("Failed to import MCP service: %s", exc, exc_info=True)
|
|
|
|
| 28 |
|
| 29 |
def run_stdio():
|
| 30 |
"""Run the FastMCP service over stdio."""
|
|
|
|
| 31 |
service = _load_service()
|
| 32 |
logger.info("Starting stdio server mode")
|
| 33 |
service.run()
|
|
|
|
| 35 |
|
| 36 |
def run_sse(host: str, port: int):
|
| 37 |
"""Run the FastMCP service with a lightweight SSE transport."""
|
|
|
|
| 38 |
service = _load_service()
|
| 39 |
logger.info("Starting SSE server on %s:%s", host, port)
|
| 40 |
|
| 41 |
+
import uvicorn
|
| 42 |
from starlette.applications import Starlette
|
| 43 |
+
from starlette.middleware import Middleware
|
| 44 |
+
from starlette.middleware.cors import CORSMiddleware
|
| 45 |
from starlette.requests import Request
|
| 46 |
+
from starlette.responses import HTMLResponse, JSONResponse, StreamingResponse, Response
|
| 47 |
from starlette.routing import Route
|
| 48 |
|
| 49 |
+
# Store active SSE connections
|
| 50 |
connections: Dict[str, asyncio.Queue] = {}
|
| 51 |
|
| 52 |
+
async def sse_endpoint(request: Request):
|
| 53 |
+
"""SSE endpoint - establishes connection and sends events."""
|
| 54 |
connection_id = str(uuid.uuid4())
|
| 55 |
queue: asyncio.Queue = asyncio.Queue()
|
| 56 |
connections[connection_id] = queue
|
|
|
|
| 58 |
|
| 59 |
async def event_stream():
|
| 60 |
try:
|
| 61 |
+
# Send initial endpoint event telling client where to POST messages
|
| 62 |
yield f"event: endpoint\ndata: /messages?sessionId={connection_id}\n\n"
|
| 63 |
while True:
|
| 64 |
try:
|
| 65 |
+
message = await asyncio.wait_for(queue.get(), timeout=30)
|
| 66 |
yield f"event: message\ndata: {json.dumps(message)}\n\n"
|
| 67 |
except asyncio.TimeoutError:
|
| 68 |
+
# Send keep-alive comment to prevent connection timeout
|
| 69 |
yield ": keep-alive\n\n"
|
| 70 |
+
except asyncio.CancelledError:
|
| 71 |
+
logger.info("SSE stream cancelled: %s", connection_id)
|
| 72 |
finally:
|
| 73 |
connections.pop(connection_id, None)
|
| 74 |
logger.info("SSE connection closed: %s", connection_id)
|
|
|
|
| 78 |
"Connection": "keep-alive",
|
| 79 |
"X-Accel-Buffering": "no",
|
| 80 |
"Access-Control-Allow-Origin": "*",
|
| 81 |
+
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
| 82 |
+
"Access-Control-Allow-Headers": "*",
|
| 83 |
}
|
| 84 |
return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers)
|
| 85 |
|
| 86 |
+
async def messages_endpoint(request: Request):
|
| 87 |
+
"""Messages endpoint - receives MCP protocol messages from client."""
|
| 88 |
+
# Handle CORS preflight
|
| 89 |
+
if request.method == "OPTIONS":
|
| 90 |
+
return Response(
|
| 91 |
+
status_code=200,
|
| 92 |
+
headers={
|
| 93 |
+
"Access-Control-Allow-Origin": "*",
|
| 94 |
+
"Access-Control-Allow-Methods": "GET, POST, OPTIONS",
|
| 95 |
+
"Access-Control-Allow-Headers": "*",
|
| 96 |
+
}
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
session_id = request.query_params.get("sessionId")
|
| 100 |
+
if not session_id or session_id not in connections:
|
| 101 |
+
logger.warning("Invalid session: %s", session_id)
|
| 102 |
+
return JSONResponse(
|
| 103 |
+
{"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid session"}},
|
| 104 |
+
status_code=400,
|
| 105 |
+
headers={"Access-Control-Allow-Origin": "*"}
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
try:
|
| 109 |
+
body = await request.json()
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.error("Failed to parse JSON body: %s", e)
|
| 112 |
+
return JSONResponse(
|
| 113 |
+
{"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}},
|
| 114 |
+
status_code=400,
|
| 115 |
+
headers={"Access-Control-Allow-Origin": "*"}
|
| 116 |
+
)
|
| 117 |
|
|
|
|
| 118 |
method = body.get("method")
|
| 119 |
params = body.get("params", {})
|
| 120 |
+
message_id = body.get("id") # Can be None for notifications
|
| 121 |
|
| 122 |
logger.info("Received method=%s id=%s", method, message_id)
|
| 123 |
|
| 124 |
+
# Check if this is a notification (no id field means notification)
|
| 125 |
+
is_notification = message_id is None
|
| 126 |
|
| 127 |
async def handle_request():
|
| 128 |
+
result: Optional[Dict[str, Any]] = None
|
| 129 |
+
error: Optional[Dict[str, Any]] = None
|
| 130 |
+
|
| 131 |
try:
|
| 132 |
if method == "initialize":
|
| 133 |
+
result = {
|
| 134 |
"protocolVersion": "2024-11-05",
|
| 135 |
"serverInfo": {"name": "M3GNet-MCP", "version": "1.0.0"},
|
| 136 |
"capabilities": {"tools": {}},
|
| 137 |
}
|
| 138 |
elif method == "tools/list":
|
| 139 |
tools = []
|
| 140 |
+
for tool in service._tool_manager._tools.values():
|
| 141 |
+
tool_schema = getattr(tool, "parameters", None)
|
| 142 |
+
if tool_schema is None:
|
| 143 |
+
tool_schema = {"type": "object", "properties": {}}
|
| 144 |
+
tools.append({
|
| 145 |
+
"name": tool.name,
|
| 146 |
+
"description": tool.description or tool.name,
|
| 147 |
+
"inputSchema": tool_schema,
|
| 148 |
+
})
|
| 149 |
result = {"tools": tools}
|
| 150 |
elif method == "tools/call":
|
| 151 |
tool_name = params.get("name")
|
| 152 |
arguments = params.get("arguments", {})
|
| 153 |
+
tool = service._tool_manager._tools.get(tool_name)
|
| 154 |
if tool is None:
|
| 155 |
raise ValueError(f"Unknown tool: {tool_name}")
|
| 156 |
logger.info("Invoking tool '%s'", tool_name)
|
|
|
|
| 158 |
tool_response = await tool.fn(**arguments)
|
| 159 |
else:
|
| 160 |
tool_response = tool.fn(**arguments)
|
| 161 |
+
result = {"content": [{"type": "text", "text": json.dumps(tool_response, default=str)}]}
|
| 162 |
+
elif method and method.startswith("notifications/"):
|
| 163 |
# Notifications do not require a response per JSON-RPC spec
|
| 164 |
logger.info("Handled notification: %s", method)
|
| 165 |
+
return # Exit without sending response
|
| 166 |
else:
|
| 167 |
+
error = {"code": -32601, "message": f"Method not found: {method}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
except Exception as exc:
|
| 170 |
logger.exception("Error handling request")
|
| 171 |
+
error = {"code": -32000, "message": str(exc)}
|
| 172 |
+
|
| 173 |
+
# Only send response for requests (not notifications)
|
| 174 |
+
if not is_notification and session_id in connections:
|
| 175 |
+
if error:
|
| 176 |
+
response = {"jsonrpc": "2.0", "id": message_id, "error": error}
|
| 177 |
+
else:
|
| 178 |
+
response = {"jsonrpc": "2.0", "id": message_id, "result": result}
|
| 179 |
+
await connections[session_id].put(response)
|
| 180 |
|
| 181 |
+
# Run handler in background
|
| 182 |
asyncio.create_task(handle_request())
|
| 183 |
+
|
| 184 |
+
# Return accepted response
|
| 185 |
+
return JSONResponse(
|
| 186 |
+
{"ok": True},
|
| 187 |
+
headers={"Access-Control-Allow-Origin": "*"}
|
| 188 |
+
)
|
| 189 |
|
| 190 |
async def health(_: Request):
|
| 191 |
+
"""Health check endpoint."""
|
| 192 |
+
return JSONResponse(
|
| 193 |
+
{"status": "ok", "service": "M3GNet-MCP"},
|
| 194 |
+
headers={"Access-Control-Allow-Origin": "*"}
|
| 195 |
+
)
|
| 196 |
|
| 197 |
async def home(_: Request):
|
| 198 |
+
"""Homepage with service info."""
|
| 199 |
return HTMLResponse("""<!DOCTYPE html>
|
| 200 |
+
<html>
|
| 201 |
+
<head><title>M3GNet MCP</title></head>
|
| 202 |
+
<body>
|
| 203 |
+
<h1>M3GNet MCP Service</h1>
|
| 204 |
+
<p>Status: <span style="color:green">Running</span></p>
|
| 205 |
+
<p>SSE endpoint: <code>/sse</code></p>
|
| 206 |
+
<p>Messages endpoint: <code>/messages</code></p>
|
| 207 |
+
<h2>Cursor Configuration</h2>
|
| 208 |
+
<pre>
|
| 209 |
+
{
|
| 210 |
+
"mcpServers": {
|
| 211 |
+
"m3gnet": {
|
| 212 |
+
"url": "http://localhost:7860/sse"
|
| 213 |
+
}
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
</pre>
|
| 217 |
+
</body>
|
| 218 |
+
</html>""")
|
| 219 |
+
|
| 220 |
+
# Create app with CORS middleware
|
| 221 |
+
middleware = [
|
| 222 |
+
Middleware(
|
| 223 |
+
CORSMiddleware,
|
| 224 |
+
allow_origins=["*"],
|
| 225 |
+
allow_methods=["*"],
|
| 226 |
+
allow_headers=["*"],
|
| 227 |
+
allow_credentials=True,
|
| 228 |
+
)
|
| 229 |
+
]
|
| 230 |
|
| 231 |
app = Starlette(
|
| 232 |
routes=[
|
| 233 |
+
Route("/", home, methods=["GET"]),
|
| 234 |
+
Route("/health", health, methods=["GET"]),
|
| 235 |
+
Route("/sse", sse_endpoint, methods=["GET"]),
|
| 236 |
+
Route("/messages", messages_endpoint, methods=["GET", "POST", "OPTIONS"]),
|
| 237 |
+
],
|
| 238 |
+
middleware=middleware,
|
| 239 |
)
|
| 240 |
|
| 241 |
uvicorn.run(app, host=host, port=port, log_level="info")
|