n0w0f commited on
Commit
7949a14
·
verified ·
1 Parent(s): 5bc74d1

v2: 1024 context, NL property queries (LaCLIP-style), A100 80GB optimized

Browse files
Files changed (1) hide show
  1. train_mattext_embeddings.py +730 -238
train_mattext_embeddings.py CHANGED
@@ -1,23 +1,26 @@
1
  """
2
- MatText Multi-Modal Embedding Alignment Training
3
 
4
- Architecture: CLIP-style contrastive learning across 8+ material text representations
5
- - Shared encoder (ModernBERT-base, 8192 ctx) with per-modality projection heads
6
- - All-pairs symmetric InfoNCE loss
7
- - Property-conditioned retrieval via property description encoding
8
- - FAISS vector database for cross-modal retrieval
 
 
 
 
 
9
 
10
  Based on:
11
- - MultiMat (AllPairsCLIP, arxiv:2312.00111)
12
  - MatExpert (property↔structure InfoNCE, arxiv:2410.21317)
13
- - CrystalCLR (composition similarity, arxiv:2211.13408)
 
14
 
15
  Usage:
16
- pip install torch transformers datasets faiss-cpu huggingface_hub trackio
17
  python train_mattext_embeddings.py
18
-
19
- # Or on HF Jobs:
20
- # Hardware: a10g-large (24GB VRAM), timeout: 6h
21
  """
22
 
23
  import os
@@ -26,6 +29,7 @@ import math
26
  import time
27
  import logging
28
  import random
 
29
  import numpy as np
30
  import torch
31
  import torch.nn as nn
@@ -39,7 +43,6 @@ import faiss
39
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
40
  logger = logging.getLogger(__name__)
41
 
42
-
43
  # ============================================================================
44
  # Configuration
45
  # ============================================================================
@@ -47,13 +50,13 @@ logger = logging.getLogger(__name__)
47
  class Config:
48
  # Model
49
  encoder_name = "answerdotai/ModernBERT-base"
50
- embed_dim = 128 # projection dimension (MultiMat recipe: 128-d)
51
- max_length = 512 # tokens per modality input (ModernBERT supports up to 8192)
52
 
53
  # Modalities to align (columns in the dataset)
54
  modalities = [
55
  "composition",
56
- "atom_sequences",
57
  "cif_symmetrized",
58
  "cif_p1",
59
  "zmatrix",
@@ -61,29 +64,38 @@ class Config:
61
  "slices",
62
  "crystal_text_llm",
63
  "local_env",
64
- "robocrys_rep", # natural language description (only in pretrain subsets)
65
  ]
66
 
 
 
 
 
67
  # Training
68
- batch_size = 32
69
  learning_rate = 2e-5
70
  weight_decay = 0.01
71
- num_epochs = 3
 
72
  warmup_ratio = 0.1
73
- temperature = 0.07 # InfoNCE temperature (MultiMat/CLIP standard)
74
- grad_accum_steps = 8 # effective batch = 32*8 = 256 (critical for InfoNCE)
75
  max_grad_norm = 1.0
76
  gradient_checkpointing = True
77
- max_modalities_per_step = 4 # randomly sample N modalities per step to save VRAM
78
 
79
  # Data
80
  dataset_name = "n0w0f/MatText"
81
  pretrain_config = "pretrain100k_v2"
82
  finetune_configs = [
83
- ("bandgap-train-filtered", "fold_0"),
84
- ("form_energy-train-filtered", "fold_0"),
85
  ]
86
- max_train_samples = 50000
 
 
 
 
87
 
88
  # Output
89
  output_dir = "mattext-embeddings"
@@ -92,7 +104,171 @@ class Config:
92
 
93
  # Device
94
  device = "cuda" if torch.cuda.is_available() else "cpu"
95
- fp16 = torch.cuda.is_available()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
 
98
  # ============================================================================
@@ -117,51 +293,59 @@ class ModalityProjection(nn.Module):
117
  class MatTextEncoder(nn.Module):
118
  """
119
  Shared transformer encoder with per-modality projection heads.
120
- All modalities share the same backbone but project to a shared
121
- embedding space through modality-specific heads.
122
  """
123
  def __init__(self, config: Config):
124
  super().__init__()
125
  self.config = config
126
 
127
- # Shared backbone
128
- self.backbone = AutoModel.from_pretrained(config.encoder_name)
 
 
 
 
 
129
  hidden_size = self.backbone.config.hidden_size
130
 
131
  if config.gradient_checkpointing:
132
  self.backbone.gradient_checkpointing_enable()
133
 
134
- # Per-modality projection heads
135
  self.projections = nn.ModuleDict({
136
  mod: ModalityProjection(hidden_size, config.embed_dim)
137
  for mod in config.modalities
138
  })
139
 
140
- # Property projection (for property-conditioned queries)
141
- self.property_projection = ModalityProjection(hidden_size, config.embed_dim)
 
 
 
142
 
143
- # Learnable temperature
144
  self.log_temperature = nn.Parameter(
145
  torch.tensor(math.log(1.0 / config.temperature))
146
  )
147
 
148
  def encode(self, input_ids, attention_mask, modality_name):
149
- """Encode a single modality"""
150
  outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
151
-
152
- # Mean pooling
153
  mask = attention_mask.unsqueeze(-1).float()
154
  hidden = outputs.last_hidden_state
155
  pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
156
-
157
- # Project through modality-specific head
158
- if modality_name == "property":
159
- return self.property_projection(pooled)
160
  return self.projections[modality_name](pooled)
161
 
162
  @property
163
  def temperature(self):
164
  return torch.exp(self.log_temperature).clamp(min=0.01, max=100.0)
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  # ============================================================================
@@ -169,8 +353,9 @@ class MatTextEncoder(nn.Module):
169
  # ============================================================================
170
 
171
  def symmetric_clip_loss(emb_a, emb_b, temperature):
172
- """Symmetric InfoNCE (CLIP loss)"""
173
  N = emb_a.size(0)
 
 
174
  logits = (emb_a @ emb_b.T) * temperature
175
  labels = torch.arange(N, device=emb_a.device)
176
  loss_a = F.cross_entropy(logits, labels)
@@ -179,13 +364,11 @@ def symmetric_clip_loss(emb_a, emb_b, temperature):
179
 
180
 
181
  def all_pairs_clip_loss(embeddings_dict, temperature):
182
- """AllPairsCLIP: sum symmetric InfoNCE over all modality pairs."""
183
  mods = [k for k, v in embeddings_dict.items() if v is not None]
184
  if len(mods) < 2:
185
- return torch.tensor(0.0, requires_grad=True)
186
 
187
- device = embeddings_dict[mods[0]].device
188
- total_loss = torch.tensor(0.0, device=device)
189
  n_pairs = 0
190
 
191
  for i in range(len(mods)):
@@ -195,14 +378,13 @@ def all_pairs_clip_loss(embeddings_dict, temperature):
195
  )
196
  n_pairs += 1
197
 
198
- return total_loss / n_pairs
199
 
200
 
201
  def property_similarity_loss(embeddings, labels, temperature):
202
- """Property-aware soft contrastive loss (SupReMix-inspired)."""
203
  N = embeddings.size(0)
204
  if N < 2:
205
- return torch.tensor(0.0, requires_grad=True)
206
 
207
  label_diff = torch.abs(labels.unsqueeze(0) - labels.unsqueeze(1))
208
  max_diff = label_diff.max().clamp(min=1e-6)
@@ -220,20 +402,46 @@ def property_similarity_loss(embeddings, labels, temperature):
220
  # Dataset
221
  # ============================================================================
222
 
