n0w0f commited on
Commit
5bc74d1
Β·
verified Β·
1 Parent(s): 6e805ad

Add comprehensive README with architecture details and usage

Browse files
Files changed (1) hide show
  1. README.md +231 -0
README.md ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MatText Aligned Embeddings: Multi-Modal Material Retrieval
2
+
3
+ **A CLIP-style multi-modal embedding model that aligns 10 different material text representations into a shared 128-d vector space for cross-modal retrieval.**
4
+
5
+ Query with *any* modality (composition, CIF, SLICES, natural language, z-matrix...) β†’ retrieve materials with similar properties across *all* modalities.
6
+
7
+ ## πŸ—οΈ Architecture
8
+
9
+ ```
10
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
11
+ β”‚ MatTextEncoder β”‚
12
+ β”‚ β”‚
13
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
14
+ β”‚ β”‚ Shared Backbone: ModernBERT-base (150M params) β”‚ β”‚
15
+ β”‚ β”‚ - 8192 token context window (handles long CIFs) β”‚ β”‚
16
+ β”‚ β”‚ - Mean pooling β†’ 768-d representation β”‚ β”‚
17
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
18
+ β”‚ β”‚ β”‚
19
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
20
+ β”‚ β–Ό β–Ό β–Ό β”‚
21
+ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚
22
+ β”‚ β”‚ Projection β”‚ β”‚ Projection β”‚ β”‚ Projection β”‚ ... β”‚
23
+ β”‚ β”‚ composition β”‚ β”‚ cif_sym β”‚ β”‚ slices β”‚ β”‚
24
+ β”‚ β”‚ 768β†’768β†’128 β”‚ β”‚ 768β†’768β†’128 β”‚ β”‚ 768β†’768β†’128 β”‚ β”‚
25
+ β”‚ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚
26
+ β”‚ β–Ό β–Ό β–Ό β”‚
27
+ β”‚ 128-d L2-norm 128-d L2-norm 128-d L2-norm β”‚
28
+ β”‚ β”‚
29
+ β”‚ ──── Shared Embedding Space ──── β”‚
30
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
31
+ ```
32
+
33
+ ### Key Design Decisions
34
+
35
+ | Decision | Choice | Rationale |
36
+ |----------|--------|-----------|
37
+ | Backbone | ModernBERT-base | 8192 ctx handles long CIFs; fast RoPE attention |
38
+ | Projection | 2-layer MLP per modality | MultiMat recipe: modality-specific heads preserve specialization |
39
+ | Embedding dim | 128 | Standard for contrastive learning; compact for FAISS |
40
+ | Loss | AllPairsCLIP + Property-MSE | Aligns all N(N-1)/2 modality pairs; property regularization |
41
+ | Temperature | Learnable (init 0.07) | CLIP standard; learned Ο„ improves convergence |
42
+
43
+ ## πŸ“Š Modalities Supported
44
+
45
+ | Modality | Column | Example | Query Type |
46
+ |----------|--------|---------|------------|
47
+ | Composition | `composition` | `Fe2O3` | "Find iron oxides" |
48
+ | Atom Sequence | `atom_sequences` | `Fe Fe Fe O O O` | Element lists |
49
+ | CIF (symmetrized) | `cif_symmetrized` | Full CIF text | Paste CIF data |
50
+ | CIF (P1) | `cif_p1` | Full CIF in P1 | Paste CIF data |
51
+ | Z-matrix | `zmatrix` | `Fe\nO 1 2.0\nO 1 2.0 2 90` | Internal coords |
52
+ | Atom Seq++ | `atom_sequences_plusplus` | `Fe O 3.57 3.57 90 90` | Elements + lattice |
53
+ | SLICES | `slices` | `Fe O 0 1 o o o` | SLICES encoding |
54
+ | Crystal Text (LLM) | `crystal_text_llm` | `3.6 3.6 3.6\n90 90 90\nFe...` | Gruver format |
55
+ | Local Environment | `local_env` | SMILES-like env | Local bonding |
56
+ | Natural Language | `robocrys_rep` | "FeO crystallizes in..." | Plain English |
57
+ | **Property Query** | property text | "bandgap: 1.5 eV" | Property search |
58
+
59
+ ## πŸ§ͺ Training Recipe
60
+
61
+ Based on three key papers:
62
+
63
+ 1. **MultiMat** (AllPairsCLIP, [arxiv:2312.00111](https://arxiv.org/abs/2312.00111)): Sum of symmetric InfoNCE over all modality pairs
64
+ 2. **MatExpert** ([arxiv:2410.21317](https://arxiv.org/abs/2410.21317)): Property↔structure contrastive alignment
65
+ 3. **CrystalCLR** ([arxiv:2211.13408](https://arxiv.org/abs/2211.13408)): Composition similarity loss
66
+ 4. **SupReMix** ([arxiv:2309.16633](https://arxiv.org/abs/2309.16633)): Property-label-aware soft contrastive
67
+
68
+ ### Two-Phase Training
69
+
70
+ **Phase 1 β€” Multi-modal alignment** (pretrain100k_v2, 50k samples):
71
+ - AllPairsCLIP loss across all 10 modalities
72
+ - Random modality sampling (4/10 per step) for VRAM efficiency
73
+ - Each step aligns C(4,2)=6 modality pairs
74
+
75
+ **Phase 2 β€” Property-conditioned alignment** (bandgap + form_energy, 50k samples):
76
+ - Same CLIP loss + property similarity MSE loss
77
+ - Property text "composition: Fe2O3 | bandgap: 2.1000" aligned with structure representations
78
+ - Materials with similar property values cluster in embedding space
79
+
80
+ ### Hyperparameters
81
+
82
+ ```
83
+ encoder: answerdotai/ModernBERT-base
84
+ embed_dim: 128
85
+ max_length: 512 tokens
86
+ batch_size: 32 Γ— 8 grad_accum = 256 effective
87
+ learning_rate: 2e-5 (cosine decay, 10% warmup)
88
+ temperature: learnable (init 0.07)
89
+ epochs: 3 per phase
90
+ optimizer: AdamW (weight_decay=0.01)
91
+ fp16: True
92
+ gradient_checkpointing: True
93
+ ```
94
+
95
+ ## πŸš€ Quick Start
96
+
97
+ ### Training
98
+
99
+ ```bash
100
+ pip install torch transformers datasets faiss-cpu huggingface_hub trackio
101
+
102
+ # Local GPU
103
+ python train_mattext_embeddings.py
104
+
105
+ # HF Jobs (recommended: a10g-large, 24GB VRAM)
106
+ # Set timeout to 6h
107
+ ```
108
+
109
+ ### Inference & Search
110
+
111
+ ```python
112
+ import torch
113
+ import faiss
114
+ import json
115
+ import numpy as np
116
+ from transformers import AutoModel, AutoTokenizer
117
+
118
+ # Load model
119
+ from train_mattext_embeddings import MatTextEncoder, Config, search_vector_db
120
+
121
+ config = Config()
122
+ config.device = "cuda" if torch.cuda.is_available() else "cpu"
123
+
124
+ model = MatTextEncoder(config)
125
+ model.load_state_dict(torch.load("mattext-embeddings/model.pt", map_location=config.device))
126
+ model = model.to(config.device)
127
+ model.eval()
128
+
129
+ tokenizer = AutoTokenizer.from_pretrained(config.encoder_name)
130
+
131
+ # Load FAISS indices
132
+ indices = {}
133
+ for mod in ["composition", "crystal_text_llm", "slices", "cif_symmetrized"]:
134
+ index = faiss.read_index(f"mattext-embeddings/faiss/{mod}.index")
135
+ with open(f"mattext-embeddings/faiss/{mod}_metadata.json") as f:
136
+ metadata = json.load(f)
137
+ indices[mod] = {"index": index, "metadata": metadata}
138
+
139
+ # Search!
140
+ results = search_vector_db("Fe2O3", "composition", model, tokenizer, indices, config, k=5)
141
+ for score, meta in results:
142
+ print(f"Score: {score:.4f} | {meta['composition']}")
143
+ ```
144
+
145
+ ### Cross-Modal Query Examples
146
+
147
+ ```python
148
+ # Query by composition β†’ find across all modalities
149
+ search_vector_db("SiO2", "composition", model, tokenizer, indices, config)
150
+
151
+ # Query by natural language β†’ find materials
152
+ search_vector_db("perovskite with high bandgap", "robocrys_rep", model, tokenizer, indices, config)
153
+
154
+ # Query by SLICES representation
155
+ search_vector_db("Si O 0 1 o o o", "slices", model, tokenizer, indices, config)
156
+
157
+ # Query by CIF data
158
+ search_vector_db("data_SiO2\n_symmetry P1\n...", "cif_symmetrized", model, tokenizer, indices, config)
159
+
160
+ # Property-conditioned query
161
+ search_vector_db("composition: Si | bandgap: 1.1200", "property", model, tokenizer, indices, config)
162
+ ```
163
+
164
+ ## πŸ”¬ Evaluation Metrics
165
+
166
+ Cross-modal Recall@k: for each material, embed in modality A, retrieve in modality B, check if correct match is in top-k.
167
+
168
+ | Pair | R@1 | R@5 | R@10 |
169
+ |------|-----|-----|------|
170
+ | composition β†’ crystal_text_llm | TBD | TBD | TBD |
171
+ | composition β†’ cif_symmetrized | TBD | TBD | TBD |
172
+ | slices β†’ crystal_text_llm | TBD | TBD | TBD |
173
+ | robocrys_rep β†’ composition | TBD | TBD | TBD |
174
+
175
+ *Results populated after training.*
176
+
177
+ ## 🧩 Extending: Graph Embeddings
178
+
179
+ The architecture supports adding graph neural network (GNN) embeddings:
180
+
181
+ ```python
182
+ # Add a GNN projection head
183
+ from torch_geometric.nn import SchNet, DimeNet # or CGCNN
184
+
185
+ class GraphEncoder(nn.Module):
186
+ def __init__(self, embed_dim=128):
187
+ super().__init__()
188
+ self.gnn = SchNet(hidden_channels=256, num_filters=128, num_interactions=6)
189
+ self.proj = ModalityProjection(256, embed_dim)
190
+
191
+ def forward(self, data):
192
+ # data: PyG Data with pos, z (atomic numbers), batch
193
+ h = self.gnn(data.z, data.pos, data.batch)
194
+ return self.proj(h)
195
+
196
+ # Add to MatTextEncoder:
197
+ model.graph_encoder = GraphEncoder(config.embed_dim)
198
+ model.projections["graph"] = model.graph_encoder.proj
199
+
200
+ # Training: treat graph embeddings as another modality in AllPairsCLIP
201
+ ```
202
+
203
+ For graph embeddings, convert CIF β†’ PyG Data (using `pymatgen` + `torch_geometric`):
204
+ ```python
205
+ from pymatgen.core import Structure
206
+ from torch_geometric.data import Data
207
+
208
+ def cif_to_graph(cif_string, cutoff=5.0):
209
+ struct = Structure.from_str(cif_string, fmt="cif")
210
+ # Get neighbors within cutoff
211
+ neighbors = struct.get_all_neighbors(cutoff)
212
+ # Build edge_index, pos, z ...
213
+ return Data(z=atomic_numbers, pos=positions, edge_index=edge_index)
214
+ ```
215
+
216
+ ## πŸ“š References
217
+
218
+ - **MatText**: [arxiv:2406.17295](https://arxiv.org/abs/2406.17295) β€” Dataset and text representations
219
+ - **MultiMat**: [arxiv:2312.00111](https://arxiv.org/abs/2312.00111) β€” AllPairsCLIP for materials
220
+ - **MatExpert**: [arxiv:2410.21317](https://arxiv.org/abs/2410.21317) β€” Property↔structure alignment
221
+ - **CrystalCLR**: [arxiv:2211.13408](https://arxiv.org/abs/2211.13408) β€” Contrastive learning for crystals
222
+ - **SupReMix**: [arxiv:2309.16633](https://arxiv.org/abs/2309.16633) β€” Property-aware hard negatives
223
+ - **Symile**: [arxiv:2411.01053](https://arxiv.org/abs/2411.01053) β€” Total-correlation loss for M modalities
224
+
225
+ ## πŸ“„ License
226
+
227
+ MIT
228
+
229
+ ## πŸ”— Dataset
230
+
231
+ [n0w0f/MatText](https://huggingface.co/datasets/n0w0f/MatText) β€” 100k+ crystal structures in 10 text representations