SEUyishu commited on
Commit
6d298cf
·
verified ·
1 Parent(s): d6de948

Update start_mcp.py

Browse files
Files changed (1) hide show
  1. start_mcp.py +125 -52
start_mcp.py CHANGED
@@ -9,7 +9,7 @@ import os
9
  import sys
10
  import uuid
11
 
12
- from typing import Any, Dict
13
 
14
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15
 
@@ -20,7 +20,6 @@ logger = logging.getLogger("m3gnet-mcp-server")
20
  def _load_service():
21
  try:
22
  from mcp_service import mcp # noqa: WPS433 (runtime import)
23
-
24
  return mcp
25
  except Exception as exc: # noqa: BLE001
26
  logger.error("Failed to import MCP service: %s", exc, exc_info=True)
@@ -29,7 +28,6 @@ def _load_service():
29
 
30
  def run_stdio():
31
  """Run the FastMCP service over stdio."""
32
-
33
  service = _load_service()
34
  logger.info("Starting stdio server mode")
35
  service.run()
@@ -37,19 +35,22 @@ def run_stdio():
37
 
38
  def run_sse(host: str, port: int):
39
  """Run the FastMCP service with a lightweight SSE transport."""
40
-
41
  service = _load_service()
42
  logger.info("Starting SSE server on %s:%s", host, port)
43
 
44
- import uvicorn # Imported lazily to keep stdio mode lightweight
45
  from starlette.applications import Starlette
 
 
46
  from starlette.requests import Request
47
- from starlette.responses import HTMLResponse, JSONResponse, StreamingResponse
48
  from starlette.routing import Route
49
 
 
50
  connections: Dict[str, asyncio.Queue] = {}
51
 
52
- async def sse_endpoint(request: Request): # noqa: WPS430
 
53
  connection_id = str(uuid.uuid4())
54
  queue: asyncio.Queue = asyncio.Queue()
55
  connections[connection_id] = queue
@@ -57,13 +58,17 @@ def run_sse(host: str, port: int):
57
 
58
  async def event_stream():
59
  try:
 
60
  yield f"event: endpoint\ndata: /messages?sessionId={connection_id}\n\n"
61
  while True:
62
  try:
63
- message = await asyncio.wait_for(queue.get(), timeout=15)
64
  yield f"event: message\ndata: {json.dumps(message)}\n\n"
65
  except asyncio.TimeoutError:
 
66
  yield ": keep-alive\n\n"
 
 
67
  finally:
68
  connections.pop(connection_id, None)
69
  logger.info("SSE connection closed: %s", connection_id)
@@ -73,47 +78,79 @@ def run_sse(host: str, port: int):
73
  "Connection": "keep-alive",
74
  "X-Accel-Buffering": "no",
75
  "Access-Control-Allow-Origin": "*",
 
 
76
  }
77
  return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers)
78
 
79
- async def messages_endpoint(request: Request): # noqa: WPS430
 
 
 
 
 
 
 
 
 
 
 
 
80
  session_id = request.query_params.get("sessionId")
81
- if session_id not in connections:
82
- return JSONResponse({"error": "invalid session"}, status_code=400)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- body = await request.json()
85
  method = body.get("method")
86
  params = body.get("params", {})
87
- message_id = body.get("id")
88
 
89
  logger.info("Received method=%s id=%s", method, message_id)
90
 
91
- # Notifications (methods starting with "notifications/") do not expect a response
92
- is_notification = method and method.startswith("notifications/")
93
 
94
  async def handle_request():
 
 
 
95
  try:
96
  if method == "initialize":
97
- result: Dict[str, Any] = {
98
  "protocolVersion": "2024-11-05",
99
  "serverInfo": {"name": "M3GNet-MCP", "version": "1.0.0"},
100
  "capabilities": {"tools": {}},
101
  }
102
  elif method == "tools/list":
103
  tools = []
104
- for tool in service._tool_manager._tools.values(): # noqa: SLF001 (internal attribute)
105
- tools.append(
106
- {
107
- "name": tool.name,
108
- "description": tool.description or tool.name,
109
- "inputSchema": getattr(tool, "parameters", {"type": "object", "properties": {}}),
110
- }
111
- )
 
112
  result = {"tools": tools}
113
  elif method == "tools/call":
114
  tool_name = params.get("name")
115
  arguments = params.get("arguments", {})
116
- tool = service._tool_manager._tools.get(tool_name) # noqa: SLF001
117
  if tool is None:
118
  raise ValueError(f"Unknown tool: {tool_name}")
119
  logger.info("Invoking tool '%s'", tool_name)
@@ -121,48 +158,84 @@ def run_sse(host: str, port: int):
121
  tool_response = await tool.fn(**arguments)
122
  else:
123
  tool_response = tool.fn(**arguments)
124
- result = {"content": [{"type": "text", "text": json.dumps(tool_response)}]}
125
- elif is_notification:
126
  # Notifications do not require a response per JSON-RPC spec
127
  logger.info("Handled notification: %s", method)
128
- return
129
  else:
