Spaces:
Sleeping
Sleeping
| # Copyright 2024 Google LLC (Original code), Modified for MCP Service | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| """ | |
| Model utilities for GNoME Materials Discovery MCP Service. | |
| This module provides: | |
| - GNoME model architecture definitions | |
| - NequIP model architecture definitions | |
| - Model loading and inference utilities | |
| - Crystal graph construction | |
| """ | |
| import functools | |
| import json | |
| import os | |
| from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| # Type definitions | |
| Array = Any | |
| PyTree = Any | |
| Shape = Iterable[int] | |
| Dtype = Any | |
| # Constants | |
| NUM_ELEMENTS = 94 | |
| def get_nonlinearity_by_name(name: str) -> Callable: | |
| """ | |
| Get nonlinearity function by name. | |
| Args: | |
| name: Name of nonlinearity ('relu', 'swish', 'tanh', etc.) | |
| Returns: | |
| Nonlinearity function | |
| """ | |
| try: | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| nonlinearities = { | |
| 'none': lambda x: x, | |
| 'relu': nn.relu, | |
| 'raw_swish': nn.swish, | |
| 'tanh': nn.tanh, | |
| 'sigmoid': nn.sigmoid, | |
| 'silu': nn.silu, | |
| } | |
| if name in nonlinearities: | |
| return nonlinearities[name] | |
| raise ValueError(f'Nonlinearity "{name}" not found.') | |
| except ImportError: | |
| raise ImportError("JAX and Flax are required for model utilities") | |
| def create_bessel_embedding(count: int, inner_cutoff: float, outer_cutoff: float): | |
| """ | |
| Create Bessel embedding for radial functions. | |
| Args: | |
| count: Number of Bessel basis functions | |
| inner_cutoff: Inner cutoff radius | |
| outer_cutoff: Outer cutoff radius | |
| Returns: | |
| Bessel embedding module | |
| """ | |
| try: | |
| import jax.numpy as jnp | |
| import flax.linen as nn | |
| from functools import partial | |
| from jax import vmap | |
| f32 = jnp.float32 | |
| def bessel(r_c, frequencies, r): | |
| rp = jnp.where(r > f32(1e-5), r, f32(1000.0)) | |
| b = 2 / r_c * jnp.sin(frequencies * rp / r_c) / rp | |
| return jnp.where(r > f32(1e-5), b, 0) | |
| class BesselEmbedding(nn.Module): | |
| count: int | |
| inner_cutoff: float | |
| outer_cutoff: float | |
| def __call__(self, rs): | |
| def init_fn(key, shape): | |
| n = shape[0] | |
| return jnp.arange(1, n + 1) * jnp.pi | |
| frequencies = self.param('frequencies', init_fn, (self.count,)) | |
| bessel_fn = partial(bessel, self.outer_cutoff, frequencies) | |
| def apply_cutoff(fn, r): | |
| """Apply smooth cutoff.""" | |
| return fn(r) * jnp.where( | |
| r < self.inner_cutoff, 1.0, | |
| jnp.where(r > self.outer_cutoff, 0.0, | |
| 0.5 * (1 + jnp.cos(jnp.pi * (r - self.inner_cutoff) / | |
| (self.outer_cutoff - self.inner_cutoff)))) | |
| ) | |
| return vmap(lambda r: apply_cutoff(bessel_fn, r))(rs) | |
| return BesselEmbedding(count, inner_cutoff, outer_cutoff) | |
| except ImportError: | |
| raise ImportError("JAX and Flax are required for Bessel embedding") | |
| def get_nequip_default_config() -> Dict[str, Any]: | |
| """ | |
| Get default NequIP configuration. | |
| Returns: | |
| Default configuration dictionary | |
| """ | |
| return { | |
| "graph_net_steps": 5, | |
| "nonlinearities": {"e": "raw_swish", "o": "tanh"}, | |
| "use_sc": True, | |
| "n_elements": 94, | |
| "hidden_irreps": "128x0e + 64x1e + 4x2e", | |
| "sh_irreps": "1x0e + 1x1e + 1x2e", | |
| "num_basis": 8, | |
| "r_max": 5.0, | |
| "radial_net_nonlinearity": "raw_swish", | |
| "radial_net_n_hidden": 64, | |
| "radial_net_n_layers": 2, | |
| "n_neighbors": 10.0, | |
| "scalar_mlp_std": 4.0, | |
| } | |
| def get_gnome_default_config() -> Dict[str, Any]: | |
| """ | |
| Get default GNoME crystal energy model configuration. | |
| Returns: | |
| Default configuration dictionary | |
| """ | |
| return { | |
| "graph_net_steps": 5, | |
| "mlp_width": (128, 128, 64), | |
| "mlp_nonlinearity": "raw_swish", | |
| "embedding_dim": 128, | |
| "featurizer": "gaussian", | |
| "shift": -1.6526496, | |
| "scale": 1.0, | |
| "feature_band_limit": 0, | |
| "conditioning_band_limit": 0, | |
| "extra_scalars_for_gating": False, | |
| "residual": "none", | |
| "node_aggregation": "mean", | |
| "edges_for_globals_aggregation": "mean", | |
| "readout_edges_for_globals_aggregation": "mean", | |
| } | |
| class ModelLoader: | |
| """Handles loading and caching of GNoME/NequIP models.""" | |
| def __init__(self, model_dir: str = "./models"): | |
| """ | |
| Initialize ModelLoader. | |
| Args: | |
| model_dir: Directory containing model checkpoints | |
| """ | |
| self.model_dir = model_dir | |
| self._models: Dict[str, Any] = {} | |
| self._configs: Dict[str, Dict] = {} | |
| def load_model(self, model_name: str) -> Tuple[Any, Any, Dict]: | |
| """ | |
| Load a model from checkpoint. | |
| Args: | |
| model_name: Name of the model to load | |
| Returns: | |
| Tuple of (model, params, config) | |
| """ | |
| if model_name in self._models: | |
| return self._models[model_name] | |
| try: | |
| import jax | |
| import jax.numpy as jnp | |
| from jax import eval_shape, random | |
| from jax.tree_util import tree_map | |
| from jax.core import ShapedArray | |
| from flax import serialization | |
| import jraph | |
| from ml_collections import ConfigDict | |
| f32 = jnp.float32 | |
| i32 = jnp.int32 | |
| model_path = os.path.join(self.model_dir, model_name) | |
| # Load config | |
| config_path = os.path.join(model_path, 'config.json') | |
| if not os.path.exists(config_path): | |
| raise FileNotFoundError(f"Config not found at {config_path}") | |
| with open(config_path, 'r') as f: | |
| config = json.loads(json.loads(f.read())) | |
| config = ConfigDict(config) | |
| # Initialize model based on model family | |
| model_family = config.get('model_family', 'nequip') | |
| if model_family == 'nequip': | |
| model = self._create_nequip_model(config) | |
| else: | |
| raise ValueError(f"Unsupported model family: {model_family}") | |
| # Create abstract graph for initialization | |
| graph = jraph.GraphsTuple( | |
| ShapedArray((1, NUM_ELEMENTS), f32), | |
| ShapedArray((1, 3), f32), | |
| ShapedArray((1,), i32), | |
| ShapedArray((1,), i32), | |
| ShapedArray((1, 1), f32), | |
| ShapedArray((1,), i32), | |
| ShapedArray((1,), i32), | |
| ) | |
| # Find checkpoint file | |
| checkpoints = [c for c in os.listdir(model_path) if 'checkpoint' in c] | |
| if not checkpoints: | |
| raise FileNotFoundError(f"No checkpoint found in {model_path}") | |
| checkpoint_path = os.path.join(model_path, checkpoints[0]) | |
| # Load parameters | |
| def init_model(graph): | |
| key = random.PRNGKey(0) | |
| params = model.init(key, graph) | |
| return params | |
| abstract_params = eval_shape(init_model, graph) | |
| with open(checkpoint_path, 'rb') as f: | |
| ckpt_data = (0, abstract_params, None) | |
| ckpt = serialization.from_bytes(ckpt_data, f.read()) | |
| params = tree_map(lambda x: x.astype(f32), ckpt[1]) | |
| self._models[model_name] = (model, params, dict(config)) | |
| self._configs[model_name] = dict(config) | |
| return model, params, dict(config) | |
| except ImportError as e: | |
| raise ImportError(f"Required packages not available: {e}") | |
| except Exception as e: | |
| logger.error(f"Error loading model {model_name}: {e}") | |
| raise | |
| def _create_nequip_model(self, config: Any) -> Any: | |
| """Create NequIP model from config.""" | |
| # This is a placeholder - actual implementation would use the nequip module | |
| raise NotImplementedError("NequIP model creation requires full JAX stack") | |
| def get_available_models(self) -> list: | |
| """ | |
| Get list of available models. | |
| Returns: | |
| List of model names | |
| """ | |
| if not os.path.exists(self.model_dir): | |
| return [] | |
| return [ | |
| d for d in os.listdir(self.model_dir) | |
| if os.path.isdir(os.path.join(self.model_dir, d)) | |
| ] | |
| def atoms_to_graph( | |
| atoms: Any, | |
| cutoff: float = 5.0, | |
| max_neighbors: int = 100 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Convert ASE Atoms to graph representation. | |
| Args: | |
| atoms: ASE Atoms object | |
| cutoff: Cutoff radius for neighbor finding | |
| max_neighbors: Maximum number of neighbors per atom | |
| Returns: | |
| Graph dictionary | |
| """ | |
| try: | |
| import numpy as np | |
| from ase.neighborlist import neighbor_list | |
| except ImportError: | |
| raise ImportError("ASE is required for atoms to graph conversion") | |
| # Get neighbor list | |
| i, j, d, D = neighbor_list('ijdD', atoms, cutoff) | |
| # Get atomic numbers and one-hot encode | |
| atomic_numbers = atoms.get_atomic_numbers() | |
| n_atoms = len(atoms) | |
| # Create one-hot encoding | |
| node_features = np.zeros((n_atoms, NUM_ELEMENTS)) | |
| for idx, z in enumerate(atomic_numbers): | |
| if z <= NUM_ELEMENTS: | |
| node_features[idx, z - 1] = 1.0 | |
| return { | |
| "nodes": node_features, | |
| "edges": D, # Displacement vectors | |
| "senders": i, | |
| "receivers": j, | |
| "n_node": np.array([n_atoms]), | |
| "n_edge": np.array([len(i)]), | |
| "positions": atoms.get_positions(), | |
| "cell": atoms.get_cell()[:], | |
| } | |
| def predict_energy( | |
| model: Any, | |
| params: Any, | |
| graph: Dict[str, Any] | |
| ) -> float: | |
| """ | |
| Predict energy for a given graph. | |
| Args: | |
| model: Model instance | |
| params: Model parameters | |
| graph: Graph dictionary | |
| Returns: | |
| Predicted energy | |
| """ | |
| try: | |
| import jax.numpy as jnp | |
| import jraph | |
| # Convert to jraph GraphsTuple | |
| graph_tuple = jraph.GraphsTuple( | |
| nodes=jnp.array(graph["nodes"]), | |
| edges=jnp.array(graph["edges"]), | |
| senders=jnp.array(graph["senders"]), | |
| receivers=jnp.array(graph["receivers"]), | |
| globals=jnp.zeros((1, 1)), | |
| n_node=jnp.array(graph["n_node"]), | |
| n_edge=jnp.array(graph["n_edge"]), | |
| ) | |
| energy = model.apply(params, graph_tuple) | |
| return float(energy[0, 0]) | |
| except ImportError: | |
| raise ImportError("JAX and jraph are required for energy prediction") | |
| def compute_forces( | |
| model: Any, | |
| params: Any, | |
| graph: Dict[str, Any] | |
| ) -> Any: | |
| """ | |
| Compute forces for a given graph. | |
| Args: | |
| model: Model instance | |
| params: Model parameters | |
| graph: Graph dictionary | |
| Returns: | |
| Forces array | |
| """ | |
| try: | |
| import jax | |
| import jax.numpy as jnp | |
| def energy_fn(positions): | |
| g = dict(graph) | |
| g["nodes"] = jnp.array(graph["nodes"]) | |
| # Would need to recompute edges based on new positions | |
| return predict_energy(model, params, g) | |
| # Compute negative gradient of energy | |
| positions = jnp.array(graph["positions"]) | |
| forces = -jax.grad(energy_fn)(positions) | |
| return forces | |
| except ImportError: | |
| raise ImportError("JAX is required for force computation") | |
| def get_model_info(model_name: str, model_dir: str = "./models") -> Dict[str, Any]: | |
| """ | |
| Get information about a model without loading it. | |
| Args: | |
| model_name: Name of the model | |
| model_dir: Directory containing models | |
| Returns: | |
| Model information dictionary | |
| """ | |
| model_path = os.path.join(model_dir, model_name) | |
| config_path = os.path.join(model_path, 'config.json') | |
| if not os.path.exists(config_path): | |
| return {"error": f"Model {model_name} not found"} | |
| try: | |
| with open(config_path, 'r') as f: | |
| config = json.loads(json.loads(f.read())) | |
| return { | |
| "model_name": model_name, | |
| "model_family": config.get("model_family", "unknown"), | |
| "graph_net_steps": config.get("graph_net_steps"), | |
| "hidden_irreps": config.get("hidden_irreps"), | |
| "r_max": config.get("r_max"), | |
| "n_elements": config.get("n_elements", NUM_ELEMENTS), | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| class StructureMatcher: | |
| """Utility class for comparing crystal structures.""" | |
| def __init__( | |
| self, | |
| ltol: float = 0.2, | |
| stol: float = 0.3, | |
| angle_tol: float = 5.0 | |
| ): | |
| """ | |
| Initialize StructureMatcher. | |
| Args: | |
| ltol: Length tolerance | |
| stol: Site tolerance | |
| angle_tol: Angle tolerance in degrees | |
| """ | |
| self.ltol = ltol | |
| self.stol = stol | |
| self.angle_tol = angle_tol | |
| def fit(self, structure1: Any, structure2: Any) -> bool: | |
| """ | |
| Check if two structures match. | |
| Args: | |
| structure1: First pymatgen Structure | |
| structure2: Second pymatgen Structure | |
| Returns: | |
| True if structures match | |
| """ | |
| try: | |
| from pymatgen.analysis.structure_matcher import StructureMatcher as PmgMatcher | |
| matcher = PmgMatcher( | |
| ltol=self.ltol, | |
| stol=self.stol, | |
| angle_tol=self.angle_tol | |
| ) | |
| return matcher.fit(structure1, structure2) | |
| except ImportError: | |
| raise ImportError("pymatgen is required for structure matching") | |
| def get_rms_dist(self, structure1: Any, structure2: Any) -> Optional[Tuple[float, float]]: | |
| """ | |
| Get RMS distance between structures. | |
| Args: | |
| structure1: First pymatgen Structure | |
| structure2: Second pymatgen Structure | |
| Returns: | |
| Tuple of (rms_dist, max_dist) or None if no match | |
| """ | |
| try: | |
| from pymatgen.analysis.structure_matcher import StructureMatcher as PmgMatcher | |
| matcher = PmgMatcher( | |
| ltol=self.ltol, | |
| stol=self.stol, | |
| angle_tol=self.angle_tol | |
| ) | |
| return matcher.get_rms_dist(structure1, structure2) | |
| except ImportError: | |
| raise ImportError("pymatgen is required for structure matching") | |