Materials_discovery / mcp_server.py
SEUyishu's picture
Update mcp_server.py
7868cdf verified
# Copyright 2024 Google LLC (Original code), Modified for MCP Service
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""
GNoME Materials Discovery MCP Server
This is the main MCP server implementation providing tools for:
- Dataset access and querying
- Decomposition energy calculation
- Phase diagram analysis
- Crystal structure operations
- Air stability analysis
- Model inference
"""
import os
import json
import logging
from typing import Optional, List, Dict, Any
from mcp.server.fastmcp import FastMCP
# Import local modules
from data_utils import DataManager, get_data_manager
from phase_diagram_utils import (
compute_decomposition_energy,
build_phase_diagram,
compute_air_stability,
compare_with_materials_project,
find_competing_phases
)
from model_utils import (
ModelLoader,
atoms_to_graph,
get_model_info,
get_nequip_default_config,
get_gnome_default_config,
StructureMatcher
)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Data directory configuration - must match Dockerfile ENV
DATA_DIR = os.environ.get("GNOME_DATA_DIR", "/app/gnome_data")
MODEL_DIR = os.environ.get("GNOME_MODEL_DIR", "/app/models")
# Ensure directories exist at module load
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
# Create simple FastMCP server - SSE is handled externally by start_mcp.py
mcp = FastMCP("GNoME Materials Discovery")
# ============================================================================
# Dataset Access Tools
# ============================================================================
@mcp.tool()
async def get_dataset_statistics() -> Dict[str, Any]:
"""
Get statistics about the GNoME materials discovery dataset.
Returns information about:
- Total number of materials
- Unique compositions and formulas
- Crystal system distribution
- Average formation energy
- Element coverage
"""
try:
dm = get_data_manager(DATA_DIR)
stats = dm.get_statistics()
return {
"status": "success",
"data": stats
}
except Exception as e:
logger.error(f"Error getting statistics: {e}")
return {"status": "error", "message": str(e)}
@mcp.tool()
async def query_materials(
composition: Optional[str] = None,
elements: Optional[str] = None,
space_group: Optional[int] = None,
crystal_system: Optional[str] = None,
min_bandgap: Optional[float] = None,
max_bandgap: Optional[float] = None,
max_decomposition_energy: Optional[float] = None,
limit: int = 50
) -> Dict[str, Any]:
"""
Query materials from the GNoME dataset with various filters.
Args:
composition: Exact composition to match (e.g., "Li2O")
elements: Comma-separated list of elements that must be present (e.g., "Li,O")
space_group: Space group number to filter by
crystal_system: Crystal system name (e.g., "cubic", "hexagonal")
min_bandgap: Minimum bandgap value in eV
max_bandgap: Maximum bandgap value in eV
max_decomposition_energy: Maximum decomposition energy per atom in eV
limit: Maximum number of results to return (default: 50)
Returns:
List of matching materials with their properties
"""
try:
dm = get_data_manager(DATA_DIR)
elements_list = None
if elements:
elements_list = [e.strip() for e in elements.split(",")]
results = dm.query_by_composition(
composition=composition,
elements=elements_list,
space_group=space_group,
crystal_system=crystal_system,
min_bandgap=min_bandgap,
max_bandgap=max_bandgap,
max_decomposition_energy=max_decomposition_energy,
limit=limit
)
# Convert to list of dicts
materials = results.to_dict(orient='records')
return {
"status": "success",
"count": len(materials),
"materials": materials
}
except Exception as e:
logger.error(f"Error querying materials: {e}")
return {"status": "error", "message": str(e)}
@mcp.tool()
async def get_material_by_id(material_id: str) -> Dict[str, Any]:
"""
Get detailed information about a specific material by its ID.
Args:
material_id: The unique MaterialId from the GNoME dataset
Returns:
Complete material information including structure and properties
"""
try:
dm = get_data_manager(DATA_DIR)
material = dm.get_crystal_by_id(material_id)
if material is None:
return {"status": "error", "message": f"Material {material_id} not found"}
return {
"status": "success",
"material": material.to_dict()
}
except Exception as e:
logger.error(f"Error getting material: {e}")
return {"status": "error", "message": str(e)}
@mcp.tool()
async def get_random_material(
crystal_system: Optional[str] = None,
n_elements: Optional[int] = None
) -> Dict[str, Any]:
"""
Get a random material from the GNoME dataset.
Args:
crystal_system: Optional filter by crystal system
n_elements: Optional filter by number of elements (e.g., 2 for binary, 3 for ternary)
Returns:
Random material information
"""
try:
dm = get_data_manager(DATA_DIR)
crystals = dm.load_gnome_crystals()
if crystal_system:
crystals = crystals[crystals['Crystal System'] == crystal_system]
if n_elements:
crystals = crystals[crystals['Chemical System'].map(len) == n_elements]
if len(crystals) == 0:
return {"status": "error", "message": "No materials match the criteria"}
sample = crystals.sample(1).iloc[0]
return {
"status": "success",
"material": sample.to_dict()
}
except Exception as e:
logger.error(f"Error getting random material: {e}")
return {"status": "error", "message": str(e)}
# ============================================================================
# Phase Diagram and Stability Tools
# ============================================================================
@mcp.tool()
async def calculate_decomposition_energy(
composition: str,
energy: float
) -> Dict[str, Any]:
"""
Calculate the decomposition energy of a material relative to the GNoME convex hull.
This determines whether a material is thermodynamically stable or metastable.
A negative or zero decomposition energy indicates stability.
Args:
composition: Chemical composition (e.g., "LiFePO4", "Li2O")
energy: Total corrected energy from DFT calculation in eV
Returns:
Decomposition energy and decomposition products
"""
try:
import pymatgen as mg
dm = get_data_manager(DATA_DIR)
all_crystals = dm.load_all_crystals()
grouped = dm.get_grouped_entries()
# Get chemical system from composition
comp = mg.core.Composition(composition)
chemsys = [str(el) for el in comp.elements]
result = compute_decomposition_energy(
composition=composition,
energy=energy,
chemsys=chemsys,
grouped_entries=grouped,
all_crystals=all_crystals
)
# 确保所有返回值都是标准Python类型(避免numpy.bool_等)
def to_builtin(obj):
import numpy as np
if isinstance(obj, dict):
return {k: to_builtin(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [to_builtin(v) for v in obj]
elif isinstance(obj, (np.generic, np.bool_)):
return obj.item()
return obj
safe_result = to_builtin(result)
return {
"status": "success",
**safe_result
}
except Exception as e:
logger.error(f"Error calculating decomposition energy: {e}")
return {"status": "error", "message": str(e)}
@mcp.tool()
async def get_phase_diagram(
elements: str
) -> Dict[str, Any]:
"""
Build and analyze the phase diagram for a chemical system.
Args:
elements: Comma or dash separated list of elements (e.g., "Li,Fe,P,O" or "Li-Fe-P-O")
Returns:
Phase diagram information including stable and unstable entries
"""
try:
import re
dm = get_data_manager(DATA_DIR)
all_crystals = dm.load_all_crystals()
grouped = dm.get_grouped_entries()
# Parse elements
chemsys = re.split(r'[\s,\-]+', elements)
chemsys = [e.strip() for e in chemsys if e.strip()]
result = build_phase_diagram(
chemsys=chemsys,
grouped_entries=grouped,
all_crystals=all_crystals
)
return {
"status": "success",
**result
}
except Exception as e:
logger.error(f"Error building phase diagram: {e}")
return {"status": "error", "message": str(e)}
@mcp.tool()
async def calculate_air_stability(
composition: str,
energy: float,
temperature: float = 300.0,
oxygen_pressure: float = 21200.0
) -> Dict[str, Any]:
"""
Calculate the air stability of a material.
Analyzes stability with respect to:
- Oxygen (via grand potential phase diagram)
- Carbon dioxide (CO2 reactivity)
- Water (H2O reactivity)
Args:
composition: Chemical composition (e.g., "Li3N", "NaCl")
energy: Total corrected energy in eV
temperature: Temperature in Kelvin (default: 300K)
oxygen_pressure: Oxygen partial pressure in Pa (default: 21200 Pa, ambient)
Returns:
Air stability analysis results
"""
try:
import pymatgen as mg
dm = get_data_manager(DATA_DIR)
all_crystals = dm.load_all_crystals()
grouped = dm.get_grouped_entries()
comp = mg.core.Composition(composition)
chemsys = [str(el) for el in comp.elements]
result = compute_air_stability(
composition=composition,
energy=energy,
chemsys=chemsys,
grouped_entries=grouped,
all_crystals=all_crystals,
temperature=temperature,
oxygen_pressure=oxygen_pressure
)
return {
"status": "success",
**result
}
except Exception as e:
logger.error(f"Error calculating air stability: {e}")
return {"status": "error", "message": str(e)}
@mcp.tool()
async def compare_gnome_with_mp(
elements: str
) -> Dict[str, Any]:
"""
Compare GNoME phase diagram with Materials Project for a chemical system.
Identifies:
- New stable phases discovered by GNoME
- Phases only in Materials Project
- Common stable phases
Args:
elements: Comma or dash separated list of elements
Returns:
Comparison results between GNoME and Materials Project
"""
try:
import re
dm = get_data_manager(DATA_DIR)
# Load required data
all_crystals = dm.load_all_crystals()
mp_crystals = dm.load_mp_crystals()
gnome_grouped = dm.get_grouped_entries()
# Create MP grouped entries
required_columns = [
'Composition', 'NSites', 'Corrected Energy',
'Formation Energy Per Atom', 'Chemical System'
]
mp_minimal = mp_crystals[required_columns]
mp_grouped = mp_minimal.groupby('Chemical System')
# Parse elements
chemsys = re.split(r'[\s,\-]+', elements)
chemsys = [e.strip() for e in chemsys if e.strip()]
result = compare_with_materials_project(
chemsys=chemsys,
grouped_entries=gnome_grouped,
mp_grouped_entries=mp_grouped,
all_crystals=all_crystals,
mp_crystals=mp_crystals
)
return {
"status": "success",
**result
}
except Exception as e:
logger.error(f"Error comparing with MP: {e}")
return {"status": "error", "message": str(e)}
@mcp.tool()
async def find_competing_phases_for_composition(
composition: str,
n_phases: int = 5
) -> Dict[str, Any]:
"""
Find competing phases for a given composition.
Identifies the most thermodynamically favorable phases in the same
chemical space that compete with the given composition.
Args:
composition: Chemical composition to analyze
n_phases: Number of competing phases to return (default: 5)
Returns:
List of competing phases with their properties
"""
try:
import pymatgen as mg
dm = get_data_manager(DATA_DIR)
all_crystals = dm.load_all_crystals()
grouped = dm.get_grouped_entries()
comp = mg.core.Composition(composition)
chemsys = [str(el) for el in comp.elements]
phases = find_competing_phases(
composition=composition,
chemsys=chemsys,
grouped_entries=grouped,
all_crystals=all_crystals,
n_phases=n_phases
)
return {
"status": "success",
"composition": composition,
"chemical_system": chemsys,
"competing_phases": phases
}
except Exception as e:
logger.error(f"Error finding competing phases: {e}")
return {"status": "error", "message": str(e)}
# ============================================================================
# Structure Tools
# ============================================================================
@mcp.tool()
async def get_structure(
reduced_formula: str,
output_format: str = "json"
) -> Dict[str, Any]:
"""
Get crystal structure for a given reduced formula.
Args:
reduced_formula: Reduced chemical formula (e.g., "LiFePO4", "TiO2")
output_format: Output format - "json", "cif", or "poscar"
Returns:
Crystal structure data in the requested format
"""
try:
dm = get_data_manager(DATA_DIR)
atoms, structure = dm.load_structure(reduced_formula)
if output_format == "cif":
return {
"status": "success",
"format": "cif",
"data": structure.to(fmt="cif")
}
elif output_format == "poscar":
return {
"status": "success",
"format": "poscar",
"data": structure.to(fmt="poscar")
}
else: # json
return {
"status": "success",
"format": "json",
"data": {
"formula": structure.formula,
"reduced_formula": structure.composition.reduced_formula,
"lattice": {
"a": structure.lattice.a,
"b": structure.lattice.b,
"c": structure.lattice.c,
"alpha": structure.lattice.alpha,
"beta": structure.lattice.beta,
"gamma": structure.lattice.gamma,
"volume": structure.lattice.volume,
"matrix": structure.lattice.matrix.tolist()
},
"sites": [
{
"species": str(site.specie),
"coords": site.frac_coords.tolist(),
"cart_coords": site.coords.tolist()
}
for site in structure.sites
],
"n_sites": len(structure),
"space_group": structure.get_space_group_info()[0]
}
}
except Exception as e:
logger.error(f"Error getting structure: {e}")
return {"status": "error", "message": str(e)}
@mcp.tool()
async def compare_structures(
formula1: str,
formula2: str,
ltol: float = 0.2,
stol: float = 0.3,
angle_tol: float = 5.0
) -> Dict[str, Any]:
"""
Compare two crystal structures from the GNoME dataset.
Uses pymatgen's StructureMatcher to determine if structures are equivalent.
Args:
formula1: First reduced formula
formula2: Second reduced formula
ltol: Length tolerance for matching
stol: Site tolerance for matching
angle_tol: Angle tolerance in degrees
Returns:
Comparison results including whether structures match
"""
try:
dm = get_data_manager(DATA_DIR)
_, structure1 = dm.load_structure(formula1)
_, structure2 = dm.load_structure(formula2)
matcher = StructureMatcher(ltol=ltol, stol=stol, angle_tol=angle_tol)
is_match = matcher.fit(structure1, structure2)
rms_result = matcher.get_rms_dist(structure1, structure2)
return {
"status": "success",
"formula1": formula1,
"formula2": formula2,
"structures_match": is_match,
"rms_dist": rms_result[0] if rms_result else None,
"max_dist": rms_result[1] if rms_result else None,
"tolerances": {
"ltol": ltol,
"stol": stol,
"angle_tol": angle_tol
}
}
except Exception as e:
logger.error(f"Error comparing structures: {e}")
return {"status": "error", "message": str(e)}
# ============================================================================
# r²SCAN Validation Tools
# ============================================================================
@mcp.tool()
async def get_r2scan_data(
composition: Optional[str] = None,
limit: int = 50
) -> Dict[str, Any]:
"""
Get r²SCAN validation data for materials.
r²SCAN is a more accurate DFT functional used to validate GNoME predictions.
Args:
composition: Optional composition filter
limit: Maximum number of results
Returns:
r²SCAN calculated energies and stability metrics
"""
try:
dm = get_data_manager(DATA_DIR)
r2scan = dm.load_r2scan_crystals()
if composition:
r2scan = r2scan[r2scan['Composition'] == composition]
results = r2scan.head(limit).to_dict(orient='records')
return {
"status": "success",
"count": len(results),
"data": results
}
except Exception as e:
logger.error(f"Error getting r2scan data: {e}")
return {"status": "error", "message": str(e)}
# ============================================================================
# a2c Crystal Structure Prediction Tools
# ============================================================================
@mcp.tool()
async def get_a2c_supporting_data() -> Dict[str, Any]:
"""
Get a2c (amorphous-to-crystalline) structure prediction supporting data.
The a2c pipeline discovers crystal structures by relaxing amorphous
configurations using GNoME force fields.
Returns:
List of available a2c campaigns with their chemical systems
"""
try:
dm = get_data_manager(DATA_DIR)
a2c_data = dm.load_a2c_data()
campaigns = []
for key, data in a2c_data.items():
campaigns.append({
"chemical_system": key,
"has_amorphous_structure": "amorphous_structure" in data,
"num_initial_structures": len(data.get("a2c_initial_structures", [])),
"num_matches": len(data.get("a2c_match_after_relax_example", []))
})
return {
"status": "success",
"num_campaigns": len(campaigns),
"campaigns": campaigns
}
except Exception as e:
logger.error(f"Error getting a2c data: {e}")
return {"status": "error", "message": str(e)}
@mcp.tool()
async def get_a2c_campaign_details(
chemical_system: str
) -> Dict[str, Any]:
"""
Get detailed data for a specific a2c campaign.
Args:
chemical_system: Chemical system name (e.g., "Al2O3", "SiO2")
Returns:
Detailed a2c campaign data including structures
"""
try:
dm = get_data_manager(DATA_DIR)
a2c_data = dm.load_a2c_data()
if chemical_system not in a2c_data:
return {
"status": "error",
"message": f"Chemical system {chemical_system} not found in a2c data"
}
data = a2c_data[chemical_system]
matches = []
for match in data.get("a2c_match_after_relax_example", []):
matches.append({
"index": match.get("index_in_a2c_initial_structures"),
"formula": match.get("formula"),
"has_ff_relaxed": "relaxed_ff" in match,
"has_dft_relaxed": "relaxed_dft" in match
})
return {
"status": "success",
"chemical_system": chemical_system,
"amorphous_structure": data.get("amorphous_structure", "")[:500] + "...",
"num_initial_structures": len(data.get("a2c_initial_structures", [])),
"matches": matches
}
except Exception as e:
logger.error(f"Error getting a2c campaign details: {e}")
return {"status": "error", "message": str(e)}
# ============================================================================
# Model Information Tools
# ============================================================================
@mcp.tool()
async def get_model_configurations() -> Dict[str, Any]:
"""
Get default configurations for GNoME and NequIP models.
Returns:
Default configuration dictionaries for both model architectures
"""
return {
"status": "success",
"nequip_config": get_nequip_default_config(),
"gnome_config": get_gnome_default_config()
}
@mcp.tool()
async def list_available_models() -> Dict[str, Any]:
"""
List available pre-trained models.
Returns:
List of available model names and their information
"""
try:
loader = ModelLoader(MODEL_DIR)
models = loader.get_available_models()
model_info = []
for model_name in models:
info = get_model_info(model_name, MODEL_DIR)
model_info.append(info)
return {
"status": "success",
"num_models": len(models),
"models": model_info
}
except Exception as e:
logger.error(f"Error listing models: {e}")
return {"status": "error", "message": str(e)}
# ============================================================================
# Utility Tools
# ============================================================================
@mcp.tool()
async def get_pseudopotential_corrections() -> Dict[str, Any]:
"""
Get pseudopotential corrections for Materials Project compatibility.
These corrections are needed when comparing energies between GNoME
and Materials Project calculations.
Returns:
Dictionary of elemental corrections (eV/atom)
"""
from data_utils import PP_CORRECTIONS
return {
"status": "success",
"description": "Pseudopotential corrections for elements where GNoME and MP use different pseudopotentials",
"corrections": PP_CORRECTIONS,
"units": "eV/atom"
}
@mcp.tool()
async def download_dataset(
include_structures: bool = False
) -> Dict[str, Any]:
"""
Download the GNoME dataset files to the server.
Downloads summary CSVs and optionally structure archives.
Data is persisted in the server's /data directory.
Args:
include_structures: Whether to also download structure archives (~GB)
Returns:
Status of download operation
"""
try:
dm = get_data_manager(DATA_DIR)
downloaded = []
# Download summary files
gnome_path, external_path = dm.download_summary_data()
downloaded.extend([str(gnome_path), str(external_path)])
# Download MP snapshot
mp_path = dm.download_mp_snapshot()
downloaded.append(str(mp_path))
# Download r2scan
r2scan_path = dm.download_r2scan_data()
downloaded.append(str(r2scan_path))
if include_structures:
struct_path = dm.download_structure_archive("by_reduced_formula")
downloaded.append(str(struct_path))
return {
"status": "success",
"downloaded_files": downloaded,
"data_directory": DATA_DIR
}
except Exception as e:
logger.error(f"Error downloading dataset: {e}")
return {"status": "error", "message": str(e)}
@mcp.tool()
async def check_data_status() -> Dict[str, Any]:
"""
Check the status of downloaded dataset files on the server.
Returns information about which files are available and their sizes.
Use this to verify data is properly downloaded before querying.
Returns:
Status of each dataset file (exists, size, path)
"""
import os
files_to_check = {
"gnome_summary": "stable_materials_summary.csv",
"external_summary": "external_materials_summary.csv",
"mp_snapshot": "mp_snapshot_summary.csv",
"r2scan": "stable_materials_r2scan.csv",
"structures": "by_reduced_formula.zip",
"a2c_data": "a2c_supporting_data.json"
}
file_status = {}
total_size = 0
for key, filename in files_to_check.items():
filepath = os.path.join(DATA_DIR, filename)
if os.path.exists(filepath):
size = os.path.getsize(filepath)
total_size += size
file_status[key] = {
"exists": True,
"path": filepath,
"size_mb": round(size / (1024 * 1024), 2)
}
else:
file_status[key] = {
"exists": False,
"path": filepath,
"size_mb": 0
}
# Check if core data is ready
core_ready = (
file_status["gnome_summary"]["exists"] and
file_status["external_summary"]["exists"]
)
return {
"status": "success",
"data_directory": DATA_DIR,
"core_data_ready": core_ready,
"total_size_mb": round(total_size / (1024 * 1024), 2),
"files": file_status,
"message": "Core data is ready for queries" if core_ready else "Please call download_dataset() first"
}
# ============================================================================
# Server Entry Point
# ============================================================================
if __name__ == "__main__":
# Run the server with stdio transport for local testing
mcp.run()