130
- raise ValueError(f"Unsupported method: {method}")
131
-
132
- # Only send response if message_id is present (i.e., it's a request, not a notification)
133
- if message_id is not None:
134
- response = {"jsonrpc": "2.0", "id": message_id, "result": result}
135
- await connections[session_id].put(response)
136
 
137
- except Exception as exc: # noqa: BLE001
138
  logger.exception("Error handling request")
139
- # Only send error response if we have a valid message_id
140
- if message_id is not None:
141
- response = {
142
- "jsonrpc": "2.0",
143
- "id": message_id,
144
- "error": {"code": -32000, "message": str(exc)},
145
- }
146
- await connections[session_id].put(response)
 
147
 
 
148
  asyncio.create_task(handle_request())
149
- return JSONResponse({"ok": True})
 
 
 
 
 
150
 
151
  async def health(_: Request):
152
- return JSONResponse({"status": "ok", "service": "M3GNet-MCP"})
 
 
 
 
153
 
154
  async def home(_: Request):
 
155
  return HTMLResponse("""<!DOCTYPE html>
156
- <html><head><title>M3GNet MCP</title></head>
157
- <body><h1>M3GNet MCP Service</h1><p>SSE endpoint: <code>/sse</code></p></body></html>""")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  app = Starlette(
160
  routes=[
161
- Route("/", home),
162
- Route("/health", health),
163
- Route("/sse", sse_endpoint),
164
- Route("/messages", messages_endpoint, methods=["POST"]),
165
- ]
 
166
  )
167
 
168
  uvicorn.run(app, host=host, port=port, log_level="info")
 
9
  import sys
10
  import uuid
11
 
12
+ from typing import Any, Dict, Optional
13
 
14
  sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15
 
 
20
  def _load_service():
21
  try:
22
  from mcp_service import mcp # noqa: WPS433 (runtime import)
 
23
  return mcp
24
  except Exception as exc: # noqa: BLE001
25
  logger.error("Failed to import MCP service: %s", exc, exc_info=True)
 
28
 
29
  def run_stdio():
30
  """Run the FastMCP service over stdio."""
 
31
  service = _load_service()
32
  logger.info("Starting stdio server mode")
33
  service.run()
 
35
 
36
  def run_sse(host: str, port: int):
37
  """Run the FastMCP service with a lightweight SSE transport."""
 
38
  service = _load_service()
39
  logger.info("Starting SSE server on %s:%s", host, port)
40
 
41
+ import uvicorn
42
  from starlette.applications import Starlette
43
+ from starlette.middleware import Middleware
44
+ from starlette.middleware.cors import CORSMiddleware
45
  from starlette.requests import Request
46
+ from starlette.responses import HTMLResponse, JSONResponse, StreamingResponse, Response
47
  from starlette.routing import Route
48
 
49
+ # Store active SSE connections
50
  connections: Dict[str, asyncio.Queue] = {}
51
 
52
+ async def sse_endpoint(request: Request):
53
+ """SSE endpoint - establishes connection and sends events."""
54
  connection_id = str(uuid.uuid4())
55
  queue: asyncio.Queue = asyncio.Queue()
56
  connections[connection_id] = queue
 
58
 
59
  async def event_stream():
60
  try:
61
+ # Send initial endpoint event telling client where to POST messages
62
  yield f"event: endpoint\ndata: /messages?sessionId={connection_id}\n\n"
63
  while True:
64
  try:
65
+ message = await asyncio.wait_for(queue.get(), timeout=30)
66
  yield f"event: message\ndata: {json.dumps(message)}\n\n"
67
  except asyncio.TimeoutError:
68
+ # Send keep-alive comment to prevent connection timeout
69
  yield ": keep-alive\n\n"
70
+ except asyncio.CancelledError:
71
+ logger.info("SSE stream cancelled: %s", connection_id)
72
  finally:
73
  connections.pop(connection_id, None)
74
  logger.info("SSE connection closed: %s", connection_id)
 
78
  "Connection": "keep-alive",
79
  "X-Accel-Buffering": "no",
80
  "Access-Control-Allow-Origin": "*",
81
+ "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
82
+ "Access-Control-Allow-Headers": "*",
83
  }
84
  return StreamingResponse(event_stream(), media_type="text/event-stream", headers=headers)
85
 
86
+ async def messages_endpoint(request: Request):
87
+ """Messages endpoint - receives MCP protocol messages from client."""
88
+ # Handle CORS preflight
89
+ if request.method == "OPTIONS":
90
+ return Response(
91
+ status_code=200,
92
+ headers={
93
+ "Access-Control-Allow-Origin": "*",
94
+ "Access-Control-Allow-Methods": "GET, POST, OPTIONS",
95
+ "Access-Control-Allow-Headers": "*",
96
+ }
97
+ )
98
+
99
  session_id = request.query_params.get("sessionId")
