SEUyishu commited on
Commit
a4e2367
·
verified ·
1 Parent(s): a13bf36

Upload 3 files

Browse files
Files changed (3) hide show
  1. entrypoint.py +23 -5
  2. mcp_server.py +925 -915
  3. requirements.txt +1 -2
entrypoint.py CHANGED
@@ -97,14 +97,32 @@ def main():
97
  logger.warning("Use download_dataset() tool to download data.")
98
 
99
  logger.info("=" * 60)
100
- logger.info("Starting MCP Server...")
101
  logger.info("=" * 60)
102
 
103
- # Import and run the MCP server
104
- from mcp_server import mcp
105
 
106
- # Run with SSE transport
107
- mcp.run(transport="sse", host=host, port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
 
110
  if __name__ == "__main__":
 
97
  logger.warning("Use download_dataset() tool to download data.")
98
 
99
  logger.info("=" * 60)
100
+ logger.info("Starting MCP Server with SSE transport...")
101
  logger.info("=" * 60)
102
 
103
+ # Import the MCP server
104
+ from mcp_server import mcp, create_sse_app
105
 
106
+ # Create Starlette app with SSE transport for HuggingFace Spaces
107
+ import uvicorn
108
+ from starlette.applications import Starlette
109
+ from starlette.routing import Mount, Route
110
+ from starlette.responses import JSONResponse
111
+
112
+ async def health_check(request):
113
+ """Health check endpoint."""
114
+ return JSONResponse({"status": "healthy", "service": "GNoME MCP Server"})
115
+
116
+ # Create Starlette app
117
+ app = Starlette(
118
+ routes=[
119
+ Route("/health", health_check),
120
+ Mount("/", app=create_sse_app()),
121
+ ]
122
+ )
123
+
124
+ # Run with uvicorn
125
+ uvicorn.run(app, host=host, port=port)
126
 
127
 
128
  if __name__ == "__main__":
mcp_server.py CHANGED
@@ -1,915 +1,925 @@
1
- # Copyright 2024 Google LLC (Original code), Modified for MCP Service
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- """
10
- GNoME Materials Discovery MCP Server
11
-
12
- This is the main MCP server implementation providing tools for:
13
- - Dataset access and querying
14
- - Decomposition energy calculation
15
- - Phase diagram analysis
16
- - Crystal structure operations
17
- - Air stability analysis
18
- - Model inference
19
- """
20
-
21
- import os
22
- import json
23
- import logging
24
- from typing import Optional, List, Dict, Any
25
- from contextlib import asynccontextmanager
26
- from collections.abc import AsyncIterator
27
-
28
- from mcp.server.fastmcp import FastMCP
29
-
30
- # Import local modules
31
- from data_utils import DataManager, get_data_manager
32
- from phase_diagram_utils import (
33
- compute_decomposition_energy,
34
- build_phase_diagram,
35
- compute_air_stability,
36
- compare_with_materials_project,
37
- find_competing_phases
38
- )
39
- from model_utils import (
40
- ModelLoader,
41
- atoms_to_graph,
42
- get_model_info,
43
- get_nequip_default_config,
44
- get_gnome_default_config,
45
- StructureMatcher
46
- )
47
-
48
- # Configure logging
49
- logging.basicConfig(level=logging.INFO)
50
- logger = logging.getLogger(__name__)
51
-
52
- # Data directory configuration - must match Dockerfile ENV
53
- DATA_DIR = os.environ.get("GNOME_DATA_DIR", "/app/gnome_data")
54
- MODEL_DIR = os.environ.get("GNOME_MODEL_DIR", "/app/models")
55
-
56
- # Ensure directories exist at module load
57
- os.makedirs(DATA_DIR, exist_ok=True)
58
- os.makedirs(MODEL_DIR, exist_ok=True)
59
-
60
-
61
- @asynccontextmanager
62
- async def app_lifespan(server: FastMCP) -> AsyncIterator[Dict[str, Any]]:
63
- """
64
- Application lifespan context manager.
65
- Initializes data manager and model loader.
66
- """
67
- logger.info("Initializing GNoME Materials Discovery MCP Server...")
68
-
69
- # Initialize data manager
70
- data_manager = DataManager(DATA_DIR)
71
- model_loader = ModelLoader(MODEL_DIR)
72
-
73
- logger.info(f"Data directory: {DATA_DIR}")
74
- logger.info(f"Model directory: {MODEL_DIR}")
75
-
76
- yield {
77
- "data_manager": data_manager,
78
- "model_loader": model_loader
79
- }
80
-
81
- # Cleanup
82
- logger.info("Shutting down GNoME MCP Server...")
83
- data_manager.close()
84
-
85
-
86
- # Create FastMCP server
87
- mcp = FastMCP("GNoME Materials Discovery")
88
-
89
-
90
- # ============================================================================
91
- # Dataset Access Tools
92
- # ============================================================================
93
-
94
- @mcp.tool()
95
- async def get_dataset_statistics() -> Dict[str, Any]:
96
- """
97
- Get statistics about the GNoME materials discovery dataset.
98
-
99
- Returns information about:
100
- - Total number of materials
101
- - Unique compositions and formulas
102
- - Crystal system distribution
103
- - Average formation energy
104
- - Element coverage
105
- """
106
- try:
107
- dm = get_data_manager(DATA_DIR)
108
- stats = dm.get_statistics()
109
- return {
110
- "status": "success",
111
- "data": stats
112
- }
113
- except Exception as e:
114
- logger.error(f"Error getting statistics: {e}")
115
- return {"status": "error", "message": str(e)}
116
-
117
-
118
- @mcp.tool()
119
- async def query_materials(
120
- composition: Optional[str] = None,
121
- elements: Optional[str] = None,
122
- space_group: Optional[int] = None,
123
- crystal_system: Optional[str] = None,
124
- min_bandgap: Optional[float] = None,
125
- max_bandgap: Optional[float] = None,
126
- max_decomposition_energy: Optional[float] = None,
127
- limit: int = 50
128
- ) -> Dict[str, Any]:
129
- """
130
- Query materials from the GNoME dataset with various filters.
131
-
132
- Args:
133
- composition: Exact composition to match (e.g., "Li2O")
134
- elements: Comma-separated list of elements that must be present (e.g., "Li,O")
135
- space_group: Space group number to filter by
136
- crystal_system: Crystal system name (e.g., "cubic", "hexagonal")
137
- min_bandgap: Minimum bandgap value in eV
138
- max_bandgap: Maximum bandgap value in eV
139
- max_decomposition_energy: Maximum decomposition energy per atom in eV
140
- limit: Maximum number of results to return (default: 50)
141
-
142
- Returns:
143
- List of matching materials with their properties
144
- """
145
- try:
146
- dm = get_data_manager(DATA_DIR)
147
-
148
- elements_list = None
149
- if elements:
150
- elements_list = [e.strip() for e in elements.split(",")]
151
-
152
- results = dm.query_by_composition(
153
- composition=composition,
154
- elements=elements_list,
155
- space_group=space_group,
156
- crystal_system=crystal_system,
157
- min_bandgap=min_bandgap,
158
- max_bandgap=max_bandgap,
159
- max_decomposition_energy=max_decomposition_energy,
160
- limit=limit
161
- )
162
-
163
- # Convert to list of dicts
164
- materials = results.to_dict(orient='records')
165
-
166
- return {
167
- "status": "success",
168
- "count": len(materials),
169
- "materials": materials
170
- }
171
- except Exception as e:
172
- logger.error(f"Error querying materials: {e}")
173
- return {"status": "error", "message": str(e)}
174
-
175
-
176
- @mcp.tool()
177
- async def get_material_by_id(material_id: str) -> Dict[str, Any]:
178
- """
179
- Get detailed information about a specific material by its ID.
180
-
181
- Args:
182
- material_id: The unique MaterialId from the GNoME dataset
183
-
184
- Returns:
185
- Complete material information including structure and properties
186
- """
187
- try:
188
- dm = get_data_manager(DATA_DIR)
189
- material = dm.get_crystal_by_id(material_id)
190
-
191
- if material is None:
192
- return {"status": "error", "message": f"Material {material_id} not found"}
193
-
194
- return {
195
- "status": "success",
196
- "material": material.to_dict()
197
- }
198
- except Exception as e:
199
- logger.error(f"Error getting material: {e}")
200
- return {"status": "error", "message": str(e)}
201
-
202
-
203
- @mcp.tool()
204
- async def get_random_material(
205
- crystal_system: Optional[str] = None,
206
- n_elements: Optional[int] = None
207
- ) -> Dict[str, Any]:
208
- """
209
- Get a random material from the GNoME dataset.
210
-
211
- Args:
212
- crystal_system: Optional filter by crystal system
213
- n_elements: Optional filter by number of elements (e.g., 2 for binary, 3 for ternary)
214
-
215
- Returns:
216
- Random material information
217
- """
218
- try:
219
- dm = get_data_manager(DATA_DIR)
220
- crystals = dm.load_gnome_crystals()
221
-
222
- if crystal_system:
223
- crystals = crystals[crystals['Crystal System'] == crystal_system]
224
-
225
- if n_elements:
226
- crystals = crystals[crystals['Chemical System'].map(len) == n_elements]
227
-
228
- if len(crystals) == 0:
229
- return {"status": "error", "message": "No materials match the criteria"}
230
-
231
- sample = crystals.sample(1).iloc[0]
232
-
233
- return {
234
- "status": "success",
235
- "material": sample.to_dict()
236
- }
237
- except Exception as e:
238
- logger.error(f"Error getting random material: {e}")
239
- return {"status": "error", "message": str(e)}
240
-
241
-
242
- # ============================================================================
243
- # Phase Diagram and Stability Tools
244
- # ============================================================================
245
-
246
- @mcp.tool()
247
- async def calculate_decomposition_energy(
248
- composition: str,
249
- energy: float
250
- ) -> Dict[str, Any]:
251
- """
252
- Calculate the decomposition energy of a material relative to the GNoME convex hull.
253
-
254
- This determines whether a material is thermodynamically stable or metastable.
255
- A negative or zero decomposition energy indicates stability.
256
-
257
- Args:
258
- composition: Chemical composition (e.g., "LiFePO4", "Li2O")
259
- energy: Total corrected energy from DFT calculation in eV
260
-
261
- Returns:
262
- Decomposition energy and decomposition products
263
- """
264
- try:
265
- import pymatgen as mg
266
-
267
- dm = get_data_manager(DATA_DIR)
268
- all_crystals = dm.load_all_crystals()
269
- grouped = dm.get_grouped_entries()
270
-
271
- # Get chemical system from composition
272
- comp = mg.core.Composition(composition)
273
- chemsys = [str(el) for el in comp.elements]
274
-
275
- result = compute_decomposition_energy(
276
- composition=composition,
277
- energy=energy,
278
- chemsys=chemsys,
279
- grouped_entries=grouped,
280
- all_crystals=all_crystals
281
- )
282
-
283
- return {
284
- "status": "success",
285
- **result
286
- }
287
- except Exception as e:
288
- logger.error(f"Error calculating decomposition energy: {e}")
289
- return {"status": "error", "message": str(e)}
290
-
291
-
292
- @mcp.tool()
293
- async def get_phase_diagram(
294
- elements: str
295
- ) -> Dict[str, Any]:
296
- """
297
- Build and analyze the phase diagram for a chemical system.
298
-
299
- Args:
300
- elements: Comma or dash separated list of elements (e.g., "Li,Fe,P,O" or "Li-Fe-P-O")
301
-
302
- Returns:
303
- Phase diagram information including stable and unstable entries
304
- """
305
- try:
306
- import re
307
-
308
- dm = get_data_manager(DATA_DIR)
309
- all_crystals = dm.load_all_crystals()
310
- grouped = dm.get_grouped_entries()
311
-
312
- # Parse elements
313
- chemsys = re.split(r'[\s,\-]+', elements)
314
- chemsys = [e.strip() for e in chemsys if e.strip()]
315
-
316
- result = build_phase_diagram(
317
- chemsys=chemsys,
318
- grouped_entries=grouped,
319
- all_crystals=all_crystals
320
- )
321
-
322
- return {
323
- "status": "success",
324
- **result
325
- }
326
- except Exception as e:
327
- logger.error(f"Error building phase diagram: {e}")
328
- return {"status": "error", "message": str(e)}
329
-
330
-
331
- @mcp.tool()
332
- async def calculate_air_stability(
333
- composition: str,
334
- energy: float,
335
- temperature: float = 300.0,
336
- oxygen_pressure: float = 21200.0
337
- ) -> Dict[str, Any]:
338
- """
339
- Calculate the air stability of a material.
340
-
341
- Analyzes stability with respect to:
342
- - Oxygen (via grand potential phase diagram)
343
- - Carbon dioxide (CO2 reactivity)
344
- - Water (H2O reactivity)
345
-
346
- Args:
347
- composition: Chemical composition (e.g., "Li3N", "NaCl")
348
- energy: Total corrected energy in eV
349
- temperature: Temperature in Kelvin (default: 300K)
350
- oxygen_pressure: Oxygen partial pressure in Pa (default: 21200 Pa, ambient)
351
-
352
- Returns:
353
- Air stability analysis results
354
- """
355
- try:
356
- import pymatgen as mg
357
-
358
- dm = get_data_manager(DATA_DIR)
359
- all_crystals = dm.load_all_crystals()
360
- grouped = dm.get_grouped_entries()
361
-
362
- comp = mg.core.Composition(composition)
363
- chemsys = [str(el) for el in comp.elements]
364
-
365
- result = compute_air_stability(
366
- composition=composition,
367
- energy=energy,
368
- chemsys=chemsys,
369
- grouped_entries=grouped,
370
- all_crystals=all_crystals,
371
- temperature=temperature,
372
- oxygen_pressure=oxygen_pressure
373
- )
374
-
375
- return {
376
- "status": "success",
377
- **result
378
- }
379
- except Exception as e:
380
- logger.error(f"Error calculating air stability: {e}")
381
- return {"status": "error", "message": str(e)}
382
-
383
-
384
- @mcp.tool()
385
- async def compare_gnome_with_mp(
386
- elements: str
387
- ) -> Dict[str, Any]:
388
- """
389
- Compare GNoME phase diagram with Materials Project for a chemical system.
390
-
391
- Identifies:
392
- - New stable phases discovered by GNoME
393
- - Phases only in Materials Project
394
- - Common stable phases
395
-
396
- Args:
397
- elements: Comma or dash separated list of elements
398
-
399
- Returns:
400
- Comparison results between GNoME and Materials Project
401
- """
402
- try:
403
- import re
404
-
405
- dm = get_data_manager(DATA_DIR)
406
-
407
- # Load required data
408
- all_crystals = dm.load_all_crystals()
409
- mp_crystals = dm.load_mp_crystals()
410
-
411
- gnome_grouped = dm.get_grouped_entries()
412
-
413
- # Create MP grouped entries
414
- required_columns = [
415
- 'Composition', 'NSites', 'Corrected Energy',
416
- 'Formation Energy Per Atom', 'Chemical System'
417
- ]
418
- mp_minimal = mp_crystals[required_columns]
419
- mp_grouped = mp_minimal.groupby('Chemical System')
420
-
421
- # Parse elements
422
- chemsys = re.split(r'[\s,\-]+', elements)
423
- chemsys = [e.strip() for e in chemsys if e.strip()]
424
-
425
- result = compare_with_materials_project(
426
- chemsys=chemsys,
427
- grouped_entries=gnome_grouped,
428
- mp_grouped_entries=mp_grouped,
429
- all_crystals=all_crystals,
430
- mp_crystals=mp_crystals
431
- )
432
-
433
- return {
434
- "status": "success",
435
- **result
436
- }
437
- except Exception as e:
438
- logger.error(f"Error comparing with MP: {e}")
439
- return {"status": "error", "message": str(e)}
440
-
441
-
442
- @mcp.tool()
443
- async def find_competing_phases_for_composition(
444
- composition: str,
445
- n_phases: int = 5
446
- ) -> Dict[str, Any]:
447
- """
448
- Find competing phases for a given composition.
449
-
450
- Identifies the most thermodynamically favorable phases in the same
451
- chemical space that compete with the given composition.
452
-
453
- Args:
454
- composition: Chemical composition to analyze
455
- n_phases: Number of competing phases to return (default: 5)
456
-
457
- Returns:
458
- List of competing phases with their properties
459
- """
460
- try:
461
- import pymatgen as mg
462
-
463
- dm = get_data_manager(DATA_DIR)
464
- all_crystals = dm.load_all_crystals()
465
- grouped = dm.get_grouped_entries()
466
-
467
- comp = mg.core.Composition(composition)
468
- chemsys = [str(el) for el in comp.elements]
469
-
470
- phases = find_competing_phases(
471
- composition=composition,
472
- chemsys=chemsys,
473
- grouped_entries=grouped,
474
- all_crystals=all_crystals,
475
- n_phases=n_phases
476
- )
477
-
478
- return {
479
- "status": "success",
480
- "composition": composition,
481
- "chemical_system": chemsys,
482
- "competing_phases": phases
483
- }
484
- except Exception as e:
485
- logger.error(f"Error finding competing phases: {e}")
486
- return {"status": "error", "message": str(e)}
487
-
488
-
489
- # ============================================================================
490
- # Structure Tools
491
- # ============================================================================
492
-
493
- @mcp.tool()
494
- async def get_structure(
495
- reduced_formula: str,
496
- output_format: str = "json"
497
- ) -> Dict[str, Any]:
498
- """
499
- Get crystal structure for a given reduced formula.
500
-
501
- Args:
502
- reduced_formula: Reduced chemical formula (e.g., "LiFePO4", "TiO2")
503
- output_format: Output format - "json", "cif", or "poscar"
504
-
505
- Returns:
506
- Crystal structure data in the requested format
507
- """
508
- try:
509
- dm = get_data_manager(DATA_DIR)
510
- atoms, structure = dm.load_structure(reduced_formula)
511
-
512
- if output_format == "cif":
513
- return {
514
- "status": "success",
515
- "format": "cif",
516
- "data": structure.to(fmt="cif")
517
- }
518
- elif output_format == "poscar":
519
- return {
520
- "status": "success",
521
- "format": "poscar",
522
- "data": structure.to(fmt="poscar")
523
- }
524
- else: # json
525
- return {
526
- "status": "success",
527
- "format": "json",
528
- "data": {
529
- "formula": structure.formula,
530
- "reduced_formula": structure.composition.reduced_formula,
531
- "lattice": {
532
- "a": structure.lattice.a,
533
- "b": structure.lattice.b,
534
- "c": structure.lattice.c,
535
- "alpha": structure.lattice.alpha,
536
- "beta": structure.lattice.beta,
537
- "gamma": structure.lattice.gamma,
538
- "volume": structure.lattice.volume,
539
- "matrix": structure.lattice.matrix.tolist()
540
- },
541
- "sites": [
542
- {
543
- "species": str(site.specie),
544
- "coords": site.frac_coords.tolist(),
545
- "cart_coords": site.coords.tolist()
546
- }
547
- for site in structure.sites
548
- ],
549
- "n_sites": len(structure),
550
- "space_group": structure.get_space_group_info()[0]
551
- }
552
- }
553
- except Exception as e:
554
- logger.error(f"Error getting structure: {e}")
555
- return {"status": "error", "message": str(e)}
556
-
557
-
558
- @mcp.tool()
559
- async def compare_structures(
560
- formula1: str,
561
- formula2: str,
562
- ltol: float = 0.2,
563
- stol: float = 0.3,
564
- angle_tol: float = 5.0
565
- ) -> Dict[str, Any]:
566
- """
567
- Compare two crystal structures from the GNoME dataset.
568
-
569
- Uses pymatgen's StructureMatcher to determine if structures are equivalent.
570
-
571
- Args:
572
- formula1: First reduced formula
573
- formula2: Second reduced formula
574
- ltol: Length tolerance for matching
575
- stol: Site tolerance for matching
576
- angle_tol: Angle tolerance in degrees
577
-
578
- Returns:
579
- Comparison results including whether structures match
580
- """
581
- try:
582
- dm = get_data_manager(DATA_DIR)
583
- _, structure1 = dm.load_structure(formula1)
584
- _, structure2 = dm.load_structure(formula2)
585
-
586
- matcher = StructureMatcher(ltol=ltol, stol=stol, angle_tol=angle_tol)
587
-
588
- is_match = matcher.fit(structure1, structure2)
589
- rms_result = matcher.get_rms_dist(structure1, structure2)
590
-
591
- return {
592
- "status": "success",
593
- "formula1": formula1,
594
- "formula2": formula2,
595
- "structures_match": is_match,
596
- "rms_dist": rms_result[0] if rms_result else None,
597
- "max_dist": rms_result[1] if rms_result else None,
598
- "tolerances": {
599
- "ltol": ltol,
600
- "stol": stol,
601
- "angle_tol": angle_tol
602
- }
603
- }
604
- except Exception as e:
605
- logger.error(f"Error comparing structures: {e}")
606
- return {"status": "error", "message": str(e)}
607
-
608
-
609
- # ============================================================================
610
- # r²SCAN Validation Tools
611
- # ============================================================================
612
-
613
- @mcp.tool()
614
- async def get_r2scan_data(
615
- composition: Optional[str] = None,
616
- limit: int = 50
617
- ) -> Dict[str, Any]:
618
- """
619
- Get r²SCAN validation data for materials.
620
-
621
- r²SCAN is a more accurate DFT functional used to validate GNoME predictions.
622
-
623
- Args:
624
- composition: Optional composition filter
625
- limit: Maximum number of results
626
-
627
- Returns:
628
- r²SCAN calculated energies and stability metrics
629
- """
630
- try:
631
- dm = get_data_manager(DATA_DIR)
632
- r2scan = dm.load_r2scan_crystals()
633
-
634
- if composition:
635
- r2scan = r2scan[r2scan['Composition'] == composition]
636
-
637
- results = r2scan.head(limit).to_dict(orient='records')
638
-
639
- return {
640
- "status": "success",
641
- "count": len(results),
642
- "data": results
643
- }
644
- except Exception as e:
645
- logger.error(f"Error getting r2scan data: {e}")
646
- return {"status": "error", "message": str(e)}
647
-
648
-
649
- # ============================================================================
650
- # a2c Crystal Structure Prediction Tools
651
- # ============================================================================
652
-
653
- @mcp.tool()
654
- async def get_a2c_supporting_data() -> Dict[str, Any]:
655
- """
656
- Get a2c (amorphous-to-crystalline) structure prediction supporting data.
657
-
658
- The a2c pipeline discovers crystal structures by relaxing amorphous
659
- configurations using GNoME force fields.
660
-
661
- Returns:
662
- List of available a2c campaigns with their chemical systems
663
- """
664
- try:
665
- dm = get_data_manager(DATA_DIR)
666
- a2c_data = dm.load_a2c_data()
667
-
668
- campaigns = []
669
- for key, data in a2c_data.items():
670
- campaigns.append({
671
- "chemical_system": key,
672
- "has_amorphous_structure": "amorphous_structure" in data,
673
- "num_initial_structures": len(data.get("a2c_initial_structures", [])),
674
- "num_matches": len(data.get("a2c_match_after_relax_example", []))
675
- })
676
-
677
- return {
678
- "status": "success",
679
- "num_campaigns": len(campaigns),
680
- "campaigns": campaigns
681
- }
682
- except Exception as e:
683
- logger.error(f"Error getting a2c data: {e}")
684
- return {"status": "error", "message": str(e)}
685
-
686
-
687
- @mcp.tool()
688
- async def get_a2c_campaign_details(
689
- chemical_system: str
690
- ) -> Dict[str, Any]:
691
- """
692
- Get detailed data for a specific a2c campaign.
693
-
694
- Args:
695
- chemical_system: Chemical system name (e.g., "Al2O3", "SiO2")
696
-
697
- Returns:
698
- Detailed a2c campaign data including structures
699
- """
700
- try:
701
- dm = get_data_manager(DATA_DIR)
702
- a2c_data = dm.load_a2c_data()
703
-
704
- if chemical_system not in a2c_data:
705
- return {
706
- "status": "error",
707
- "message": f"Chemical system {chemical_system} not found in a2c data"
708
- }
709
-
710
- data = a2c_data[chemical_system]
711
-
712
- matches = []
713
- for match in data.get("a2c_match_after_relax_example", []):
714
- matches.append({
715
- "index": match.get("index_in_a2c_initial_structures"),
716
- "formula": match.get("formula"),
717
- "has_ff_relaxed": "relaxed_ff" in match,
718
- "has_dft_relaxed": "relaxed_dft" in match
719
- })
720
-
721
- return {
722
- "status": "success",
723
- "chemical_system": chemical_system,
724
- "amorphous_structure": data.get("amorphous_structure", "")[:500] + "...",
725
- "num_initial_structures": len(data.get("a2c_initial_structures", [])),
726
- "matches": matches
727
- }
728
- except Exception as e:
729
- logger.error(f"Error getting a2c campaign details: {e}")
730
- return {"status": "error", "message": str(e)}
731
-
732
-
733
- # ============================================================================
734
- # Model Information Tools
735
- # ============================================================================
736
-
737
- @mcp.tool()
738
- async def get_model_configurations() -> Dict[str, Any]:
739
- """
740
- Get default configurations for GNoME and NequIP models.
741
-
742
- Returns:
743
- Default configuration dictionaries for both model architectures
744
- """
745
- return {
746
- "status": "success",
747
- "nequip_config": get_nequip_default_config(),
748
- "gnome_config": get_gnome_default_config()
749
- }
750
-
751
-
752
- @mcp.tool()
753
- async def list_available_models() -> Dict[str, Any]:
754
- """
755
- List available pre-trained models.
756
-
757
- Returns:
758
- List of available model names and their information
759
- """
760
- try:
761
- loader = ModelLoader(MODEL_DIR)
762
- models = loader.get_available_models()
763
-
764
- model_info = []
765
- for model_name in models:
766
- info = get_model_info(model_name, MODEL_DIR)
767
- model_info.append(info)
768
-
769
- return {
770
- "status": "success",
771
- "num_models": len(models),
772
- "models": model_info
773
- }
774
- except Exception as e:
775
- logger.error(f"Error listing models: {e}")
776
- return {"status": "error", "message": str(e)}
777
-
778
-
779
- # ============================================================================
780
- # Utility Tools
781
- # ============================================================================
782
-
783
- @mcp.tool()
784
- async def get_pseudopotential_corrections() -> Dict[str, Any]:
785
- """
786
- Get pseudopotential corrections for Materials Project compatibility.
787
-
788
- These corrections are needed when comparing energies between GNoME
789
- and Materials Project calculations.
790
-
791
- Returns:
792
- Dictionary of elemental corrections (eV/atom)
793
- """
794
- from data_utils import PP_CORRECTIONS
795
-
796
- return {
797
- "status": "success",
798
- "description": "Pseudopotential corrections for elements where GNoME and MP use different pseudopotentials",
799
- "corrections": PP_CORRECTIONS,
800
- "units": "eV/atom"
801
- }
802
-
803
-
804
- @mcp.tool()
805
- async def download_dataset(
806
- include_structures: bool = False
807
- ) -> Dict[str, Any]:
808
- """
809
- Download the GNoME dataset files to the server.
810
-
811
- Downloads summary CSVs and optionally structure archives.
812
- Data is persisted in the server's /data directory.
813
-
814
- Args:
815
- include_structures: Whether to also download structure archives (~GB)
816
-
817
- Returns:
818
- Status of download operation
819
- """
820
- try:
821
- dm = get_data_manager(DATA_DIR)
822
-
823
- downloaded = []
824
-
825
- # Download summary files
826
- gnome_path, external_path = dm.download_summary_data()
827
- downloaded.extend([str(gnome_path), str(external_path)])
828
-
829
- # Download MP snapshot
830
- mp_path = dm.download_mp_snapshot()
831
- downloaded.append(str(mp_path))
832
-
833
- # Download r2scan
834
- r2scan_path = dm.download_r2scan_data()
835
- downloaded.append(str(r2scan_path))
836
-
837
- if include_structures:
838
- struct_path = dm.download_structure_archive("by_reduced_formula")
839
- downloaded.append(str(struct_path))
840
-
841
- return {
842
- "status": "success",
843
- "downloaded_files": downloaded,
844
- "data_directory": DATA_DIR
845
- }
846
- except Exception as e:
847
- logger.error(f"Error downloading dataset: {e}")
848
- return {"status": "error", "message": str(e)}
849
-
850
-
851
- @mcp.tool()
852
- async def check_data_status() -> Dict[str, Any]:
853
- """
854
- Check the status of downloaded dataset files on the server.
855
-
856
- Returns information about which files are available and their sizes.
857
- Use this to verify data is properly downloaded before querying.
858
-
859
- Returns:
860
- Status of each dataset file (exists, size, path)
861
- """
862
- import os
863
-
864
- files_to_check = {
865
- "gnome_summary": "stable_materials_summary.csv",
866
- "external_summary": "external_materials_summary.csv",
867
- "mp_snapshot": "mp_snapshot_summary.csv",
868
- "r2scan": "stable_materials_r2scan.csv",
869
- "structures": "by_reduced_formula.zip",
870
- "a2c_data": "a2c_supporting_data.json"
871
- }
872
-
873
- file_status = {}
874
- total_size = 0
875
-
876
- for key, filename in files_to_check.items():
877
- filepath = os.path.join(DATA_DIR, filename)
878
- if os.path.exists(filepath):
879
- size = os.path.getsize(filepath)
880
- total_size += size
881
- file_status[key] = {
882
- "exists": True,
883
- "path": filepath,
884
- "size_mb": round(size / (1024 * 1024), 2)
885
- }
886
- else:
887
- file_status[key] = {
888
- "exists": False,
889
- "path": filepath,
890
- "size_mb": 0
891
- }
892
-
893
- # Check if core data is ready
894
- core_ready = (
895
- file_status["gnome_summary"]["exists"] and
896
- file_status["external_summary"]["exists"]
897
- )
898
-
899
- return {
900
- "status": "success",
901
- "data_directory": DATA_DIR,
902
- "core_data_ready": core_ready,
903
- "total_size_mb": round(total_size / (1024 * 1024), 2),
904
- "files": file_status,
905
- "message": "Core data is ready for queries" if core_ready else "Please call download_dataset() first"
906
- }
907
-
908
-
909
- # ============================================================================
910
- # Server Entry Point
911
- # ============================================================================
912
-
913
- if __name__ == "__main__":
914
- # Run the server
915
- mcp.run(transport="sse")
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Google LLC (Original code), Modified for MCP Service
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ """
10
+ GNoME Materials Discovery MCP Server
11
+
12
+ This is the main MCP server implementation providing tools for:
13
+ - Dataset access and querying
14
+ - Decomposition energy calculation
15
+ - Phase diagram analysis
16
+ - Crystal structure operations
17
+ - Air stability analysis
18
+ - Model inference
19
+ """
20
+
21
+ import os
22
+ import json
23
+ import logging
24
+ from typing import Optional, List, Dict, Any
25
+ from contextlib import asynccontextmanager
26
+ from collections.abc import AsyncIterator
27
+
28
+ from mcp.server.fastmcp import FastMCP
29
+
30
+ # Import local modules
31
+ from data_utils import DataManager, get_data_manager
32
+ from phase_diagram_utils import (
33
+ compute_decomposition_energy,
34
+ build_phase_diagram,
35
+ compute_air_stability,
36
+ compare_with_materials_project,
37
+ find_competing_phases
38
+ )
39
+ from model_utils import (
40
+ ModelLoader,
41
+ atoms_to_graph,
42
+ get_model_info,
43
+ get_nequip_default_config,
44
+ get_gnome_default_config,
45
+ StructureMatcher
46
+ )
47
+
48
+ # Configure logging
49
+ logging.basicConfig(level=logging.INFO)
50
+ logger = logging.getLogger(__name__)
51
+
52
+ # Data directory configuration - must match Dockerfile ENV
53
+ DATA_DIR = os.environ.get("GNOME_DATA_DIR", "/app/gnome_data")
54
+ MODEL_DIR = os.environ.get("GNOME_MODEL_DIR", "/app/models")
55
+
56
+ # Ensure directories exist at module load
57
+ os.makedirs(DATA_DIR, exist_ok=True)
58
+ os.makedirs(MODEL_DIR, exist_ok=True)
59
+
60
+
61
+ @asynccontextmanager
62
+ async def app_lifespan(server: FastMCP) -> AsyncIterator[Dict[str, Any]]:
63
+ """
64
+ Application lifespan context manager.
65
+ Initializes data manager and model loader.
66
+ """
67
+ logger.info("Initializing GNoME Materials Discovery MCP Server...")
68
+
69
+ # Initialize data manager
70
+ data_manager = DataManager(DATA_DIR)
71
+ model_loader = ModelLoader(MODEL_DIR)
72
+
73
+ logger.info(f"Data directory: {DATA_DIR}")
74
+ logger.info(f"Model directory: {MODEL_DIR}")
75
+
76
+ yield {
77
+ "data_manager": data_manager,
78
+ "model_loader": model_loader
79
+ }
80
+
81
+ # Cleanup
82
+ logger.info("Shutting down GNoME MCP Server...")
83
+ data_manager.close()
84
+
85
+
86
+ # Create FastMCP server
87
+ mcp = FastMCP("GNoME Materials Discovery")
88
+
89
+
90
+ # ============================================================================
91
+ # Dataset Access Tools
92
+ # ============================================================================
93
+
94
+ @mcp.tool()
95
+ async def get_dataset_statistics() -> Dict[str, Any]:
96
+ """
97
+ Get statistics about the GNoME materials discovery dataset.
98
+
99
+ Returns information about:
100
+ - Total number of materials
101
+ - Unique compositions and formulas
102
+ - Crystal system distribution
103
+ - Average formation energy
104
+ - Element coverage
105
+ """
106
+ try:
107
+ dm = get_data_manager(DATA_DIR)
108
+ stats = dm.get_statistics()
109
+ return {
110
+ "status": "success",
111
+ "data": stats
112
+ }
113
+ except Exception as e:
114
+ logger.error(f"Error getting statistics: {e}")
115
+ return {"status": "error", "message": str(e)}
116
+
117
+
118
+ @mcp.tool()
119
+ async def query_materials(
120
+ composition: Optional[str] = None,
121
+ elements: Optional[str] = None,
122
+ space_group: Optional[int] = None,
123
+ crystal_system: Optional[str] = None,
124
+ min_bandgap: Optional[float] = None,
125
+ max_bandgap: Optional[float] = None,
126
+ max_decomposition_energy: Optional[float] = None,
127
+ limit: int = 50
128
+ ) -> Dict[str, Any]:
129
+ """
130
+ Query materials from the GNoME dataset with various filters.
131
+
132
+ Args:
133
+ composition: Exact composition to match (e.g., "Li2O")
134
+ elements: Comma-separated list of elements that must be present (e.g., "Li,O")
135
+ space_group: Space group number to filter by
136
+ crystal_system: Crystal system name (e.g., "cubic", "hexagonal")
137
+ min_bandgap: Minimum bandgap value in eV
138
+ max_bandgap: Maximum bandgap value in eV
139
+ max_decomposition_energy: Maximum decomposition energy per atom in eV
140
+ limit: Maximum number of results to return (default: 50)
141
+
142
+ Returns:
143
+ List of matching materials with their properties
144
+ """
145
+ try:
146
+ dm = get_data_manager(DATA_DIR)
147
+
148
+ elements_list = None
149
+ if elements:
150
+ elements_list = [e.strip() for e in elements.split(",")]
151
+
152
+ results = dm.query_by_composition(
153
+ composition=composition,
154
+ elements=elements_list,
155
+ space_group=space_group,
156
+ crystal_system=crystal_system,
157
+ min_bandgap=min_bandgap,
158
+ max_bandgap=max_bandgap,
159
+ max_decomposition_energy=max_decomposition_energy,
160
+ limit=limit
161
+ )
162
+
163
+ # Convert to list of dicts
164
+ materials = results.to_dict(orient='records')
165
+
166
+ return {
167
+ "status": "success",
168
+ "count": len(materials),
169
+ "materials": materials
170
+ }
171
+ except Exception as e:
172
+ logger.error(f"Error querying materials: {e}")
173
+ return {"status": "error", "message": str(e)}
174
+
175
+
176
+ @mcp.tool()
177
+ async def get_material_by_id(material_id: str) -> Dict[str, Any]:
178
+ """
179
+ Get detailed information about a specific material by its ID.
180
+
181
+ Args:
182
+ material_id: The unique MaterialId from the GNoME dataset
183
+
184
+ Returns:
185
+ Complete material information including structure and properties
186
+ """
187
+ try:
188
+ dm = get_data_manager(DATA_DIR)
189
+ material = dm.get_crystal_by_id(material_id)
190
+
191
+ if material is None:
192
+ return {"status": "error", "message": f"Material {material_id} not found"}
193
+
194
+ return {
195
+ "status": "success",
196
+ "material": material.to_dict()
197
+ }
198
+ except Exception as e:
199
+ logger.error(f"Error getting material: {e}")
200
+ return {"status": "error", "message": str(e)}
201
+
202
+
203
+ @mcp.tool()
204
+ async def get_random_material(
205
+ crystal_system: Optional[str] = None,
206
+ n_elements: Optional[int] = None
207
+ ) -> Dict[str, Any]:
208
+ """
209
+ Get a random material from the GNoME dataset.
210
+
211
+ Args:
212
+ crystal_system: Optional filter by crystal system
213
+ n_elements: Optional filter by number of elements (e.g., 2 for binary, 3 for ternary)
214
+
215
+ Returns:
216
+ Random material information
217
+ """
218
+ try:
219
+ dm = get_data_manager(DATA_DIR)
220
+ crystals = dm.load_gnome_crystals()
221
+
222
+ if crystal_system:
223
+ crystals = crystals[crystals['Crystal System'] == crystal_system]
224
+
225
+ if n_elements:
226
+ crystals = crystals[crystals['Chemical System'].map(len) == n_elements]
227
+
228
+ if len(crystals) == 0:
229
+ return {"status": "error", "message": "No materials match the criteria"}
230
+
231
+ sample = crystals.sample(1).iloc[0]
232
+
233
+ return {
234
+ "status": "success",
235
+ "material": sample.to_dict()
236
+ }
237
+ except Exception as e:
238
+ logger.error(f"Error getting random material: {e}")
239
+ return {"status": "error", "message": str(e)}
240
+
241
+
242
+ # ============================================================================
243
+ # Phase Diagram and Stability Tools
244
+ # ============================================================================
245
+
246
+ @mcp.tool()
247
+ async def calculate_decomposition_energy(
248
+ composition: str,
249
+ energy: float
250
+ ) -> Dict[str, Any]:
251
+ """
252
+ Calculate the decomposition energy of a material relative to the GNoME convex hull.
253
+
254
+ This determines whether a material is thermodynamically stable or metastable.
255
+ A negative or zero decomposition energy indicates stability.
256
+
257
+ Args:
258
+ composition: Chemical composition (e.g., "LiFePO4", "Li2O")
259
+ energy: Total corrected energy from DFT calculation in eV
260
+
261
+ Returns:
262
+ Decomposition energy and decomposition products
263
+ """
264
+ try:
265
+ import pymatgen as mg
266
+
267
+ dm = get_data_manager(DATA_DIR)
268
+ all_crystals = dm.load_all_crystals()
269
+ grouped = dm.get_grouped_entries()
270
+
271
+ # Get chemical system from composition
272
+ comp = mg.core.Composition(composition)
273
+ chemsys = [str(el) for el in comp.elements]
274
+
275
+ result = compute_decomposition_energy(
276
+ composition=composition,
277
+ energy=energy,
278
+ chemsys=chemsys,
279
+ grouped_entries=grouped,
280
+ all_crystals=all_crystals
281
+ )
282
+
283
+ return {
284
+ "status": "success",
285
+ **result
286
+ }
287
+ except Exception as e:
288
+ logger.error(f"Error calculating decomposition energy: {e}")
289
+ return {"status": "error", "message": str(e)}
290
+
291
+
292
+ @mcp.tool()
293
+ async def get_phase_diagram(
294
+ elements: str
295
+ ) -> Dict[str, Any]:
296
+ """
297
+ Build and analyze the phase diagram for a chemical system.
298
+
299
+ Args:
300
+ elements: Comma or dash separated list of elements (e.g., "Li,Fe,P,O" or "Li-Fe-P-O")
301
+
302
+ Returns:
303
+ Phase diagram information including stable and unstable entries
304
+ """
305
+ try:
306
+ import re
307
+
308
+ dm = get_data_manager(DATA_DIR)
309
+ all_crystals = dm.load_all_crystals()
310
+ grouped = dm.get_grouped_entries()
311
+
312
+ # Parse elements
313
+ chemsys = re.split(r'[\s,\-]+', elements)
314
+ chemsys = [e.strip() for e in chemsys if e.strip()]
315
+
316
+ result = build_phase_diagram(
317
+ chemsys=chemsys,
318
+ grouped_entries=grouped,
319
+ all_crystals=all_crystals
320
+ )
321
+
322
+ return {
323
+ "status": "success",
324
+ **result
325
+ }
326
+ except Exception as e:
327
+ logger.error(f"Error building phase diagram: {e}")
328
+ return {"status": "error", "message": str(e)}
329
+
330
+
331
+ @mcp.tool()
332
+ async def calculate_air_stability(
333
+ composition: str,
334
+ energy: float,
335
+ temperature: float = 300.0,
336
+ oxygen_pressure: float = 21200.0
337
+ ) -> Dict[str, Any]:
338
+ """
339
+ Calculate the air stability of a material.
340
+
341
+ Analyzes stability with respect to:
342
+ - Oxygen (via grand potential phase diagram)
343
+ - Carbon dioxide (CO2 reactivity)
344
+ - Water (H2O reactivity)
345
+
346
+ Args:
347
+ composition: Chemical composition (e.g., "Li3N", "NaCl")
348
+ energy: Total corrected energy in eV
349
+ temperature: Temperature in Kelvin (default: 300K)
350
+ oxygen_pressure: Oxygen partial pressure in Pa (default: 21200 Pa, ambient)
351
+
352
+ Returns:
353
+ Air stability analysis results
354
+ """
355
+ try:
356
+ import pymatgen as mg
357
+
358
+ dm = get_data_manager(DATA_DIR)
359
+ all_crystals = dm.load_all_crystals()
360
+ grouped = dm.get_grouped_entries()
361
+
362
+ comp = mg.core.Composition(composition)
363
+ chemsys = [str(el) for el in comp.elements]
364
+
365
+ result = compute_air_stability(
366
+ composition=composition,
367
+ energy=energy,
368
+ chemsys=chemsys,
369
+ grouped_entries=grouped,
370
+ all_crystals=all_crystals,
371
+ temperature=temperature,
372
+ oxygen_pressure=oxygen_pressure
373
+ )
374
+
375
+ return {
376
+ "status": "success",
377
+ **result
378
+ }
379
+ except Exception as e:
380
+ logger.error(f"Error calculating air stability: {e}")
381
+ return {"status": "error", "message": str(e)}
382
+
383
+
384
+ @mcp.tool()
385
+ async def compare_gnome_with_mp(
386
+ elements: str
387
+ ) -> Dict[str, Any]:
388
+ """
389
+ Compare GNoME phase diagram with Materials Project for a chemical system.
390
+
391
+ Identifies:
392
+ - New stable phases discovered by GNoME
393
+ - Phases only in Materials Project
394
+ - Common stable phases
395
+
396
+ Args:
397
+ elements: Comma or dash separated list of elements
398
+
399
+ Returns:
400
+ Comparison results between GNoME and Materials Project
401
+ """
402
+ try:
403
+ import re
404
+
405
+ dm = get_data_manager(DATA_DIR)
406
+
407
+ # Load required data
408
+ all_crystals = dm.load_all_crystals()
409
+ mp_crystals = dm.load_mp_crystals()
410
+
411
+ gnome_grouped = dm.get_grouped_entries()
412
+
413
+ # Create MP grouped entries
414
+ required_columns = [
415
+ 'Composition', 'NSites', 'Corrected Energy',
416
+ 'Formation Energy Per Atom', 'Chemical System'
417
+ ]
418
+ mp_minimal = mp_crystals[required_columns]
419
+ mp_grouped = mp_minimal.groupby('Chemical System')
420
+
421
+ # Parse elements
422
+ chemsys = re.split(r'[\s,\-]+', elements)
423
+ chemsys = [e.strip() for e in chemsys if e.strip()]
424
+
425
+ result = compare_with_materials_project(
426
+ chemsys=chemsys,
427
+ grouped_entries=gnome_grouped,
428
+ mp_grouped_entries=mp_grouped,
429
+ all_crystals=all_crystals,
430
+ mp_crystals=mp_crystals
431
+ )
432
+
433
+ return {
434
+ "status": "success",
435
+ **result
436
+ }
437
+ except Exception as e:
438
+ logger.error(f"Error comparing with MP: {e}")
439
+ return {"status": "error", "message": str(e)}
440
+
441
+
442
+ @mcp.tool()
443
+ async def find_competing_phases_for_composition(
444
+ composition: str,
445
+ n_phases: int = 5
446
+ ) -> Dict[str, Any]:
447
+ """
448
+ Find competing phases for a given composition.
449
+
450
+ Identifies the most thermodynamically favorable phases in the same
451
+ chemical space that compete with the given composition.
452
+
453
+ Args:
454
+ composition: Chemical composition to analyze
455
+ n_phases: Number of competing phases to return (default: 5)
456
+
457
+ Returns:
458
+ List of competing phases with their properties
459
+ """
460
+ try:
461
+ import pymatgen as mg
462
+
463
+ dm = get_data_manager(DATA_DIR)
464
+ all_crystals = dm.load_all_crystals()
465
+ grouped = dm.get_grouped_entries()
466
+
467
+ comp = mg.core.Composition(composition)
468
+ chemsys = [str(el) for el in comp.elements]
469
+
470
+ phases = find_competing_phases(
471
+ composition=composition,
472
+ chemsys=chemsys,
473
+ grouped_entries=grouped,
474
+ all_crystals=all_crystals,
475
+ n_phases=n_phases
476
+ )
477
+
478
+ return {
479
+ "status": "success",
480
+ "composition": composition,
481
+ "chemical_system": chemsys,
482
+ "competing_phases": phases
483
+ }
484
+ except Exception as e:
485
+ logger.error(f"Error finding competing phases: {e}")
486
+ return {"status": "error", "message": str(e)}
487
+
488
+
489
+ # ============================================================================
490
+ # Structure Tools
491
+ # ============================================================================
492
+
493
+ @mcp.tool()
494
+ async def get_structure(
495
+ reduced_formula: str,
496
+ output_format: str = "json"
497
+ ) -> Dict[str, Any]:
498
+ """
499
+ Get crystal structure for a given reduced formula.
500
+
501
+ Args:
502
+ reduced_formula: Reduced chemical formula (e.g., "LiFePO4", "TiO2")
503
+ output_format: Output format - "json", "cif", or "poscar"
504
+
505
+ Returns:
506
+ Crystal structure data in the requested format
507
+ """
508
+ try:
509
+ dm = get_data_manager(DATA_DIR)
510
+ atoms, structure = dm.load_structure(reduced_formula)
511
+
512
+ if output_format == "cif":
513
+ return {
514
+ "status": "success",
515
+ "format": "cif",
516
+ "data": structure.to(fmt="cif")
517
+ }
518
+ elif output_format == "poscar":
519
+ return {
520
+ "status": "success",
521
+ "format": "poscar",
522
+ "data": structure.to(fmt="poscar")
523
+ }
524
+ else: # json
525
+ return {
526
+ "status": "success",
527
+ "format": "json",
528
+ "data": {
529
+ "formula": structure.formula,
530
+ "reduced_formula": structure.composition.reduced_formula,
531
+ "lattice": {
532
+ "a": structure.lattice.a,
533
+ "b": structure.lattice.b,
534
+ "c": structure.lattice.c,
535
+ "alpha": structure.lattice.alpha,
536
+ "beta": structure.lattice.beta,
537
+ "gamma": structure.lattice.gamma,
538
+ "volume": structure.lattice.volume,
539
+ "matrix": structure.lattice.matrix.tolist()
540
+ },
541
+ "sites": [
542
+ {
543
+ "species": str(site.specie),
544
+ "coords": site.frac_coords.tolist(),
545
+ "cart_coords": site.coords.tolist()
546
+ }
547
+ for site in structure.sites
548
+ ],
549
+ "n_sites": len(structure),
550
+ "space_group": structure.get_space_group_info()[0]
551
+ }
552
+ }
553
+ except Exception as e:
554
+ logger.error(f"Error getting structure: {e}")
555
+ return {"status": "error", "message": str(e)}
556
+
557
+
558
+ @mcp.tool()
559
+ async def compare_structures(
560
+ formula1: str,
561
+ formula2: str,
562
+ ltol: float = 0.2,
563
+ stol: float = 0.3,
564
+ angle_tol: float = 5.0
565
+ ) -> Dict[str, Any]:
566
+ """
567
+ Compare two crystal structures from the GNoME dataset.
568
+
569
+ Uses pymatgen's StructureMatcher to determine if structures are equivalent.
570
+
571
+ Args:
572
+ formula1: First reduced formula
573
+ formula2: Second reduced formula
574
+ ltol: Length tolerance for matching
575
+ stol: Site tolerance for matching
576
+ angle_tol: Angle tolerance in degrees
577
+
578
+ Returns:
579
+ Comparison results including whether structures match
580
+ """
581
+ try:
582
+ dm = get_data_manager(DATA_DIR)
583
+ _, structure1 = dm.load_structure(formula1)
584
+ _, structure2 = dm.load_structure(formula2)
585
+
586
+ matcher = StructureMatcher(ltol=ltol, stol=stol, angle_tol=angle_tol)
587
+
588
+ is_match = matcher.fit(structure1, structure2)
589
+ rms_result = matcher.get_rms_dist(structure1, structure2)
590
+
591
+ return {
592
+ "status": "success",
593
+ "formula1": formula1,
594
+ "formula2": formula2,
595
+ "structures_match": is_match,
596
+ "rms_dist": rms_result[0] if rms_result else None,
597
+ "max_dist": rms_result[1] if rms_result else None,
598
+ "tolerances": {
599
+ "ltol": ltol,
600
+ "stol": stol,
601
+ "angle_tol": angle_tol
602
+ }
603
+ }
604
+ except Exception as e:
605
+ logger.error(f"Error comparing structures: {e}")
606
+ return {"status": "error", "message": str(e)}
607
+
608
+
609
+ # ============================================================================
610
+ # r²SCAN Validation Tools
611
+ # ============================================================================
612
+
613
+ @mcp.tool()
614
+ async def get_r2scan_data(
615
+ composition: Optional[str] = None,
616
+ limit: int = 50
617
+ ) -> Dict[str, Any]:
618
+ """
619
+ Get r²SCAN validation data for materials.
620
+
621
+ r²SCAN is a more accurate DFT functional used to validate GNoME predictions.
622
+
623
+ Args:
624
+ composition: Optional composition filter
625
+ limit: Maximum number of results
626
+
627
+ Returns:
628
+ r²SCAN calculated energies and stability metrics
629
+ """
630
+ try:
631
+ dm = get_data_manager(DATA_DIR)
632
+ r2scan = dm.load_r2scan_crystals()
633
+
634
+ if composition:
635
+ r2scan = r2scan[r2scan['Composition'] == composition]
636
+
637
+ results = r2scan.head(limit).to_dict(orient='records')
638
+
639
+ return {
640
+ "status": "success",
641
+ "count": len(results),
642
+ "data": results
643
+ }
644
+ except Exception as e:
645
+ logger.error(f"Error getting r2scan data: {e}")
646
+ return {"status": "error", "message": str(e)}
647
+
648
+
649
+ # ============================================================================
650
+ # a2c Crystal Structure Prediction Tools
651
+ # ============================================================================
652
+
653
+ @mcp.tool()
654
+ async def get_a2c_supporting_data() -> Dict[str, Any]:
655
+ """
656
+ Get a2c (amorphous-to-crystalline) structure prediction supporting data.
657
+
658
+ The a2c pipeline discovers crystal structures by relaxing amorphous
659
+ configurations using GNoME force fields.
660
+
661
+ Returns:
662
+ List of available a2c campaigns with their chemical systems
663
+ """
664
+ try:
665
+ dm = get_data_manager(DATA_DIR)
666
+ a2c_data = dm.load_a2c_data()
667
+
668
+ campaigns = []
669
+ for key, data in a2c_data.items():
670
+ campaigns.append({
671
+ "chemical_system": key,
672
+ "has_amorphous_structure": "amorphous_structure" in data,
673
+ "num_initial_structures": len(data.get("a2c_initial_structures", [])),
674
+ "num_matches": len(data.get("a2c_match_after_relax_example", []))
675
+ })
676
+
677
+ return {
678
+ "status": "success",
679
+ "num_campaigns": len(campaigns),
680
+ "campaigns": campaigns
681
+ }
682
+ except Exception as e:
683
+ logger.error(f"Error getting a2c data: {e}")
684
+ return {"status": "error", "message": str(e)}
685
+
686
+
687
+ @mcp.tool()
688
+ async def get_a2c_campaign_details(
689
+ chemical_system: str
690
+ ) -> Dict[str, Any]:
691
+ """
692
+ Get detailed data for a specific a2c campaign.
693
+
694
+ Args:
695
+ chemical_system: Chemical system name (e.g., "Al2O3", "SiO2")
696
+
697
+ Returns:
698
+ Detailed a2c campaign data including structures
699
+ """
700
+ try:
701
+ dm = get_data_manager(DATA_DIR)
702
+ a2c_data = dm.load_a2c_data()
703
+
704
+ if chemical_system not in a2c_data:
705
+ return {
706
+ "status": "error",
707
+ "message": f"Chemical system {chemical_system} not found in a2c data"
708
+ }
709
+
710
+ data = a2c_data[chemical_system]
711
+
712
+ matches = []
713
+ for match in data.get("a2c_match_after_relax_example", []):
714
+ matches.append({
715
+ "index": match.get("index_in_a2c_initial_structures"),
716
+ "formula": match.get("formula"),
717
+ "has_ff_relaxed": "relaxed_ff" in match,
718
+ "has_dft_relaxed": "relaxed_dft" in match
719
+ })
720
+
721
+ return {
722
+ "status": "success",
723
+ "chemical_system": chemical_system,
724
+ "amorphous_structure": data.get("amorphous_structure", "")[:500] + "...",
725
+ "num_initial_structures": len(data.get("a2c_initial_structures", [])),
726
+ "matches": matches
727
+ }
728
+ except Exception as e:
729
+ logger.error(f"Error getting a2c campaign details: {e}")
730
+ return {"status": "error", "message": str(e)}
731
+
732
+
733
+ # ============================================================================
734
+ # Model Information Tools
735
+ # ============================================================================
736
+
737
+ @mcp.tool()
738
+ async def get_model_configurations() -> Dict[str, Any]:
739
+ """
740
+ Get default configurations for GNoME and NequIP models.
741
+
742
+ Returns:
743
+ Default configuration dictionaries for both model architectures
744
+ """
745
+ return {
746
+ "status": "success",
747
+ "nequip_config": get_nequip_default_config(),
748
+ "gnome_config": get_gnome_default_config()
749
+ }
750
+
751
+
752
+ @mcp.tool()
753
+ async def list_available_models() -> Dict[str, Any]:
754
+ """
755
+ List available pre-trained models.
756
+
757
+ Returns:
758
+ List of available model names and their information
759
+ """
760
+ try:
761
+ loader = ModelLoader(MODEL_DIR)
762
+ models = loader.get_available_models()
763
+
764
+ model_info = []
765
+ for model_name in models:
766
+ info = get_model_info(model_name, MODEL_DIR)
767
+ model_info.append(info)
768
+
769
+ return {
770
+ "status": "success",
771
+ "num_models": len(models),
772
+ "models": model_info
773
+ }
774
+ except Exception as e:
775
+ logger.error(f"Error listing models: {e}")
776
+ return {"status": "error", "message": str(e)}
777
+
778
+
779
+ # ============================================================================
780
+ # Utility Tools
781
+ # ============================================================================
782
+
783
+ @mcp.tool()
784
+ async def get_pseudopotential_corrections() -> Dict[str, Any]:
785
+ """
786
+ Get pseudopotential corrections for Materials Project compatibility.
787
+
788
+ These corrections are needed when comparing energies between GNoME
789
+ and Materials Project calculations.
790
+
791
+ Returns:
792
+ Dictionary of elemental corrections (eV/atom)
793
+ """
794
+ from data_utils import PP_CORRECTIONS
795
+
796
+ return {
797
+ "status": "success",
798
+ "description": "Pseudopotential corrections for elements where GNoME and MP use different pseudopotentials",
799
+ "corrections": PP_CORRECTIONS,
800
+ "units": "eV/atom"
801
+ }
802
+
803
+
804
+ @mcp.tool()
805
+ async def download_dataset(
806
+ include_structures: bool = False
807
+ ) -> Dict[str, Any]:
808
+ """
809
+ Download the GNoME dataset files to the server.
810
+
811
+ Downloads summary CSVs and optionally structure archives.
812
+ Data is persisted in the server's /data directory.
813
+
814
+ Args:
815
+ include_structures: Whether to also download structure archives (~GB)
816
+
817
+ Returns:
818
+ Status of download operation
819
+ """
820
+ try:
821
+ dm = get_data_manager(DATA_DIR)
822
+
823
+ downloaded = []
824
+
825
+ # Download summary files
826
+ gnome_path, external_path = dm.download_summary_data()
827
+ downloaded.extend([str(gnome_path), str(external_path)])
828
+
829
+ # Download MP snapshot
830
+ mp_path = dm.download_mp_snapshot()
831
+ downloaded.append(str(mp_path))
832
+
833
+ # Download r2scan
834
+ r2scan_path = dm.download_r2scan_data()
835
+ downloaded.append(str(r2scan_path))
836
+
837
+ if include_structures:
838
+ struct_path = dm.download_structure_archive("by_reduced_formula")
839
+ downloaded.append(str(struct_path))
840
+
841
+ return {
842
+ "status": "success",
843
+ "downloaded_files": downloaded,
844
+ "data_directory": DATA_DIR
845
+ }
846
+ except Exception as e:
847
+ logger.error(f"Error downloading dataset: {e}")
848
+ return {"status": "error", "message": str(e)}
849
+
850
+
851
+ @mcp.tool()
852
+ async def check_data_status() -> Dict[str, Any]:
853
+ """
854
+ Check the status of downloaded dataset files on the server.
855
+
856
+ Returns information about which files are available and their sizes.
857
+ Use this to verify data is properly downloaded before querying.
858
+
859
+ Returns:
860
+ Status of each dataset file (exists, size, path)
861
+ """
862
+ import os
863
+
864
+ files_to_check = {
865
+ "gnome_summary": "stable_materials_summary.csv",
866
+ "external_summary": "external_materials_summary.csv",
867
+ "mp_snapshot": "mp_snapshot_summary.csv",
868
+ "r2scan": "stable_materials_r2scan.csv",
869
+ "structures": "by_reduced_formula.zip",
870
+ "a2c_data": "a2c_supporting_data.json"
871
+ }
872
+
873
+ file_status = {}
874
+ total_size = 0
875
+
876
+ for key, filename in files_to_check.items():
877
+ filepath = os.path.join(DATA_DIR, filename)
878
+ if os.path.exists(filepath):
879
+ size = os.path.getsize(filepath)
880
+ total_size += size
881
+ file_status[key] = {
882
+ "exists": True,
883
+ "path": filepath,
884
+ "size_mb": round(size / (1024 * 1024), 2)
885
+ }
886
+ else:
887
+ file_status[key] = {
888
+ "exists": False,
889
+ "path": filepath,
890
+ "size_mb": 0
891
+ }
892
+
893
+ # Check if core data is ready
894
+ core_ready = (
895
+ file_status["gnome_summary"]["exists"] and
896
+ file_status["external_summary"]["exists"]
897
+ )
898
+
899
+ return {
900
+ "status": "success",
901
+ "data_directory": DATA_DIR,
902
+ "core_data_ready": core_ready,
903
+ "total_size_mb": round(total_size / (1024 * 1024), 2),
904
+ "files": file_status,
905
+ "message": "Core data is ready for queries" if core_ready else "Please call download_dataset() first"
906
+ }
907
+
908
+
909
+ # ============================================================================
910
+ # Server Entry Point
911
+ # ============================================================================
912
+
913
+ def create_sse_app():
914
+ """Create SSE app for mounting in Starlette."""
915
+ return mcp.sse_app()
916
+
917
+
918
+ def create_streamable_app():
919
+ """Create streamable HTTP app for mounting in Starlette."""
920
+ return mcp.streamable_http_app()
921
+
922
+
923
+ if __name__ == "__main__":
924
+ # Run the server with stdio transport for local testing
925
+ mcp.run()
requirements.txt CHANGED
@@ -2,8 +2,7 @@
2
  # Requirements for HuggingFace Spaces deployment
3
 
4
  # MCP Framework
5
- mcp>=1.0.0
6
- fastmcp>=0.1.0
7
 
8
  # Web server
9
  uvicorn>=0.29.0
 
2
  # Requirements for HuggingFace Spaces deployment
3
 
4
  # MCP Framework
5
+ mcp[cli]>=1.0.0
 
6
 
7
  # Web server
8
  uvicorn>=0.29.0