Spaces:
Sleeping
Sleeping
| """ | |
| MatDeepLearn MCP Service Test Script | |
| 测试 MCP 服务的各个功能是否正常工作 | |
| 直接测试底层函数,不通过 MCP 装饰器 | |
| """ | |
| import sys | |
| import os | |
| import json | |
| # 添加项目路径 | |
| project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| if project_root not in sys.path: | |
| sys.path.insert(0, project_root) | |
| mcp_plugin_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "mcp_plugin") | |
| if mcp_plugin_dir not in sys.path: | |
| sys.path.insert(0, mcp_plugin_dir) | |
| def print_result(test_name: str, result: dict): | |
| """打印测试结果""" | |
| status = "✅ PASS" if result.get("success", False) else "❌ FAIL" | |
| print(f"\n{'='*60}") | |
| print(f"测试: {test_name}") | |
| print(f"状态: {status}") | |
| # 美化输出 | |
| print(f"结果: {json.dumps(result, indent=2, ensure_ascii=False, default=str)}") | |
| print(f"{'='*60}") | |
| return result.get("success", False) | |
| # ============== 直接定义测试函数(复制核心逻辑)============== | |
| def test_check_environment() -> dict: | |
| """检查环境配置""" | |
| result = { | |
| "success": True, | |
| "torch_available": False, | |
| "torch_geometric_available": False, | |
| "matdeeplearn_available": False, | |
| "gpu_available": False, | |
| "gpu_count": 0, | |
| "gpu_name": "N/A", | |
| "available_models": [ | |
| "CGCNN_demo", "MPNN_demo", "SchNet_demo", | |
| "MEGNet_demo", "GCN_demo", "SOAP_demo", "SM_demo" | |
| ] | |
| } | |
| # 检查 PyTorch | |
| try: | |
| import torch | |
| result["torch_available"] = True | |
| result["torch_version"] = torch.__version__ | |
| result["gpu_available"] = torch.cuda.is_available() | |
| result["gpu_count"] = torch.cuda.device_count() if result["gpu_available"] else 0 | |
| result["gpu_name"] = torch.cuda.get_device_name(0) if result["gpu_available"] else "N/A" | |
| except ImportError: | |
| result["torch_version"] = "未安装" | |
| # 检查 PyTorch Geometric | |
| try: | |
| import torch_geometric | |
| result["torch_geometric_available"] = True | |
| result["torch_geometric_version"] = torch_geometric.__version__ | |
| except ImportError: | |
| result["torch_geometric_version"] = "未安装" | |
| # 检查 MatDeepLearn | |
| try: | |
| from matdeeplearn import models, process, training | |
| result["matdeeplearn_available"] = True | |
| except ImportError as e: | |
| result["matdeeplearn_error"] = str(e) | |
| # 如果核心依赖都有,标记成功 | |
| if result["torch_available"]: | |
| result["success"] = True | |
| if not result["torch_geometric_available"]: | |
| result["warning"] = "torch_geometric 未安装,部分功能不可用" | |
| else: | |
| result["success"] = False | |
| result["error"] = "PyTorch 未安装" | |
| return result | |
| def test_list_available_models() -> dict: | |
| """列出可用模型""" | |
| models_info = { | |
| "CGCNN_demo": { | |
| "name": "Crystal Graph Convolutional Neural Network", | |
| "description": "A GNN for predicting material properties using crystal graphs." | |
| }, | |
| "MPNN_demo": { | |
| "name": "Message Passing Neural Network", | |
| "description": "General message passing framework for molecular graphs." | |
| }, | |
| "SchNet_demo": { | |
| "name": "SchNet", | |
| "description": "Continuous-filter convolutional neural network." | |
| }, | |
| "MEGNet_demo": { | |
| "name": "MatErials Graph Network", | |
| "description": "Graph network with global state for materials." | |
| }, | |
| "GCN_demo": { | |
| "name": "Graph Convolutional Network", | |
| "description": "Standard graph convolutional network." | |
| }, | |
| "SOAP_demo": { | |
| "name": "Smooth Overlap of Atomic Positions", | |
| "description": "Descriptor-based method using SOAP features." | |
| }, | |
| "SM_demo": { | |
| "name": "Sine Matrix", | |
| "description": "Descriptor-based method using Sine/Coulomb matrix." | |
| } | |
| } | |
| return {"success": True, "models": models_info, "total_models": len(models_info)} | |
| def test_get_model_config(model_name: str) -> dict: | |
| """获取模型配置""" | |
| import yaml | |
| config_path = os.path.join(project_root, "config.yml") | |
| if not os.path.exists(config_path): | |
| return {"success": False, "error": "Config file not found"} | |
| with open(config_path, "r") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| if model_name not in config.get("Models", {}): | |
| return {"success": False, "error": f"Model '{model_name}' not found"} | |
| return { | |
| "success": True, | |
| "model_name": model_name, | |
| "model_config": config["Models"][model_name] | |
| } | |
| def test_get_dataset_info(data_path: str) -> dict: | |
| """获取数据集信息""" | |
| import csv | |
| if not os.path.exists(data_path): | |
| return {"success": False, "error": f"Data path not found: {data_path}"} | |
| extensions = {} | |
| for f in os.listdir(data_path): | |
| ext = os.path.splitext(f)[1].lower() | |
| extensions[ext] = extensions.get(ext, 0) + 1 | |
| has_targets = os.path.exists(os.path.join(data_path, "targets.csv")) | |
| has_processed = os.path.exists(os.path.join(data_path, "processed")) | |
| num_samples = 0 | |
| if has_targets: | |
| with open(os.path.join(data_path, "targets.csv")) as f: | |
| num_samples = sum(1 for _ in csv.reader(f)) | |
| return { | |
| "success": True, | |
| "data_path": data_path, | |
| "file_extensions": extensions, | |
| "has_targets_csv": has_targets, | |
| "has_processed_data": has_processed, | |
| "num_samples": num_samples | |
| } | |
| def test_analyze_structure(structure_file: str) -> dict: | |
| """分析结构文件""" | |
| import numpy as np | |
| import ase | |
| from ase import io | |
| if not os.path.exists(structure_file): | |
| return {"success": False, "error": f"File not found: {structure_file}"} | |
| structure = ase.io.read(structure_file) | |
| symbols = structure.get_chemical_symbols() | |
| distance_matrix = structure.get_all_distances(mic=True) | |
| cutoff_radius = 8.0 | |
| neighbors_count = [] | |
| for i in range(len(structure)): | |
| neighbors = np.sum((distance_matrix[i] > 0) & (distance_matrix[i] < cutoff_radius)) | |
| neighbors_count.append(int(neighbors)) | |
| return { | |
| "success": True, | |
| "num_atoms": len(structure), | |
| "chemical_formula": structure.get_chemical_formula(), | |
| "elements": list(set(symbols)), | |
| "has_periodicity": any(structure.pbc), | |
| "average_neighbors": float(np.mean(neighbors_count)) | |
| } | |
| def run_tests(): | |
| """运行所有测试""" | |
| print("\n" + "="*60) | |
| print("MatDeepLearn MCP Service 测试") | |
| print("="*60) | |
| passed = 0 | |
| failed = 0 | |
| # 测试 1: 检查环境 | |
| print("\n[测试 1/5] 检查环境配置...") | |
| result = test_check_environment() | |
| if print_result("check_environment", result): | |
| passed += 1 | |
| if result.get("gpu_available"): | |
| print(f" GPU: {result.get('gpu_name')} (数量: {result.get('gpu_count')})") | |
| else: | |
| print(" GPU: 不可用 (将使用 CPU)") | |
| print(f" PyTorch 版本: {result.get('torch_version')}") | |
| else: | |
| failed += 1 | |
| # 测试 2: 列出可用模型 | |
| print("\n[测试 2/5] 列出可用模型...") | |
| result = test_list_available_models() | |
| if print_result("list_available_models", result): | |
| passed += 1 | |
| print(f" 可用模型数量: {result.get('total_models')}") | |
| for name, info in result.get("models", {}).items(): | |
| print(f" - {name}: {info.get('name')}") | |
| else: | |
| failed += 1 | |
| # 测试 3: 获取模型配置 | |
| print("\n[测试 3/5] 获取 CGCNN_demo 模型配置...") | |
| result = test_get_model_config("CGCNN_demo") | |
| if print_result("get_model_config", result): | |
| passed += 1 | |
| config = result.get("model_config", {}) | |
| print(f" 模型类型: {config.get('model')}") | |
| print(f" Epochs: {config.get('epochs')}") | |
| print(f" Batch Size: {config.get('batch_size')}") | |
| print(f" Learning Rate: {config.get('lr')}") | |
| else: | |
| failed += 1 | |
| # 测试 4: 获取数据集信息 (使用 test_data 如果存在) | |
| print("\n[测试 4/5] 获取数据集信息...") | |
| test_data_path = os.path.join(project_root, "data", "test_data", "test_data") | |
| if os.path.exists(test_data_path): | |
| result = test_get_dataset_info(test_data_path) | |
| if print_result("get_dataset_info", result): | |
| passed += 1 | |
| print(f" 数据路径: {result.get('data_path')}") | |
| print(f" 样本数量: {result.get('num_samples')}") | |
| print(f" 已处理: {result.get('has_processed_data')}") | |
| else: | |
| failed += 1 | |
| else: | |
| # 尝试检查 data 目录 | |
| data_path = os.path.join(project_root, "data") | |
| result = test_get_dataset_info(data_path) | |
| if result.get("success"): | |
| print_result("get_dataset_info (data目录)", result) | |
| passed += 1 | |
| else: | |
| print(f"⚠️ 跳过: 测试数据目录不存在 ({test_data_path})") | |
| print(" 提示: 请解压 data/test_data.tar.gz 以进行完整测试") | |
| passed += 1 # 跳过不算失败 | |
| # 测试 5: 测试不存在的模型配置(错误处理) | |
| print("\n[测试 5/5] 测试错误处理 (不存在的模型)...") | |
| result = test_get_model_config("NonExistentModel") | |
| if not result.get("success"): | |
| print(f"✅ 错误处理正常: {result.get('error')}") | |
| passed += 1 | |
| else: | |
| print("❌ 错误处理失败: 应该返回错误") | |
| failed += 1 | |
| # 总结 | |
| print("\n" + "="*60) | |
| print("测试总结") | |
| print("="*60) | |
| print(f"通过: {passed}") | |
| print(f"失败: {failed}") | |
| print(f"总计: {passed + failed}") | |
| print("="*60) | |
| if failed == 0: | |
| print("\n🎉 所有测试通过!MCP 服务已准备就绪。") | |
| print("\n下一步:") | |
| print(" 1. 本地运行: python mcp_output/start_mcp.py") | |
| print(" 2. HTTP 模式: MCP_TRANSPORT=http python mcp_output/start_mcp.py") | |
| print(" 3. 部署到 HuggingFace Space") | |
| return True | |
| else: | |
| print(f"\n⚠️ 有 {failed} 个测试失败,请检查错误信息。") | |
| return False | |
| def run_structure_analysis_test(): | |
| """测试结构分析功能(如果有测试数据)""" | |
| print("\n" + "="*60) | |
| print("额外测试: 结构分析") | |
| print("="*60) | |
| # 查找可用的结构文件 | |
| test_data_path = os.path.join(project_root, "data", "test_data", "test_data") | |
| if os.path.exists(test_data_path): | |
| # 查找第一个 json 文件 | |
| for f in os.listdir(test_data_path): | |
| if f.endswith('.json') and f != 'atom_dict.json': | |
| structure_file = os.path.join(test_data_path, f) | |
| print(f"\n分析结构文件: {f}") | |
| result = test_analyze_structure(structure_file) | |
| if result.get("success"): | |
| print(f" 化学式: {result.get('chemical_formula')}") | |
| print(f" 原子数: {result.get('num_atoms')}") | |
| print(f" 元素: {result.get('elements')}") | |
| print(f" 周期性: {result.get('has_periodicity')}") | |
| print(f" 平均邻居数: {result.get('average_neighbors'):.2f}") | |
| else: | |
| print(f" 错误: {result.get('error')}") | |
| break | |
| else: | |
| print("⚠️ 测试数据不可用,跳过结构分析测试") | |
| if __name__ == "__main__": | |
| success = run_tests() | |
| # 如果基本测试通过,尝试结构分析测试 | |
| if success: | |
| try: | |
| run_structure_analysis_test() | |
| except Exception as e: | |
| print(f"\n结构分析测试出错: {e}") | |
| sys.exit(0 if success else 1) | |