100
+ if not session_id or session_id not in connections:
101
+ logger.warning("Invalid session: %s", session_id)
102
+ return JSONResponse(
103
+ {"jsonrpc": "2.0", "error": {"code": -32600, "message": "Invalid session"}},
104
+ status_code=400,
105
+ headers={"Access-Control-Allow-Origin": "*"}
106
+ )
107
+
108
+ try:
109
+ body = await request.json()
110
+ except Exception as e:
111
+ logger.error("Failed to parse JSON body: %s", e)
112
+ return JSONResponse(
113
+ {"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error"}},
114
+ status_code=400,
115
+ headers={"Access-Control-Allow-Origin": "*"}
116
+ )
117
 
 
118
  method = body.get("method")
119
  params = body.get("params", {})
120
+ message_id = body.get("id") # Can be None for notifications
121
 
122
  logger.info("Received method=%s id=%s", method, message_id)
123
 
124
+ # Check if this is a notification (no id field means notification)
125
+ is_notification = message_id is None
126
 
127
  async def handle_request():
128
+ result: Optional[Dict[str, Any]] = None
129
+ error: Optional[Dict[str, Any]] = None
130
+
131
  try:
132
  if method == "initialize":
133
+ result = {
134
  "protocolVersion": "2024-11-05",
135
  "serverInfo": {"name": "M3GNet-MCP", "version": "1.0.0"},
136
  "capabilities": {"tools": {}},
137
  }
138
  elif method == "tools/list":
139
  tools = []
140
+ for tool in service._tool_manager._tools.values():
141
+ tool_schema = getattr(tool, "parameters", None)
142
+ if tool_schema is None:
143
+ tool_schema = {"type": "object", "properties": {}}
144
+ tools.append({
145
+ "name": tool.name,
146
+ "description": tool.description or tool.name,
147
+ "inputSchema": tool_schema,
148
+ })
149
  result = {"tools": tools}
150
  elif method == "tools/call":
151
  tool_name = params.get("name")
152
  arguments = params.get("arguments", {})
153
+ tool = service._tool_manager._tools.get(tool_name)
154
  if tool is None:
155
  raise ValueError(f"Unknown tool: {tool_name}")
156
  logger.info("Invoking tool '%s'", tool_name)
 
158
  tool_response = await tool.fn(**arguments)
159
  else:
160
  tool_response = tool.fn(**arguments)
161
+ result = {"content": [{"type": "text", "text": json.dumps(tool_response, default=str)}]}
162
+ elif method and method.startswith("notifications/"):
163
  # Notifications do not require a response per JSON-RPC spec
164
  logger.info("Handled notification: %s", method)
165
+ return # Exit without sending response
166
  else:
167
+ error = {"code": -32601, "message": f"Method not found: {method}"}
 
 
 
 
 
168
 
169
+ except Exception as exc:
170
  logger.exception("Error handling request")
171
+ error = {"code": -32000, "message": str(exc)}
172
+
173
+ # Only send response for requests (not notifications)
174
+ if not is_notification and session_id in connections:
175
+ if error:
176
+ response = {"jsonrpc": "2.0", "id": message_id, "error": error}
177
+ else:
178
+ response = {"jsonrpc": "2.0", "id": message_id, "result": result}
179
+ await connections[session_id].put(response)
180
 
181
+ # Run handler in background
182
  asyncio.create_task(handle_request())
183
+
184
+ # Return accepted response
185
+ return JSONResponse(
186
+ {"ok": True},
187
+ headers={"Access-Control-Allow-Origin": "*"}
188
+ )
189
 
190
  async def health(_: Request):
191
+ """Health check endpoint."""
192
+ return JSONResponse(
193
+ {"status": "ok", "service": "M3GNet-MCP"},
194
+ headers={"Access-Control-Allow-Origin": "*"}
195
+ )
196
 
197
  async def home(_: Request):
198
+ """Homepage with service info."""
199
  return HTMLResponse("""<!DOCTYPE html>
200
+ <html>
201
+ <head><title>M3GNet MCP</title></head>
202
+ <body>
203
+ <h1>M3GNet MCP Service</h1>
204
+ <p>Status: <span style="color:green">Running</span></p>
205
+ <p>SSE endpoint: <code>/sse</code></p>
206
+ <p>Messages endpoint: <code>/messages</code></p>
207
+ <h2>Cursor Configuration</h2>
208
+ <pre>
209
+ {
210
+ "mcpServers": {
211
+ "m3gnet": {
212
+ "url": "http://localhost:7860/sse"
213
+ }
214
+ }
215
+ }
216
+ </pre>
217
+ </body>
218
+ </html>""")
219
+
220
+ # Create app with CORS middleware
221
+ middleware = [
222
+ Middleware(
223
+ CORSMiddleware,
224
+ allow_origins=["*"],
225
+ allow_methods=["*"],
226
+ allow_headers=["*"],
227
+ allow_credentials=True,
228
+ )
229
+ ]
230
 
231
  app = Starlette(
232
  routes=[
233
+ Route("/", home, methods=["GET"]),
234
+ Route("/health", health, methods=["GET"]),
235
+ Route("/sse", sse_endpoint, methods=["GET"]),
236
+ Route("/messages", messages_endpoint, methods=["GET", "POST", "OPTIONS"]),
237
+ ],
238
+ middleware=middleware,
239
  )
240
 
241
  uvicorn.run(app, host=host, port=port, log_level="info")