223
- class MatTextMultiModalDataset(Dataset):
224
- def __init__(self, data, modalities, property_col=None, property_name=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  self.data = data
226
  self.modalities = modalities
227
  self.property_col = property_col
228
  self.property_name = property_name
 
 
229
 
230
  available_cols = set(data.column_names) if hasattr(data, 'column_names') else set(data[0].keys())
231
  self.available_modalities = [m for m in modalities if m in available_cols]
232
- logger.info(f"Available modalities: {self.available_modalities}")
233
 
234
- self.has_properties = property_col is not None and property_col in available_cols
235
- if self.has_properties:
236
- logger.info(f"Property column '{property_col}' found")
237
 
238
  def __len__(self):
239
  return len(self.data)
@@ -241,6 +449,7 @@ class MatTextMultiModalDataset(Dataset):
241
  def __getitem__(self, idx):
242
  row = self.data[idx]
243
  item = {}
 
244
  for mod in self.available_modalities:
245
  text = row.get(mod, None)
246
  if text and isinstance(text, str) and len(text.strip()) > 0:
@@ -248,23 +457,35 @@ class MatTextMultiModalDataset(Dataset):
248
  else:
249
  item[mod] = None
250
 
 
 
 
251
  if self.has_properties and row.get(self.property_col) is not None:
252
  label_val = float(row[self.property_col])
253
- comp = row.get("composition", "unknown")
254
- item["property_text"] = f"composition: {comp} | {self.property_name}: {label_val:.4f}"
255
  item["property_label"] = label_val
 
 
 
 
 
 
 
 
 
 
 
256
  else:
257
- item["property_text"] = None
258
  item["property_label"] = None
 
 
259
 
260
  return item
261
 
262
 
263
- def collate_fn(batch, tokenizer, modalities, max_length):
264
  result = {}
265
- all_mod_keys = list(modalities) + ["property_text"]
266
 
267
- for mod in all_mod_keys:
268
  texts = [item.get(mod) for item in batch]
269
  valid_texts = [t for t in texts if t is not None]
270
  if len(valid_texts) == 0:
@@ -274,7 +495,10 @@ def collate_fn(batch, tokenizer, modalities, max_length):
274
  texts_clean = [t if t is not None else "" for t in texts]
275
  mask_valid = [t is not None for t in texts]
276
 
277
- encoded = tokenizer(texts_clean, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
 
 
 
278
  result[mod] = {
279
  "input_ids": encoded["input_ids"],
280
  "attention_mask": encoded["attention_mask"],
@@ -298,15 +522,23 @@ def collate_fn(batch, tokenizer, modalities, max_length):
298
  # Training Loop
299
  # ============================================================================
300
 
301
- def train_epoch(model, dataloader, optimizer, scheduler, config, epoch, scaler=None):
 
302
  model.train()
303
- total_loss = 0; total_clip_loss = 0; total_prop_loss = 0
 
 
 
304
  log_interval = 20
305
 
 
 
 
306
  optimizer.zero_grad()
307
 
308
  for batch_idx, batch in enumerate(dataloader):
309
- # Randomly sample modalities to save VRAM
 
310
  available_mods = [m for m in config.modalities if batch.get(m) is not None]
311
  if len(available_mods) > config.max_modalities_per_step:
312
  must_have = [m for m in ["composition", "crystal_text_llm"] if m in available_mods]
@@ -316,83 +548,144 @@ def train_epoch(model, dataloader, optimizer, scheduler, config, epoch, scaler=N
316
  else:
317
  sampled = available_mods
318
 
 
 
 
 
319
  embeddings = {}
320
- for mod in sampled:
321
- if batch.get(mod) is None:
322
- embeddings[mod] = None; continue
323
-
324
- input_ids = batch[mod]["input_ids"].to(config.device)
325
- attention_mask = batch[mod]["attention_mask"].to(config.device)
326
- valid_mask = batch[mod]["valid_mask"]
327
-
328
- if not valid_mask.any():
329
- embeddings[mod] = None; continue
330
-
331
- with torch.amp.autocast('cuda', enabled=config.fp16):
 
 
332
  emb = model.encode(input_ids, attention_mask, mod)
333
- emb = emb * valid_mask.to(config.device).unsqueeze(-1).float()
334
- embeddings[mod] = emb
335
 
336
- with torch.amp.autocast('cuda', enabled=config.fp16):
337
  temperature = model.temperature
338
  clip_l = all_pairs_clip_loss(embeddings, temperature)
339
 
340
  prop_l = torch.tensor(0.0, device=config.device)
341
- if batch.get("property_text") is not None and batch.get("property_labels") is not None:
342
- prop_ids = batch["property_text"]["input_ids"].to(config.device)
343
- prop_mask = batch["property_text"]["attention_mask"].to(config.device)
344
- prop_valid = batch["property_text"]["valid_mask"]
345
-
346
- if prop_valid.any():
347
- with torch.amp.autocast('cuda', enabled=config.fp16):
348
- prop_emb = model.encode(prop_ids, prop_mask, "property")
349
-
350
- labels = batch["property_labels"].to(config.device)
351
- labels_mask = batch["property_labels_mask"].to(config.device)
352
 
353
- if labels_mask.sum() > 1:
354
- prop_l = property_similarity_loss(prop_emb[labels_mask], labels[labels_mask], temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
- for anchor_mod in ["robocrys_rep", "crystal_text_llm", "composition"]:
357
- if embeddings.get(anchor_mod) is not None:
358
- with torch.amp.autocast('cuda', enabled=config.fp16):
359
- prop_clip = symmetric_clip_loss(
360
- prop_emb[labels_mask], embeddings[anchor_mod][labels_mask], temperature
361
- )
362
- prop_l = prop_l + 0.5 * prop_clip
363
- break
364
-
365
- loss = (clip_l + 0.3 * prop_l) / config.grad_accum_steps
366
-
367
- if config.fp16 and scaler is not None:
 
 
 
 
 
 
 
 
 
 
 
 
368
  scaler.scale(loss).backward()
369
  else:
370
  loss.backward()
371
 
372
  if (batch_idx + 1) % config.grad_accum_steps == 0:
373
- if config.fp16 and scaler is not None:
374
  scaler.unscale_(optimizer)
375
  torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
376
- scaler.step(optimizer); scaler.update()
 
377
  else:
378
  torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
379
  optimizer.step()
380
- scheduler.step(); optimizer.zero_grad()
 
 
381
 
382
  total_loss += loss.item() * config.grad_accum_steps
383
  total_clip_loss += clip_l.item()
384
  total_prop_loss += prop_l.item() if isinstance(prop_l, torch.Tensor) else prop_l
 
385
 
386
  if (batch_idx + 1) % log_interval == 0:
387
  avg = total_loss / (batch_idx + 1)
 
 
 
 
 
 
388
  logger.info(
389
- f"Epoch {epoch} | {batch_idx+1}/{len(dataloader)} | "
390
- f"Loss: {avg:.4f} | CLIP: {total_clip_loss/(batch_idx+1):.4f} | "
391
- f"Prop: {total_prop_loss/(batch_idx+1):.4f} | "
392
- f"LR: {scheduler.get_last_lr()[0]:.2e} | T: {model.temperature.item():.3f}"
393
  )
 
 
 
 
 
 
 
 
 
 
 
394
 
395
- return total_loss / max(len(dataloader), 1)
396
 
397
 
398
  # ============================================================================
@@ -400,37 +693,56 @@ def train_epoch(model, dataloader, optimizer, scheduler, config, epoch, scaler=N
400
  # ============================================================================
401
 
402
  @torch.no_grad()
403
- def evaluate_retrieval(model, dataloader, config, k_values=[1, 5, 10]):
404
  model.eval()
405
  all_embeddings = {mod: [] for mod in config.modalities}
406
 
 
 
 
407
  for batch in dataloader:
408
  for mod in config.modalities:
409
- if batch.get(mod) is None: continue
 
410
  input_ids = batch[mod]["input_ids"].to(config.device)
411
  attention_mask = batch[mod]["attention_mask"].to(config.device)
412
  valid_mask = batch[mod]["valid_mask"]
413
- if not valid_mask.any(): continue
 
 
 
 
414
 
415
- emb = model.encode(input_ids, attention_mask, mod).cpu()
416
  for i in range(len(emb)):
417
  all_embeddings[mod].append(emb[i] if valid_mask[i] else None)
418
 
419
  results = {}
420
  eval_pairs = [
421
- ("composition", "crystal_text_llm"), ("composition", "cif_symmetrized"),
422
- ("slices", "crystal_text_llm"), ("composition", "slices"),
 
 
 
 
 
423
  ]
424
  if len([e for e in all_embeddings.get("robocrys_rep", []) if e is not None]) > 0:
425
- eval_pairs.extend([("robocrys_rep", "composition"), ("robocrys_rep", "cif_symmetrized")])
 
 
 
 
426
 
427
  for mod_a, mod_b in eval_pairs:
428
- embs_a, embs_b = all_embeddings.get(mod_a, []), all_embeddings.get(mod_b, [])
429
- if not embs_a or not embs_b: continue
 
 
430
 
431
- valid_idx = [i for i in range(min(len(embs_a), len(embs_b)))
432
  if embs_a[i] is not None and embs_b[i] is not None]
433
- if len(valid_idx) < 10: continue
 
434
 
435
  ea = torch.stack([embs_a[i] for i in valid_idx])
436
  eb = torch.stack([embs_b[i] for i in valid_idx])
@@ -439,6 +751,8 @@ def evaluate_retrieval(model, dataloader, config, k_values=[1, 5, 10]):
439
  recalls = {}
440
  for k in k_values:
441
  kk = min(k, len(valid_idx) - 1)
 
 
442
  topk = sim.topk(kk, dim=1).indices
443
  correct = (topk == torch.arange(len(valid_idx)).unsqueeze(1)).any(dim=1)
444
  recalls[f"R@{k}"] = correct.float().mean().item()
@@ -449,15 +763,56 @@ def evaluate_retrieval(model, dataloader, config, k_values=[1, 5, 10]):
449
  return results
450
 
451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  # ============================================================================
453
  # FAISS Vector Database
454
  # ============================================================================
455
 
456
  def build_vector_database(model, dataset, tokenizer, config, modalities_to_index=None):
457
  if modalities_to_index is None:
458
- modalities_to_index = config.modalities
 
459
  model.eval()
460
 
 
 
 
461
  all_embeddings = {mod: [] for mod in modalities_to_index}
462
  all_metadata = []
463
  bs = 64
@@ -465,30 +820,43 @@ def build_vector_database(model, dataset, tokenizer, config, modalities_to_index
465
  for start in range(0, len(dataset), bs):
466
  end = min(start + bs, len(dataset))
467
  items = [dataset[i] for i in range(start, end)]
468
- batch = collate_fn(items, tokenizer, config.modalities, config.max_length)
469
 
470
  for item in items:
471
- all_metadata.append({"composition": item.get("composition", ""), "property_label": item.get("property_label")})
 
 
 
 
 
 
 
472
 
473
  with torch.no_grad():
474
  for mod in modalities_to_index:
475
  if batch.get(mod) is None:
476
- all_embeddings[mod].extend([None] * len(items)); continue
477
- emb = model.encode(
478
- batch[mod]["input_ids"].to(config.device),
479
- batch[mod]["attention_mask"].to(config.device), mod
480
- ).cpu().numpy()
 
 
 
481
  for i in range(len(emb)):
482
- all_embeddings[mod].append(emb[i] if batch[mod]["valid_mask"][i] else None)
 
 
 
483
 
484
- if (start // bs) % 10 == 0:
485
  logger.info(f"Indexed {end}/{len(dataset)}")
486
 
487
  indices = {}
488
  for mod in modalities_to_index:
489
  valid_embs = [e for e in all_embeddings[mod] if e is not None]
490
  valid_map = [i for i, e in enumerate(all_embeddings[mod]) if e is not None]
491
- if not valid_embs: continue
 
492
 
493
  emb_matrix = np.stack(valid_embs).astype(np.float32)
494
  faiss.normalize_L2(emb_matrix)
@@ -496,33 +864,56 @@ def build_vector_database(model, dataset, tokenizer, config, modalities_to_index
496
 
497
  if len(valid_embs) > 10000:
498
  nlist = min(100, int(np.sqrt(len(valid_embs))))
499
- q = faiss.IndexFlatIP(d)
500
- index = faiss.IndexIVFFlat(q, d, nlist, faiss.METRIC_INNER_PRODUCT)
501
  index.train(emb_matrix)
 
502
  else:
503
  index = faiss.IndexFlatIP(d)
504
- index.add(emb_matrix)
505
 
506
- indices[mod] = {"index": index, "valid_indices_map": valid_map,
507
- "metadata": [all_metadata[i] for i in valid_map]}
 
 
 
 
508
  logger.info(f"FAISS {mod}: {len(valid_embs)} vectors, dim={d}")
509
 
510
  return indices
511
 
512
 
513
  def search_vector_db(query_text, query_modality, model, tokenizer, indices, config, k=10):
 
 
 
 
 
 
514
  model.eval()
515
- enc = tokenizer([query_text], padding=True, truncation=True, max_length=config.max_length, return_tensors="pt")
 
 
 
 
 
 
 
 
516
  with torch.no_grad():
517
- q = model.encode(enc["input_ids"].to(config.device), enc["attention_mask"].to(config.device), query_modality)
518
- q = q.cpu().numpy().astype(np.float32)
519
- faiss.normalize_L2(q)
 
 
 
 
 
520
 
521
  results = []
522
  for mod_name, idx_data in indices.items():
523
- scores, ids = idx_data["index"].search(q, k)
524
  for s, i in zip(scores[0], ids[0]):
525
- if i >= 0:
526
  m = dict(idx_data["metadata"][i])
527
  m["matched_modality"] = mod_name
528
  results.append((float(s), m))
@@ -532,8 +923,10 @@ def search_vector_db(query_text, query_modality, model, tokenizer, indices, conf
532
  for s, m in results:
533
  c = m.get("composition", "")
534
  if c not in seen:
535
- seen.add(c); unique.append((s, m))
536
- if len(unique) >= k: break
 
 
537
  return unique
538
 
539
 
@@ -543,146 +936,245 @@ def search_vector_db(query_text, query_modality, model, tokenizer, indices, conf
543
 
544
  def main():
545
  config = Config()
546
- logger.info(f"Device: {config.device} | Encoder: {config.encoder_name}")
547
- logger.info(f"Batch: {config.batch_size}x{config.grad_accum_steps}={config.batch_size*config.grad_accum_steps}")
548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  try:
550
  import trackio
551
- trackio.init(project="mattext-embeddings", name=f"align-{config.encoder_name.split('/')[-1]}")
552
  use_trackio = True
553
- except:
554
- use_trackio = False
 
555
 
556
  tokenizer = AutoTokenizer.from_pretrained(config.encoder_name)
557
  model = MatTextEncoder(config).to(config.device)
558
- logger.info(f"Params: {sum(p.numel() for p in model.parameters()):,}")
 
 
559
 
560
- # Load data
561
- pretrain_data = load_dataset(config.dataset_name, config.pretrain_config, split="train")
562
- logger.info(f"Pretrain: {len(pretrain_data)} samples, cols: {pretrain_data.column_names}")
563
 
564
- finetune_data = None
565
- for ft_cfg, ft_split in config.finetune_configs:
566
- try:
567
- ft = load_dataset(config.dataset_name, ft_cfg, split=ft_split)
568
- logger.info(f"Loaded {ft_cfg}/{ft_split}: {len(ft)} samples")
569
- finetune_data = ft if finetune_data is None else concatenate_datasets([
570
- finetune_data.select_columns(list(set(finetune_data.column_names) & set(ft.column_names))),
571
- ft.select_columns(list(set(finetune_data.column_names) & set(ft.column_names)))
572
- ])
573
- except Exception as e:
574
- logger.warning(f"Failed {ft_cfg}: {e}")
575
 
576
- if len(pretrain_data) > config.max_train_samples:
577
- pretrain_data = pretrain_data.shuffle(seed=42).select(range(config.max_train_samples))
 
578
 
579
- make_collate = lambda tok, mods, ml: lambda batch: collate_fn(batch, tok, mods, ml)
 
580
 
581
- pretrain_loader = DataLoader(
582
- MatTextMultiModalDataset(pretrain_data, config.modalities),
583
- batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=0,
584
- collate_fn=make_collate(tokenizer, config.modalities, config.max_length),
585
- pin_memory=config.device == "cuda",
586
  )
587
 
588
- finetune_loader = None
589
- if finetune_data:
590
- if len(finetune_data) > config.max_train_samples:
591
- finetune_data = finetune_data.shuffle(seed=42).select(range(config.max_train_samples))
592
- finetune_loader = DataLoader(
593
- MatTextMultiModalDataset(finetune_data, config.modalities, "labels", "property_value"),
594
- batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=0,
595
- collate_fn=make_collate(tokenizer, config.modalities, config.max_length),
596
- pin_memory=config.device == "cuda",
597
- )
598
-
599
  optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
600
- total_steps = len(pretrain_loader) * config.num_epochs // config.grad_accum_steps
601
- if finetune_loader:
602
- total_steps += len(finetune_loader) * config.num_epochs // config.grad_accum_steps
603
- scheduler = get_cosine_schedule_with_warmup(optimizer, int(total_steps * config.warmup_ratio), total_steps)
604
- scaler = torch.amp.GradScaler('cuda') if config.fp16 else None
605
 
606
- logger.info(f"Steps: {total_steps}")
607
-
608
- # Phase 1: Multi-modal alignment
609
- logger.info("=" * 60 + "\nPhase 1: Multi-modal alignment\n" + "=" * 60)
610
  best_loss = float('inf')
611
- for epoch in range(1, config.num_epochs + 1):
 
 
612
  t0 = time.time()
613
- loss = train_epoch(model, pretrain_loader, optimizer, scheduler, config, epoch, scaler)
614
- logger.info(f"Epoch {epoch} | Loss: {loss:.4f} | Time: {time.time()-t0:.0f}s")
615
- if use_trackio:
616
- try: trackio.log({"phase": 1, "epoch": epoch, "loss": loss})
617
- except: pass
 
618
  if loss < best_loss:
619
  best_loss = loss
620
- os.makedirs(config.output_dir, exist_ok=True)
621
- torch.save(model.state_dict(), f"{config.output_dir}/best_model.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
622
 
623
- # Phase 2: Property-conditioned alignment
624
- if finetune_loader:
625
- logger.info("=" * 60 + "\nPhase 2: Property-conditioned alignment\n" + "=" * 60)
626
- for epoch in range(1, config.num_epochs + 1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
627
  t0 = time.time()
628
- loss = train_epoch(model, finetune_loader, optimizer, scheduler, config, epoch, scaler)
629
- logger.info(f"P2 Epoch {epoch} | Loss: {loss:.4f} | Time: {time.time()-t0:.0f}s")
 
 
 
 
630
  if loss < best_loss:
631
  best_loss = loss
632
  torch.save(model.state_dict(), f"{config.output_dir}/best_model.pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
 
634
- # Evaluate
635
- logger.info("=" * 60 + "\nEvaluation\n" + "=" * 60)
636
  eval_data = load_dataset(config.dataset_name, config.pretrain_config, split="test")
637
  if len(eval_data) > 5000:
638
  eval_data = eval_data.shuffle(seed=42).select(range(5000))
 
639
 
 
640
  eval_loader = DataLoader(
641
- MatTextMultiModalDataset(eval_data, config.modalities),
642
- batch_size=config.batch_size, shuffle=False, num_workers=0,
643
- collate_fn=make_collate(tokenizer, config.modalities, config.max_length),
644
  )
645
- results = evaluate_retrieval(model, eval_loader, config)
646
 
647
- # Build FAISS DB
648
- logger.info("Building FAISS indices...")
649
- db = build_vector_database(
650
- model, MatTextMultiModalDataset(eval_data, config.modalities),
651
- tokenizer, config, ["composition", "crystal_text_llm", "slices", "cif_symmetrized"]
 
652
  )
653
 
654
- os.makedirs(f"{config.output_dir}/faiss", exist_ok=True)
655
- for mod, d in db.items():
656
- faiss.write_index(d["index"], f"{config.output_dir}/faiss/{mod}.index")
657
- with open(f"{config.output_dir}/faiss/{mod}_metadata.json", "w") as f:
 
658
  json.dump(d["metadata"], f)
659
 
660
- # Demo
661
- for q, m in [("Fe2O3", "composition"), ("Si Ge", "composition")]:
662
- logger.info(f"\nQuery: '{q}' ({m})")
663
- for rank, (s, meta) in enumerate(search_vector_db(q, m, model, tokenizer, db, config, 5), 1):
664
- logger.info(f" #{rank}: {s:.4f} | {meta}")
665
 
666
- # Save & push
 
667
  torch.save(model.state_dict(), f"{config.output_dir}/model.pt")
668
  tokenizer.save_pretrained(config.output_dir)
 
 
 
 
 
 
 
 
 
 
 
669
  with open(f"{config.output_dir}/config.json", "w") as f:
670
- json.dump({k: str(v) if not isinstance(v, (int, float, str, bool, list, dict, type(None))) else v
671
- for k, v in vars(Config).items() if not k.startswith("_")}, f, indent=2)
672
  with open(f"{config.output_dir}/retrieval_results.json", "w") as f:
673
- json.dump(results, f, indent=2)
 
 
 
 
 
 
 
 
 
674
 
675
  if config.push_to_hub:
676
  try:
677
  api = HfApi()
678
  api.create_repo(config.hub_model_id, exist_ok=True)
679
- api.upload_folder(folder_path=config.output_dir, repo_id=config.hub_model_id,
680
- commit_message="Upload MatText aligned embeddings + FAISS indices")
681
- logger.info(f"Pushed to https://huggingface.co/{config.hub_model_id}")
 
 
 
682
  except Exception as e:
683
  logger.error(f"Push failed: {e}")
684
 
685
- logger.info("DONE!")
 
 
 
 
 
686
 
687
 
688
  if __name__ == "__main__":
 
1
  """
2
+ MatText Multi-Modal Embedding Alignment Training (v2)
3
 
4
+ Architecture: CLIP-style contrastive learning across 10+ material text representations
5
+ + LaCLIP-style natural language property descriptions for free-form querying
6
+
7
+ Key upgrades from v1:
8
+ - 1024 token context (was 512) — captures long CIFs
9
+ - Natural language property query support ("oxide with high bandgap")
10
+ - LaCLIP-style diverse NL description generation from structured labels
11
+ - A100 80GB optimized (bf16, larger batches, more modalities/step)
12
+ - Flash Attention 2 when available
13
+ - Phase 2 aligns NL descriptions ↔ all structure modalities
14
 
15
  Based on:
16
+ - MultiMat (AllPairsCLIP, arxiv:2312.00111)
17
  - MatExpert (property↔structure InfoNCE, arxiv:2410.21317)
18
+ - LaCLIP (LLM text augmentation, arxiv:2305.20088)
19
+ - SupReMix (property-label-aware soft contrastive, arxiv:2309.16633)
20
 
21
  Usage:
22
+ pip install torch transformers datasets faiss-cpu huggingface_hub trackio accelerate
23
  python train_mattext_embeddings.py
 
 
 
24
  """
25
 
26
  import os
 
29
  import time
30
  import logging
31
  import random
32
+ import re
33
  import numpy as np
34
  import torch
35
  import torch.nn as nn
 
43
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
44
  logger = logging.getLogger(__name__)
45
 
 
46
  # ============================================================================
47
  # Configuration
48
  # ============================================================================
 
50
  class Config:
51
  # Model
52
  encoder_name = "answerdotai/ModernBERT-base"
53
+ embed_dim = 128 # projection dimension
54
+ max_length = 1024 # tokens per modality (ModernBERT pretrained at 1024, extended to 8192)
55
 
56
  # Modalities to align (columns in the dataset)
57
  modalities = [
58
  "composition",
59
+ "atom_sequences",
60
  "cif_symmetrized",
61
  "cif_p1",
62
  "zmatrix",
 
64
  "slices",
65
  "crystal_text_llm",
66
  "local_env",
67
+ "robocrys_rep", # natural language structural description (pretrain only)
68
  ]
69
 
70
+ # Natural language query modality (separate from robocrys_rep)
71
+ # This is the key modality for queries like "oxide with high bandgap"
72
+ nl_query_modality = "nl_property_description"
73
+
74
  # Training
75
+ batch_size = 48 # A100 80GB can handle this at 1024 ctx with bf16
76
  learning_rate = 2e-5
77
  weight_decay = 0.01
78
+ num_epochs_phase1 = 3
79
+ num_epochs_phase2 = 3
80
  warmup_ratio = 0.1
81
+ temperature = 0.07
82
+ grad_accum_steps = 6 # effective batch = 48*6 = 288
83
  max_grad_norm = 1.0
84
  gradient_checkpointing = True
85
+ max_modalities_per_step = 5 # more than v1 since A100 80GB
86
 
87
  # Data
88
  dataset_name = "n0w0f/MatText"
89
  pretrain_config = "pretrain100k_v2"
90
  finetune_configs = [
91
+ ("bandgap-train-filtered", "fold_0", "bandgap"),
92
+ ("form_energy-train-filtered", "fold_0", "formation_energy"),
93
  ]
94
+ max_pretrain_samples = 60000
95
+ max_finetune_samples = 60000
96
+
97
+ # NL description generation
98
+ nl_descriptions_per_sample = 3 # LaCLIP: diverse paraphrases per sample
99
 
100
  # Output
101
  output_dir = "mattext-embeddings"
 
104
 
105
  # Device
106
  device = "cuda" if torch.cuda.is_available() else "cpu"
107
+ use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
108
+ use_fp16 = torch.cuda.is_available() and not use_bf16
109
+ use_flash_attn = False # set True if flash-attn is installed
110
+
111
+
112
+ # ============================================================================
113
+ # NL Property Description Generator (LaCLIP-style)
114
+ # ============================================================================
115
+
116
+ class NLPropertyDescriptionGenerator:
117
+ """
118
+ Generates diverse natural language descriptions from structured material properties.
119
+ This bridges the gap between structured labels (bandgap=3.2) and free-form queries
120
+ ("oxide with high bandgap"). LaCLIP-inspired: multiple paraphrases per sample.
121
+ """
122
+
123
+ BANDGAP_QUALIFIERS = {
124
+ (0, 0.01): "zero",
125
+ (0.01, 0.5): "very narrow",
126
+ (0.5, 1.5): "narrow",
127
+ (1.5, 3.0): "moderate",
128
+ (3.0, 5.0): "wide",
129
+ (5.0, 100): "very wide",
130
+ }
131
+
132
+ FENERGY_QUALIFIERS = {
133
+ (-100, -3.0): "very stable",
134
+ (-3.0, -1.5): "stable",
135
+ (-1.5, -0.5): "moderately stable",
136
+ (-0.5, 0.0): "marginally stable",
137
+ (0.0, 1.0): "metastable",
138
+ (1.0, 100): "unstable",
139
+ }
140
+
141
+ ANION_PATTERNS = [
142
+ (r'O\d*$|O\d+[A-Z]', "oxide"),
143
+ (r'S\d*$|S\d+[A-Z]', "sulfide"),
144
+ (r'N\d*$|N\d+[A-Z]', "nitride"),
145
+ (r'F\d*$|F\d+[A-Z]', "fluoride"),
146
+ (r'Cl\d*$|Cl\d+[A-Z]', "chloride"),
147
+ (r'Br\d*$|Br\d+[A-Z]', "bromide"),
148
+ (r'I\d*$|I\d+[A-Z]', "iodide"),
149
+ (r'Se\d*$|Se\d+[A-Z]', "selenide"),
150
+ (r'Te\d*$|Te\d+[A-Z]', "telluride"),
151
+ (r'C\d*$|C\d+[A-Z]', "carbide"),
152
+ (r'H\d*$|H\d+[A-Z]', "hydride"),
153
+ ]
154
+
155
+ ELEMENT_COUNT_NAMES = {
156
+ 1: "elemental", 2: "binary", 3: "ternary", 4: "quaternary", 5: "quinary",
157
+ }
158
+
159
+ @classmethod
160
+ def _qualify_bandgap(cls, bg):
161
+ for (lo, hi), qual in cls.BANDGAP_QUALIFIERS.items():
162
+ if lo <= bg < hi:
163
+ return qual
164
+ return "moderate"
165
+
166
+ @classmethod
167
+ def _qualify_fenergy(cls, fe):
168
+ for (lo, hi), qual in cls.FENERGY_QUALIFIERS.items():
169
+ if lo <= fe < hi:
170
+ return qual
171
+ return "moderately stable"
172
+
173
+ @classmethod
174
+ def _detect_anion(cls, composition):
175
+ for pattern, name in cls.ANION_PATTERNS:
176
+ if re.search(pattern, composition):
177
+ return name
178
+ return "compound"
179
+
180
+ @classmethod
181
+ def _count_elements(cls, composition):
182
+ elements = re.findall(r'[A-Z][a-z]?', composition)
183
+ return len(set(elements))
184
+
185
+ @classmethod
186
+ def _get_elements(cls, composition):
187
+ return list(set(re.findall(r'[A-Z][a-z]?', composition)))
188
+
189
+ @classmethod
190
+ def generate_descriptions(cls, composition, property_name=None, property_value=None,
191
+ crystal_system=None, n=3):
192
+ """Generate n diverse NL descriptions for a material."""
193
+ anion_type = cls._detect_anion(composition)
194
+ n_elements = cls._count_elements(composition)
195
+ complexity = cls.ELEMENT_COUNT_NAMES.get(n_elements, "complex")
196
+
197
+ property_templates = []
198
+ if property_name == "bandgap" and property_value is not None:
199
+ qual = cls._qualify_bandgap(property_value)
200
+ property_templates.extend([
201
+ f"A {anion_type} material with {qual} bandgap of {property_value:.2f} eV.",
202
+ f"{composition} is a {complexity} {anion_type} with a {qual} electronic band gap ({property_value:.2f} eV).",
203
+ f"This {anion_type} has a bandgap of {property_value:.2f} eV, classified as {qual}.",
204
+ f"A {qual} bandgap {anion_type} ({property_value:.1f} eV) with composition {composition}.",
205
+ f"{composition}: {anion_type} semiconductor with {qual} band gap of {property_value:.2f} electron volts.",
206
+ f"An {anion_type} with {qual} bandgap around {property_value:.1f} eV, formula {composition}.",
207
+ f"This {complexity} {anion_type} ({composition}) exhibits a {qual} bandgap of approximately {property_value:.2f} eV.",
208
+ f"Material {composition} is a {qual}-gap {anion_type} with bandgap {property_value:.2f} eV.",
209
+ ])
210
+ if property_value > 3.0:
211
+ property_templates.append(
212
+ f"{composition} is a wide-gap {anion_type} suitable for UV applications, bandgap {property_value:.2f} eV."
213
+ )
214
+ if property_value < 1.0 and property_value > 0.01:
215
+ property_templates.append(
216
+ f"{composition} is a narrow-gap {anion_type}, potentially useful for infrared applications, bandgap {property_value:.2f} eV."
217
+ )
218
+ if property_value < 0.01:
219
+ property_templates.append(
220
+ f"{composition} is metallic or near-zero gap {anion_type} with bandgap {property_value:.3f} eV."
221
+ )
222
+
223
+ elif property_name == "formation_energy" and property_value is not None:
224
+ qual = cls._qualify_fenergy(property_value)
225
+ property_templates.extend([
226
+ f"A {qual} {anion_type} with formation energy of {property_value:.3f} eV/atom.",
227
+ f"{composition} is a {complexity} {anion_type} that is {qual} with formation energy {property_value:.3f} eV/atom.",
228
+ f"This {anion_type} ({composition}) has a formation energy of {property_value:.3f} eV/atom, making it {qual}.",
229
+ f"A {qual} {complexity} {anion_type}: {composition}, formation energy = {property_value:.3f} eV/atom.",
230
+ f"{composition}: thermodynamically {qual} {anion_type} (formation energy {property_value:.3f} eV/atom).",
231
+ f"This material ({composition}) is a {qual} {anion_type} compound with Ef = {property_value:.3f} eV/atom.",
232
+ f"A {anion_type} with composition {composition} showing {qual} thermodynamic stability ({property_value:.3f} eV/atom).",
233
+ ])
234
+
235
+ composition_templates = [
236
+ f"A {complexity} {anion_type} with formula {composition}.",
237
+ f"{composition} is a {complexity} {anion_type} compound.",
238
+ f"This material has composition {composition}, a {complexity} {anion_type}.",
239
+ f"A {anion_type} material: {composition} ({n_elements} elements).",
240
+ ]
241
+ if crystal_system:
242
+ composition_templates.extend([
243
+ f"{composition} is a {crystal_system} {anion_type}.",
244
+ f"A {crystal_system} structured {complexity} {anion_type}: {composition}.",
245
+ ])
246
+
247
+ combined_templates = []
248
+ if property_name and property_value is not None:
249
+ if property_name == "bandgap":
250
+ qual = cls._qualify_bandgap(property_value)
251
+ combined_templates.extend([
252
+ f"{composition} is a {complexity} {anion_type} with {qual} bandgap of {property_value:.2f} eV.",
253
+ f"A {qual} bandgap {complexity} {anion_type} material, {composition}, with band gap {property_value:.1f} eV.",
254
+ ])
255
+ elif property_name == "formation_energy":
256
+ qual = cls._qualify_fenergy(property_value)
257
+ combined_templates.extend([
258
+ f"{composition} is a {qual} {complexity} {anion_type} with formation energy {property_value:.3f} eV/atom.",
259
+ f"A {qual} {anion_type}, {composition}, with Ef = {property_value:.3f} eV/atom.",
260
+ ])
261
+
262
+ all_templates = property_templates + composition_templates + combined_templates
263
+ if not all_templates:
264
+ all_templates = composition_templates
265
+
266
+ if len(all_templates) >= n:
267
+ descriptions = random.sample(all_templates, n)
268
+ else:
269
+ descriptions = all_templates + random.choices(all_templates, k=n - len(all_templates))
270
+
271
+ return descriptions
272
 
273
 
274
  # ============================================================================
 
293
  class MatTextEncoder(nn.Module):
294
  """
295
  Shared transformer encoder with per-modality projection heads.
296
+ Includes an NL query projection head for free-form text queries.
 
297
  """
298
  def __init__(self, config: Config):
299
  super().__init__()
300
  self.config = config
301
 
302
+ model_kwargs = {}
303
+ if config.use_flash_attn:
304
+ model_kwargs["attn_implementation"] = "flash_attention_2"
305
+ if config.use_bf16:
306
+ model_kwargs["torch_dtype"] = torch.bfloat16
307
+
308
+ self.backbone = AutoModel.from_pretrained(config.encoder_name, **model_kwargs)
309
  hidden_size = self.backbone.config.hidden_size
310
 
311
  if config.gradient_checkpointing:
312
  self.backbone.gradient_checkpointing_enable()
313
 
 
314
  self.projections = nn.ModuleDict({
315
  mod: ModalityProjection(hidden_size, config.embed_dim)
316
  for mod in config.modalities
317
  })
318
 
319
+ # NL query head — for "oxide with high bandgap" style queries
320
+ self.projections[config.nl_query_modality] = ModalityProjection(hidden_size, config.embed_dim)
321
+
322
+ # Property head — for structured property text like "bandgap: 2.1"
323
+ self.projections["property"] = ModalityProjection(hidden_size, config.embed_dim)
324
 
 
325
  self.log_temperature = nn.Parameter(
326
  torch.tensor(math.log(1.0 / config.temperature))
327
  )
328
 
329
  def encode(self, input_ids, attention_mask, modality_name):
 
330
  outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
 
 
331
  mask = attention_mask.unsqueeze(-1).float()
332
  hidden = outputs.last_hidden_state
333
  pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
 
 
 
 
334
  return self.projections[modality_name](pooled)
335
 
336
  @property
337
  def temperature(self):
338
  return torch.exp(self.log_temperature).clamp(min=0.01, max=100.0)
339
+
340
+ def get_config_dict(self):
341
+ return {
342
+ "encoder_name": self.config.encoder_name,
343
+ "embed_dim": self.config.embed_dim,
344
+ "max_length": self.config.max_length,
345
+ "modalities": self.config.modalities,
346
+ "nl_query_modality": self.config.nl_query_modality,
347
+ "temperature": self.temperature.item(),
348
+ }
349
 
350
 
351
  # ============================================================================
 
353
  # ============================================================================
354
 
355
  def symmetric_clip_loss(emb_a, emb_b, temperature):
 
356
  N = emb_a.size(0)
357
+ if N < 2:
358
+ return torch.tensor(0.0, device=emb_a.device, requires_grad=True)
359
  logits = (emb_a @ emb_b.T) * temperature
360
  labels = torch.arange(N, device=emb_a.device)
361
  loss_a = F.cross_entropy(logits, labels)
 
364
 
365
 
366
  def all_pairs_clip_loss(embeddings_dict, temperature):
 
367
  mods = [k for k, v in embeddings_dict.items() if v is not None]
368
  if len(mods) < 2:
369
+ return torch.tensor(0.0, device=temperature.device, requires_grad=True)
370
 
371
+ total_loss = torch.tensor(0.0, device=temperature.device)
 
372
  n_pairs = 0
373
 
374
  for i in range(len(mods)):
 
378
  )
379
  n_pairs += 1
380
 
381
+ return total_loss / max(n_pairs, 1)
382
 
383
 
384
  def property_similarity_loss(embeddings, labels, temperature):
 
385
  N = embeddings.size(0)
386
  if N < 2:
387
+ return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
388
 
389
  label_diff = torch.abs(labels.unsqueeze(0) - labels.unsqueeze(1))
390
  max_diff = label_diff.max().clamp(min=1e-6)
 
402
  # Dataset
403
  # ============================================================================
404
 
405
+ class MatTextPhase1Dataset(Dataset):
406
+ """Phase 1: Multi-modal alignment on pretrain data (no labels)."""
407
+ def __init__(self, data, modalities):
408
+ self.data = data
409
+ self.modalities = modalities
410
+ available_cols = set(data.column_names) if hasattr(data, 'column_names') else set(data[0].keys())
411
+ self.available_modalities = [m for m in modalities if m in available_cols]
412
+ logger.info(f"Phase1 modalities: {self.available_modalities}")
413
+
414
+ def __len__(self):
415
+ return len(self.data)
416
+
417
+ def __getitem__(self, idx):
418
+ row = self.data[idx]
419
+ item = {}
420
+ for mod in self.available_modalities:
421
+ text = row.get(mod, None)
422
+ if text and isinstance(text, str) and len(text.strip()) > 0:
423
+ item[mod] = text.strip()
424
+ else:
425
+ item[mod] = None
426
+ return item
427
+
428
+
429
+ class MatTextPhase2Dataset(Dataset):
430
+ """Phase 2: Property-conditioned alignment with LaCLIP-style NL descriptions."""
431
+ def __init__(self, data, modalities, property_col, property_name, nl_descriptions_per_sample=3):
432
  self.data = data
433
  self.modalities = modalities
434
  self.property_col = property_col
435
  self.property_name = property_name
436
+ self.nl_descriptions_per_sample = nl_descriptions_per_sample
437
+ self.nl_gen = NLPropertyDescriptionGenerator()
438
 
439
  available_cols = set(data.column_names) if hasattr(data, 'column_names') else set(data[0].keys())
440
  self.available_modalities = [m for m in modalities if m in available_cols]
441
+ self.has_properties = property_col in available_cols
442
 
443
+ logger.info(f"Phase2 modalities: {self.available_modalities}")
444
+ logger.info(f"Property: {property_name} (col={property_col}, has={self.has_properties})")
 
445
 
446
  def __len__(self):
447
  return len(self.data)
 
449
  def __getitem__(self, idx):
450
  row = self.data[idx]
451
  item = {}
452
+
453
  for mod in self.available_modalities:
454
  text = row.get(mod, None)
455
  if text and isinstance(text, str) and len(text.strip()) > 0:
 
457
  else:
458
  item[mod] = None
459
 
460
+ composition = row.get("composition", "unknown")
461
+ crystal_system = row.get("crystal_system", None)
462
+
463
  if self.has_properties and row.get(self.property_col) is not None:
464
  label_val = float(row[self.property_col])
 
 
465
  item["property_label"] = label_val
466
+ item["property_text"] = f"composition: {composition} | {self.property_name}: {label_val:.4f}"
467
+
468
+ # LaCLIP-style diverse NL descriptions — randomly sample one per call
469
+ nl_descs = self.nl_gen.generate_descriptions(
470
+ composition=composition,
471
+ property_name=self.property_name,
472
+ property_value=label_val,
473
+ crystal_system=crystal_system,
474
+ n=self.nl_descriptions_per_sample,
475
+ )
476
+ item["nl_property_description"] = random.choice(nl_descs)
477
  else:
 
478
  item["property_label"] = None
479
+ item["property_text"] = None
480
+ item["nl_property_description"] = None
481
 
482
  return item
483
 
484
 
485
+ def collate_fn(batch, tokenizer, all_modality_keys, max_length):
486
  result = {}
 
487
 
488
+ for mod in all_modality_keys:
489
  texts = [item.get(mod) for item in batch]
490
  valid_texts = [t for t in texts if t is not None]
491
  if len(valid_texts) == 0:
 
495
  texts_clean = [t if t is not None else "" for t in texts]
496
  mask_valid = [t is not None for t in texts]
497
 
498
+ encoded = tokenizer(
499
+ texts_clean, padding=True, truncation=True,
500
+ max_length=max_length, return_tensors="pt"
501
+ )
502
  result[mod] = {
503
  "input_ids": encoded["input_ids"],
504
  "attention_mask": encoded["attention_mask"],
 
522
  # Training Loop
523
  # ============================================================================
524
 
525
+ def train_epoch(model, dataloader, optimizer, scheduler, config, epoch, phase,
526
+ scaler=None, use_trackio=False, global_step=0):
527
  model.train()
528
+ total_loss = 0.0
529
+ total_clip_loss = 0.0
530
+ total_prop_loss = 0.0
531
+ total_nl_loss = 0.0
532
  log_interval = 20
533
 
534
+ autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32)
535
+ use_amp = config.use_bf16 or config.use_fp16
536
+
537
  optimizer.zero_grad()
538
 
539
  for batch_idx, batch in enumerate(dataloader):
540
+ step_start = time.time()
541
+
542
  available_mods = [m for m in config.modalities if batch.get(m) is not None]
543
  if len(available_mods) > config.max_modalities_per_step:
544
  must_have = [m for m in ["composition", "crystal_text_llm"] if m in available_mods]
 
548
  else:
549
  sampled = available_mods
550
 
551
+ if phase == 2 and batch.get(config.nl_query_modality) is not None:
552
+ if config.nl_query_modality not in sampled:
553
+ sampled.append(config.nl_query_modality)
554
+
555
  embeddings = {}
556
+ with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
557
+ for mod in sampled:
558
+ if batch.get(mod) is None:
559
+ embeddings[mod] = None
560
+ continue
561
+
562
+ input_ids = batch[mod]["input_ids"].to(config.device)
563
+ attention_mask = batch[mod]["attention_mask"].to(config.device)
564
+ valid_mask = batch[mod]["valid_mask"]
565
+
566
+ if not valid_mask.any():
567
+ embeddings[mod] = None
568
+ continue
569
+
570
  emb = model.encode(input_ids, attention_mask, mod)
571
+ emb = emb * valid_mask.to(config.device).unsqueeze(-1).float()
572
+ embeddings[mod] = emb
573
 
574
+ with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
575
  temperature = model.temperature
576
  clip_l = all_pairs_clip_loss(embeddings, temperature)
577
 
578
  prop_l = torch.tensor(0.0, device=config.device)
579
+ nl_l = torch.tensor(0.0, device=config.device)
580
+
581
+ if phase == 2:
582
+ if batch.get("property_text") is not None:
583
+ prop_ids = batch["property_text"]["input_ids"].to(config.device)
584
+ prop_mask_att = batch["property_text"]["attention_mask"].to(config.device)
585
+ prop_valid = batch["property_text"]["valid_mask"]
 
 
 
 
586
 
587
+ if prop_valid.any():
588
+ with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
589
+ prop_emb = model.encode(prop_ids, prop_mask_att, "property")
590
+
591
+ labels = batch["property_labels"].to(config.device)
592
+ labels_mask = batch["property_labels_mask"].to(config.device)
593
+
594
+ if labels_mask.sum() > 1:
595
+ prop_l = property_similarity_loss(
596
+ prop_emb[labels_mask], labels[labels_mask], temperature
597
+ )
598
+
599
+ for anchor_mod in ["composition", "crystal_text_llm"]:
600
+ if embeddings.get(anchor_mod) is not None:
601
+ valid_both = labels_mask & batch[anchor_mod]["valid_mask"].to(config.device)
602
+ if valid_both.sum() > 1:
603
+ with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
604
+ prop_clip = symmetric_clip_loss(
605
+ prop_emb[valid_both],
606
+ embeddings[anchor_mod][valid_both],
607
+ temperature,
608
+ )
609
+ prop_l = prop_l + 0.5 * prop_clip
610
+
611
+ # NL property description ↔ all structure modalities
612
+ if embeddings.get(config.nl_query_modality) is not None:
613
+ nl_emb = embeddings[config.nl_query_modality]
614
+ nl_valid = batch[config.nl_query_modality]["valid_mask"].to(config.device)
615
 
616
+ if nl_valid.sum() > 1:
617
+ n_nl_pairs = 0
618
+ for struct_mod in sampled:
619
+ if struct_mod in [config.nl_query_modality, "property_text"]:
620
+ continue
621
+ if embeddings.get(struct_mod) is None:
622
+ continue
623
+ struct_valid = batch[struct_mod]["valid_mask"].to(config.device)
624
+ valid_both = nl_valid & struct_valid
625
+ if valid_both.sum() > 1:
626
+ with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
627
+ nl_struct_loss = symmetric_clip_loss(
628
+ nl_emb[valid_both],
629
+ embeddings[struct_mod][valid_both],
630
+ temperature,
631
+ )
632
+ nl_l = nl_l + nl_struct_loss
633
+ n_nl_pairs += 1
634
+ if n_nl_pairs > 0:
635
+ nl_l = nl_l / n_nl_pairs
636
+
637
+ loss = (clip_l + 0.3 * prop_l + 0.5 * nl_l) / config.grad_accum_steps
638
+
639
+ if scaler is not None:
640
  scaler.scale(loss).backward()
641
  else:
642
  loss.backward()
643
 
644
  if (batch_idx + 1) % config.grad_accum_steps == 0:
645
+ if scaler is not None:
646
  scaler.unscale_(optimizer)
647
  torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
648
+ scaler.step(optimizer)
649
+ scaler.update()
650
  else:
651
  torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
652
  optimizer.step()
653
+ scheduler.step()
654
+ optimizer.zero_grad()
655
+ global_step += 1
656
 
657
  total_loss += loss.item() * config.grad_accum_steps
658
  total_clip_loss += clip_l.item()
659
  total_prop_loss += prop_l.item() if isinstance(prop_l, torch.Tensor) else prop_l
660
+ total_nl_loss += nl_l.item() if isinstance(nl_l, torch.Tensor) else nl_l
661
 
662
  if (batch_idx + 1) % log_interval == 0:
663
  avg = total_loss / (batch_idx + 1)
664
+ avg_clip = total_clip_loss / (batch_idx + 1)
665
+ avg_prop = total_prop_loss / (batch_idx + 1)
666
+ avg_nl = total_nl_loss / (batch_idx + 1)
667
+ lr = scheduler.get_last_lr()[0]
668
+ step_time = time.time() - step_start
669
+
670
  logger.info(
671
+ f"P{phase} E{epoch} | {batch_idx+1}/{len(dataloader)} | "
672
+ f"Loss: {avg:.4f} | CLIP: {avg_clip:.4f} | Prop: {avg_prop:.4f} | "
673
+ f"NL: {avg_nl:.4f} | LR: {lr:.2e} | T: {model.temperature.item():.3f} | "
674
+ f"mods: {len(sampled)} | {step_time:.1f}s/step"
675
  )
676
+
677
+ if use_trackio:
678
+ try:
679
+ import trackio
680
+ trackio.log({
681
+ "phase": phase, "epoch": epoch, "step": global_step,
682
+ "loss": avg, "clip_loss": avg_clip, "prop_loss": avg_prop,
683
+ "nl_loss": avg_nl, "lr": lr, "temperature": model.temperature.item(),
684
+ })
685
+ except:
686
+ pass
687
 
688
+ return total_loss / max(len(dataloader), 1), global_step
689
 
690
 
691
  # ============================================================================
 
693
  # ============================================================================
694
 
695
  @torch.no_grad()
696
+ def evaluate_retrieval(model, dataloader, config, k_values=[1, 5, 10, 20]):
697
  model.eval()
698
  all_embeddings = {mod: [] for mod in config.modalities}
699
 
700
+ autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32)
701
+ use_amp = config.use_bf16 or config.use_fp16
702
+
703
  for batch in dataloader:
704
  for mod in config.modalities:
705
+ if batch.get(mod) is None:
706
+ continue
707
  input_ids = batch[mod]["input_ids"].to(config.device)
708
  attention_mask = batch[mod]["attention_mask"].to(config.device)
709
  valid_mask = batch[mod]["valid_mask"]
710
+ if not valid_mask.any():
711
+ continue
712
+
713
+ with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
714
+ emb = model.encode(input_ids, attention_mask, mod).float().cpu()
715
 
 
716
  for i in range(len(emb)):
717
  all_embeddings[mod].append(emb[i] if valid_mask[i] else None)
718
 
719
  results = {}
720
  eval_pairs = [
721
+ ("composition", "crystal_text_llm"),
722
+ ("composition", "cif_symmetrized"),
723
+ ("composition", "slices"),
724
+ ("slices", "crystal_text_llm"),
725
+ ("composition", "zmatrix"),
726
+ ("composition", "atom_sequences_plusplus"),
727
+ ("local_env", "composition"),
728
  ]
729
  if len([e for e in all_embeddings.get("robocrys_rep", []) if e is not None]) > 0:
730
+ eval_pairs.extend([
731
+ ("robocrys_rep", "composition"),
732
+ ("robocrys_rep", "cif_symmetrized"),
733
+ ("robocrys_rep", "slices"),
734
+ ])
735
 
736
  for mod_a, mod_b in eval_pairs:
737
+ embs_a = all_embeddings.get(mod_a, [])
738
+ embs_b = all_embeddings.get(mod_b, [])
739
+ if not embs_a or not embs_b:
740
+ continue
741
 
742
+ valid_idx = [i for i in range(min(len(embs_a), len(embs_b)))
743
  if embs_a[i] is not None and embs_b[i] is not None]
744
+ if len(valid_idx) < 10:
745
+ continue
746
 
747
  ea = torch.stack([embs_a[i] for i in valid_idx])
748
  eb = torch.stack([embs_b[i] for i in valid_idx])
 
751
  recalls = {}
752
  for k in k_values:
753
  kk = min(k, len(valid_idx) - 1)
754
+ if kk < 1:
755
+ continue
756
  topk = sim.topk(kk, dim=1).indices
757
  correct = (topk == torch.arange(len(valid_idx)).unsqueeze(1)).any(dim=1)
758
  recalls[f"R@{k}"] = correct.float().mean().item()
 
763
  return results
764
 
765
 
766
+ @torch.no_grad()
767
+ def evaluate_nl_queries(model, tokenizer, indices, config):
768
+ model.eval()
769
+
770
+ test_queries = [
771
+ ("oxide with high bandgap", config.nl_query_modality),
772
+ ("narrow bandgap semiconductor", config.nl_query_modality),
773
+ ("stable binary oxide", config.nl_query_modality),
774
+ ("wide bandgap fluoride", config.nl_query_modality),
775
+ ("ternary sulfide with low formation energy", config.nl_query_modality),
776
+ ("metallic nitride", config.nl_query_modality),
777
+ ("Fe2O3", "composition"),
778
+ ("SiO2", "composition"),
779
+ ("TiO2", "composition"),
780
+ ("GaN", "composition"),
781
+ ("perovskite structure with octahedral coordination", "robocrys_rep"),
782
+ ("cubic crystal with face-centered lattice", "robocrys_rep"),
783
+ ]
784
+
785
+ results = {}
786
+ for query_text, query_modality in test_queries:
787
+ try:
788
+ hits = search_vector_db(query_text, query_modality, model, tokenizer, indices, config, k=5)
789
+ results[query_text] = {
790
+ "modality": query_modality,
791
+ "top_hits": [(s, m) for s, m in hits],
792
+ }
793
+ logger.info(f"\nQuery: '{query_text}' (via {query_modality})")
794
+ for rank, (score, meta) in enumerate(hits[:5], 1):
795
+ logger.info(f" #{rank}: {score:.4f} | {meta.get('composition', 'N/A')} | "
796
+ f"via {meta.get('matched_modality', 'N/A')}")
797
+ except Exception as e:
798
+ logger.warning(f"Query '{query_text}' failed: {e}")
799
+
800
+ return results
801
+
802
+
803
  # ============================================================================
804
  # FAISS Vector Database
805
  # ============================================================================
806
 
807
  def build_vector_database(model, dataset, tokenizer, config, modalities_to_index=None):
808
  if modalities_to_index is None:
809
+ modalities_to_index = ["composition", "crystal_text_llm", "slices",
810
+ "cif_symmetrized", "robocrys_rep"]
811
  model.eval()
812
 
813
+ autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32)
814
+ use_amp = config.use_bf16 or config.use_fp16
815
+
816
  all_embeddings = {mod: [] for mod in modalities_to_index}
817
  all_metadata = []
818
  bs = 64
 
820
  for start in range(0, len(dataset), bs):
821
  end = min(start + bs, len(dataset))
822
  items = [dataset[i] for i in range(start, end)]
 
823
 
824
  for item in items:
825
+ meta = {
826
+ "composition": item.get("composition", ""),
827
+ "property_label": item.get("property_label"),
828
+ }
829
+ all_metadata.append(meta)
830
+
831
+ all_mod_keys = list(config.modalities)
832
+ batch = collate_fn(items, tokenizer, all_mod_keys, config.max_length)
833
 
834
  with torch.no_grad():
835
  for mod in modalities_to_index:
836
  if batch.get(mod) is None:
837
+ all_embeddings[mod].extend([None] * len(items))
838
+ continue
839
+ with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
840
+ emb = model.encode(
841
+ batch[mod]["input_ids"].to(config.device),
842
+ batch[mod]["attention_mask"].to(config.device),
843
+ mod,
844
+ ).float().cpu().numpy()
845
  for i in range(len(emb)):
846
+ if batch[mod]["valid_mask"][i]:
847
+ all_embeddings[mod].append(emb[i])
848
+ else:
849
+ all_embeddings[mod].append(None)
850
 
851
+ if (start // bs) % 20 == 0:
852
  logger.info(f"Indexed {end}/{len(dataset)}")
853
 
854
  indices = {}
855
  for mod in modalities_to_index:
856
  valid_embs = [e for e in all_embeddings[mod] if e is not None]
857
  valid_map = [i for i, e in enumerate(all_embeddings[mod]) if e is not None]
858
+ if not valid_embs:
859
+ continue
860
 
861
  emb_matrix = np.stack(valid_embs).astype(np.float32)
862
  faiss.normalize_L2(emb_matrix)
 
864
 
865
  if len(valid_embs) > 10000:
866
  nlist = min(100, int(np.sqrt(len(valid_embs))))
867
+ quantizer = faiss.IndexFlatIP(d)
868
+ index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
869
  index.train(emb_matrix)
870
+ index.nprobe = 10
871
  else:
872
  index = faiss.IndexFlatIP(d)
 
873
 
874
+ index.add(emb_matrix)
875
+ indices[mod] = {
876
+ "index": index,
877
+ "valid_indices_map": valid_map,
878
+ "metadata": [all_metadata[i] for i in valid_map],
879
+ }
880
  logger.info(f"FAISS {mod}: {len(valid_embs)} vectors, dim={d}")
881
 
882
  return indices
883
 
884
 
885
  def search_vector_db(query_text, query_modality, model, tokenizer, indices, config, k=10):
886
+ """Search the vector DB with any modality query.
887
+
888
+ For NL queries like "oxide with high bandgap": query_modality="nl_property_description"
889
+ For composition queries like "Fe2O3": query_modality="composition"
890
+ For structure descriptions: query_modality="robocrys_rep"
891
+ """
892
  model.eval()
893
+
894
+ autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32)
895
+ use_amp = config.use_bf16 or config.use_fp16
896
+
897
+ enc = tokenizer(
898
+ [query_text], padding=True, truncation=True,
899
+ max_length=config.max_length, return_tensors="pt",
900
+ )
901
+
902
  with torch.no_grad():
903
+ with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
904
+ q_emb = model.encode(
905
+ enc["input_ids"].to(config.device),
906
+ enc["attention_mask"].to(config.device),
907
+ query_modality,
908
+ ).float().cpu().numpy().astype(np.float32)
909
+
910
+ faiss.normalize_L2(q_emb)
911
 
912
  results = []
913
  for mod_name, idx_data in indices.items():
914
+ scores, ids = idx_data["index"].search(q_emb, k)
915
  for s, i in zip(scores[0], ids[0]):
916
+ if i >= 0 and i < len(idx_data["metadata"]):
917
  m = dict(idx_data["metadata"][i])
918
  m["matched_modality"] = mod_name
919
  results.append((float(s), m))
 
923
  for s, m in results:
924
  c = m.get("composition", "")
925
  if c not in seen:
926
+ seen.add(c)
927
+ unique.append((s, m))
928
+ if len(unique) >= k:
929
+ break
930
  return unique
931
 
932
 
 
936
 
937
  def main():
938
  config = Config()
 
 
939
 
940
+ try:
941
+ from flash_attn import flash_attn_func
942
+ config.use_flash_attn = True
943
+ logger.info("Flash Attention 2 available — enabling")
944
+ except ImportError:
945
+ config.use_flash_attn = False
946
+ logger.info("Flash Attention 2 not available — using default attention")
947
+
948
+ logger.info(f"Device: {config.device}")
949
+ logger.info(f"Precision: {'bf16' if config.use_bf16 else 'fp16' if config.use_fp16 else 'fp32'}")
950
+ logger.info(f"Max length: {config.max_length}")
951
+ logger.info(f"Batch: {config.batch_size} × {config.grad_accum_steps} = {config.batch_size * config.grad_accum_steps} effective")
952
+ logger.info(f"Encoder: {config.encoder_name}")
953
+
954
+ use_trackio = False
955
  try:
956
  import trackio
957
+ trackio.init(project="mattext-embeddings", name=f"align-v2-{config.max_length}ctx")
958
  use_trackio = True
959
+ logger.info("Trackio initialized")
960
+ except Exception as e:
961
+ logger.warning(f"Trackio init failed: {e}")
962
 
963
  tokenizer = AutoTokenizer.from_pretrained(config.encoder_name)
964
  model = MatTextEncoder(config).to(config.device)
965
+ total_params = sum(p.numel() for p in model.parameters())
966
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
967
+ logger.info(f"Total params: {total_params:,} | Trainable: {trainable_params:,}")
968
 
969
+ # Phase 1
970
+ logger.info("=" * 70 + "\nPHASE 1: Multi-modal alignment on pretrain100k_v2\n" + "=" * 70)
 
971
 
972
+ pretrain_data = load_dataset(config.dataset_name, config.pretrain_config, split="train")
973
+ logger.info(f"Pretrain loaded: {len(pretrain_data)} samples, cols: {pretrain_data.column_names}")
 
 
 
 
 
 
 
 
 
974
 
975
+ if len(pretrain_data) > config.max_pretrain_samples:
976
+ pretrain_data = pretrain_data.shuffle(seed=42).select(range(config.max_pretrain_samples))
977
+ logger.info(f"Subsampled to {len(pretrain_data)}")
978
 
979
+ phase1_dataset = MatTextPhase1Dataset(pretrain_data, config.modalities)
980
+ make_collate = lambda mods: lambda batch: collate_fn(batch, tokenizer, mods, config.max_length)
981
 
982
+ phase1_loader = DataLoader(
983
+ phase1_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True,
984
+ num_workers=2, collate_fn=make_collate(config.modalities),
985
+ pin_memory=(config.device == "cuda"), prefetch_factor=2,
 
986
  )
987
 
 
 
 
 
 
 
 
 
 
 
 
988
  optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
989
+ phase1_steps = len(phase1_loader) * config.num_epochs_phase1 // config.grad_accum_steps
990
+ scheduler = get_cosine_schedule_with_warmup(optimizer, int(phase1_steps * config.warmup_ratio), phase1_steps)
991
+ scaler = torch.amp.GradScaler('cuda') if config.use_fp16 else None
 
 
992
 
993
+ global_step = 0
 
 
 
994
  best_loss = float('inf')
995
+ os.makedirs(config.output_dir, exist_ok=True)
996
+
997
+ for epoch in range(1, config.num_epochs_phase1 + 1):
998
  t0 = time.time()
999
+ loss, global_step = train_epoch(
1000
+ model, phase1_loader, optimizer, scheduler, config,
1001
+ epoch, phase=1, scaler=scaler, use_trackio=use_trackio, global_step=global_step,
1002
+ )
1003
+ elapsed = time.time() - t0
1004
+ logger.info(f"Phase1 Epoch {epoch}/{config.num_epochs_phase1} | Loss: {loss:.4f} | Time: {elapsed:.0f}s ({elapsed/60:.1f}min)")
1005
  if loss < best_loss:
1006
  best_loss = loss
1007
+ torch.save(model.state_dict(), f"{config.output_dir}/best_model_phase1.pt")
1008
+ logger.info(f" → New best model saved (loss={loss:.4f})")
1009
+
1010
+ del pretrain_data, phase1_dataset, phase1_loader
1011
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
1012
+
1013
+ # Phase 2
1014
+ logger.info("=" * 70 + "\nPHASE 2: Property-conditioned alignment + NL query training\n" + "=" * 70)
1015
+
1016
+ finetune_datasets = []
1017
+ for ft_cfg, ft_split, prop_name in config.finetune_configs:
1018
+ try:
1019
+ ft = load_dataset(config.dataset_name, ft_cfg, split=ft_split)
1020
+ logger.info(f"Loaded {ft_cfg}/{ft_split}: {len(ft)} samples")
1021
+ finetune_datasets.append((ft, prop_name))
1022
+ except Exception as e:
1023
+ logger.warning(f"Failed to load {ft_cfg}/{ft_split}: {e}")
1024
 
1025
+ if finetune_datasets:
1026
+ all_phase2_datasets = []
1027
+ for ft_data, prop_name in finetune_datasets:
1028
+ if len(ft_data) > config.max_finetune_samples // len(finetune_datasets):
1029
+ n = config.max_finetune_samples // len(finetune_datasets)
1030
+ ft_data = ft_data.shuffle(seed=42).select(range(n))
1031
+
1032
+ phase2_ds = MatTextPhase2Dataset(
1033
+ ft_data, config.modalities, "labels", prop_name,
1034
+ nl_descriptions_per_sample=config.nl_descriptions_per_sample,
1035
+ )
1036
+ all_phase2_datasets.append(phase2_ds)
1037
+ logger.info(f"Phase2 dataset ({prop_name}): {len(phase2_ds)} samples")
1038
+
1039
+ class ConcatPhase2Dataset(Dataset):
1040
+ def __init__(self, datasets):
1041
+ self.datasets = datasets
1042
+ self.lengths = [len(d) for d in datasets]
1043
+ self.total = sum(self.lengths)
1044
+ self.cum_lengths = []
1045
+ acc = 0
1046
+ for l in self.lengths:
1047
+ self.cum_lengths.append(acc)
1048
+ acc += l
1049
+ def __len__(self):
1050
+ return self.total
1051
+ def __getitem__(self, idx):
1052
+ for i, (cum, length) in enumerate(zip(self.cum_lengths, self.lengths)):
1053
+ if idx < cum + length:
1054
+ return self.datasets[i][idx - cum]
1055
+ return self.datasets[-1][idx - self.cum_lengths[-1]]
1056
+
1057
+ combined_phase2 = ConcatPhase2Dataset(all_phase2_datasets)
1058
+ phase2_mod_keys = list(config.modalities) + [config.nl_query_modality, "property_text"]
1059
+
1060
+ phase2_loader = DataLoader(
1061
+ combined_phase2, batch_size=config.batch_size, shuffle=True, drop_last=True,
1062
+ num_workers=2,
1063
+ collate_fn=lambda batch: collate_fn(batch, tokenizer, phase2_mod_keys, config.max_length),
1064
+ pin_memory=(config.device == "cuda"), prefetch_factor=2,
1065
+ )
1066
+
1067
+ optimizer2 = torch.optim.AdamW(
1068
+ model.parameters(), lr=config.learning_rate * 0.5, weight_decay=config.weight_decay,
1069
+ )
1070
+ phase2_steps = len(phase2_loader) * config.num_epochs_phase2 // config.grad_accum_steps
1071
+ scheduler2 = get_cosine_schedule_with_warmup(optimizer2, int(phase2_steps * config.warmup_ratio), phase2_steps)
1072
+
1073
+ for epoch in range(1, config.num_epochs_phase2 + 1):
1074
  t0 = time.time()
1075
+ loss, global_step = train_epoch(
1076
+ model, phase2_loader, optimizer2, scheduler2, config,
1077
+ epoch, phase=2, scaler=scaler, use_trackio=use_trackio, global_step=global_step,
1078
+ )
1079
+ elapsed = time.time() - t0
1080
+ logger.info(f"Phase2 Epoch {epoch}/{config.num_epochs_phase2} | Loss: {loss:.4f} | Time: {elapsed:.0f}s ({elapsed/60:.1f}min)")
1081
  if loss < best_loss:
1082
  best_loss = loss
1083
  torch.save(model.state_dict(), f"{config.output_dir}/best_model.pt")
1084
+ logger.info(f" → New best model saved (loss={loss:.4f})")
1085
+
1086
+ del combined_phase2, phase2_loader
1087
+ else:
1088
+ logger.warning("No finetune data loaded — skipping Phase 2")
1089
+
1090
+ # Evaluation
1091
+ logger.info("=" * 70 + "\nEVALUATION\n" + "=" * 70)
1092
+
1093
+ best_path = f"{config.output_dir}/best_model.pt"
1094
+ if not os.path.exists(best_path):
1095
+ best_path = f"{config.output_dir}/best_model_phase1.pt"
1096
+ if os.path.exists(best_path):
1097
+ model.load_state_dict(torch.load(best_path, map_location=config.device))
1098
+ logger.info(f"Loaded best model from {best_path}")
1099
 
 
 
1100
  eval_data = load_dataset(config.dataset_name, config.pretrain_config, split="test")
1101
  if len(eval_data) > 5000:
1102
  eval_data = eval_data.shuffle(seed=42).select(range(5000))
1103
+ logger.info(f"Eval data: {len(eval_data)} samples")
1104
 
1105
+ eval_dataset = MatTextPhase1Dataset(eval_data, config.modalities)
1106
  eval_loader = DataLoader(
1107
+ eval_dataset, batch_size=config.batch_size, shuffle=False,
1108
+ num_workers=2, collate_fn=make_collate(config.modalities),
 
1109
  )
 
1110
 
1111
+ retrieval_results = evaluate_retrieval(model, eval_loader, config)
1112
+
1113
+ logger.info("\nBuilding FAISS vector database...")
1114
+ db_indices = build_vector_database(
1115
+ model, eval_dataset, tokenizer, config,
1116
+ modalities_to_index=["composition", "crystal_text_llm", "slices", "cif_symmetrized", "robocrys_rep"],
1117
  )
1118
 
1119
+ faiss_dir = f"{config.output_dir}/faiss"
1120
+ os.makedirs(faiss_dir, exist_ok=True)
1121
+ for mod, d in db_indices.items():
1122
+ faiss.write_index(d["index"], f"{faiss_dir}/{mod}.index")
1123
+ with open(f"{faiss_dir}/{mod}_metadata.json", "w") as f:
1124
  json.dump(d["metadata"], f)
1125
 
1126
+ logger.info("\n" + "=" * 70 + "\nNATURAL LANGUAGE QUERY EVALUATION\n" + "=" * 70)
1127
+ nl_results = evaluate_nl_queries(model, tokenizer, db_indices, config)
 
 
 
1128
 
1129
+ # Save
1130
+ logger.info("\nSaving model and artifacts...")
1131
  torch.save(model.state_dict(), f"{config.output_dir}/model.pt")
1132
  tokenizer.save_pretrained(config.output_dir)
1133
+
1134
+ model_config = model.get_config_dict()
1135
+ model_config["training"] = {
1136
+ "num_epochs_phase1": config.num_epochs_phase1,
1137
+ "num_epochs_phase2": config.num_epochs_phase2,
1138
+ "batch_size": config.batch_size,
1139
+ "grad_accum_steps": config.grad_accum_steps,
1140
+ "learning_rate": config.learning_rate,
1141
+ "max_length": config.max_length,
1142
+ "nl_descriptions_per_sample": config.nl_descriptions_per_sample,
1143
+ }
1144
  with open(f"{config.output_dir}/config.json", "w") as f:
1145
+ json.dump(model_config, f, indent=2)
1146
+
1147
  with open(f"{config.output_dir}/retrieval_results.json", "w") as f:
1148
+ json.dump(retrieval_results, f, indent=2)
1149
+
1150
+ nl_results_serializable = {}
1151
+ for k, v in nl_results.items():
1152
+ nl_results_serializable[k] = {
1153
+ "modality": v["modality"],
1154
+ "top_hits": [(s, m) for s, m in v["top_hits"]],
1155
+ }
1156
+ with open(f"{config.output_dir}/nl_query_results.json", "w") as f:
1157
+ json.dump(nl_results_serializable, f, indent=2)
1158
 
1159
  if config.push_to_hub:
1160
  try:
1161
  api = HfApi()
1162
  api.create_repo(config.hub_model_id, exist_ok=True)
1163
+ api.upload_folder(
1164
+ folder_path=config.output_dir,
1165
+ repo_id=config.hub_model_id,
1166
+ commit_message=f"Upload MatText aligned embeddings v2 (1024 ctx, NL queries)",
1167
+ )
1168
+ logger.info(f"✓ Pushed to https://huggingface.co/{config.hub_model_id}")
1169
  except Exception as e:
1170
  logger.error(f"Push failed: {e}")
1171
 
1172
+ logger.info("\n" + "=" * 70)
1173
+ logger.info("TRAINING COMPLETE")
1174
+ logger.info(f"Model: {config.output_dir}/model.pt")
1175
+ logger.info(f"FAISS: {faiss_dir}/")
1176
+ logger.info(f"Hub: https://huggingface.co/{config.hub_model_id}")
1177
+ logger.info("=" * 70)
1178
 
1179
 
1180
  if __name__ == "__main__":