File size: 11,859 Bytes
283f249
5bc74d1
283f249
5bc74d1
283f249
 
 
 
 
 
 
 
 
 
 
5bc74d1
 
 
 
283f249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bc74d1
 
283f249
5bc74d1
283f249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bc74d1
283f249
 
 
 
 
 
5bc74d1
283f249
 
 
 
 
5bc74d1
 
 
283f249
 
 
 
5bc74d1
283f249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bc74d1
 
 
283f249
5bc74d1
 
283f249
 
 
5bc74d1
 
 
283f249
5bc74d1
283f249
5bc74d1
 
 
 
283f249
5bc74d1
 
283f249
5bc74d1
283f249
 
5bc74d1
283f249
5bc74d1
 
283f249
 
 
 
 
5bc74d1
 
 
 
 
 
 
283f249
5bc74d1
 
283f249
5bc74d1
 
 
 
283f249
5bc74d1
 
 
 
283f249
5bc74d1
 
 
 
 
 
283f249
5bc74d1
 
283f249
 
 
 
 
5bc74d1
283f249
 
 
5bc74d1
283f249
 
5bc74d1
283f249
 
5bc74d1
283f249
 
 
 
 
5bc74d1
 
283f249
 
 
 
 
 
 
 
 
 
 
5bc74d1
283f249
5bc74d1
283f249
 
 
 
 
5bc74d1
 
 
 
 
283f249
5bc74d1
 
283f249
 
5bc74d1
 
 
 
283f249
5bc74d1
 
 
 
 
 
283f249
 
 
5bc74d1
 
283f249
 
 
5bc74d1
 
 
283f249
 
 
 
 
 
 
5bc74d1
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# MatText Aligned Embeddings v2: Multi-Modal Material Retrieval with Natural Language Queries

**A CLIP-style multi-modal embedding model that aligns 10+ material text representations into a shared 128-d vector space. Query with natural language ("oxide with high bandgap"), composition, CIF, SLICES, or any modality β†’ retrieve matching materials.**

## πŸ†• v2 Key Features

| Feature | v1 | v2 |
|---------|----|----|
| Context length | 512 tokens | **1024 tokens** (captures long CIFs) |
| Natural language queries | ❌ | **βœ… "oxide with high bandgap"** |
| Property-aware retrieval | Basic | **LaCLIP-style diverse NL descriptions** |
| GPU optimization | fp16 / 24GB | **bf16 / 80GB A100 optimized** |
| Effective batch size | 256 | **288** |
| Modalities per step | 4 | **5** |
| Flash Attention 2 | ❌ | **βœ… (auto-detect)** |

## πŸ—οΈ Architecture

```
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                         MatTextEncoder (157M params)                   β”‚
β”‚                                                                        β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
β”‚  β”‚  Shared Backbone: ModernBERT-base (150M params, 8192 ctx)        β”‚ β”‚
β”‚  β”‚  Mean pooling β†’ 768-d representation                             β”‚ β”‚
β”‚  β”‚  Gradient checkpointing + bf16                                   β”‚ β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
β”‚                               β”‚                                        β”‚
β”‚     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”              β”‚
β”‚     β–Ό             β–Ό                     β–Ό              β–Ό              β”‚
β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”         β”‚
β”‚ β”‚comp     β”‚ β”‚cif_sym   β”‚ β”‚nl_property_desc   β”‚ β”‚property  β”‚  ...Γ—12 β”‚
β”‚ β”‚768β†’768  β”‚ β”‚768β†’768   β”‚ β”‚768β†’768β†’128        β”‚ β”‚768β†’768   β”‚         β”‚
β”‚ β”‚β†’128     β”‚ β”‚β†’128      β”‚ β”‚"oxide with high   β”‚ β”‚β†’128      β”‚         β”‚
β”‚ β”‚         β”‚ β”‚          β”‚ β”‚ bandgap" queries   β”‚ β”‚          β”‚         β”‚
β”‚ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜         β”‚
β”‚      β–Ό           β–Ό                β–Ό                  β–Ό               β”‚
β”‚  128-d L2     128-d L2        128-d L2           128-d L2            β”‚
β”‚                                                                        β”‚
β”‚              ──── Shared 128-d Embedding Space ────                    β”‚
β”‚     (FAISS IndexFlatIP for cosine similarity search)                  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
```

### 12 Projection Heads

