| """ |
| TRIADS — Multi-Benchmark Materials Property Prediction |
| HuggingFace Gradio App (Production Redux) |
| |
| Covers all 6 Matbench benchmarks: |
| 1. matbench_steels — Yield Strength (MPa) |
| 2. matbench_expt_gap — Band Gap (eV) |
| 3. matbench_ismetal — Metallicity (ROC-AUC) |
| 4. matbench_glass — Glass Forming Ability |
| 5. matbench_jdft2d — Exfoliation Energy (meV/atom) |
| 6. matbench_phonons — Peak Phonon Frequency (cm⁻¹) |
| """ |
|
|
| import os |
| import warnings |
| import urllib.request |
| import json |
| import traceback |
|
|
| warnings.filterwarnings("ignore") |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import gradio as gr |
| from huggingface_hub import hf_hub_download |
|
|
| |
| |
| |
|
|
| REPO_ID = "Rtx09/TRIADS" |
|
|
| BENCHMARK_INFO = { |
| "steels": { |
| "title": "🔩 Steel Yield Strength", |
| "description": "Predict yield strength (MPa) of steel alloys from composition.", |
| "unit": "MPa", |
| "example": "Fe0.7Cr0.15Ni0.15", |
| "examples": ["Fe0.7Cr0.15Ni0.15", "Fe0.8C0.02Mn0.1Si0.05Cr0.03", "Fe0.6Ni0.25Mo0.1Cr0.05"], |
| "task": "regression", |
| "result": "91.20 ± 12.23 MPa MAE (5-fold, 5-seed ensemble)", |
| }, |
| "expt_gap": { |
| "title": "⚡ Experimental Band Gap", |
| "description": "Predict experimental electronic band gap (eV) from composition.", |
| "unit": "eV", |
| "example": "TiO2", |
| "examples": ["TiO2", "GaN", "ZnO", "Si", "CdS"], |
| "task": "regression", |
| "result": "0.3068 ± 0.0082 eV MAE (5-fold, composition-only)", |
| }, |
| "ismetal": { |
| "title": "🔮 Metallicity Classifier", |
| "description": "Predict whether a material is metallic or non-metallic from composition.", |
| "unit": "probability (1 = metal)", |
| "example": "Cu", |
| "examples": ["Cu", "SiO2", "Fe3O4", "BaTiO3", "Al"], |
| "task": "classification", |
| "result": "0.9655 ± 0.0029 ROC-AUC (5-fold, composition-only)", |
| }, |
| "glass": { |
| "title": "🪟 Glass Forming Ability", |
| "description": "Predict metallic glass forming ability from alloy composition.", |
| "unit": "probability (1 = glass former)", |
| "example": "Cu46Zr54", |
| "examples": ["Cu46Zr54", "Fe80B20", "Al86Ni7La6Y1", "Pd40Cu30Ni10P20"], |
| "task": "classification", |
| "result": "0.9285 ± 0.0063 ROC-AUC (5-fold, 5-seed ensemble)", |
| }, |
| "jdft2d": { |
| "title": "📐 Exfoliation Energy", |
| "description": "Predict exfoliation energy (meV/atom) of 2D materials from structure+composition.", |
| "unit": "meV/atom", |
| "example": "MoS2", |
| "examples": ["MoS2", "WSe2", "BN", "graphene (C)", "MoTe2"], |
| "task": "regression", |
| "result": "35.89 ± 12.40 meV/atom MAE (5-fold, 5-seed ensemble)", |
| }, |
| "phonons": { |
| "title": "🎵 Phonon Peak Frequency", |
| "description": "Predict peak phonon frequency (cm⁻¹) from crystal structure.", |
| "unit": "cm⁻¹", |
| "example": "Si (diamond cubic)", |
| "examples": ["Si", "GaAs", "MgO", "BN (wurtzite)", "TiO2 (rutile)"], |
| "task": "regression", |
| "result": "41.91 ± 4.04 cm⁻¹ MAE (5-fold, gate-halt GraphTRIADS)", |
| }, |
| } |
|
|
|
|
| |
| |
| |
|
|
| class DeepHybridTRM(nn.Module): |
| """ |
| HybridTRIADS — composition-only tasks. |
| Shared across: steels, expt_gap, ismetal, glass, jdft2d. |
| """ |
| def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200, |
| d_attn=64, nhead=4, d_hidden=96, ff_dim=150, |
| dropout=0.2, max_steps=20, **kw): |
| super().__init__() |
| self.max_steps, self.D = max_steps, d_hidden |
| self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra |
|
|
| self.tok_proj = nn.Sequential( |
| nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU()) |
| self.m2v_proj = nn.Sequential( |
| nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU()) |
|
|
| self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True) |
| self.sa1_n = nn.LayerNorm(d_attn) |
| self.sa1_ff = nn.Sequential( |
| nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(d_attn*2, d_attn)) |
| self.sa1_fn = nn.LayerNorm(d_attn) |
|
|
| self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True) |
| self.sa2_n = nn.LayerNorm(d_attn) |
| self.sa2_ff = nn.Sequential( |
| nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(d_attn*2, d_attn)) |
| self.sa2_fn = nn.LayerNorm(d_attn) |
|
|
| self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True) |
| self.ca_n = nn.LayerNorm(d_attn) |
|
|
| pool_in = d_attn + (n_extra if n_extra > 0 else 0) |
| self.pool = nn.Sequential( |
| nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU()) |
|
|
| self.z_up = nn.Sequential( |
| nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden)) |
| self.y_up = nn.Sequential( |
| nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout), |
| nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden)) |
| self.head = nn.Linear(d_hidden, 1) |
| self._init() |
|
|
| def _init(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: nn.init.zeros_(m.bias) |
|
|
| def _attention(self, x): |
| B = x.size(0) |
| mg_dim = self.n_props * self.stat_dim |
| if self.n_extra > 0: |
| extra = x[:, mg_dim:mg_dim + self.n_extra] |
| m2v = x[:, mg_dim + self.n_extra:] |
| else: |
| extra, m2v = None, x[:, mg_dim:] |
|
|
| tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim)) |
| ctx = self.m2v_proj(m2v).unsqueeze(1) |
|
|
| tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0]) |
| tok = self.sa1_fn(tok + self.sa1_ff(tok)) |
| tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0]) |
| tok = self.sa2_fn(tok + self.sa2_ff(tok)) |
| tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0]) |
|
|
| pooled = tok.mean(dim=1) |
| if extra is not None: |
| pooled = torch.cat([pooled, extra], dim=-1) |
| return self.pool(pooled) |
|
|
| def forward(self, x, deep_supervision=False): |
| B = x.size(0) |
| xp = self._attention(x) |
| z = torch.zeros(B, self.D, device=x.device) |
| y = torch.zeros(B, self.D, device=x.device) |
| step_preds = [] |
| for _ in range(self.max_steps): |
| z = z + self.z_up(torch.cat([xp, y, z], -1)) |
| y = y + self.y_up(torch.cat([y, z], -1)) |
| step_preds.append(self.head(y).squeeze(1)) |
| return step_preds if deep_supervision else step_preds[-1] |
|
|
|
|
| |
| |
| |
|
|
| _featurizer_cache = {} |
| _mat2vec_cache = {} |
| _featurizer_err = None |
|
|
| def _get_featurizer(): |
| """Lazy-load the ExpandedFeaturizer (downloads Mat2Vec once).""" |
| global _featurizer_err |
| if "main" in _featurizer_cache: |
| return _featurizer_cache["main"] |
|
|
| try: |
| from matminer.featurizers.composition import ( |
| ElementProperty, ElementFraction, Stoichiometry, |
| ValenceOrbital, IonProperty, BandCenter |
| ) |
| from matminer.featurizers.base import MultipleFeaturizer |
| from gensim.models import Word2Vec |
| from sklearn.preprocessing import StandardScaler |
|
|
| GCS = "https://storage.googleapis.com/mat2vec/" |
| M2V_FILES = [ |
| "pretrained_embeddings", |
| "pretrained_embeddings.wv.vectors.npy", |
| "pretrained_embeddings.trainables.syn1neg.npy", |
| ] |
| |
| |
| cache_dir = os.path.join(os.getcwd(), "mat2vec_cache") |
| try: |
| os.makedirs(cache_dir, exist_ok=True) |
| |
| test_file = os.path.join(cache_dir, ".test") |
| with open(test_file, 'w') as f: f.write('1') |
| os.remove(test_file) |
| except Exception: |
| cache_dir = "/tmp/mat2vec_cache" |
| os.makedirs(cache_dir, exist_ok=True) |
|
|
| for f in M2V_FILES: |
| p = os.path.join(cache_dir, f) |
| if not os.path.exists(p): |
| print(f"Downloading {f}...") |
| urllib.request.urlretrieve(GCS + f, p) |
|
|
| |
| try: |
| ep = ElementProperty.from_preset("magpie") |
| except Exception as e: |
| print(f"Magpie download failed, retrying once: {e}") |
| import time |
| time.sleep(2) |
| ep = ElementProperty.from_preset("magpie") |
|
|
| m2v = Word2Vec.load(os.path.join(cache_dir, "pretrained_embeddings")) |
| emb = {w: m2v.wv[w] for w in m2v.wv.index_to_key} |
| extra = MultipleFeaturizer([ElementFraction(), Stoichiometry(), |
| ValenceOrbital(), IonProperty(), BandCenter()]) |
|
|
| _featurizer_cache["main"] = (ep, m2v, emb, extra) |
| return _featurizer_cache["main"] |
|
|
| except Exception as e: |
| _featurizer_err = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" |
| print(f"CRITICAL Featurizer Error: {_featurizer_err}") |
| return None |
|
|
|
|
| def featurize_composition(formula: str): |
| """Featurize a chemical formula into the TRIADS feature vector.""" |
| from pymatgen.core import Composition |
|
|
| result = _get_featurizer() |
| if result is None: |
| return None, f"Featurizer initialization failed.\nError: {_featurizer_err}" |
|
|
| ep, m2v, emb, extra = result |
|
|
| try: |
| comp = Composition(formula) |
| except Exception as e: |
| return None, f"Invalid formula: '{formula}' | {str(e)}" |
|
|
| try: |
| mg = np.array(ep.featurize(comp), np.float32) |
| except Exception as e: |
| mg = np.zeros(len(ep.feature_labels()), np.float32) |
|
|
| try: |
| ex = np.array(extra.featurize(comp), np.float32) |
| ex = np.nan_to_num(ex, nan=0.0) |
| except Exception as e: |
| ex = np.zeros(50, np.float32) |
|
|
| |
| v, t = np.zeros(200, np.float32), 0.0 |
| for s, f in comp.get_el_amt_dict().items(): |
| if s in emb: |
| v += f * emb[s] |
| t += f |
| m2v_vec = v / max(t, 1e-8) |
|
|
| mg = np.nan_to_num(mg, nan=0.0) |
| feat = np.concatenate([mg, ex, m2v_vec]) |
| return feat.astype(np.float32), None |
|
|
|
|
| |
| |
| |
|
|
| _fold_models = {} |
|
|
| _MODEL_CONFIGS = { |
| "steels": dict(d_attn=64, d_hidden=96, ff_dim=150, dropout=0.20, max_steps=20), |
| "expt_gap": dict(d_attn=64, d_hidden=96, ff_dim=150, dropout=0.20, max_steps=20), |
| "ismetal": dict(d_attn=24, d_hidden=48, ff_dim=72, dropout=0.20, max_steps=16), |
| "glass": dict(d_attn=24, d_hidden=48, ff_dim=72, dropout=0.20, max_steps=16), |
| "jdft2d": dict(d_attn=32, d_hidden=64, ff_dim=96, dropout=0.20, max_steps=16), |
| } |
|
|
| _HF_PATHS = { |
| "steels": "weights/steels/weights.pt", |
| "expt_gap": "weights/expt_gap/weights.pt", |
| "ismetal": "weights/is_metal/weights.pt", |
| "glass": "weights/glass/weights.pt", |
| "jdft2d": "weights/jdft2d/weights.pt", |
| "phonons": "weights/phonons/weights.pt", |
| } |
|
|
|
|
| def _load_benchmark_models(benchmark: str): |
| if benchmark in _fold_models: |
| return _fold_models[benchmark] |
| if benchmark == "phonons": |
| return None |
|
|
| try: |
| |
| local_path = _HF_PATHS[benchmark] |
| if os.path.exists(local_path): |
| path = local_path |
| else: |
| |
| print(f"Local weight {local_path} missing. Attempting hf_hub_download...") |
| path = hf_hub_download(repo_id=REPO_ID, filename=_HF_PATHS[benchmark]) |
| |
| ckpt = torch.load(path, map_location="cpu", weights_only=False) |
|
|
| fold_entries = ckpt.get("folds", [ckpt]) |
| n_extra = ckpt.get("n_extra", 0) |
| cfg = {**_MODEL_CONFIGS[benchmark], "n_extra": n_extra} |
|
|
| models = [] |
| for entry in fold_entries: |
| m = DeepHybridTRM(**cfg) |
| sd = entry.get("model_state", entry) if isinstance(entry, dict) else entry |
| m.load_state_dict(sd) |
| m.eval() |
| models.append(m) |
|
|
| _fold_models[benchmark] = models |
| return models |
| except Exception as e: |
| err_msg = f"Error loading {benchmark} weights: {e}\n{traceback.format_exc()}" |
| print(err_msg) |
| return None |
|
|
|
|
| def _ensemble_predict(benchmark: str, x: np.ndarray, is_classification: bool = False): |
| models = _load_benchmark_models(benchmark) |
| if not models: |
| return None, "Weights could not be loaded. See logs." |
|
|
| xt = torch.tensor(x[None], dtype=torch.float32) |
| preds = [] |
| for m in models: |
| with torch.no_grad(): |
| out = m(xt).item() |
| if is_classification: |
| out = torch.sigmoid(torch.tensor(out)).item() |
| preds.append(out) |
| return float(np.mean(preds)), None |
|
|
|
|
| |
| |
| |
|
|
| def predict_steels(formula: str): |
| feat, err = featurize_composition(formula) |
| if err: return f"❌ Error: {err}", "" |
| pred, err = _ensemble_predict("steels", feat) |
| if err: return f"❌ {err}", "" |
| return f"### {pred:.1f} MPa", f"**{pred:.1f} MPa** yield strength" |
|
|
| def predict_expt_gap(formula: str): |
| feat, err = featurize_composition(formula) |
| if err: return f"❌ Error: {err}", "" |
| pred, err = _ensemble_predict("expt_gap", feat) |
| if err: return f"❌ {err}", "" |
| return f"### {pred:.3f} eV", f"**{pred:.3f} eV** band gap" |
|
|
| def predict_ismetal(formula: str): |
| feat, err = featurize_composition(formula) |
| if err: return f"❌ Error: {err}", "" |
| pred, err = _ensemble_predict("ismetal", feat, True) |
| if err: return f"❌ {err}", "" |
| label = "🔩 METALLIC" if pred > 0.5 else "💎 NON-METALLIC" |
| return f"### {pred:.3f} (Metal)", f"{label} (p={pred:.3f})" |
|
|
| def predict_glass(formula: str): |
| feat, err = featurize_composition(formula) |
| if err: return f"❌ Error: {err}", "" |
| pred, err = _ensemble_predict("glass", feat, True) |
| if err: return f"❌ {err}", "" |
| label = "🪟 GLASS-FORMER" if pred > 0.5 else "❌ CRYSTALLINE" |
| return f"### {pred:.3f} (Glass)", f"{label} (p={pred:.3f})" |
|
|
| def predict_jdft2d(formula: str): |
| feat, err = featurize_composition(formula) |
| if err: return f"❌ Error: {err}", "" |
| pred, err = _ensemble_predict("jdft2d", feat) |
| if err: return f"❌ {err}", "" |
| return f"### {pred:.1f} meV/atom", f"**{pred:.1f} meV/atom** exfoliation" |
|
|
| PHONONS_INFO = """ |
| ## 🎵 Phonon Peak Frequency |
| |
| The **TRIADS V6 Graph-TRM** achieves **41.91 ± 4.04 cm⁻¹ MAE** on Matbench phonons, using a gate-based halting Graph Neural Network that adaptively runs 4–16 message-passing cycles. |
| |
| ### Architecture |
| - **Gate-based halting**: 4–16 adaptive GNN cycles (halts when gate activations drop below threshold) |
| - **Graph Attention TRM**: line-graph bond updates + joint self-attention + cross-attention |
| - **Input**: Full crystal structure — atom positions, bond distances, angles (requires CIF/POSCAR) |
| |
| ### Why no live demo? |
| The phonons model requires a **pre-computed crystal graph** (atom positions, bond lengths, bond angles). |
| Composition-only featurization is insufficient for phonon prediction — structural details like bond stiffness |
| and crystal symmetry are essential. |
| |
| ### Benchmark Results |
| | Model | MAE (cm⁻¹) | |
| |---|---| |
| | **TRIADS V6 (ours)** | **41.91 ± 4.04** | |
| | MEGNet | 28.76 | |
| | ALIGNN | 29.34 | |
| | MODNet | 45.39 | |
| | CrabNet | 47.09 | |
| | TRIADS V4 | 56.33 | |
| |
| > **Note**: MEGNet and ALIGNN use full DFT structural relaxation data. |
| > TRIADS V6 achieves competitive performance with a simpler, more parameter-efficient Graph-TRM architecture (< 50K parameters). |
| """ |
|
|
|
|
| |
| |
| |
|
|
| CSS = """ |
| #result_text { font-size: 1.5rem; font-weight: 700; color: #6366f1; } |
| .benchmark-badge { background: #1e293b; color: #94a3b8; border-radius: 8px; padding: 8px; } |
| """ |
|
|
| def build(): |
| with gr.Blocks(css=CSS, title="TRIADS") as demo: |
| gr.Markdown("# ⚡ TRIADS — Materials Property Prediction") |
| gr.Markdown("Recursive Information-Attention with Deep Supervision for all Matbench benchmarks.") |
| |
| with gr.Tabs(): |
| with gr.Tab("🔩 Steel Yield"): |
| f_s = gr.Textbox(label="Formula", value="Fe0.7Cr0.15Ni0.15") |
| btn_s = gr.Button("Predict", variant="primary") |
| out_s = gr.Markdown(elem_id="result_text") |
| ctx_s = gr.Markdown() |
| btn_s.click(predict_steels, f_s, [out_s, ctx_s]) |
|
|
| with gr.Tab("⚡ Band Gap"): |
| f_g = gr.Textbox(label="Formula", value="TiO2") |
| btn_g = gr.Button("Predict", variant="primary") |
| out_g = gr.Markdown(elem_id="result_text") |
| ctx_g = gr.Markdown() |
| btn_g.click(predict_expt_gap, f_g, [out_g, ctx_g]) |
|
|
| with gr.Tab("🔮 Metallicity"): |
| f_m = gr.Textbox(label="Formula", value="Cu") |
| btn_m = gr.Button("Predict", variant="primary") |
| out_m = gr.Markdown(elem_id="result_text") |
| ctx_m = gr.Markdown() |
| btn_m.click(predict_ismetal, f_m, [out_m, ctx_m]) |
|
|
| with gr.Tab("🪟 Glass Forming"): |
| f_gf = gr.Textbox(label="Formula", value="Cu46Zr54") |
| btn_gf = gr.Button("Predict", variant="primary") |
| out_gf = gr.Markdown(elem_id="result_text") |
| ctx_gf = gr.Markdown() |
| btn_gf.click(predict_glass, f_gf, [out_gf, ctx_gf]) |
|
|
| with gr.Tab("📐 JDFT2D"): |
| f_j = gr.Textbox(label="Formula", value="MoS2") |
| btn_j = gr.Button("Predict", variant="primary") |
| out_j = gr.Markdown(elem_id="result_text") |
| ctx_j = gr.Markdown() |
| btn_j.click(predict_jdft2d, f_j, [out_j, ctx_j]) |
|
|
| with gr.Tab("🎵 Phonons"): |
| gr.Markdown(PHONONS_INFO) |
|
|
| return demo |
|
|
| if __name__ == "__main__": |
| build().launch() |
|
|