| # MatText Aligned Embeddings v2: Multi-Modal Material Retrieval with Natural Language Queries |
|
|
| **A CLIP-style multi-modal embedding model that aligns 10+ material text representations into a shared 128-d vector space. Query with natural language ("oxide with high bandgap"), composition, CIF, SLICES, or any modality β retrieve matching materials.** |
|
|
| ## π v2 Key Features |
|
|
| | Feature | v1 | v2 | |
| |---------|----|----| |
| | Context length | 512 tokens | **1024 tokens** (captures long CIFs) | |
| | Natural language queries | β | **β
"oxide with high bandgap"** | |
| | Property-aware retrieval | Basic | **LaCLIP-style diverse NL descriptions** | |
| | GPU optimization | fp16 / 24GB | **bf16 / 80GB A100 optimized** | |
| | Effective batch size | 256 | **288** | |
| | Modalities per step | 4 | **5** | |
| | Flash Attention 2 | β | **β
(auto-detect)** | |
|
|
| ## ποΈ Architecture |
|
|
| ``` |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| β MatTextEncoder (157M params) β |
| β β |
| β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β |
| β β Shared Backbone: ModernBERT-base (150M params, 8192 ctx) β β |
| β β Mean pooling β 768-d representation β β |
| β β Gradient checkpointing + bf16 β β |
| β ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ β |
| β β β |
| β βββββββββββββββ¬βββββββββββ΄βββββββββββ¬βββββββββββββββ β |
| β βΌ βΌ βΌ βΌ β |
| β βββββββββββ ββββββββββββ βββββββββββββββββββββ ββββββββββββ β |
| β βcomp β βcif_sym β βnl_property_desc β βproperty β ...Γ12 β |
| β β768β768 β β768β768 β β768β768β128 β β768β768 β β |
| β ββ128 β ββ128 β β"oxide with high β ββ128 β β |
| β β β β β β bandgap" queries β β β β |
| β ββββββ¬βββββ ββββββ¬ββββββ ββββββββββ¬βββββββββββ ββββββ¬ββββββ β |
| β βΌ βΌ βΌ βΌ β |
| β 128-d L2 128-d L2 128-d L2 128-d L2 β |
| β β |
| β ββββ Shared 128-d Embedding Space ββββ β |
| β (FAISS IndexFlatIP for cosine similarity search) β |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| ``` |
|
|
| ### 12 Projection Heads |
|
|
| | # | Head | Input | Purpose | |
| |---|------|-------|---------| |
| | 1 | `composition` | "Fe2O3" | Formula queries | |
| | 2 | `atom_sequences` | "Fe Fe O O O" | Element list queries | |
| | 3 | `cif_symmetrized` | Full CIF | Paste CIF data | |
| | 4 | `cif_p1` | CIF in P1 | P1 space group CIF | |
| | 5 | `zmatrix` | Z-matrix coords | Internal coordinates | |
| | 6 | `atom_sequences_plusplus` | Elements + lattice | Atom sequence + cell | |
| | 7 | `slices` | SLICES encoding | Compact structure encoding | |
| | 8 | `crystal_text_llm` | Gruver format | Lattice + coords text | |
| | 9 | `local_env` | SMILES-like env | Local bonding environment | |
| | 10 | `robocrys_rep` | NL description | "FeO crystallizes in..." | |
| | 11 | **`nl_property_description`** | **Free-form NL** | **"oxide with high bandgap"** | |
| | 12 | `property` | Structured props | "bandgap: 2.1 eV" | |
|
|
| ## π How NL Queries Work |
|
|
| The key innovation is a **LaCLIP-style** training approach ([arxiv:2305.20088](https://arxiv.org/abs/2305.20088)): |
|
|
| 1. **During Phase 2 training**, for each material with known properties (bandgap, formation energy), we generate **diverse natural language descriptions** from templates: |
| - `"A wide bandgap oxide suitable for UV applications, bandgap 3.20 eV"` |
| - `"TiO2: oxide semiconductor with wide band gap of 3.20 electron volts"` |
| - `"This binary oxide (TiO2) exhibits a wide bandgap of approximately 3.20 eV"` |
| |
| 2. These NL descriptions are passed through a **dedicated `nl_property_description` projection head** and aligned with ALL structure modalities via InfoNCE. |
|
|
| 3. **At inference**, when you query `"oxide with high bandgap"`, the model maps it through the same NL head into the shared embedding space, and FAISS finds the nearest materials β those that were trained to be close to similar descriptions. |
|
|
| This is distinct from `robocrys_rep` (which describes crystal *structure*: "FeO crystallizes in the rock salt structure..."). The NL query head describes *properties* ("wide bandgap oxide"). |
|
|
| ## π§ͺ Training Recipe |
|
|
| ### Two-Phase Training |
|
|
| **Phase 1 β Multi-modal alignment** (pretrain100k_v2, 60k samples, 3 epochs): |
| - AllPairsCLIP loss across 10 modalities |
| - Random modality sampling (5/10 per step) β always includes composition + crystal_text_llm |
| - Effective batch 288 |
| |
| **Phase 2 β Property-conditioned + NL query alignment** (bandgap + formation_energy, 60k samples, 3 epochs): |
| - AllPairsCLIP loss (structure modalities) |
| - **NL description β structure InfoNCE** (the key NL query loss) |
| - Property β composition/crystal_text_llm InfoNCE ([MatExpert](https://arxiv.org/abs/2410.21317)) |
| - SupReMix-style property similarity MSE ([arxiv:2309.16633](https://arxiv.org/abs/2309.16633)) |
| - Loss weights: `L = L_clip + 0.3 * L_property + 0.5 * L_nl` |
|
|
| ### Based On |
|
|
| | Paper | Contribution | ArXiv | |
| |-------|-------------|-------| |
| | **MultiMat** | AllPairsCLIP loss | [2312.00111](https://arxiv.org/abs/2312.00111) | |
| | **MatExpert** | Propertyβstructure InfoNCE | [2410.21317](https://arxiv.org/abs/2410.21317) | |
| | **LaCLIP** | LLM text augmentation for CLIP | [2305.20088](https://arxiv.org/abs/2305.20088) | |
| | **SupReMix** | Property-label-aware soft contrastive | [2309.16633](https://arxiv.org/abs/2309.16633) | |
| | **CrystalCLR** | Composition similarity | [2211.13408](https://arxiv.org/abs/2211.13408) | |
|
|
| ### Hyperparameters |
|
|
| ```yaml |
| encoder: answerdotai/ModernBERT-base |
| embed_dim: 128 |
| max_length: 1024 tokens |
| batch_size: 48 Γ 6 grad_accum = 288 effective |
| learning_rate: 2e-5 (phase 1), 1e-5 (phase 2) |
| temperature: learnable (init 0.07) |
| epochs: 3 per phase |
| optimizer: AdamW (weight_decay=0.01) |
| precision: bf16 (A100) / fp16 (T4/V100) |
| gradient_checkpointing: True |
| max_modalities_per_step: 5 |
| ``` |
|
|
| ## π Quick Start |
|
|
| ### Training (your GPU) |
|
|
| ```bash |
| pip install torch transformers datasets faiss-cpu huggingface_hub trackio accelerate |
| |
| # Optional but recommended for A100/H100: |
| pip install flash-attn --no-build-isolation |
| |
| python train_mattext_embeddings.py |
| ``` |
|
|
| The script auto-detects: |
| - GPU capability (bf16 for Ampere+, fp16 otherwise) |
| - Flash Attention 2 availability |
| - CUDA vs CPU |
|
|
| ### Inference & Search |
|
|
| ```python |
| import torch |
| import faiss |
| import json |
| import numpy as np |
| from transformers import AutoTokenizer |
| from train_mattext_embeddings import MatTextEncoder, Config, search_vector_db |
| |
| # Load |
| config = Config() |
| config.device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = MatTextEncoder(config) |
| model.load_state_dict(torch.load("mattext-embeddings/model.pt", map_location=config.device)) |
| model = model.to(config.device).eval() |
| tokenizer = AutoTokenizer.from_pretrained(config.encoder_name) |
| |
| # Load FAISS indices |
| indices = {} |
| for mod in ["composition", "crystal_text_llm", "slices", "cif_symmetrized", "robocrys_rep"]: |
| index = faiss.read_index(f"mattext-embeddings/faiss/{mod}.index") |
| with open(f"mattext-embeddings/faiss/{mod}_metadata.json") as f: |
| metadata = json.load(f) |
| indices[mod] = {"index": index, "metadata": metadata} |
| ``` |
|
|
| ### Query Examples |
|
|
| ```python |
| # π Natural language property queries (THE KEY FEATURE) |
| search_vector_db("oxide with high bandgap", "nl_property_description", model, tokenizer, indices, config) |
| search_vector_db("stable ternary nitride", "nl_property_description", model, tokenizer, indices, config) |
| search_vector_db("narrow bandgap semiconductor for IR", "nl_property_description", model, tokenizer, indices, config) |
| search_vector_db("metallic binary compound", "nl_property_description", model, tokenizer, indices, config) |
| |
| # π§ͺ Composition queries |
| search_vector_db("Fe2O3", "composition", model, tokenizer, indices, config) |
| search_vector_db("BaTiO3", "composition", model, tokenizer, indices, config) |
| |
| # π Structure description queries |
| search_vector_db("perovskite with octahedral coordination", "robocrys_rep", model, tokenizer, indices, config) |
| |
| # π Structured property queries |
| search_vector_db("composition: TiO2 | bandgap: 3.2000", "property", model, tokenizer, indices, config) |
| |
| # π¬ CIF queries (paste your CIF) |
| search_vector_db("data_TiO2\n_symmetry P1\n_cell 4.59 4.59 2.96 90 90 90", "cif_symmetrized", ...) |
| |
| # 𧬠SLICES queries |
| search_vector_db("Ti O 0 1 o o o", "slices", model, tokenizer, indices, config) |
| ``` |
|
|
| ## π Evaluation Metrics |
|
|
| Cross-modal Recall@k on test set: |
|
|
| | Pair | R@1 | R@5 | R@10 | R@20 | |
| |------|-----|-----|------|------| |
| | composition β crystal_text_llm | TBD | TBD | TBD | TBD | |
| | composition β cif_symmetrized | TBD | TBD | TBD | TBD | |
| | composition β slices | TBD | TBD | TBD | TBD | |
| | slices β crystal_text_llm | TBD | TBD | TBD | TBD | |
| | robocrys_rep β composition | TBD | TBD | TBD | TBD | |
|
|
| NL Query Results: |
|
|
| | Query | Top-1 Match | Score | |
| |-------|------------|-------| |
| | "oxide with high bandgap" | TBD | TBD | |
| | "narrow bandgap semiconductor" | TBD | TBD | |
| | "stable binary oxide" | TBD | TBD | |
|
|
| *Results populated after training.* |
|
|
| ## π§© Extending: Graph Embeddings |
|
|
| The architecture is plug-and-play for new modalities: |
|
|
| ```python |
| # Add a GNN modality |
| from torch_geometric.nn import SchNet |
| |
| class GraphEncoder(nn.Module): |
| def __init__(self, embed_dim=128): |
| super().__init__() |
| self.gnn = SchNet(hidden_channels=256) |
| self.proj = ModalityProjection(256, embed_dim) |
| |
| def forward(self, data): |
| h = self.gnn(data.z, data.pos, data.batch) |
| return self.proj(h) |
| |
| # Register as new modality |
| model.projections["graph"] = graph_encoder.proj |
| # It gets aligned automatically through AllPairsCLIP |
| ``` |
|
|
| ## π¦ Dataset |
|
|
| [n0w0f/MatText](https://huggingface.co/datasets/n0w0f/MatText) β 100k+ crystal structures in 10+ text representations |
|
|
| ## π References |
|
|
| - **MatText**: [arxiv:2406.17295](https://arxiv.org/abs/2406.17295) |
| - **MultiMat**: [arxiv:2312.00111](https://arxiv.org/abs/2312.00111) |
| - **MatExpert**: [arxiv:2410.21317](https://arxiv.org/abs/2410.21317) |
| - **LaCLIP**: [arxiv:2305.20088](https://arxiv.org/abs/2305.20088) |
| - **SupReMix**: [arxiv:2309.16633](https://arxiv.org/abs/2309.16633) |
| - **CrystalCLR**: [arxiv:2211.13408](https://arxiv.org/abs/2211.13408) |
| - **Symile**: [arxiv:2411.01053](https://arxiv.org/abs/2411.01053) |
|
|
| ## π License |
|
|
| MIT |
|
|