Spaces:
Sleeping
Sleeping
| """ | |
| MatDeepLearn MCP Service | |
| A Model Context Protocol service for materials property prediction using Graph Neural Networks. | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import tempfile | |
| import yaml | |
| import numpy as np | |
| import base64 | |
| import hashlib | |
| import shutil | |
| import uuid | |
| import zipfile | |
| import tarfile | |
| import io | |
| from datetime import datetime | |
| from typing import Optional, List, Dict, Any | |
| from pathlib import Path | |
| # Add MatDeepLearn to path | |
| project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| if project_root not in sys.path: | |
| sys.path.insert(0, project_root) | |
| from fastmcp import FastMCP | |
| # Import MatDeepLearn modules | |
| try: | |
| import torch | |
| from matdeeplearn import models, process, training | |
| from matdeeplearn.models.utils import model_summary | |
| MATDEEPLEARN_AVAILABLE = True | |
| except ImportError as e: | |
| MATDEEPLEARN_AVAILABLE = False | |
| IMPORT_ERROR = str(e) | |
| mcp = FastMCP("matdeeplearn_service") | |
| # ============================================================================ | |
| # 全局存储管理 - 用于管理上传的数据和训练的模型 | |
| # ============================================================================ | |
| # 服务器端存储目录 | |
| STORAGE_BASE = os.path.join(project_root, "mcp_storage") | |
| DATASETS_DIR = os.path.join(STORAGE_BASE, "datasets") | |
| MODELS_DIR = os.path.join(STORAGE_BASE, "models") | |
| SESSIONS_DIR = os.path.join(STORAGE_BASE, "sessions") | |
| # 确保存储目录存在 | |
| for dir_path in [STORAGE_BASE, DATASETS_DIR, MODELS_DIR, SESSIONS_DIR]: | |
| os.makedirs(dir_path, exist_ok=True) | |
| # 会话管理字典 (session_id -> session_info) | |
| _sessions: Dict[str, Dict] = {} | |
| def _get_session_path(session_id: str) -> str: | |
| """获取会话目录路径""" | |
| return os.path.join(SESSIONS_DIR, session_id) | |
| def _generate_session_id() -> str: | |
| """生成唯一会话ID""" | |
| return f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" | |
| def _generate_dataset_id(name: str) -> str: | |
| """生成数据集ID""" | |
| return f"dataset_{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" | |
| def _generate_model_id(model_name: str) -> str: | |
| """生成模型ID""" | |
| return f"model_{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" | |
| def _safe_join(base: str, *paths: str) -> str: | |
| """Join paths and ensure the result stays inside base directory.""" | |
| base_path = Path(base).resolve() | |
| target_path = base_path.joinpath(*paths).resolve() | |
| if not str(target_path).startswith(str(base_path)): | |
| raise ValueError("Attempted to write outside of the allowed directory") | |
| return str(target_path) | |
| def _normalize_filename(filename: str) -> str: | |
| """Return sanitized filename without directory components.""" | |
| clean_name = os.path.basename(filename) | |
| if not clean_name: | |
| raise ValueError("Filename must not be empty") | |
| return clean_name | |
| def _clear_directory_contents(directory: str) -> None: | |
| """Remove all files and folders inside the provided directory.""" | |
| if not os.path.isdir(directory): | |
| return | |
| for entry in os.listdir(directory): | |
| entry_path = os.path.join(directory, entry) | |
| if os.path.isdir(entry_path): | |
| shutil.rmtree(entry_path) | |
| else: | |
| os.remove(entry_path) | |
| def _copy_tree(src: str, dst: str, overwrite: bool = False) -> Dict[str, List[str]]: | |
| """Copy a directory tree with overwrite and traversal protection.""" | |
| results = {"created": [], "overwritten": [], "skipped": []} | |
| src_path = Path(src) | |
| for root, _, files in os.walk(src_path): | |
| rel_root = os.path.relpath(root, src_path) | |
| rel_root = "" if rel_root == "." else rel_root | |
| for file_name in files: | |
| if file_name.startswith(".__MACOSX"): | |
| continue | |
| rel_path = os.path.normpath(os.path.join(rel_root, file_name)) | |
| dest_path = _safe_join(dst, rel_path) | |
| os.makedirs(os.path.dirname(dest_path), exist_ok=True) | |
| src_file = os.path.join(root, file_name) | |
| if os.path.exists(dest_path): | |
| if overwrite: | |
| shutil.copy2(src_file, dest_path) | |
| results["overwritten"].append(rel_path.replace("\\", "/")) | |
| else: | |
| results["skipped"].append(rel_path.replace("\\", "/")) | |
| else: | |
| shutil.copy2(src_file, dest_path) | |
| results["created"].append(rel_path.replace("\\", "/")) | |
| return results | |
| def _resolve_dataset_root(extracted_dir: str) -> str: | |
| """Select the most probable dataset root inside an extracted archive.""" | |
| entries = [p for p in Path(extracted_dir).iterdir() if not p.name.startswith("__MACOSX")] | |
| if len(entries) == 1 and entries[0].is_dir(): | |
| return str(entries[0]) | |
| return extracted_dir | |
| def _update_session_uploaded_files(session_path: str) -> None: | |
| """Rescan session data directory and persist uploaded file list.""" | |
| info_file = os.path.join(session_path, "session_info.json") | |
| if not os.path.exists(info_file): | |
| return | |
| data_dir = os.path.join(session_path, "data") | |
| uploaded = [] | |
| if os.path.exists(data_dir): | |
| for root, _, files in os.walk(data_dir): | |
| for name in files: | |
| rel_path = os.path.relpath(os.path.join(root, name), data_dir) | |
| uploaded.append(rel_path.replace("\\", "/")) | |
| with open(info_file, 'r', encoding='utf-8') as f: | |
| session_info = json.load(f) | |
| session_info["uploaded_files"] = sorted(uploaded) | |
| with open(info_file, 'w', encoding='utf-8') as f: | |
| json.dump(session_info, f, indent=2) | |
| def _record_session_data_source(session_path: str, source_info: Dict[str, Any]) -> None: | |
| """Append dataset source metadata to the session record.""" | |
| info_file = os.path.join(session_path, "session_info.json") | |
| if not os.path.exists(info_file): | |
| return | |
| with open(info_file, 'r', encoding='utf-8') as f: | |
| session_info = json.load(f) | |
| data_sources = session_info.setdefault("data_sources", []) | |
| data_sources.append(source_info) | |
| session_info["data_sources"] = data_sources[-10:] | |
| with open(info_file, 'w', encoding='utf-8') as f: | |
| json.dump(session_info, f, indent=2) | |
| def _summarize_dataset_directory(data_path: str) -> Dict[str, Any]: | |
| """Collect lightweight statistics about files inside a dataset directory.""" | |
| summary = { | |
| "total_files": 0, | |
| "targets_csv": False, | |
| "structure_extensions": {} | |
| } | |
| if not os.path.exists(data_path): | |
| return summary | |
| for root, _, files in os.walk(data_path): | |
| for name in files: | |
| summary["total_files"] += 1 | |
| if name.lower() == "targets.csv": | |
| summary["targets_csv"] = True | |
| else: | |
| ext = os.path.splitext(name)[1].lower() or "<no_ext>" | |
| summary["structure_extensions"][ext] = summary["structure_extensions"].get(ext, 0) + 1 | |
| return summary | |
| def _safe_extract_zip(archive: zipfile.ZipFile, destination: str) -> None: | |
| """Extract zip members while preventing path traversal.""" | |
| for member in archive.infolist(): | |
| name = member.filename | |
| if not name: | |
| continue | |
| target_path = _safe_join(destination, name) | |
| if member.is_dir() or name.endswith('/'): | |
| os.makedirs(target_path, exist_ok=True) | |
| continue | |
| os.makedirs(os.path.dirname(target_path), exist_ok=True) | |
| with archive.open(member, 'r') as src, open(target_path, 'wb') as dst: | |
| shutil.copyfileobj(src, dst) | |
| def _safe_extract_tar(archive: tarfile.TarFile, destination: str) -> None: | |
| """Extract tar members while preventing unsafe writes.""" | |
| for member in archive.getmembers(): | |
| name = member.name | |
| if not name: | |
| continue | |
| if member.islnk() or member.issym(): | |
| continue | |
| target_path = _safe_join(destination, name) | |
| if member.isdir(): | |
| os.makedirs(target_path, exist_ok=True) | |
| continue | |
| if member.isfile(): | |
| extracted = archive.extractfile(member) | |
| if extracted is None: | |
| continue | |
| os.makedirs(os.path.dirname(target_path), exist_ok=True) | |
| with extracted as src, open(target_path, 'wb') as dst: | |
| shutil.copyfileobj(src, dst) | |
| def _build_local_training_script( | |
| data_path_literal: str, | |
| model_name: str, | |
| target_index: int, | |
| epochs: int, | |
| batch_size: int, | |
| learning_rate: float, | |
| train_ratio: float, | |
| val_ratio: float, | |
| test_ratio: float, | |
| save_model: bool, | |
| model_output_literal: str, | |
| reprocess: bool = False, | |
| config_path_literal: str = "config.yml" | |
| ) -> str: | |
| """Generate a reusable Python script string for local MatDeepLearn training.""" | |
| job_name = f"local_train_{model_name.lower()}" | |
| reprocess_str = "True" if reprocess else "False" | |
| save_model_str = "True" if save_model else "False" | |
| script = """import yaml | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from matdeeplearn import training | |
| DATA_PATH = Path(r"{data_path}") | |
| CONFIG_PATH = Path(r"{config_path}") | |
| MODEL_OUTPUT = Path(r"{model_output}") | |
| def main() -> None: | |
| if not DATA_PATH.exists(): | |
| raise FileNotFoundError("Data path not found: {{}}".format(DATA_PATH)) | |
| if not CONFIG_PATH.exists(): | |
| raise FileNotFoundError("Config file not found: {{}}".format(CONFIG_PATH)) | |
| with CONFIG_PATH.open("r", encoding="utf-8") as f: | |
| config = yaml.safe_load(f) | |
| job_config = {{ | |
| "job_name": "{job_name}", | |
| "reprocess": "{reprocess}", | |
| "model": "{model_name}", | |
| "load_model": "False", | |
| "save_model": "{save_model}", | |
| "model_path": str(MODEL_OUTPUT.resolve()), | |
| "write_output": "True", | |
| "parallel": "False", | |
| "seed": int(np.random.randint(1, 1_000_000)), | |
| }} | |
| training_config = config.get("Training", {{}}).copy() | |
| training_config.update({{ | |
| "target_index": {target_index}, | |
| "train_ratio": {train_ratio}, | |
| "val_ratio": {val_ratio}, | |
| "test_ratio": {test_ratio}, | |
| "verbosity": 5, | |
| }}) | |
| model_config = config["Models"]["{model_name}"].copy() | |
| model_config.update({{ | |
| "epochs": {epochs}, | |
| "batch_size": {batch_size}, | |
| "lr": {learning_rate}, | |
| }}) | |
| data_path = str(DATA_PATH.resolve()) | |
| if MODEL_OUTPUT.parent != Path("."): | |
| MODEL_OUTPUT.parent.mkdir(parents=True, exist_ok=True) | |
| world_size = torch.cuda.device_count() | |
| rank = "cuda" if torch.cuda.is_available() else "cpu" | |
| errors = training.train_regular( | |
| rank, | |
| world_size, | |
| data_path, | |
| job_config, | |
| training_config, | |
| model_config, | |
| ) | |
| print("Training complete. Train/Val/Test errors:", errors) | |
| if __name__ == "__main__": | |
| main() | |
| """.format( | |
| data_path=data_path_literal, | |
| config_path=config_path_literal, | |
| model_output=model_output_literal, | |
| job_name=job_name, | |
| reprocess=reprocess_str, | |
| model_name=model_name, | |
| save_model=save_model_str, | |
| target_index=target_index, | |
| train_ratio=train_ratio, | |
| val_ratio=val_ratio, | |
| test_ratio=test_ratio, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| learning_rate=learning_rate, | |
| ) | |
| return script | |
| def _infer_structure_format(structure_files: List[str]) -> Optional[str]: | |
| """Guess the structure file extension for ASE parsing.""" | |
| normalized: Dict[str, int] = {} | |
| for file_name in structure_files: | |
| root, ext = os.path.splitext(file_name) | |
| if ext: | |
| ext = ext.lower().lstrip('.') | |
| else: | |
| upper_name = os.path.basename(file_name).upper() | |
| if upper_name in {"POSCAR", "CONTCAR"}: | |
| ext = "vasp" | |
| else: | |
| ext = "" | |
| if not ext: | |
| continue | |
| normalized[ext] = normalized.get(ext, 0) + 1 | |
| if not normalized: | |
| return None | |
| if len(normalized) == 1: | |
| return next(iter(normalized)) | |
| # Prefer common solid-state formats | |
| priority = ["cif", "vasp", "poscar", "xyz", "json"] | |
| sorted_items = sorted(normalized.items(), key=lambda item: (-item[1], priority.index(item[0]) if item[0] in priority else 99)) | |
| top_ext, top_count = sorted_items[0] | |
| if len(structure_files) == top_count: | |
| return top_ext | |
| return None | |
| # ============================================================================ | |
| # 会话管理工具 | |
| # ============================================================================ | |
| def create_session(session_name: Optional[str] = None) -> dict: | |
| """ | |
| Create a new working session. Use this before uploading data. | |
| Parameters: | |
| session_name (str, optional): A friendly name for this session. | |
| Returns: | |
| dict: Contains session_id and session info. | |
| Example: | |
| create_session(session_name="my_material_project") | |
| """ | |
| try: | |
| session_id = _generate_session_id() | |
| session_path = _get_session_path(session_id) | |
| os.makedirs(session_path, exist_ok=True) | |
| os.makedirs(os.path.join(session_path, "data"), exist_ok=True) | |
| os.makedirs(os.path.join(session_path, "models"), exist_ok=True) | |
| os.makedirs(os.path.join(session_path, "outputs"), exist_ok=True) | |
| # Initialize an id index file under data to avoid missing-file errors | |
| id_index_path = os.path.join(session_path, "data", "id.json") | |
| if not os.path.exists(id_index_path): | |
| try: | |
| with open(id_index_path, 'w', encoding='utf-8') as _idf: | |
| json.dump({}, _idf) | |
| except Exception: | |
| pass | |
| session_info = { | |
| "session_id": session_id, | |
| "session_name": session_name or session_id, | |
| "created_at": datetime.now().isoformat(), | |
| "data_path": os.path.join(session_path, "data"), | |
| "models_path": os.path.join(session_path, "models"), | |
| "outputs_path": os.path.join(session_path, "outputs"), | |
| "uploaded_files": [], | |
| "trained_models": [], | |
| "data_sources": [], | |
| "status": "active" | |
| } | |
| _sessions[session_id] = session_info | |
| # Save session info to disk | |
| with open(os.path.join(session_path, "session_info.json"), 'w') as f: | |
| json.dump(session_info, f, indent=2) | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "session_name": session_info["session_name"], | |
| "message": "Session created successfully. Use this session_id for uploading data and training.", | |
| "next_steps": [ | |
| "1. Upload structure files using upload_structure_files", | |
| "2. Upload targets.csv using upload_targets", | |
| "3. Process data using process_session_data", | |
| "4. Train model using train_session_model" | |
| ] | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def get_session_info(session_id: str) -> dict: | |
| """ | |
| Get information about an existing session. | |
| Parameters: | |
| session_id (str): The session ID returned from create_session. | |
| Returns: | |
| dict: Session information including uploaded files and trained models. | |
| """ | |
| try: | |
| session_path = _get_session_path(session_id) | |
| info_file = os.path.join(session_path, "session_info.json") | |
| if not os.path.exists(info_file): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| with open(info_file, 'r') as f: | |
| session_info = json.load(f) | |
| # Update with current file counts | |
| data_path = session_info["data_path"] | |
| if os.path.exists(data_path): | |
| files = os.listdir(data_path) | |
| session_info["current_files"] = files | |
| session_info["file_count"] = len(files) | |
| session_info["has_targets"] = "targets.csv" in files | |
| session_info["dataset_summary"] = _summarize_dataset_directory(data_path) | |
| return {"success": True, **session_info} | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def list_sessions() -> dict: | |
| """ | |
| List all available sessions on the server. | |
| Returns: | |
| dict: List of sessions with their basic info. | |
| """ | |
| try: | |
| sessions = [] | |
| if os.path.exists(SESSIONS_DIR): | |
| for session_id in os.listdir(SESSIONS_DIR): | |
| info_file = os.path.join(SESSIONS_DIR, session_id, "session_info.json") | |
| if os.path.exists(info_file): | |
| with open(info_file, 'r') as f: | |
| info = json.load(f) | |
| sessions.append({ | |
| "session_id": session_id, | |
| "session_name": info.get("session_name", session_id), | |
| "created_at": info.get("created_at"), | |
| "status": info.get("status", "unknown") | |
| }) | |
| return { | |
| "success": True, | |
| "sessions": sessions, | |
| "total_sessions": len(sessions) | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def delete_session(session_id: str, confirm: bool = False) -> dict: | |
| """ | |
| Delete a session and all associated data. | |
| Parameters: | |
| session_id (str): The session ID to delete. | |
| confirm (bool): Must be True to confirm deletion. | |
| Returns: | |
| dict: Deletion status. | |
| """ | |
| try: | |
| if not confirm: | |
| return { | |
| "success": False, | |
| "error": "Please set confirm=True to delete the session. This action cannot be undone." | |
| } | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| shutil.rmtree(session_path) | |
| if session_id in _sessions: | |
| del _sessions[session_id] | |
| return { | |
| "success": True, | |
| "message": f"Session {session_id} deleted successfully." | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| # ============================================================================ | |
| # 数据上传工具 | |
| # ============================================================================ | |
| def upload_structure_file( | |
| session_id: str, | |
| filename: str, | |
| file_content: str, | |
| file_format: Optional[str] = None | |
| ) -> dict: | |
| """ | |
| Upload a single structure file to a session. | |
| Parameters: | |
| session_id (str): The session ID. | |
| filename (str): Name for the file (e.g., "structure1.cif"). | |
| file_content (str): The complete file content as a string. | |
| file_format (str, optional): File format hint (auto-detected from filename if not provided). | |
| Returns: | |
| dict: Upload status and file info. | |
| Example: | |
| upload_structure_file( | |
| session_id="session_xxx", | |
| filename="NaCl.cif", | |
| file_content="data_NaCl\\n_cell_length_a 5.64..." | |
| ) | |
| """ | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| data_path = os.path.join(session_path, "data") | |
| os.makedirs(data_path, exist_ok=True) | |
| filename = _normalize_filename(filename) | |
| file_path = _safe_join(data_path, filename) | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| f.write(file_content) | |
| # Validate structure if possible | |
| validation = {"valid": True} | |
| try: | |
| import ase.io | |
| with tempfile.NamedTemporaryFile(mode='w', suffix=os.path.splitext(filename)[1], delete=False) as tmp: | |
| tmp.write(file_content) | |
| tmp_path = tmp.name | |
| try: | |
| structure = ase.io.read(tmp_path) | |
| validation = { | |
| "valid": True, | |
| "num_atoms": len(structure), | |
| "formula": structure.get_chemical_formula() | |
| } | |
| finally: | |
| os.unlink(tmp_path) | |
| except Exception as e: | |
| validation = {"valid": False, "warning": str(e)} | |
| # Update id.json index | |
| try: | |
| id_index_path = os.path.join(data_path, "id.json") | |
| if not os.path.exists(id_index_path): | |
| with open(id_index_path, 'w', encoding='utf-8') as _idf: | |
| json.dump({}, _idf) | |
| with open(id_index_path, 'r', encoding='utf-8') as _idf: | |
| id_index = json.load(_idf) | |
| except Exception: | |
| id_index = {} | |
| file_id = uuid.uuid4().hex | |
| id_index[file_id] = { | |
| "filename": filename, | |
| "uploaded_at": datetime.now().isoformat(), | |
| "size": len(file_content) | |
| } | |
| try: | |
| with open(id_index_path, 'w', encoding='utf-8') as _idf: | |
| json.dump(id_index, _idf, indent=2) | |
| except Exception: | |
| pass | |
| try: | |
| _update_session_uploaded_files(session_path) | |
| except Exception: | |
| pass | |
| return { | |
| "success": True, | |
| "filename": filename, | |
| "file_size": len(file_content), | |
| "saved_to": file_path, | |
| "validation": validation, | |
| "file_id": file_id | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def upload_structure_files_batch( | |
| session_id: str, | |
| files: Dict[str, str] | |
| ) -> dict: | |
| """ | |
| Upload multiple structure files to a session in one call. | |
| Parameters: | |
| session_id (str): The session ID. | |
| files (dict): Dictionary mapping filename to file content. | |
| Example: {"struct1.cif": "content1", "struct2.cif": "content2"} | |
| Returns: | |
| dict: Upload status for all files. | |
| Example: | |
| upload_structure_files_batch( | |
| session_id="session_xxx", | |
| files={ | |
| "NaCl.cif": "data_NaCl...", | |
| "ZnO.cif": "data_ZnO..." | |
| } | |
| ) | |
| """ | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| data_path = os.path.join(session_path, "data") | |
| os.makedirs(data_path, exist_ok=True) | |
| results = [] | |
| success_count = 0 | |
| for filename, content in files.items(): | |
| try: | |
| clean_name = _normalize_filename(filename) | |
| file_path = _safe_join(data_path, clean_name) | |
| with open(file_path, 'w', encoding='utf-8') as f: | |
| f.write(content) | |
| # update id index | |
| try: | |
| id_index_path = os.path.join(data_path, "id.json") | |
| if not os.path.exists(id_index_path): | |
| with open(id_index_path, 'w', encoding='utf-8') as _idf: | |
| json.dump({}, _idf) | |
| with open(id_index_path, 'r', encoding='utf-8') as _idf: | |
| id_index = json.load(_idf) | |
| except Exception: | |
| id_index = {} | |
| file_id = uuid.uuid4().hex | |
| id_index[file_id] = {"filename": clean_name, "uploaded_at": datetime.now().isoformat(), "size": len(content)} | |
| try: | |
| with open(id_index_path, 'w', encoding='utf-8') as _idf: | |
| json.dump(id_index, _idf, indent=2) | |
| except Exception: | |
| pass | |
| results.append({ | |
| "filename": clean_name, | |
| "success": True, | |
| "size": len(content), | |
| "file_id": file_id | |
| }) | |
| success_count += 1 | |
| except Exception as e: | |
| results.append({ | |
| "filename": filename, | |
| "success": False, | |
| "error": str(e) | |
| }) | |
| try: | |
| _update_session_uploaded_files(session_path) | |
| except Exception: | |
| pass | |
| return { | |
| "success": True, | |
| "total_files": len(files), | |
| "successful_uploads": success_count, | |
| "failed_uploads": len(files) - success_count, | |
| "results": results | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def upload_targets( | |
| session_id: str, | |
| targets_content: str, | |
| validate: bool = True | |
| ) -> dict: | |
| """ | |
| Upload targets.csv file to a session. | |
| Parameters: | |
| session_id (str): The session ID. | |
| targets_content (str): Content of targets.csv file. | |
| Format: structure_id,target_value (one per line). | |
| validate (bool): Whether to validate the targets file. | |
| Returns: | |
| dict: Upload status and validation info. | |
| Example: | |
| upload_targets( | |
| session_id="session_xxx", | |
| targets_content="NaCl,1.5\\nZnO,2.3\\nTiO2,3.1" | |
| ) | |
| """ | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| data_path = os.path.join(session_path, "data") | |
| targets_path = os.path.join(data_path, "targets.csv") | |
| with open(targets_path, 'w', encoding='utf-8') as f: | |
| f.write(targets_content) | |
| # Validate and analyze | |
| validation = {"valid": True} | |
| if validate: | |
| import csv | |
| from io import StringIO | |
| reader = csv.reader(StringIO(targets_content)) | |
| rows = list(reader) | |
| structure_ids = [] | |
| target_values = [] | |
| for row in rows: | |
| if len(row) >= 2: | |
| structure_ids.append(row[0]) | |
| try: | |
| target_values.append(float(row[1])) | |
| except: | |
| pass | |
| # Check for matching structure files | |
| existing_files = os.listdir(data_path) | |
| structure_files = [f for f in existing_files if f != "targets.csv"] | |
| structure_names = [os.path.splitext(f)[0] for f in structure_files] | |
| matched = [sid for sid in structure_ids if sid in structure_names] | |
| unmatched = [sid for sid in structure_ids if sid not in structure_names] | |
| validation = { | |
| "valid": True, | |
| "num_samples": len(rows), | |
| "num_valid_targets": len(target_values), | |
| "target_range": { | |
| "min": min(target_values) if target_values else None, | |
| "max": max(target_values) if target_values else None, | |
| "mean": sum(target_values) / len(target_values) if target_values else None | |
| }, | |
| "matched_structures": len(matched), | |
| "unmatched_structures": unmatched[:10] if unmatched else [], | |
| "existing_structure_files": len(structure_files) | |
| } | |
| try: | |
| _update_session_uploaded_files(session_path) | |
| except Exception: | |
| pass | |
| return { | |
| "success": True, | |
| "saved_to": targets_path, | |
| "validation": validation | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def upload_binary_file( | |
| session_id: str, | |
| filename: str, | |
| base64_content: str, | |
| destination: str = "models" | |
| ) -> dict: | |
| """ | |
| Upload a binary file (e.g., pre-trained model .pth file) encoded as base64. | |
| Parameters: | |
| session_id (str): The session ID. | |
| filename (str): Name for the file. | |
| base64_content (str): File content encoded as base64 string. | |
| destination (str): Where to save - "models" or "data". | |
| Returns: | |
| dict: Upload status. | |
| Example: | |
| # In Python, encode your model file: | |
| # import base64 | |
| # with open("model.pth", "rb") as f: | |
| # encoded = base64.b64encode(f.read()).decode() | |
| # Then pass encoded as base64_content | |
| """ | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| if destination == "models": | |
| dest_path = os.path.join(session_path, "models") | |
| else: | |
| dest_path = os.path.join(session_path, "data") | |
| filename = _normalize_filename(filename) | |
| file_path = _safe_join(dest_path, filename) | |
| # Decode and write binary content | |
| binary_content = base64.b64decode(base64_content) | |
| with open(file_path, 'wb') as f: | |
| f.write(binary_content) | |
| try: | |
| if destination != "models": | |
| _update_session_uploaded_files(session_path) | |
| except Exception: | |
| pass | |
| return { | |
| "success": True, | |
| "filename": filename, | |
| "file_size_bytes": len(binary_content), | |
| "saved_to": file_path | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def upload_dataset_archive( | |
| session_id: str, | |
| filename: str, | |
| base64_content: str, | |
| overwrite_existing: bool = False, | |
| clear_existing: bool = False | |
| ) -> dict: | |
| """Decode and extract a dataset archive directly into the session data folder.""" | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| data_path = os.path.join(session_path, "data") | |
| os.makedirs(data_path, exist_ok=True) | |
| filename = _normalize_filename(filename) | |
| archive_bytes = base64.b64decode(base64_content) | |
| temp_dir = tempfile.mkdtemp(prefix="mcp_dataset_") | |
| try: | |
| archive_lower = filename.lower() | |
| if archive_lower.endswith(".zip"): | |
| with zipfile.ZipFile(io.BytesIO(archive_bytes)) as archive: | |
| _safe_extract_zip(archive, temp_dir) | |
| elif archive_lower.endswith((".tar", ".tar.gz", ".tgz", ".tar.bz2", ".tbz")): | |
| with tarfile.open(fileobj=io.BytesIO(archive_bytes), mode="r:*") as archive: | |
| _safe_extract_tar(archive, temp_dir) | |
| else: | |
| return { | |
| "success": False, | |
| "error": "Unsupported archive format. Use .zip, .tar, .tar.gz, .tgz, .tar.bz2" | |
| } | |
| dataset_root = _resolve_dataset_root(temp_dir) | |
| if not os.listdir(dataset_root): | |
| return {"success": False, "error": "Archive appears to be empty after extraction."} | |
| if clear_existing: | |
| _clear_directory_contents(data_path) | |
| copy_report = _copy_tree(dataset_root, data_path, overwrite=overwrite_existing) | |
| finally: | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| try: | |
| _update_session_uploaded_files(session_path) | |
| except Exception: | |
| pass | |
| summary = _summarize_dataset_directory(data_path) | |
| try: | |
| _record_session_data_source(session_path, { | |
| "type": "archive_upload", | |
| "filename": filename, | |
| "timestamp": datetime.now().isoformat(), | |
| "stats": { | |
| "created": len(copy_report["created"]), | |
| "overwritten": len(copy_report["overwritten"]), | |
| "skipped": len(copy_report["skipped"]) | |
| } | |
| }) | |
| except Exception: | |
| pass | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "archive_name": filename, | |
| "created_files": copy_report["created"], | |
| "overwritten_files": copy_report["overwritten"], | |
| "skipped_files": copy_report["skipped"], | |
| "dataset_summary": summary, | |
| "next_steps": [ | |
| "Use process_session_data to generate graphs", | |
| "Confirm targets.csv is present before training" | |
| ] | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def register_local_dataset( | |
| session_id: str, | |
| dataset_path: str, | |
| overwrite_existing: bool = False, | |
| clear_existing: bool = False | |
| ) -> dict: | |
| """Copy a dataset from disk into the managed session directory.""" | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| dataset_abs = os.path.abspath(dataset_path) | |
| if not os.path.exists(dataset_abs): | |
| return {"success": False, "error": f"Dataset path not found: {dataset_path}"} | |
| if not os.path.isdir(dataset_abs): | |
| return {"success": False, "error": "dataset_path must be a directory"} | |
| data_path = os.path.join(session_path, "data") | |
| os.makedirs(data_path, exist_ok=True) | |
| if Path(dataset_abs).resolve() == Path(data_path).resolve(): | |
| summary = _summarize_dataset_directory(data_path) | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "source_path": dataset_abs, | |
| "created_files": [], | |
| "overwritten_files": [], | |
| "skipped_files": [], | |
| "dataset_summary": summary, | |
| "message": "dataset_path already points to the session data directory; nothing to copy." | |
| } | |
| if clear_existing: | |
| _clear_directory_contents(data_path) | |
| copy_report = _copy_tree(dataset_abs, data_path, overwrite=overwrite_existing) | |
| try: | |
| _update_session_uploaded_files(session_path) | |
| except Exception: | |
| pass | |
| summary = _summarize_dataset_directory(data_path) | |
| try: | |
| _record_session_data_source(session_path, { | |
| "type": "local_import", | |
| "source_path": dataset_abs, | |
| "timestamp": datetime.now().isoformat(), | |
| "stats": { | |
| "created": len(copy_report["created"]), | |
| "overwritten": len(copy_report["overwritten"]), | |
| "skipped": len(copy_report["skipped"]) | |
| } | |
| }) | |
| except Exception: | |
| pass | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "source_path": dataset_abs, | |
| "created_files": copy_report["created"], | |
| "overwritten_files": copy_report["overwritten"], | |
| "skipped_files": copy_report["skipped"], | |
| "dataset_summary": summary, | |
| "next_steps": [ | |
| "Verify targets.csv is present in session data", | |
| "Run process_session_data to generate processed graphs" | |
| ] | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def list_session_data_files( | |
| session_id: str, | |
| include_sizes: bool = False, | |
| max_items: int = 200 | |
| ) -> dict: | |
| """Enumerate dataset files stored for a session.""" | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| data_path = os.path.join(session_path, "data") | |
| if not os.path.exists(data_path): | |
| return {"success": True, "files": [], "total_files": 0, "dataset_summary": {}} | |
| files_info = [] | |
| for root, _, files in os.walk(data_path): | |
| for name in files: | |
| rel_path = os.path.relpath(os.path.join(root, name), data_path).replace("\\", "/") | |
| file_entry: Dict[str, Any] = {"path": rel_path} | |
| file_abs_path = os.path.join(root, name) | |
| if include_sizes: | |
| file_entry["size_bytes"] = os.path.getsize(file_abs_path) | |
| files_info.append(file_entry) | |
| files_info.sort(key=lambda item: item["path"].lower()) | |
| total_files = len(files_info) | |
| truncated = files_info[:max(0, max_items)] | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "files": truncated, | |
| "total_files": total_files, | |
| "truncated": total_files > len(truncated), | |
| "dataset_summary": _summarize_dataset_directory(data_path) | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def check_environment() -> dict: | |
| """ | |
| Check if the MatDeepLearn environment is properly configured. | |
| Returns: | |
| dict: Contains environment status including GPU availability. | |
| """ | |
| try: | |
| if not MATDEEPLEARN_AVAILABLE: | |
| return { | |
| "success": False, | |
| "error": f"MatDeepLearn not available: {IMPORT_ERROR}" | |
| } | |
| gpu_available = torch.cuda.is_available() | |
| gpu_count = torch.cuda.device_count() if gpu_available else 0 | |
| gpu_name = torch.cuda.get_device_name(0) if gpu_available else "N/A" | |
| return { | |
| "success": True, | |
| "matdeeplearn_available": True, | |
| "torch_version": torch.__version__, | |
| "gpu_available": gpu_available, | |
| "gpu_count": gpu_count, | |
| "gpu_name": gpu_name, | |
| "available_models": [ | |
| "CGCNN_demo", "MPNN_demo", "SchNet_demo", | |
| "MEGNet_demo", "GCN_demo", "SOAP_demo", "SM_demo" | |
| ] | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def list_available_models() -> dict: | |
| """ | |
| List all available Graph Neural Network models. | |
| Returns: | |
| dict: Contains list of available models with descriptions. | |
| """ | |
| try: | |
| models_info = { | |
| "CGCNN_demo": { | |
| "name": "Crystal Graph Convolutional Neural Network", | |
| "description": "A GNN for predicting material properties using crystal graphs.", | |
| "paper": "Xie & Grossman, PRL 2018" | |
| }, | |
| "MPNN_demo": { | |
| "name": "Message Passing Neural Network", | |
| "description": "General message passing framework for molecular graphs.", | |
| "paper": "Gilmer et al., ICML 2017" | |
| }, | |
| "SchNet_demo": { | |
| "name": "SchNet", | |
| "description": "Continuous-filter convolutional neural network for modeling quantum interactions.", | |
| "paper": "Schütt et al., JCP 2017" | |
| }, | |
| "MEGNet_demo": { | |
| "name": "MatErials Graph Network", | |
| "description": "Graph network with global state for materials property prediction.", | |
| "paper": "Chen et al., Chem. Mater. 2019" | |
| }, | |
| "GCN_demo": { | |
| "name": "Graph Convolutional Network", | |
| "description": "Standard graph convolutional network architecture.", | |
| "paper": "Kipf & Welling, ICLR 2017" | |
| }, | |
| "SOAP_demo": { | |
| "name": "Smooth Overlap of Atomic Positions", | |
| "description": "Descriptor-based method using SOAP features.", | |
| "paper": "Bartók et al., PRB 2013" | |
| }, | |
| "SM_demo": { | |
| "name": "Sine Matrix", | |
| "description": "Descriptor-based method using Sine/Coulomb matrix features.", | |
| "paper": "Various" | |
| } | |
| } | |
| return { | |
| "success": True, | |
| "models": models_info, | |
| "total_models": len(models_info) | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def get_model_config(model_name: str) -> dict: | |
| """ | |
| Get the default configuration for a specific GNN model. | |
| Parameters: | |
| model_name (str): Name of the model (e.g., 'CGCNN_demo', 'SchNet_demo'). | |
| Returns: | |
| dict: Contains the default configuration for the model. | |
| """ | |
| try: | |
| config_path = os.path.join(project_root, "config.yml") | |
| if not os.path.exists(config_path): | |
| return {"success": False, "error": "Config file not found"} | |
| with open(config_path, "r") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| if model_name not in config.get("Models", {}): | |
| return { | |
| "success": False, | |
| "error": f"Model '{model_name}' not found. Available models: {list(config.get('Models', {}).keys())}" | |
| } | |
| model_config = config["Models"][model_name] | |
| processing_config = config.get("Processing", {}) | |
| training_config = config.get("Training", {}) | |
| return { | |
| "success": True, | |
| "model_name": model_name, | |
| "model_config": model_config, | |
| "processing_config": processing_config, | |
| "training_config": training_config | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def process_structure_data( | |
| data_path: str = "<your-data-path>", | |
| target_index: int = 0, | |
| graph_max_radius: float = 8.0, | |
| graph_max_neighbors: int = 12, | |
| data_format: str = "cif", | |
| dictionary_source: str = "default" | |
| ) -> dict: | |
| """ | |
| Generate local data processing script and configuration for MatDeepLearn. | |
| Server-side processing is not available; use this script locally. | |
| Parameters: | |
| data_path (str): Path to directory containing structure files (for script template). | |
| target_index (int): Index of target column in targets.csv (default: 0). | |
| graph_max_radius (float): Maximum radius for edges in graph (default: 8.0). | |
| graph_max_neighbors (int): Maximum number of neighbors per atom (default: 12). | |
| data_format (str): Structure file format ('cif', 'vasp', 'xyz', 'json'). | |
| dictionary_source (str): Atom dictionary source ('default', 'blank', 'generated', 'provided'). | |
| Returns: | |
| dict: Contains local processing script and data preparation guide. | |
| """ | |
| try: | |
| process_script = '''"""MatDeepLearn Data Processing Script | |
| Generated for local execution. | |
| """ | |
| import yaml | |
| from pathlib import Path | |
| from matdeeplearn import process | |
| # Configuration - Update this path | |
| DATA_PATH = Path(r"{data_path}") | |
| CONFIG_PATH = Path(r"config.yml") | |
| def main(): | |
| if not DATA_PATH.exists(): | |
| raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}") | |
| data_path_str = str(DATA_PATH.resolve()) | |
| processing_args = {{ | |
| "dataset_type": "inmemory", | |
| "data_path": data_path_str, | |
| "target_path": "targets.csv", | |
| "dictionary_source": "{dictionary_source}", | |
| "dictionary_path": "atom_dict.json", | |
| "data_format": "{data_format}", | |
| "verbose": "True", | |
| "graph_max_radius": {graph_max_radius}, | |
| "graph_max_neighbors": {graph_max_neighbors}, | |
| "voronoi": "False", | |
| "edge_features": "True", | |
| "graph_edge_length": 50, | |
| "SM_descriptor": "False", | |
| "SOAP_descriptor": "False" | |
| }} | |
| print(f"Processing data from {{data_path_str}}...") | |
| dataset = process.get_dataset( | |
| data_path_str, | |
| {target_index}, # target_index | |
| "True", # reprocess | |
| processing_args | |
| ) | |
| print(f"Processing complete!") | |
| print(f"Dataset size: {{len(dataset)}}") | |
| if len(dataset) > 0: | |
| print(f"Atoms per structure: {{dataset[0].x.shape[0]}}") | |
| print(f"Node features: {{dataset[0].x.shape[1]}}") | |
| print(f"Edges per structure: {{dataset[0].edge_index.shape[1]}}") | |
| print(f"Processed data saved to: {{data_path_str}}/processed/") | |
| if __name__ == "__main__": | |
| main() | |
| '''.format( | |
| data_path=data_path, | |
| data_format=data_format, | |
| dictionary_source=dictionary_source, | |
| graph_max_radius=graph_max_radius, | |
| graph_max_neighbors=graph_max_neighbors, | |
| target_index=target_index | |
| ) | |
| return { | |
| "success": True, | |
| "message": "Server-side data processing is not available. Use the provided script locally.", | |
| "local_processing_script": process_script, | |
| "processing_parameters": { | |
| "graph_max_radius": graph_max_radius, | |
| "graph_max_neighbors": graph_max_neighbors, | |
| "data_format": data_format, | |
| "dictionary_source": dictionary_source, | |
| "target_index": target_index | |
| }, | |
| "data_preparation_guide": { | |
| "folder_structure": [ | |
| "your_dataset/", | |
| " ├── structure1.cif", | |
| " ├── structure2.cif", | |
| " ├── ...", | |
| " └── targets.csv" | |
| ], | |
| "targets_csv_format": "structure_id,target_value (no header, one per line)", | |
| "supported_formats": ["cif", "vasp/poscar", "xyz", "json", "extxyz"], | |
| "naming_convention": "Structure filename (without extension) must match structure_id in targets.csv" | |
| }, | |
| "usage_instructions": [ | |
| "1. Prepare your data folder with structure files and targets.csv", | |
| "2. Save the script as 'process_data.py'", | |
| "3. Update DATA_PATH to your data folder", | |
| "4. Run: python process_data.py", | |
| "5. Processed data will be saved to data_folder/processed/" | |
| ] | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def train_model( | |
| data_path: str, | |
| model_name: str = "CGCNN_demo", | |
| epochs: int = 100, | |
| batch_size: int = 32, | |
| learning_rate: float = 0.002, | |
| train_ratio: float = 0.8, | |
| val_ratio: float = 0.1, | |
| test_ratio: float = 0.1, | |
| save_model: bool = True, | |
| model_path: str = "trained_model.pth", | |
| target_index: int = 0, | |
| include_environment_guide: bool = True | |
| ) -> dict: | |
| """ | |
| Generate a complete local training script with environment setup guide for MatDeepLearn. | |
| This tool provides everything needed to train GNN models on your local machine. | |
| Parameters: | |
| data_path (str): Path to directory containing processed structure data. | |
| model_name (str): Name of the model to train. Options: 'CGCNN_demo', 'MPNN_demo', 'SchNet_demo', 'MEGNet_demo', 'GCN_demo', 'SOAP_demo', 'SM_demo'. | |
| epochs (int): Number of training epochs (default: 100). | |
| batch_size (int): Training batch size (default: 32). | |
| learning_rate (float): Learning rate (default: 0.002). | |
| train_ratio (float): Ratio of data for training (default: 0.8). | |
| val_ratio (float): Ratio of data for validation (default: 0.1). | |
| test_ratio (float): Ratio of data for testing (default: 0.1). | |
| save_model (bool): Whether to save the trained model (default: True). | |
| model_path (str): Path to save the trained model (default: 'trained_model.pth'). | |
| target_index (int): Index of target column in targets.csv (default: 0). | |
| include_environment_guide (bool): Whether to include environment setup guide (default: True). | |
| Returns: | |
| dict: Complete local training package including: | |
| - Ready-to-run Python training script | |
| - Model configuration and hyperparameters | |
| - Environment setup instructions | |
| - Dependencies list | |
| """ | |
| try: | |
| # Load default config | |
| config_path = os.path.join(project_root, "config.yml") | |
| if not os.path.exists(config_path): | |
| return {"success": False, "error": "Config file not found on server"} | |
| with open(config_path, "r") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| available_models = list(config.get("Models", {}).keys()) | |
| if model_name not in config.get("Models", {}): | |
| return { | |
| "success": False, | |
| "error": f"Model '{model_name}' not found. Available models: {available_models}" | |
| } | |
| # Get model-specific config | |
| model_config = config["Models"][model_name] | |
| script = _build_local_training_script( | |
| data_path_literal=data_path, | |
| model_name=model_name, | |
| target_index=target_index, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| learning_rate=learning_rate, | |
| train_ratio=train_ratio, | |
| val_ratio=val_ratio, | |
| test_ratio=test_ratio, | |
| save_model=save_model, | |
| model_output_literal=model_path, | |
| reprocess=False, | |
| config_path_literal="config.yml" | |
| ) | |
| # Build environment setup guide | |
| environment_guide = None | |
| if include_environment_guide: | |
| environment_guide = { | |
| "python_version": "Python 3.7 - 3.9 recommended", | |
| "cuda_requirement": "CUDA 10.2 or 11.x for GPU training", | |
| "setup_steps": [ | |
| "1. Create a virtual environment: python -m venv matdeeplearn_env", | |
| "2. Activate the environment:", | |
| " - Windows: matdeeplearn_env\\Scripts\\activate", | |
| " - Linux/Mac: source matdeeplearn_env/bin/activate", | |
| "3. Install PyTorch with CUDA support:", | |
| " pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118", | |
| "4. Install PyTorch Geometric:", | |
| " pip install torch-geometric", | |
| " pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html", | |
| "5. Install other dependencies:", | |
| " pip install ase pymatgen numpy scipy matplotlib scikit-learn pyyaml", | |
| "6. Clone MatDeepLearn repository:", | |
| " git clone https://github.com/Fung-Lab/MatDeepLearn.git", | |
| "7. Install MatDeepLearn:", | |
| " cd MatDeepLearn && pip install -e ." | |
| ], | |
| "core_dependencies": [ | |
| "torch>=1.8.0", | |
| "torch-geometric>=2.0.0", | |
| "ase>=3.20.0", | |
| "pymatgen>=2020.9.0", | |
| "numpy>=1.20.0", | |
| "scipy>=1.6.0", | |
| "scikit-learn>=0.24.0", | |
| "pyyaml", | |
| "matplotlib" | |
| ], | |
| "optional_dependencies": [ | |
| "dscribe>=0.3.5 (for SOAP descriptor)", | |
| "ray>=1.0.0 (for hyperparameter optimization)" | |
| ], | |
| "gpu_check_code": "import torch; print(f'CUDA available: {torch.cuda.is_available()}'); print(f'GPU count: {torch.cuda.device_count()}')" | |
| } | |
| # Build config.yml content for user | |
| config_yml_content = yaml.dump({ | |
| "Processing": config.get("Processing", {}), | |
| "Training": config.get("Training", {}), | |
| "Models": {model_name: model_config} | |
| }, default_flow_style=False) | |
| return { | |
| "success": True, | |
| "message": "Local training package generated successfully. Save the script and run on your machine.", | |
| "model_name": model_name, | |
| "model_description": _get_model_description(model_name), | |
| "parameters": { | |
| "epochs": epochs, | |
| "batch_size": batch_size, | |
| "learning_rate": learning_rate, | |
| "train_ratio": train_ratio, | |
| "val_ratio": val_ratio, | |
| "test_ratio": test_ratio, | |
| "target_index": target_index, | |
| "save_model": save_model, | |
| "model_output": model_path | |
| }, | |
| "model_architecture": { | |
| "hidden_dims": [model_config.get("dim1"), model_config.get("dim2"), model_config.get("dim3")], | |
| "gc_layers": model_config.get("gc_count"), | |
| "pooling": model_config.get("pool"), | |
| "optimizer": model_config.get("optimizer"), | |
| "scheduler": model_config.get("scheduler") | |
| }, | |
| "local_training_script": script, | |
| "config_yml_template": config_yml_content, | |
| "environment_guide": environment_guide, | |
| "usage_instructions": { | |
| "step1": "Save the training script as 'train_local.py'", | |
| "step2": "Save the config_yml_template as 'config.yml' in the same directory", | |
| "step3": "Update DATA_PATH in the script to point to your dataset directory", | |
| "step4": "Ensure your dataset has structure files and targets.csv", | |
| "step5": "Run: python train_local.py" | |
| }, | |
| "available_models": available_models | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def _get_model_description(model_name: str) -> dict: | |
| """Get description for a model.""" | |
| descriptions = { | |
| "CGCNN_demo": { | |
| "full_name": "Crystal Graph Convolutional Neural Network", | |
| "description": "A GNN for predicting material properties using crystal graphs with edge-conditioned convolutions.", | |
| "best_for": "Crystalline materials with periodic structures", | |
| "paper": "Xie & Grossman, PRL 2018" | |
| }, | |
| "MPNN_demo": { | |
| "full_name": "Message Passing Neural Network", | |
| "description": "General message passing framework for molecular and crystal graphs.", | |
| "best_for": "General molecular and material property prediction", | |
| "paper": "Gilmer et al., ICML 2017" | |
| }, | |
| "SchNet_demo": { | |
| "full_name": "SchNet", | |
| "description": "Continuous-filter convolutional neural network with smooth radial basis functions.", | |
| "best_for": "Quantum chemistry and energy predictions", | |
| "paper": "Schütt et al., JCP 2017" | |
| }, | |
| "MEGNet_demo": { | |
| "full_name": "MatErials Graph Network", | |
| "description": "Graph network with global state for materials property prediction.", | |
| "best_for": "Materials with complex global properties", | |
| "paper": "Chen et al., Chem. Mater. 2019" | |
| }, | |
| "GCN_demo": { | |
| "full_name": "Graph Convolutional Network", | |
| "description": "Standard graph convolutional network architecture.", | |
| "best_for": "Simple and fast baseline models", | |
| "paper": "Kipf & Welling, ICLR 2017" | |
| }, | |
| "SOAP_demo": { | |
| "full_name": "Smooth Overlap of Atomic Positions", | |
| "description": "Descriptor-based method using SOAP features.", | |
| "best_for": "Local atomic environment similarity", | |
| "paper": "Bartók et al., PRB 2013" | |
| }, | |
| "SM_demo": { | |
| "full_name": "Sine Matrix", | |
| "description": "Descriptor-based method using Sine/Coulomb matrix features.", | |
| "best_for": "Small molecules and simple structures", | |
| "paper": "Various" | |
| } | |
| } | |
| return descriptions.get(model_name, {"full_name": model_name, "description": "Custom model"}) | |
| def predict_properties( | |
| data_path: str, | |
| model_path: str, | |
| target_index: int = 0 | |
| ) -> dict: | |
| """ | |
| Generate a local prediction script for predicting properties with a trained model. | |
| Server-side prediction is not available; use this script locally. | |
| Parameters: | |
| data_path (str): Path to directory containing structure files to predict. | |
| model_path (str): Path to the trained model file (.pth). | |
| target_index (int): Index of target column (default: 0). | |
| Returns: | |
| dict: Contains local prediction script and instructions. | |
| """ | |
| try: | |
| predict_script = '''"""MatDeepLearn Prediction Script | |
| Generated for local execution. | |
| """ | |
| import yaml | |
| import torch | |
| from pathlib import Path | |
| from matdeeplearn import training, process | |
| # Configuration - Update these paths | |
| DATA_PATH = Path(r"{data_path}") | |
| MODEL_PATH = Path(r"{model_path}") | |
| CONFIG_PATH = Path(r"config.yml") | |
| TARGET_INDEX = {target_index} | |
| def main(): | |
| if not DATA_PATH.exists(): | |
| raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}") | |
| if not MODEL_PATH.exists(): | |
| raise FileNotFoundError(f"Model path not found: {{MODEL_PATH}}") | |
| if not CONFIG_PATH.exists(): | |
| raise FileNotFoundError(f"Config file not found: {{CONFIG_PATH}}") | |
| with CONFIG_PATH.open("r", encoding="utf-8") as f: | |
| config = yaml.safe_load(f) | |
| data_path_str = str(DATA_PATH.resolve()) | |
| # Process data | |
| processing_args = config.get("Processing", {{}}) | |
| processing_args["data_path"] = data_path_str | |
| dataset = process.get_dataset( | |
| data_path_str, | |
| TARGET_INDEX, | |
| "False", | |
| processing_args | |
| ) | |
| job_config = {{ | |
| "job_name": "prediction_job", | |
| "model_path": str(MODEL_PATH.resolve()), | |
| "write_output": "True" | |
| }} | |
| print(f"Predicting on {{len(dataset)}} structures...") | |
| # Run prediction | |
| test_error = training.predict(dataset, "l1_loss", job_config) | |
| print(f"Prediction complete!") | |
| print(f"MAE (if targets available): {{test_error:.4f}}") | |
| print(f"Results saved to: prediction_job_predicted_outputs.csv") | |
| if __name__ == "__main__": | |
| main() | |
| '''.format( | |
| data_path=data_path, | |
| model_path=model_path, | |
| target_index=target_index | |
| ) | |
| return { | |
| "success": True, | |
| "message": "Server-side prediction is not available. Use the provided script locally.", | |
| "local_prediction_script": predict_script, | |
| "usage_instructions": [ | |
| "1. Save the script as 'predict_local.py'", | |
| "2. Ensure config.yml is in the same directory", | |
| "3. Update DATA_PATH and MODEL_PATH to your local paths", | |
| "4. Prepare your data with structure files and targets.csv", | |
| "5. Run: python predict_local.py" | |
| ], | |
| "data_requirements": { | |
| "structure_files": "CIF, POSCAR, XYZ, or JSON format files", | |
| "targets.csv": "structure_id,target_value (can be dummy values like 0.0 for pure prediction)" | |
| } | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def cross_validation( | |
| data_path: str, | |
| model_name: str = "CGCNN_demo", | |
| cv_folds: int = 5, | |
| epochs: int = 100, | |
| batch_size: int = 32, | |
| learning_rate: float = 0.002 | |
| ) -> dict: | |
| """ | |
| Generate a local cross-validation script for k-fold CV. | |
| Parameters: | |
| data_path (str): Path to directory containing structure data. | |
| model_name (str): Name of the model to use (default: 'CGCNN_demo'). | |
| cv_folds (int): Number of cross-validation folds (default: 5). | |
| epochs (int): Number of training epochs per fold (default: 100). | |
| batch_size (int): Batch size (default: 32). | |
| learning_rate (float): Learning rate (default: 0.002). | |
| Returns: | |
| dict: Contains local cross-validation script and instructions. | |
| """ | |
| try: | |
| # Load config to validate model | |
| config_path = os.path.join(project_root, "config.yml") | |
| with open(config_path, "r") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| if model_name not in config.get("Models", {}): | |
| available = list(config.get("Models", {}).keys()) | |
| return {"success": False, "error": f"Model '{model_name}' not found. Available: {available}"} | |
| # Generate CV script | |
| cv_script = '''"""MatDeepLearn Cross-Validation Script | |
| Generated for local execution. | |
| """ | |
| import yaml | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| from matdeeplearn import training, process | |
| # Configuration - Update these paths | |
| DATA_PATH = Path(r"{data_path}") | |
| CONFIG_PATH = Path(r"config.yml") | |
| # CV Parameters | |
| CV_FOLDS = {cv_folds} | |
| MODEL_NAME = "{model_name}" | |
| EPOCHS = {epochs} | |
| BATCH_SIZE = {batch_size} | |
| LEARNING_RATE = {learning_rate} | |
| def main(): | |
| if not DATA_PATH.exists(): | |
| raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}") | |
| if not CONFIG_PATH.exists(): | |
| raise FileNotFoundError(f"Config file not found: {{CONFIG_PATH}}") | |
| with CONFIG_PATH.open("r", encoding="utf-8") as f: | |
| config = yaml.safe_load(f) | |
| job_config = {{ | |
| "job_name": "cv_job_{model_name}", | |
| "reprocess": "False", | |
| "model": MODEL_NAME, | |
| "write_output": "True", | |
| "parallel": "False", | |
| "seed": int(np.random.randint(1, 1_000_000)), | |
| "cv_folds": CV_FOLDS | |
| }} | |
| training_config = config.get("Training", {{}}).copy() | |
| training_config.update({{ | |
| "target_index": 0, | |
| "verbosity": 5, | |
| }}) | |
| model_config = config["Models"][MODEL_NAME].copy() | |
| model_config.update({{ | |
| "epochs": EPOCHS, | |
| "batch_size": BATCH_SIZE, | |
| "lr": LEARNING_RATE, | |
| }}) | |
| data_path_str = str(DATA_PATH.resolve()) | |
| world_size = torch.cuda.device_count() | |
| rank = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Running {{CV_FOLDS}}-fold cross validation with {{MODEL_NAME}}") | |
| print(f"Device: {{rank}}, GPU count: {{world_size}}") | |
| # Run CV | |
| cv_results = training.train_cv( | |
| rank, | |
| world_size, | |
| data_path_str, | |
| job_config, | |
| training_config, | |
| model_config, | |
| ) | |
| print("\n" + "="*50) | |
| print("Cross-Validation Results:") | |
| print(f"Mean MAE: {{np.mean(cv_results):.4f}} +/- {{np.std(cv_results):.4f}}") | |
| print(f"Individual fold errors: {{cv_results}}") | |
| print("="*50) | |
| if __name__ == "__main__": | |
| main() | |
| '''.format( | |
| data_path=data_path, | |
| cv_folds=cv_folds, | |
| model_name=model_name, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| learning_rate=learning_rate | |
| ) | |
| return { | |
| "success": True, | |
| "message": "Cross-validation script generated. Run locally on your machine.", | |
| "model_name": model_name, | |
| "cv_folds": cv_folds, | |
| "parameters": { | |
| "epochs": epochs, | |
| "batch_size": batch_size, | |
| "learning_rate": learning_rate | |
| }, | |
| "local_cv_script": cv_script, | |
| "usage_instructions": [ | |
| "1. Save the script as 'run_cv.py'", | |
| "2. Ensure config.yml is in the same directory", | |
| "3. Update DATA_PATH to your dataset location", | |
| "4. Run: python run_cv.py" | |
| ], | |
| "expected_output": f"Will output mean MAE and std across {cv_folds} folds" | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def analyze_structure( | |
| file_content: Optional[str] = None, | |
| file_format: Optional[str] = None, | |
| structure_file: Optional[str] = None | |
| ) -> dict: | |
| """ | |
| Analyze the structure of an atomic structure. | |
| Parameters: | |
| file_content (str, optional): The content of the structure file (CIF, XYZ, POSCAR, JSON format). | |
| Pass the actual file content directly here. | |
| file_format (str, optional): Format of the file content ('cif', 'xyz', 'vasp', 'json'). | |
| Required when file_content is provided. | |
| structure_file (str, optional): Path to a structure file on the server (legacy option). | |
| Returns: | |
| dict: Contains structure analysis including atoms, bonds, and graph info. | |
| Example usage: | |
| analyze_structure(file_content="your CIF file content here...", file_format="cif") | |
| """ | |
| try: | |
| import ase | |
| from ase import io | |
| from io import StringIO | |
| structure = None | |
| # Method 1: Direct file content (preferred for remote access) | |
| if file_content is not None: | |
| if file_format is None: | |
| return {"success": False, "error": "file_format is required when providing file_content. Use 'cif', 'xyz', 'vasp', or 'json'."} | |
| # Map common format names | |
| format_map = { | |
| 'cif': 'cif', | |
| 'xyz': 'xyz', | |
| 'vasp': 'vasp', | |
| 'poscar': 'vasp', | |
| 'json': 'json', | |
| 'extxyz': 'extxyz' | |
| } | |
| fmt = format_map.get(file_format.lower()) | |
| if fmt is None: | |
| return {"success": False, "error": f"Unsupported format: {file_format}. Supported: cif, xyz, vasp, poscar, json, extxyz"} | |
| # Create a temporary file to read the structure | |
| with tempfile.NamedTemporaryFile(mode='w', suffix=f'.{fmt}', delete=False) as tmp: | |
| tmp.write(file_content) | |
| tmp_path = tmp.name | |
| try: | |
| structure = ase.io.read(tmp_path, format=fmt) | |
| finally: | |
| os.unlink(tmp_path) # Clean up temp file | |
| # Method 2: File path on server (legacy) | |
| elif structure_file is not None: | |
| if not os.path.exists(structure_file): | |
| return {"success": False, "error": f"Structure file not found: {structure_file}. Tip: For remote MCP, pass file_content directly instead of file path."} | |
| structure = ase.io.read(structure_file) | |
| else: | |
| return {"success": False, "error": "Either file_content (with file_format) or structure_file must be provided."} | |
| # Get basic info | |
| symbols = structure.get_chemical_symbols() | |
| positions = structure.get_positions().tolist() | |
| cell = structure.get_cell().tolist() if any(structure.pbc) else None | |
| pbc = structure.pbc.tolist() | |
| # Get distance matrix | |
| distance_matrix = structure.get_all_distances(mic=True) | |
| # Analyze connectivity | |
| cutoff_radius = 8.0 | |
| neighbors_count = [] | |
| for i in range(len(structure)): | |
| neighbors = np.sum((distance_matrix[i] > 0) & (distance_matrix[i] < cutoff_radius)) | |
| neighbors_count.append(int(neighbors)) | |
| return { | |
| "success": True, | |
| "num_atoms": len(structure), | |
| "chemical_formula": structure.get_chemical_formula(), | |
| "elements": list(set(symbols)), | |
| "element_counts": {elem: symbols.count(elem) for elem in set(symbols)}, | |
| "has_periodicity": any(pbc), | |
| "pbc": pbc, | |
| "cell": cell, | |
| "positions": positions[:10] if len(positions) > 10 else positions, # First 10 positions | |
| "average_neighbors": float(np.mean(neighbors_count)), | |
| "min_neighbors": min(neighbors_count), | |
| "max_neighbors": max(neighbors_count), | |
| "min_distance": float(distance_matrix[distance_matrix > 0].min()), | |
| "max_distance": float(distance_matrix.max()) | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def compare_models( | |
| data_path: str, | |
| model_list: List[str] = None, | |
| epochs: int = 50, | |
| batch_size: int = 32 | |
| ) -> dict: | |
| """ | |
| Generate a local script to compare performance of different GNN models. | |
| Parameters: | |
| data_path (str): Path to directory containing structure data. | |
| model_list (List[str]): List of models to compare (default: all available). | |
| epochs (int): Number of training epochs per model (default: 50). | |
| batch_size (int): Batch size for training (default: 32). | |
| Returns: | |
| dict: Contains local comparison script and model recommendations. | |
| """ | |
| try: | |
| # Load config | |
| config_path = os.path.join(project_root, "config.yml") | |
| with open(config_path, "r") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| available_models = list(config.get("Models", {}).keys()) | |
| if model_list is None: | |
| model_list = available_models | |
| else: | |
| # Validate models | |
| invalid = [m for m in model_list if m not in available_models] | |
| if invalid: | |
| return { | |
| "success": False, | |
| "error": f"Invalid models: {invalid}. Available: {available_models}" | |
| } | |
| # Generate comparison script | |
| models_str = str(model_list) | |
| compare_script = '''"""MatDeepLearn Model Comparison Script | |
| Generated for local benchmarking. | |
| """ | |
| import yaml | |
| import numpy as np | |
| import torch | |
| import time | |
| from pathlib import Path | |
| from matdeeplearn import training, process | |
| # Configuration - Update these paths | |
| DATA_PATH = Path(r"{data_path}") | |
| CONFIG_PATH = Path(r"config.yml") | |
| # Comparison Parameters | |
| MODEL_LIST = {models} | |
| EPOCHS = {epochs} | |
| BATCH_SIZE = {batch_size} | |
| def train_single_model(model_name, data_path_str, config, rank, world_size): | |
| """Train a single model and return results.""" | |
| job_config = {{ | |
| "job_name": f"compare_{{model_name}}", | |
| "reprocess": "False", | |
| "model": model_name, | |
| "save_model": "False", | |
| "write_output": "False", | |
| "parallel": "False", | |
| "seed": 42, | |
| }} | |
| training_config = config.get("Training", {{}}).copy() | |
| training_config.update({{ | |
| "target_index": 0, | |
| "train_ratio": 0.8, | |
| "val_ratio": 0.1, | |
| "test_ratio": 0.1, | |
| "verbosity": 5, | |
| }}) | |
| model_config = config["Models"][model_name].copy() | |
| model_config.update({{ | |
| "epochs": EPOCHS, | |
| "batch_size": BATCH_SIZE, | |
| }}) | |
| start_time = time.time() | |
| errors = training.train_regular( | |
| rank, | |
| world_size, | |
| data_path_str, | |
| job_config, | |
| training_config, | |
| model_config, | |
| ) | |
| elapsed = time.time() - start_time | |
| return {{ | |
| "model": model_name, | |
| "train_error": errors[0], | |
| "val_error": errors[1], | |
| "test_error": errors[2], | |
| "training_time": elapsed | |
| }} | |
| def main(): | |
| if not DATA_PATH.exists(): | |
| raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}") | |
| if not CONFIG_PATH.exists(): | |
| raise FileNotFoundError(f"Config file not found: {{CONFIG_PATH}}") | |
| with CONFIG_PATH.open("r", encoding="utf-8") as f: | |
| config = yaml.safe_load(f) | |
| data_path_str = str(DATA_PATH.resolve()) | |
| world_size = torch.cuda.device_count() | |
| rank = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device: {{rank}}, GPU count: {{world_size}}") | |
| print(f"Comparing models: {{MODEL_LIST}}") | |
| print("="*60) | |
| results = [] | |
| for model_name in MODEL_LIST: | |
| print(f"\nTraining {{model_name}}...") | |
| try: | |
| result = train_single_model(model_name, data_path_str, config, rank, world_size) | |
| results.append(result) | |
| print(f" Test MAE: {{result['test_error']:.4f}}, Time: {{result['training_time']:.1f}}s") | |
| except Exception as e: | |
| print(f" Error: {{e}}") | |
| results.append({{ | |
| "model": model_name, | |
| "error": str(e) | |
| }}) | |
| # Summary | |
| print("\n" + "="*60) | |
| print("MODEL COMPARISON RESULTS") | |
| print("="*60) | |
| print(f"{{' Model':<20}} {{'Test MAE':>12}} {{'Val MAE':>12}} {{'Time (s)':>10}}") | |
| print("-"*60) | |
| valid_results = [r for r in results if "test_error" in r] | |
| for r in sorted(valid_results, key=lambda x: x["test_error"]): | |
| print(f"{{r['model']:<20}} {{r['test_error']:>12.4f}} {{r['val_error']:>12.4f}} {{r['training_time']:>10.1f}}") | |
| if valid_results: | |
| best = min(valid_results, key=lambda x: x["test_error"]) | |
| print("\n" + "="*60) | |
| print(f"BEST MODEL: {{best['model']}} with Test MAE = {{best['test_error']:.4f}}") | |
| print("="*60) | |
| if __name__ == "__main__": | |
| main() | |
| '''.format( | |
| data_path=data_path, | |
| models=models_str, | |
| epochs=epochs, | |
| batch_size=batch_size | |
| ) | |
| # Model recommendations | |
| recommendations = { | |
| "for_crystals": ["CGCNN_demo", "MEGNet_demo", "SchNet_demo"], | |
| "for_molecules": ["MPNN_demo", "SchNet_demo"], | |
| "for_speed": ["GCN_demo", "SM_demo"], | |
| "for_accuracy": ["CGCNN_demo", "MEGNet_demo"] | |
| } | |
| return { | |
| "success": True, | |
| "message": "Model comparison script generated. Run locally to benchmark models.", | |
| "models_to_compare": model_list, | |
| "parameters": { | |
| "epochs": epochs, | |
| "batch_size": batch_size | |
| }, | |
| "local_comparison_script": compare_script, | |
| "model_recommendations": recommendations, | |
| "usage_instructions": [ | |
| "1. Save the script as 'compare_models.py'", | |
| "2. Ensure config.yml is in the same directory", | |
| "3. Update DATA_PATH to your dataset location", | |
| "4. Run: python compare_models.py", | |
| "5. Results will show ranked models by test MAE" | |
| ], | |
| "available_models": available_models | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def get_dataset_info( | |
| data_path: Optional[str] = None, | |
| structure_files: Optional[List[str]] = None, | |
| targets_csv_content: Optional[str] = None | |
| ) -> dict: | |
| """ | |
| Get information about a dataset. | |
| Parameters: | |
| data_path (str, optional): Path to directory containing structure data (server-side). | |
| structure_files (List[str], optional): List of structure filenames (for validation check). | |
| targets_csv_content (str, optional): Content of targets.csv file to analyze. | |
| Returns: | |
| dict: Contains dataset information including file counts and formats. | |
| """ | |
| try: | |
| # If analyzing uploaded content | |
| if targets_csv_content is not None: | |
| import csv | |
| from io import StringIO | |
| reader = csv.reader(StringIO(targets_csv_content)) | |
| rows = list(reader) | |
| num_samples = len(rows) | |
| # Parse target values | |
| target_values = [] | |
| for row in rows: | |
| if len(row) >= 2: | |
| try: | |
| target_values.append(float(row[1])) | |
| except: | |
| pass | |
| result = { | |
| "success": True, | |
| "source": "uploaded_content", | |
| "num_samples": num_samples, | |
| "has_targets_csv": True, | |
| "ready_for_training": True | |
| } | |
| if target_values: | |
| result["target_statistics"] = { | |
| "min": min(target_values), | |
| "max": max(target_values), | |
| "mean": sum(target_values) / len(target_values) | |
| } | |
| if structure_files: | |
| extensions = {} | |
| for f in structure_files: | |
| ext = os.path.splitext(f)[1].lower() | |
| extensions[ext] = extensions.get(ext, 0) + 1 | |
| result["file_extensions"] = extensions | |
| result["num_structure_files"] = len(structure_files) | |
| return result | |
| # Traditional path-based analysis | |
| if data_path is None: | |
| return {"success": False, "error": "Either data_path or targets_csv_content must be provided"} | |
| if not os.path.exists(data_path): | |
| return {"success": False, "error": f"Data path not found: {data_path}"} | |
| # Count files by extension | |
| extensions = {} | |
| for file in os.listdir(data_path): | |
| ext = os.path.splitext(file)[1].lower() | |
| extensions[ext] = extensions.get(ext, 0) + 1 | |
| # Check for required files | |
| has_targets = os.path.exists(os.path.join(data_path, "targets.csv")) | |
| has_atom_dict = os.path.exists(os.path.join(data_path, "atom_dict.json")) | |
| has_processed = os.path.exists(os.path.join(data_path, "processed")) | |
| # Read targets if available | |
| num_samples = 0 | |
| if has_targets: | |
| import csv | |
| with open(os.path.join(data_path, "targets.csv")) as f: | |
| num_samples = sum(1 for _ in csv.reader(f)) | |
| return { | |
| "success": True, | |
| "source": "server_path", | |
| "data_path": data_path, | |
| "file_extensions": extensions, | |
| "has_targets_csv": has_targets, | |
| "has_atom_dict": has_atom_dict, | |
| "has_processed_data": has_processed, | |
| "num_samples": num_samples, | |
| "ready_for_training": has_targets | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def quick_structure_analysis( | |
| file_content: str, | |
| file_format: str, | |
| include_positions: bool = False, | |
| include_distances: bool = True | |
| ) -> dict: | |
| """ | |
| Perform quick analysis on structure file content uploaded directly. | |
| This is the recommended tool for analyzing structures when using remote MCP. | |
| Parameters: | |
| file_content (str): The complete content of the structure file. | |
| file_format (str): Format of the file - 'cif', 'xyz', 'vasp'/'poscar', 'json', 'extxyz'. | |
| include_positions (bool): Whether to include atomic positions in output (default: False). | |
| include_distances (bool): Whether to include distance analysis (default: True). | |
| Returns: | |
| dict: Comprehensive structure analysis. | |
| Example: | |
| quick_structure_analysis( | |
| file_content="data_NaCl\\n_cell_length_a 5.64...", | |
| file_format="cif" | |
| ) | |
| """ | |
| try: | |
| import ase | |
| from ase import io | |
| # Map format names | |
| format_map = { | |
| 'cif': 'cif', | |
| 'xyz': 'xyz', | |
| 'vasp': 'vasp', | |
| 'poscar': 'vasp', | |
| 'json': 'json', | |
| 'extxyz': 'extxyz' | |
| } | |
| fmt = format_map.get(file_format.lower()) | |
| if fmt is None: | |
| return { | |
| "success": False, | |
| "error": f"Unsupported format: {file_format}. Supported: cif, xyz, vasp, poscar, json, extxyz" | |
| } | |
| # Write to temp file and read | |
| with tempfile.NamedTemporaryFile(mode='w', suffix=f'.{fmt}', delete=False) as tmp: | |
| tmp.write(file_content) | |
| tmp_path = tmp.name | |
| try: | |
| structure = ase.io.read(tmp_path, format=fmt) | |
| finally: | |
| os.unlink(tmp_path) | |
| # Basic analysis | |
| symbols = structure.get_chemical_symbols() | |
| cell = structure.get_cell().tolist() if any(structure.pbc) else None | |
| pbc = structure.pbc.tolist() | |
| result = { | |
| "success": True, | |
| "num_atoms": len(structure), | |
| "chemical_formula": structure.get_chemical_formula(), | |
| "reduced_formula": structure.get_chemical_formula(mode='reduce'), | |
| "elements": sorted(list(set(symbols))), | |
| "element_counts": {elem: symbols.count(elem) for elem in set(symbols)}, | |
| "has_periodicity": any(pbc), | |
| "pbc": pbc, | |
| "cell_parameters": cell, | |
| "volume": float(structure.get_volume()) if any(pbc) else None, | |
| } | |
| if include_positions: | |
| positions = structure.get_positions().tolist() | |
| result["positions"] = positions | |
| result["symbols"] = symbols | |
| if include_distances: | |
| distance_matrix = structure.get_all_distances(mic=True) | |
| cutoff_radius = 8.0 | |
| neighbors_count = [] | |
| for i in range(len(structure)): | |
| neighbors = np.sum((distance_matrix[i] > 0) & (distance_matrix[i] < cutoff_radius)) | |
| neighbors_count.append(int(neighbors)) | |
| result["distance_analysis"] = { | |
| "cutoff_radius": cutoff_radius, | |
| "average_neighbors": float(np.mean(neighbors_count)), | |
| "min_neighbors": min(neighbors_count), | |
| "max_neighbors": max(neighbors_count), | |
| "min_distance": float(distance_matrix[distance_matrix > 0].min()), | |
| "max_distance": float(distance_matrix.max()) | |
| } | |
| # Check if suitable for GNN | |
| result["gnn_suitable"] = { | |
| "has_enough_atoms": len(structure) >= 2, | |
| "has_3d_coordinates": True, | |
| "is_periodic": any(pbc), | |
| "recommendation": "Suitable for GNN training" if len(structure) >= 2 else "Too few atoms" | |
| } | |
| return result | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| # ============================================================================ | |
| # 基于会话的训练和模型管理工具 | |
| # ============================================================================ | |
| def process_session_data( | |
| session_id: str, | |
| target_index: int = 0, | |
| graph_max_radius: float = 8.0, | |
| graph_max_neighbors: int = 12, | |
| data_format: Optional[str] = None, | |
| dictionary_source: str = "default" | |
| ) -> dict: | |
| """ | |
| Analyze uploaded session data and generate local processing script. | |
| Server-side graph processing is not available; use the generated script locally. | |
| Parameters: | |
| session_id (str): The session ID. | |
| target_index (int): Index of target column in targets.csv (default: 0). | |
| graph_max_radius (float): Maximum radius for graph edges (default: 8.0 Angstrom). | |
| graph_max_neighbors (int): Maximum neighbors per atom (default: 12). | |
| data_format (str, optional): Explicit structure format ('cif', 'vasp', 'xyz', 'json'). | |
| dictionary_source (str): Atom dictionary source ('default', 'blank', 'generated', 'provided'). | |
| Returns: | |
| dict: Data analysis and local processing script. | |
| """ | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| data_path = os.path.join(session_path, "data") | |
| # Check for required files and analyze | |
| has_targets = os.path.exists(os.path.join(data_path, "targets.csv")) | |
| skip_files = {"targets.csv", "id.json", "atom_dict.json", "processed"} | |
| structure_files = [ | |
| f for f in os.listdir(data_path) | |
| if f not in skip_files and not f.startswith('.') and os.path.isfile(os.path.join(data_path, f)) | |
| ] | |
| # Detect format | |
| determined_format = data_format | |
| if determined_format is None: | |
| detected = _infer_structure_format(structure_files) | |
| determined_format = detected if detected else "json" | |
| # Analyze targets if present | |
| target_stats = None | |
| if has_targets: | |
| import csv | |
| targets_path = os.path.join(data_path, "targets.csv") | |
| with open(targets_path, 'r') as f: | |
| reader = csv.reader(f) | |
| rows = list(reader) | |
| target_values = [] | |
| for row in rows: | |
| if len(row) >= 2: | |
| try: | |
| target_values.append(float(row[1])) | |
| except: | |
| pass | |
| if target_values: | |
| target_stats = { | |
| "num_samples": len(rows), | |
| "min": min(target_values), | |
| "max": max(target_values), | |
| "mean": sum(target_values) / len(target_values) | |
| } | |
| # Generate processing script | |
| process_script = '''"""MatDeepLearn Data Processing Script | |
| Generated for session: {session_id} | |
| """ | |
| import yaml | |
| from pathlib import Path | |
| from matdeeplearn import process | |
| # Configuration - Update this path to your local data copy | |
| DATA_PATH = Path(r"{data_path}") | |
| CONFIG_PATH = Path(r"config.yml") | |
| def main(): | |
| if not DATA_PATH.exists(): | |
| raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}") | |
| data_path_str = str(DATA_PATH.resolve()) | |
| processing_args = {{ | |
| "dataset_type": "inmemory", | |
| "data_path": data_path_str, | |
| "target_path": "targets.csv", | |
| "dictionary_source": "{dictionary_source}", | |
| "dictionary_path": "atom_dict.json", | |
| "data_format": "{data_format}", | |
| "verbose": "True", | |
| "graph_max_radius": {graph_max_radius}, | |
| "graph_max_neighbors": {graph_max_neighbors}, | |
| "voronoi": "False", | |
| "edge_features": "True", | |
| "graph_edge_length": 50, | |
| "SM_descriptor": "False", | |
| "SOAP_descriptor": "False" | |
| }} | |
| print(f"Processing data from {{data_path_str}}...") | |
| dataset = process.get_dataset( | |
| data_path_str, | |
| {target_index}, # target_index | |
| "True", # reprocess | |
| processing_args | |
| ) | |
| print(f"Processing complete!") | |
| print(f"Dataset size: {{len(dataset)}}") | |
| if len(dataset) > 0: | |
| print(f"Atoms per structure: {{dataset[0].x.shape[0]}}") | |
| print(f"Node features: {{dataset[0].x.shape[1]}}") | |
| print(f"Edges per structure: {{dataset[0].edge_index.shape[1]}}") | |
| if __name__ == "__main__": | |
| main() | |
| '''.format( | |
| session_id=session_id, | |
| data_path="<path-to-your-local-data>", | |
| dictionary_source=dictionary_source, | |
| data_format=determined_format, | |
| graph_max_radius=graph_max_radius, | |
| graph_max_neighbors=graph_max_neighbors, | |
| target_index=target_index | |
| ) | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "data_analysis": { | |
| "num_structure_files": len(structure_files), | |
| "structure_files_sample": structure_files[:10], | |
| "detected_format": determined_format, | |
| "has_targets": has_targets, | |
| "target_statistics": target_stats, | |
| "server_data_path": data_path | |
| }, | |
| "local_processing_script": process_script, | |
| "processing_parameters": { | |
| "graph_max_radius": graph_max_radius, | |
| "graph_max_neighbors": graph_max_neighbors, | |
| "data_format": determined_format, | |
| "dictionary_source": dictionary_source, | |
| "target_index": target_index | |
| }, | |
| "next_steps": [ | |
| "1. Copy session data to your local machine", | |
| "2. Save the processing script as 'process_data.py'", | |
| "3. Update DATA_PATH to your local data folder", | |
| "4. Run: python process_data.py", | |
| "5. Then use train_session_model to get training script" | |
| ], | |
| "note": "Server-side processing is not available. Use the script locally with MatDeepLearn installed." | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def train_session_model( | |
| session_id: str, | |
| model_name: str = "CGCNN_demo", | |
| epochs: int = 100, | |
| batch_size: int = 32, | |
| learning_rate: float = 0.002, | |
| train_ratio: float = 0.8, | |
| val_ratio: float = 0.1, | |
| test_ratio: float = 0.1, | |
| model_save_name: Optional[str] = None, | |
| include_data_export_guide: bool = True | |
| ) -> dict: | |
| """ | |
| Generate a complete local training package for a session dataset. | |
| Includes training script, configuration, and data export instructions. | |
| Parameters: | |
| session_id (str): The session ID with processed data. | |
| model_name (str): Model to use - "CGCNN_demo", "SchNet_demo", "MPNN_demo", etc. | |
| epochs (int): Number of training epochs (default: 100). | |
| batch_size (int): Batch size (default: 32). | |
| learning_rate (float): Learning rate (default: 0.002). | |
| train_ratio (float): Training data ratio (default: 0.8). | |
| val_ratio (float): Validation data ratio (default: 0.1). | |
| test_ratio (float): Test data ratio (default: 0.1). | |
| model_save_name (str, optional): Custom name for saved model. | |
| include_data_export_guide (bool): Include guide for exporting session data (default: True). | |
| Returns: | |
| dict: Complete local training package with script, config, and instructions. | |
| """ | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| data_path = os.path.join(session_path, "data") | |
| # Load config | |
| config_path = os.path.join(project_root, "config.yml") | |
| with open(config_path, "r") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| if model_name not in config.get("Models", {}): | |
| available_models = list(config.get("Models", {}).keys()) | |
| return { | |
| "success": False, | |
| "error": f"Model '{model_name}' not found. Available: {available_models}" | |
| } | |
| # Generate model filename | |
| if model_save_name: | |
| model_filename = f"{model_save_name}.pth" | |
| else: | |
| model_filename = f"{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pth" | |
| placeholder_data_path = "<path-to-your-local-dataset>" | |
| script = _build_local_training_script( | |
| data_path_literal=placeholder_data_path, | |
| model_name=model_name, | |
| target_index=0, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| learning_rate=learning_rate, | |
| train_ratio=train_ratio, | |
| val_ratio=val_ratio, | |
| test_ratio=test_ratio, | |
| save_model=True, | |
| model_output_literal=model_filename, | |
| reprocess=False, | |
| config_path_literal="config.yml" | |
| ) | |
| # Build config.yml content | |
| model_config = config["Models"][model_name] | |
| config_yml_content = yaml.dump({ | |
| "Processing": config.get("Processing", {}), | |
| "Training": config.get("Training", {}), | |
| "Models": {model_name: model_config} | |
| }, default_flow_style=False) | |
| # Data export guide | |
| data_export_guide = None | |
| if include_data_export_guide: | |
| data_export_guide = { | |
| "session_data_location": data_path, | |
| "required_files": [ | |
| "targets.csv - CSV file with structure_id,target_value per line", | |
| "Structure files (CIF, POSCAR, XYZ, or JSON format)" | |
| ], | |
| "export_options": [ | |
| "Option 1: Copy the session data folder to your local machine", | |
| "Option 2: Use your original dataset files directly", | |
| "Option 3: Re-upload structures to a local directory" | |
| ], | |
| "data_format_requirements": { | |
| "targets.csv": "structure_id,target_value (no header required)", | |
| "structures": "One file per structure, filename should match structure_id in targets.csv" | |
| } | |
| } | |
| return { | |
| "success": True, | |
| "message": "Local training package generated. Follow the instructions to train on your machine.", | |
| "session_id": session_id, | |
| "model_name": model_name, | |
| "model_description": _get_model_description(model_name), | |
| "recommended_model_filename": model_filename, | |
| "server_data_path": data_path, | |
| "parameters": { | |
| "epochs": epochs, | |
| "batch_size": batch_size, | |
| "learning_rate": learning_rate, | |
| "train_ratio": train_ratio, | |
| "val_ratio": val_ratio, | |
| "test_ratio": test_ratio | |
| }, | |
| "local_training_script": script, | |
| "config_yml_template": config_yml_content, | |
| "data_export_guide": data_export_guide, | |
| "setup_instructions": [ | |
| "1. Set up Python environment with MatDeepLearn dependencies", | |
| "2. Save the training script as 'train_local.py'", | |
| "3. Save config_yml_template as 'config.yml'", | |
| "4. Copy your dataset to a local folder or use existing data", | |
| "5. Update DATA_PATH in the script", | |
| "6. Run: python train_local.py" | |
| ], | |
| "environment_tip": "Use get_environment_requirements tool for detailed environment setup guide" | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def predict_with_session_model( | |
| session_id: str, | |
| model_path: str = "<path-to-your-model.pth>" | |
| ) -> dict: | |
| """ | |
| Generate a local prediction script for a session's data. | |
| Server-side prediction is not available; use the generated script locally. | |
| Parameters: | |
| session_id (str): The session ID. | |
| model_path (str): Path to the trained model file on your local machine. | |
| Returns: | |
| dict: Local prediction script and session data info. | |
| """ | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| data_path = os.path.join(session_path, "data") | |
| # Get session data info | |
| structure_files = [] | |
| has_targets = os.path.exists(os.path.join(data_path, "targets.csv")) | |
| if os.path.exists(data_path): | |
| skip_files = {"targets.csv", "id.json", "atom_dict.json", "processed"} | |
| structure_files = [ | |
| f for f in os.listdir(data_path) | |
| if f not in skip_files and not f.startswith('.') and os.path.isfile(os.path.join(data_path, f)) | |
| ] | |
| predict_script = '''"""MatDeepLearn Prediction Script | |
| Generated for session: {session_id} | |
| """ | |
| import yaml | |
| import torch | |
| from pathlib import Path | |
| from matdeeplearn import training, process | |
| # Configuration - Update these paths | |
| DATA_PATH = Path(r"{data_placeholder}") | |
| MODEL_PATH = Path(r"{model_path}") | |
| CONFIG_PATH = Path(r"config.yml") | |
| def main(): | |
| if not DATA_PATH.exists(): | |
| raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}") | |
| if not MODEL_PATH.exists(): | |
| raise FileNotFoundError(f"Model path not found: {{MODEL_PATH}}") | |
| if not CONFIG_PATH.exists(): | |
| raise FileNotFoundError(f"Config file not found: {{CONFIG_PATH}}") | |
| with CONFIG_PATH.open("r", encoding="utf-8") as f: | |
| config = yaml.safe_load(f) | |
| data_path_str = str(DATA_PATH.resolve()) | |
| # Process data | |
| processing_args = config.get("Processing", {{}}).copy() | |
| processing_args["data_path"] = data_path_str | |
| dataset = process.get_dataset( | |
| data_path_str, | |
| 0, # target_index | |
| "True", # reprocess | |
| processing_args | |
| ) | |
| job_config = {{ | |
| "job_name": "prediction_{session_id}", | |
| "model_path": str(MODEL_PATH.resolve()), | |
| "write_output": "True" | |
| }} | |
| print(f"Predicting on {{len(dataset)}} structures...") | |
| print(f"Using model: {{MODEL_PATH}}") | |
| # Run prediction | |
| test_error = training.predict(dataset, "l1_loss", job_config) | |
| print(f"\nPrediction complete!") | |
| print(f"MAE (if targets available): {{test_error:.4f}}") | |
| print(f"Results saved to: prediction_{session_id}_predicted_outputs.csv") | |
| if __name__ == "__main__": | |
| main() | |
| '''.format( | |
| session_id=session_id, | |
| data_placeholder="<path-to-your-local-data>", | |
| model_path=model_path | |
| ) | |
| return { | |
| "success": True, | |
| "message": "Server-side prediction is not available. Use the generated script locally.", | |
| "session_id": session_id, | |
| "session_data_info": { | |
| "server_data_path": data_path, | |
| "num_structure_files": len(structure_files), | |
| "has_targets": has_targets | |
| }, | |
| "local_prediction_script": predict_script, | |
| "usage_instructions": [ | |
| "1. Copy the session data to your local machine", | |
| "2. Save the script as 'predict_local.py'", | |
| "3. Ensure config.yml is in the same directory", | |
| "4. Update DATA_PATH and MODEL_PATH", | |
| "5. Run: python predict_local.py" | |
| ], | |
| "note": "You need a trained model (.pth file) to run predictions" | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def download_session_data(session_id: str) -> dict: | |
| """ | |
| Get session data location and summary for copying to local machine. | |
| Parameters: | |
| session_id (str): The session ID. | |
| Returns: | |
| dict: Session data location and file list. | |
| """ | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| data_path = os.path.join(session_path, "data") | |
| # List all files | |
| files_info = [] | |
| total_size = 0 | |
| if os.path.exists(data_path): | |
| for root, dirs, files in os.walk(data_path): | |
| for f in files: | |
| file_path = os.path.join(root, f) | |
| rel_path = os.path.relpath(file_path, data_path) | |
| size = os.path.getsize(file_path) | |
| total_size += size | |
| files_info.append({ | |
| "path": rel_path.replace("\\", "/"), | |
| "size_bytes": size | |
| }) | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "data_location": data_path, | |
| "total_files": len(files_info), | |
| "total_size_mb": total_size / (1024 * 1024), | |
| "files": files_info[:50], # First 50 files | |
| "files_truncated": len(files_info) > 50, | |
| "copy_instructions": [ | |
| f"Server data location: {data_path}", | |
| "To use locally:", | |
| "1. Copy the entire data folder to your local machine", | |
| "2. Or use scp/rsync if accessing remote server", | |
| "3. Update DATA_PATH in your training/processing scripts" | |
| ] | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def run_cross_validation_session( | |
| session_id: str, | |
| model_name: str = "CGCNN_demo", | |
| cv_folds: int = 5, | |
| epochs: int = 100, | |
| batch_size: int = 32, | |
| learning_rate: float = 0.002 | |
| ) -> dict: | |
| """ | |
| Generate a local cross-validation script for session data. | |
| Parameters: | |
| session_id (str): The session ID. | |
| model_name (str): Model to use (default: "CGCNN_demo"). | |
| cv_folds (int): Number of folds (default: 5). | |
| epochs (int): Training epochs per fold (default: 100). | |
| batch_size (int): Batch size (default: 32). | |
| learning_rate (float): Learning rate (default: 0.002). | |
| Returns: | |
| dict: Local cross-validation script and instructions. | |
| """ | |
| try: | |
| session_path = _get_session_path(session_id) | |
| if not os.path.exists(session_path): | |
| return {"success": False, "error": f"Session not found: {session_id}"} | |
| data_path = os.path.join(session_path, "data") | |
| # Use the cross_validation tool to generate the script | |
| result = cross_validation( | |
| data_path=data_path, | |
| model_name=model_name, | |
| cv_folds=cv_folds, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| learning_rate=learning_rate | |
| ) | |
| if result.get("success"): | |
| result["session_id"] = session_id | |
| result["session_data_path"] = data_path | |
| result["note"] = "Update DATA_PATH in the script to your local dataset copy" | |
| return result | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def get_environment_requirements( | |
| target_os: str = "auto", | |
| cuda_version: str = "11.8", | |
| include_optional: bool = True, | |
| output_format: str = "detailed" | |
| ) -> dict: | |
| """ | |
| Get complete environment setup requirements for local MatDeepLearn training. | |
| Parameters: | |
| target_os (str): Target operating system - "windows", "linux", "macos", or "auto" (default: "auto"). | |
| cuda_version (str): Target CUDA version - "10.2", "11.7", "11.8", "12.1" (default: "11.8"). | |
| include_optional (bool): Include optional dependencies like SOAP, ray (default: True). | |
| output_format (str): Output format - "detailed", "minimal", "script" (default: "detailed"). | |
| Returns: | |
| dict: Complete environment setup guide with installation commands. | |
| """ | |
| try: | |
| import platform | |
| # Auto-detect OS | |
| if target_os == "auto": | |
| system = platform.system().lower() | |
| if system == "windows": | |
| target_os = "windows" | |
| elif system == "darwin": | |
| target_os = "macos" | |
| else: | |
| target_os = "linux" | |
| # PyTorch installation URLs based on CUDA version | |
| pytorch_urls = { | |
| "10.2": "https://download.pytorch.org/whl/cu102", | |
| "11.7": "https://download.pytorch.org/whl/cu117", | |
| "11.8": "https://download.pytorch.org/whl/cu118", | |
| "12.1": "https://download.pytorch.org/whl/cu121", | |
| "cpu": "" | |
| } | |
| pytorch_url = pytorch_urls.get(cuda_version, pytorch_urls["11.8"]) | |
| # Core dependencies | |
| core_deps = [ | |
| "torch>=1.8.0", | |
| "torch-geometric>=2.0.0", | |
| "torch-scatter", | |
| "torch-sparse", | |
| "torch-cluster", | |
| "torch-spline-conv", | |
| "ase>=3.20.0", | |
| "pymatgen>=2020.9.0", | |
| "numpy>=1.20.0", | |
| "scipy>=1.6.0", | |
| "scikit-learn>=0.24.0", | |
| "pyyaml", | |
| "matplotlib", | |
| "pandas" | |
| ] | |
| optional_deps = [ | |
| "dscribe>=0.3.5 # For SOAP descriptor", | |
| "ray>=1.0.0 # For hyperparameter optimization", | |
| "joblib>=0.13.0 # For parallel processing" | |
| ] | |
| # Generate requirements.txt content | |
| requirements_content = "# MatDeepLearn Requirements\n" | |
| requirements_content += "# Core dependencies\n" | |
| for dep in core_deps: | |
| requirements_content += f"{dep}\n" | |
| if include_optional: | |
| requirements_content += "\n# Optional dependencies\n" | |
| for dep in optional_deps: | |
| requirements_content += f"{dep}\n" | |
| # Generate installation script based on OS | |
| if target_os == "windows": | |
| activation_cmd = "matdeeplearn_env\\Scripts\\activate" | |
| python_cmd = "python" | |
| shell_comment = "REM" | |
| script_ext = ".bat" | |
| else: | |
| activation_cmd = "source matdeeplearn_env/bin/activate" | |
| python_cmd = "python3" | |
| shell_comment = "#" | |
| script_ext = ".sh" | |
| install_script = f"""{shell_comment} MatDeepLearn Environment Setup Script | |
| {shell_comment} Target OS: {target_os}, CUDA: {cuda_version} | |
| {shell_comment} Step 1: Create virtual environment | |
| {python_cmd} -m venv matdeeplearn_env | |
| {shell_comment} Step 2: Activate environment | |
| {activation_cmd} | |
| {shell_comment} Step 3: Upgrade pip | |
| pip install --upgrade pip | |
| {shell_comment} Step 4: Install PyTorch with CUDA {cuda_version} | |
| pip install torch torchvision torchaudio --index-url {pytorch_url} | |
| {shell_comment} Step 5: Install PyTorch Geometric | |
| pip install torch-geometric | |
| pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cu{cuda_version.replace('.', '')}.html | |
| {shell_comment} Step 6: Install core dependencies | |
| pip install ase pymatgen numpy scipy matplotlib scikit-learn pyyaml pandas | |
| {shell_comment} Step 7: Clone and install MatDeepLearn | |
| git clone https://github.com/Fung-Lab/MatDeepLearn.git | |
| cd MatDeepLearn | |
| pip install -e . | |
| """ | |
| if include_optional: | |
| install_script += f""" | |
| {shell_comment} Step 8 (Optional): Install additional dependencies | |
| pip install dscribe ray joblib | |
| """ | |
| # GPU check script | |
| gpu_check_script = """import torch | |
| print("="*50) | |
| print("MatDeepLearn Environment Check") | |
| print("="*50) | |
| print(f"Python version: {__import__('sys').version}") | |
| print(f"PyTorch version: {torch.__version__}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| print(f"CUDA version: {torch.version.cuda}") | |
| print(f"GPU count: {torch.cuda.device_count()}") | |
| for i in range(torch.cuda.device_count()): | |
| print(f" GPU {i}: {torch.cuda.get_device_name(i)}") | |
| else: | |
| print("WARNING: CUDA not available. Training will run on CPU (slower).") | |
| try: | |
| import torch_geometric | |
| print(f"PyTorch Geometric version: {torch_geometric.__version__}") | |
| except ImportError: | |
| print("ERROR: PyTorch Geometric not installed!") | |
| try: | |
| import ase | |
| print(f"ASE version: {ase.__version__}") | |
| except ImportError: | |
| print("ERROR: ASE not installed!") | |
| try: | |
| import pymatgen | |
| print(f"Pymatgen version: {pymatgen.__version__}") | |
| except ImportError: | |
| print("ERROR: Pymatgen not installed!") | |
| print("="*50) | |
| print("Environment check complete!") | |
| """ | |
| result = { | |
| "success": True, | |
| "target_os": target_os, | |
| "cuda_version": cuda_version, | |
| "python_requirement": "Python 3.7 - 3.10 recommended", | |
| "core_dependencies": core_deps, | |
| "requirements_txt": requirements_content, | |
| "installation_script": install_script, | |
| "script_filename": f"setup_matdeeplearn{script_ext}", | |
| "gpu_check_script": gpu_check_script, | |
| "gpu_check_filename": "check_environment.py", | |
| "quick_start": [ | |
| f"1. Save installation script as setup_matdeeplearn{script_ext}", | |
| f"2. Run the script to set up environment", | |
| "3. Save check_environment.py and run to verify installation", | |
| "4. Use train_model or train_session_model to generate training scripts" | |
| ] | |
| } | |
| if include_optional: | |
| result["optional_dependencies"] = optional_deps | |
| if output_format == "detailed": | |
| result["detailed_steps"] = { | |
| "step1_venv": { | |
| "description": "Create isolated Python environment", | |
| "command": f"{python_cmd} -m venv matdeeplearn_env", | |
| "note": "This creates a clean environment for MatDeepLearn" | |
| }, | |
| "step2_activate": { | |
| "description": "Activate the environment", | |
| "command": activation_cmd, | |
| "note": "Run this each time before using MatDeepLearn" | |
| }, | |
| "step3_pytorch": { | |
| "description": "Install PyTorch with GPU support", | |
| "command": f"pip install torch torchvision torchaudio --index-url {pytorch_url}", | |
| "note": f"This installs PyTorch for CUDA {cuda_version}. For CPU only, omit --index-url" | |
| }, | |
| "step4_pyg": { | |
| "description": "Install PyTorch Geometric", | |
| "commands": [ | |
| "pip install torch-geometric", | |
| f"pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cu{cuda_version.replace('.', '')}.html" | |
| ], | |
| "note": "PyG requires matching versions with your PyTorch" | |
| }, | |
| "step5_deps": { | |
| "description": "Install material science libraries", | |
| "command": "pip install ase pymatgen numpy scipy matplotlib scikit-learn pyyaml", | |
| "note": "These handle crystal structures and scientific computing" | |
| }, | |
| "step6_matdeeplearn": { | |
| "description": "Install MatDeepLearn", | |
| "commands": [ | |
| "git clone https://github.com/Fung-Lab/MatDeepLearn.git", | |
| "cd MatDeepLearn", | |
| "pip install -e ." | |
| ], | |
| "note": "Editable install allows code modifications" | |
| } | |
| } | |
| result["troubleshooting"] = { | |
| "cuda_not_found": "Ensure NVIDIA drivers are installed and match CUDA version", | |
| "pyg_version_mismatch": "PyTorch Geometric wheels must match your PyTorch version exactly", | |
| "memory_error": "Reduce batch_size if you encounter GPU memory errors", | |
| "import_error": "Verify all dependencies installed in the same environment" | |
| } | |
| return result | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def get_hyperparameter_guide( | |
| model_name: str = "CGCNN_demo", | |
| dataset_size: Optional[int] = None, | |
| task_type: str = "regression" | |
| ) -> dict: | |
| """ | |
| Get hyperparameter recommendations and tuning guide. | |
| Parameters: | |
| model_name (str): Model name (default: "CGCNN_demo"). | |
| dataset_size (int, optional): Approximate size of your dataset. | |
| task_type (str): Task type - "regression" or "classification" (default: "regression"). | |
| Returns: | |
| dict: Hyperparameter recommendations and tuning guide. | |
| """ | |
| try: | |
| # Load config | |
| config_path = os.path.join(project_root, "config.yml") | |
| with open(config_path, "r") as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| available_models = list(config.get("Models", {}).keys()) | |
| if model_name not in config.get("Models", {}): | |
| return { | |
| "success": False, | |
| "error": f"Model '{model_name}' not found. Available: {available_models}" | |
| } | |
| model_config = config["Models"][model_name] | |
| # Base recommendations | |
| recommendations = { | |
| "learning_rate": { | |
| "default": model_config.get("lr", 0.002), | |
| "range": [0.0001, 0.01], | |
| "tuning_tip": "Start with default, reduce by 0.5x if training unstable" | |
| }, | |
| "batch_size": { | |
| "default": model_config.get("batch_size", 32), | |
| "range": [16, 256], | |
| "tuning_tip": "Larger batch = faster but may need higher lr; smaller batch = better generalization" | |
| }, | |
| "epochs": { | |
| "default": model_config.get("epochs", 250), | |
| "range": [50, 500], | |
| "tuning_tip": "Monitor validation loss; stop if no improvement for 20+ epochs" | |
| }, | |
| "hidden_dims": { | |
| "default": [model_config.get("dim1", 100), model_config.get("dim2", 100)], | |
| "range": [[32, 32], [256, 256]], | |
| "tuning_tip": "Increase for complex tasks, decrease for small datasets" | |
| }, | |
| "gc_layers": { | |
| "default": model_config.get("gc_count", 4), | |
| "range": [2, 8], | |
| "tuning_tip": "More layers capture longer-range interactions but risk oversmoothing" | |
| }, | |
| "dropout": { | |
| "default": model_config.get("dropout_rate", 0.0), | |
| "range": [0.0, 0.5], | |
| "tuning_tip": "Increase if overfitting (train << val error)" | |
| } | |
| } | |
| # Adjust based on dataset size | |
| if dataset_size: | |
| if dataset_size < 500: | |
| size_advice = "Small dataset - use regularization, smaller model" | |
| recommendations["batch_size"]["recommended"] = min(32, dataset_size // 10) | |
| recommendations["dropout"]["recommended"] = 0.2 | |
| recommendations["hidden_dims"]["recommended"] = [64, 64] | |
| recommendations["epochs"]["recommended"] = 200 | |
| elif dataset_size < 5000: | |
| size_advice = "Medium dataset - default settings should work" | |
| recommendations["batch_size"]["recommended"] = 64 | |
| recommendations["dropout"]["recommended"] = 0.1 | |
| recommendations["epochs"]["recommended"] = 250 | |
| else: | |
| size_advice = "Large dataset - can use larger model" | |
| recommendations["batch_size"]["recommended"] = 128 | |
| recommendations["hidden_dims"]["recommended"] = [128, 128] | |
| recommendations["epochs"]["recommended"] = 300 | |
| else: | |
| size_advice = "Provide dataset_size for tailored recommendations" | |
| # Generate hyperparameter search script | |
| hp_search_script = '''"""MatDeepLearn Hyperparameter Search Script | |
| Simple grid search for key hyperparameters. | |
| """ | |
| import yaml | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| from matdeeplearn import training | |
| from itertools import product | |
| DATA_PATH = Path(r"<path-to-your-dataset>") | |
| CONFIG_PATH = Path(r"config.yml") | |
| # Hyperparameter grid | |
| PARAM_GRID = { | |
| "lr": [0.001, 0.002, 0.005], | |
| "batch_size": [32, 64, 128], | |
| "gc_count": [3, 4, 5], | |
| } | |
| def run_experiment(params, data_path_str, config, rank, world_size): | |
| job_config = { | |
| "job_name": f"hp_search_lr{params['lr']}_bs{params['batch_size']}", | |
| "reprocess": "False", | |
| "model": "''' + model_name + '''", | |
| "save_model": "False", | |
| "write_output": "False", | |
| "parallel": "False", | |
| "seed": 42, | |
| } | |
| training_config = config.get("Training", {}).copy() | |
| model_config = config["Models"]["''' + model_name + '''"].copy() | |
| model_config.update({ | |
| "lr": params["lr"], | |
| "batch_size": params["batch_size"], | |
| "gc_count": params.get("gc_count", model_config.get("gc_count", 4)), | |
| "epochs": 100, # Shorter for search | |
| }) | |
| errors = training.train_regular(rank, world_size, data_path_str, job_config, training_config, model_config) | |
| return {"params": params, "train_mae": errors[0], "val_mae": errors[1], "test_mae": errors[2]} | |
| def main(): | |
| with CONFIG_PATH.open("r") as f: | |
| config = yaml.safe_load(f) | |
| data_path_str = str(DATA_PATH.resolve()) | |
| rank = "cuda" if torch.cuda.is_available() else "cpu" | |
| world_size = torch.cuda.device_count() | |
| results = [] | |
| param_combinations = [dict(zip(PARAM_GRID.keys(), v)) for v in product(*PARAM_GRID.values())] | |
| print(f"Running {len(param_combinations)} experiments...") | |
| for i, params in enumerate(param_combinations): | |
| print(f"\\nExperiment {i+1}/{len(param_combinations)}: {params}") | |
| result = run_experiment(params, data_path_str, config, rank, world_size) | |
| results.append(result) | |
| print(f" Val MAE: {result['val_mae']:.4f}") | |
| # Find best | |
| best = min(results, key=lambda x: x["val_mae"]) | |
| print("\\n" + "="*50) | |
| print(f"BEST PARAMS: {best['params']}") | |
| print(f"Best Val MAE: {best['val_mae']:.4f}") | |
| if __name__ == "__main__": | |
| main() | |
| ''' | |
| return { | |
| "success": True, | |
| "model_name": model_name, | |
| "model_description": _get_model_description(model_name), | |
| "default_config": model_config, | |
| "recommendations": recommendations, | |
| "dataset_size_advice": size_advice, | |
| "hyperparameter_search_script": hp_search_script, | |
| "tuning_strategy": [ | |
| "1. Start with default hyperparameters", | |
| "2. Train and observe train/val error curves", | |
| "3. If overfitting (train << val): increase dropout, reduce model size, add regularization", | |
| "4. If underfitting (both high): increase model size, more epochs, higher lr", | |
| "5. Use the search script for systematic optimization" | |
| ], | |
| "common_issues": { | |
| "loss_not_decreasing": "Try lower learning rate or different optimizer", | |
| "nan_loss": "Reduce learning rate significantly, check data normalization", | |
| "overfitting": "Add dropout, reduce hidden dims, use early stopping", | |
| "slow_convergence": "Increase learning rate, use scheduler with warmup" | |
| } | |
| } | |
| except Exception as e: | |
| return {"success": False, "error": str(e)} | |
| def create_app() -> FastMCP: | |
| """ | |
| Creates and returns the FastMCP application instance. | |
| Returns: | |
| FastMCP: The FastMCP application instance. | |
| """ | |
| return mcp | |