n0w0f commited on
Commit
6e805ad
·
verified ·
1 Parent(s): efe29ef

Add training script: CLIP-style multi-modal material embedding alignment

Browse files
Files changed (1) hide show
  1. train_mattext_embeddings.py +689 -0
train_mattext_embeddings.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MatText Multi-Modal Embedding Alignment Training
3
+
4
+ Architecture: CLIP-style contrastive learning across 8+ material text representations
5
+ - Shared encoder (ModernBERT-base, 8192 ctx) with per-modality projection heads
6
+ - All-pairs symmetric InfoNCE loss
7
+ - Property-conditioned retrieval via property description encoding
8
+ - FAISS vector database for cross-modal retrieval
9
+
10
+ Based on:
11
+ - MultiMat (AllPairsCLIP, arxiv:2312.00111)
12
+ - MatExpert (property↔structure InfoNCE, arxiv:2410.21317)
13
+ - CrystalCLR (composition similarity, arxiv:2211.13408)
14
+
15
+ Usage:
16
+ pip install torch transformers datasets faiss-cpu huggingface_hub trackio
17
+ python train_mattext_embeddings.py
18
+
19
+ # Or on HF Jobs:
20
+ # Hardware: a10g-large (24GB VRAM), timeout: 6h
21
+ """
22
+
23
+ import os
24
+ import json
25
+ import math
26
+ import time
27
+ import logging
28
+ import random
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from torch.utils.data import Dataset, DataLoader
34
+ from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
35
+ from datasets import load_dataset, concatenate_datasets
36
+ from huggingface_hub import HfApi
37
+ import faiss
38
+
39
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ # ============================================================================
44
+ # Configuration
45
+ # ============================================================================
46
+
47
+ class Config:
48
+ # Model
49
+ encoder_name = "answerdotai/ModernBERT-base"
50
+ embed_dim = 128 # projection dimension (MultiMat recipe: 128-d)
51
+ max_length = 512 # tokens per modality input (ModernBERT supports up to 8192)
52
+
53
+ # Modalities to align (columns in the dataset)
54
+ modalities = [
55
+ "composition",
56
+ "atom_sequences",
57
+ "cif_symmetrized",
58
+ "cif_p1",
59
+ "zmatrix",
60
+ "atom_sequences_plusplus",
61
+ "slices",
62
+ "crystal_text_llm",
63
+ "local_env",
64
+ "robocrys_rep", # natural language description (only in pretrain subsets)
65
+ ]
66
+
67
+ # Training
68
+ batch_size = 32
69
+ learning_rate = 2e-5
70
+ weight_decay = 0.01
71
+ num_epochs = 3
72
+ warmup_ratio = 0.1
73
+ temperature = 0.07 # InfoNCE temperature (MultiMat/CLIP standard)
74
+ grad_accum_steps = 8 # effective batch = 32*8 = 256 (critical for InfoNCE)
75
+ max_grad_norm = 1.0
76
+ gradient_checkpointing = True
77
+ max_modalities_per_step = 4 # randomly sample N modalities per step to save VRAM
78
+
79
+ # Data
80
+ dataset_name = "n0w0f/MatText"
81
+ pretrain_config = "pretrain100k_v2"
82
+ finetune_configs = [
83
+ ("bandgap-train-filtered", "fold_0"),
84
+ ("form_energy-train-filtered", "fold_0"),
85
+ ]
86
+ max_train_samples = 50000
87
+
88
+ # Output
89
+ output_dir = "mattext-embeddings"
90
+ hub_model_id = "n0w0f/mattext-aligned-embeddings"
91
+ push_to_hub = True
92
+
93
+ # Device
94
+ device = "cuda" if torch.cuda.is_available() else "cpu"
95
+ fp16 = torch.cuda.is_available()
96
+
97
+
98
+ # ============================================================================
99
+ # Model: Shared Encoder + Per-Modality Projection Heads
100
+ # ============================================================================
101
+
102
+ class ModalityProjection(nn.Module):
103
+ """2-layer MLP projection head (MultiMat recipe)"""
104
+ def __init__(self, input_dim, output_dim):
105
+ super().__init__()
106
+ self.net = nn.Sequential(
107
+ nn.Linear(input_dim, input_dim),
108
+ nn.GELU(),
109
+ nn.LayerNorm(input_dim),
110
+ nn.Linear(input_dim, output_dim),
111
+ )
112
+
113
+ def forward(self, x):
114
+ return F.normalize(self.net(x), dim=-1)
115
+
116
+
117
+ class MatTextEncoder(nn.Module):
118
+ """
119
+ Shared transformer encoder with per-modality projection heads.
120
+ All modalities share the same backbone but project to a shared
121
+ embedding space through modality-specific heads.
122
+ """
123
+ def __init__(self, config: Config):
124
+ super().__init__()
125
+ self.config = config
126
+
127
+ # Shared backbone
128
+ self.backbone = AutoModel.from_pretrained(config.encoder_name)
129
+ hidden_size = self.backbone.config.hidden_size
130
+
131
+ if config.gradient_checkpointing:
132
+ self.backbone.gradient_checkpointing_enable()
133
+
134
+ # Per-modality projection heads
135
+ self.projections = nn.ModuleDict({
136
+ mod: ModalityProjection(hidden_size, config.embed_dim)
137
+ for mod in config.modalities
138
+ })
139
+
140
+ # Property projection (for property-conditioned queries)
141
+ self.property_projection = ModalityProjection(hidden_size, config.embed_dim)
142
+
143
+ # Learnable temperature
144
+ self.log_temperature = nn.Parameter(
145
+ torch.tensor(math.log(1.0 / config.temperature))
146
+ )
147
+
148
+ def encode(self, input_ids, attention_mask, modality_name):
149
+ """Encode a single modality"""
150
+ outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
151
+
152
+ # Mean pooling
153
+ mask = attention_mask.unsqueeze(-1).float()
154
+ hidden = outputs.last_hidden_state
155
+ pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
156
+
157
+ # Project through modality-specific head
158
+ if modality_name == "property":
159
+ return self.property_projection(pooled)
160
+ return self.projections[modality_name](pooled)
161
+
162
+ @property
163
+ def temperature(self):
164
+ return torch.exp(self.log_temperature).clamp(min=0.01, max=100.0)
165
+
166
+
167
+ # ============================================================================
168
+ # Loss Functions
169
+ # ============================================================================
170
+
171
+ def symmetric_clip_loss(emb_a, emb_b, temperature):
172
+ """Symmetric InfoNCE (CLIP loss)"""
173
+ N = emb_a.size(0)
174
+ logits = (emb_a @ emb_b.T) * temperature
175
+ labels = torch.arange(N, device=emb_a.device)
176
+ loss_a = F.cross_entropy(logits, labels)
177
+ loss_b = F.cross_entropy(logits.T, labels)
178
+ return (loss_a + loss_b) / 2
179
+
180
+
181
+ def all_pairs_clip_loss(embeddings_dict, temperature):
182
+ """AllPairsCLIP: sum symmetric InfoNCE over all modality pairs."""
183
+ mods = [k for k, v in embeddings_dict.items() if v is not None]
184
+ if len(mods) < 2:
185
+ return torch.tensor(0.0, requires_grad=True)
186
+
187
+ device = embeddings_dict[mods[0]].device
188
+ total_loss = torch.tensor(0.0, device=device)
189
+ n_pairs = 0
190
+
191
+ for i in range(len(mods)):
192
+ for j in range(i + 1, len(mods)):
193
+ total_loss = total_loss + symmetric_clip_loss(
194
+ embeddings_dict[mods[i]], embeddings_dict[mods[j]], temperature
195
+ )
196
+ n_pairs += 1
197
+
198
+ return total_loss / n_pairs
199
+
200
+
201
+ def property_similarity_loss(embeddings, labels, temperature):
202
+ """Property-aware soft contrastive loss (SupReMix-inspired)."""
203
+ N = embeddings.size(0)
204
+ if N < 2:
205
+ return torch.tensor(0.0, requires_grad=True)
206
+
207
+ label_diff = torch.abs(labels.unsqueeze(0) - labels.unsqueeze(1))
208
+ max_diff = label_diff.max().clamp(min=1e-6)
209
+ label_sim = 1.0 - (label_diff / max_diff)
210
+
211
+ cos_sim = embeddings @ embeddings.T
212
+ mask = torch.eye(N, device=embeddings.device).bool()
213
+ cos_sim = cos_sim.masked_fill(mask, 0)
214
+ label_sim = label_sim.masked_fill(mask, 0)
215
+
216
+ return F.mse_loss(cos_sim, label_sim)
217
+
218
+
219
+ # ============================================================================
220
+ # Dataset
221
+ # ============================================================================
222
+
223
+ class MatTextMultiModalDataset(Dataset):
224
+ def __init__(self, data, modalities, property_col=None, property_name=None):
225
+ self.data = data
226
+ self.modalities = modalities
227
+ self.property_col = property_col
228
+ self.property_name = property_name
229
+
230
+ available_cols = set(data.column_names) if hasattr(data, 'column_names') else set(data[0].keys())
231
+ self.available_modalities = [m for m in modalities if m in available_cols]
232
+ logger.info(f"Available modalities: {self.available_modalities}")
233
+
234
+ self.has_properties = property_col is not None and property_col in available_cols
235
+ if self.has_properties:
236
+ logger.info(f"Property column '{property_col}' found")
237
+
238
+ def __len__(self):
239
+ return len(self.data)
240
+
241
+ def __getitem__(self, idx):
242
+ row = self.data[idx]
243
+ item = {}
244
+ for mod in self.available_modalities:
245
+ text = row.get(mod, None)
246
+ if text and isinstance(text, str) and len(text.strip()) > 0:
247
+ item[mod] = text.strip()
248
+ else:
249
+ item[mod] = None
250
+
251
+ if self.has_properties and row.get(self.property_col) is not None:
252
+ label_val = float(row[self.property_col])
253
+ comp = row.get("composition", "unknown")
254
+ item["property_text"] = f"composition: {comp} | {self.property_name}: {label_val:.4f}"
255
+ item["property_label"] = label_val
256
+ else:
257
+ item["property_text"] = None
258
+ item["property_label"] = None
259
+
260
+ return item
261
+
262
+
263
+ def collate_fn(batch, tokenizer, modalities, max_length):
264
+ result = {}
265
+ all_mod_keys = list(modalities) + ["property_text"]
266
+
267
+ for mod in all_mod_keys:
268
+ texts = [item.get(mod) for item in batch]
269
+ valid_texts = [t for t in texts if t is not None]
270
+ if len(valid_texts) == 0:
271
+ result[mod] = None
272
+ continue
273
+
274
+ texts_clean = [t if t is not None else "" for t in texts]
275
+ mask_valid = [t is not None for t in texts]
276
+
277
+ encoded = tokenizer(texts_clean, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
278
+ result[mod] = {
279
+ "input_ids": encoded["input_ids"],
280
+ "attention_mask": encoded["attention_mask"],
281
+ "valid_mask": torch.tensor(mask_valid, dtype=torch.bool),
282
+ }
283
+
284
+ labels = [item.get("property_label") for item in batch]
285
+ if any(l is not None for l in labels):
286
+ labels_clean = [l if l is not None else 0.0 for l in labels]
287
+ labels_mask = [l is not None for l in labels]
288
+ result["property_labels"] = torch.tensor(labels_clean, dtype=torch.float32)
289
+ result["property_labels_mask"] = torch.tensor(labels_mask, dtype=torch.bool)
290
+ else:
291
+ result["property_labels"] = None
292
+ result["property_labels_mask"] = None
293
+
294
+ return result
295
+
296
+
297
+ # ============================================================================
298
+ # Training Loop
299
+ # ============================================================================
300
+
301
+ def train_epoch(model, dataloader, optimizer, scheduler, config, epoch, scaler=None):
302
+ model.train()
303
+ total_loss = 0; total_clip_loss = 0; total_prop_loss = 0
304
+ log_interval = 20
305
+
306
+ optimizer.zero_grad()
307
+
308
+ for batch_idx, batch in enumerate(dataloader):
309
+ # Randomly sample modalities to save VRAM
310
+ available_mods = [m for m in config.modalities if batch.get(m) is not None]
311
+ if len(available_mods) > config.max_modalities_per_step:
312
+ must_have = [m for m in ["composition", "crystal_text_llm"] if m in available_mods]
313
+ remaining = [m for m in available_mods if m not in must_have]
314
+ n_sample = max(config.max_modalities_per_step - len(must_have), 1)
315
+ sampled = must_have + random.sample(remaining, min(n_sample, len(remaining)))
316
+ else:
317
+ sampled = available_mods
318
+
319
+ embeddings = {}
320
+ for mod in sampled:
321
+ if batch.get(mod) is None:
322
+ embeddings[mod] = None; continue
323
+
324
+ input_ids = batch[mod]["input_ids"].to(config.device)
325
+ attention_mask = batch[mod]["attention_mask"].to(config.device)
326
+ valid_mask = batch[mod]["valid_mask"]
327
+
328
+ if not valid_mask.any():
329
+ embeddings[mod] = None; continue
330
+
331
+ with torch.amp.autocast('cuda', enabled=config.fp16):
332
+ emb = model.encode(input_ids, attention_mask, mod)
333
+ emb = emb * valid_mask.to(config.device).unsqueeze(-1).float()
334
+ embeddings[mod] = emb
335
+
336
+ with torch.amp.autocast('cuda', enabled=config.fp16):
337
+ temperature = model.temperature
338
+ clip_l = all_pairs_clip_loss(embeddings, temperature)
339
+
340
+ prop_l = torch.tensor(0.0, device=config.device)
341
+ if batch.get("property_text") is not None and batch.get("property_labels") is not None:
342
+ prop_ids = batch["property_text"]["input_ids"].to(config.device)
343
+ prop_mask = batch["property_text"]["attention_mask"].to(config.device)
344
+ prop_valid = batch["property_text"]["valid_mask"]
345
+
346
+ if prop_valid.any():
347
+ with torch.amp.autocast('cuda', enabled=config.fp16):
348
+ prop_emb = model.encode(prop_ids, prop_mask, "property")
349
+
350
+ labels = batch["property_labels"].to(config.device)
351
+ labels_mask = batch["property_labels_mask"].to(config.device)
352
+
353
+ if labels_mask.sum() > 1:
354
+ prop_l = property_similarity_loss(prop_emb[labels_mask], labels[labels_mask], temperature)
355
+
356
+ for anchor_mod in ["robocrys_rep", "crystal_text_llm", "composition"]:
357
+ if embeddings.get(anchor_mod) is not None:
358
+ with torch.amp.autocast('cuda', enabled=config.fp16):
359
+ prop_clip = symmetric_clip_loss(
360
+ prop_emb[labels_mask], embeddings[anchor_mod][labels_mask], temperature
361
+ )
362
+ prop_l = prop_l + 0.5 * prop_clip
363
+ break
364
+
365
+ loss = (clip_l + 0.3 * prop_l) / config.grad_accum_steps
366
+
367
+ if config.fp16 and scaler is not None:
368
+ scaler.scale(loss).backward()
369
+ else:
370
+ loss.backward()
371
+
372
+ if (batch_idx + 1) % config.grad_accum_steps == 0:
373
+ if config.fp16 and scaler is not None:
374
+ scaler.unscale_(optimizer)
375
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
376
+ scaler.step(optimizer); scaler.update()
377
+ else:
378
+ torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
379
+ optimizer.step()
380
+ scheduler.step(); optimizer.zero_grad()
381
+
382
+ total_loss += loss.item() * config.grad_accum_steps
383
+ total_clip_loss += clip_l.item()
384
+ total_prop_loss += prop_l.item() if isinstance(prop_l, torch.Tensor) else prop_l
385
+
386
+ if (batch_idx + 1) % log_interval == 0:
387
+ avg = total_loss / (batch_idx + 1)
388
+ logger.info(
389
+ f"Epoch {epoch} | {batch_idx+1}/{len(dataloader)} | "
390
+ f"Loss: {avg:.4f} | CLIP: {total_clip_loss/(batch_idx+1):.4f} | "
391
+ f"Prop: {total_prop_loss/(batch_idx+1):.4f} | "
392
+ f"LR: {scheduler.get_last_lr()[0]:.2e} | T: {model.temperature.item():.3f}"
393
+ )
394
+
395
+ return total_loss / max(len(dataloader), 1)
396
+
397
+
398
+ # ============================================================================
399
+ # Evaluation
400
+ # ============================================================================
401
+
402
+ @torch.no_grad()
403
+ def evaluate_retrieval(model, dataloader, config, k_values=[1, 5, 10]):
404
+ model.eval()
405
+ all_embeddings = {mod: [] for mod in config.modalities}
406
+
407
+ for batch in dataloader:
408
+ for mod in config.modalities:
409
+ if batch.get(mod) is None: continue
410
+ input_ids = batch[mod]["input_ids"].to(config.device)
411
+ attention_mask = batch[mod]["attention_mask"].to(config.device)
412
+ valid_mask = batch[mod]["valid_mask"]
413
+ if not valid_mask.any(): continue
414
+
415
+ emb = model.encode(input_ids, attention_mask, mod).cpu()
416
+ for i in range(len(emb)):
417
+ all_embeddings[mod].append(emb[i] if valid_mask[i] else None)
418
+
419
+ results = {}
420
+ eval_pairs = [
421
+ ("composition", "crystal_text_llm"), ("composition", "cif_symmetrized"),
422
+ ("slices", "crystal_text_llm"), ("composition", "slices"),
423
+ ]
424
+ if len([e for e in all_embeddings.get("robocrys_rep", []) if e is not None]) > 0:
425
+ eval_pairs.extend([("robocrys_rep", "composition"), ("robocrys_rep", "cif_symmetrized")])
426
+
427
+ for mod_a, mod_b in eval_pairs:
428
+ embs_a, embs_b = all_embeddings.get(mod_a, []), all_embeddings.get(mod_b, [])
429
+ if not embs_a or not embs_b: continue
430
+
431
+ valid_idx = [i for i in range(min(len(embs_a), len(embs_b)))
432
+ if embs_a[i] is not None and embs_b[i] is not None]
433
+ if len(valid_idx) < 10: continue
434
+
435
+ ea = torch.stack([embs_a[i] for i in valid_idx])
436
+ eb = torch.stack([embs_b[i] for i in valid_idx])
437
+ sim = ea @ eb.T
438
+
439
+ recalls = {}
440
+ for k in k_values:
441
+ kk = min(k, len(valid_idx) - 1)
442
+ topk = sim.topk(kk, dim=1).indices
443
+ correct = (topk == torch.arange(len(valid_idx)).unsqueeze(1)).any(dim=1)
444
+ recalls[f"R@{k}"] = correct.float().mean().item()
445
+
446
+ results[f"{mod_a}→{mod_b}"] = recalls
447
+ logger.info(f" {mod_a}→{mod_b}: {recalls}")
448
+
449
+ return results
450
+
451
+
452
+ # ============================================================================
453
+ # FAISS Vector Database
454
+ # ============================================================================
455
+
456
+ def build_vector_database(model, dataset, tokenizer, config, modalities_to_index=None):
457
+ if modalities_to_index is None:
458
+ modalities_to_index = config.modalities
459
+ model.eval()
460
+
461
+ all_embeddings = {mod: [] for mod in modalities_to_index}
462
+ all_metadata = []
463
+ bs = 64
464
+
465
+ for start in range(0, len(dataset), bs):
466
+ end = min(start + bs, len(dataset))
467
+ items = [dataset[i] for i in range(start, end)]
468
+ batch = collate_fn(items, tokenizer, config.modalities, config.max_length)
469
+
470
+ for item in items:
471
+ all_metadata.append({"composition": item.get("composition", ""), "property_label": item.get("property_label")})
472
+
473
+ with torch.no_grad():
474
+ for mod in modalities_to_index:
475
+ if batch.get(mod) is None:
476
+ all_embeddings[mod].extend([None] * len(items)); continue
477
+ emb = model.encode(
478
+ batch[mod]["input_ids"].to(config.device),
479
+ batch[mod]["attention_mask"].to(config.device), mod
480
+ ).cpu().numpy()
481
+ for i in range(len(emb)):
482
+ all_embeddings[mod].append(emb[i] if batch[mod]["valid_mask"][i] else None)
483
+
484
+ if (start // bs) % 10 == 0:
485
+ logger.info(f"Indexed {end}/{len(dataset)}")
486
+
487
+ indices = {}
488
+ for mod in modalities_to_index:
489
+ valid_embs = [e for e in all_embeddings[mod] if e is not None]
490
+ valid_map = [i for i, e in enumerate(all_embeddings[mod]) if e is not None]
491
+ if not valid_embs: continue
492
+
493
+ emb_matrix = np.stack(valid_embs).astype(np.float32)
494
+ faiss.normalize_L2(emb_matrix)
495
+ d = emb_matrix.shape[1]
496
+
497
+ if len(valid_embs) > 10000:
498
+ nlist = min(100, int(np.sqrt(len(valid_embs))))
499
+ q = faiss.IndexFlatIP(d)
500
+ index = faiss.IndexIVFFlat(q, d, nlist, faiss.METRIC_INNER_PRODUCT)
501
+ index.train(emb_matrix)
502
+ else:
503
+ index = faiss.IndexFlatIP(d)
504
+ index.add(emb_matrix)
505
+
506
+ indices[mod] = {"index": index, "valid_indices_map": valid_map,
507
+ "metadata": [all_metadata[i] for i in valid_map]}
508
+ logger.info(f"FAISS {mod}: {len(valid_embs)} vectors, dim={d}")
509
+
510
+ return indices
511
+
512
+
513
+ def search_vector_db(query_text, query_modality, model, tokenizer, indices, config, k=10):
514
+ model.eval()
515
+ enc = tokenizer([query_text], padding=True, truncation=True, max_length=config.max_length, return_tensors="pt")
516
+ with torch.no_grad():
517
+ q = model.encode(enc["input_ids"].to(config.device), enc["attention_mask"].to(config.device), query_modality)
518
+ q = q.cpu().numpy().astype(np.float32)
519
+ faiss.normalize_L2(q)
520
+
521
+ results = []
522
+ for mod_name, idx_data in indices.items():
523
+ scores, ids = idx_data["index"].search(q, k)
524
+ for s, i in zip(scores[0], ids[0]):
525
+ if i >= 0:
526
+ m = dict(idx_data["metadata"][i])
527
+ m["matched_modality"] = mod_name
528
+ results.append((float(s), m))
529
+
530
+ results.sort(key=lambda x: x[0], reverse=True)
531
+ seen, unique = set(), []
532
+ for s, m in results:
533
+ c = m.get("composition", "")
534
+ if c not in seen:
535
+ seen.add(c); unique.append((s, m))
536
+ if len(unique) >= k: break
537
+ return unique
538
+
539
+
540
+ # ============================================================================
541
+ # Main
542
+ # ============================================================================
543
+
544
+ def main():
545
+ config = Config()
546
+ logger.info(f"Device: {config.device} | Encoder: {config.encoder_name}")
547
+ logger.info(f"Batch: {config.batch_size}x{config.grad_accum_steps}={config.batch_size*config.grad_accum_steps}")
548
+
549
+ try:
550
+ import trackio
551
+ trackio.init(project="mattext-embeddings", name=f"align-{config.encoder_name.split('/')[-1]}")
552
+ use_trackio = True
553
+ except:
554
+ use_trackio = False
555
+
556
+ tokenizer = AutoTokenizer.from_pretrained(config.encoder_name)
557
+ model = MatTextEncoder(config).to(config.device)
558
+ logger.info(f"Params: {sum(p.numel() for p in model.parameters()):,}")
559
+
560
+ # Load data
561
+ pretrain_data = load_dataset(config.dataset_name, config.pretrain_config, split="train")
562
+ logger.info(f"Pretrain: {len(pretrain_data)} samples, cols: {pretrain_data.column_names}")
563
+
564
+ finetune_data = None
565
+ for ft_cfg, ft_split in config.finetune_configs:
566
+ try:
567
+ ft = load_dataset(config.dataset_name, ft_cfg, split=ft_split)
568
+ logger.info(f"Loaded {ft_cfg}/{ft_split}: {len(ft)} samples")
569
+ finetune_data = ft if finetune_data is None else concatenate_datasets([
570
+ finetune_data.select_columns(list(set(finetune_data.column_names) & set(ft.column_names))),
571
+ ft.select_columns(list(set(finetune_data.column_names) & set(ft.column_names)))
572
+ ])
573
+ except Exception as e:
574
+ logger.warning(f"Failed {ft_cfg}: {e}")
575
+
576
+ if len(pretrain_data) > config.max_train_samples:
577
+ pretrain_data = pretrain_data.shuffle(seed=42).select(range(config.max_train_samples))
578
+
579
+ make_collate = lambda tok, mods, ml: lambda batch: collate_fn(batch, tok, mods, ml)
580
+
581
+ pretrain_loader = DataLoader(
582
+ MatTextMultiModalDataset(pretrain_data, config.modalities),
583
+ batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=0,
584
+ collate_fn=make_collate(tokenizer, config.modalities, config.max_length),
585
+ pin_memory=config.device == "cuda",
586
+ )
587
+
588
+ finetune_loader = None
589
+ if finetune_data:
590
+ if len(finetune_data) > config.max_train_samples:
591
+ finetune_data = finetune_data.shuffle(seed=42).select(range(config.max_train_samples))
592
+ finetune_loader = DataLoader(
593
+ MatTextMultiModalDataset(finetune_data, config.modalities, "labels", "property_value"),
594
+ batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=0,
595
+ collate_fn=make_collate(tokenizer, config.modalities, config.max_length),
596
+ pin_memory=config.device == "cuda",
597
+ )
598
+
599
+ optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
600
+ total_steps = len(pretrain_loader) * config.num_epochs // config.grad_accum_steps
601
+ if finetune_loader:
602
+ total_steps += len(finetune_loader) * config.num_epochs // config.grad_accum_steps
603
+ scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * config.warmup_ratio), total_steps)
604
+ scaler = torch.amp.GradScaler('cuda') if config.fp16 else None
605
+
606
+ logger.info(f"Steps: {total_steps}")
607
+
608
+ # Phase 1: Multi-modal alignment
609
+ logger.info("=" * 60 + "\nPhase 1: Multi-modal alignment\n" + "=" * 60)
610
+ best_loss = float('inf')
611
+ for epoch in range(1, config.num_epochs + 1):
612
+ t0 = time.time()
613
+ loss = train_epoch(model, pretrain_loader, optimizer, scheduler, config, epoch, scaler)
614
+ logger.info(f"Epoch {epoch} | Loss: {loss:.4f} | Time: {time.time()-t0:.0f}s")
615
+ if use_trackio:
616
+ try: trackio.log({"phase": 1, "epoch": epoch, "loss": loss})
617
+ except: pass
618
+ if loss < best_loss:
619
+ best_loss = loss
620
+ os.makedirs(config.output_dir, exist_ok=True)
621
+ torch.save(model.state_dict(), f"{config.output_dir}/best_model.pt")
622
+
623
+ # Phase 2: Property-conditioned alignment
624
+ if finetune_loader:
625
+ logger.info("=" * 60 + "\nPhase 2: Property-conditioned alignment\n" + "=" * 60)
626
+ for epoch in range(1, config.num_epochs + 1):
627
+ t0 = time.time()
628
+ loss = train_epoch(model, finetune_loader, optimizer, scheduler, config, epoch, scaler)
629
+ logger.info(f"P2 Epoch {epoch} | Loss: {loss:.4f} | Time: {time.time()-t0:.0f}s")
630
+ if loss < best_loss:
631
+ best_loss = loss
632
+ torch.save(model.state_dict(), f"{config.output_dir}/best_model.pt")
633
+
634
+ # Evaluate
635
+ logger.info("=" * 60 + "\nEvaluation\n" + "=" * 60)
636
+ eval_data = load_dataset(config.dataset_name, config.pretrain_config, split="test")
637
+ if len(eval_data) > 5000:
638
+ eval_data = eval_data.shuffle(seed=42).select(range(5000))
639
+
640
+ eval_loader = DataLoader(
641
+ MatTextMultiModalDataset(eval_data, config.modalities),
642
+ batch_size=config.batch_size, shuffle=False, num_workers=0,
643
+ collate_fn=make_collate(tokenizer, config.modalities, config.max_length),
644
+ )
645
+ results = evaluate_retrieval(model, eval_loader, config)
646
+
647
+ # Build FAISS DB
648
+ logger.info("Building FAISS indices...")
649
+ db = build_vector_database(
650
+ model, MatTextMultiModalDataset(eval_data, config.modalities),
651
+ tokenizer, config, ["composition", "crystal_text_llm", "slices", "cif_symmetrized"]
652
+ )
653
+
654
+ os.makedirs(f"{config.output_dir}/faiss", exist_ok=True)
655
+ for mod, d in db.items():
656
+ faiss.write_index(d["index"], f"{config.output_dir}/faiss/{mod}.index")
657
+ with open(f"{config.output_dir}/faiss/{mod}_metadata.json", "w") as f:
658
+ json.dump(d["metadata"], f)
659
+
660
+ # Demo
661
+ for q, m in [("Fe2O3", "composition"), ("Si Ge", "composition")]:
662
+ logger.info(f"\nQuery: '{q}' ({m})")
663
+ for rank, (s, meta) in enumerate(search_vector_db(q, m, model, tokenizer, db, config, 5), 1):
664
+ logger.info(f" #{rank}: {s:.4f} | {meta}")
665
+
666
+ # Save & push
667
+ torch.save(model.state_dict(), f"{config.output_dir}/model.pt")
668
+ tokenizer.save_pretrained(config.output_dir)
669
+ with open(f"{config.output_dir}/config.json", "w") as f:
670
+ json.dump({k: str(v) if not isinstance(v, (int, float, str, bool, list, dict, type(None))) else v
671
+ for k, v in vars(Config).items() if not k.startswith("_")}, f, indent=2)
672
+ with open(f"{config.output_dir}/retrieval_results.json", "w") as f:
673
+ json.dump(results, f, indent=2)
674
+
675
+ if config.push_to_hub:
676
+ try:
677
+ api = HfApi()
678
+ api.create_repo(config.hub_model_id, exist_ok=True)
679
+ api.upload_folder(folder_path=config.output_dir, repo_id=config.hub_model_id,
680
+ commit_message="Upload MatText aligned embeddings + FAISS indices")
681
+ logger.info(f"Pushed to https://huggingface.co/{config.hub_model_id}")
682
+ except Exception as e:
683
+ logger.error(f"Push failed: {e}")
684
+
685
+ logger.info("DONE!")
686
+
687
+
688
+ if __name__ == "__main__":
689
+ main()