Spaces:
Sleeping
Sleeping
| # Copyright 2024 Google LLC (Original code), Modified for MCP Service | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| """ | |
| Data utilities for GNoME Materials Discovery MCP Service. | |
| This module handles: | |
| - Dataset downloading from Google Cloud Storage | |
| - Data preprocessing and caching | |
| - Crystal structure loading | |
| """ | |
| import os | |
| import json | |
| import tempfile | |
| import shutil | |
| import zipfile | |
| from typing import Optional, List, Tuple, Dict, Any | |
| from pathlib import Path | |
| import pandas as pd | |
| import requests | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Constants | |
| PUBLIC_LINK = "https://storage.googleapis.com/" | |
| BUCKET_NAME = "gdm_materials_discovery" | |
| FOLDER_NAME = "gnome_data" | |
| EXTERNAL_FOLDER_NAME = "external_data" | |
| # Data files | |
| GNOME_FILES = ( | |
| "stable_materials_summary.csv", | |
| "stable_materials_r2scan.csv", | |
| ) | |
| EXTERNAL_FILES = ( | |
| "mp_snapshot_summary.csv", | |
| "external_materials_summary.csv", | |
| ) | |
| STRUCTURE_FILES = ( | |
| "by_composition.zip", | |
| "by_id.zip", | |
| "by_reduced_formula.zip", | |
| ) | |
| AUXILIARY_FILES = ( | |
| "a2c_supporting_data.json", | |
| ) | |
| # Pseudopotential corrections for MP compatibility | |
| PP_CORRECTIONS = { | |
| "Ga": -0.0028805, | |
| "Ge": 0.10417085, | |
| "Li": -0.00301278, | |
| "Mg": 0.0924014, | |
| "Na": -0.00447437 | |
| } | |
| # Default data directory - must match Dockerfile ENV | |
| DEFAULT_DATA_DIR = os.environ.get("GNOME_DATA_DIR", "/app/gnome_data") | |
| class DataManager: | |
| """Manages GNoME dataset downloading and caching.""" | |
| def __init__(self, data_dir: str = None): | |
| """ | |
| Initialize DataManager. | |
| Args: | |
| data_dir: Directory to store downloaded data (defaults to GNOME_DATA_DIR env var) | |
| """ | |
| if data_dir is None: | |
| data_dir = DEFAULT_DATA_DIR | |
| self.data_dir = Path(data_dir) | |
| self.data_dir.mkdir(parents=True, exist_ok=True) | |
| # Cached dataframes | |
| self._gnome_crystals: Optional[pd.DataFrame] = None | |
| self._reference_crystals: Optional[pd.DataFrame] = None | |
| self._mp_crystals: Optional[pd.DataFrame] = None | |
| self._r2scan_crystals: Optional[pd.DataFrame] = None | |
| self._all_crystals: Optional[pd.DataFrame] = None | |
| self._grouped_entries: Optional[pd.core.groupby.DataFrameGroupBy] = None | |
| self._structure_zip: Optional[zipfile.ZipFile] = None | |
| def download_file(self, filename: str, folder: str = FOLDER_NAME) -> Path: | |
| """ | |
| Download a file from Google Cloud Storage. | |
| Args: | |
| filename: Name of file to download | |
| folder: Folder in bucket | |
| Returns: | |
| Path to downloaded file | |
| """ | |
| url = f"{PUBLIC_LINK}{BUCKET_NAME}/{folder}/{filename}" | |
| output_path = self.data_dir / filename | |
| if output_path.exists(): | |
| logger.info(f"File {filename} already exists, skipping download") | |
| return output_path | |
| logger.info(f"Downloading {filename} from {url}") | |
| try: | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| with open(output_path, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| logger.info(f"Downloaded {filename} successfully") | |
| return output_path | |
| except Exception as e: | |
| logger.error(f"Failed to download {filename}: {e}") | |
| raise | |
| def download_summary_data(self) -> Tuple[Path, Path]: | |
| """ | |
| Download the main summary CSV files. | |
| Returns: | |
| Tuple of paths to gnome and external summary files | |
| """ | |
| gnome_path = self.download_file("stable_materials_summary.csv", FOLDER_NAME) | |
| external_path = self.download_file("external_materials_summary.csv", EXTERNAL_FOLDER_NAME) | |
| return gnome_path, external_path | |
| def download_mp_snapshot(self) -> Path: | |
| """Download Materials Project snapshot.""" | |
| return self.download_file("mp_snapshot_summary.csv", EXTERNAL_FOLDER_NAME) | |
| def download_r2scan_data(self) -> Path: | |
| """Download r2SCAN validation data.""" | |
| return self.download_file("stable_materials_r2scan.csv", FOLDER_NAME) | |
| def download_structure_archive(self, archive_type: str = "by_reduced_formula") -> Path: | |
| """ | |
| Download structure archive. | |
| Args: | |
| archive_type: One of 'by_composition', 'by_id', 'by_reduced_formula' | |
| Returns: | |
| Path to downloaded archive | |
| """ | |
| filename = f"{archive_type}.zip" | |
| return self.download_file(filename, FOLDER_NAME) | |
| def download_a2c_data(self) -> Path: | |
| """Download a2c supporting data.""" | |
| folder = f"{FOLDER_NAME}/auxiliary_gnome_data" | |
| return self.download_file("a2c_supporting_data.json", folder) | |
| def annotate_chemical_system(self, crystals: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Annotate dataframe with chemical system tuples. | |
| Args: | |
| crystals: DataFrame with 'Elements' column | |
| Returns: | |
| DataFrame with 'Chemical System' column added | |
| """ | |
| chemical_systems = [] | |
| for e in crystals['Elements']: | |
| try: | |
| # Replace single quotes with double quotes for JSON parsing | |
| chemsys = json.loads(e.replace("'", '"')) | |
| chemical_systems.append(tuple(sorted(chemsys))) | |
| except Exception: | |
| chemical_systems.append(()) | |
| crystals['Chemical System'] = chemical_systems | |
| return crystals | |
| def load_gnome_crystals(self) -> pd.DataFrame: | |
| """ | |
| Load and preprocess GNoME crystals dataframe. | |
| Returns: | |
| Preprocessed GNoME crystals DataFrame | |
| """ | |
| if self._gnome_crystals is not None: | |
| return self._gnome_crystals | |
| gnome_path, _ = self.download_summary_data() | |
| self._gnome_crystals = pd.read_csv(gnome_path, index_col=0) | |
| self._gnome_crystals = self.annotate_chemical_system(self._gnome_crystals) | |
| return self._gnome_crystals | |
| def load_reference_crystals(self) -> pd.DataFrame: | |
| """ | |
| Load and preprocess reference crystals dataframe. | |
| Returns: | |
| Preprocessed reference crystals DataFrame | |
| """ | |
| if self._reference_crystals is not None: | |
| return self._reference_crystals | |
| _, external_path = self.download_summary_data() | |
| self._reference_crystals = pd.read_csv(external_path) | |
| self._reference_crystals = self.annotate_chemical_system(self._reference_crystals) | |
| return self._reference_crystals | |
| def load_mp_crystals(self) -> pd.DataFrame: | |
| """ | |
| Load and preprocess Materials Project snapshot. | |
| Returns: | |
| Preprocessed MP crystals DataFrame | |
| """ | |
| if self._mp_crystals is not None: | |
| return self._mp_crystals | |
| mp_path = self.download_mp_snapshot() | |
| self._mp_crystals = pd.read_csv(mp_path) | |
| self._mp_crystals = self.annotate_chemical_system(self._mp_crystals) | |
| return self._mp_crystals | |
| def load_r2scan_crystals(self) -> pd.DataFrame: | |
| """ | |
| Load r2SCAN validation data. | |
| Returns: | |
| r2SCAN crystals DataFrame | |
| """ | |
| if self._r2scan_crystals is not None: | |
| return self._r2scan_crystals | |
| r2scan_path = self.download_r2scan_data() | |
| self._r2scan_crystals = pd.read_csv(r2scan_path) | |
| return self._r2scan_crystals | |
| def load_all_crystals(self) -> pd.DataFrame: | |
| """ | |
| Load combined GNoME and reference crystals. | |
| Returns: | |
| Combined crystals DataFrame | |
| """ | |
| if self._all_crystals is not None: | |
| return self._all_crystals | |
| gnome = self.load_gnome_crystals() | |
| reference = self.load_reference_crystals() | |
| self._all_crystals = pd.concat([gnome, reference], ignore_index=True) | |
| return self._all_crystals | |
| def get_grouped_entries(self) -> pd.core.groupby.DataFrameGroupBy: | |
| """ | |
| Get entries grouped by chemical system. | |
| Returns: | |
| Grouped DataFrame | |
| """ | |
| if self._grouped_entries is not None: | |
| return self._grouped_entries | |
| all_crystals = self.load_all_crystals() | |
| required_columns = [ | |
| 'Composition', 'NSites', 'Corrected Energy', | |
| 'Formation Energy Per Atom', 'Chemical System' | |
| ] | |
| minimal_entries = all_crystals[required_columns] | |
| self._grouped_entries = minimal_entries.groupby('Chemical System') | |
| return self._grouped_entries | |
| def get_structure_zip(self) -> zipfile.ZipFile: | |
| """ | |
| Get zipfile handle for structure archive. | |
| Returns: | |
| ZipFile object for structure archive | |
| """ | |
| if self._structure_zip is not None: | |
| return self._structure_zip | |
| archive_path = self.download_structure_archive("by_reduced_formula") | |
| self._structure_zip = zipfile.ZipFile(archive_path) | |
| return self._structure_zip | |
| def load_structure(self, reduced_formula: str) -> Tuple[Any, Any]: | |
| """ | |
| Load crystal structure by reduced formula. | |
| Args: | |
| reduced_formula: Reduced formula of the structure | |
| Returns: | |
| Tuple of (ase.Atoms, pymatgen.Structure) | |
| """ | |
| try: | |
| import ase.io | |
| from pymatgen.core import Structure as PmgStructure | |
| except ImportError: | |
| raise ImportError("ase and pymatgen are required for structure loading") | |
| z = self.get_structure_zip() | |
| extension = f"{reduced_formula}.CIF" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_path = os.path.join(temp_dir, extension) | |
| with z.open(os.path.join('by_reduced_formula', extension)) as zf: | |
| with open(temp_path, 'wb') as fp: | |
| shutil.copyfileobj(zf, fp) | |
| atoms = ase.io.read(temp_path) | |
| structure = PmgStructure.from_file(temp_path) | |
| return atoms, structure | |
| def load_a2c_data(self) -> Dict[str, Any]: | |
| """ | |
| Load a2c supporting data. | |
| Returns: | |
| Dictionary containing a2c data | |
| """ | |
| a2c_path = self.download_a2c_data() | |
| with open(a2c_path, 'r') as f: | |
| return json.load(f) | |
| def query_by_composition( | |
| self, | |
| composition: Optional[str] = None, | |
| elements: Optional[List[str]] = None, | |
| space_group: Optional[int] = None, | |
| crystal_system: Optional[str] = None, | |
| min_bandgap: Optional[float] = None, | |
| max_bandgap: Optional[float] = None, | |
| max_decomposition_energy: Optional[float] = None, | |
| limit: int = 100 | |
| ) -> pd.DataFrame: | |
| """ | |
| Query crystals with various filters. | |
| Args: | |
| composition: Exact composition to match | |
| elements: List of elements that must be present | |
| space_group: Space group number | |
| crystal_system: Crystal system name | |
| min_bandgap: Minimum bandgap value | |
| max_bandgap: Maximum bandgap value | |
| max_decomposition_energy: Maximum decomposition energy | |
| limit: Maximum number of results | |
| Returns: | |
| Filtered DataFrame | |
| """ | |
| crystals = self.load_gnome_crystals() | |
| if composition: | |
| crystals = crystals[crystals['Composition'] == composition] | |
| if elements: | |
| def has_all_elements(row): | |
| try: | |
| chemsys = json.loads(row['Elements'].replace("'", '"')) | |
| return all(el in chemsys for el in elements) | |
| except: | |
| return False | |
| crystals = crystals[crystals.apply(has_all_elements, axis=1)] | |
| if space_group: | |
| crystals = crystals[crystals['Space Group Number'] == space_group] | |
| if crystal_system: | |
| crystals = crystals[crystals['Crystal System'] == crystal_system] | |
| if min_bandgap is not None and 'Bandgap' in crystals.columns: | |
| crystals = crystals[crystals['Bandgap'] >= min_bandgap] | |
| if max_bandgap is not None and 'Bandgap' in crystals.columns: | |
| crystals = crystals[crystals['Bandgap'] <= max_bandgap] | |
| if max_decomposition_energy is not None: | |
| col = 'Decomposition Energy Per Atom' | |
| if col in crystals.columns: | |
| crystals = crystals[crystals[col] <= max_decomposition_energy] | |
| return crystals.head(limit) | |
| def get_crystal_by_id(self, material_id: str) -> Optional[pd.Series]: | |
| """ | |
| Get crystal by MaterialId. | |
| Args: | |
| material_id: Unique material identifier | |
| Returns: | |
| Crystal data as Series or None | |
| """ | |
| crystals = self.load_gnome_crystals() | |
| result = crystals[crystals['MaterialId'] == material_id] | |
| if len(result) > 0: | |
| return result.iloc[0] | |
| return None | |
| def get_statistics(self) -> Dict[str, Any]: | |
| """ | |
| Get dataset statistics. | |
| Returns: | |
| Dictionary with statistics | |
| """ | |
| crystals = self.load_gnome_crystals() | |
| stats = { | |
| "total_materials": len(crystals), | |
| "unique_compositions": crystals['Composition'].nunique(), | |
| "unique_reduced_formulas": crystals['Reduced Formula'].nunique() if 'Reduced Formula' in crystals.columns else None, | |
| "crystal_systems": crystals['Crystal System'].value_counts().to_dict() if 'Crystal System' in crystals.columns else {}, | |
| "space_groups_count": crystals['Space Group Number'].nunique() if 'Space Group Number' in crystals.columns else None, | |
| "avg_formation_energy": crystals['Formation Energy Per Atom'].mean() if 'Formation Energy Per Atom' in crystals.columns else None, | |
| "element_coverage": len(set().union(*[ | |
| set(json.loads(e.replace("'", '"'))) | |
| for e in crystals['Elements'] if isinstance(e, str) | |
| ])) if 'Elements' in crystals.columns else None, | |
| } | |
| return stats | |
| def close(self): | |
| """Close open file handles.""" | |
| if self._structure_zip is not None: | |
| self._structure_zip.close() | |
| self._structure_zip = None | |
| # Global data manager instance | |
| _data_manager: Optional[DataManager] = None | |
| def get_data_manager(data_dir: str = None) -> DataManager: | |
| """ | |
| Get or create global DataManager instance. | |
| Args: | |
| data_dir: Directory for data storage (defaults to GNOME_DATA_DIR env var) | |
| Returns: | |
| DataManager instance | |
| """ | |
| global _data_manager | |
| if _data_manager is None: | |
| _data_manager = DataManager(data_dir) | |
| return _data_manager | |