MatDeepLearn / mcp_output /test_mcp_service.py
SEUyishu's picture
Upload 9 files
778fec6 verified
"""
MatDeepLearn MCP Service Test Script
测试 MCP 服务的各个功能是否正常工作
直接测试底层函数,不通过 MCP 装饰器
"""
import sys
import os
import json
# 添加项目路径
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if project_root not in sys.path:
sys.path.insert(0, project_root)
mcp_plugin_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "mcp_plugin")
if mcp_plugin_dir not in sys.path:
sys.path.insert(0, mcp_plugin_dir)
def print_result(test_name: str, result: dict):
"""打印测试结果"""
status = "✅ PASS" if result.get("success", False) else "❌ FAIL"
print(f"\n{'='*60}")
print(f"测试: {test_name}")
print(f"状态: {status}")
# 美化输出
print(f"结果: {json.dumps(result, indent=2, ensure_ascii=False, default=str)}")
print(f"{'='*60}")
return result.get("success", False)
# ============== 直接定义测试函数(复制核心逻辑)==============
def test_check_environment() -> dict:
"""检查环境配置"""
result = {
"success": True,
"torch_available": False,
"torch_geometric_available": False,
"matdeeplearn_available": False,
"gpu_available": False,
"gpu_count": 0,
"gpu_name": "N/A",
"available_models": [
"CGCNN_demo", "MPNN_demo", "SchNet_demo",
"MEGNet_demo", "GCN_demo", "SOAP_demo", "SM_demo"
]
}
# 检查 PyTorch
try:
import torch
result["torch_available"] = True
result["torch_version"] = torch.__version__
result["gpu_available"] = torch.cuda.is_available()
result["gpu_count"] = torch.cuda.device_count() if result["gpu_available"] else 0
result["gpu_name"] = torch.cuda.get_device_name(0) if result["gpu_available"] else "N/A"
except ImportError:
result["torch_version"] = "未安装"
# 检查 PyTorch Geometric
try:
import torch_geometric
result["torch_geometric_available"] = True
result["torch_geometric_version"] = torch_geometric.__version__
except ImportError:
result["torch_geometric_version"] = "未安装"
# 检查 MatDeepLearn
try:
from matdeeplearn import models, process, training
result["matdeeplearn_available"] = True
except ImportError as e:
result["matdeeplearn_error"] = str(e)
# 如果核心依赖都有,标记成功
if result["torch_available"]:
result["success"] = True
if not result["torch_geometric_available"]:
result["warning"] = "torch_geometric 未安装,部分功能不可用"
else:
result["success"] = False
result["error"] = "PyTorch 未安装"
return result
def test_list_available_models() -> dict:
"""列出可用模型"""
models_info = {
"CGCNN_demo": {
"name": "Crystal Graph Convolutional Neural Network",
"description": "A GNN for predicting material properties using crystal graphs."
},
"MPNN_demo": {
"name": "Message Passing Neural Network",
"description": "General message passing framework for molecular graphs."
},
"SchNet_demo": {
"name": "SchNet",
"description": "Continuous-filter convolutional neural network."
},
"MEGNet_demo": {
"name": "MatErials Graph Network",
"description": "Graph network with global state for materials."
},
"GCN_demo": {
"name": "Graph Convolutional Network",
"description": "Standard graph convolutional network."
},
"SOAP_demo": {
"name": "Smooth Overlap of Atomic Positions",
"description": "Descriptor-based method using SOAP features."
},
"SM_demo": {
"name": "Sine Matrix",
"description": "Descriptor-based method using Sine/Coulomb matrix."
}
}
return {"success": True, "models": models_info, "total_models": len(models_info)}
def test_get_model_config(model_name: str) -> dict:
"""获取模型配置"""
import yaml
config_path = os.path.join(project_root, "config.yml")
if not os.path.exists(config_path):
return {"success": False, "error": "Config file not found"}
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if model_name not in config.get("Models", {}):
return {"success": False, "error": f"Model '{model_name}' not found"}
return {
"success": True,
"model_name": model_name,
"model_config": config["Models"][model_name]
}
def test_get_dataset_info(data_path: str) -> dict:
"""获取数据集信息"""
import csv
if not os.path.exists(data_path):
return {"success": False, "error": f"Data path not found: {data_path}"}
extensions = {}
for f in os.listdir(data_path):
ext = os.path.splitext(f)[1].lower()
extensions[ext] = extensions.get(ext, 0) + 1
has_targets = os.path.exists(os.path.join(data_path, "targets.csv"))
has_processed = os.path.exists(os.path.join(data_path, "processed"))
num_samples = 0
if has_targets:
with open(os.path.join(data_path, "targets.csv")) as f:
num_samples = sum(1 for _ in csv.reader(f))
return {
"success": True,
"data_path": data_path,
"file_extensions": extensions,
"has_targets_csv": has_targets,
"has_processed_data": has_processed,
"num_samples": num_samples
}
def test_analyze_structure(structure_file: str) -> dict:
"""分析结构文件"""
import numpy as np
import ase
from ase import io
if not os.path.exists(structure_file):
return {"success": False, "error": f"File not found: {structure_file}"}
structure = ase.io.read(structure_file)
symbols = structure.get_chemical_symbols()
distance_matrix = structure.get_all_distances(mic=True)
cutoff_radius = 8.0
neighbors_count = []
for i in range(len(structure)):
neighbors = np.sum((distance_matrix[i] > 0) & (distance_matrix[i] < cutoff_radius))
neighbors_count.append(int(neighbors))
return {
"success": True,
"num_atoms": len(structure),
"chemical_formula": structure.get_chemical_formula(),
"elements": list(set(symbols)),
"has_periodicity": any(structure.pbc),
"average_neighbors": float(np.mean(neighbors_count))
}
def run_tests():
"""运行所有测试"""
print("\n" + "="*60)
print("MatDeepLearn MCP Service 测试")
print("="*60)
passed = 0
failed = 0
# 测试 1: 检查环境
print("\n[测试 1/5] 检查环境配置...")
result = test_check_environment()
if print_result("check_environment", result):
passed += 1
if result.get("gpu_available"):
print(f" GPU: {result.get('gpu_name')} (数量: {result.get('gpu_count')})")
else:
print(" GPU: 不可用 (将使用 CPU)")
print(f" PyTorch 版本: {result.get('torch_version')}")
else:
failed += 1
# 测试 2: 列出可用模型
print("\n[测试 2/5] 列出可用模型...")
result = test_list_available_models()
if print_result("list_available_models", result):
passed += 1
print(f" 可用模型数量: {result.get('total_models')}")
for name, info in result.get("models", {}).items():
print(f" - {name}: {info.get('name')}")
else:
failed += 1
# 测试 3: 获取模型配置
print("\n[测试 3/5] 获取 CGCNN_demo 模型配置...")
result = test_get_model_config("CGCNN_demo")
if print_result("get_model_config", result):
passed += 1
config = result.get("model_config", {})
print(f" 模型类型: {config.get('model')}")
print(f" Epochs: {config.get('epochs')}")
print(f" Batch Size: {config.get('batch_size')}")
print(f" Learning Rate: {config.get('lr')}")
else:
failed += 1
# 测试 4: 获取数据集信息 (使用 test_data 如果存在)
print("\n[测试 4/5] 获取数据集信息...")
test_data_path = os.path.join(project_root, "data", "test_data", "test_data")
if os.path.exists(test_data_path):
result = test_get_dataset_info(test_data_path)
if print_result("get_dataset_info", result):
passed += 1
print(f" 数据路径: {result.get('data_path')}")
print(f" 样本数量: {result.get('num_samples')}")
print(f" 已处理: {result.get('has_processed_data')}")
else:
failed += 1
else:
# 尝试检查 data 目录
data_path = os.path.join(project_root, "data")
result = test_get_dataset_info(data_path)
if result.get("success"):
print_result("get_dataset_info (data目录)", result)
passed += 1
else:
print(f"⚠️ 跳过: 测试数据目录不存在 ({test_data_path})")
print(" 提示: 请解压 data/test_data.tar.gz 以进行完整测试")
passed += 1 # 跳过不算失败
# 测试 5: 测试不存在的模型配置(错误处理)
print("\n[测试 5/5] 测试错误处理 (不存在的模型)...")
result = test_get_model_config("NonExistentModel")
if not result.get("success"):
print(f"✅ 错误处理正常: {result.get('error')}")
passed += 1
else:
print("❌ 错误处理失败: 应该返回错误")
failed += 1
# 总结
print("\n" + "="*60)
print("测试总结")
print("="*60)
print(f"通过: {passed}")
print(f"失败: {failed}")
print(f"总计: {passed + failed}")
print("="*60)
if failed == 0:
print("\n🎉 所有测试通过!MCP 服务已准备就绪。")
print("\n下一步:")
print(" 1. 本地运行: python mcp_output/start_mcp.py")
print(" 2. HTTP 模式: MCP_TRANSPORT=http python mcp_output/start_mcp.py")
print(" 3. 部署到 HuggingFace Space")
return True
else:
print(f"\n⚠️ 有 {failed} 个测试失败,请检查错误信息。")
return False
def run_structure_analysis_test():
"""测试结构分析功能(如果有测试数据)"""
print("\n" + "="*60)
print("额外测试: 结构分析")
print("="*60)
# 查找可用的结构文件
test_data_path = os.path.join(project_root, "data", "test_data", "test_data")
if os.path.exists(test_data_path):
# 查找第一个 json 文件
for f in os.listdir(test_data_path):
if f.endswith('.json') and f != 'atom_dict.json':
structure_file = os.path.join(test_data_path, f)
print(f"\n分析结构文件: {f}")
result = test_analyze_structure(structure_file)
if result.get("success"):
print(f" 化学式: {result.get('chemical_formula')}")
print(f" 原子数: {result.get('num_atoms')}")
print(f" 元素: {result.get('elements')}")
print(f" 周期性: {result.get('has_periodicity')}")
print(f" 平均邻居数: {result.get('average_neighbors'):.2f}")
else:
print(f" 错误: {result.get('error')}")
break
else:
print("⚠️ 测试数据不可用,跳过结构分析测试")
if __name__ == "__main__":
success = run_tests()
# 如果基本测试通过,尝试结构分析测试
if success:
try:
run_structure_analysis_test()
except Exception as e:
print(f"\n结构分析测试出错: {e}")
sys.exit(0 if success else 1)