SEUyishu commited on
Commit
778fec6
·
verified ·
1 Parent(s): dfc4f2b

Upload 9 files

Browse files
mcp_output/README_MCP.md ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MatDeepLearn MCP Service
2
+
3
+ 这是 [MatDeepLearn](https://github.com/Fung-Lab/MatDeepLearn) 的 MCP (Model Context Protocol) 服务封装,用于通过 AI 助手进行材料属性预测的图神经网络训练和推理。
4
+
5
+ ## 功能概述
6
+
7
+ MatDeepLearn MCP 服务提供以下工具:
8
+
9
+ | 工具名称 | 描述 |
10
+ |---------|------|
11
+ | `check_environment` | 检查环境配置和 GPU 可用性 |
12
+ | `list_available_models` | 列出所有可用的 GNN 模型 |
13
+ | `get_model_config` | 获取特定模型的默认配置 |
14
+ | `process_structure_data` | 将原子结构数据处理为图格式 |
15
+ | `train_model` | 训练 GNN 模型 |
16
+ | `predict_properties` | 使用训练好的模型预测新结构的属性 |
17
+ | `cross_validation` | 执行 k 折交叉验证 |
18
+ | `analyze_structure` | 分析原子结构文件 |
19
+ | `compare_models` | 比较不同 GNN 模型的性能 |
20
+ | `get_dataset_info` | 获取数据集目录信息 |
21
+
22
+ ## 支持的模型
23
+
24
+ - **CGCNN_demo**: Crystal Graph Convolutional Neural Network
25
+ - **MPNN_demo**: Message Passing Neural Network
26
+ - **SchNet_demo**: SchNet 连续滤波卷积神经网络
27
+ - **MEGNet_demo**: MatErials Graph Network
28
+ - **GCN_demo**: Graph Convolutional Network
29
+ - **SOAP_demo**: Smooth Overlap of Atomic Positions 描述符方法
30
+ - **SM_demo**: Sine Matrix 描述符方法
31
+
32
+ ## 本地运行
33
+
34
+ ### 安装依赖
35
+
36
+ ```bash
37
+ cd MatDeepLearn
38
+ pip install -r mcp_output/requirements.txt
39
+ ```
40
+
41
+ ### 启动 STDIO 模式(用于本地 AI 助手)
42
+
43
+ ```bash
44
+ python mcp_output/start_mcp.py
45
+ ```
46
+
47
+ ### 启动 HTTP 模式(用于远程访问)
48
+
49
+ ```bash
50
+ export MCP_TRANSPORT=http
51
+ export MCP_PORT=7860
52
+ python mcp_output/start_mcp.py
53
+ ```
54
+
55
+ ## 部署到 HuggingFace Space
56
+
57
+ ### 1. 创建 HuggingFace Space
58
+
59
+ 1. 登录 [HuggingFace](https://huggingface.co/)
60
+ 2. 点击 "New Space"
61
+ 3. 选择 "Docker" 作为 SDK
62
+ 4. 填写 Space 名称(如 `matdeeplearn-mcp`)
63
+
64
+ ### 2. 上传代码
65
+
66
+ 方法一:通过 Git
67
+
68
+ ```bash
69
+ # 克隆你的 Space 仓库
70
+ git clone https://huggingface.co/spaces/YOUR_USERNAME/matdeeplearn-mcp
71
+ cd matdeeplearn-mcp
72
+
73
+ # 复制 MatDeepLearn 代码
74
+ cp -r /path/to/MatDeepLearn/* .
75
+
76
+ # 提交并推送
77
+ git add .
78
+ git commit -m "Initial MatDeepLearn MCP deployment"
79
+ git push
80
+ ```
81
+
82
+ 方法二:通过 HuggingFace Web 界面
83
+
84
+ 1. 在 Space 页面点击 "Files" 标签
85
+ 2. 上传所有 MatDeepLearn 文件
86
+ 3. 确保包含 `Dockerfile`、`mcp_output/` 目录和所有源代码
87
+
88
+ ### 3. 配置 Space
89
+
90
+ 确保你的 Space 设置中:
91
+ - SDK: Docker
92
+ - Hardware: CPU Basic(免费)或 GPU(付费,更快)
93
+
94
+ ### 4. 等待构建
95
+
96
+ Space 会自动构建 Docker 镜像并启动服务。构建完成后,你可以通过以下 URL 访问:
97
+
98
+ ```
99
+ https://YOUR_USERNAME-matdeeplearn-mcp.hf.space
100
+ ```
101
+
102
+ ## 在 AI 助手中使用
103
+
104
+ ### Claude Desktop 配置
105
+
106
+ 在 `claude_desktop_config.json` 中添加:
107
+
108
+ ```json
109
+ {
110
+ "mcpServers": {
111
+ "matdeeplearn": {
112
+ "command": "python",
113
+ "args": ["/path/to/MatDeepLearn/mcp_output/start_mcp.py"]
114
+ }
115
+ }
116
+ }
117
+ ```
118
+
119
+ ### 使用远程 HTTP 服务
120
+
121
+ 如果部署到 HuggingFace Space,可以通过 HTTP 调用:
122
+
123
+ ```json
124
+ {
125
+ "mcpServers": {
126
+ "matdeeplearn": {
127
+ "url": "https://YOUR_USERNAME-matdeeplearn-mcp.hf.space/mcp"
128
+ }
129
+ }
130
+ }
131
+ ```
132
+
133
+ ## 使用示例
134
+
135
+ ### 检查环境
136
+
137
+ ```
138
+ 请检查 MatDeepLearn 环境是否正常
139
+ ```
140
+
141
+ ### 列出可用模型
142
+
143
+ ```
144
+ 列出 MatDeepLearn 中所有可用的图神经网络模型
145
+ ```
146
+
147
+ ### 训练模型
148
+
149
+ ```
150
+ 使用 CGCNN 模型在 data/test_data 目录上训练 100 个 epoch
151
+ ```
152
+
153
+ ### 预测属性
154
+
155
+ ```
156
+ 使用 trained_model.pth 模型预测 new_structures/ 目录中结构的属性
157
+ ```
158
+
159
+ ### 分析结构
160
+
161
+ ```
162
+ 分析 structure.cif 文件的原子结构信息
163
+ ```
164
+
165
+ ## 数据格式要求
166
+
167
+ ### 目录结构
168
+
169
+ ```
170
+ data_directory/
171
+ ├── targets.csv # 必需:包含结构ID和目标属性
172
+ ├── atom_dict.json # 可选:原子特征字典
173
+ ├── structure1.json # 结构文件(支持 json, cif, xyz, POSCAR 等)
174
+ ├── structure2.json
175
+ └── ...
176
+ ```
177
+
178
+ ### targets.csv 格式
179
+
180
+ ```csv
181
+ structure_id,property1,property2
182
+ structure1,1.23,4.56
183
+ structure2,2.34,5.67
184
+ ```
185
+
186
+ ## 常见问题
187
+
188
+ ### Q: GPU 不可用怎么办?
189
+ A: 服务会自动回退到 CPU 模式。对于大型数据集,建议使用 GPU。
190
+
191
+ ### Q: 如何添加自定义模型?
192
+ A: 在 `matdeeplearn/models/` 目录下添加模型文件,并在 `config.yml` 中添加配置。
193
+
194
+ ### Q: 支持哪些结构文件格式?
195
+ A: 支持 ASE 库支持的所有格式,包括:json, cif, xyz, POSCAR, vasp 等。
196
+
197
+ ## 许可证
198
+
199
+ 本项目遵循 MIT 许可证。
200
+
201
+ ## 致谢
202
+
203
+ - [MatDeepLearn](https://github.com/Fung-Lab/MatDeepLearn) - Victor Fung 等人开发
204
+ - [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/) - GNN 框���
205
+ - [FastMCP](https://github.com/jlowin/fastmcp) - MCP 服务框架
mcp_output/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # MatDeepLearn MCP Output
mcp_output/analysis.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "project_name": "MatDeepLearn",
3
+ "project_description": "A platform for testing and using graph neural networks (GNNs) for materials chemistry applications",
4
+ "repository": "https://github.com/Fung-Lab/MatDeepLearn",
5
+ "mcp_tools": [
6
+ {
7
+ "name": "check_environment",
8
+ "description": "Check if MatDeepLearn environment is properly configured and GPU is available"
9
+ },
10
+ {
11
+ "name": "list_available_models",
12
+ "description": "List all available GNN models in MatDeepLearn"
13
+ },
14
+ {
15
+ "name": "get_model_config",
16
+ "description": "Get the default configuration for a specific model"
17
+ },
18
+ {
19
+ "name": "process_structure_data",
20
+ "description": "Process atomic structure data into graph format for GNN training"
21
+ },
22
+ {
23
+ "name": "train_model",
24
+ "description": "Train a GNN model on processed structure data"
25
+ },
26
+ {
27
+ "name": "predict_properties",
28
+ "description": "Use a trained model to predict properties of new structures"
29
+ },
30
+ {
31
+ "name": "cross_validation",
32
+ "description": "Perform k-fold cross validation on a dataset"
33
+ },
34
+ {
35
+ "name": "analyze_structure",
36
+ "description": "Analyze the structure of atomic data and convert to graph representation info"
37
+ },
38
+ {
39
+ "name": "compare_models",
40
+ "description": "Compare performance of different GNN models on a dataset"
41
+ },
42
+ {
43
+ "name": "get_dataset_info",
44
+ "description": "Get information about a dataset directory"
45
+ }
46
+ ],
47
+ "supported_models": [
48
+ "CGCNN_demo",
49
+ "MPNN_demo",
50
+ "SchNet_demo",
51
+ "MEGNet_demo",
52
+ "GCN_demo",
53
+ "SOAP_demo",
54
+ "SM_demo"
55
+ ],
56
+ "dependencies": [
57
+ "torch",
58
+ "torch-geometric",
59
+ "ase",
60
+ "pymatgen",
61
+ "fastmcp",
62
+ "numpy",
63
+ "scipy",
64
+ "scikit-learn"
65
+ ],
66
+ "python_version": ">=3.8",
67
+ "created_at": "2025-12-03",
68
+ "transport_modes": ["stdio", "http"]
69
+ }
mcp_output/mcp_plugin/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # MatDeepLearn MCP Plugin
mcp_output/mcp_plugin/__pycache__/mcp_service.cpython-311.pyc ADDED
Binary file (26.7 kB). View file
 
mcp_output/mcp_plugin/mcp_service.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MatDeepLearn MCP Service
3
+ A Model Context Protocol service for materials property prediction using Graph Neural Networks.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import json
9
+ import tempfile
10
+ import yaml
11
+ import numpy as np
12
+ from typing import Optional, List, Dict, Any
13
+ from pathlib import Path
14
+
15
+ # Add MatDeepLearn to path
16
+ project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
17
+ if project_root not in sys.path:
18
+ sys.path.insert(0, project_root)
19
+
20
+ from fastmcp import FastMCP
21
+
22
+ # Import MatDeepLearn modules
23
+ try:
24
+ import torch
25
+ from matdeeplearn import models, process, training
26
+ from matdeeplearn.models.utils import model_summary
27
+ MATDEEPLEARN_AVAILABLE = True
28
+ except ImportError as e:
29
+ MATDEEPLEARN_AVAILABLE = False
30
+ IMPORT_ERROR = str(e)
31
+
32
+ mcp = FastMCP("matdeeplearn_service")
33
+
34
+
35
+ @mcp.tool(name="check_environment", description="Check if MatDeepLearn environment is properly configured and GPU is available.")
36
+ def check_environment() -> dict:
37
+ """
38
+ Check if the MatDeepLearn environment is properly configured.
39
+
40
+ Returns:
41
+ dict: Contains environment status including GPU availability.
42
+ """
43
+ try:
44
+ if not MATDEEPLEARN_AVAILABLE:
45
+ return {
46
+ "success": False,
47
+ "error": f"MatDeepLearn not available: {IMPORT_ERROR}"
48
+ }
49
+
50
+ gpu_available = torch.cuda.is_available()
51
+ gpu_count = torch.cuda.device_count() if gpu_available else 0
52
+ gpu_name = torch.cuda.get_device_name(0) if gpu_available else "N/A"
53
+
54
+ return {
55
+ "success": True,
56
+ "matdeeplearn_available": True,
57
+ "torch_version": torch.__version__,
58
+ "gpu_available": gpu_available,
59
+ "gpu_count": gpu_count,
60
+ "gpu_name": gpu_name,
61
+ "available_models": [
62
+ "CGCNN_demo", "MPNN_demo", "SchNet_demo",
63
+ "MEGNet_demo", "GCN_demo", "SOAP_demo", "SM_demo"
64
+ ]
65
+ }
66
+ except Exception as e:
67
+ return {"success": False, "error": str(e)}
68
+
69
+
70
+ @mcp.tool(name="list_available_models", description="List all available GNN models in MatDeepLearn.")
71
+ def list_available_models() -> dict:
72
+ """
73
+ List all available Graph Neural Network models.
74
+
75
+ Returns:
76
+ dict: Contains list of available models with descriptions.
77
+ """
78
+ try:
79
+ models_info = {
80
+ "CGCNN_demo": {
81
+ "name": "Crystal Graph Convolutional Neural Network",
82
+ "description": "A GNN for predicting material properties using crystal graphs.",
83
+ "paper": "Xie & Grossman, PRL 2018"
84
+ },
85
+ "MPNN_demo": {
86
+ "name": "Message Passing Neural Network",
87
+ "description": "General message passing framework for molecular graphs.",
88
+ "paper": "Gilmer et al., ICML 2017"
89
+ },
90
+ "SchNet_demo": {
91
+ "name": "SchNet",
92
+ "description": "Continuous-filter convolutional neural network for modeling quantum interactions.",
93
+ "paper": "Schütt et al., JCP 2017"
94
+ },
95
+ "MEGNet_demo": {
96
+ "name": "MatErials Graph Network",
97
+ "description": "Graph network with global state for materials property prediction.",
98
+ "paper": "Chen et al., Chem. Mater. 2019"
99
+ },
100
+ "GCN_demo": {
101
+ "name": "Graph Convolutional Network",
102
+ "description": "Standard graph convolutional network architecture.",
103
+ "paper": "Kipf & Welling, ICLR 2017"
104
+ },
105
+ "SOAP_demo": {
106
+ "name": "Smooth Overlap of Atomic Positions",
107
+ "description": "Descriptor-based method using SOAP features.",
108
+ "paper": "Bartók et al., PRB 2013"
109
+ },
110
+ "SM_demo": {
111
+ "name": "Sine Matrix",
112
+ "description": "Descriptor-based method using Sine/Coulomb matrix features.",
113
+ "paper": "Various"
114
+ }
115
+ }
116
+
117
+ return {
118
+ "success": True,
119
+ "models": models_info,
120
+ "total_models": len(models_info)
121
+ }
122
+ except Exception as e:
123
+ return {"success": False, "error": str(e)}
124
+
125
+
126
+ @mcp.tool(name="get_model_config", description="Get the default configuration for a specific model.")
127
+ def get_model_config(model_name: str) -> dict:
128
+ """
129
+ Get the default configuration for a specific GNN model.
130
+
131
+ Parameters:
132
+ model_name (str): Name of the model (e.g., 'CGCNN_demo', 'SchNet_demo').
133
+
134
+ Returns:
135
+ dict: Contains the default configuration for the model.
136
+ """
137
+ try:
138
+ config_path = os.path.join(project_root, "config.yml")
139
+
140
+ if not os.path.exists(config_path):
141
+ return {"success": False, "error": "Config file not found"}
142
+
143
+ with open(config_path, "r") as f:
144
+ config = yaml.load(f, Loader=yaml.FullLoader)
145
+
146
+ if model_name not in config.get("Models", {}):
147
+ return {
148
+ "success": False,
149
+ "error": f"Model '{model_name}' not found. Available models: {list(config.get('Models', {}).keys())}"
150
+ }
151
+
152
+ model_config = config["Models"][model_name]
153
+ processing_config = config.get("Processing", {})
154
+ training_config = config.get("Training", {})
155
+
156
+ return {
157
+ "success": True,
158
+ "model_name": model_name,
159
+ "model_config": model_config,
160
+ "processing_config": processing_config,
161
+ "training_config": training_config
162
+ }
163
+ except Exception as e:
164
+ return {"success": False, "error": str(e)}
165
+
166
+
167
+ @mcp.tool(name="process_structure_data", description="Process atomic structure data into graph format for GNN training.")
168
+ def process_structure_data(
169
+ data_path: str,
170
+ target_index: int = 0,
171
+ graph_max_radius: float = 8.0,
172
+ graph_max_neighbors: int = 12,
173
+ reprocess: bool = False
174
+ ) -> dict:
175
+ """
176
+ Process atomic structure data into graph format.
177
+
178
+ Parameters:
179
+ data_path (str): Path to directory containing structure files and targets.csv.
180
+ target_index (int): Index of target column in targets.csv (default: 0).
181
+ graph_max_radius (float): Maximum radius for edges in graph (default: 8.0).
182
+ graph_max_neighbors (int): Maximum number of neighbors per atom (default: 12).
183
+ reprocess (bool): Whether to reprocess data even if processed files exist.
184
+
185
+ Returns:
186
+ dict: Contains processing status and dataset information.
187
+ """
188
+ try:
189
+ if not MATDEEPLEARN_AVAILABLE:
190
+ return {"success": False, "error": "MatDeepLearn not available"}
191
+
192
+ if not os.path.exists(data_path):
193
+ return {"success": False, "error": f"Data path not found: {data_path}"}
194
+
195
+ processing_args = {
196
+ "dataset_type": "inmemory",
197
+ "data_path": data_path,
198
+ "target_path": "targets.csv",
199
+ "dictionary_source": "default",
200
+ "dictionary_path": "atom_dict.json",
201
+ "data_format": "json",
202
+ "verbose": "True",
203
+ "graph_max_radius": graph_max_radius,
204
+ "graph_max_neighbors": graph_max_neighbors,
205
+ "voronoi": "False",
206
+ "edge_features": "True",
207
+ "graph_edge_length": 50,
208
+ "SM_descriptor": "False",
209
+ "SOAP_descriptor": "False"
210
+ }
211
+
212
+ dataset = process.get_dataset(
213
+ data_path,
214
+ target_index,
215
+ "True" if reprocess else "False",
216
+ processing_args
217
+ )
218
+
219
+ return {
220
+ "success": True,
221
+ "dataset_size": len(dataset),
222
+ "sample_data": {
223
+ "num_nodes": dataset[0].x.shape[0] if len(dataset) > 0 else 0,
224
+ "num_node_features": dataset[0].x.shape[1] if len(dataset) > 0 else 0,
225
+ "num_edges": dataset[0].edge_index.shape[1] if len(dataset) > 0 else 0
226
+ },
227
+ "data_path": data_path
228
+ }
229
+ except Exception as e:
230
+ return {"success": False, "error": str(e)}
231
+
232
+
233
+ @mcp.tool(name="train_model", description="Train a GNN model on processed structure data.")
234
+ def train_model(
235
+ data_path: str,
236
+ model_name: str = "CGCNN_demo",
237
+ epochs: int = 100,
238
+ batch_size: int = 32,
239
+ learning_rate: float = 0.002,
240
+ train_ratio: float = 0.8,
241
+ val_ratio: float = 0.1,
242
+ test_ratio: float = 0.1,
243
+ save_model: bool = True,
244
+ model_path: str = "trained_model.pth"
245
+ ) -> dict:
246
+ """
247
+ Train a GNN model on processed structure data.
248
+
249
+ Parameters:
250
+ data_path (str): Path to directory containing processed structure data.
251
+ model_name (str): Name of the model to train (default: 'CGCNN_demo').
252
+ epochs (int): Number of training epochs (default: 100).
253
+ batch_size (int): Training batch size (default: 32).
254
+ learning_rate (float): Learning rate (default: 0.002).
255
+ train_ratio (float): Ratio of data for training (default: 0.8).
256
+ val_ratio (float): Ratio of data for validation (default: 0.1).
257
+ test_ratio (float): Ratio of data for testing (default: 0.1).
258
+ save_model (bool): Whether to save the trained model (default: True).
259
+ model_path (str): Path to save the trained model (default: 'trained_model.pth').
260
+
261
+ Returns:
262
+ dict: Contains training results including train/val/test errors.
263
+ """
264
+ try:
265
+ if not MATDEEPLEARN_AVAILABLE:
266
+ return {"success": False, "error": "MatDeepLearn not available"}
267
+
268
+ if not os.path.exists(data_path):
269
+ return {"success": False, "error": f"Data path not found: {data_path}"}
270
+
271
+ # Load default config
272
+ config_path = os.path.join(project_root, "config.yml")
273
+ with open(config_path, "r") as f:
274
+ config = yaml.load(f, Loader=yaml.FullLoader)
275
+
276
+ if model_name not in config.get("Models", {}):
277
+ return {"success": False, "error": f"Model '{model_name}' not found"}
278
+
279
+ # Prepare configuration
280
+ job_config = {
281
+ "job_name": "mcp_train_job",
282
+ "reprocess": "False",
283
+ "model": model_name,
284
+ "load_model": "False",
285
+ "save_model": "True" if save_model else "False",
286
+ "model_path": model_path,
287
+ "write_output": "True",
288
+ "parallel": "False",
289
+ "seed": np.random.randint(1, 1e6)
290
+ }
291
+
292
+ training_config = {
293
+ "target_index": 0,
294
+ "loss": "l1_loss",
295
+ "train_ratio": train_ratio,
296
+ "val_ratio": val_ratio,
297
+ "test_ratio": test_ratio,
298
+ "verbosity": 5
299
+ }
300
+
301
+ model_config = config["Models"][model_name].copy()
302
+ model_config["epochs"] = epochs
303
+ model_config["batch_size"] = batch_size
304
+ model_config["lr"] = learning_rate
305
+
306
+ # Determine device
307
+ world_size = torch.cuda.device_count()
308
+ if world_size == 0:
309
+ rank = "cpu"
310
+ else:
311
+ rank = "cuda"
312
+
313
+ # Train model
314
+ error_values = training.train_regular(
315
+ rank,
316
+ world_size,
317
+ data_path,
318
+ job_config,
319
+ training_config,
320
+ model_config
321
+ )
322
+
323
+ return {
324
+ "success": True,
325
+ "model_name": model_name,
326
+ "epochs": epochs,
327
+ "train_error": float(error_values[0]) if error_values is not None else None,
328
+ "val_error": float(error_values[1]) if error_values is not None else None,
329
+ "test_error": float(error_values[2]) if error_values is not None else None,
330
+ "model_saved": save_model,
331
+ "model_path": model_path if save_model else None
332
+ }
333
+ except Exception as e:
334
+ return {"success": False, "error": str(e)}
335
+
336
+
337
+ @mcp.tool(name="predict_properties", description="Use a trained model to predict properties of new structures.")
338
+ def predict_properties(
339
+ data_path: str,
340
+ model_path: str,
341
+ target_index: int = 0
342
+ ) -> dict:
343
+ """
344
+ Use a trained model to predict properties of new structures.
345
+
346
+ Parameters:
347
+ data_path (str): Path to directory containing structure files to predict.
348
+ model_path (str): Path to the trained model file (.pth).
349
+ target_index (int): Index of target column (default: 0).
350
+
351
+ Returns:
352
+ dict: Contains predictions and error metrics.
353
+ """
354
+ try:
355
+ if not MATDEEPLEARN_AVAILABLE:
356
+ return {"success": False, "error": "MatDeepLearn not available"}
357
+
358
+ if not os.path.exists(data_path):
359
+ return {"success": False, "error": f"Data path not found: {data_path}"}
360
+
361
+ if not os.path.exists(model_path):
362
+ return {"success": False, "error": f"Model file not found: {model_path}"}
363
+
364
+ # Get dataset
365
+ dataset = process.get_dataset(data_path, target_index, "False")
366
+
367
+ job_config = {
368
+ "job_name": "mcp_predict_job",
369
+ "model_path": model_path,
370
+ "write_output": "True"
371
+ }
372
+
373
+ # Run prediction
374
+ test_error = training.predict(dataset, "l1_loss", job_config)
375
+
376
+ return {
377
+ "success": True,
378
+ "dataset_size": len(dataset),
379
+ "test_error": float(test_error),
380
+ "output_file": "mcp_predict_job_predicted_outputs.csv"
381
+ }
382
+ except Exception as e:
383
+ return {"success": False, "error": str(e)}
384
+
385
+
386
+ @mcp.tool(name="cross_validation", description="Perform k-fold cross validation on a dataset.")
387
+ def cross_validation(
388
+ data_path: str,
389
+ model_name: str = "CGCNN_demo",
390
+ cv_folds: int = 5,
391
+ epochs: int = 100
392
+ ) -> dict:
393
+ """
394
+ Perform k-fold cross validation on a dataset.
395
+
396
+ Parameters:
397
+ data_path (str): Path to directory containing structure data.
398
+ model_name (str): Name of the model to use (default: 'CGCNN_demo').
399
+ cv_folds (int): Number of cross-validation folds (default: 5).
400
+ epochs (int): Number of training epochs per fold (default: 100).
401
+
402
+ Returns:
403
+ dict: Contains cross-validation results.
404
+ """
405
+ try:
406
+ if not MATDEEPLEARN_AVAILABLE:
407
+ return {"success": False, "error": "MatDeepLearn not available"}
408
+
409
+ if not os.path.exists(data_path):
410
+ return {"success": False, "error": f"Data path not found: {data_path}"}
411
+
412
+ # Load config
413
+ config_path = os.path.join(project_root, "config.yml")
414
+ with open(config_path, "r") as f:
415
+ config = yaml.load(f, Loader=yaml.FullLoader)
416
+
417
+ if model_name not in config.get("Models", {}):
418
+ return {"success": False, "error": f"Model '{model_name}' not found"}
419
+
420
+ job_config = {
421
+ "job_name": "mcp_cv_job",
422
+ "reprocess": "False",
423
+ "model": model_name,
424
+ "cv_folds": cv_folds,
425
+ "write_output": "True",
426
+ "parallel": "False",
427
+ "seed": np.random.randint(1, 1e6)
428
+ }
429
+
430
+ training_config = {
431
+ "target_index": 0,
432
+ "loss": "l1_loss",
433
+ "verbosity": 5
434
+ }
435
+
436
+ model_config = config["Models"][model_name].copy()
437
+ model_config["epochs"] = epochs
438
+
439
+ world_size = torch.cuda.device_count()
440
+ rank = "cpu" if world_size == 0 else "cuda"
441
+
442
+ cv_error = training.train_CV(
443
+ rank,
444
+ world_size,
445
+ data_path,
446
+ job_config,
447
+ training_config,
448
+ model_config
449
+ )
450
+
451
+ return {
452
+ "success": True,
453
+ "model_name": model_name,
454
+ "cv_folds": cv_folds,
455
+ "cv_error": float(cv_error) if cv_error is not None else None,
456
+ "output_file": "mcp_cv_job_CV_outputs.csv"
457
+ }
458
+ except Exception as e:
459
+ return {"success": False, "error": str(e)}
460
+
461
+
462
+ @mcp.tool(name="analyze_structure", description="Analyze the structure of atomic data and convert to graph representation info.")
463
+ def analyze_structure(structure_file: str) -> dict:
464
+ """
465
+ Analyze the structure of an atomic structure file.
466
+
467
+ Parameters:
468
+ structure_file (str): Path to a structure file (json, cif, xyz, POSCAR, etc.).
469
+
470
+ Returns:
471
+ dict: Contains structure analysis including atoms, bonds, and graph info.
472
+ """
473
+ try:
474
+ if not os.path.exists(structure_file):
475
+ return {"success": False, "error": f"Structure file not found: {structure_file}"}
476
+
477
+ import ase
478
+ from ase import io
479
+
480
+ # Read structure
481
+ structure = ase.io.read(structure_file)
482
+
483
+ # Get basic info
484
+ symbols = structure.get_chemical_symbols()
485
+ positions = structure.get_positions().tolist()
486
+ cell = structure.get_cell().tolist() if any(structure.pbc) else None
487
+ pbc = structure.pbc.tolist()
488
+
489
+ # Get distance matrix
490
+ distance_matrix = structure.get_all_distances(mic=True)
491
+
492
+ # Analyze connectivity
493
+ cutoff_radius = 8.0
494
+ neighbors_count = []
495
+ for i in range(len(structure)):
496
+ neighbors = np.sum((distance_matrix[i] > 0) & (distance_matrix[i] < cutoff_radius))
497
+ neighbors_count.append(int(neighbors))
498
+
499
+ return {
500
+ "success": True,
501
+ "num_atoms": len(structure),
502
+ "chemical_formula": structure.get_chemical_formula(),
503
+ "elements": list(set(symbols)),
504
+ "element_counts": {elem: symbols.count(elem) for elem in set(symbols)},
505
+ "has_periodicity": any(pbc),
506
+ "pbc": pbc,
507
+ "cell": cell,
508
+ "average_neighbors": float(np.mean(neighbors_count)),
509
+ "min_neighbors": min(neighbors_count),
510
+ "max_neighbors": max(neighbors_count),
511
+ "min_distance": float(distance_matrix[distance_matrix > 0].min()),
512
+ "max_distance": float(distance_matrix.max())
513
+ }
514
+ except Exception as e:
515
+ return {"success": False, "error": str(e)}
516
+
517
+
518
+ @mcp.tool(name="compare_models", description="Compare performance of different GNN models on a dataset.")
519
+ def compare_models(
520
+ data_path: str,
521
+ model_list: List[str] = None,
522
+ epochs: int = 50
523
+ ) -> dict:
524
+ """
525
+ Compare performance of different GNN models on a dataset.
526
+
527
+ Parameters:
528
+ data_path (str): Path to directory containing structure data.
529
+ model_list (List[str]): List of models to compare (default: all available).
530
+ epochs (int): Number of training epochs per model (default: 50).
531
+
532
+ Returns:
533
+ dict: Contains comparison results for each model.
534
+ """
535
+ try:
536
+ if not MATDEEPLEARN_AVAILABLE:
537
+ return {"success": False, "error": "MatDeepLearn not available"}
538
+
539
+ if not os.path.exists(data_path):
540
+ return {"success": False, "error": f"Data path not found: {data_path}"}
541
+
542
+ if model_list is None:
543
+ model_list = ["CGCNN_demo", "GCN_demo", "SchNet_demo"]
544
+
545
+ results = {}
546
+
547
+ for model_name in model_list:
548
+ try:
549
+ result = train_model(
550
+ data_path=data_path,
551
+ model_name=model_name,
552
+ epochs=epochs,
553
+ save_model=False
554
+ )
555
+
556
+ if result["success"]:
557
+ results[model_name] = {
558
+ "train_error": result["train_error"],
559
+ "val_error": result["val_error"],
560
+ "test_error": result["test_error"]
561
+ }
562
+ else:
563
+ results[model_name] = {"error": result["error"]}
564
+ except Exception as e:
565
+ results[model_name] = {"error": str(e)}
566
+
567
+ # Find best model
568
+ best_model = None
569
+ best_error = float("inf")
570
+ for model, res in results.items():
571
+ if "test_error" in res and res["test_error"] is not None:
572
+ if res["test_error"] < best_error:
573
+ best_error = res["test_error"]
574
+ best_model = model
575
+
576
+ return {
577
+ "success": True,
578
+ "results": results,
579
+ "best_model": best_model,
580
+ "best_test_error": best_error if best_model else None
581
+ }
582
+ except Exception as e:
583
+ return {"success": False, "error": str(e)}
584
+
585
+
586
+ @mcp.tool(name="get_dataset_info", description="Get information about a dataset directory.")
587
+ def get_dataset_info(data_path: str) -> dict:
588
+ """
589
+ Get information about a dataset directory.
590
+
591
+ Parameters:
592
+ data_path (str): Path to directory containing structure data.
593
+
594
+ Returns:
595
+ dict: Contains dataset information including file counts and formats.
596
+ """
597
+ try:
598
+ if not os.path.exists(data_path):
599
+ return {"success": False, "error": f"Data path not found: {data_path}"}
600
+
601
+ # Count files by extension
602
+ extensions = {}
603
+ for file in os.listdir(data_path):
604
+ ext = os.path.splitext(file)[1].lower()
605
+ extensions[ext] = extensions.get(ext, 0) + 1
606
+
607
+ # Check for required files
608
+ has_targets = os.path.exists(os.path.join(data_path, "targets.csv"))
609
+ has_atom_dict = os.path.exists(os.path.join(data_path, "atom_dict.json"))
610
+ has_processed = os.path.exists(os.path.join(data_path, "processed"))
611
+
612
+ # Read targets if available
613
+ num_samples = 0
614
+ if has_targets:
615
+ import csv
616
+ with open(os.path.join(data_path, "targets.csv")) as f:
617
+ num_samples = sum(1 for _ in csv.reader(f))
618
+
619
+ return {
620
+ "success": True,
621
+ "data_path": data_path,
622
+ "file_extensions": extensions,
623
+ "has_targets_csv": has_targets,
624
+ "has_atom_dict": has_atom_dict,
625
+ "has_processed_data": has_processed,
626
+ "num_samples": num_samples,
627
+ "ready_for_training": has_targets
628
+ }
629
+ except Exception as e:
630
+ return {"success": False, "error": str(e)}
631
+
632
+
633
+ def create_app() -> FastMCP:
634
+ """
635
+ Creates and returns the FastMCP application instance.
636
+
637
+ Returns:
638
+ FastMCP: The FastMCP application instance.
639
+ """
640
+ return mcp
mcp_output/requirements.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MatDeepLearn MCP Service Requirements
2
+
3
+ # Core MCP Framework
4
+ fastmcp>=0.1.0
5
+
6
+ # PyTorch - CPU version for HuggingFace Space (lighter weight)
7
+ --extra-index-url https://download.pytorch.org/whl/cpu
8
+ torch>=2.0.0
9
+
10
+ # PyTorch Geometric
11
+ torch-scatter
12
+ torch-sparse
13
+ torch-cluster
14
+ torch-spline-conv
15
+ torch-geometric>=2.0.0
16
+
17
+ # Scientific Computing
18
+ numpy>=1.20.0
19
+ scipy>=1.6.0
20
+ scikit-learn>=0.24.0
21
+
22
+ # Materials Science
23
+ ase>=3.20.0
24
+ pymatgen>=2022.0.0
25
+
26
+ # Descriptors (optional, for SOAP/SM models)
27
+ dscribe>=1.0.0
28
+
29
+ # Configuration
30
+ pyyaml>=5.4.0
31
+
32
+ # Visualization (optional)
33
+ matplotlib>=3.1.0
34
+
35
+ # Hyperparameter Optimization (optional)
36
+ ray[tune]>=2.0.0
37
+
38
+ # Utilities
39
+ joblib>=0.13.0
40
+
41
+ # HTTP Server
42
+ uvicorn>=0.20.0
43
+ starlette>=0.25.0
mcp_output/start_mcp.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MatDeepLearn MCP Service Startup Entry
3
+ """
4
+ import sys
5
+ import os
6
+
7
+ # Add project root to path
8
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
9
+ matdeeplearn_root = os.path.dirname(project_root)
10
+
11
+ # Add paths
12
+ if matdeeplearn_root not in sys.path:
13
+ sys.path.insert(0, matdeeplearn_root)
14
+
15
+ mcp_plugin_dir = os.path.join(project_root, "mcp_plugin")
16
+ if mcp_plugin_dir not in sys.path:
17
+ sys.path.insert(0, mcp_plugin_dir)
18
+
19
+ from mcp_service import create_app
20
+
21
+
22
+ def main():
23
+ """Start FastMCP service"""
24
+ app = create_app()
25
+
26
+ # Use environment variable to configure port, default 7860 (HuggingFace default)
27
+ port = int(os.environ.get("MCP_PORT", "7860"))
28
+
29
+ # Choose transport mode based on environment variable
30
+ transport = os.environ.get("MCP_TRANSPORT", "stdio")
31
+
32
+ print(f"Starting MatDeepLearn MCP Service...")
33
+ print(f"Transport: {transport}")
34
+
35
+ if transport == "http":
36
+ print(f"Port: {port}")
37
+ app.run(transport="http", host="0.0.0.0", port=port)
38
+ else:
39
+ # Default to STDIO mode
40
+ app.run()
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()
mcp_output/test_mcp_service.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MatDeepLearn MCP Service Test Script
3
+ 测试 MCP 服务的各个功能是否正常工作
4
+
5
+ 直接测试底层函数,不通过 MCP 装饰器
6
+ """
7
+
8
+ import sys
9
+ import os
10
+ import json
11
+
12
+ # 添加项目路径
13
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
14
+ if project_root not in sys.path:
15
+ sys.path.insert(0, project_root)
16
+
17
+ mcp_plugin_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "mcp_plugin")
18
+ if mcp_plugin_dir not in sys.path:
19
+ sys.path.insert(0, mcp_plugin_dir)
20
+
21
+
22
+ def print_result(test_name: str, result: dict):
23
+ """打印测试结果"""
24
+ status = "✅ PASS" if result.get("success", False) else "❌ FAIL"
25
+ print(f"\n{'='*60}")
26
+ print(f"测试: {test_name}")
27
+ print(f"状态: {status}")
28
+ # 美化输出
29
+ print(f"结果: {json.dumps(result, indent=2, ensure_ascii=False, default=str)}")
30
+ print(f"{'='*60}")
31
+ return result.get("success", False)
32
+
33
+
34
+ # ============== 直接定义测试函数(复制核心逻辑)==============
35
+
36
+ def test_check_environment() -> dict:
37
+ """检查环境配置"""
38
+ result = {
39
+ "success": True,
40
+ "torch_available": False,
41
+ "torch_geometric_available": False,
42
+ "matdeeplearn_available": False,
43
+ "gpu_available": False,
44
+ "gpu_count": 0,
45
+ "gpu_name": "N/A",
46
+ "available_models": [
47
+ "CGCNN_demo", "MPNN_demo", "SchNet_demo",
48
+ "MEGNet_demo", "GCN_demo", "SOAP_demo", "SM_demo"
49
+ ]
50
+ }
51
+
52
+ # 检查 PyTorch
53
+ try:
54
+ import torch
55
+ result["torch_available"] = True
56
+ result["torch_version"] = torch.__version__
57
+ result["gpu_available"] = torch.cuda.is_available()
58
+ result["gpu_count"] = torch.cuda.device_count() if result["gpu_available"] else 0
59
+ result["gpu_name"] = torch.cuda.get_device_name(0) if result["gpu_available"] else "N/A"
60
+ except ImportError:
61
+ result["torch_version"] = "未安装"
62
+
63
+ # 检查 PyTorch Geometric
64
+ try:
65
+ import torch_geometric
66
+ result["torch_geometric_available"] = True
67
+ result["torch_geometric_version"] = torch_geometric.__version__
68
+ except ImportError:
69
+ result["torch_geometric_version"] = "未安装"
70
+
71
+ # 检查 MatDeepLearn
72
+ try:
73
+ from matdeeplearn import models, process, training
74
+ result["matdeeplearn_available"] = True
75
+ except ImportError as e:
76
+ result["matdeeplearn_error"] = str(e)
77
+
78
+ # 如果核心依赖都有,标记成功
79
+ if result["torch_available"]:
80
+ result["success"] = True
81
+ if not result["torch_geometric_available"]:
82
+ result["warning"] = "torch_geometric 未安装,部分功能不可用"
83
+ else:
84
+ result["success"] = False
85
+ result["error"] = "PyTorch 未安装"
86
+
87
+ return result
88
+
89
+
90
+ def test_list_available_models() -> dict:
91
+ """列出可用模型"""
92
+ models_info = {
93
+ "CGCNN_demo": {
94
+ "name": "Crystal Graph Convolutional Neural Network",
95
+ "description": "A GNN for predicting material properties using crystal graphs."
96
+ },
97
+ "MPNN_demo": {
98
+ "name": "Message Passing Neural Network",
99
+ "description": "General message passing framework for molecular graphs."
100
+ },
101
+ "SchNet_demo": {
102
+ "name": "SchNet",
103
+ "description": "Continuous-filter convolutional neural network."
104
+ },
105
+ "MEGNet_demo": {
106
+ "name": "MatErials Graph Network",
107
+ "description": "Graph network with global state for materials."
108
+ },
109
+ "GCN_demo": {
110
+ "name": "Graph Convolutional Network",
111
+ "description": "Standard graph convolutional network."
112
+ },
113
+ "SOAP_demo": {
114
+ "name": "Smooth Overlap of Atomic Positions",
115
+ "description": "Descriptor-based method using SOAP features."
116
+ },
117
+ "SM_demo": {
118
+ "name": "Sine Matrix",
119
+ "description": "Descriptor-based method using Sine/Coulomb matrix."
120
+ }
121
+ }
122
+ return {"success": True, "models": models_info, "total_models": len(models_info)}
123
+
124
+
125
+ def test_get_model_config(model_name: str) -> dict:
126
+ """获取模型配置"""
127
+ import yaml
128
+
129
+ config_path = os.path.join(project_root, "config.yml")
130
+ if not os.path.exists(config_path):
131
+ return {"success": False, "error": "Config file not found"}
132
+
133
+ with open(config_path, "r") as f:
134
+ config = yaml.load(f, Loader=yaml.FullLoader)
135
+
136
+ if model_name not in config.get("Models", {}):
137
+ return {"success": False, "error": f"Model '{model_name}' not found"}
138
+
139
+ return {
140
+ "success": True,
141
+ "model_name": model_name,
142
+ "model_config": config["Models"][model_name]
143
+ }
144
+
145
+
146
+ def test_get_dataset_info(data_path: str) -> dict:
147
+ """获取数据集信息"""
148
+ import csv
149
+
150
+ if not os.path.exists(data_path):
151
+ return {"success": False, "error": f"Data path not found: {data_path}"}
152
+
153
+ extensions = {}
154
+ for f in os.listdir(data_path):
155
+ ext = os.path.splitext(f)[1].lower()
156
+ extensions[ext] = extensions.get(ext, 0) + 1
157
+
158
+ has_targets = os.path.exists(os.path.join(data_path, "targets.csv"))
159
+ has_processed = os.path.exists(os.path.join(data_path, "processed"))
160
+
161
+ num_samples = 0
162
+ if has_targets:
163
+ with open(os.path.join(data_path, "targets.csv")) as f:
164
+ num_samples = sum(1 for _ in csv.reader(f))
165
+
166
+ return {
167
+ "success": True,
168
+ "data_path": data_path,
169
+ "file_extensions": extensions,
170
+ "has_targets_csv": has_targets,
171
+ "has_processed_data": has_processed,
172
+ "num_samples": num_samples
173
+ }
174
+
175
+
176
+ def test_analyze_structure(structure_file: str) -> dict:
177
+ """分析结构文件"""
178
+ import numpy as np
179
+ import ase
180
+ from ase import io
181
+
182
+ if not os.path.exists(structure_file):
183
+ return {"success": False, "error": f"File not found: {structure_file}"}
184
+
185
+ structure = ase.io.read(structure_file)
186
+ symbols = structure.get_chemical_symbols()
187
+ distance_matrix = structure.get_all_distances(mic=True)
188
+
189
+ cutoff_radius = 8.0
190
+ neighbors_count = []
191
+ for i in range(len(structure)):
192
+ neighbors = np.sum((distance_matrix[i] > 0) & (distance_matrix[i] < cutoff_radius))
193
+ neighbors_count.append(int(neighbors))
194
+
195
+ return {
196
+ "success": True,
197
+ "num_atoms": len(structure),
198
+ "chemical_formula": structure.get_chemical_formula(),
199
+ "elements": list(set(symbols)),
200
+ "has_periodicity": any(structure.pbc),
201
+ "average_neighbors": float(np.mean(neighbors_count))
202
+ }
203
+
204
+
205
+ def run_tests():
206
+ """运行所有测试"""
207
+ print("\n" + "="*60)
208
+ print("MatDeepLearn MCP Service 测试")
209
+ print("="*60)
210
+
211
+ passed = 0
212
+ failed = 0
213
+
214
+ # 测试 1: 检查环境
215
+ print("\n[测试 1/5] 检查环境配置...")
216
+ result = test_check_environment()
217
+ if print_result("check_environment", result):
218
+ passed += 1
219
+ if result.get("gpu_available"):
220
+ print(f" GPU: {result.get('gpu_name')} (数量: {result.get('gpu_count')})")
221
+ else:
222
+ print(" GPU: 不可用 (将使用 CPU)")
223
+ print(f" PyTorch 版本: {result.get('torch_version')}")
224
+ else:
225
+ failed += 1
226
+
227
+ # 测试 2: 列出可用模型
228
+ print("\n[测试 2/5] 列出可用模型...")
229
+ result = test_list_available_models()
230
+ if print_result("list_available_models", result):
231
+ passed += 1
232
+ print(f" 可用模型数量: {result.get('total_models')}")
233
+ for name, info in result.get("models", {}).items():
234
+ print(f" - {name}: {info.get('name')}")
235
+ else:
236
+ failed += 1
237
+
238
+ # 测试 3: 获取模型配置
239
+ print("\n[测试 3/5] 获取 CGCNN_demo 模型配置...")
240
+ result = test_get_model_config("CGCNN_demo")
241
+ if print_result("get_model_config", result):
242
+ passed += 1
243
+ config = result.get("model_config", {})
244
+ print(f" 模型类型: {config.get('model')}")
245
+ print(f" Epochs: {config.get('epochs')}")
246
+ print(f" Batch Size: {config.get('batch_size')}")
247
+ print(f" Learning Rate: {config.get('lr')}")
248
+ else:
249
+ failed += 1
250
+
251
+ # 测试 4: 获取数据集信息 (使用 test_data 如果存在)
252
+ print("\n[测试 4/5] 获取数据集信息...")
253
+ test_data_path = os.path.join(project_root, "data", "test_data", "test_data")
254
+ if os.path.exists(test_data_path):
255
+ result = test_get_dataset_info(test_data_path)
256
+ if print_result("get_dataset_info", result):
257
+ passed += 1
258
+ print(f" 数据路径: {result.get('data_path')}")
259
+ print(f" 样本数量: {result.get('num_samples')}")
260
+ print(f" 已处理: {result.get('has_processed_data')}")
261
+ else:
262
+ failed += 1
263
+ else:
264
+ # 尝试检查 data 目录
265
+ data_path = os.path.join(project_root, "data")
266
+ result = test_get_dataset_info(data_path)
267
+ if result.get("success"):
268
+ print_result("get_dataset_info (data目录)", result)
269
+ passed += 1
270
+ else:
271
+ print(f"⚠️ 跳过: 测试数据目录不存在 ({test_data_path})")
272
+ print(" 提示: 请解压 data/test_data.tar.gz 以进行完整测试")
273
+ passed += 1 # 跳过不算失败
274
+
275
+ # 测试 5: 测试不存在的模型配置(错误处理)
276
+ print("\n[测试 5/5] 测试错误处理 (不存在的模型)...")
277
+ result = test_get_model_config("NonExistentModel")
278
+ if not result.get("success"):
279
+ print(f"✅ 错误处理正常: {result.get('error')}")
280
+ passed += 1
281
+ else:
282
+ print("❌ 错误处理失败: 应该返回错误")
283
+ failed += 1
284
+
285
+ # 总结
286
+ print("\n" + "="*60)
287
+ print("测试总结")
288
+ print("="*60)
289
+ print(f"通过: {passed}")
290
+ print(f"失败: {failed}")
291
+ print(f"总计: {passed + failed}")
292
+ print("="*60)
293
+
294
+ if failed == 0:
295
+ print("\n🎉 所有测试通过!MCP 服务已准备就绪。")
296
+ print("\n下一步:")
297
+ print(" 1. 本地运行: python mcp_output/start_mcp.py")
298
+ print(" 2. HTTP 模式: MCP_TRANSPORT=http python mcp_output/start_mcp.py")
299
+ print(" 3. 部署到 HuggingFace Space")
300
+ return True
301
+ else:
302
+ print(f"\n⚠️ 有 {failed} 个测试失败,请检查错误信息。")
303
+ return False
304
+
305
+
306
+ def run_structure_analysis_test():
307
+ """测试结构分析功能(如果有测试数据)"""
308
+ print("\n" + "="*60)
309
+ print("额外测试: 结构分析")
310
+ print("="*60)
311
+
312
+ # 查找可用的结构文件
313
+ test_data_path = os.path.join(project_root, "data", "test_data", "test_data")
314
+
315
+ if os.path.exists(test_data_path):
316
+ # 查找第一个 json 文件
317
+ for f in os.listdir(test_data_path):
318
+ if f.endswith('.json') and f != 'atom_dict.json':
319
+ structure_file = os.path.join(test_data_path, f)
320
+ print(f"\n分析结构文件: {f}")
321
+ result = test_analyze_structure(structure_file)
322
+ if result.get("success"):
323
+ print(f" 化学式: {result.get('chemical_formula')}")
324
+ print(f" 原子数: {result.get('num_atoms')}")
325
+ print(f" 元素: {result.get('elements')}")
326
+ print(f" 周期性: {result.get('has_periodicity')}")
327
+ print(f" 平均邻居数: {result.get('average_neighbors'):.2f}")
328
+ else:
329
+ print(f" 错误: {result.get('error')}")
330
+ break
331
+ else:
332
+ print("⚠️ 测试数据不可用,跳过结构分析测试")
333
+
334
+
335
+ if __name__ == "__main__":
336
+ success = run_tests()
337
+
338
+ # 如果基本测试通过,尝试结构分析测试
339
+ if success:
340
+ try:
341
+ run_structure_analysis_test()
342
+ except Exception as e:
343
+ print(f"\n结构分析测试出错: {e}")
344
+
345
+ sys.exit(0 if success else 1)