# Copyright 2024 Google LLC (Original code), Modified for MCP Service # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 """ Model utilities for GNoME Materials Discovery MCP Service. This module provides: - GNoME model architecture definitions - NequIP model architecture definitions - Model loading and inference utilities - Crystal graph construction """ import functools import json import os from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union import logging logger = logging.getLogger(__name__) # Type definitions Array = Any PyTree = Any Shape = Iterable[int] Dtype = Any # Constants NUM_ELEMENTS = 94 def get_nonlinearity_by_name(name: str) -> Callable: """ Get nonlinearity function by name. Args: name: Name of nonlinearity ('relu', 'swish', 'tanh', etc.) Returns: Nonlinearity function """ try: import jax.numpy as jnp import flax.linen as nn nonlinearities = { 'none': lambda x: x, 'relu': nn.relu, 'raw_swish': nn.swish, 'tanh': nn.tanh, 'sigmoid': nn.sigmoid, 'silu': nn.silu, } if name in nonlinearities: return nonlinearities[name] raise ValueError(f'Nonlinearity "{name}" not found.') except ImportError: raise ImportError("JAX and Flax are required for model utilities") def create_bessel_embedding(count: int, inner_cutoff: float, outer_cutoff: float): """ Create Bessel embedding for radial functions. Args: count: Number of Bessel basis functions inner_cutoff: Inner cutoff radius outer_cutoff: Outer cutoff radius Returns: Bessel embedding module """ try: import jax.numpy as jnp import flax.linen as nn from functools import partial from jax import vmap f32 = jnp.float32 def bessel(r_c, frequencies, r): rp = jnp.where(r > f32(1e-5), r, f32(1000.0)) b = 2 / r_c * jnp.sin(frequencies * rp / r_c) / rp return jnp.where(r > f32(1e-5), b, 0) class BesselEmbedding(nn.Module): count: int inner_cutoff: float outer_cutoff: float @nn.compact def __call__(self, rs): def init_fn(key, shape): n = shape[0] return jnp.arange(1, n + 1) * jnp.pi frequencies = self.param('frequencies', init_fn, (self.count,)) bessel_fn = partial(bessel, self.outer_cutoff, frequencies) def apply_cutoff(fn, r): """Apply smooth cutoff.""" return fn(r) * jnp.where( r < self.inner_cutoff, 1.0, jnp.where(r > self.outer_cutoff, 0.0, 0.5 * (1 + jnp.cos(jnp.pi * (r - self.inner_cutoff) / (self.outer_cutoff - self.inner_cutoff)))) ) return vmap(lambda r: apply_cutoff(bessel_fn, r))(rs) return BesselEmbedding(count, inner_cutoff, outer_cutoff) except ImportError: raise ImportError("JAX and Flax are required for Bessel embedding") def get_nequip_default_config() -> Dict[str, Any]: """ Get default NequIP configuration. Returns: Default configuration dictionary """ return { "graph_net_steps": 5, "nonlinearities": {"e": "raw_swish", "o": "tanh"}, "use_sc": True, "n_elements": 94, "hidden_irreps": "128x0e + 64x1e + 4x2e", "sh_irreps": "1x0e + 1x1e + 1x2e", "num_basis": 8, "r_max": 5.0, "radial_net_nonlinearity": "raw_swish", "radial_net_n_hidden": 64, "radial_net_n_layers": 2, "n_neighbors": 10.0, "scalar_mlp_std": 4.0, } def get_gnome_default_config() -> Dict[str, Any]: """ Get default GNoME crystal energy model configuration. Returns: Default configuration dictionary """ return { "graph_net_steps": 5, "mlp_width": (128, 128, 64), "mlp_nonlinearity": "raw_swish", "embedding_dim": 128, "featurizer": "gaussian", "shift": -1.6526496, "scale": 1.0, "feature_band_limit": 0, "conditioning_band_limit": 0, "extra_scalars_for_gating": False, "residual": "none", "node_aggregation": "mean", "edges_for_globals_aggregation": "mean", "readout_edges_for_globals_aggregation": "mean", } class ModelLoader: """Handles loading and caching of GNoME/NequIP models.""" def __init__(self, model_dir: str = "./models"): """ Initialize ModelLoader. Args: model_dir: Directory containing model checkpoints """ self.model_dir = model_dir self._models: Dict[str, Any] = {} self._configs: Dict[str, Dict] = {} def load_model(self, model_name: str) -> Tuple[Any, Any, Dict]: """ Load a model from checkpoint. Args: model_name: Name of the model to load Returns: Tuple of (model, params, config) """ if model_name in self._models: return self._models[model_name] try: import jax import jax.numpy as jnp from jax import eval_shape, random from jax.tree_util import tree_map from jax.core import ShapedArray from flax import serialization import jraph from ml_collections import ConfigDict f32 = jnp.float32 i32 = jnp.int32 model_path = os.path.join(self.model_dir, model_name) # Load config config_path = os.path.join(model_path, 'config.json') if not os.path.exists(config_path): raise FileNotFoundError(f"Config not found at {config_path}") with open(config_path, 'r') as f: config = json.loads(json.loads(f.read())) config = ConfigDict(config) # Initialize model based on model family model_family = config.get('model_family', 'nequip') if model_family == 'nequip': model = self._create_nequip_model(config) else: raise ValueError(f"Unsupported model family: {model_family}") # Create abstract graph for initialization graph = jraph.GraphsTuple( ShapedArray((1, NUM_ELEMENTS), f32), ShapedArray((1, 3), f32), ShapedArray((1,), i32), ShapedArray((1,), i32), ShapedArray((1, 1), f32), ShapedArray((1,), i32), ShapedArray((1,), i32), ) # Find checkpoint file checkpoints = [c for c in os.listdir(model_path) if 'checkpoint' in c] if not checkpoints: raise FileNotFoundError(f"No checkpoint found in {model_path}") checkpoint_path = os.path.join(model_path, checkpoints[0]) # Load parameters def init_model(graph): key = random.PRNGKey(0) params = model.init(key, graph) return params abstract_params = eval_shape(init_model, graph) with open(checkpoint_path, 'rb') as f: ckpt_data = (0, abstract_params, None) ckpt = serialization.from_bytes(ckpt_data, f.read()) params = tree_map(lambda x: x.astype(f32), ckpt[1]) self._models[model_name] = (model, params, dict(config)) self._configs[model_name] = dict(config) return model, params, dict(config) except ImportError as e: raise ImportError(f"Required packages not available: {e}") except Exception as e: logger.error(f"Error loading model {model_name}: {e}") raise def _create_nequip_model(self, config: Any) -> Any: """Create NequIP model from config.""" # This is a placeholder - actual implementation would use the nequip module raise NotImplementedError("NequIP model creation requires full JAX stack") def get_available_models(self) -> list: """ Get list of available models. Returns: List of model names """ if not os.path.exists(self.model_dir): return [] return [ d for d in os.listdir(self.model_dir) if os.path.isdir(os.path.join(self.model_dir, d)) ] def atoms_to_graph( atoms: Any, cutoff: float = 5.0, max_neighbors: int = 100 ) -> Dict[str, Any]: """ Convert ASE Atoms to graph representation. Args: atoms: ASE Atoms object cutoff: Cutoff radius for neighbor finding max_neighbors: Maximum number of neighbors per atom Returns: Graph dictionary """ try: import numpy as np from ase.neighborlist import neighbor_list except ImportError: raise ImportError("ASE is required for atoms to graph conversion") # Get neighbor list i, j, d, D = neighbor_list('ijdD', atoms, cutoff) # Get atomic numbers and one-hot encode atomic_numbers = atoms.get_atomic_numbers() n_atoms = len(atoms) # Create one-hot encoding node_features = np.zeros((n_atoms, NUM_ELEMENTS)) for idx, z in enumerate(atomic_numbers): if z <= NUM_ELEMENTS: node_features[idx, z - 1] = 1.0 return { "nodes": node_features, "edges": D, # Displacement vectors "senders": i, "receivers": j, "n_node": np.array([n_atoms]), "n_edge": np.array([len(i)]), "positions": atoms.get_positions(), "cell": atoms.get_cell()[:], } def predict_energy( model: Any, params: Any, graph: Dict[str, Any] ) -> float: """ Predict energy for a given graph. Args: model: Model instance params: Model parameters graph: Graph dictionary Returns: Predicted energy """ try: import jax.numpy as jnp import jraph # Convert to jraph GraphsTuple graph_tuple = jraph.GraphsTuple( nodes=jnp.array(graph["nodes"]), edges=jnp.array(graph["edges"]), senders=jnp.array(graph["senders"]), receivers=jnp.array(graph["receivers"]), globals=jnp.zeros((1, 1)), n_node=jnp.array(graph["n_node"]), n_edge=jnp.array(graph["n_edge"]), ) energy = model.apply(params, graph_tuple) return float(energy[0, 0]) except ImportError: raise ImportError("JAX and jraph are required for energy prediction") def compute_forces( model: Any, params: Any, graph: Dict[str, Any] ) -> Any: """ Compute forces for a given graph. Args: model: Model instance params: Model parameters graph: Graph dictionary Returns: Forces array """ try: import jax import jax.numpy as jnp def energy_fn(positions): g = dict(graph) g["nodes"] = jnp.array(graph["nodes"]) # Would need to recompute edges based on new positions return predict_energy(model, params, g) # Compute negative gradient of energy positions = jnp.array(graph["positions"]) forces = -jax.grad(energy_fn)(positions) return forces except ImportError: raise ImportError("JAX is required for force computation") def get_model_info(model_name: str, model_dir: str = "./models") -> Dict[str, Any]: """ Get information about a model without loading it. Args: model_name: Name of the model model_dir: Directory containing models Returns: Model information dictionary """ model_path = os.path.join(model_dir, model_name) config_path = os.path.join(model_path, 'config.json') if not os.path.exists(config_path): return {"error": f"Model {model_name} not found"} try: with open(config_path, 'r') as f: config = json.loads(json.loads(f.read())) return { "model_name": model_name, "model_family": config.get("model_family", "unknown"), "graph_net_steps": config.get("graph_net_steps"), "hidden_irreps": config.get("hidden_irreps"), "r_max": config.get("r_max"), "n_elements": config.get("n_elements", NUM_ELEMENTS), } except Exception as e: return {"error": str(e)} class StructureMatcher: """Utility class for comparing crystal structures.""" def __init__( self, ltol: float = 0.2, stol: float = 0.3, angle_tol: float = 5.0 ): """ Initialize StructureMatcher. Args: ltol: Length tolerance stol: Site tolerance angle_tol: Angle tolerance in degrees """ self.ltol = ltol self.stol = stol self.angle_tol = angle_tol def fit(self, structure1: Any, structure2: Any) -> bool: """ Check if two structures match. Args: structure1: First pymatgen Structure structure2: Second pymatgen Structure Returns: True if structures match """ try: from pymatgen.analysis.structure_matcher import StructureMatcher as PmgMatcher matcher = PmgMatcher( ltol=self.ltol, stol=self.stol, angle_tol=self.angle_tol ) return matcher.fit(structure1, structure2) except ImportError: raise ImportError("pymatgen is required for structure matching") def get_rms_dist(self, structure1: Any, structure2: Any) -> Optional[Tuple[float, float]]: """ Get RMS distance between structures. Args: structure1: First pymatgen Structure structure2: Second pymatgen Structure Returns: Tuple of (rms_dist, max_dist) or None if no match """ try: from pymatgen.analysis.structure_matcher import StructureMatcher as PmgMatcher matcher = PmgMatcher( ltol=self.ltol, stol=self.stol, angle_tol=self.angle_tol ) return matcher.get_rms_dist(structure1, structure2) except ImportError: raise ImportError("pymatgen is required for structure matching")