Materials_discovery / model_utils.py
SEUyishu's picture
Upload 16 files
7f0fa00 verified
# Copyright 2024 Google LLC (Original code), Modified for MCP Service
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""
Model utilities for GNoME Materials Discovery MCP Service.
This module provides:
- GNoME model architecture definitions
- NequIP model architecture definitions
- Model loading and inference utilities
- Crystal graph construction
"""
import functools
import json
import os
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
import logging
logger = logging.getLogger(__name__)
# Type definitions
Array = Any
PyTree = Any
Shape = Iterable[int]
Dtype = Any
# Constants
NUM_ELEMENTS = 94
def get_nonlinearity_by_name(name: str) -> Callable:
"""
Get nonlinearity function by name.
Args:
name: Name of nonlinearity ('relu', 'swish', 'tanh', etc.)
Returns:
Nonlinearity function
"""
try:
import jax.numpy as jnp
import flax.linen as nn
nonlinearities = {
'none': lambda x: x,
'relu': nn.relu,
'raw_swish': nn.swish,
'tanh': nn.tanh,
'sigmoid': nn.sigmoid,
'silu': nn.silu,
}
if name in nonlinearities:
return nonlinearities[name]
raise ValueError(f'Nonlinearity "{name}" not found.')
except ImportError:
raise ImportError("JAX and Flax are required for model utilities")
def create_bessel_embedding(count: int, inner_cutoff: float, outer_cutoff: float):
"""
Create Bessel embedding for radial functions.
Args:
count: Number of Bessel basis functions
inner_cutoff: Inner cutoff radius
outer_cutoff: Outer cutoff radius
Returns:
Bessel embedding module
"""
try:
import jax.numpy as jnp
import flax.linen as nn
from functools import partial
from jax import vmap
f32 = jnp.float32
def bessel(r_c, frequencies, r):
rp = jnp.where(r > f32(1e-5), r, f32(1000.0))
b = 2 / r_c * jnp.sin(frequencies * rp / r_c) / rp
return jnp.where(r > f32(1e-5), b, 0)
class BesselEmbedding(nn.Module):
count: int
inner_cutoff: float
outer_cutoff: float
@nn.compact
def __call__(self, rs):
def init_fn(key, shape):
n = shape[0]
return jnp.arange(1, n + 1) * jnp.pi
frequencies = self.param('frequencies', init_fn, (self.count,))
bessel_fn = partial(bessel, self.outer_cutoff, frequencies)
def apply_cutoff(fn, r):
"""Apply smooth cutoff."""
return fn(r) * jnp.where(
r < self.inner_cutoff, 1.0,
jnp.where(r > self.outer_cutoff, 0.0,
0.5 * (1 + jnp.cos(jnp.pi * (r - self.inner_cutoff) /
(self.outer_cutoff - self.inner_cutoff))))
)
return vmap(lambda r: apply_cutoff(bessel_fn, r))(rs)
return BesselEmbedding(count, inner_cutoff, outer_cutoff)
except ImportError:
raise ImportError("JAX and Flax are required for Bessel embedding")
def get_nequip_default_config() -> Dict[str, Any]:
"""
Get default NequIP configuration.
Returns:
Default configuration dictionary
"""
return {
"graph_net_steps": 5,
"nonlinearities": {"e": "raw_swish", "o": "tanh"},
"use_sc": True,
"n_elements": 94,
"hidden_irreps": "128x0e + 64x1e + 4x2e",
"sh_irreps": "1x0e + 1x1e + 1x2e",
"num_basis": 8,
"r_max": 5.0,
"radial_net_nonlinearity": "raw_swish",
"radial_net_n_hidden": 64,
"radial_net_n_layers": 2,
"n_neighbors": 10.0,
"scalar_mlp_std": 4.0,
}
def get_gnome_default_config() -> Dict[str, Any]:
"""
Get default GNoME crystal energy model configuration.
Returns:
Default configuration dictionary
"""
return {
"graph_net_steps": 5,
"mlp_width": (128, 128, 64),
"mlp_nonlinearity": "raw_swish",
"embedding_dim": 128,
"featurizer": "gaussian",
"shift": -1.6526496,
"scale": 1.0,
"feature_band_limit": 0,
"conditioning_band_limit": 0,
"extra_scalars_for_gating": False,
"residual": "none",
"node_aggregation": "mean",
"edges_for_globals_aggregation": "mean",
"readout_edges_for_globals_aggregation": "mean",
}
class ModelLoader:
"""Handles loading and caching of GNoME/NequIP models."""
def __init__(self, model_dir: str = "./models"):
"""
Initialize ModelLoader.
Args:
model_dir: Directory containing model checkpoints
"""
self.model_dir = model_dir
self._models: Dict[str, Any] = {}
self._configs: Dict[str, Dict] = {}
def load_model(self, model_name: str) -> Tuple[Any, Any, Dict]:
"""
Load a model from checkpoint.
Args:
model_name: Name of the model to load
Returns:
Tuple of (model, params, config)
"""
if model_name in self._models:
return self._models[model_name]
try:
import jax
import jax.numpy as jnp
from jax import eval_shape, random
from jax.tree_util import tree_map
from jax.core import ShapedArray
from flax import serialization
import jraph
from ml_collections import ConfigDict
f32 = jnp.float32
i32 = jnp.int32
model_path = os.path.join(self.model_dir, model_name)
# Load config
config_path = os.path.join(model_path, 'config.json')
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found at {config_path}")
with open(config_path, 'r') as f:
config = json.loads(json.loads(f.read()))
config = ConfigDict(config)
# Initialize model based on model family
model_family = config.get('model_family', 'nequip')
if model_family == 'nequip':
model = self._create_nequip_model(config)
else:
raise ValueError(f"Unsupported model family: {model_family}")
# Create abstract graph for initialization
graph = jraph.GraphsTuple(
ShapedArray((1, NUM_ELEMENTS), f32),
ShapedArray((1, 3), f32),
ShapedArray((1,), i32),
ShapedArray((1,), i32),
ShapedArray((1, 1), f32),
ShapedArray((1,), i32),
ShapedArray((1,), i32),
)
# Find checkpoint file
checkpoints = [c for c in os.listdir(model_path) if 'checkpoint' in c]
if not checkpoints:
raise FileNotFoundError(f"No checkpoint found in {model_path}")
checkpoint_path = os.path.join(model_path, checkpoints[0])
# Load parameters
def init_model(graph):
key = random.PRNGKey(0)
params = model.init(key, graph)
return params
abstract_params = eval_shape(init_model, graph)
with open(checkpoint_path, 'rb') as f:
ckpt_data = (0, abstract_params, None)
ckpt = serialization.from_bytes(ckpt_data, f.read())
params = tree_map(lambda x: x.astype(f32), ckpt[1])
self._models[model_name] = (model, params, dict(config))
self._configs[model_name] = dict(config)
return model, params, dict(config)
except ImportError as e:
raise ImportError(f"Required packages not available: {e}")
except Exception as e:
logger.error(f"Error loading model {model_name}: {e}")
raise
def _create_nequip_model(self, config: Any) -> Any:
"""Create NequIP model from config."""
# This is a placeholder - actual implementation would use the nequip module
raise NotImplementedError("NequIP model creation requires full JAX stack")
def get_available_models(self) -> list:
"""
Get list of available models.
Returns:
List of model names
"""
if not os.path.exists(self.model_dir):
return []
return [
d for d in os.listdir(self.model_dir)
if os.path.isdir(os.path.join(self.model_dir, d))
]
def atoms_to_graph(
atoms: Any,
cutoff: float = 5.0,
max_neighbors: int = 100
) -> Dict[str, Any]:
"""
Convert ASE Atoms to graph representation.
Args:
atoms: ASE Atoms object
cutoff: Cutoff radius for neighbor finding
max_neighbors: Maximum number of neighbors per atom
Returns:
Graph dictionary
"""
try:
import numpy as np
from ase.neighborlist import neighbor_list
except ImportError:
raise ImportError("ASE is required for atoms to graph conversion")
# Get neighbor list
i, j, d, D = neighbor_list('ijdD', atoms, cutoff)
# Get atomic numbers and one-hot encode
atomic_numbers = atoms.get_atomic_numbers()
n_atoms = len(atoms)
# Create one-hot encoding
node_features = np.zeros((n_atoms, NUM_ELEMENTS))
for idx, z in enumerate(atomic_numbers):
if z <= NUM_ELEMENTS:
node_features[idx, z - 1] = 1.0
return {
"nodes": node_features,
"edges": D, # Displacement vectors
"senders": i,
"receivers": j,
"n_node": np.array([n_atoms]),
"n_edge": np.array([len(i)]),
"positions": atoms.get_positions(),
"cell": atoms.get_cell()[:],
}
def predict_energy(
model: Any,
params: Any,
graph: Dict[str, Any]
) -> float:
"""
Predict energy for a given graph.
Args:
model: Model instance
params: Model parameters
graph: Graph dictionary
Returns:
Predicted energy
"""
try:
import jax.numpy as jnp
import jraph
# Convert to jraph GraphsTuple
graph_tuple = jraph.GraphsTuple(
nodes=jnp.array(graph["nodes"]),
edges=jnp.array(graph["edges"]),
senders=jnp.array(graph["senders"]),
receivers=jnp.array(graph["receivers"]),
globals=jnp.zeros((1, 1)),
n_node=jnp.array(graph["n_node"]),
n_edge=jnp.array(graph["n_edge"]),
)
energy = model.apply(params, graph_tuple)
return float(energy[0, 0])
except ImportError:
raise ImportError("JAX and jraph are required for energy prediction")
def compute_forces(
model: Any,
params: Any,
graph: Dict[str, Any]
) -> Any:
"""
Compute forces for a given graph.
Args:
model: Model instance
params: Model parameters
graph: Graph dictionary
Returns:
Forces array
"""
try:
import jax
import jax.numpy as jnp
def energy_fn(positions):
g = dict(graph)
g["nodes"] = jnp.array(graph["nodes"])
# Would need to recompute edges based on new positions
return predict_energy(model, params, g)
# Compute negative gradient of energy
positions = jnp.array(graph["positions"])
forces = -jax.grad(energy_fn)(positions)
return forces
except ImportError:
raise ImportError("JAX is required for force computation")
def get_model_info(model_name: str, model_dir: str = "./models") -> Dict[str, Any]:
"""
Get information about a model without loading it.
Args:
model_name: Name of the model
model_dir: Directory containing models
Returns:
Model information dictionary
"""
model_path = os.path.join(model_dir, model_name)
config_path = os.path.join(model_path, 'config.json')
if not os.path.exists(config_path):
return {"error": f"Model {model_name} not found"}
try:
with open(config_path, 'r') as f:
config = json.loads(json.loads(f.read()))
return {
"model_name": model_name,
"model_family": config.get("model_family", "unknown"),
"graph_net_steps": config.get("graph_net_steps"),
"hidden_irreps": config.get("hidden_irreps"),
"r_max": config.get("r_max"),
"n_elements": config.get("n_elements", NUM_ELEMENTS),
}
except Exception as e:
return {"error": str(e)}
class StructureMatcher:
"""Utility class for comparing crystal structures."""
def __init__(
self,
ltol: float = 0.2,
stol: float = 0.3,
angle_tol: float = 5.0
):
"""
Initialize StructureMatcher.
Args:
ltol: Length tolerance
stol: Site tolerance
angle_tol: Angle tolerance in degrees
"""
self.ltol = ltol
self.stol = stol
self.angle_tol = angle_tol
def fit(self, structure1: Any, structure2: Any) -> bool:
"""
Check if two structures match.
Args:
structure1: First pymatgen Structure
structure2: Second pymatgen Structure
Returns:
True if structures match
"""
try:
from pymatgen.analysis.structure_matcher import StructureMatcher as PmgMatcher
matcher = PmgMatcher(
ltol=self.ltol,
stol=self.stol,
angle_tol=self.angle_tol
)
return matcher.fit(structure1, structure2)
except ImportError:
raise ImportError("pymatgen is required for structure matching")
def get_rms_dist(self, structure1: Any, structure2: Any) -> Optional[Tuple[float, float]]:
"""
Get RMS distance between structures.
Args:
structure1: First pymatgen Structure
structure2: Second pymatgen Structure
Returns:
Tuple of (rms_dist, max_dist) or None if no match
"""
try:
from pymatgen.analysis.structure_matcher import StructureMatcher as PmgMatcher
matcher = PmgMatcher(
ltol=self.ltol,
stol=self.stol,
angle_tol=self.angle_tol
)
return matcher.get_rms_dist(structure1, structure2)
except ImportError:
raise ImportError("pymatgen is required for structure matching")