| # | Head | Input | Purpose |
|---|------|-------|---------|
| 1 | `composition` | "Fe2O3" | Formula queries |
| 2 | `atom_sequences` | "Fe Fe O O O" | Element list queries |
| 3 | `cif_symmetrized` | Full CIF | Paste CIF data |
| 4 | `cif_p1` | CIF in P1 | P1 space group CIF |
| 5 | `zmatrix` | Z-matrix coords | Internal coordinates |
| 6 | `atom_sequences_plusplus` | Elements + lattice | Atom sequence + cell |
| 7 | `slices` | SLICES encoding | Compact structure encoding |
| 8 | `crystal_text_llm` | Gruver format | Lattice + coords text |
| 9 | `local_env` | SMILES-like env | Local bonding environment |
| 10 | `robocrys_rep` | NL description | "FeO crystallizes in..." |
| 11 | **`nl_property_description`** | **Free-form NL** | **"oxide with high bandgap"** |
| 12 | `property` | Structured props | "bandgap: 2.1 eV" |

## πŸ” How NL Queries Work

The key innovation is a **LaCLIP-style** training approach ([arxiv:2305.20088](https://arxiv.org/abs/2305.20088)):

1. **During Phase 2 training**, for each material with known properties (bandgap, formation energy), we generate **diverse natural language descriptions** from templates:
   - `"A wide bandgap oxide suitable for UV applications, bandgap 3.20 eV"` 
   - `"TiO2: oxide semiconductor with wide band gap of 3.20 electron volts"`
   - `"This binary oxide (TiO2) exhibits a wide bandgap of approximately 3.20 eV"`
   
2. These NL descriptions are passed through a **dedicated `nl_property_description` projection head** and aligned with ALL structure modalities via InfoNCE.

3. **At inference**, when you query `"oxide with high bandgap"`, the model maps it through the same NL head into the shared embedding space, and FAISS finds the nearest materials β€” those that were trained to be close to similar descriptions.

This is distinct from `robocrys_rep` (which describes crystal *structure*: "FeO crystallizes in the rock salt structure..."). The NL query head describes *properties* ("wide bandgap oxide").

## πŸ§ͺ Training Recipe

### Two-Phase Training

**Phase 1 β€” Multi-modal alignment** (pretrain100k_v2, 60k samples, 3 epochs):
- AllPairsCLIP loss across 10 modalities
- Random modality sampling (5/10 per step) β€” always includes composition + crystal_text_llm
- Effective batch 288

**Phase 2 β€” Property-conditioned + NL query alignment** (bandgap + formation_energy, 60k samples, 3 epochs):
- AllPairsCLIP loss (structure modalities)
- **NL description ↔ structure InfoNCE** (the key NL query loss)
- Property ↔ composition/crystal_text_llm InfoNCE ([MatExpert](https://arxiv.org/abs/2410.21317))
- SupReMix-style property similarity MSE ([arxiv:2309.16633](https://arxiv.org/abs/2309.16633))
- Loss weights: `L = L_clip + 0.3 * L_property + 0.5 * L_nl`

### Based On

| Paper | Contribution | ArXiv |
|-------|-------------|-------|
| **MultiMat** | AllPairsCLIP loss | [2312.00111](https://arxiv.org/abs/2312.00111) |
| **MatExpert** | Property↔structure InfoNCE | [2410.21317](https://arxiv.org/abs/2410.21317) |
| **LaCLIP** | LLM text augmentation for CLIP | [2305.20088](https://arxiv.org/abs/2305.20088) |
| **SupReMix** | Property-label-aware soft contrastive | [2309.16633](https://arxiv.org/abs/2309.16633) |
| **CrystalCLR** | Composition similarity | [2211.13408](https://arxiv.org/abs/2211.13408) |

### Hyperparameters

```yaml
encoder: answerdotai/ModernBERT-base
embed_dim: 128
max_length: 1024 tokens
batch_size: 48 Γ— 6 grad_accum = 288 effective
learning_rate: 2e-5 (phase 1), 1e-5 (phase 2)
temperature: learnable (init 0.07)
epochs: 3 per phase
optimizer: AdamW (weight_decay=0.01)
precision: bf16 (A100) / fp16 (T4/V100)
gradient_checkpointing: True
max_modalities_per_step: 5
```

## πŸš€ Quick Start

### Training (your GPU)

```bash
pip install torch transformers datasets faiss-cpu huggingface_hub trackio accelerate

# Optional but recommended for A100/H100:
pip install flash-attn --no-build-isolation

python train_mattext_embeddings.py
```

The script auto-detects:
- GPU capability (bf16 for Ampere+, fp16 otherwise)
- Flash Attention 2 availability
- CUDA vs CPU

### Inference & Search

```python
import torch
import faiss
import json
import numpy as np
from transformers import AutoTokenizer
from train_mattext_embeddings import MatTextEncoder, Config, search_vector_db

# Load
config = Config()
config.device = "cuda" if torch.cuda.is_available() else "cpu"
model = MatTextEncoder(config)
model.load_state_dict(torch.load("mattext-embeddings/model.pt", map_location=config.device))
model = model.to(config.device).eval()
tokenizer = AutoTokenizer.from_pretrained(config.encoder_name)

# Load FAISS indices
indices = {}
for mod in ["composition", "crystal_text_llm", "slices", "cif_symmetrized", "robocrys_rep"]:
    index = faiss.read_index(f"mattext-embeddings/faiss/{mod}.index")
    with open(f"mattext-embeddings/faiss/{mod}_metadata.json") as f:
        metadata = json.load(f)
    indices[mod] = {"index": index, "metadata": metadata}
```

### Query Examples

```python
# πŸ” Natural language property queries (THE KEY FEATURE)
search_vector_db("oxide with high bandgap", "nl_property_description", model, tokenizer, indices, config)
search_vector_db("stable ternary nitride", "nl_property_description", model, tokenizer, indices, config)
search_vector_db("narrow bandgap semiconductor for IR", "nl_property_description", model, tokenizer, indices, config)
search_vector_db("metallic binary compound", "nl_property_description", model, tokenizer, indices, config)

# πŸ§ͺ Composition queries  
search_vector_db("Fe2O3", "composition", model, tokenizer, indices, config)
search_vector_db("BaTiO3", "composition", model, tokenizer, indices, config)

# πŸ“– Structure description queries
search_vector_db("perovskite with octahedral coordination", "robocrys_rep", model, tokenizer, indices, config)

# πŸ“Š Structured property queries
search_vector_db("composition: TiO2 | bandgap: 3.2000", "property", model, tokenizer, indices, config)

# πŸ”¬ CIF queries (paste your CIF)
search_vector_db("data_TiO2\n_symmetry P1\n_cell 4.59 4.59 2.96 90 90 90", "cif_symmetrized", ...)

# 🧬 SLICES queries
search_vector_db("Ti O 0 1 o o o", "slices", model, tokenizer, indices, config)
```

## πŸ“Š Evaluation Metrics

Cross-modal Recall@k on test set:

| Pair | R@1 | R@5 | R@10 | R@20 |
|------|-----|-----|------|------|
| composition β†’ crystal_text_llm | TBD | TBD | TBD | TBD |
| composition β†’ cif_symmetrized | TBD | TBD | TBD | TBD |
| composition β†’ slices | TBD | TBD | TBD | TBD |
| slices β†’ crystal_text_llm | TBD | TBD | TBD | TBD |
| robocrys_rep β†’ composition | TBD | TBD | TBD | TBD |

NL Query Results:

| Query | Top-1 Match | Score |
|-------|------------|-------|
| "oxide with high bandgap" | TBD | TBD |
| "narrow bandgap semiconductor" | TBD | TBD |
| "stable binary oxide" | TBD | TBD |

*Results populated after training.*

## 🧩 Extending: Graph Embeddings

The architecture is plug-and-play for new modalities:

```python
# Add a GNN modality
from torch_geometric.nn import SchNet

class GraphEncoder(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()
        self.gnn = SchNet(hidden_channels=256)
        self.proj = ModalityProjection(256, embed_dim)
    
    def forward(self, data):
        h = self.gnn(data.z, data.pos, data.batch)
        return self.proj(h)

# Register as new modality
model.projections["graph"] = graph_encoder.proj
# It gets aligned automatically through AllPairsCLIP
```

## πŸ“¦ Dataset

[n0w0f/MatText](https://huggingface.co/datasets/n0w0f/MatText) β€” 100k+ crystal structures in 10+ text representations

## πŸ“š References

- **MatText**: [arxiv:2406.17295](https://arxiv.org/abs/2406.17295)
- **MultiMat**: [arxiv:2312.00111](https://arxiv.org/abs/2312.00111)
- **MatExpert**: [arxiv:2410.21317](https://arxiv.org/abs/2410.21317)
- **LaCLIP**: [arxiv:2305.20088](https://arxiv.org/abs/2305.20088)
- **SupReMix**: [arxiv:2309.16633](https://arxiv.org/abs/2309.16633)
- **CrystalCLR**: [arxiv:2211.13408](https://arxiv.org/abs/2211.13408)
- **Symile**: [arxiv:2411.01053](https://arxiv.org/abs/2411.01053)

## πŸ“„ License

MIT