SEUyishu commited on
Commit
51e7c6b
·
verified ·
1 Parent(s): 596b729

Update mcp_output/mcp_plugin/mcp_service.py

Browse files
Files changed (1) hide show
  1. mcp_output/mcp_plugin/mcp_service.py +1083 -0
mcp_output/mcp_plugin/mcp_service.py CHANGED
@@ -9,6 +9,11 @@ import json
9
  import tempfile
10
  import yaml
11
  import numpy as np
 
 
 
 
 
12
  from typing import Optional, List, Dict, Any
13
  from pathlib import Path
14
 
@@ -31,6 +36,477 @@ except ImportError as e:
31
 
32
  mcp = FastMCP("matdeeplearn_service")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  @mcp.tool(name="check_environment", description="Check if MatDeepLearn environment is properly configured and GPU is available.")
36
  def check_environment() -> dict:
@@ -884,6 +1360,613 @@ def quick_structure_analysis(
884
  return {"success": False, "error": str(e)}
885
 
886
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
887
  def create_app() -> FastMCP:
888
  """
889
  Creates and returns the FastMCP application instance.
 
9
  import tempfile
10
  import yaml
11
  import numpy as np
12
+ import base64
13
+ import hashlib
14
+ import shutil
15
+ import uuid
16
+ from datetime import datetime
17
  from typing import Optional, List, Dict, Any
18
  from pathlib import Path
19
 
 
36
 
37
  mcp = FastMCP("matdeeplearn_service")
38
 
39
+ # ============================================================================
40
+ # 全局存储管理 - 用于管理上传的数据和训练的模型
41
+ # ============================================================================
42
+
43
+ # 服务器端存储目录
44
+ STORAGE_BASE = os.path.join(project_root, "mcp_storage")
45
+ DATASETS_DIR = os.path.join(STORAGE_BASE, "datasets")
46
+ MODELS_DIR = os.path.join(STORAGE_BASE, "models")
47
+ SESSIONS_DIR = os.path.join(STORAGE_BASE, "sessions")
48
+
49
+ # 确保存储目录存在
50
+ for dir_path in [STORAGE_BASE, DATASETS_DIR, MODELS_DIR, SESSIONS_DIR]:
51
+ os.makedirs(dir_path, exist_ok=True)
52
+
53
+ # 会话管理字典 (session_id -> session_info)
54
+ _sessions: Dict[str, Dict] = {}
55
+
56
+
57
+ def _get_session_path(session_id: str) -> str:
58
+ """获取会话目录路径"""
59
+ return os.path.join(SESSIONS_DIR, session_id)
60
+
61
+
62
+ def _generate_session_id() -> str:
63
+ """生成唯一会话ID"""
64
+ return f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
65
+
66
+
67
+ def _generate_dataset_id(name: str) -> str:
68
+ """生成数据集ID"""
69
+ return f"dataset_{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
70
+
71
+
72
+ def _generate_model_id(model_name: str) -> str:
73
+ """生成模型ID"""
74
+ return f"model_{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
75
+
76
+
77
+ # ============================================================================
78
+ # 会话管理工具
79
+ # ============================================================================
80
+
81
+ @mcp.tool(name="create_session", description="Create a new working session for uploading data and training models. Returns a session_id to use in subsequent operations.")
82
+ def create_session(session_name: Optional[str] = None) -> dict:
83
+ """
84
+ Create a new working session. Use this before uploading data.
85
+
86
+ Parameters:
87
+ session_name (str, optional): A friendly name for this session.
88
+
89
+ Returns:
90
+ dict: Contains session_id and session info.
91
+
92
+ Example:
93
+ create_session(session_name="my_material_project")
94
+ """
95
+ try:
96
+ session_id = _generate_session_id()
97
+ session_path = _get_session_path(session_id)
98
+ os.makedirs(session_path, exist_ok=True)
99
+ os.makedirs(os.path.join(session_path, "data"), exist_ok=True)
100
+ os.makedirs(os.path.join(session_path, "models"), exist_ok=True)
101
+ os.makedirs(os.path.join(session_path, "outputs"), exist_ok=True)
102
+
103
+ session_info = {
104
+ "session_id": session_id,
105
+ "session_name": session_name or session_id,
106
+ "created_at": datetime.now().isoformat(),
107
+ "data_path": os.path.join(session_path, "data"),
108
+ "models_path": os.path.join(session_path, "models"),
109
+ "outputs_path": os.path.join(session_path, "outputs"),
110
+ "uploaded_files": [],
111
+ "trained_models": [],
112
+ "status": "active"
113
+ }
114
+
115
+ _sessions[session_id] = session_info
116
+
117
+ # Save session info to disk
118
+ with open(os.path.join(session_path, "session_info.json"), 'w') as f:
119
+ json.dump(session_info, f, indent=2)
120
+
121
+ return {
122
+ "success": True,
123
+ "session_id": session_id,
124
+ "session_name": session_info["session_name"],
125
+ "message": "Session created successfully. Use this session_id for uploading data and training.",
126
+ "next_steps": [
127
+ "1. Upload structure files using upload_structure_files",
128
+ "2. Upload targets.csv using upload_targets",
129
+ "3. Process data using process_session_data",
130
+ "4. Train model using train_session_model"
131
+ ]
132
+ }
133
+ except Exception as e:
134
+ return {"success": False, "error": str(e)}
135
+
136
+
137
+ @mcp.tool(name="get_session_info", description="Get information about an existing session.")
138
+ def get_session_info(session_id: str) -> dict:
139
+ """
140
+ Get information about an existing session.
141
+
142
+ Parameters:
143
+ session_id (str): The session ID returned from create_session.
144
+
145
+ Returns:
146
+ dict: Session information including uploaded files and trained models.
147
+ """
148
+ try:
149
+ session_path = _get_session_path(session_id)
150
+ info_file = os.path.join(session_path, "session_info.json")
151
+
152
+ if not os.path.exists(info_file):
153
+ return {"success": False, "error": f"Session not found: {session_id}"}
154
+
155
+ with open(info_file, 'r') as f:
156
+ session_info = json.load(f)
157
+
158
+ # Update with current file counts
159
+ data_path = session_info["data_path"]
160
+ if os.path.exists(data_path):
161
+ files = os.listdir(data_path)
162
+ session_info["current_files"] = files
163
+ session_info["file_count"] = len(files)
164
+ session_info["has_targets"] = "targets.csv" in files
165
+
166
+ return {"success": True, **session_info}
167
+ except Exception as e:
168
+ return {"success": False, "error": str(e)}
169
+
170
+
171
+ @mcp.tool(name="list_sessions", description="List all available sessions.")
172
+ def list_sessions() -> dict:
173
+ """
174
+ List all available sessions on the server.
175
+
176
+ Returns:
177
+ dict: List of sessions with their basic info.
178
+ """
179
+ try:
180
+ sessions = []
181
+ if os.path.exists(SESSIONS_DIR):
182
+ for session_id in os.listdir(SESSIONS_DIR):
183
+ info_file = os.path.join(SESSIONS_DIR, session_id, "session_info.json")
184
+ if os.path.exists(info_file):
185
+ with open(info_file, 'r') as f:
186
+ info = json.load(f)
187
+ sessions.append({
188
+ "session_id": session_id,
189
+ "session_name": info.get("session_name", session_id),
190
+ "created_at": info.get("created_at"),
191
+ "status": info.get("status", "unknown")
192
+ })
193
+
194
+ return {
195
+ "success": True,
196
+ "sessions": sessions,
197
+ "total_sessions": len(sessions)
198
+ }
199
+ except Exception as e:
200
+ return {"success": False, "error": str(e)}
201
+
202
+
203
+ @mcp.tool(name="delete_session", description="Delete a session and all its data.")
204
+ def delete_session(session_id: str, confirm: bool = False) -> dict:
205
+ """
206
+ Delete a session and all associated data.
207
+
208
+ Parameters:
209
+ session_id (str): The session ID to delete.
210
+ confirm (bool): Must be True to confirm deletion.
211
+
212
+ Returns:
213
+ dict: Deletion status.
214
+ """
215
+ try:
216
+ if not confirm:
217
+ return {
218
+ "success": False,
219
+ "error": "Please set confirm=True to delete the session. This action cannot be undone."
220
+ }
221
+
222
+ session_path = _get_session_path(session_id)
223
+ if not os.path.exists(session_path):
224
+ return {"success": False, "error": f"Session not found: {session_id}"}
225
+
226
+ shutil.rmtree(session_path)
227
+
228
+ if session_id in _sessions:
229
+ del _sessions[session_id]
230
+
231
+ return {
232
+ "success": True,
233
+ "message": f"Session {session_id} deleted successfully."
234
+ }
235
+ except Exception as e:
236
+ return {"success": False, "error": str(e)}
237
+
238
+
239
+ # ============================================================================
240
+ # 数据上传工具
241
+ # ============================================================================
242
+
243
+ @mcp.tool(name="upload_structure_file", description="Upload a single structure file to a session. Supports CIF, XYZ, POSCAR, JSON formats.")
244
+ def upload_structure_file(
245
+ session_id: str,
246
+ filename: str,
247
+ file_content: str,
248
+ file_format: Optional[str] = None
249
+ ) -> dict:
250
+ """
251
+ Upload a single structure file to a session.
252
+
253
+ Parameters:
254
+ session_id (str): The session ID.
255
+ filename (str): Name for the file (e.g., "structure1.cif").
256
+ file_content (str): The complete file content as a string.
257
+ file_format (str, optional): File format hint (auto-detected from filename if not provided).
258
+
259
+ Returns:
260
+ dict: Upload status and file info.
261
+
262
+ Example:
263
+ upload_structure_file(
264
+ session_id="session_xxx",
265
+ filename="NaCl.cif",
266
+ file_content="data_NaCl\\n_cell_length_a 5.64..."
267
+ )
268
+ """
269
+ try:
270
+ session_path = _get_session_path(session_id)
271
+ if not os.path.exists(session_path):
272
+ return {"success": False, "error": f"Session not found: {session_id}"}
273
+
274
+ data_path = os.path.join(session_path, "data")
275
+ file_path = os.path.join(data_path, filename)
276
+
277
+ with open(file_path, 'w', encoding='utf-8') as f:
278
+ f.write(file_content)
279
+
280
+ # Validate structure if possible
281
+ validation = {"valid": True}
282
+ try:
283
+ import ase.io
284
+ with tempfile.NamedTemporaryFile(mode='w', suffix=os.path.splitext(filename)[1], delete=False) as tmp:
285
+ tmp.write(file_content)
286
+ tmp_path = tmp.name
287
+ try:
288
+ structure = ase.io.read(tmp_path)
289
+ validation = {
290
+ "valid": True,
291
+ "num_atoms": len(structure),
292
+ "formula": structure.get_chemical_formula()
293
+ }
294
+ finally:
295
+ os.unlink(tmp_path)
296
+ except Exception as e:
297
+ validation = {"valid": False, "warning": str(e)}
298
+
299
+ return {
300
+ "success": True,
301
+ "filename": filename,
302
+ "file_size": len(file_content),
303
+ "saved_to": file_path,
304
+ "validation": validation
305
+ }
306
+ except Exception as e:
307
+ return {"success": False, "error": str(e)}
308
+
309
+
310
+ @mcp.tool(name="upload_structure_files_batch", description="Upload multiple structure files at once to a session.")
311
+ def upload_structure_files_batch(
312
+ session_id: str,
313
+ files: Dict[str, str]
314
+ ) -> dict:
315
+ """
316
+ Upload multiple structure files to a session in one call.
317
+
318
+ Parameters:
319
+ session_id (str): The session ID.
320
+ files (dict): Dictionary mapping filename to file content.
321
+ Example: {"struct1.cif": "content1", "struct2.cif": "content2"}
322
+
323
+ Returns:
324
+ dict: Upload status for all files.
325
+
326
+ Example:
327
+ upload_structure_files_batch(
328
+ session_id="session_xxx",
329
+ files={
330
+ "NaCl.cif": "data_NaCl...",
331
+ "ZnO.cif": "data_ZnO..."
332
+ }
333
+ )
334
+ """
335
+ try:
336
+ session_path = _get_session_path(session_id)
337
+ if not os.path.exists(session_path):
338
+ return {"success": False, "error": f"Session not found: {session_id}"}
339
+
340
+ data_path = os.path.join(session_path, "data")
341
+ results = []
342
+ success_count = 0
343
+
344
+ for filename, content in files.items():
345
+ try:
346
+ file_path = os.path.join(data_path, filename)
347
+ with open(file_path, 'w', encoding='utf-8') as f:
348
+ f.write(content)
349
+ results.append({
350
+ "filename": filename,
351
+ "success": True,
352
+ "size": len(content)
353
+ })
354
+ success_count += 1
355
+ except Exception as e:
356
+ results.append({
357
+ "filename": filename,
358
+ "success": False,
359
+ "error": str(e)
360
+ })
361
+
362
+ return {
363
+ "success": True,
364
+ "total_files": len(files),
365
+ "successful_uploads": success_count,
366
+ "failed_uploads": len(files) - success_count,
367
+ "results": results
368
+ }
369
+ except Exception as e:
370
+ return {"success": False, "error": str(e)}
371
+
372
+
373
+ @mcp.tool(name="upload_targets", description="Upload targets.csv file containing target properties for training.")
374
+ def upload_targets(
375
+ session_id: str,
376
+ targets_content: str,
377
+ validate: bool = True
378
+ ) -> dict:
379
+ """
380
+ Upload targets.csv file to a session.
381
+
382
+ Parameters:
383
+ session_id (str): The session ID.
384
+ targets_content (str): Content of targets.csv file.
385
+ Format: structure_id,target_value (one per line).
386
+ validate (bool): Whether to validate the targets file.
387
+
388
+ Returns:
389
+ dict: Upload status and validation info.
390
+
391
+ Example:
392
+ upload_targets(
393
+ session_id="session_xxx",
394
+ targets_content="NaCl,1.5\\nZnO,2.3\\nTiO2,3.1"
395
+ )
396
+ """
397
+ try:
398
+ session_path = _get_session_path(session_id)
399
+ if not os.path.exists(session_path):
400
+ return {"success": False, "error": f"Session not found: {session_id}"}
401
+
402
+ data_path = os.path.join(session_path, "data")
403
+ targets_path = os.path.join(data_path, "targets.csv")
404
+
405
+ with open(targets_path, 'w', encoding='utf-8') as f:
406
+ f.write(targets_content)
407
+
408
+ # Validate and analyze
409
+ validation = {"valid": True}
410
+ if validate:
411
+ import csv
412
+ from io import StringIO
413
+
414
+ reader = csv.reader(StringIO(targets_content))
415
+ rows = list(reader)
416
+
417
+ structure_ids = []
418
+ target_values = []
419
+ for row in rows:
420
+ if len(row) >= 2:
421
+ structure_ids.append(row[0])
422
+ try:
423
+ target_values.append(float(row[1]))
424
+ except:
425
+ pass
426
+
427
+ # Check for matching structure files
428
+ existing_files = os.listdir(data_path)
429
+ structure_files = [f for f in existing_files if f != "targets.csv"]
430
+ structure_names = [os.path.splitext(f)[0] for f in structure_files]
431
+
432
+ matched = [sid for sid in structure_ids if sid in structure_names]
433
+ unmatched = [sid for sid in structure_ids if sid not in structure_names]
434
+
435
+ validation = {
436
+ "valid": True,
437
+ "num_samples": len(rows),
438
+ "num_valid_targets": len(target_values),
439
+ "target_range": {
440
+ "min": min(target_values) if target_values else None,
441
+ "max": max(target_values) if target_values else None,
442
+ "mean": sum(target_values) / len(target_values) if target_values else None
443
+ },
444
+ "matched_structures": len(matched),
445
+ "unmatched_structures": unmatched[:10] if unmatched else [],
446
+ "existing_structure_files": len(structure_files)
447
+ }
448
+
449
+ return {
450
+ "success": True,
451
+ "saved_to": targets_path,
452
+ "validation": validation
453
+ }
454
+ except Exception as e:
455
+ return {"success": False, "error": str(e)}
456
+
457
+
458
+ @mcp.tool(name="upload_binary_file", description="Upload a binary file (like .pth model file) encoded as base64.")
459
+ def upload_binary_file(
460
+ session_id: str,
461
+ filename: str,
462
+ base64_content: str,
463
+ destination: str = "models"
464
+ ) -> dict:
465
+ """
466
+ Upload a binary file (e.g., pre-trained model .pth file) encoded as base64.
467
+
468
+ Parameters:
469
+ session_id (str): The session ID.
470
+ filename (str): Name for the file.
471
+ base64_content (str): File content encoded as base64 string.
472
+ destination (str): Where to save - "models" or "data".
473
+
474
+ Returns:
475
+ dict: Upload status.
476
+
477
+ Example:
478
+ # In Python, encode your model file:
479
+ # import base64
480
+ # with open("model.pth", "rb") as f:
481
+ # encoded = base64.b64encode(f.read()).decode()
482
+ # Then pass encoded as base64_content
483
+ """
484
+ try:
485
+ session_path = _get_session_path(session_id)
486
+ if not os.path.exists(session_path):
487
+ return {"success": False, "error": f"Session not found: {session_id}"}
488
+
489
+ if destination == "models":
490
+ dest_path = os.path.join(session_path, "models")
491
+ else:
492
+ dest_path = os.path.join(session_path, "data")
493
+
494
+ file_path = os.path.join(dest_path, filename)
495
+
496
+ # Decode and write binary content
497
+ binary_content = base64.b64decode(base64_content)
498
+ with open(file_path, 'wb') as f:
499
+ f.write(binary_content)
500
+
501
+ return {
502
+ "success": True,
503
+ "filename": filename,
504
+ "file_size_bytes": len(binary_content),
505
+ "saved_to": file_path
506
+ }
507
+ except Exception as e:
508
+ return {"success": False, "error": str(e)}
509
+
510
 
511
  @mcp.tool(name="check_environment", description="Check if MatDeepLearn environment is properly configured and GPU is available.")
512
  def check_environment() -> dict:
 
1360
  return {"success": False, "error": str(e)}
1361
 
1362
 
1363
+ # ============================================================================
1364
+ # 基于会话的训练和模型管理工具
1365
+ # ============================================================================
1366
+
1367
+ @mcp.tool(name="process_session_data", description="Process uploaded structure data in a session into graph format for GNN training.")
1368
+ def process_session_data(
1369
+ session_id: str,
1370
+ target_index: int = 0,
1371
+ graph_max_radius: float = 8.0,
1372
+ graph_max_neighbors: int = 12,
1373
+ reprocess: bool = True
1374
+ ) -> dict:
1375
+ """
1376
+ Process all uploaded structure files in a session into graph format.
1377
+
1378
+ Parameters:
1379
+ session_id (str): The session ID.
1380
+ target_index (int): Index of target column in targets.csv (default: 0, meaning second column).
1381
+ graph_max_radius (float): Maximum radius for graph edges (default: 8.0 Angstrom).
1382
+ graph_max_neighbors (int): Maximum neighbors per atom (default: 12).
1383
+ reprocess (bool): Force reprocessing even if already processed (default: True).
1384
+
1385
+ Returns:
1386
+ dict: Processing status and dataset statistics.
1387
+ """
1388
+ try:
1389
+ if not MATDEEPLEARN_AVAILABLE:
1390
+ return {"success": False, "error": "MatDeepLearn not available"}
1391
+
1392
+ session_path = _get_session_path(session_id)
1393
+ if not os.path.exists(session_path):
1394
+ return {"success": False, "error": f"Session not found: {session_id}"}
1395
+
1396
+ data_path = os.path.join(session_path, "data")
1397
+
1398
+ # Check for required files
1399
+ if not os.path.exists(os.path.join(data_path, "targets.csv")):
1400
+ return {
1401
+ "success": False,
1402
+ "error": "targets.csv not found. Please upload targets using upload_targets first."
1403
+ }
1404
+
1405
+ files = [f for f in os.listdir(data_path) if f != "targets.csv" and not f.startswith('.')]
1406
+ if len(files) == 0:
1407
+ return {
1408
+ "success": False,
1409
+ "error": "No structure files found. Please upload structure files first."
1410
+ }
1411
+
1412
+ processing_args = {
1413
+ "dataset_type": "inmemory",
1414
+ "data_path": data_path,
1415
+ "target_path": "targets.csv",
1416
+ "dictionary_source": "default",
1417
+ "dictionary_path": "atom_dict.json",
1418
+ "data_format": "json",
1419
+ "verbose": "True",
1420
+ "graph_max_radius": graph_max_radius,
1421
+ "graph_max_neighbors": graph_max_neighbors,
1422
+ "voronoi": "False",
1423
+ "edge_features": "True",
1424
+ "graph_edge_length": 50,
1425
+ "SM_descriptor": "False",
1426
+ "SOAP_descriptor": "False"
1427
+ }
1428
+
1429
+ dataset = process.get_dataset(
1430
+ data_path,
1431
+ target_index,
1432
+ "True" if reprocess else "False",
1433
+ processing_args
1434
+ )
1435
+
1436
+ # Calculate statistics
1437
+ num_nodes_list = [data.x.shape[0] for data in dataset]
1438
+ num_edges_list = [data.edge_index.shape[1] for data in dataset]
1439
+
1440
+ return {
1441
+ "success": True,
1442
+ "session_id": session_id,
1443
+ "dataset_size": len(dataset),
1444
+ "statistics": {
1445
+ "avg_atoms_per_structure": float(np.mean(num_nodes_list)),
1446
+ "min_atoms": min(num_nodes_list),
1447
+ "max_atoms": max(num_nodes_list),
1448
+ "avg_edges_per_structure": float(np.mean(num_edges_list)),
1449
+ "num_node_features": dataset[0].x.shape[1] if len(dataset) > 0 else 0
1450
+ },
1451
+ "ready_for_training": True,
1452
+ "next_step": "Use train_session_model to train a model on this data."
1453
+ }
1454
+ except Exception as e:
1455
+ return {"success": False, "error": str(e)}
1456
+
1457
+
1458
+ @mcp.tool(name="train_session_model", description="Train a GNN model on processed session data.")
1459
+ def train_session_model(
1460
+ session_id: str,
1461
+ model_name: str = "CGCNN_demo",
1462
+ epochs: int = 100,
1463
+ batch_size: int = 32,
1464
+ learning_rate: float = 0.002,
1465
+ train_ratio: float = 0.8,
1466
+ val_ratio: float = 0.1,
1467
+ test_ratio: float = 0.1,
1468
+ model_save_name: Optional[str] = None
1469
+ ) -> dict:
1470
+ """
1471
+ Train a GNN model on processed session data.
1472
+
1473
+ Parameters:
1474
+ session_id (str): The session ID with processed data.
1475
+ model_name (str): Model to use - "CGCNN_demo", "SchNet_demo", "MPNN_demo", etc.
1476
+ epochs (int): Number of training epochs (default: 100).
1477
+ batch_size (int): Batch size (default: 32).
1478
+ learning_rate (float): Learning rate (default: 0.002).
1479
+ train_ratio (float): Training data ratio (default: 0.8).
1480
+ val_ratio (float): Validation data ratio (default: 0.1).
1481
+ test_ratio (float): Test data ratio (default: 0.1).
1482
+ model_save_name (str, optional): Custom name for saved model.
1483
+
1484
+ Returns:
1485
+ dict: Training results including errors and model path.
1486
+ """
1487
+ try:
1488
+ if not MATDEEPLEARN_AVAILABLE:
1489
+ return {"success": False, "error": "MatDeepLearn not available"}
1490
+
1491
+ session_path = _get_session_path(session_id)
1492
+ if not os.path.exists(session_path):
1493
+ return {"success": False, "error": f"Session not found: {session_id}"}
1494
+
1495
+ data_path = os.path.join(session_path, "data")
1496
+ models_path = os.path.join(session_path, "models")
1497
+ outputs_path = os.path.join(session_path, "outputs")
1498
+
1499
+ # Load config
1500
+ config_path = os.path.join(project_root, "config.yml")
1501
+ with open(config_path, "r") as f:
1502
+ config = yaml.load(f, Loader=yaml.FullLoader)
1503
+
1504
+ if model_name not in config.get("Models", {}):
1505
+ available_models = list(config.get("Models", {}).keys())
1506
+ return {
1507
+ "success": False,
1508
+ "error": f"Model '{model_name}' not found. Available: {available_models}"
1509
+ }
1510
+
1511
+ # Generate model filename
1512
+ if model_save_name:
1513
+ model_filename = f"{model_save_name}.pth"
1514
+ else:
1515
+ model_filename = f"{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pth"
1516
+
1517
+ model_path = os.path.join(models_path, model_filename)
1518
+
1519
+ # Prepare job config
1520
+ job_config = {
1521
+ "job_name": f"train_{session_id}",
1522
+ "reprocess": "False",
1523
+ "model": model_name,
1524
+ "load_model": "False",
1525
+ "save_model": "True",
1526
+ "model_path": model_path,
1527
+ "write_output": "True",
1528
+ "parallel": "False",
1529
+ "seed": np.random.randint(1, 1e6)
1530
+ }
1531
+
1532
+ training_config = {
1533
+ "target_index": 0,
1534
+ "loss": "l1_loss",
1535
+ "train_ratio": train_ratio,
1536
+ "val_ratio": val_ratio,
1537
+ "test_ratio": test_ratio,
1538
+ "verbosity": 5
1539
+ }
1540
+
1541
+ model_config = config["Models"][model_name].copy()
1542
+ model_config["epochs"] = epochs
1543
+ model_config["batch_size"] = batch_size
1544
+ model_config["lr"] = learning_rate
1545
+
1546
+ # Determine device
1547
+ world_size = torch.cuda.device_count()
1548
+ rank = "cpu" if world_size == 0 else "cuda"
1549
+
1550
+ # Change to outputs directory for writing results
1551
+ original_cwd = os.getcwd()
1552
+ os.chdir(outputs_path)
1553
+
1554
+ try:
1555
+ # Train model
1556
+ error_values = training.train_regular(
1557
+ rank,
1558
+ world_size,
1559
+ data_path,
1560
+ job_config,
1561
+ training_config,
1562
+ model_config
1563
+ )
1564
+ finally:
1565
+ os.chdir(original_cwd)
1566
+
1567
+ # Update session info
1568
+ info_file = os.path.join(session_path, "session_info.json")
1569
+ if os.path.exists(info_file):
1570
+ with open(info_file, 'r') as f:
1571
+ session_info = json.load(f)
1572
+
1573
+ session_info.setdefault("trained_models", []).append({
1574
+ "model_name": model_name,
1575
+ "model_file": model_filename,
1576
+ "model_path": model_path,
1577
+ "trained_at": datetime.now().isoformat(),
1578
+ "epochs": epochs,
1579
+ "train_error": float(error_values[0]) if error_values is not None else None,
1580
+ "val_error": float(error_values[1]) if error_values is not None else None,
1581
+ "test_error": float(error_values[2]) if error_values is not None else None
1582
+ })
1583
+
1584
+ with open(info_file, 'w') as f:
1585
+ json.dump(session_info, f, indent=2)
1586
+
1587
+ return {
1588
+ "success": True,
1589
+ "session_id": session_id,
1590
+ "model_name": model_name,
1591
+ "model_file": model_filename,
1592
+ "model_path": model_path,
1593
+ "epochs": epochs,
1594
+ "device_used": rank,
1595
+ "results": {
1596
+ "train_error": float(error_values[0]) if error_values is not None else None,
1597
+ "val_error": float(error_values[1]) if error_values is not None else None,
1598
+ "test_error": float(error_values[2]) if error_values is not None else None
1599
+ },
1600
+ "next_steps": [
1601
+ "Use predict_with_session_model to make predictions",
1602
+ "Use download_model to get the trained model file",
1603
+ "Use evaluate_session_model for detailed evaluation"
1604
+ ]
1605
+ }
1606
+ except Exception as e:
1607
+ return {"success": False, "error": str(e)}
1608
+
1609
+
1610
+ @mcp.tool(name="list_session_models", description="List all trained models in a session.")
1611
+ def list_session_models(session_id: str) -> dict:
1612
+ """
1613
+ List all trained models in a session.
1614
+
1615
+ Parameters:
1616
+ session_id (str): The session ID.
1617
+
1618
+ Returns:
1619
+ dict: List of trained models with their info.
1620
+ """
1621
+ try:
1622
+ session_path = _get_session_path(session_id)
1623
+ if not os.path.exists(session_path):
1624
+ return {"success": False, "error": f"Session not found: {session_id}"}
1625
+
1626
+ models_path = os.path.join(session_path, "models")
1627
+
1628
+ # Get model files
1629
+ model_files = []
1630
+ if os.path.exists(models_path):
1631
+ for f in os.listdir(models_path):
1632
+ if f.endswith('.pth'):
1633
+ file_path = os.path.join(models_path, f)
1634
+ model_files.append({
1635
+ "filename": f,
1636
+ "path": file_path,
1637
+ "size_mb": os.path.getsize(file_path) / (1024 * 1024),
1638
+ "created": datetime.fromtimestamp(os.path.getctime(file_path)).isoformat()
1639
+ })
1640
+
1641
+ # Get training history from session info
1642
+ info_file = os.path.join(session_path, "session_info.json")
1643
+ trained_models = []
1644
+ if os.path.exists(info_file):
1645
+ with open(info_file, 'r') as f:
1646
+ session_info = json.load(f)
1647
+ trained_models = session_info.get("trained_models", [])
1648
+
1649
+ return {
1650
+ "success": True,
1651
+ "session_id": session_id,
1652
+ "model_files": model_files,
1653
+ "training_history": trained_models,
1654
+ "total_models": len(model_files)
1655
+ }
1656
+ except Exception as e:
1657
+ return {"success": False, "error": str(e)}
1658
+
1659
+
1660
+ @mcp.tool(name="predict_with_session_model", description="Make predictions using a trained model from the session.")
1661
+ def predict_with_session_model(
1662
+ session_id: str,
1663
+ model_filename: str,
1664
+ structure_contents: Optional[Dict[str, str]] = None,
1665
+ use_session_data: bool = False
1666
+ ) -> dict:
1667
+ """
1668
+ Make predictions using a trained model.
1669
+
1670
+ Parameters:
1671
+ session_id (str): The session ID.
1672
+ model_filename (str): Name of the model file (e.g., "CGCNN_demo_20231201.pth").
1673
+ structure_contents (dict, optional): New structures to predict.
1674
+ Format: {"name1.cif": "content", ...}
1675
+ use_session_data (bool): If True, predict on the session's training data.
1676
+
1677
+ Returns:
1678
+ dict: Predictions for each structure.
1679
+ """
1680
+ try:
1681
+ if not MATDEEPLEARN_AVAILABLE:
1682
+ return {"success": False, "error": "MatDeepLearn not available"}
1683
+
1684
+ session_path = _get_session_path(session_id)
1685
+ if not os.path.exists(session_path):
1686
+ return {"success": False, "error": f"Session not found: {session_id}"}
1687
+
1688
+ model_path = os.path.join(session_path, "models", model_filename)
1689
+ if not os.path.exists(model_path):
1690
+ return {"success": False, "error": f"Model not found: {model_filename}"}
1691
+
1692
+ # Determine data path
1693
+ if use_session_data:
1694
+ data_path = os.path.join(session_path, "data")
1695
+ elif structure_contents:
1696
+ # Create temp directory for new structures
1697
+ temp_dir = tempfile.mkdtemp(prefix="mcp_predict_")
1698
+ data_path = temp_dir
1699
+
1700
+ # Write structures
1701
+ for filename, content in structure_contents.items():
1702
+ with open(os.path.join(temp_dir, filename), 'w') as f:
1703
+ f.write(content)
1704
+
1705
+ # Create dummy targets.csv
1706
+ struct_names = [os.path.splitext(f)[0] for f in structure_contents.keys()]
1707
+ with open(os.path.join(temp_dir, "targets.csv"), 'w') as f:
1708
+ for name in struct_names:
1709
+ f.write(f"{name},0.0\n")
1710
+ else:
1711
+ return {
1712
+ "success": False,
1713
+ "error": "Either structure_contents or use_session_data=True must be provided"
1714
+ }
1715
+
1716
+ # Get dataset
1717
+ dataset = process.get_dataset(data_path, 0, "True")
1718
+
1719
+ job_config = {
1720
+ "job_name": f"predict_{session_id}",
1721
+ "model_path": model_path,
1722
+ "write_output": "True"
1723
+ }
1724
+
1725
+ outputs_path = os.path.join(session_path, "outputs")
1726
+ original_cwd = os.getcwd()
1727
+ os.chdir(outputs_path)
1728
+
1729
+ try:
1730
+ # Run prediction
1731
+ test_error = training.predict(dataset, "l1_loss", job_config)
1732
+
1733
+ # Read predictions
1734
+ predictions = []
1735
+ output_file = os.path.join(outputs_path, f"predict_{session_id}_predicted_outputs.csv")
1736
+ if os.path.exists(output_file):
1737
+ import csv
1738
+ with open(output_file, 'r') as f:
1739
+ reader = csv.reader(f)
1740
+ for row in reader:
1741
+ if len(row) >= 2:
1742
+ predictions.append({
1743
+ "structure_id": row[0],
1744
+ "predicted_value": float(row[1]) if row[1] else None
1745
+ })
1746
+ finally:
1747
+ os.chdir(original_cwd)
1748
+ if structure_contents and 'temp_dir' in locals():
1749
+ shutil.rmtree(temp_dir, ignore_errors=True)
1750
+
1751
+ return {
1752
+ "success": True,
1753
+ "session_id": session_id,
1754
+ "model_used": model_filename,
1755
+ "num_predictions": len(predictions),
1756
+ "predictions": predictions,
1757
+ "mean_absolute_error": float(test_error) if use_session_data else None
1758
+ }
1759
+ except Exception as e:
1760
+ return {"success": False, "error": str(e)}
1761
+
1762
+
1763
+ @mcp.tool(name="download_model", description="Get a trained model file as base64 encoded string for download.")
1764
+ def download_model(session_id: str, model_filename: str) -> dict:
1765
+ """
1766
+ Get a trained model file as base64 encoded string.
1767
+ You can decode this to get the .pth file.
1768
+
1769
+ Parameters:
1770
+ session_id (str): The session ID.
1771
+ model_filename (str): Name of the model file.
1772
+
1773
+ Returns:
1774
+ dict: Base64 encoded model file and metadata.
1775
+
1776
+ Usage after receiving:
1777
+ import base64
1778
+ model_data = base64.b64decode(result["model_base64"])
1779
+ with open("my_model.pth", "wb") as f:
1780
+ f.write(model_data)
1781
+ """
1782
+ try:
1783
+ session_path = _get_session_path(session_id)
1784
+ if not os.path.exists(session_path):
1785
+ return {"success": False, "error": f"Session not found: {session_id}"}
1786
+
1787
+ model_path = os.path.join(session_path, "models", model_filename)
1788
+ if not os.path.exists(model_path):
1789
+ return {"success": False, "error": f"Model not found: {model_filename}"}
1790
+
1791
+ # Read and encode model
1792
+ with open(model_path, 'rb') as f:
1793
+ model_bytes = f.read()
1794
+
1795
+ model_base64 = base64.b64encode(model_bytes).decode('utf-8')
1796
+
1797
+ return {
1798
+ "success": True,
1799
+ "model_filename": model_filename,
1800
+ "file_size_bytes": len(model_bytes),
1801
+ "file_size_mb": len(model_bytes) / (1024 * 1024),
1802
+ "model_base64": model_base64,
1803
+ "instructions": "Decode with: base64.b64decode(model_base64) and save as .pth file"
1804
+ }
1805
+ except Exception as e:
1806
+ return {"success": False, "error": str(e)}
1807
+
1808
+
1809
+ @mcp.tool(name="compare_session_models", description="Compare multiple trained models in a session on the same dataset.")
1810
+ def compare_session_models(
1811
+ session_id: str,
1812
+ model_filenames: Optional[List[str]] = None
1813
+ ) -> dict:
1814
+ """
1815
+ Compare multiple trained models in a session.
1816
+
1817
+ Parameters:
1818
+ session_id (str): The session ID.
1819
+ model_filenames (list, optional): List of model files to compare. If None, compare all.
1820
+
1821
+ Returns:
1822
+ dict: Comparison results with rankings.
1823
+ """
1824
+ try:
1825
+ session_path = _get_session_path(session_id)
1826
+ if not os.path.exists(session_path):
1827
+ return {"success": False, "error": f"Session not found: {session_id}"}
1828
+
1829
+ # Get training history
1830
+ info_file = os.path.join(session_path, "session_info.json")
1831
+ if not os.path.exists(info_file):
1832
+ return {"success": False, "error": "No training history found"}
1833
+
1834
+ with open(info_file, 'r') as f:
1835
+ session_info = json.load(f)
1836
+
1837
+ trained_models = session_info.get("trained_models", [])
1838
+
1839
+ if model_filenames:
1840
+ trained_models = [m for m in trained_models if m.get("model_file") in model_filenames]
1841
+
1842
+ if len(trained_models) == 0:
1843
+ return {"success": False, "error": "No trained models found"}
1844
+
1845
+ # Sort by test error
1846
+ sorted_models = sorted(
1847
+ trained_models,
1848
+ key=lambda x: x.get("test_error") or float('inf')
1849
+ )
1850
+
1851
+ comparison = []
1852
+ for i, model in enumerate(sorted_models):
1853
+ comparison.append({
1854
+ "rank": i + 1,
1855
+ "model_name": model.get("model_name"),
1856
+ "model_file": model.get("model_file"),
1857
+ "train_error": model.get("train_error"),
1858
+ "val_error": model.get("val_error"),
1859
+ "test_error": model.get("test_error"),
1860
+ "epochs": model.get("epochs"),
1861
+ "trained_at": model.get("trained_at")
1862
+ })
1863
+
1864
+ best_model = sorted_models[0] if sorted_models else None
1865
+
1866
+ return {
1867
+ "success": True,
1868
+ "session_id": session_id,
1869
+ "num_models_compared": len(comparison),
1870
+ "comparison": comparison,
1871
+ "best_model": {
1872
+ "model_file": best_model.get("model_file"),
1873
+ "model_name": best_model.get("model_name"),
1874
+ "test_error": best_model.get("test_error")
1875
+ } if best_model else None,
1876
+ "recommendation": f"Best model is {best_model.get('model_file')} with test error {best_model.get('test_error'):.4f}" if best_model and best_model.get('test_error') else None
1877
+ }
1878
+ except Exception as e:
1879
+ return {"success": False, "error": str(e)}
1880
+
1881
+
1882
+ @mcp.tool(name="run_cross_validation_session", description="Run k-fold cross validation on session data.")
1883
+ def run_cross_validation_session(
1884
+ session_id: str,
1885
+ model_name: str = "CGCNN_demo",
1886
+ cv_folds: int = 5,
1887
+ epochs: int = 100
1888
+ ) -> dict:
1889
+ """
1890
+ Run k-fold cross validation on session data.
1891
+
1892
+ Parameters:
1893
+ session_id (str): The session ID.
1894
+ model_name (str): Model to use (default: "CGCNN_demo").
1895
+ cv_folds (int): Number of folds (default: 5).
1896
+ epochs (int): Training epochs per fold (default: 100).
1897
+
1898
+ Returns:
1899
+ dict: Cross validation results.
1900
+ """
1901
+ try:
1902
+ if not MATDEEPLEARN_AVAILABLE:
1903
+ return {"success": False, "error": "MatDeepLearn not available"}
1904
+
1905
+ session_path = _get_session_path(session_id)
1906
+ if not os.path.exists(session_path):
1907
+ return {"success": False, "error": f"Session not found: {session_id}"}
1908
+
1909
+ data_path = os.path.join(session_path, "data")
1910
+ outputs_path = os.path.join(session_path, "outputs")
1911
+
1912
+ # Load config
1913
+ config_path = os.path.join(project_root, "config.yml")
1914
+ with open(config_path, "r") as f:
1915
+ config = yaml.load(f, Loader=yaml.FullLoader)
1916
+
1917
+ if model_name not in config.get("Models", {}):
1918
+ return {"success": False, "error": f"Model '{model_name}' not found"}
1919
+
1920
+ job_config = {
1921
+ "job_name": f"cv_{session_id}",
1922
+ "reprocess": "False",
1923
+ "model": model_name,
1924
+ "cv_folds": cv_folds,
1925
+ "write_output": "True",
1926
+ "parallel": "False",
1927
+ "seed": np.random.randint(1, 1e6)
1928
+ }
1929
+
1930
+ training_config = {
1931
+ "target_index": 0,
1932
+ "loss": "l1_loss",
1933
+ "verbosity": 5
1934
+ }
1935
+
1936
+ model_config = config["Models"][model_name].copy()
1937
+ model_config["epochs"] = epochs
1938
+
1939
+ world_size = torch.cuda.device_count()
1940
+ rank = "cpu" if world_size == 0 else "cuda"
1941
+
1942
+ original_cwd = os.getcwd()
1943
+ os.chdir(outputs_path)
1944
+
1945
+ try:
1946
+ cv_error = training.train_CV(
1947
+ rank,
1948
+ world_size,
1949
+ data_path,
1950
+ job_config,
1951
+ training_config,
1952
+ model_config
1953
+ )
1954
+ finally:
1955
+ os.chdir(original_cwd)
1956
+
1957
+ return {
1958
+ "success": True,
1959
+ "session_id": session_id,
1960
+ "model_name": model_name,
1961
+ "cv_folds": cv_folds,
1962
+ "epochs_per_fold": epochs,
1963
+ "cv_mean_error": float(cv_error) if cv_error is not None else None,
1964
+ "output_file": f"cv_{session_id}_CV_outputs.csv"
1965
+ }
1966
+ except Exception as e:
1967
+ return {"success": False, "error": str(e)}
1968
+
1969
+
1970
  def create_app() -> FastMCP:
1971
  """
1972
  Creates and returns the FastMCP application instance.