MatDeepLearn / mcp_output /mcp_plugin /mcp_service.py
SEUyishu's picture
Update mcp_output/mcp_plugin/mcp_service.py
ed2dea7 verified
"""
MatDeepLearn MCP Service
A Model Context Protocol service for materials property prediction using Graph Neural Networks.
"""
import os
import sys
import json
import tempfile
import yaml
import numpy as np
import base64
import hashlib
import shutil
import uuid
import zipfile
import tarfile
import io
from datetime import datetime
from typing import Optional, List, Dict, Any
from pathlib import Path
# Add MatDeepLearn to path
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if project_root not in sys.path:
sys.path.insert(0, project_root)
from fastmcp import FastMCP
# Import MatDeepLearn modules
try:
import torch
from matdeeplearn import models, process, training
from matdeeplearn.models.utils import model_summary
MATDEEPLEARN_AVAILABLE = True
except ImportError as e:
MATDEEPLEARN_AVAILABLE = False
IMPORT_ERROR = str(e)
mcp = FastMCP("matdeeplearn_service")
# ============================================================================
# 全局存储管理 - 用于管理上传的数据和训练的模型
# ============================================================================
# 服务器端存储目录
STORAGE_BASE = os.path.join(project_root, "mcp_storage")
DATASETS_DIR = os.path.join(STORAGE_BASE, "datasets")
MODELS_DIR = os.path.join(STORAGE_BASE, "models")
SESSIONS_DIR = os.path.join(STORAGE_BASE, "sessions")
# 确保存储目录存在
for dir_path in [STORAGE_BASE, DATASETS_DIR, MODELS_DIR, SESSIONS_DIR]:
os.makedirs(dir_path, exist_ok=True)
# 会话管理字典 (session_id -> session_info)
_sessions: Dict[str, Dict] = {}
def _get_session_path(session_id: str) -> str:
"""获取会话目录路径"""
return os.path.join(SESSIONS_DIR, session_id)
def _generate_session_id() -> str:
"""生成唯一会话ID"""
return f"session_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}"
def _generate_dataset_id(name: str) -> str:
"""生成数据集ID"""
return f"dataset_{name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
def _generate_model_id(model_name: str) -> str:
"""生成模型ID"""
return f"model_{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
def _safe_join(base: str, *paths: str) -> str:
"""Join paths and ensure the result stays inside base directory."""
base_path = Path(base).resolve()
target_path = base_path.joinpath(*paths).resolve()
if not str(target_path).startswith(str(base_path)):
raise ValueError("Attempted to write outside of the allowed directory")
return str(target_path)
def _normalize_filename(filename: str) -> str:
"""Return sanitized filename without directory components."""
clean_name = os.path.basename(filename)
if not clean_name:
raise ValueError("Filename must not be empty")
return clean_name
def _clear_directory_contents(directory: str) -> None:
"""Remove all files and folders inside the provided directory."""
if not os.path.isdir(directory):
return
for entry in os.listdir(directory):
entry_path = os.path.join(directory, entry)
if os.path.isdir(entry_path):
shutil.rmtree(entry_path)
else:
os.remove(entry_path)
def _copy_tree(src: str, dst: str, overwrite: bool = False) -> Dict[str, List[str]]:
"""Copy a directory tree with overwrite and traversal protection."""
results = {"created": [], "overwritten": [], "skipped": []}
src_path = Path(src)
for root, _, files in os.walk(src_path):
rel_root = os.path.relpath(root, src_path)
rel_root = "" if rel_root == "." else rel_root
for file_name in files:
if file_name.startswith(".__MACOSX"):
continue
rel_path = os.path.normpath(os.path.join(rel_root, file_name))
dest_path = _safe_join(dst, rel_path)
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
src_file = os.path.join(root, file_name)
if os.path.exists(dest_path):
if overwrite:
shutil.copy2(src_file, dest_path)
results["overwritten"].append(rel_path.replace("\\", "/"))
else:
results["skipped"].append(rel_path.replace("\\", "/"))
else:
shutil.copy2(src_file, dest_path)
results["created"].append(rel_path.replace("\\", "/"))
return results
def _resolve_dataset_root(extracted_dir: str) -> str:
"""Select the most probable dataset root inside an extracted archive."""
entries = [p for p in Path(extracted_dir).iterdir() if not p.name.startswith("__MACOSX")]
if len(entries) == 1 and entries[0].is_dir():
return str(entries[0])
return extracted_dir
def _update_session_uploaded_files(session_path: str) -> None:
"""Rescan session data directory and persist uploaded file list."""
info_file = os.path.join(session_path, "session_info.json")
if not os.path.exists(info_file):
return
data_dir = os.path.join(session_path, "data")
uploaded = []
if os.path.exists(data_dir):
for root, _, files in os.walk(data_dir):
for name in files:
rel_path = os.path.relpath(os.path.join(root, name), data_dir)
uploaded.append(rel_path.replace("\\", "/"))
with open(info_file, 'r', encoding='utf-8') as f:
session_info = json.load(f)
session_info["uploaded_files"] = sorted(uploaded)
with open(info_file, 'w', encoding='utf-8') as f:
json.dump(session_info, f, indent=2)
def _record_session_data_source(session_path: str, source_info: Dict[str, Any]) -> None:
"""Append dataset source metadata to the session record."""
info_file = os.path.join(session_path, "session_info.json")
if not os.path.exists(info_file):
return
with open(info_file, 'r', encoding='utf-8') as f:
session_info = json.load(f)
data_sources = session_info.setdefault("data_sources", [])
data_sources.append(source_info)
session_info["data_sources"] = data_sources[-10:]
with open(info_file, 'w', encoding='utf-8') as f:
json.dump(session_info, f, indent=2)
def _summarize_dataset_directory(data_path: str) -> Dict[str, Any]:
"""Collect lightweight statistics about files inside a dataset directory."""
summary = {
"total_files": 0,
"targets_csv": False,
"structure_extensions": {}
}
if not os.path.exists(data_path):
return summary
for root, _, files in os.walk(data_path):
for name in files:
summary["total_files"] += 1
if name.lower() == "targets.csv":
summary["targets_csv"] = True
else:
ext = os.path.splitext(name)[1].lower() or "<no_ext>"
summary["structure_extensions"][ext] = summary["structure_extensions"].get(ext, 0) + 1
return summary
def _safe_extract_zip(archive: zipfile.ZipFile, destination: str) -> None:
"""Extract zip members while preventing path traversal."""
for member in archive.infolist():
name = member.filename
if not name:
continue
target_path = _safe_join(destination, name)
if member.is_dir() or name.endswith('/'):
os.makedirs(target_path, exist_ok=True)
continue
os.makedirs(os.path.dirname(target_path), exist_ok=True)
with archive.open(member, 'r') as src, open(target_path, 'wb') as dst:
shutil.copyfileobj(src, dst)
def _safe_extract_tar(archive: tarfile.TarFile, destination: str) -> None:
"""Extract tar members while preventing unsafe writes."""
for member in archive.getmembers():
name = member.name
if not name:
continue
if member.islnk() or member.issym():
continue
target_path = _safe_join(destination, name)
if member.isdir():
os.makedirs(target_path, exist_ok=True)
continue
if member.isfile():
extracted = archive.extractfile(member)
if extracted is None:
continue
os.makedirs(os.path.dirname(target_path), exist_ok=True)
with extracted as src, open(target_path, 'wb') as dst:
shutil.copyfileobj(src, dst)
def _build_local_training_script(
data_path_literal: str,
model_name: str,
target_index: int,
epochs: int,
batch_size: int,
learning_rate: float,
train_ratio: float,
val_ratio: float,
test_ratio: float,
save_model: bool,
model_output_literal: str,
reprocess: bool = False,
config_path_literal: str = "config.yml"
) -> str:
"""Generate a reusable Python script string for local MatDeepLearn training."""
job_name = f"local_train_{model_name.lower()}"
reprocess_str = "True" if reprocess else "False"
save_model_str = "True" if save_model else "False"
script = """import yaml
from pathlib import Path
import numpy as np
import torch
from matdeeplearn import training
DATA_PATH = Path(r"{data_path}")
CONFIG_PATH = Path(r"{config_path}")
MODEL_OUTPUT = Path(r"{model_output}")
def main() -> None:
if not DATA_PATH.exists():
raise FileNotFoundError("Data path not found: {{}}".format(DATA_PATH))
if not CONFIG_PATH.exists():
raise FileNotFoundError("Config file not found: {{}}".format(CONFIG_PATH))
with CONFIG_PATH.open("r", encoding="utf-8") as f:
config = yaml.safe_load(f)
job_config = {{
"job_name": "{job_name}",
"reprocess": "{reprocess}",
"model": "{model_name}",
"load_model": "False",
"save_model": "{save_model}",
"model_path": str(MODEL_OUTPUT.resolve()),
"write_output": "True",
"parallel": "False",
"seed": int(np.random.randint(1, 1_000_000)),
}}
training_config = config.get("Training", {{}}).copy()
training_config.update({{
"target_index": {target_index},
"train_ratio": {train_ratio},
"val_ratio": {val_ratio},
"test_ratio": {test_ratio},
"verbosity": 5,
}})
model_config = config["Models"]["{model_name}"].copy()
model_config.update({{
"epochs": {epochs},
"batch_size": {batch_size},
"lr": {learning_rate},
}})
data_path = str(DATA_PATH.resolve())
if MODEL_OUTPUT.parent != Path("."):
MODEL_OUTPUT.parent.mkdir(parents=True, exist_ok=True)
world_size = torch.cuda.device_count()
rank = "cuda" if torch.cuda.is_available() else "cpu"
errors = training.train_regular(
rank,
world_size,
data_path,
job_config,
training_config,
model_config,
)
print("Training complete. Train/Val/Test errors:", errors)
if __name__ == "__main__":
main()
""".format(
data_path=data_path_literal,
config_path=config_path_literal,
model_output=model_output_literal,
job_name=job_name,
reprocess=reprocess_str,
model_name=model_name,
save_model=save_model_str,
target_index=target_index,
train_ratio=train_ratio,
val_ratio=val_ratio,
test_ratio=test_ratio,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
)
return script
def _infer_structure_format(structure_files: List[str]) -> Optional[str]:
"""Guess the structure file extension for ASE parsing."""
normalized: Dict[str, int] = {}
for file_name in structure_files:
root, ext = os.path.splitext(file_name)
if ext:
ext = ext.lower().lstrip('.')
else:
upper_name = os.path.basename(file_name).upper()
if upper_name in {"POSCAR", "CONTCAR"}:
ext = "vasp"
else:
ext = ""
if not ext:
continue
normalized[ext] = normalized.get(ext, 0) + 1
if not normalized:
return None
if len(normalized) == 1:
return next(iter(normalized))
# Prefer common solid-state formats
priority = ["cif", "vasp", "poscar", "xyz", "json"]
sorted_items = sorted(normalized.items(), key=lambda item: (-item[1], priority.index(item[0]) if item[0] in priority else 99))
top_ext, top_count = sorted_items[0]
if len(structure_files) == top_count:
return top_ext
return None
# ============================================================================
# 会话管理工具
# ============================================================================
@mcp.tool(name="create_session", description="Create a new working session for uploading data and training models. Returns a session_id to use in subsequent operations.")
def create_session(session_name: Optional[str] = None) -> dict:
"""
Create a new working session. Use this before uploading data.
Parameters:
session_name (str, optional): A friendly name for this session.
Returns:
dict: Contains session_id and session info.
Example:
create_session(session_name="my_material_project")
"""
try:
session_id = _generate_session_id()
session_path = _get_session_path(session_id)
os.makedirs(session_path, exist_ok=True)
os.makedirs(os.path.join(session_path, "data"), exist_ok=True)
os.makedirs(os.path.join(session_path, "models"), exist_ok=True)
os.makedirs(os.path.join(session_path, "outputs"), exist_ok=True)
# Initialize an id index file under data to avoid missing-file errors
id_index_path = os.path.join(session_path, "data", "id.json")
if not os.path.exists(id_index_path):
try:
with open(id_index_path, 'w', encoding='utf-8') as _idf:
json.dump({}, _idf)
except Exception:
pass
session_info = {
"session_id": session_id,
"session_name": session_name or session_id,
"created_at": datetime.now().isoformat(),
"data_path": os.path.join(session_path, "data"),
"models_path": os.path.join(session_path, "models"),
"outputs_path": os.path.join(session_path, "outputs"),
"uploaded_files": [],
"trained_models": [],
"data_sources": [],
"status": "active"
}
_sessions[session_id] = session_info
# Save session info to disk
with open(os.path.join(session_path, "session_info.json"), 'w') as f:
json.dump(session_info, f, indent=2)
return {
"success": True,
"session_id": session_id,
"session_name": session_info["session_name"],
"message": "Session created successfully. Use this session_id for uploading data and training.",
"next_steps": [
"1. Upload structure files using upload_structure_files",
"2. Upload targets.csv using upload_targets",
"3. Process data using process_session_data",
"4. Train model using train_session_model"
]
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="get_session_info", description="Get information about an existing session.")
def get_session_info(session_id: str) -> dict:
"""
Get information about an existing session.
Parameters:
session_id (str): The session ID returned from create_session.
Returns:
dict: Session information including uploaded files and trained models.
"""
try:
session_path = _get_session_path(session_id)
info_file = os.path.join(session_path, "session_info.json")
if not os.path.exists(info_file):
return {"success": False, "error": f"Session not found: {session_id}"}
with open(info_file, 'r') as f:
session_info = json.load(f)
# Update with current file counts
data_path = session_info["data_path"]
if os.path.exists(data_path):
files = os.listdir(data_path)
session_info["current_files"] = files
session_info["file_count"] = len(files)
session_info["has_targets"] = "targets.csv" in files
session_info["dataset_summary"] = _summarize_dataset_directory(data_path)
return {"success": True, **session_info}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="list_sessions", description="List all available sessions.")
def list_sessions() -> dict:
"""
List all available sessions on the server.
Returns:
dict: List of sessions with their basic info.
"""
try:
sessions = []
if os.path.exists(SESSIONS_DIR):
for session_id in os.listdir(SESSIONS_DIR):
info_file = os.path.join(SESSIONS_DIR, session_id, "session_info.json")
if os.path.exists(info_file):
with open(info_file, 'r') as f:
info = json.load(f)
sessions.append({
"session_id": session_id,
"session_name": info.get("session_name", session_id),
"created_at": info.get("created_at"),
"status": info.get("status", "unknown")
})
return {
"success": True,
"sessions": sessions,
"total_sessions": len(sessions)
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="delete_session", description="Delete a session and all its data.")
def delete_session(session_id: str, confirm: bool = False) -> dict:
"""
Delete a session and all associated data.
Parameters:
session_id (str): The session ID to delete.
confirm (bool): Must be True to confirm deletion.
Returns:
dict: Deletion status.
"""
try:
if not confirm:
return {
"success": False,
"error": "Please set confirm=True to delete the session. This action cannot be undone."
}
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
shutil.rmtree(session_path)
if session_id in _sessions:
del _sessions[session_id]
return {
"success": True,
"message": f"Session {session_id} deleted successfully."
}
except Exception as e:
return {"success": False, "error": str(e)}
# ============================================================================
# 数据上传工具
# ============================================================================
@mcp.tool(name="upload_structure_file", description="Upload a single structure file to a session. Supports CIF, XYZ, POSCAR, JSON formats.")
def upload_structure_file(
session_id: str,
filename: str,
file_content: str,
file_format: Optional[str] = None
) -> dict:
"""
Upload a single structure file to a session.
Parameters:
session_id (str): The session ID.
filename (str): Name for the file (e.g., "structure1.cif").
file_content (str): The complete file content as a string.
file_format (str, optional): File format hint (auto-detected from filename if not provided).
Returns:
dict: Upload status and file info.
Example:
upload_structure_file(
session_id="session_xxx",
filename="NaCl.cif",
file_content="data_NaCl\\n_cell_length_a 5.64..."
)
"""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
data_path = os.path.join(session_path, "data")
os.makedirs(data_path, exist_ok=True)
filename = _normalize_filename(filename)
file_path = _safe_join(data_path, filename)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(file_content)
# Validate structure if possible
validation = {"valid": True}
try:
import ase.io
with tempfile.NamedTemporaryFile(mode='w', suffix=os.path.splitext(filename)[1], delete=False) as tmp:
tmp.write(file_content)
tmp_path = tmp.name
try:
structure = ase.io.read(tmp_path)
validation = {
"valid": True,
"num_atoms": len(structure),
"formula": structure.get_chemical_formula()
}
finally:
os.unlink(tmp_path)
except Exception as e:
validation = {"valid": False, "warning": str(e)}
# Update id.json index
try:
id_index_path = os.path.join(data_path, "id.json")
if not os.path.exists(id_index_path):
with open(id_index_path, 'w', encoding='utf-8') as _idf:
json.dump({}, _idf)
with open(id_index_path, 'r', encoding='utf-8') as _idf:
id_index = json.load(_idf)
except Exception:
id_index = {}
file_id = uuid.uuid4().hex
id_index[file_id] = {
"filename": filename,
"uploaded_at": datetime.now().isoformat(),
"size": len(file_content)
}
try:
with open(id_index_path, 'w', encoding='utf-8') as _idf:
json.dump(id_index, _idf, indent=2)
except Exception:
pass
try:
_update_session_uploaded_files(session_path)
except Exception:
pass
return {
"success": True,
"filename": filename,
"file_size": len(file_content),
"saved_to": file_path,
"validation": validation,
"file_id": file_id
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="upload_structure_files_batch", description="Upload multiple structure files at once to a session.")
def upload_structure_files_batch(
session_id: str,
files: Dict[str, str]
) -> dict:
"""
Upload multiple structure files to a session in one call.
Parameters:
session_id (str): The session ID.
files (dict): Dictionary mapping filename to file content.
Example: {"struct1.cif": "content1", "struct2.cif": "content2"}
Returns:
dict: Upload status for all files.
Example:
upload_structure_files_batch(
session_id="session_xxx",
files={
"NaCl.cif": "data_NaCl...",
"ZnO.cif": "data_ZnO..."
}
)
"""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
data_path = os.path.join(session_path, "data")
os.makedirs(data_path, exist_ok=True)
results = []
success_count = 0
for filename, content in files.items():
try:
clean_name = _normalize_filename(filename)
file_path = _safe_join(data_path, clean_name)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)
# update id index
try:
id_index_path = os.path.join(data_path, "id.json")
if not os.path.exists(id_index_path):
with open(id_index_path, 'w', encoding='utf-8') as _idf:
json.dump({}, _idf)
with open(id_index_path, 'r', encoding='utf-8') as _idf:
id_index = json.load(_idf)
except Exception:
id_index = {}
file_id = uuid.uuid4().hex
id_index[file_id] = {"filename": clean_name, "uploaded_at": datetime.now().isoformat(), "size": len(content)}
try:
with open(id_index_path, 'w', encoding='utf-8') as _idf:
json.dump(id_index, _idf, indent=2)
except Exception:
pass
results.append({
"filename": clean_name,
"success": True,
"size": len(content),
"file_id": file_id
})
success_count += 1
except Exception as e:
results.append({
"filename": filename,
"success": False,
"error": str(e)
})
try:
_update_session_uploaded_files(session_path)
except Exception:
pass
return {
"success": True,
"total_files": len(files),
"successful_uploads": success_count,
"failed_uploads": len(files) - success_count,
"results": results
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="upload_targets", description="Upload targets.csv file containing target properties for training.")
def upload_targets(
session_id: str,
targets_content: str,
validate: bool = True
) -> dict:
"""
Upload targets.csv file to a session.
Parameters:
session_id (str): The session ID.
targets_content (str): Content of targets.csv file.
Format: structure_id,target_value (one per line).
validate (bool): Whether to validate the targets file.
Returns:
dict: Upload status and validation info.
Example:
upload_targets(
session_id="session_xxx",
targets_content="NaCl,1.5\\nZnO,2.3\\nTiO2,3.1"
)
"""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
data_path = os.path.join(session_path, "data")
targets_path = os.path.join(data_path, "targets.csv")
with open(targets_path, 'w', encoding='utf-8') as f:
f.write(targets_content)
# Validate and analyze
validation = {"valid": True}
if validate:
import csv
from io import StringIO
reader = csv.reader(StringIO(targets_content))
rows = list(reader)
structure_ids = []
target_values = []
for row in rows:
if len(row) >= 2:
structure_ids.append(row[0])
try:
target_values.append(float(row[1]))
except:
pass
# Check for matching structure files
existing_files = os.listdir(data_path)
structure_files = [f for f in existing_files if f != "targets.csv"]
structure_names = [os.path.splitext(f)[0] for f in structure_files]
matched = [sid for sid in structure_ids if sid in structure_names]
unmatched = [sid for sid in structure_ids if sid not in structure_names]
validation = {
"valid": True,
"num_samples": len(rows),
"num_valid_targets": len(target_values),
"target_range": {
"min": min(target_values) if target_values else None,
"max": max(target_values) if target_values else None,
"mean": sum(target_values) / len(target_values) if target_values else None
},
"matched_structures": len(matched),
"unmatched_structures": unmatched[:10] if unmatched else [],
"existing_structure_files": len(structure_files)
}
try:
_update_session_uploaded_files(session_path)
except Exception:
pass
return {
"success": True,
"saved_to": targets_path,
"validation": validation
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="upload_binary_file", description="Upload a binary file (like .pth model file) encoded as base64.")
def upload_binary_file(
session_id: str,
filename: str,
base64_content: str,
destination: str = "models"
) -> dict:
"""
Upload a binary file (e.g., pre-trained model .pth file) encoded as base64.
Parameters:
session_id (str): The session ID.
filename (str): Name for the file.
base64_content (str): File content encoded as base64 string.
destination (str): Where to save - "models" or "data".
Returns:
dict: Upload status.
Example:
# In Python, encode your model file:
# import base64
# with open("model.pth", "rb") as f:
# encoded = base64.b64encode(f.read()).decode()
# Then pass encoded as base64_content
"""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
if destination == "models":
dest_path = os.path.join(session_path, "models")
else:
dest_path = os.path.join(session_path, "data")
filename = _normalize_filename(filename)
file_path = _safe_join(dest_path, filename)
# Decode and write binary content
binary_content = base64.b64decode(base64_content)
with open(file_path, 'wb') as f:
f.write(binary_content)
try:
if destination != "models":
_update_session_uploaded_files(session_path)
except Exception:
pass
return {
"success": True,
"filename": filename,
"file_size_bytes": len(binary_content),
"saved_to": file_path
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(
name="upload_dataset_archive",
description="Upload a compressed dataset archive (zip/tar) into a session's data directory."
)
def upload_dataset_archive(
session_id: str,
filename: str,
base64_content: str,
overwrite_existing: bool = False,
clear_existing: bool = False
) -> dict:
"""Decode and extract a dataset archive directly into the session data folder."""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
data_path = os.path.join(session_path, "data")
os.makedirs(data_path, exist_ok=True)
filename = _normalize_filename(filename)
archive_bytes = base64.b64decode(base64_content)
temp_dir = tempfile.mkdtemp(prefix="mcp_dataset_")
try:
archive_lower = filename.lower()
if archive_lower.endswith(".zip"):
with zipfile.ZipFile(io.BytesIO(archive_bytes)) as archive:
_safe_extract_zip(archive, temp_dir)
elif archive_lower.endswith((".tar", ".tar.gz", ".tgz", ".tar.bz2", ".tbz")):
with tarfile.open(fileobj=io.BytesIO(archive_bytes), mode="r:*") as archive:
_safe_extract_tar(archive, temp_dir)
else:
return {
"success": False,
"error": "Unsupported archive format. Use .zip, .tar, .tar.gz, .tgz, .tar.bz2"
}
dataset_root = _resolve_dataset_root(temp_dir)
if not os.listdir(dataset_root):
return {"success": False, "error": "Archive appears to be empty after extraction."}
if clear_existing:
_clear_directory_contents(data_path)
copy_report = _copy_tree(dataset_root, data_path, overwrite=overwrite_existing)
finally:
shutil.rmtree(temp_dir, ignore_errors=True)
try:
_update_session_uploaded_files(session_path)
except Exception:
pass
summary = _summarize_dataset_directory(data_path)
try:
_record_session_data_source(session_path, {
"type": "archive_upload",
"filename": filename,
"timestamp": datetime.now().isoformat(),
"stats": {
"created": len(copy_report["created"]),
"overwritten": len(copy_report["overwritten"]),
"skipped": len(copy_report["skipped"])
}
})
except Exception:
pass
return {
"success": True,
"session_id": session_id,
"archive_name": filename,
"created_files": copy_report["created"],
"overwritten_files": copy_report["overwritten"],
"skipped_files": copy_report["skipped"],
"dataset_summary": summary,
"next_steps": [
"Use process_session_data to generate graphs",
"Confirm targets.csv is present before training"
]
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(
name="register_local_dataset",
description="Copy an existing local dataset directory into a session's data workspace."
)
def register_local_dataset(
session_id: str,
dataset_path: str,
overwrite_existing: bool = False,
clear_existing: bool = False
) -> dict:
"""Copy a dataset from disk into the managed session directory."""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
dataset_abs = os.path.abspath(dataset_path)
if not os.path.exists(dataset_abs):
return {"success": False, "error": f"Dataset path not found: {dataset_path}"}
if not os.path.isdir(dataset_abs):
return {"success": False, "error": "dataset_path must be a directory"}
data_path = os.path.join(session_path, "data")
os.makedirs(data_path, exist_ok=True)
if Path(dataset_abs).resolve() == Path(data_path).resolve():
summary = _summarize_dataset_directory(data_path)
return {
"success": True,
"session_id": session_id,
"source_path": dataset_abs,
"created_files": [],
"overwritten_files": [],
"skipped_files": [],
"dataset_summary": summary,
"message": "dataset_path already points to the session data directory; nothing to copy."
}
if clear_existing:
_clear_directory_contents(data_path)
copy_report = _copy_tree(dataset_abs, data_path, overwrite=overwrite_existing)
try:
_update_session_uploaded_files(session_path)
except Exception:
pass
summary = _summarize_dataset_directory(data_path)
try:
_record_session_data_source(session_path, {
"type": "local_import",
"source_path": dataset_abs,
"timestamp": datetime.now().isoformat(),
"stats": {
"created": len(copy_report["created"]),
"overwritten": len(copy_report["overwritten"]),
"skipped": len(copy_report["skipped"])
}
})
except Exception:
pass
return {
"success": True,
"session_id": session_id,
"source_path": dataset_abs,
"created_files": copy_report["created"],
"overwritten_files": copy_report["overwritten"],
"skipped_files": copy_report["skipped"],
"dataset_summary": summary,
"next_steps": [
"Verify targets.csv is present in session data",
"Run process_session_data to generate processed graphs"
]
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(
name="list_session_data_files",
description="List files currently available in a session's data directory."
)
def list_session_data_files(
session_id: str,
include_sizes: bool = False,
max_items: int = 200
) -> dict:
"""Enumerate dataset files stored for a session."""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
data_path = os.path.join(session_path, "data")
if not os.path.exists(data_path):
return {"success": True, "files": [], "total_files": 0, "dataset_summary": {}}
files_info = []
for root, _, files in os.walk(data_path):
for name in files:
rel_path = os.path.relpath(os.path.join(root, name), data_path).replace("\\", "/")
file_entry: Dict[str, Any] = {"path": rel_path}
file_abs_path = os.path.join(root, name)
if include_sizes:
file_entry["size_bytes"] = os.path.getsize(file_abs_path)
files_info.append(file_entry)
files_info.sort(key=lambda item: item["path"].lower())
total_files = len(files_info)
truncated = files_info[:max(0, max_items)]
return {
"success": True,
"session_id": session_id,
"files": truncated,
"total_files": total_files,
"truncated": total_files > len(truncated),
"dataset_summary": _summarize_dataset_directory(data_path)
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="check_environment", description="Check if MatDeepLearn environment is properly configured and GPU is available.")
def check_environment() -> dict:
"""
Check if the MatDeepLearn environment is properly configured.
Returns:
dict: Contains environment status including GPU availability.
"""
try:
if not MATDEEPLEARN_AVAILABLE:
return {
"success": False,
"error": f"MatDeepLearn not available: {IMPORT_ERROR}"
}
gpu_available = torch.cuda.is_available()
gpu_count = torch.cuda.device_count() if gpu_available else 0
gpu_name = torch.cuda.get_device_name(0) if gpu_available else "N/A"
return {
"success": True,
"matdeeplearn_available": True,
"torch_version": torch.__version__,
"gpu_available": gpu_available,
"gpu_count": gpu_count,
"gpu_name": gpu_name,
"available_models": [
"CGCNN_demo", "MPNN_demo", "SchNet_demo",
"MEGNet_demo", "GCN_demo", "SOAP_demo", "SM_demo"
]
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="list_available_models", description="List all available GNN models in MatDeepLearn.")
def list_available_models() -> dict:
"""
List all available Graph Neural Network models.
Returns:
dict: Contains list of available models with descriptions.
"""
try:
models_info = {
"CGCNN_demo": {
"name": "Crystal Graph Convolutional Neural Network",
"description": "A GNN for predicting material properties using crystal graphs.",
"paper": "Xie & Grossman, PRL 2018"
},
"MPNN_demo": {
"name": "Message Passing Neural Network",
"description": "General message passing framework for molecular graphs.",
"paper": "Gilmer et al., ICML 2017"
},
"SchNet_demo": {
"name": "SchNet",
"description": "Continuous-filter convolutional neural network for modeling quantum interactions.",
"paper": "Schütt et al., JCP 2017"
},
"MEGNet_demo": {
"name": "MatErials Graph Network",
"description": "Graph network with global state for materials property prediction.",
"paper": "Chen et al., Chem. Mater. 2019"
},
"GCN_demo": {
"name": "Graph Convolutional Network",
"description": "Standard graph convolutional network architecture.",
"paper": "Kipf & Welling, ICLR 2017"
},
"SOAP_demo": {
"name": "Smooth Overlap of Atomic Positions",
"description": "Descriptor-based method using SOAP features.",
"paper": "Bartók et al., PRB 2013"
},
"SM_demo": {
"name": "Sine Matrix",
"description": "Descriptor-based method using Sine/Coulomb matrix features.",
"paper": "Various"
}
}
return {
"success": True,
"models": models_info,
"total_models": len(models_info)
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="get_model_config", description="Get the default configuration for a specific model.")
def get_model_config(model_name: str) -> dict:
"""
Get the default configuration for a specific GNN model.
Parameters:
model_name (str): Name of the model (e.g., 'CGCNN_demo', 'SchNet_demo').
Returns:
dict: Contains the default configuration for the model.
"""
try:
config_path = os.path.join(project_root, "config.yml")
if not os.path.exists(config_path):
return {"success": False, "error": "Config file not found"}
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if model_name not in config.get("Models", {}):
return {
"success": False,
"error": f"Model '{model_name}' not found. Available models: {list(config.get('Models', {}).keys())}"
}
model_config = config["Models"][model_name]
processing_config = config.get("Processing", {})
training_config = config.get("Training", {})
return {
"success": True,
"model_name": model_name,
"model_config": model_config,
"processing_config": processing_config,
"training_config": training_config
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="process_structure_data", description="Get guidance and script for processing atomic structure data into graph format for GNN training locally.")
def process_structure_data(
data_path: str = "<your-data-path>",
target_index: int = 0,
graph_max_radius: float = 8.0,
graph_max_neighbors: int = 12,
data_format: str = "cif",
dictionary_source: str = "default"
) -> dict:
"""
Generate local data processing script and configuration for MatDeepLearn.
Server-side processing is not available; use this script locally.
Parameters:
data_path (str): Path to directory containing structure files (for script template).
target_index (int): Index of target column in targets.csv (default: 0).
graph_max_radius (float): Maximum radius for edges in graph (default: 8.0).
graph_max_neighbors (int): Maximum number of neighbors per atom (default: 12).
data_format (str): Structure file format ('cif', 'vasp', 'xyz', 'json').
dictionary_source (str): Atom dictionary source ('default', 'blank', 'generated', 'provided').
Returns:
dict: Contains local processing script and data preparation guide.
"""
try:
process_script = '''"""MatDeepLearn Data Processing Script
Generated for local execution.
"""
import yaml
from pathlib import Path
from matdeeplearn import process
# Configuration - Update this path
DATA_PATH = Path(r"{data_path}")
CONFIG_PATH = Path(r"config.yml")
def main():
if not DATA_PATH.exists():
raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}")
data_path_str = str(DATA_PATH.resolve())
processing_args = {{
"dataset_type": "inmemory",
"data_path": data_path_str,
"target_path": "targets.csv",
"dictionary_source": "{dictionary_source}",
"dictionary_path": "atom_dict.json",
"data_format": "{data_format}",
"verbose": "True",
"graph_max_radius": {graph_max_radius},
"graph_max_neighbors": {graph_max_neighbors},
"voronoi": "False",
"edge_features": "True",
"graph_edge_length": 50,
"SM_descriptor": "False",
"SOAP_descriptor": "False"
}}
print(f"Processing data from {{data_path_str}}...")
dataset = process.get_dataset(
data_path_str,
{target_index}, # target_index
"True", # reprocess
processing_args
)
print(f"Processing complete!")
print(f"Dataset size: {{len(dataset)}}")
if len(dataset) > 0:
print(f"Atoms per structure: {{dataset[0].x.shape[0]}}")
print(f"Node features: {{dataset[0].x.shape[1]}}")
print(f"Edges per structure: {{dataset[0].edge_index.shape[1]}}")
print(f"Processed data saved to: {{data_path_str}}/processed/")
if __name__ == "__main__":
main()
'''.format(
data_path=data_path,
data_format=data_format,
dictionary_source=dictionary_source,
graph_max_radius=graph_max_radius,
graph_max_neighbors=graph_max_neighbors,
target_index=target_index
)
return {
"success": True,
"message": "Server-side data processing is not available. Use the provided script locally.",
"local_processing_script": process_script,
"processing_parameters": {
"graph_max_radius": graph_max_radius,
"graph_max_neighbors": graph_max_neighbors,
"data_format": data_format,
"dictionary_source": dictionary_source,
"target_index": target_index
},
"data_preparation_guide": {
"folder_structure": [
"your_dataset/",
" ├── structure1.cif",
" ├── structure2.cif",
" ├── ...",
" └── targets.csv"
],
"targets_csv_format": "structure_id,target_value (no header, one per line)",
"supported_formats": ["cif", "vasp/poscar", "xyz", "json", "extxyz"],
"naming_convention": "Structure filename (without extension) must match structure_id in targets.csv"
},
"usage_instructions": [
"1. Prepare your data folder with structure files and targets.csv",
"2. Save the script as 'process_data.py'",
"3. Update DATA_PATH to your data folder",
"4. Run: python process_data.py",
"5. Processed data will be saved to data_folder/processed/"
]
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="train_model", description="Generate local training code, configuration and environment setup guide for a GNN model. This tool helps you set up and run training on your local machine.")
def train_model(
data_path: str,
model_name: str = "CGCNN_demo",
epochs: int = 100,
batch_size: int = 32,
learning_rate: float = 0.002,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
test_ratio: float = 0.1,
save_model: bool = True,
model_path: str = "trained_model.pth",
target_index: int = 0,
include_environment_guide: bool = True
) -> dict:
"""
Generate a complete local training script with environment setup guide for MatDeepLearn.
This tool provides everything needed to train GNN models on your local machine.
Parameters:
data_path (str): Path to directory containing processed structure data.
model_name (str): Name of the model to train. Options: 'CGCNN_demo', 'MPNN_demo', 'SchNet_demo', 'MEGNet_demo', 'GCN_demo', 'SOAP_demo', 'SM_demo'.
epochs (int): Number of training epochs (default: 100).
batch_size (int): Training batch size (default: 32).
learning_rate (float): Learning rate (default: 0.002).
train_ratio (float): Ratio of data for training (default: 0.8).
val_ratio (float): Ratio of data for validation (default: 0.1).
test_ratio (float): Ratio of data for testing (default: 0.1).
save_model (bool): Whether to save the trained model (default: True).
model_path (str): Path to save the trained model (default: 'trained_model.pth').
target_index (int): Index of target column in targets.csv (default: 0).
include_environment_guide (bool): Whether to include environment setup guide (default: True).
Returns:
dict: Complete local training package including:
- Ready-to-run Python training script
- Model configuration and hyperparameters
- Environment setup instructions
- Dependencies list
"""
try:
# Load default config
config_path = os.path.join(project_root, "config.yml")
if not os.path.exists(config_path):
return {"success": False, "error": "Config file not found on server"}
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
available_models = list(config.get("Models", {}).keys())
if model_name not in config.get("Models", {}):
return {
"success": False,
"error": f"Model '{model_name}' not found. Available models: {available_models}"
}
# Get model-specific config
model_config = config["Models"][model_name]
script = _build_local_training_script(
data_path_literal=data_path,
model_name=model_name,
target_index=target_index,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
train_ratio=train_ratio,
val_ratio=val_ratio,
test_ratio=test_ratio,
save_model=save_model,
model_output_literal=model_path,
reprocess=False,
config_path_literal="config.yml"
)
# Build environment setup guide
environment_guide = None
if include_environment_guide:
environment_guide = {
"python_version": "Python 3.7 - 3.9 recommended",
"cuda_requirement": "CUDA 10.2 or 11.x for GPU training",
"setup_steps": [
"1. Create a virtual environment: python -m venv matdeeplearn_env",
"2. Activate the environment:",
" - Windows: matdeeplearn_env\\Scripts\\activate",
" - Linux/Mac: source matdeeplearn_env/bin/activate",
"3. Install PyTorch with CUDA support:",
" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118",
"4. Install PyTorch Geometric:",
" pip install torch-geometric",
" pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html",
"5. Install other dependencies:",
" pip install ase pymatgen numpy scipy matplotlib scikit-learn pyyaml",
"6. Clone MatDeepLearn repository:",
" git clone https://github.com/Fung-Lab/MatDeepLearn.git",
"7. Install MatDeepLearn:",
" cd MatDeepLearn && pip install -e ."
],
"core_dependencies": [
"torch>=1.8.0",
"torch-geometric>=2.0.0",
"ase>=3.20.0",
"pymatgen>=2020.9.0",
"numpy>=1.20.0",
"scipy>=1.6.0",
"scikit-learn>=0.24.0",
"pyyaml",
"matplotlib"
],
"optional_dependencies": [
"dscribe>=0.3.5 (for SOAP descriptor)",
"ray>=1.0.0 (for hyperparameter optimization)"
],
"gpu_check_code": "import torch; print(f'CUDA available: {torch.cuda.is_available()}'); print(f'GPU count: {torch.cuda.device_count()}')"
}
# Build config.yml content for user
config_yml_content = yaml.dump({
"Processing": config.get("Processing", {}),
"Training": config.get("Training", {}),
"Models": {model_name: model_config}
}, default_flow_style=False)
return {
"success": True,
"message": "Local training package generated successfully. Save the script and run on your machine.",
"model_name": model_name,
"model_description": _get_model_description(model_name),
"parameters": {
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": learning_rate,
"train_ratio": train_ratio,
"val_ratio": val_ratio,
"test_ratio": test_ratio,
"target_index": target_index,
"save_model": save_model,
"model_output": model_path
},
"model_architecture": {
"hidden_dims": [model_config.get("dim1"), model_config.get("dim2"), model_config.get("dim3")],
"gc_layers": model_config.get("gc_count"),
"pooling": model_config.get("pool"),
"optimizer": model_config.get("optimizer"),
"scheduler": model_config.get("scheduler")
},
"local_training_script": script,
"config_yml_template": config_yml_content,
"environment_guide": environment_guide,
"usage_instructions": {
"step1": "Save the training script as 'train_local.py'",
"step2": "Save the config_yml_template as 'config.yml' in the same directory",
"step3": "Update DATA_PATH in the script to point to your dataset directory",
"step4": "Ensure your dataset has structure files and targets.csv",
"step5": "Run: python train_local.py"
},
"available_models": available_models
}
except Exception as e:
return {"success": False, "error": str(e)}
def _get_model_description(model_name: str) -> dict:
"""Get description for a model."""
descriptions = {
"CGCNN_demo": {
"full_name": "Crystal Graph Convolutional Neural Network",
"description": "A GNN for predicting material properties using crystal graphs with edge-conditioned convolutions.",
"best_for": "Crystalline materials with periodic structures",
"paper": "Xie & Grossman, PRL 2018"
},
"MPNN_demo": {
"full_name": "Message Passing Neural Network",
"description": "General message passing framework for molecular and crystal graphs.",
"best_for": "General molecular and material property prediction",
"paper": "Gilmer et al., ICML 2017"
},
"SchNet_demo": {
"full_name": "SchNet",
"description": "Continuous-filter convolutional neural network with smooth radial basis functions.",
"best_for": "Quantum chemistry and energy predictions",
"paper": "Schütt et al., JCP 2017"
},
"MEGNet_demo": {
"full_name": "MatErials Graph Network",
"description": "Graph network with global state for materials property prediction.",
"best_for": "Materials with complex global properties",
"paper": "Chen et al., Chem. Mater. 2019"
},
"GCN_demo": {
"full_name": "Graph Convolutional Network",
"description": "Standard graph convolutional network architecture.",
"best_for": "Simple and fast baseline models",
"paper": "Kipf & Welling, ICLR 2017"
},
"SOAP_demo": {
"full_name": "Smooth Overlap of Atomic Positions",
"description": "Descriptor-based method using SOAP features.",
"best_for": "Local atomic environment similarity",
"paper": "Bartók et al., PRB 2013"
},
"SM_demo": {
"full_name": "Sine Matrix",
"description": "Descriptor-based method using Sine/Coulomb matrix features.",
"best_for": "Small molecules and simple structures",
"paper": "Various"
}
}
return descriptions.get(model_name, {"full_name": model_name, "description": "Custom model"})
@mcp.tool(name="predict_properties", description="Generate local prediction script for using a trained model to predict properties of new structures.")
def predict_properties(
data_path: str,
model_path: str,
target_index: int = 0
) -> dict:
"""
Generate a local prediction script for predicting properties with a trained model.
Server-side prediction is not available; use this script locally.
Parameters:
data_path (str): Path to directory containing structure files to predict.
model_path (str): Path to the trained model file (.pth).
target_index (int): Index of target column (default: 0).
Returns:
dict: Contains local prediction script and instructions.
"""
try:
predict_script = '''"""MatDeepLearn Prediction Script
Generated for local execution.
"""
import yaml
import torch
from pathlib import Path
from matdeeplearn import training, process
# Configuration - Update these paths
DATA_PATH = Path(r"{data_path}")
MODEL_PATH = Path(r"{model_path}")
CONFIG_PATH = Path(r"config.yml")
TARGET_INDEX = {target_index}
def main():
if not DATA_PATH.exists():
raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}")
if not MODEL_PATH.exists():
raise FileNotFoundError(f"Model path not found: {{MODEL_PATH}}")
if not CONFIG_PATH.exists():
raise FileNotFoundError(f"Config file not found: {{CONFIG_PATH}}")
with CONFIG_PATH.open("r", encoding="utf-8") as f:
config = yaml.safe_load(f)
data_path_str = str(DATA_PATH.resolve())
# Process data
processing_args = config.get("Processing", {{}})
processing_args["data_path"] = data_path_str
dataset = process.get_dataset(
data_path_str,
TARGET_INDEX,
"False",
processing_args
)
job_config = {{
"job_name": "prediction_job",
"model_path": str(MODEL_PATH.resolve()),
"write_output": "True"
}}
print(f"Predicting on {{len(dataset)}} structures...")
# Run prediction
test_error = training.predict(dataset, "l1_loss", job_config)
print(f"Prediction complete!")
print(f"MAE (if targets available): {{test_error:.4f}}")
print(f"Results saved to: prediction_job_predicted_outputs.csv")
if __name__ == "__main__":
main()
'''.format(
data_path=data_path,
model_path=model_path,
target_index=target_index
)
return {
"success": True,
"message": "Server-side prediction is not available. Use the provided script locally.",
"local_prediction_script": predict_script,
"usage_instructions": [
"1. Save the script as 'predict_local.py'",
"2. Ensure config.yml is in the same directory",
"3. Update DATA_PATH and MODEL_PATH to your local paths",
"4. Prepare your data with structure files and targets.csv",
"5. Run: python predict_local.py"
],
"data_requirements": {
"structure_files": "CIF, POSCAR, XYZ, or JSON format files",
"targets.csv": "structure_id,target_value (can be dummy values like 0.0 for pure prediction)"
}
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="cross_validation", description="Generate local cross-validation script and guide for k-fold CV on a dataset.")
def cross_validation(
data_path: str,
model_name: str = "CGCNN_demo",
cv_folds: int = 5,
epochs: int = 100,
batch_size: int = 32,
learning_rate: float = 0.002
) -> dict:
"""
Generate a local cross-validation script for k-fold CV.
Parameters:
data_path (str): Path to directory containing structure data.
model_name (str): Name of the model to use (default: 'CGCNN_demo').
cv_folds (int): Number of cross-validation folds (default: 5).
epochs (int): Number of training epochs per fold (default: 100).
batch_size (int): Batch size (default: 32).
learning_rate (float): Learning rate (default: 0.002).
Returns:
dict: Contains local cross-validation script and instructions.
"""
try:
# Load config to validate model
config_path = os.path.join(project_root, "config.yml")
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if model_name not in config.get("Models", {}):
available = list(config.get("Models", {}).keys())
return {"success": False, "error": f"Model '{model_name}' not found. Available: {available}"}
# Generate CV script
cv_script = '''"""MatDeepLearn Cross-Validation Script
Generated for local execution.
"""
import yaml
import numpy as np
import torch
from pathlib import Path
from matdeeplearn import training, process
# Configuration - Update these paths
DATA_PATH = Path(r"{data_path}")
CONFIG_PATH = Path(r"config.yml")
# CV Parameters
CV_FOLDS = {cv_folds}
MODEL_NAME = "{model_name}"
EPOCHS = {epochs}
BATCH_SIZE = {batch_size}
LEARNING_RATE = {learning_rate}
def main():
if not DATA_PATH.exists():
raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}")
if not CONFIG_PATH.exists():
raise FileNotFoundError(f"Config file not found: {{CONFIG_PATH}}")
with CONFIG_PATH.open("r", encoding="utf-8") as f:
config = yaml.safe_load(f)
job_config = {{
"job_name": "cv_job_{model_name}",
"reprocess": "False",
"model": MODEL_NAME,
"write_output": "True",
"parallel": "False",
"seed": int(np.random.randint(1, 1_000_000)),
"cv_folds": CV_FOLDS
}}
training_config = config.get("Training", {{}}).copy()
training_config.update({{
"target_index": 0,
"verbosity": 5,
}})
model_config = config["Models"][MODEL_NAME].copy()
model_config.update({{
"epochs": EPOCHS,
"batch_size": BATCH_SIZE,
"lr": LEARNING_RATE,
}})
data_path_str = str(DATA_PATH.resolve())
world_size = torch.cuda.device_count()
rank = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running {{CV_FOLDS}}-fold cross validation with {{MODEL_NAME}}")
print(f"Device: {{rank}}, GPU count: {{world_size}}")
# Run CV
cv_results = training.train_cv(
rank,
world_size,
data_path_str,
job_config,
training_config,
model_config,
)
print("\n" + "="*50)
print("Cross-Validation Results:")
print(f"Mean MAE: {{np.mean(cv_results):.4f}} +/- {{np.std(cv_results):.4f}}")
print(f"Individual fold errors: {{cv_results}}")
print("="*50)
if __name__ == "__main__":
main()
'''.format(
data_path=data_path,
cv_folds=cv_folds,
model_name=model_name,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate
)
return {
"success": True,
"message": "Cross-validation script generated. Run locally on your machine.",
"model_name": model_name,
"cv_folds": cv_folds,
"parameters": {
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": learning_rate
},
"local_cv_script": cv_script,
"usage_instructions": [
"1. Save the script as 'run_cv.py'",
"2. Ensure config.yml is in the same directory",
"3. Update DATA_PATH to your dataset location",
"4. Run: python run_cv.py"
],
"expected_output": f"Will output mean MAE and std across {cv_folds} folds"
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="analyze_structure", description="Analyze atomic structure data. You can pass file content directly (for CIF, XYZ, POSCAR formats) or a file path on the server.")
def analyze_structure(
file_content: Optional[str] = None,
file_format: Optional[str] = None,
structure_file: Optional[str] = None
) -> dict:
"""
Analyze the structure of an atomic structure.
Parameters:
file_content (str, optional): The content of the structure file (CIF, XYZ, POSCAR, JSON format).
Pass the actual file content directly here.
file_format (str, optional): Format of the file content ('cif', 'xyz', 'vasp', 'json').
Required when file_content is provided.
structure_file (str, optional): Path to a structure file on the server (legacy option).
Returns:
dict: Contains structure analysis including atoms, bonds, and graph info.
Example usage:
analyze_structure(file_content="your CIF file content here...", file_format="cif")
"""
try:
import ase
from ase import io
from io import StringIO
structure = None
# Method 1: Direct file content (preferred for remote access)
if file_content is not None:
if file_format is None:
return {"success": False, "error": "file_format is required when providing file_content. Use 'cif', 'xyz', 'vasp', or 'json'."}
# Map common format names
format_map = {
'cif': 'cif',
'xyz': 'xyz',
'vasp': 'vasp',
'poscar': 'vasp',
'json': 'json',
'extxyz': 'extxyz'
}
fmt = format_map.get(file_format.lower())
if fmt is None:
return {"success": False, "error": f"Unsupported format: {file_format}. Supported: cif, xyz, vasp, poscar, json, extxyz"}
# Create a temporary file to read the structure
with tempfile.NamedTemporaryFile(mode='w', suffix=f'.{fmt}', delete=False) as tmp:
tmp.write(file_content)
tmp_path = tmp.name
try:
structure = ase.io.read(tmp_path, format=fmt)
finally:
os.unlink(tmp_path) # Clean up temp file
# Method 2: File path on server (legacy)
elif structure_file is not None:
if not os.path.exists(structure_file):
return {"success": False, "error": f"Structure file not found: {structure_file}. Tip: For remote MCP, pass file_content directly instead of file path."}
structure = ase.io.read(structure_file)
else:
return {"success": False, "error": "Either file_content (with file_format) or structure_file must be provided."}
# Get basic info
symbols = structure.get_chemical_symbols()
positions = structure.get_positions().tolist()
cell = structure.get_cell().tolist() if any(structure.pbc) else None
pbc = structure.pbc.tolist()
# Get distance matrix
distance_matrix = structure.get_all_distances(mic=True)
# Analyze connectivity
cutoff_radius = 8.0
neighbors_count = []
for i in range(len(structure)):
neighbors = np.sum((distance_matrix[i] > 0) & (distance_matrix[i] < cutoff_radius))
neighbors_count.append(int(neighbors))
return {
"success": True,
"num_atoms": len(structure),
"chemical_formula": structure.get_chemical_formula(),
"elements": list(set(symbols)),
"element_counts": {elem: symbols.count(elem) for elem in set(symbols)},
"has_periodicity": any(pbc),
"pbc": pbc,
"cell": cell,
"positions": positions[:10] if len(positions) > 10 else positions, # First 10 positions
"average_neighbors": float(np.mean(neighbors_count)),
"min_neighbors": min(neighbors_count),
"max_neighbors": max(neighbors_count),
"min_distance": float(distance_matrix[distance_matrix > 0].min()),
"max_distance": float(distance_matrix.max())
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="compare_models", description="Generate local model comparison script to benchmark different GNN models on a dataset.")
def compare_models(
data_path: str,
model_list: List[str] = None,
epochs: int = 50,
batch_size: int = 32
) -> dict:
"""
Generate a local script to compare performance of different GNN models.
Parameters:
data_path (str): Path to directory containing structure data.
model_list (List[str]): List of models to compare (default: all available).
epochs (int): Number of training epochs per model (default: 50).
batch_size (int): Batch size for training (default: 32).
Returns:
dict: Contains local comparison script and model recommendations.
"""
try:
# Load config
config_path = os.path.join(project_root, "config.yml")
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
available_models = list(config.get("Models", {}).keys())
if model_list is None:
model_list = available_models
else:
# Validate models
invalid = [m for m in model_list if m not in available_models]
if invalid:
return {
"success": False,
"error": f"Invalid models: {invalid}. Available: {available_models}"
}
# Generate comparison script
models_str = str(model_list)
compare_script = '''"""MatDeepLearn Model Comparison Script
Generated for local benchmarking.
"""
import yaml
import numpy as np
import torch
import time
from pathlib import Path
from matdeeplearn import training, process
# Configuration - Update these paths
DATA_PATH = Path(r"{data_path}")
CONFIG_PATH = Path(r"config.yml")
# Comparison Parameters
MODEL_LIST = {models}
EPOCHS = {epochs}
BATCH_SIZE = {batch_size}
def train_single_model(model_name, data_path_str, config, rank, world_size):
"""Train a single model and return results."""
job_config = {{
"job_name": f"compare_{{model_name}}",
"reprocess": "False",
"model": model_name,
"save_model": "False",
"write_output": "False",
"parallel": "False",
"seed": 42,
}}
training_config = config.get("Training", {{}}).copy()
training_config.update({{
"target_index": 0,
"train_ratio": 0.8,
"val_ratio": 0.1,
"test_ratio": 0.1,
"verbosity": 5,
}})
model_config = config["Models"][model_name].copy()
model_config.update({{
"epochs": EPOCHS,
"batch_size": BATCH_SIZE,
}})
start_time = time.time()
errors = training.train_regular(
rank,
world_size,
data_path_str,
job_config,
training_config,
model_config,
)
elapsed = time.time() - start_time
return {{
"model": model_name,
"train_error": errors[0],
"val_error": errors[1],
"test_error": errors[2],
"training_time": elapsed
}}
def main():
if not DATA_PATH.exists():
raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}")
if not CONFIG_PATH.exists():
raise FileNotFoundError(f"Config file not found: {{CONFIG_PATH}}")
with CONFIG_PATH.open("r", encoding="utf-8") as f:
config = yaml.safe_load(f)
data_path_str = str(DATA_PATH.resolve())
world_size = torch.cuda.device_count()
rank = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {{rank}}, GPU count: {{world_size}}")
print(f"Comparing models: {{MODEL_LIST}}")
print("="*60)
results = []
for model_name in MODEL_LIST:
print(f"\nTraining {{model_name}}...")
try:
result = train_single_model(model_name, data_path_str, config, rank, world_size)
results.append(result)
print(f" Test MAE: {{result['test_error']:.4f}}, Time: {{result['training_time']:.1f}}s")
except Exception as e:
print(f" Error: {{e}}")
results.append({{
"model": model_name,
"error": str(e)
}})
# Summary
print("\n" + "="*60)
print("MODEL COMPARISON RESULTS")
print("="*60)
print(f"{{' Model':<20}} {{'Test MAE':>12}} {{'Val MAE':>12}} {{'Time (s)':>10}}")
print("-"*60)
valid_results = [r for r in results if "test_error" in r]
for r in sorted(valid_results, key=lambda x: x["test_error"]):
print(f"{{r['model']:<20}} {{r['test_error']:>12.4f}} {{r['val_error']:>12.4f}} {{r['training_time']:>10.1f}}")
if valid_results:
best = min(valid_results, key=lambda x: x["test_error"])
print("\n" + "="*60)
print(f"BEST MODEL: {{best['model']}} with Test MAE = {{best['test_error']:.4f}}")
print("="*60)
if __name__ == "__main__":
main()
'''.format(
data_path=data_path,
models=models_str,
epochs=epochs,
batch_size=batch_size
)
# Model recommendations
recommendations = {
"for_crystals": ["CGCNN_demo", "MEGNet_demo", "SchNet_demo"],
"for_molecules": ["MPNN_demo", "SchNet_demo"],
"for_speed": ["GCN_demo", "SM_demo"],
"for_accuracy": ["CGCNN_demo", "MEGNet_demo"]
}
return {
"success": True,
"message": "Model comparison script generated. Run locally to benchmark models.",
"models_to_compare": model_list,
"parameters": {
"epochs": epochs,
"batch_size": batch_size
},
"local_comparison_script": compare_script,
"model_recommendations": recommendations,
"usage_instructions": [
"1. Save the script as 'compare_models.py'",
"2. Ensure config.yml is in the same directory",
"3. Update DATA_PATH to your dataset location",
"4. Run: python compare_models.py",
"5. Results will show ranked models by test MAE"
],
"available_models": available_models
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="get_dataset_info", description="Get information about a dataset directory or uploaded dataset.")
def get_dataset_info(
data_path: Optional[str] = None,
structure_files: Optional[List[str]] = None,
targets_csv_content: Optional[str] = None
) -> dict:
"""
Get information about a dataset.
Parameters:
data_path (str, optional): Path to directory containing structure data (server-side).
structure_files (List[str], optional): List of structure filenames (for validation check).
targets_csv_content (str, optional): Content of targets.csv file to analyze.
Returns:
dict: Contains dataset information including file counts and formats.
"""
try:
# If analyzing uploaded content
if targets_csv_content is not None:
import csv
from io import StringIO
reader = csv.reader(StringIO(targets_csv_content))
rows = list(reader)
num_samples = len(rows)
# Parse target values
target_values = []
for row in rows:
if len(row) >= 2:
try:
target_values.append(float(row[1]))
except:
pass
result = {
"success": True,
"source": "uploaded_content",
"num_samples": num_samples,
"has_targets_csv": True,
"ready_for_training": True
}
if target_values:
result["target_statistics"] = {
"min": min(target_values),
"max": max(target_values),
"mean": sum(target_values) / len(target_values)
}
if structure_files:
extensions = {}
for f in structure_files:
ext = os.path.splitext(f)[1].lower()
extensions[ext] = extensions.get(ext, 0) + 1
result["file_extensions"] = extensions
result["num_structure_files"] = len(structure_files)
return result
# Traditional path-based analysis
if data_path is None:
return {"success": False, "error": "Either data_path or targets_csv_content must be provided"}
if not os.path.exists(data_path):
return {"success": False, "error": f"Data path not found: {data_path}"}
# Count files by extension
extensions = {}
for file in os.listdir(data_path):
ext = os.path.splitext(file)[1].lower()
extensions[ext] = extensions.get(ext, 0) + 1
# Check for required files
has_targets = os.path.exists(os.path.join(data_path, "targets.csv"))
has_atom_dict = os.path.exists(os.path.join(data_path, "atom_dict.json"))
has_processed = os.path.exists(os.path.join(data_path, "processed"))
# Read targets if available
num_samples = 0
if has_targets:
import csv
with open(os.path.join(data_path, "targets.csv")) as f:
num_samples = sum(1 for _ in csv.reader(f))
return {
"success": True,
"source": "server_path",
"data_path": data_path,
"file_extensions": extensions,
"has_targets_csv": has_targets,
"has_atom_dict": has_atom_dict,
"has_processed_data": has_processed,
"num_samples": num_samples,
"ready_for_training": has_targets
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="quick_structure_analysis", description="Quick analysis of a structure file content without needing a server path. Ideal for analyzing uploaded files from Cursor.")
def quick_structure_analysis(
file_content: str,
file_format: str,
include_positions: bool = False,
include_distances: bool = True
) -> dict:
"""
Perform quick analysis on structure file content uploaded directly.
This is the recommended tool for analyzing structures when using remote MCP.
Parameters:
file_content (str): The complete content of the structure file.
file_format (str): Format of the file - 'cif', 'xyz', 'vasp'/'poscar', 'json', 'extxyz'.
include_positions (bool): Whether to include atomic positions in output (default: False).
include_distances (bool): Whether to include distance analysis (default: True).
Returns:
dict: Comprehensive structure analysis.
Example:
quick_structure_analysis(
file_content="data_NaCl\\n_cell_length_a 5.64...",
file_format="cif"
)
"""
try:
import ase
from ase import io
# Map format names
format_map = {
'cif': 'cif',
'xyz': 'xyz',
'vasp': 'vasp',
'poscar': 'vasp',
'json': 'json',
'extxyz': 'extxyz'
}
fmt = format_map.get(file_format.lower())
if fmt is None:
return {
"success": False,
"error": f"Unsupported format: {file_format}. Supported: cif, xyz, vasp, poscar, json, extxyz"
}
# Write to temp file and read
with tempfile.NamedTemporaryFile(mode='w', suffix=f'.{fmt}', delete=False) as tmp:
tmp.write(file_content)
tmp_path = tmp.name
try:
structure = ase.io.read(tmp_path, format=fmt)
finally:
os.unlink(tmp_path)
# Basic analysis
symbols = structure.get_chemical_symbols()
cell = structure.get_cell().tolist() if any(structure.pbc) else None
pbc = structure.pbc.tolist()
result = {
"success": True,
"num_atoms": len(structure),
"chemical_formula": structure.get_chemical_formula(),
"reduced_formula": structure.get_chemical_formula(mode='reduce'),
"elements": sorted(list(set(symbols))),
"element_counts": {elem: symbols.count(elem) for elem in set(symbols)},
"has_periodicity": any(pbc),
"pbc": pbc,
"cell_parameters": cell,
"volume": float(structure.get_volume()) if any(pbc) else None,
}
if include_positions:
positions = structure.get_positions().tolist()
result["positions"] = positions
result["symbols"] = symbols
if include_distances:
distance_matrix = structure.get_all_distances(mic=True)
cutoff_radius = 8.0
neighbors_count = []
for i in range(len(structure)):
neighbors = np.sum((distance_matrix[i] > 0) & (distance_matrix[i] < cutoff_radius))
neighbors_count.append(int(neighbors))
result["distance_analysis"] = {
"cutoff_radius": cutoff_radius,
"average_neighbors": float(np.mean(neighbors_count)),
"min_neighbors": min(neighbors_count),
"max_neighbors": max(neighbors_count),
"min_distance": float(distance_matrix[distance_matrix > 0].min()),
"max_distance": float(distance_matrix.max())
}
# Check if suitable for GNN
result["gnn_suitable"] = {
"has_enough_atoms": len(structure) >= 2,
"has_3d_coordinates": True,
"is_periodic": any(pbc),
"recommendation": "Suitable for GNN training" if len(structure) >= 2 else "Too few atoms"
}
return result
except Exception as e:
return {"success": False, "error": str(e)}
# ============================================================================
# 基于会话的训练和模型管理工具
# ============================================================================
@mcp.tool(name="process_session_data", description="Analyze uploaded session data and generate local processing script for GNN training.")
def process_session_data(
session_id: str,
target_index: int = 0,
graph_max_radius: float = 8.0,
graph_max_neighbors: int = 12,
data_format: Optional[str] = None,
dictionary_source: str = "default"
) -> dict:
"""
Analyze uploaded session data and generate local processing script.
Server-side graph processing is not available; use the generated script locally.
Parameters:
session_id (str): The session ID.
target_index (int): Index of target column in targets.csv (default: 0).
graph_max_radius (float): Maximum radius for graph edges (default: 8.0 Angstrom).
graph_max_neighbors (int): Maximum neighbors per atom (default: 12).
data_format (str, optional): Explicit structure format ('cif', 'vasp', 'xyz', 'json').
dictionary_source (str): Atom dictionary source ('default', 'blank', 'generated', 'provided').
Returns:
dict: Data analysis and local processing script.
"""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
data_path = os.path.join(session_path, "data")
# Check for required files and analyze
has_targets = os.path.exists(os.path.join(data_path, "targets.csv"))
skip_files = {"targets.csv", "id.json", "atom_dict.json", "processed"}
structure_files = [
f for f in os.listdir(data_path)
if f not in skip_files and not f.startswith('.') and os.path.isfile(os.path.join(data_path, f))
]
# Detect format
determined_format = data_format
if determined_format is None:
detected = _infer_structure_format(structure_files)
determined_format = detected if detected else "json"
# Analyze targets if present
target_stats = None
if has_targets:
import csv
targets_path = os.path.join(data_path, "targets.csv")
with open(targets_path, 'r') as f:
reader = csv.reader(f)
rows = list(reader)
target_values = []
for row in rows:
if len(row) >= 2:
try:
target_values.append(float(row[1]))
except:
pass
if target_values:
target_stats = {
"num_samples": len(rows),
"min": min(target_values),
"max": max(target_values),
"mean": sum(target_values) / len(target_values)
}
# Generate processing script
process_script = '''"""MatDeepLearn Data Processing Script
Generated for session: {session_id}
"""
import yaml
from pathlib import Path
from matdeeplearn import process
# Configuration - Update this path to your local data copy
DATA_PATH = Path(r"{data_path}")
CONFIG_PATH = Path(r"config.yml")
def main():
if not DATA_PATH.exists():
raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}")
data_path_str = str(DATA_PATH.resolve())
processing_args = {{
"dataset_type": "inmemory",
"data_path": data_path_str,
"target_path": "targets.csv",
"dictionary_source": "{dictionary_source}",
"dictionary_path": "atom_dict.json",
"data_format": "{data_format}",
"verbose": "True",
"graph_max_radius": {graph_max_radius},
"graph_max_neighbors": {graph_max_neighbors},
"voronoi": "False",
"edge_features": "True",
"graph_edge_length": 50,
"SM_descriptor": "False",
"SOAP_descriptor": "False"
}}
print(f"Processing data from {{data_path_str}}...")
dataset = process.get_dataset(
data_path_str,
{target_index}, # target_index
"True", # reprocess
processing_args
)
print(f"Processing complete!")
print(f"Dataset size: {{len(dataset)}}")
if len(dataset) > 0:
print(f"Atoms per structure: {{dataset[0].x.shape[0]}}")
print(f"Node features: {{dataset[0].x.shape[1]}}")
print(f"Edges per structure: {{dataset[0].edge_index.shape[1]}}")
if __name__ == "__main__":
main()
'''.format(
session_id=session_id,
data_path="<path-to-your-local-data>",
dictionary_source=dictionary_source,
data_format=determined_format,
graph_max_radius=graph_max_radius,
graph_max_neighbors=graph_max_neighbors,
target_index=target_index
)
return {
"success": True,
"session_id": session_id,
"data_analysis": {
"num_structure_files": len(structure_files),
"structure_files_sample": structure_files[:10],
"detected_format": determined_format,
"has_targets": has_targets,
"target_statistics": target_stats,
"server_data_path": data_path
},
"local_processing_script": process_script,
"processing_parameters": {
"graph_max_radius": graph_max_radius,
"graph_max_neighbors": graph_max_neighbors,
"data_format": determined_format,
"dictionary_source": dictionary_source,
"target_index": target_index
},
"next_steps": [
"1. Copy session data to your local machine",
"2. Save the processing script as 'process_data.py'",
"3. Update DATA_PATH to your local data folder",
"4. Run: python process_data.py",
"5. Then use train_session_model to get training script"
],
"note": "Server-side processing is not available. Use the script locally with MatDeepLearn installed."
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="train_session_model", description="Generate comprehensive local training script and environment guide for a session dataset.")
def train_session_model(
session_id: str,
model_name: str = "CGCNN_demo",
epochs: int = 100,
batch_size: int = 32,
learning_rate: float = 0.002,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
test_ratio: float = 0.1,
model_save_name: Optional[str] = None,
include_data_export_guide: bool = True
) -> dict:
"""
Generate a complete local training package for a session dataset.
Includes training script, configuration, and data export instructions.
Parameters:
session_id (str): The session ID with processed data.
model_name (str): Model to use - "CGCNN_demo", "SchNet_demo", "MPNN_demo", etc.
epochs (int): Number of training epochs (default: 100).
batch_size (int): Batch size (default: 32).
learning_rate (float): Learning rate (default: 0.002).
train_ratio (float): Training data ratio (default: 0.8).
val_ratio (float): Validation data ratio (default: 0.1).
test_ratio (float): Test data ratio (default: 0.1).
model_save_name (str, optional): Custom name for saved model.
include_data_export_guide (bool): Include guide for exporting session data (default: True).
Returns:
dict: Complete local training package with script, config, and instructions.
"""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
data_path = os.path.join(session_path, "data")
# Load config
config_path = os.path.join(project_root, "config.yml")
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if model_name not in config.get("Models", {}):
available_models = list(config.get("Models", {}).keys())
return {
"success": False,
"error": f"Model '{model_name}' not found. Available: {available_models}"
}
# Generate model filename
if model_save_name:
model_filename = f"{model_save_name}.pth"
else:
model_filename = f"{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pth"
placeholder_data_path = "<path-to-your-local-dataset>"
script = _build_local_training_script(
data_path_literal=placeholder_data_path,
model_name=model_name,
target_index=0,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
train_ratio=train_ratio,
val_ratio=val_ratio,
test_ratio=test_ratio,
save_model=True,
model_output_literal=model_filename,
reprocess=False,
config_path_literal="config.yml"
)
# Build config.yml content
model_config = config["Models"][model_name]
config_yml_content = yaml.dump({
"Processing": config.get("Processing", {}),
"Training": config.get("Training", {}),
"Models": {model_name: model_config}
}, default_flow_style=False)
# Data export guide
data_export_guide = None
if include_data_export_guide:
data_export_guide = {
"session_data_location": data_path,
"required_files": [
"targets.csv - CSV file with structure_id,target_value per line",
"Structure files (CIF, POSCAR, XYZ, or JSON format)"
],
"export_options": [
"Option 1: Copy the session data folder to your local machine",
"Option 2: Use your original dataset files directly",
"Option 3: Re-upload structures to a local directory"
],
"data_format_requirements": {
"targets.csv": "structure_id,target_value (no header required)",
"structures": "One file per structure, filename should match structure_id in targets.csv"
}
}
return {
"success": True,
"message": "Local training package generated. Follow the instructions to train on your machine.",
"session_id": session_id,
"model_name": model_name,
"model_description": _get_model_description(model_name),
"recommended_model_filename": model_filename,
"server_data_path": data_path,
"parameters": {
"epochs": epochs,
"batch_size": batch_size,
"learning_rate": learning_rate,
"train_ratio": train_ratio,
"val_ratio": val_ratio,
"test_ratio": test_ratio
},
"local_training_script": script,
"config_yml_template": config_yml_content,
"data_export_guide": data_export_guide,
"setup_instructions": [
"1. Set up Python environment with MatDeepLearn dependencies",
"2. Save the training script as 'train_local.py'",
"3. Save config_yml_template as 'config.yml'",
"4. Copy your dataset to a local folder or use existing data",
"5. Update DATA_PATH in the script",
"6. Run: python train_local.py"
],
"environment_tip": "Use get_environment_requirements tool for detailed environment setup guide"
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="predict_with_session_model", description="Generate local prediction script for using a model with session data.")
def predict_with_session_model(
session_id: str,
model_path: str = "<path-to-your-model.pth>"
) -> dict:
"""
Generate a local prediction script for a session's data.
Server-side prediction is not available; use the generated script locally.
Parameters:
session_id (str): The session ID.
model_path (str): Path to the trained model file on your local machine.
Returns:
dict: Local prediction script and session data info.
"""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
data_path = os.path.join(session_path, "data")
# Get session data info
structure_files = []
has_targets = os.path.exists(os.path.join(data_path, "targets.csv"))
if os.path.exists(data_path):
skip_files = {"targets.csv", "id.json", "atom_dict.json", "processed"}
structure_files = [
f for f in os.listdir(data_path)
if f not in skip_files and not f.startswith('.') and os.path.isfile(os.path.join(data_path, f))
]
predict_script = '''"""MatDeepLearn Prediction Script
Generated for session: {session_id}
"""
import yaml
import torch
from pathlib import Path
from matdeeplearn import training, process
# Configuration - Update these paths
DATA_PATH = Path(r"{data_placeholder}")
MODEL_PATH = Path(r"{model_path}")
CONFIG_PATH = Path(r"config.yml")
def main():
if not DATA_PATH.exists():
raise FileNotFoundError(f"Data path not found: {{DATA_PATH}}")
if not MODEL_PATH.exists():
raise FileNotFoundError(f"Model path not found: {{MODEL_PATH}}")
if not CONFIG_PATH.exists():
raise FileNotFoundError(f"Config file not found: {{CONFIG_PATH}}")
with CONFIG_PATH.open("r", encoding="utf-8") as f:
config = yaml.safe_load(f)
data_path_str = str(DATA_PATH.resolve())
# Process data
processing_args = config.get("Processing", {{}}).copy()
processing_args["data_path"] = data_path_str
dataset = process.get_dataset(
data_path_str,
0, # target_index
"True", # reprocess
processing_args
)
job_config = {{
"job_name": "prediction_{session_id}",
"model_path": str(MODEL_PATH.resolve()),
"write_output": "True"
}}
print(f"Predicting on {{len(dataset)}} structures...")
print(f"Using model: {{MODEL_PATH}}")
# Run prediction
test_error = training.predict(dataset, "l1_loss", job_config)
print(f"\nPrediction complete!")
print(f"MAE (if targets available): {{test_error:.4f}}")
print(f"Results saved to: prediction_{session_id}_predicted_outputs.csv")
if __name__ == "__main__":
main()
'''.format(
session_id=session_id,
data_placeholder="<path-to-your-local-data>",
model_path=model_path
)
return {
"success": True,
"message": "Server-side prediction is not available. Use the generated script locally.",
"session_id": session_id,
"session_data_info": {
"server_data_path": data_path,
"num_structure_files": len(structure_files),
"has_targets": has_targets
},
"local_prediction_script": predict_script,
"usage_instructions": [
"1. Copy the session data to your local machine",
"2. Save the script as 'predict_local.py'",
"3. Ensure config.yml is in the same directory",
"4. Update DATA_PATH and MODEL_PATH",
"5. Run: python predict_local.py"
],
"note": "You need a trained model (.pth file) to run predictions"
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="download_session_data", description="Get information about session data for local download/copy.")
def download_session_data(session_id: str) -> dict:
"""
Get session data location and summary for copying to local machine.
Parameters:
session_id (str): The session ID.
Returns:
dict: Session data location and file list.
"""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
data_path = os.path.join(session_path, "data")
# List all files
files_info = []
total_size = 0
if os.path.exists(data_path):
for root, dirs, files in os.walk(data_path):
for f in files:
file_path = os.path.join(root, f)
rel_path = os.path.relpath(file_path, data_path)
size = os.path.getsize(file_path)
total_size += size
files_info.append({
"path": rel_path.replace("\\", "/"),
"size_bytes": size
})
return {
"success": True,
"session_id": session_id,
"data_location": data_path,
"total_files": len(files_info),
"total_size_mb": total_size / (1024 * 1024),
"files": files_info[:50], # First 50 files
"files_truncated": len(files_info) > 50,
"copy_instructions": [
f"Server data location: {data_path}",
"To use locally:",
"1. Copy the entire data folder to your local machine",
"2. Or use scp/rsync if accessing remote server",
"3. Update DATA_PATH in your training/processing scripts"
]
}
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="run_cross_validation_session", description="Generate local cross-validation script for session data.")
def run_cross_validation_session(
session_id: str,
model_name: str = "CGCNN_demo",
cv_folds: int = 5,
epochs: int = 100,
batch_size: int = 32,
learning_rate: float = 0.002
) -> dict:
"""
Generate a local cross-validation script for session data.
Parameters:
session_id (str): The session ID.
model_name (str): Model to use (default: "CGCNN_demo").
cv_folds (int): Number of folds (default: 5).
epochs (int): Training epochs per fold (default: 100).
batch_size (int): Batch size (default: 32).
learning_rate (float): Learning rate (default: 0.002).
Returns:
dict: Local cross-validation script and instructions.
"""
try:
session_path = _get_session_path(session_id)
if not os.path.exists(session_path):
return {"success": False, "error": f"Session not found: {session_id}"}
data_path = os.path.join(session_path, "data")
# Use the cross_validation tool to generate the script
result = cross_validation(
data_path=data_path,
model_name=model_name,
cv_folds=cv_folds,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate
)
if result.get("success"):
result["session_id"] = session_id
result["session_data_path"] = data_path
result["note"] = "Update DATA_PATH in the script to your local dataset copy"
return result
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="get_environment_requirements", description="Get comprehensive environment setup guide and requirements for running MatDeepLearn locally.")
def get_environment_requirements(
target_os: str = "auto",
cuda_version: str = "11.8",
include_optional: bool = True,
output_format: str = "detailed"
) -> dict:
"""
Get complete environment setup requirements for local MatDeepLearn training.
Parameters:
target_os (str): Target operating system - "windows", "linux", "macos", or "auto" (default: "auto").
cuda_version (str): Target CUDA version - "10.2", "11.7", "11.8", "12.1" (default: "11.8").
include_optional (bool): Include optional dependencies like SOAP, ray (default: True).
output_format (str): Output format - "detailed", "minimal", "script" (default: "detailed").
Returns:
dict: Complete environment setup guide with installation commands.
"""
try:
import platform
# Auto-detect OS
if target_os == "auto":
system = platform.system().lower()
if system == "windows":
target_os = "windows"
elif system == "darwin":
target_os = "macos"
else:
target_os = "linux"
# PyTorch installation URLs based on CUDA version
pytorch_urls = {
"10.2": "https://download.pytorch.org/whl/cu102",
"11.7": "https://download.pytorch.org/whl/cu117",
"11.8": "https://download.pytorch.org/whl/cu118",
"12.1": "https://download.pytorch.org/whl/cu121",
"cpu": ""
}
pytorch_url = pytorch_urls.get(cuda_version, pytorch_urls["11.8"])
# Core dependencies
core_deps = [
"torch>=1.8.0",
"torch-geometric>=2.0.0",
"torch-scatter",
"torch-sparse",
"torch-cluster",
"torch-spline-conv",
"ase>=3.20.0",
"pymatgen>=2020.9.0",
"numpy>=1.20.0",
"scipy>=1.6.0",
"scikit-learn>=0.24.0",
"pyyaml",
"matplotlib",
"pandas"
]
optional_deps = [
"dscribe>=0.3.5 # For SOAP descriptor",
"ray>=1.0.0 # For hyperparameter optimization",
"joblib>=0.13.0 # For parallel processing"
]
# Generate requirements.txt content
requirements_content = "# MatDeepLearn Requirements\n"
requirements_content += "# Core dependencies\n"
for dep in core_deps:
requirements_content += f"{dep}\n"
if include_optional:
requirements_content += "\n# Optional dependencies\n"
for dep in optional_deps:
requirements_content += f"{dep}\n"
# Generate installation script based on OS
if target_os == "windows":
activation_cmd = "matdeeplearn_env\\Scripts\\activate"
python_cmd = "python"
shell_comment = "REM"
script_ext = ".bat"
else:
activation_cmd = "source matdeeplearn_env/bin/activate"
python_cmd = "python3"
shell_comment = "#"
script_ext = ".sh"
install_script = f"""{shell_comment} MatDeepLearn Environment Setup Script
{shell_comment} Target OS: {target_os}, CUDA: {cuda_version}
{shell_comment} Step 1: Create virtual environment
{python_cmd} -m venv matdeeplearn_env
{shell_comment} Step 2: Activate environment
{activation_cmd}
{shell_comment} Step 3: Upgrade pip
pip install --upgrade pip
{shell_comment} Step 4: Install PyTorch with CUDA {cuda_version}
pip install torch torchvision torchaudio --index-url {pytorch_url}
{shell_comment} Step 5: Install PyTorch Geometric
pip install torch-geometric
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cu{cuda_version.replace('.', '')}.html
{shell_comment} Step 6: Install core dependencies
pip install ase pymatgen numpy scipy matplotlib scikit-learn pyyaml pandas
{shell_comment} Step 7: Clone and install MatDeepLearn
git clone https://github.com/Fung-Lab/MatDeepLearn.git
cd MatDeepLearn
pip install -e .
"""
if include_optional:
install_script += f"""
{shell_comment} Step 8 (Optional): Install additional dependencies
pip install dscribe ray joblib
"""
# GPU check script
gpu_check_script = """import torch
print("="*50)
print("MatDeepLearn Environment Check")
print("="*50)
print(f"Python version: {__import__('sys').version}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA version: {torch.version.cuda}")
print(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
print(f" GPU {i}: {torch.cuda.get_device_name(i)}")
else:
print("WARNING: CUDA not available. Training will run on CPU (slower).")
try:
import torch_geometric
print(f"PyTorch Geometric version: {torch_geometric.__version__}")
except ImportError:
print("ERROR: PyTorch Geometric not installed!")
try:
import ase
print(f"ASE version: {ase.__version__}")
except ImportError:
print("ERROR: ASE not installed!")
try:
import pymatgen
print(f"Pymatgen version: {pymatgen.__version__}")
except ImportError:
print("ERROR: Pymatgen not installed!")
print("="*50)
print("Environment check complete!")
"""
result = {
"success": True,
"target_os": target_os,
"cuda_version": cuda_version,
"python_requirement": "Python 3.7 - 3.10 recommended",
"core_dependencies": core_deps,
"requirements_txt": requirements_content,
"installation_script": install_script,
"script_filename": f"setup_matdeeplearn{script_ext}",
"gpu_check_script": gpu_check_script,
"gpu_check_filename": "check_environment.py",
"quick_start": [
f"1. Save installation script as setup_matdeeplearn{script_ext}",
f"2. Run the script to set up environment",
"3. Save check_environment.py and run to verify installation",
"4. Use train_model or train_session_model to generate training scripts"
]
}
if include_optional:
result["optional_dependencies"] = optional_deps
if output_format == "detailed":
result["detailed_steps"] = {
"step1_venv": {
"description": "Create isolated Python environment",
"command": f"{python_cmd} -m venv matdeeplearn_env",
"note": "This creates a clean environment for MatDeepLearn"
},
"step2_activate": {
"description": "Activate the environment",
"command": activation_cmd,
"note": "Run this each time before using MatDeepLearn"
},
"step3_pytorch": {
"description": "Install PyTorch with GPU support",
"command": f"pip install torch torchvision torchaudio --index-url {pytorch_url}",
"note": f"This installs PyTorch for CUDA {cuda_version}. For CPU only, omit --index-url"
},
"step4_pyg": {
"description": "Install PyTorch Geometric",
"commands": [
"pip install torch-geometric",
f"pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.0.0+cu{cuda_version.replace('.', '')}.html"
],
"note": "PyG requires matching versions with your PyTorch"
},
"step5_deps": {
"description": "Install material science libraries",
"command": "pip install ase pymatgen numpy scipy matplotlib scikit-learn pyyaml",
"note": "These handle crystal structures and scientific computing"
},
"step6_matdeeplearn": {
"description": "Install MatDeepLearn",
"commands": [
"git clone https://github.com/Fung-Lab/MatDeepLearn.git",
"cd MatDeepLearn",
"pip install -e ."
],
"note": "Editable install allows code modifications"
}
}
result["troubleshooting"] = {
"cuda_not_found": "Ensure NVIDIA drivers are installed and match CUDA version",
"pyg_version_mismatch": "PyTorch Geometric wheels must match your PyTorch version exactly",
"memory_error": "Reduce batch_size if you encounter GPU memory errors",
"import_error": "Verify all dependencies installed in the same environment"
}
return result
except Exception as e:
return {"success": False, "error": str(e)}
@mcp.tool(name="get_hyperparameter_guide", description="Get hyperparameter tuning guide and recommendations for different models and dataset sizes.")
def get_hyperparameter_guide(
model_name: str = "CGCNN_demo",
dataset_size: Optional[int] = None,
task_type: str = "regression"
) -> dict:
"""
Get hyperparameter recommendations and tuning guide.
Parameters:
model_name (str): Model name (default: "CGCNN_demo").
dataset_size (int, optional): Approximate size of your dataset.
task_type (str): Task type - "regression" or "classification" (default: "regression").
Returns:
dict: Hyperparameter recommendations and tuning guide.
"""
try:
# Load config
config_path = os.path.join(project_root, "config.yml")
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
available_models = list(config.get("Models", {}).keys())
if model_name not in config.get("Models", {}):
return {
"success": False,
"error": f"Model '{model_name}' not found. Available: {available_models}"
}
model_config = config["Models"][model_name]
# Base recommendations
recommendations = {
"learning_rate": {
"default": model_config.get("lr", 0.002),
"range": [0.0001, 0.01],
"tuning_tip": "Start with default, reduce by 0.5x if training unstable"
},
"batch_size": {
"default": model_config.get("batch_size", 32),
"range": [16, 256],
"tuning_tip": "Larger batch = faster but may need higher lr; smaller batch = better generalization"
},
"epochs": {
"default": model_config.get("epochs", 250),
"range": [50, 500],
"tuning_tip": "Monitor validation loss; stop if no improvement for 20+ epochs"
},
"hidden_dims": {
"default": [model_config.get("dim1", 100), model_config.get("dim2", 100)],
"range": [[32, 32], [256, 256]],
"tuning_tip": "Increase for complex tasks, decrease for small datasets"
},
"gc_layers": {
"default": model_config.get("gc_count", 4),
"range": [2, 8],
"tuning_tip": "More layers capture longer-range interactions but risk oversmoothing"
},
"dropout": {
"default": model_config.get("dropout_rate", 0.0),
"range": [0.0, 0.5],
"tuning_tip": "Increase if overfitting (train << val error)"
}
}
# Adjust based on dataset size
if dataset_size:
if dataset_size < 500:
size_advice = "Small dataset - use regularization, smaller model"
recommendations["batch_size"]["recommended"] = min(32, dataset_size // 10)
recommendations["dropout"]["recommended"] = 0.2
recommendations["hidden_dims"]["recommended"] = [64, 64]
recommendations["epochs"]["recommended"] = 200
elif dataset_size < 5000:
size_advice = "Medium dataset - default settings should work"
recommendations["batch_size"]["recommended"] = 64
recommendations["dropout"]["recommended"] = 0.1
recommendations["epochs"]["recommended"] = 250
else:
size_advice = "Large dataset - can use larger model"
recommendations["batch_size"]["recommended"] = 128
recommendations["hidden_dims"]["recommended"] = [128, 128]
recommendations["epochs"]["recommended"] = 300
else:
size_advice = "Provide dataset_size for tailored recommendations"
# Generate hyperparameter search script
hp_search_script = '''"""MatDeepLearn Hyperparameter Search Script
Simple grid search for key hyperparameters.
"""
import yaml
import numpy as np
import torch
from pathlib import Path
from matdeeplearn import training
from itertools import product
DATA_PATH = Path(r"<path-to-your-dataset>")
CONFIG_PATH = Path(r"config.yml")
# Hyperparameter grid
PARAM_GRID = {
"lr": [0.001, 0.002, 0.005],
"batch_size": [32, 64, 128],
"gc_count": [3, 4, 5],
}
def run_experiment(params, data_path_str, config, rank, world_size):
job_config = {
"job_name": f"hp_search_lr{params['lr']}_bs{params['batch_size']}",
"reprocess": "False",
"model": "''' + model_name + '''",
"save_model": "False",
"write_output": "False",
"parallel": "False",
"seed": 42,
}
training_config = config.get("Training", {}).copy()
model_config = config["Models"]["''' + model_name + '''"].copy()
model_config.update({
"lr": params["lr"],
"batch_size": params["batch_size"],
"gc_count": params.get("gc_count", model_config.get("gc_count", 4)),
"epochs": 100, # Shorter for search
})
errors = training.train_regular(rank, world_size, data_path_str, job_config, training_config, model_config)
return {"params": params, "train_mae": errors[0], "val_mae": errors[1], "test_mae": errors[2]}
def main():
with CONFIG_PATH.open("r") as f:
config = yaml.safe_load(f)
data_path_str = str(DATA_PATH.resolve())
rank = "cuda" if torch.cuda.is_available() else "cpu"
world_size = torch.cuda.device_count()
results = []
param_combinations = [dict(zip(PARAM_GRID.keys(), v)) for v in product(*PARAM_GRID.values())]
print(f"Running {len(param_combinations)} experiments...")
for i, params in enumerate(param_combinations):
print(f"\\nExperiment {i+1}/{len(param_combinations)}: {params}")
result = run_experiment(params, data_path_str, config, rank, world_size)
results.append(result)
print(f" Val MAE: {result['val_mae']:.4f}")
# Find best
best = min(results, key=lambda x: x["val_mae"])
print("\\n" + "="*50)
print(f"BEST PARAMS: {best['params']}")
print(f"Best Val MAE: {best['val_mae']:.4f}")
if __name__ == "__main__":
main()
'''
return {
"success": True,
"model_name": model_name,
"model_description": _get_model_description(model_name),
"default_config": model_config,
"recommendations": recommendations,
"dataset_size_advice": size_advice,
"hyperparameter_search_script": hp_search_script,
"tuning_strategy": [
"1. Start with default hyperparameters",
"2. Train and observe train/val error curves",
"3. If overfitting (train << val): increase dropout, reduce model size, add regularization",
"4. If underfitting (both high): increase model size, more epochs, higher lr",
"5. Use the search script for systematic optimization"
],
"common_issues": {
"loss_not_decreasing": "Try lower learning rate or different optimizer",
"nan_loss": "Reduce learning rate significantly, check data normalization",
"overfitting": "Add dropout, reduce hidden dims, use early stopping",
"slow_convergence": "Increase learning rate, use scheduler with warmup"
}
}
except Exception as e:
return {"success": False, "error": str(e)}
def create_app() -> FastMCP:
"""
Creates and returns the FastMCP application instance.
Returns:
FastMCP: The FastMCP application instance.
"""
return mcp