File size: 16,156 Bytes
7f0fa00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
# Copyright 2024 Google LLC (Original code), Modified for MCP Service
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0

"""

Model utilities for GNoME Materials Discovery MCP Service.



This module provides:

- GNoME model architecture definitions

- NequIP model architecture definitions

- Model loading and inference utilities

- Crystal graph construction

"""

import functools
import json
import os
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Union
import logging

logger = logging.getLogger(__name__)

# Type definitions
Array = Any
PyTree = Any
Shape = Iterable[int]
Dtype = Any


# Constants
NUM_ELEMENTS = 94


def get_nonlinearity_by_name(name: str) -> Callable:
    """

    Get nonlinearity function by name.

    

    Args:

        name: Name of nonlinearity ('relu', 'swish', 'tanh', etc.)

        

    Returns:

        Nonlinearity function

    """
    try:
        import jax.numpy as jnp
        import flax.linen as nn
        
        nonlinearities = {
            'none': lambda x: x,
            'relu': nn.relu,
            'raw_swish': nn.swish,
            'tanh': nn.tanh,
            'sigmoid': nn.sigmoid,
            'silu': nn.silu,
        }
        
        if name in nonlinearities:
            return nonlinearities[name]
        raise ValueError(f'Nonlinearity "{name}" not found.')
    except ImportError:
        raise ImportError("JAX and Flax are required for model utilities")


def create_bessel_embedding(count: int, inner_cutoff: float, outer_cutoff: float):
    """

    Create Bessel embedding for radial functions.

    

    Args:

        count: Number of Bessel basis functions

        inner_cutoff: Inner cutoff radius

        outer_cutoff: Outer cutoff radius

        

    Returns:

        Bessel embedding module

    """
    try:
        import jax.numpy as jnp
        import flax.linen as nn
        from functools import partial
        from jax import vmap
        
        f32 = jnp.float32
        
        def bessel(r_c, frequencies, r):
            rp = jnp.where(r > f32(1e-5), r, f32(1000.0))
            b = 2 / r_c * jnp.sin(frequencies * rp / r_c) / rp
            return jnp.where(r > f32(1e-5), b, 0)
        
        class BesselEmbedding(nn.Module):
            count: int
            inner_cutoff: float
            outer_cutoff: float
            
            @nn.compact
            def __call__(self, rs):
                def init_fn(key, shape):
                    n = shape[0]
                    return jnp.arange(1, n + 1) * jnp.pi
                
                frequencies = self.param('frequencies', init_fn, (self.count,))
                bessel_fn = partial(bessel, self.outer_cutoff, frequencies)
                
                def apply_cutoff(fn, r):
                    """Apply smooth cutoff."""
                    return fn(r) * jnp.where(
                        r < self.inner_cutoff, 1.0,
                        jnp.where(r > self.outer_cutoff, 0.0,
                            0.5 * (1 + jnp.cos(jnp.pi * (r - self.inner_cutoff) / 
                                   (self.outer_cutoff - self.inner_cutoff))))
                    )
                
                return vmap(lambda r: apply_cutoff(bessel_fn, r))(rs)
        
        return BesselEmbedding(count, inner_cutoff, outer_cutoff)
    except ImportError:
        raise ImportError("JAX and Flax are required for Bessel embedding")


def get_nequip_default_config() -> Dict[str, Any]:
    """

    Get default NequIP configuration.

    

    Returns:

        Default configuration dictionary

    """
    return {
        "graph_net_steps": 5,
        "nonlinearities": {"e": "raw_swish", "o": "tanh"},
        "use_sc": True,
        "n_elements": 94,
        "hidden_irreps": "128x0e + 64x1e + 4x2e",
        "sh_irreps": "1x0e + 1x1e + 1x2e",
        "num_basis": 8,
        "r_max": 5.0,
        "radial_net_nonlinearity": "raw_swish",
        "radial_net_n_hidden": 64,
        "radial_net_n_layers": 2,
        "n_neighbors": 10.0,
        "scalar_mlp_std": 4.0,
    }


def get_gnome_default_config() -> Dict[str, Any]:
    """

    Get default GNoME crystal energy model configuration.

    

    Returns:

        Default configuration dictionary

    """
    return {
        "graph_net_steps": 5,
        "mlp_width": (128, 128, 64),
        "mlp_nonlinearity": "raw_swish",
        "embedding_dim": 128,
        "featurizer": "gaussian",
        "shift": -1.6526496,
        "scale": 1.0,
        "feature_band_limit": 0,
        "conditioning_band_limit": 0,
        "extra_scalars_for_gating": False,
        "residual": "none",
        "node_aggregation": "mean",
        "edges_for_globals_aggregation": "mean",
        "readout_edges_for_globals_aggregation": "mean",
    }


