TRIADS / app.py
Rtx09's picture
fix: clean up phonons tab info
b90bf21
"""
TRIADS — Multi-Benchmark Materials Property Prediction
HuggingFace Gradio App (Production Redux)
Covers all 6 Matbench benchmarks:
1. matbench_steels — Yield Strength (MPa)
2. matbench_expt_gap — Band Gap (eV)
3. matbench_ismetal — Metallicity (ROC-AUC)
4. matbench_glass — Glass Forming Ability
5. matbench_jdft2d — Exfoliation Energy (meV/atom)
6. matbench_phonons — Peak Phonon Frequency (cm⁻¹)
"""
import os
import warnings
import urllib.request
import json
import traceback
warnings.filterwarnings("ignore")
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from huggingface_hub import hf_hub_download
# ─────────────────────────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────────────────────────
REPO_ID = "Rtx09/TRIADS" # Used only if local weights are missing
BENCHMARK_INFO = {
"steels": {
"title": "🔩 Steel Yield Strength",
"description": "Predict yield strength (MPa) of steel alloys from composition.",
"unit": "MPa",
"example": "Fe0.7Cr0.15Ni0.15",
"examples": ["Fe0.7Cr0.15Ni0.15", "Fe0.8C0.02Mn0.1Si0.05Cr0.03", "Fe0.6Ni0.25Mo0.1Cr0.05"],
"task": "regression",
"result": "91.20 ± 12.23 MPa MAE (5-fold, 5-seed ensemble)",
},
"expt_gap": {
"title": "⚡ Experimental Band Gap",
"description": "Predict experimental electronic band gap (eV) from composition.",
"unit": "eV",
"example": "TiO2",
"examples": ["TiO2", "GaN", "ZnO", "Si", "CdS"],
"task": "regression",
"result": "0.3068 ± 0.0082 eV MAE (5-fold, composition-only)",
},
"ismetal": {
"title": "🔮 Metallicity Classifier",
"description": "Predict whether a material is metallic or non-metallic from composition.",
"unit": "probability (1 = metal)",
"example": "Cu",
"examples": ["Cu", "SiO2", "Fe3O4", "BaTiO3", "Al"],
"task": "classification",
"result": "0.9655 ± 0.0029 ROC-AUC (5-fold, composition-only)",
},
"glass": {
"title": "🪟 Glass Forming Ability",
"description": "Predict metallic glass forming ability from alloy composition.",
"unit": "probability (1 = glass former)",
"example": "Cu46Zr54",
"examples": ["Cu46Zr54", "Fe80B20", "Al86Ni7La6Y1", "Pd40Cu30Ni10P20"],
"task": "classification",
"result": "0.9285 ± 0.0063 ROC-AUC (5-fold, 5-seed ensemble)",
},
"jdft2d": {
"title": "📐 Exfoliation Energy",
"description": "Predict exfoliation energy (meV/atom) of 2D materials from structure+composition.",
"unit": "meV/atom",
"example": "MoS2",
"examples": ["MoS2", "WSe2", "BN", "graphene (C)", "MoTe2"],
"task": "regression",
"result": "35.89 ± 12.40 meV/atom MAE (5-fold, 5-seed ensemble)",
},
"phonons": {
"title": "🎵 Phonon Peak Frequency",
"description": "Predict peak phonon frequency (cm⁻¹) from crystal structure.",
"unit": "cm⁻¹",
"example": "Si (diamond cubic)",
"examples": ["Si", "GaAs", "MgO", "BN (wurtzite)", "TiO2 (rutile)"],
"task": "regression",
"result": "41.91 ± 4.04 cm⁻¹ MAE (5-fold, gate-halt GraphTRIADS)",
},
}
# ─────────────────────────────────────────────────────────────────
# MODEL DEFINITIONS (inlined for self-contained app)
# ─────────────────────────────────────────────────────────────────
class DeepHybridTRM(nn.Module):
"""
HybridTRIADS — composition-only tasks.
Shared across: steels, expt_gap, ismetal, glass, jdft2d.
"""
def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
dropout=0.2, max_steps=20, **kw):
super().__init__()
self.max_steps, self.D = max_steps, d_hidden
self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra
self.tok_proj = nn.Sequential(
nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
self.m2v_proj = nn.Sequential(
nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
self.sa1_n = nn.LayerNorm(d_attn)
self.sa1_ff = nn.Sequential(
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_attn*2, d_attn))
self.sa1_fn = nn.LayerNorm(d_attn)
self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
self.sa2_n = nn.LayerNorm(d_attn)
self.sa2_ff = nn.Sequential(
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_attn*2, d_attn))
self.sa2_fn = nn.LayerNorm(d_attn)
self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
self.ca_n = nn.LayerNorm(d_attn)
pool_in = d_attn + (n_extra if n_extra > 0 else 0)
self.pool = nn.Sequential(
nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
self.z_up = nn.Sequential(
nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
self.y_up = nn.Sequential(
nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
self.head = nn.Linear(d_hidden, 1)
self._init()
def _init(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
def _attention(self, x):
B = x.size(0)
mg_dim = self.n_props * self.stat_dim
if self.n_extra > 0:
extra = x[:, mg_dim:mg_dim + self.n_extra]
m2v = x[:, mg_dim + self.n_extra:]
else:
extra, m2v = None, x[:, mg_dim:]
tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim))
ctx = self.m2v_proj(m2v).unsqueeze(1)
tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
tok = self.sa1_fn(tok + self.sa1_ff(tok))
tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
tok = self.sa2_fn(tok + self.sa2_ff(tok))
tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
pooled = tok.mean(dim=1)
if extra is not None:
pooled = torch.cat([pooled, extra], dim=-1)
return self.pool(pooled)
def forward(self, x, deep_supervision=False):
B = x.size(0)
xp = self._attention(x)
z = torch.zeros(B, self.D, device=x.device)
y = torch.zeros(B, self.D, device=x.device)
step_preds = []
for _ in range(self.max_steps):
z = z + self.z_up(torch.cat([xp, y, z], -1))
y = y + self.y_up(torch.cat([y, z], -1))
step_preds.append(self.head(y).squeeze(1))
return step_preds if deep_supervision else step_preds[-1]
# ─────────────────────────────────────────────────────────────────
# FEATURIZER (composition-only, shared across HybridTRIADS tasks)
# ─────────────────────────────────────────────────────────────────
_featurizer_cache = {}
_mat2vec_cache = {}
_featurizer_err = None
def _get_featurizer():
"""Lazy-load the ExpandedFeaturizer (downloads Mat2Vec once)."""
global _featurizer_err
if "main" in _featurizer_cache:
return _featurizer_cache["main"]
try:
from matminer.featurizers.composition import (
ElementProperty, ElementFraction, Stoichiometry,
ValenceOrbital, IonProperty, BandCenter
)
from matminer.featurizers.base import MultipleFeaturizer
from gensim.models import Word2Vec
from sklearn.preprocessing import StandardScaler
GCS = "https://storage.googleapis.com/mat2vec/"
M2V_FILES = [
"pretrained_embeddings",
"pretrained_embeddings.wv.vectors.npy",
"pretrained_embeddings.trainables.syn1neg.npy",
]
# Use /tmp for writable cache if current dir is read-only
cache_dir = os.path.join(os.getcwd(), "mat2vec_cache")
try:
os.makedirs(cache_dir, exist_ok=True)
# Test write access
test_file = os.path.join(cache_dir, ".test")
with open(test_file, 'w') as f: f.write('1')
os.remove(test_file)
except Exception:
cache_dir = "/tmp/mat2vec_cache"
os.makedirs(cache_dir, exist_ok=True)
for f in M2V_FILES:
p = os.path.join(cache_dir, f)
if not os.path.exists(p):
print(f"Downloading {f}...")
urllib.request.urlretrieve(GCS + f, p)
# Magpie preset can fail if Figshare is down
try:
ep = ElementProperty.from_preset("magpie")
except Exception as e:
print(f"Magpie download failed, retrying once: {e}")
import time
time.sleep(2)
ep = ElementProperty.from_preset("magpie")
m2v = Word2Vec.load(os.path.join(cache_dir, "pretrained_embeddings"))
emb = {w: m2v.wv[w] for w in m2v.wv.index_to_key}
extra = MultipleFeaturizer([ElementFraction(), Stoichiometry(),
ValenceOrbital(), IonProperty(), BandCenter()])
_featurizer_cache["main"] = (ep, m2v, emb, extra)
return _featurizer_cache["main"]
except Exception as e:
_featurizer_err = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
print(f"CRITICAL Featurizer Error: {_featurizer_err}")
return None
def featurize_composition(formula: str):
"""Featurize a chemical formula into the TRIADS feature vector."""
from pymatgen.core import Composition
result = _get_featurizer()
if result is None:
return None, f"Featurizer initialization failed.\nError: {_featurizer_err}"
ep, m2v, emb, extra = result
try:
comp = Composition(formula)
except Exception as e:
return None, f"Invalid formula: '{formula}' | {str(e)}"
try:
mg = np.array(ep.featurize(comp), np.float32)
except Exception as e:
mg = np.zeros(len(ep.feature_labels()), np.float32)
try:
ex = np.array(extra.featurize(comp), np.float32)
ex = np.nan_to_num(ex, nan=0.0)
except Exception as e:
ex = np.zeros(50, np.float32)
# Mat2Vec pooled
v, t = np.zeros(200, np.float32), 0.0
for s, f in comp.get_el_amt_dict().items():
if s in emb:
v += f * emb[s]
t += f
m2v_vec = v / max(t, 1e-8)
mg = np.nan_to_num(mg, nan=0.0)
feat = np.concatenate([mg, ex, m2v_vec])
return feat.astype(np.float32), None
# ─────────────────────────────────────────────────────────────────
# WEIGHT LOADING (lazy, cached)
# ─────────────────────────────────────────────────────────────────
_fold_models = {} # benchmark -> list[nn.Module]
_MODEL_CONFIGS = {
"steels": dict(d_attn=64, d_hidden=96, ff_dim=150, dropout=0.20, max_steps=20),
"expt_gap": dict(d_attn=64, d_hidden=96, ff_dim=150, dropout=0.20, max_steps=20),
"ismetal": dict(d_attn=24, d_hidden=48, ff_dim=72, dropout=0.20, max_steps=16),
"glass": dict(d_attn=24, d_hidden=48, ff_dim=72, dropout=0.20, max_steps=16),
"jdft2d": dict(d_attn=32, d_hidden=64, ff_dim=96, dropout=0.20, max_steps=16),
}
_HF_PATHS = {
"steels": "weights/steels/weights.pt",
"expt_gap": "weights/expt_gap/weights.pt",
"ismetal": "weights/is_metal/weights.pt",
"glass": "weights/glass/weights.pt",
"jdft2d": "weights/jdft2d/weights.pt",
"phonons": "weights/phonons/weights.pt",
}
def _load_benchmark_models(benchmark: str):
if benchmark in _fold_models:
return _fold_models[benchmark]
if benchmark == "phonons":
return None
try:
# 1. Try local path first (relative to app.py)
local_path = _HF_PATHS[benchmark]
if os.path.exists(local_path):
path = local_path
else:
# 2. Try hfHub if local missing
print(f"Local weight {local_path} missing. Attempting hf_hub_download...")
path = hf_hub_download(repo_id=REPO_ID, filename=_HF_PATHS[benchmark])
ckpt = torch.load(path, map_location="cpu", weights_only=False)
fold_entries = ckpt.get("folds", [ckpt])
n_extra = ckpt.get("n_extra", 0)
cfg = {**_MODEL_CONFIGS[benchmark], "n_extra": n_extra}
models = []
for entry in fold_entries:
m = DeepHybridTRM(**cfg)
sd = entry.get("model_state", entry) if isinstance(entry, dict) else entry
m.load_state_dict(sd)
m.eval()
models.append(m)
_fold_models[benchmark] = models
return models
except Exception as e:
err_msg = f"Error loading {benchmark} weights: {e}\n{traceback.format_exc()}"
print(err_msg)
return None
def _ensemble_predict(benchmark: str, x: np.ndarray, is_classification: bool = False):
models = _load_benchmark_models(benchmark)
if not models:
return None, "Weights could not be loaded. See logs."
xt = torch.tensor(x[None], dtype=torch.float32)
preds = []
for m in models:
with torch.no_grad():
out = m(xt).item()
if is_classification:
out = torch.sigmoid(torch.tensor(out)).item()
preds.append(out)
return float(np.mean(preds)), None
# ─────────────────────────────────────────────────────────────────
# PREDICTION FUNCTIONS
# ─────────────────────────────────────────────────────────────────
def predict_steels(formula: str):
feat, err = featurize_composition(formula)
if err: return f"❌ Error: {err}", ""
pred, err = _ensemble_predict("steels", feat)
if err: return f"❌ {err}", ""
return f"### {pred:.1f} MPa", f"**{pred:.1f} MPa** yield strength"
def predict_expt_gap(formula: str):
feat, err = featurize_composition(formula)
if err: return f"❌ Error: {err}", ""
pred, err = _ensemble_predict("expt_gap", feat)
if err: return f"❌ {err}", ""
return f"### {pred:.3f} eV", f"**{pred:.3f} eV** band gap"
def predict_ismetal(formula: str):
feat, err = featurize_composition(formula)
if err: return f"❌ Error: {err}", ""
pred, err = _ensemble_predict("ismetal", feat, True)
if err: return f"❌ {err}", ""
label = "🔩 METALLIC" if pred > 0.5 else "💎 NON-METALLIC"
return f"### {pred:.3f} (Metal)", f"{label} (p={pred:.3f})"
def predict_glass(formula: str):
feat, err = featurize_composition(formula)
if err: return f"❌ Error: {err}", ""
pred, err = _ensemble_predict("glass", feat, True)
if err: return f"❌ {err}", ""
label = "🪟 GLASS-FORMER" if pred > 0.5 else "❌ CRYSTALLINE"
return f"### {pred:.3f} (Glass)", f"{label} (p={pred:.3f})"
def predict_jdft2d(formula: str):
feat, err = featurize_composition(formula)
if err: return f"❌ Error: {err}", ""
pred, err = _ensemble_predict("jdft2d", feat)
if err: return f"❌ {err}", ""
return f"### {pred:.1f} meV/atom", f"**{pred:.1f} meV/atom** exfoliation"
PHONONS_INFO = """
## 🎵 Phonon Peak Frequency
The **TRIADS V6 Graph-TRM** achieves **41.91 ± 4.04 cm⁻¹ MAE** on Matbench phonons, using a gate-based halting Graph Neural Network that adaptively runs 4–16 message-passing cycles.
### Architecture
- **Gate-based halting**: 4–16 adaptive GNN cycles (halts when gate activations drop below threshold)
- **Graph Attention TRM**: line-graph bond updates + joint self-attention + cross-attention
- **Input**: Full crystal structure — atom positions, bond distances, angles (requires CIF/POSCAR)
### Why no live demo?
The phonons model requires a **pre-computed crystal graph** (atom positions, bond lengths, bond angles).
Composition-only featurization is insufficient for phonon prediction — structural details like bond stiffness
and crystal symmetry are essential.
### Benchmark Results
| Model | MAE (cm⁻¹) |
|---|---|
| **TRIADS V6 (ours)** | **41.91 ± 4.04** |
| MEGNet | 28.76 |
| ALIGNN | 29.34 |
| MODNet | 45.39 |
| CrabNet | 47.09 |
| TRIADS V4 | 56.33 |
> **Note**: MEGNet and ALIGNN use full DFT structural relaxation data.
> TRIADS V6 achieves competitive performance with a simpler, more parameter-efficient Graph-TRM architecture (< 50K parameters).
"""
# ─────────────────────────────────────────────────────────────────
# INTERFACE
# ─────────────────────────────────────────────────────────────────
CSS = """
#result_text { font-size: 1.5rem; font-weight: 700; color: #6366f1; }
.benchmark-badge { background: #1e293b; color: #94a3b8; border-radius: 8px; padding: 8px; }
"""
def build():
with gr.Blocks(css=CSS, title="TRIADS") as demo:
gr.Markdown("# ⚡ TRIADS — Materials Property Prediction")
gr.Markdown("Recursive Information-Attention with Deep Supervision for all Matbench benchmarks.")
with gr.Tabs():
with gr.Tab("🔩 Steel Yield"):
f_s = gr.Textbox(label="Formula", value="Fe0.7Cr0.15Ni0.15")
btn_s = gr.Button("Predict", variant="primary")
out_s = gr.Markdown(elem_id="result_text")
ctx_s = gr.Markdown()
btn_s.click(predict_steels, f_s, [out_s, ctx_s])
with gr.Tab("⚡ Band Gap"):
f_g = gr.Textbox(label="Formula", value="TiO2")
btn_g = gr.Button("Predict", variant="primary")
out_g = gr.Markdown(elem_id="result_text")
ctx_g = gr.Markdown()
btn_g.click(predict_expt_gap, f_g, [out_g, ctx_g])
with gr.Tab("🔮 Metallicity"):
f_m = gr.Textbox(label="Formula", value="Cu")
btn_m = gr.Button("Predict", variant="primary")
out_m = gr.Markdown(elem_id="result_text")
ctx_m = gr.Markdown()
btn_m.click(predict_ismetal, f_m, [out_m, ctx_m])
with gr.Tab("🪟 Glass Forming"):
f_gf = gr.Textbox(label="Formula", value="Cu46Zr54")
btn_gf = gr.Button("Predict", variant="primary")
out_gf = gr.Markdown(elem_id="result_text")
ctx_gf = gr.Markdown()
btn_gf.click(predict_glass, f_gf, [out_gf, ctx_gf])
with gr.Tab("📐 JDFT2D"):
f_j = gr.Textbox(label="Formula", value="MoS2")
btn_j = gr.Button("Predict", variant="primary")
out_j = gr.Markdown(elem_id="result_text")
ctx_j = gr.Markdown()
btn_j.click(predict_jdft2d, f_j, [out_j, ctx_j])
with gr.Tab("🎵 Phonons"):
gr.Markdown(PHONONS_INFO)
return demo
if __name__ == "__main__":
build().launch()