Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """Entry point for running the M3GNet FastMCP service.""" | |
| import argparse | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import sys | |
| import uuid | |
| from typing import Any, Dict, Optional | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger("m3gnet-mcp-server") | |
| def _load_service(): | |
| try: | |
| from mcp_service import mcp # noqa: WPS433 (runtime import) | |
| return mcp | |
| except Exception as exc: # noqa: BLE001 | |
| logger.error("Failed to import MCP service: %s", exc, exc_info=True) | |
| raise | |
| def run_stdio(): | |
| """Run the FastMCP service over stdio.""" | |
| service = _load_service() | |
| logger.info("Starting stdio server mode") | |
| service.run() | |
| def run_sse(host: str, port: int): | |
| """Run the FastMCP service with a lightweight SSE transport.""" | |
| service = _load_service() | |
| logger.info("Starting SSE server on %s:%s", host, port) | |
| import uvicorn | |
| from starlette.applications import Starlette | |
| from starlette.middleware import Middleware | |
| from starlette.middleware.cors import CORSMiddleware | |
| from starlette.requests import Request | |
| from starlette.responses import HTMLResponse, JSONResponse, StreamingResponse, Response | |
| from starlette.routing import Route | |
| # Store active SSE connections | |
| connections: Dict[str, asyncio.Queue] = {} | |
| async def sse_endpoint(request: Request): | |
| """SSE endpoint - establishes connection and sends events.""" | |
| connection_id = str(uuid.uuid4()) | |
| queue: asyncio.Queue = asyncio.Queue() | |
| connections[connection_id] = queue | |
| logger.info("SSE connection established: %s", connection_id) | |
| async def event_stream(): | |
| try: | |
| # Send initial endpoint event telling client where to POST messages | |
| yield f"event: endpoint\ndata: /messages?sessionId={connection_id}\n\n" | |
| while True: | |
| try: | |
| message = await asyncio.wait_for(queue.get(), timeout=30) | |
| yield f"event: message\ndata: {json.dumps(message)}\n\n" | |
| except asyncio.TimeoutError: | |
| # Send keep-alive comment to prevent connection timeout | |
| yield ": keep-alive\n\n" | |
| except asyncio.CancelledError: | |
| logger.info("SSE stream cancelled: %s", connection_id) | |
| finally: | |
| connections.pop(connection_id, None) | |
| logger.info("SSE connection closed: %s", connection_id) | |
| headers = { | |
| "Cache-Control": "no-cache, no-store, must-revalidate", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no", | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Methods": "GET, POST, OPTIONS", | |
| "Access-Control-Allow-Headers": "*", | |
| } | |
| return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers) | |
| async def messages_endpoint(request: Request): | |
| """Messages endpoint - receives MCP protocol messages from client.""" | |
| # Handle CORS preflight | |
| if request.method == "OPTIONS": | |
| return Response( | |
| status_code=200, | |
| headers={ | |
| "Access-Control-Allow-Origin": "*", | |
| "Access-Control-Allow-Methods": "GET, POST, OPTIONS", | |
| "Access-Control-Allow-Headers": "*", | |
| } | |
| ) | |
| session_id = request.query_params.get("sessionId") | |
| if not session_id or session_id not in connections: | |
| logger.warning("Invalid session: %s", session_id) | |
| return JSONResponse( | |
| {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid session"}}, | |
| status_code=400, | |
| headers={"Access-Control-Allow-Origin": "*"} | |
| ) | |
| try: | |
| body = await request.json() | |
| except Exception as e: | |
| logger.error("Failed to parse JSON body: %s", e) | |
| return JSONResponse( | |
| {"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}}, | |
| status_code=400, | |
| headers={"Access-Control-Allow-Origin": "*"} | |
| ) | |
| method = body.get("method") | |
| params = body.get("params", {}) | |
| message_id = body.get("id") # Can be None for notifications | |
| logger.info("Received method=%s id=%s", method, message_id) | |
| # Check if this is a notification (no id field means notification) | |
| is_notification = message_id is None | |
| async def handle_request(): | |
| result: Optional[Dict[str, Any]] = None | |
| error: Optional[Dict[str, Any]] = None | |
| try: | |
| if method == "initialize": | |
| result = { | |
| "protocolVersion": "2024-11-05", | |
| "serverInfo": {"name": "M3GNet-MCP", "version": "1.0.0"}, | |
| "capabilities": {"tools": {}}, | |
| } | |
| elif method == "tools/list": | |
| tools = [] | |
| for tool in service._tool_manager._tools.values(): | |
| tool_schema = getattr(tool, "parameters", None) | |
| if tool_schema is None: | |
| tool_schema = {"type": "object", "properties": {}} | |
| tools.append({ | |
| "name": tool.name, | |
| "description": tool.description or tool.name, | |
| "inputSchema": tool_schema, | |
| }) | |
| result = {"tools": tools} | |
| elif method == "tools/call": | |
| tool_name = params.get("name") | |
| arguments = params.get("arguments", {}) | |
| tool = service._tool_manager._tools.get(tool_name) | |
| if tool is None: | |
| raise ValueError(f"Unknown tool: {tool_name}") | |
| logger.info("Invoking tool '%s'", tool_name) | |
| if tool.is_async: | |
| tool_response = await tool.fn(**arguments) | |
| else: | |
| tool_response = tool.fn(**arguments) | |
| result = {"content": [{"type": "text", "text": json.dumps(tool_response, default=str)}]} | |
| elif method and method.startswith("notifications/"): | |
| # Notifications do not require a response per JSON-RPC spec | |
| logger.info("Handled notification: %s", method) | |
| return # Exit without sending response | |
| else: | |
| error = {"code": -32601, "message": f"Method not found: {method}"} | |
| except Exception as exc: | |
| logger.exception("Error handling request") | |
| error = {"code": -32000, "message": str(exc)} | |
| # Only send response for requests (not notifications) | |
| if not is_notification and session_id in connections: | |
| if error: | |
| response = {"jsonrpc": "2.0", "id": message_id, "error": error} | |
| else: | |
| response = {"jsonrpc": "2.0", "id": message_id, "result": result} | |
| await connections[session_id].put(response) | |
| # Run handler in background | |
| asyncio.create_task(handle_request()) | |
| # Return accepted response | |
| return JSONResponse( | |
| {"ok": True}, | |
| headers={"Access-Control-Allow-Origin": "*"} | |
| ) | |
| async def health(_: Request): | |
| """Health check endpoint.""" | |
| return JSONResponse( | |
| {"status": "ok", "service": "M3GNet-MCP"}, | |
| headers={"Access-Control-Allow-Origin": "*"} | |
| ) | |
| async def home(_: Request): | |
| """Homepage with service info.""" | |
| return HTMLResponse("""<!DOCTYPE html> | |
| <html> | |
| <head><title>M3GNet MCP</title></head> | |
| <body> | |
| <h1>M3GNet MCP Service</h1> | |
| <p>Status: <span style="color:green">Running</span></p> | |
| <p>SSE endpoint: <code>/sse</code></p> | |
| <p>Messages endpoint: <code>/messages</code></p> | |
| <h2>Cursor Configuration</h2> | |
| <pre> | |
| { | |
| "mcpServers": { | |
| "m3gnet": { | |
| "url": "http://localhost:7860/sse" | |
| } | |
| } | |
| } | |
| </pre> | |
| </body> | |
| </html>""") | |
| # Create app with CORS middleware | |
| middleware = [ | |
| Middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| allow_credentials=True, | |
| ) | |
| ] | |
| app = Starlette( | |
| routes=[ | |
| Route("/", home, methods=["GET"]), | |
| Route("/health", health, methods=["GET"]), | |
| Route("/sse", sse_endpoint, methods=["GET"]), | |
| Route("/messages", messages_endpoint, methods=["GET", "POST", "OPTIONS"]), | |
| ], | |
| middleware=middleware, | |
| ) | |
| uvicorn.run(app, host=host, port=port, log_level="info") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run the M3GNet MCP service") | |
| parser.add_argument("--mode", choices=["stdio", "sse"], default=os.environ.get("MCP_MODE", "sse")) | |
| parser.add_argument("--host", default=os.environ.get("MCP_HOST", "0.0.0.0")) | |
| parser.add_argument("--port", type=int, default=int(os.environ.get("MCP_PORT", "7860"))) | |
| args = parser.parse_args() | |
| if args.mode == "stdio": | |
| run_stdio() | |
| else: | |
| run_sse(args.host, args.port) | |
| if __name__ == "__main__": | |
| main() | |