class ModelLoader:
    """Handles loading and caching of GNoME/NequIP models."""
    
    def __init__(self, model_dir: str = "./models"):
        """

        Initialize ModelLoader.

        

        Args:

            model_dir: Directory containing model checkpoints

        """
        self.model_dir = model_dir
        self._models: Dict[str, Any] = {}
        self._configs: Dict[str, Dict] = {}
        
    def load_model(self, model_name: str) -> Tuple[Any, Any, Dict]:
        """

        Load a model from checkpoint.

        

        Args:

            model_name: Name of the model to load

            

        Returns:

            Tuple of (model, params, config)

        """
        if model_name in self._models:
            return self._models[model_name]
        
        try:
            import jax
            import jax.numpy as jnp
            from jax import eval_shape, random
            from jax.tree_util import tree_map
            from jax.core import ShapedArray
            from flax import serialization
            import jraph
            from ml_collections import ConfigDict
            
            f32 = jnp.float32
            i32 = jnp.int32
            
            model_path = os.path.join(self.model_dir, model_name)
            
            # Load config
            config_path = os.path.join(model_path, 'config.json')
            if not os.path.exists(config_path):
                raise FileNotFoundError(f"Config not found at {config_path}")
                
            with open(config_path, 'r') as f:
                config = json.loads(json.loads(f.read()))
                config = ConfigDict(config)
            
            # Initialize model based on model family
            model_family = config.get('model_family', 'nequip')
            
            if model_family == 'nequip':
                model = self._create_nequip_model(config)
            else:
                raise ValueError(f"Unsupported model family: {model_family}")
            
            # Create abstract graph for initialization
            graph = jraph.GraphsTuple(
                ShapedArray((1, NUM_ELEMENTS), f32),
                ShapedArray((1, 3), f32),
                ShapedArray((1,), i32),
                ShapedArray((1,), i32),
                ShapedArray((1, 1), f32),
                ShapedArray((1,), i32),
                ShapedArray((1,), i32),
            )
            
            # Find checkpoint file
            checkpoints = [c for c in os.listdir(model_path) if 'checkpoint' in c]
            if not checkpoints:
                raise FileNotFoundError(f"No checkpoint found in {model_path}")
            
            checkpoint_path = os.path.join(model_path, checkpoints[0])
            
            # Load parameters
            def init_model(graph):
                key = random.PRNGKey(0)
                params = model.init(key, graph)
                return params
            
            abstract_params = eval_shape(init_model, graph)
            
            with open(checkpoint_path, 'rb') as f:
                ckpt_data = (0, abstract_params, None)
                ckpt = serialization.from_bytes(ckpt_data, f.read())
            
            params = tree_map(lambda x: x.astype(f32), ckpt[1])
            
            self._models[model_name] = (model, params, dict(config))
            self._configs[model_name] = dict(config)
            
            return model, params, dict(config)
            
        except ImportError as e:
            raise ImportError(f"Required packages not available: {e}")
        except Exception as e:
            logger.error(f"Error loading model {model_name}: {e}")
            raise
            
    def _create_nequip_model(self, config: Any) -> Any:
        """Create NequIP model from config."""
        # This is a placeholder - actual implementation would use the nequip module
        raise NotImplementedError("NequIP model creation requires full JAX stack")
        
    def get_available_models(self) -> list:
        """

        Get list of available models.

        

        Returns:

            List of model names

        """
        if not os.path.exists(self.model_dir):
            return []
        return [
            d for d in os.listdir(self.model_dir)
            if os.path.isdir(os.path.join(self.model_dir, d))
        ]


def atoms_to_graph(

    atoms: Any,

    cutoff: float = 5.0,

    max_neighbors: int = 100

) -> Dict[str, Any]:
    """

    Convert ASE Atoms to graph representation.

    

    Args:

        atoms: ASE Atoms object

        cutoff: Cutoff radius for neighbor finding

        max_neighbors: Maximum number of neighbors per atom

        

    Returns:

        Graph dictionary

    """
    try:
        import numpy as np
        from ase.neighborlist import neighbor_list
    except ImportError:
        raise ImportError("ASE is required for atoms to graph conversion")
    
    # Get neighbor list
    i, j, d, D = neighbor_list('ijdD', atoms, cutoff)
    
    # Get atomic numbers and one-hot encode
    atomic_numbers = atoms.get_atomic_numbers()
    n_atoms = len(atoms)
    
    # Create one-hot encoding
    node_features = np.zeros((n_atoms, NUM_ELEMENTS))
    for idx, z in enumerate(atomic_numbers):
        if z <= NUM_ELEMENTS:
            node_features[idx, z - 1] = 1.0
    
    return {
        "nodes": node_features,
        "edges": D,  # Displacement vectors
        "senders": i,
        "receivers": j,
        "n_node": np.array([n_atoms]),
        "n_edge": np.array([len(i)]),
        "positions": atoms.get_positions(),
        "cell": atoms.get_cell()[:],
    }


def predict_energy(

    model: Any,

    params: Any,

    graph: Dict[str, Any]

) -> float:
    """

    Predict energy for a given graph.

    

    Args:

        model: Model instance

        params: Model parameters

        graph: Graph dictionary

        

    Returns:

        Predicted energy

    """
    try:
        import jax.numpy as jnp
        import jraph
        
        # Convert to jraph GraphsTuple
        graph_tuple = jraph.GraphsTuple(
            nodes=jnp.array(graph["nodes"]),
            edges=jnp.array(graph["edges"]),
            senders=jnp.array(graph["senders"]),
            receivers=jnp.array(graph["receivers"]),
            globals=jnp.zeros((1, 1)),
            n_node=jnp.array(graph["n_node"]),
            n_edge=jnp.array(graph["n_edge"]),
        )
        
        energy = model.apply(params, graph_tuple)
        return float(energy[0, 0])
        
    except ImportError:
        raise ImportError("JAX and jraph are required for energy prediction")


def compute_forces(

    model: Any,

    params: Any,

    graph: Dict[str, Any]

) -> Any:
    """

    Compute forces for a given graph.

    

    Args:

        model: Model instance

        params: Model parameters

        graph: Graph dictionary

        

    Returns:

        Forces array

    """
    try:
        import jax
        import jax.numpy as jnp
        
        def energy_fn(positions):
            g = dict(graph)
            g["nodes"] = jnp.array(graph["nodes"])
            # Would need to recompute edges based on new positions
            return predict_energy(model, params, g)
        
        # Compute negative gradient of energy
        positions = jnp.array(graph["positions"])
        forces = -jax.grad(energy_fn)(positions)
        
        return forces
        
    except ImportError:
        raise ImportError("JAX is required for force computation")


def get_model_info(model_name: str, model_dir: str = "./models") -> Dict[str, Any]:
    """

    Get information about a model without loading it.

    

    Args:

        model_name: Name of the model

        model_dir: Directory containing models

        

    Returns:

        Model information dictionary

    """
    model_path = os.path.join(model_dir, model_name)
    config_path = os.path.join(model_path, 'config.json')
    
    if not os.path.exists(config_path):
        return {"error": f"Model {model_name} not found"}
    
    try:
        with open(config_path, 'r') as f:
            config = json.loads(json.loads(f.read()))
        
        return {
            "model_name": model_name,
            "model_family": config.get("model_family", "unknown"),
            "graph_net_steps": config.get("graph_net_steps"),
            "hidden_irreps": config.get("hidden_irreps"),
            "r_max": config.get("r_max"),
            "n_elements": config.get("n_elements", NUM_ELEMENTS),
        }
    except Exception as e:
        return {"error": str(e)}


class StructureMatcher:
    """Utility class for comparing crystal structures."""
    
    def __init__(

        self,

        ltol: float = 0.2,

        stol: float = 0.3,

        angle_tol: float = 5.0

    ):
        """

        Initialize StructureMatcher.

        

        Args:

            ltol: Length tolerance

            stol: Site tolerance

            angle_tol: Angle tolerance in degrees

        """
        self.ltol = ltol
        self.stol = stol
        self.angle_tol = angle_tol
        
    def fit(self, structure1: Any, structure2: Any) -> bool:
        """

        Check if two structures match.

        

        Args:

            structure1: First pymatgen Structure

            structure2: Second pymatgen Structure

            

        Returns:

            True if structures match

        """
        try:
            from pymatgen.analysis.structure_matcher import StructureMatcher as PmgMatcher
            
            matcher = PmgMatcher(
                ltol=self.ltol,
                stol=self.stol,
                angle_tol=self.angle_tol
            )
            
            return matcher.fit(structure1, structure2)
        except ImportError:
            raise ImportError("pymatgen is required for structure matching")
            
    def get_rms_dist(self, structure1: Any, structure2: Any) -> Optional[Tuple[float, float]]:
        """

        Get RMS distance between structures.

        

        Args:

            structure1: First pymatgen Structure

            structure2: Second pymatgen Structure

            

        Returns:

            Tuple of (rms_dist, max_dist) or None if no match

        """
        try:
            from pymatgen.analysis.structure_matcher import StructureMatcher as PmgMatcher
            
            matcher = PmgMatcher(
                ltol=self.ltol,
                stol=self.stol,
                angle_tol=self.angle_tol
            )
            
            return matcher.get_rms_dist(structure1, structure2)
        except ImportError:
            raise ImportError("pymatgen is required for structure matching")