mattext-aligned-embeddings / train_mattext_embeddings.py
n0w0f's picture
v2: 1024 context, NL property queries (LaCLIP-style), A100 80GB optimized
7949a14 verified
"""
MatText Multi-Modal Embedding Alignment Training (v2)
Architecture: CLIP-style contrastive learning across 10+ material text representations
+ LaCLIP-style natural language property descriptions for free-form querying
Key upgrades from v1:
- 1024 token context (was 512) — captures long CIFs
- Natural language property query support ("oxide with high bandgap")
- LaCLIP-style diverse NL description generation from structured labels
- A100 80GB optimized (bf16, larger batches, more modalities/step)
- Flash Attention 2 when available
- Phase 2 aligns NL descriptions ↔ all structure modalities
Based on:
- MultiMat (AllPairsCLIP, arxiv:2312.00111)
- MatExpert (property↔structure InfoNCE, arxiv:2410.21317)
- LaCLIP (LLM text augmentation, arxiv:2305.20088)
- SupReMix (property-label-aware soft contrastive, arxiv:2309.16633)
Usage:
pip install torch transformers datasets faiss-cpu huggingface_hub trackio accelerate
python train_mattext_embeddings.py
"""
import os
import json
import math
import time
import logging
import random
import re
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
from datasets import load_dataset, concatenate_datasets
from huggingface_hub import HfApi
import faiss
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ============================================================================
# Configuration
# ============================================================================
class Config:
# Model
encoder_name = "answerdotai/ModernBERT-base"
embed_dim = 128 # projection dimension
max_length = 1024 # tokens per modality (ModernBERT pretrained at 1024, extended to 8192)
# Modalities to align (columns in the dataset)
modalities = [
"composition",
"atom_sequences",
"cif_symmetrized",
"cif_p1",
"zmatrix",
"atom_sequences_plusplus",
"slices",
"crystal_text_llm",
"local_env",
"robocrys_rep", # natural language structural description (pretrain only)
]
# Natural language query modality (separate from robocrys_rep)
# This is the key modality for queries like "oxide with high bandgap"
nl_query_modality = "nl_property_description"
# Training
batch_size = 48 # A100 80GB can handle this at 1024 ctx with bf16
learning_rate = 2e-5
weight_decay = 0.01
num_epochs_phase1 = 3
num_epochs_phase2 = 3
warmup_ratio = 0.1
temperature = 0.07
grad_accum_steps = 6 # effective batch = 48*6 = 288
max_grad_norm = 1.0
gradient_checkpointing = True
max_modalities_per_step = 5 # more than v1 since A100 80GB
# Data
dataset_name = "n0w0f/MatText"
pretrain_config = "pretrain100k_v2"
finetune_configs = [
("bandgap-train-filtered", "fold_0", "bandgap"),
("form_energy-train-filtered", "fold_0", "formation_energy"),
]
max_pretrain_samples = 60000
max_finetune_samples = 60000
# NL description generation
nl_descriptions_per_sample = 3 # LaCLIP: diverse paraphrases per sample
# Output
output_dir = "mattext-embeddings"
hub_model_id = "n0w0f/mattext-aligned-embeddings"
push_to_hub = True
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
use_fp16 = torch.cuda.is_available() and not use_bf16
use_flash_attn = False # set True if flash-attn is installed
# ============================================================================
# NL Property Description Generator (LaCLIP-style)
# ============================================================================
class NLPropertyDescriptionGenerator:
"""
Generates diverse natural language descriptions from structured material properties.
This bridges the gap between structured labels (bandgap=3.2) and free-form queries
("oxide with high bandgap"). LaCLIP-inspired: multiple paraphrases per sample.
"""
BANDGAP_QUALIFIERS = {
(0, 0.01): "zero",
(0.01, 0.5): "very narrow",
(0.5, 1.5): "narrow",
(1.5, 3.0): "moderate",
(3.0, 5.0): "wide",
(5.0, 100): "very wide",
}
FENERGY_QUALIFIERS = {
(-100, -3.0): "very stable",
(-3.0, -1.5): "stable",
(-1.5, -0.5): "moderately stable",
(-0.5, 0.0): "marginally stable",
(0.0, 1.0): "metastable",
(1.0, 100): "unstable",
}
ANION_PATTERNS = [
(r'O\d*$|O\d+[A-Z]', "oxide"),
(r'S\d*$|S\d+[A-Z]', "sulfide"),
(r'N\d*$|N\d+[A-Z]', "nitride"),
(r'F\d*$|F\d+[A-Z]', "fluoride"),
(r'Cl\d*$|Cl\d+[A-Z]', "chloride"),
(r'Br\d*$|Br\d+[A-Z]', "bromide"),
(r'I\d*$|I\d+[A-Z]', "iodide"),
(r'Se\d*$|Se\d+[A-Z]', "selenide"),
(r'Te\d*$|Te\d+[A-Z]', "telluride"),
(r'C\d*$|C\d+[A-Z]', "carbide"),
(r'H\d*$|H\d+[A-Z]', "hydride"),
]
ELEMENT_COUNT_NAMES = {
1: "elemental", 2: "binary", 3: "ternary", 4: "quaternary", 5: "quinary",
}
@classmethod
def _qualify_bandgap(cls, bg):
for (lo, hi), qual in cls.BANDGAP_QUALIFIERS.items():
if lo <= bg < hi:
return qual
return "moderate"
@classmethod
def _qualify_fenergy(cls, fe):
for (lo, hi), qual in cls.FENERGY_QUALIFIERS.items():
if lo <= fe < hi:
return qual
return "moderately stable"
@classmethod
def _detect_anion(cls, composition):
for pattern, name in cls.ANION_PATTERNS:
if re.search(pattern, composition):
return name
return "compound"
@classmethod
def _count_elements(cls, composition):
elements = re.findall(r'[A-Z][a-z]?', composition)
return len(set(elements))
@classmethod
def _get_elements(cls, composition):
return list(set(re.findall(r'[A-Z][a-z]?', composition)))
@classmethod
def generate_descriptions(cls, composition, property_name=None, property_value=None,
crystal_system=None, n=3):
"""Generate n diverse NL descriptions for a material."""
anion_type = cls._detect_anion(composition)
n_elements = cls._count_elements(composition)
complexity = cls.ELEMENT_COUNT_NAMES.get(n_elements, "complex")
property_templates = []
if property_name == "bandgap" and property_value is not None:
qual = cls._qualify_bandgap(property_value)
property_templates.extend([
f"A {anion_type} material with {qual} bandgap of {property_value:.2f} eV.",
f"{composition} is a {complexity} {anion_type} with a {qual} electronic band gap ({property_value:.2f} eV).",
f"This {anion_type} has a bandgap of {property_value:.2f} eV, classified as {qual}.",
f"A {qual} bandgap {anion_type} ({property_value:.1f} eV) with composition {composition}.",
f"{composition}: {anion_type} semiconductor with {qual} band gap of {property_value:.2f} electron volts.",
f"An {anion_type} with {qual} bandgap around {property_value:.1f} eV, formula {composition}.",
f"This {complexity} {anion_type} ({composition}) exhibits a {qual} bandgap of approximately {property_value:.2f} eV.",
f"Material {composition} is a {qual}-gap {anion_type} with bandgap {property_value:.2f} eV.",
])
if property_value > 3.0:
property_templates.append(
f"{composition} is a wide-gap {anion_type} suitable for UV applications, bandgap {property_value:.2f} eV."
)
if property_value < 1.0 and property_value > 0.01:
property_templates.append(
f"{composition} is a narrow-gap {anion_type}, potentially useful for infrared applications, bandgap {property_value:.2f} eV."
)
if property_value < 0.01:
property_templates.append(
f"{composition} is metallic or near-zero gap {anion_type} with bandgap {property_value:.3f} eV."
)
elif property_name == "formation_energy" and property_value is not None:
qual = cls._qualify_fenergy(property_value)
property_templates.extend([
f"A {qual} {anion_type} with formation energy of {property_value:.3f} eV/atom.",
f"{composition} is a {complexity} {anion_type} that is {qual} with formation energy {property_value:.3f} eV/atom.",
f"This {anion_type} ({composition}) has a formation energy of {property_value:.3f} eV/atom, making it {qual}.",
f"A {qual} {complexity} {anion_type}: {composition}, formation energy = {property_value:.3f} eV/atom.",
f"{composition}: thermodynamically {qual} {anion_type} (formation energy {property_value:.3f} eV/atom).",
f"This material ({composition}) is a {qual} {anion_type} compound with Ef = {property_value:.3f} eV/atom.",
f"A {anion_type} with composition {composition} showing {qual} thermodynamic stability ({property_value:.3f} eV/atom).",
])
composition_templates = [
f"A {complexity} {anion_type} with formula {composition}.",
f"{composition} is a {complexity} {anion_type} compound.",
f"This material has composition {composition}, a {complexity} {anion_type}.",
f"A {anion_type} material: {composition} ({n_elements} elements).",
]
if crystal_system:
composition_templates.extend([
f"{composition} is a {crystal_system} {anion_type}.",
f"A {crystal_system} structured {complexity} {anion_type}: {composition}.",
])
combined_templates = []
if property_name and property_value is not None:
if property_name == "bandgap":
qual = cls._qualify_bandgap(property_value)
combined_templates.extend([
f"{composition} is a {complexity} {anion_type} with {qual} bandgap of {property_value:.2f} eV.",
f"A {qual} bandgap {complexity} {anion_type} material, {composition}, with band gap {property_value:.1f} eV.",
])
elif property_name == "formation_energy":
qual = cls._qualify_fenergy(property_value)
combined_templates.extend([
f"{composition} is a {qual} {complexity} {anion_type} with formation energy {property_value:.3f} eV/atom.",
f"A {qual} {anion_type}, {composition}, with Ef = {property_value:.3f} eV/atom.",
])
all_templates = property_templates + composition_templates + combined_templates
if not all_templates:
all_templates = composition_templates
if len(all_templates) >= n:
descriptions = random.sample(all_templates, n)
else:
descriptions = all_templates + random.choices(all_templates, k=n - len(all_templates))
return descriptions
# ============================================================================
# Model: Shared Encoder + Per-Modality Projection Heads
# ============================================================================
class ModalityProjection(nn.Module):
"""2-layer MLP projection head (MultiMat recipe)"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, input_dim),
nn.GELU(),
nn.LayerNorm(input_dim),
nn.Linear(input_dim, output_dim),
)
def forward(self, x):
return F.normalize(self.net(x), dim=-1)
class MatTextEncoder(nn.Module):
"""
Shared transformer encoder with per-modality projection heads.
Includes an NL query projection head for free-form text queries.
"""
def __init__(self, config: Config):
super().__init__()
self.config = config
model_kwargs = {}
if config.use_flash_attn:
model_kwargs["attn_implementation"] = "flash_attention_2"
if config.use_bf16:
model_kwargs["torch_dtype"] = torch.bfloat16
self.backbone = AutoModel.from_pretrained(config.encoder_name, **model_kwargs)
hidden_size = self.backbone.config.hidden_size
if config.gradient_checkpointing:
self.backbone.gradient_checkpointing_enable()
self.projections = nn.ModuleDict({
mod: ModalityProjection(hidden_size, config.embed_dim)
for mod in config.modalities
})
# NL query head — for "oxide with high bandgap" style queries
self.projections[config.nl_query_modality] = ModalityProjection(hidden_size, config.embed_dim)
# Property head — for structured property text like "bandgap: 2.1"
self.projections["property"] = ModalityProjection(hidden_size, config.embed_dim)
self.log_temperature = nn.Parameter(
torch.tensor(math.log(1.0 / config.temperature))
)
def encode(self, input_ids, attention_mask, modality_name):
outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
mask = attention_mask.unsqueeze(-1).float()
hidden = outputs.last_hidden_state
pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9)
return self.projections[modality_name](pooled)
@property
def temperature(self):
return torch.exp(self.log_temperature).clamp(min=0.01, max=100.0)
def get_config_dict(self):
return {
"encoder_name": self.config.encoder_name,
"embed_dim": self.config.embed_dim,
"max_length": self.config.max_length,
"modalities": self.config.modalities,
"nl_query_modality": self.config.nl_query_modality,
"temperature": self.temperature.item(),
}
# ============================================================================
# Loss Functions
# ============================================================================
def symmetric_clip_loss(emb_a, emb_b, temperature):
N = emb_a.size(0)
if N < 2:
return torch.tensor(0.0, device=emb_a.device, requires_grad=True)
logits = (emb_a @ emb_b.T) * temperature
labels = torch.arange(N, device=emb_a.device)
loss_a = F.cross_entropy(logits, labels)
loss_b = F.cross_entropy(logits.T, labels)
return (loss_a + loss_b) / 2
def all_pairs_clip_loss(embeddings_dict, temperature):
mods = [k for k, v in embeddings_dict.items() if v is not None]
if len(mods) < 2:
return torch.tensor(0.0, device=temperature.device, requires_grad=True)
total_loss = torch.tensor(0.0, device=temperature.device)
n_pairs = 0
for i in range(len(mods)):
for j in range(i + 1, len(mods)):
total_loss = total_loss + symmetric_clip_loss(
embeddings_dict[mods[i]], embeddings_dict[mods[j]], temperature
)
n_pairs += 1
return total_loss / max(n_pairs, 1)
def property_similarity_loss(embeddings, labels, temperature):
N = embeddings.size(0)
if N < 2:
return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
label_diff = torch.abs(labels.unsqueeze(0) - labels.unsqueeze(1))
max_diff = label_diff.max().clamp(min=1e-6)
label_sim = 1.0 - (label_diff / max_diff)
cos_sim = embeddings @ embeddings.T
mask = torch.eye(N, device=embeddings.device).bool()
cos_sim = cos_sim.masked_fill(mask, 0)
label_sim = label_sim.masked_fill(mask, 0)
return F.mse_loss(cos_sim, label_sim)
# ============================================================================
# Dataset
# ============================================================================
class MatTextPhase1Dataset(Dataset):
"""Phase 1: Multi-modal alignment on pretrain data (no labels)."""
def __init__(self, data, modalities):
self.data = data
self.modalities = modalities
available_cols = set(data.column_names) if hasattr(data, 'column_names') else set(data[0].keys())
self.available_modalities = [m for m in modalities if m in available_cols]
logger.info(f"Phase1 modalities: {self.available_modalities}")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data[idx]
item = {}
for mod in self.available_modalities:
text = row.get(mod, None)
if text and isinstance(text, str) and len(text.strip()) > 0:
item[mod] = text.strip()
else:
item[mod] = None
return item
class MatTextPhase2Dataset(Dataset):
"""Phase 2: Property-conditioned alignment with LaCLIP-style NL descriptions."""
def __init__(self, data, modalities, property_col, property_name, nl_descriptions_per_sample=3):
self.data = data
self.modalities = modalities
self.property_col = property_col
self.property_name = property_name
self.nl_descriptions_per_sample = nl_descriptions_per_sample
self.nl_gen = NLPropertyDescriptionGenerator()
available_cols = set(data.column_names) if hasattr(data, 'column_names') else set(data[0].keys())
self.available_modalities = [m for m in modalities if m in available_cols]
self.has_properties = property_col in available_cols
logger.info(f"Phase2 modalities: {self.available_modalities}")
logger.info(f"Property: {property_name} (col={property_col}, has={self.has_properties})")
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data[idx]
item = {}
for mod in self.available_modalities:
text = row.get(mod, None)
if text and isinstance(text, str) and len(text.strip()) > 0:
item[mod] = text.strip()
else:
item[mod] = None
composition = row.get("composition", "unknown")
crystal_system = row.get("crystal_system", None)
if self.has_properties and row.get(self.property_col) is not None:
label_val = float(row[self.property_col])
item["property_label"] = label_val
item["property_text"] = f"composition: {composition} | {self.property_name}: {label_val:.4f}"
# LaCLIP-style diverse NL descriptions — randomly sample one per call
nl_descs = self.nl_gen.generate_descriptions(
composition=composition,
property_name=self.property_name,
property_value=label_val,
crystal_system=crystal_system,
n=self.nl_descriptions_per_sample,
)
item["nl_property_description"] = random.choice(nl_descs)
else:
item["property_label"] = None
item["property_text"] = None
item["nl_property_description"] = None
return item
def collate_fn(batch, tokenizer, all_modality_keys, max_length):
result = {}
for mod in all_modality_keys:
texts = [item.get(mod) for item in batch]
valid_texts = [t for t in texts if t is not None]
if len(valid_texts) == 0:
result[mod] = None
continue
texts_clean = [t if t is not None else "" for t in texts]
mask_valid = [t is not None for t in texts]
encoded = tokenizer(
texts_clean, padding=True, truncation=True,
max_length=max_length, return_tensors="pt"
)
result[mod] = {
"input_ids": encoded["input_ids"],
"attention_mask": encoded["attention_mask"],
"valid_mask": torch.tensor(mask_valid, dtype=torch.bool),
}
labels = [item.get("property_label") for item in batch]
if any(l is not None for l in labels):
labels_clean = [l if l is not None else 0.0 for l in labels]
labels_mask = [l is not None for l in labels]
result["property_labels"] = torch.tensor(labels_clean, dtype=torch.float32)
result["property_labels_mask"] = torch.tensor(labels_mask, dtype=torch.bool)
else:
result["property_labels"] = None
result["property_labels_mask"] = None
return result
# ============================================================================
# Training Loop
# ============================================================================
def train_epoch(model, dataloader, optimizer, scheduler, config, epoch, phase,
scaler=None, use_trackio=False, global_step=0):
model.train()
total_loss = 0.0
total_clip_loss = 0.0
total_prop_loss = 0.0
total_nl_loss = 0.0
log_interval = 20
autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32)
use_amp = config.use_bf16 or config.use_fp16
optimizer.zero_grad()
for batch_idx, batch in enumerate(dataloader):
step_start = time.time()
available_mods = [m for m in config.modalities if batch.get(m) is not None]
if len(available_mods) > config.max_modalities_per_step:
must_have = [m for m in ["composition", "crystal_text_llm"] if m in available_mods]
remaining = [m for m in available_mods if m not in must_have]
n_sample = max(config.max_modalities_per_step - len(must_have), 1)
sampled = must_have + random.sample(remaining, min(n_sample, len(remaining)))
else:
sampled = available_mods
if phase == 2 and batch.get(config.nl_query_modality) is not None:
if config.nl_query_modality not in sampled:
sampled.append(config.nl_query_modality)
embeddings = {}
with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
for mod in sampled:
if batch.get(mod) is None:
embeddings[mod] = None
continue
input_ids = batch[mod]["input_ids"].to(config.device)
attention_mask = batch[mod]["attention_mask"].to(config.device)
valid_mask = batch[mod]["valid_mask"]
if not valid_mask.any():
embeddings[mod] = None
continue
emb = model.encode(input_ids, attention_mask, mod)
emb = emb * valid_mask.to(config.device).unsqueeze(-1).float()
embeddings[mod] = emb
with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
temperature = model.temperature
clip_l = all_pairs_clip_loss(embeddings, temperature)
prop_l = torch.tensor(0.0, device=config.device)
nl_l = torch.tensor(0.0, device=config.device)
if phase == 2:
if batch.get("property_text") is not None:
prop_ids = batch["property_text"]["input_ids"].to(config.device)
prop_mask_att = batch["property_text"]["attention_mask"].to(config.device)
prop_valid = batch["property_text"]["valid_mask"]
if prop_valid.any():
with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
prop_emb = model.encode(prop_ids, prop_mask_att, "property")
labels = batch["property_labels"].to(config.device)
labels_mask = batch["property_labels_mask"].to(config.device)
if labels_mask.sum() > 1:
prop_l = property_similarity_loss(
prop_emb[labels_mask], labels[labels_mask], temperature
)
for anchor_mod in ["composition", "crystal_text_llm"]:
if embeddings.get(anchor_mod) is not None:
valid_both = labels_mask & batch[anchor_mod]["valid_mask"].to(config.device)
if valid_both.sum() > 1:
with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
prop_clip = symmetric_clip_loss(
prop_emb[valid_both],
embeddings[anchor_mod][valid_both],
temperature,
)
prop_l = prop_l + 0.5 * prop_clip
# NL property description ↔ all structure modalities
if embeddings.get(config.nl_query_modality) is not None:
nl_emb = embeddings[config.nl_query_modality]
nl_valid = batch[config.nl_query_modality]["valid_mask"].to(config.device)
if nl_valid.sum() > 1:
n_nl_pairs = 0
for struct_mod in sampled:
if struct_mod in [config.nl_query_modality, "property_text"]:
continue
if embeddings.get(struct_mod) is None:
continue
struct_valid = batch[struct_mod]["valid_mask"].to(config.device)
valid_both = nl_valid & struct_valid
if valid_both.sum() > 1:
with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
nl_struct_loss = symmetric_clip_loss(
nl_emb[valid_both],
embeddings[struct_mod][valid_both],
temperature,
)
nl_l = nl_l + nl_struct_loss
n_nl_pairs += 1
if n_nl_pairs > 0:
nl_l = nl_l / n_nl_pairs
loss = (clip_l + 0.3 * prop_l + 0.5 * nl_l) / config.grad_accum_steps
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
if (batch_idx + 1) % config.grad_accum_steps == 0:
if scaler is not None:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
scaler.step(optimizer)
scaler.update()
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
global_step += 1
total_loss += loss.item() * config.grad_accum_steps
total_clip_loss += clip_l.item()
total_prop_loss += prop_l.item() if isinstance(prop_l, torch.Tensor) else prop_l
total_nl_loss += nl_l.item() if isinstance(nl_l, torch.Tensor) else nl_l
if (batch_idx + 1) % log_interval == 0:
avg = total_loss / (batch_idx + 1)
avg_clip = total_clip_loss / (batch_idx + 1)
avg_prop = total_prop_loss / (batch_idx + 1)
avg_nl = total_nl_loss / (batch_idx + 1)
lr = scheduler.get_last_lr()[0]
step_time = time.time() - step_start
logger.info(
f"P{phase} E{epoch} | {batch_idx+1}/{len(dataloader)} | "
f"Loss: {avg:.4f} | CLIP: {avg_clip:.4f} | Prop: {avg_prop:.4f} | "
f"NL: {avg_nl:.4f} | LR: {lr:.2e} | T: {model.temperature.item():.3f} | "
f"mods: {len(sampled)} | {step_time:.1f}s/step"
)
if use_trackio:
try:
import trackio
trackio.log({
"phase": phase, "epoch": epoch, "step": global_step,
"loss": avg, "clip_loss": avg_clip, "prop_loss": avg_prop,
"nl_loss": avg_nl, "lr": lr, "temperature": model.temperature.item(),
})
except:
pass
return total_loss / max(len(dataloader), 1), global_step
# ============================================================================
# Evaluation
# ============================================================================
@torch.no_grad()
def evaluate_retrieval(model, dataloader, config, k_values=[1, 5, 10, 20]):
model.eval()
all_embeddings = {mod: [] for mod in config.modalities}
autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32)
use_amp = config.use_bf16 or config.use_fp16
for batch in dataloader:
for mod in config.modalities:
if batch.get(mod) is None:
continue
input_ids = batch[mod]["input_ids"].to(config.device)
attention_mask = batch[mod]["attention_mask"].to(config.device)
valid_mask = batch[mod]["valid_mask"]
if not valid_mask.any():
continue
with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
emb = model.encode(input_ids, attention_mask, mod).float().cpu()
for i in range(len(emb)):
all_embeddings[mod].append(emb[i] if valid_mask[i] else None)
results = {}
eval_pairs = [
("composition", "crystal_text_llm"),
("composition", "cif_symmetrized"),
("composition", "slices"),
("slices", "crystal_text_llm"),
("composition", "zmatrix"),
("composition", "atom_sequences_plusplus"),
("local_env", "composition"),
]
if len([e for e in all_embeddings.get("robocrys_rep", []) if e is not None]) > 0:
eval_pairs.extend([
("robocrys_rep", "composition"),
("robocrys_rep", "cif_symmetrized"),
("robocrys_rep", "slices"),
])
for mod_a, mod_b in eval_pairs:
embs_a = all_embeddings.get(mod_a, [])
embs_b = all_embeddings.get(mod_b, [])
if not embs_a or not embs_b:
continue
valid_idx = [i for i in range(min(len(embs_a), len(embs_b)))
if embs_a[i] is not None and embs_b[i] is not None]
if len(valid_idx) < 10:
continue
ea = torch.stack([embs_a[i] for i in valid_idx])
eb = torch.stack([embs_b[i] for i in valid_idx])
sim = ea @ eb.T
recalls = {}
for k in k_values:
kk = min(k, len(valid_idx) - 1)
if kk < 1:
continue
topk = sim.topk(kk, dim=1).indices
correct = (topk == torch.arange(len(valid_idx)).unsqueeze(1)).any(dim=1)
recalls[f"R@{k}"] = correct.float().mean().item()
results[f"{mod_a}{mod_b}"] = recalls
logger.info(f" {mod_a}{mod_b}: {recalls}")
return results
@torch.no_grad()
def evaluate_nl_queries(model, tokenizer, indices, config):
model.eval()
test_queries = [
("oxide with high bandgap", config.nl_query_modality),
("narrow bandgap semiconductor", config.nl_query_modality),
("stable binary oxide", config.nl_query_modality),
("wide bandgap fluoride", config.nl_query_modality),
("ternary sulfide with low formation energy", config.nl_query_modality),
("metallic nitride", config.nl_query_modality),
("Fe2O3", "composition"),
("SiO2", "composition"),
("TiO2", "composition"),
("GaN", "composition"),
("perovskite structure with octahedral coordination", "robocrys_rep"),
("cubic crystal with face-centered lattice", "robocrys_rep"),
]
results = {}
for query_text, query_modality in test_queries:
try:
hits = search_vector_db(query_text, query_modality, model, tokenizer, indices, config, k=5)
results[query_text] = {
"modality": query_modality,
"top_hits": [(s, m) for s, m in hits],
}
logger.info(f"\nQuery: '{query_text}' (via {query_modality})")
for rank, (score, meta) in enumerate(hits[:5], 1):
logger.info(f" #{rank}: {score:.4f} | {meta.get('composition', 'N/A')} | "
f"via {meta.get('matched_modality', 'N/A')}")
except Exception as e:
logger.warning(f"Query '{query_text}' failed: {e}")
return results
# ============================================================================
# FAISS Vector Database
# ============================================================================
def build_vector_database(model, dataset, tokenizer, config, modalities_to_index=None):
if modalities_to_index is None:
modalities_to_index = ["composition", "crystal_text_llm", "slices",
"cif_symmetrized", "robocrys_rep"]
model.eval()
autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32)
use_amp = config.use_bf16 or config.use_fp16
all_embeddings = {mod: [] for mod in modalities_to_index}
all_metadata = []
bs = 64
for start in range(0, len(dataset), bs):
end = min(start + bs, len(dataset))
items = [dataset[i] for i in range(start, end)]
for item in items:
meta = {
"composition": item.get("composition", ""),
"property_label": item.get("property_label"),
}
all_metadata.append(meta)
all_mod_keys = list(config.modalities)
batch = collate_fn(items, tokenizer, all_mod_keys, config.max_length)
with torch.no_grad():
for mod in modalities_to_index:
if batch.get(mod) is None:
all_embeddings[mod].extend([None] * len(items))
continue
with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
emb = model.encode(
batch[mod]["input_ids"].to(config.device),
batch[mod]["attention_mask"].to(config.device),
mod,
).float().cpu().numpy()
for i in range(len(emb)):
if batch[mod]["valid_mask"][i]:
all_embeddings[mod].append(emb[i])
else:
all_embeddings[mod].append(None)
if (start // bs) % 20 == 0:
logger.info(f"Indexed {end}/{len(dataset)}")
indices = {}
for mod in modalities_to_index:
valid_embs = [e for e in all_embeddings[mod] if e is not None]
valid_map = [i for i, e in enumerate(all_embeddings[mod]) if e is not None]
if not valid_embs:
continue
emb_matrix = np.stack(valid_embs).astype(np.float32)
faiss.normalize_L2(emb_matrix)
d = emb_matrix.shape[1]
if len(valid_embs) > 10000:
nlist = min(100, int(np.sqrt(len(valid_embs))))
quantizer = faiss.IndexFlatIP(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT)
index.train(emb_matrix)
index.nprobe = 10
else:
index = faiss.IndexFlatIP(d)
index.add(emb_matrix)
indices[mod] = {
"index": index,
"valid_indices_map": valid_map,
"metadata": [all_metadata[i] for i in valid_map],
}
logger.info(f"FAISS {mod}: {len(valid_embs)} vectors, dim={d}")
return indices
def search_vector_db(query_text, query_modality, model, tokenizer, indices, config, k=10):
"""Search the vector DB with any modality query.
For NL queries like "oxide with high bandgap": query_modality="nl_property_description"
For composition queries like "Fe2O3": query_modality="composition"
For structure descriptions: query_modality="robocrys_rep"
"""
model.eval()
autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32)
use_amp = config.use_bf16 or config.use_fp16
enc = tokenizer(
[query_text], padding=True, truncation=True,
max_length=config.max_length, return_tensors="pt",
)
with torch.no_grad():
with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp):
q_emb = model.encode(
enc["input_ids"].to(config.device),
enc["attention_mask"].to(config.device),
query_modality,
).float().cpu().numpy().astype(np.float32)
faiss.normalize_L2(q_emb)
results = []
for mod_name, idx_data in indices.items():
scores, ids = idx_data["index"].search(q_emb, k)
for s, i in zip(scores[0], ids[0]):
if i >= 0 and i < len(idx_data["metadata"]):
m = dict(idx_data["metadata"][i])
m["matched_modality"] = mod_name
results.append((float(s), m))
results.sort(key=lambda x: x[0], reverse=True)
seen, unique = set(), []
for s, m in results:
c = m.get("composition", "")
if c not in seen:
seen.add(c)
unique.append((s, m))
if len(unique) >= k:
break
return unique
# ============================================================================
# Main
# ============================================================================
def main():
config = Config()
try:
from flash_attn import flash_attn_func
config.use_flash_attn = True
logger.info("Flash Attention 2 available — enabling")
except ImportError:
config.use_flash_attn = False
logger.info("Flash Attention 2 not available — using default attention")
logger.info(f"Device: {config.device}")
logger.info(f"Precision: {'bf16' if config.use_bf16 else 'fp16' if config.use_fp16 else 'fp32'}")
logger.info(f"Max length: {config.max_length}")
logger.info(f"Batch: {config.batch_size} × {config.grad_accum_steps} = {config.batch_size * config.grad_accum_steps} effective")
logger.info(f"Encoder: {config.encoder_name}")
use_trackio = False
try:
import trackio
trackio.init(project="mattext-embeddings", name=f"align-v2-{config.max_length}ctx")
use_trackio = True
logger.info("Trackio initialized")
except Exception as e:
logger.warning(f"Trackio init failed: {e}")
tokenizer = AutoTokenizer.from_pretrained(config.encoder_name)
model = MatTextEncoder(config).to(config.device)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Total params: {total_params:,} | Trainable: {trainable_params:,}")
# Phase 1
logger.info("=" * 70 + "\nPHASE 1: Multi-modal alignment on pretrain100k_v2\n" + "=" * 70)
pretrain_data = load_dataset(config.dataset_name, config.pretrain_config, split="train")
logger.info(f"Pretrain loaded: {len(pretrain_data)} samples, cols: {pretrain_data.column_names}")
if len(pretrain_data) > config.max_pretrain_samples:
pretrain_data = pretrain_data.shuffle(seed=42).select(range(config.max_pretrain_samples))
logger.info(f"Subsampled to {len(pretrain_data)}")
phase1_dataset = MatTextPhase1Dataset(pretrain_data, config.modalities)
make_collate = lambda mods: lambda batch: collate_fn(batch, tokenizer, mods, config.max_length)
phase1_loader = DataLoader(
phase1_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True,
num_workers=2, collate_fn=make_collate(config.modalities),
pin_memory=(config.device == "cuda"), prefetch_factor=2,
)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
phase1_steps = len(phase1_loader) * config.num_epochs_phase1 // config.grad_accum_steps
scheduler = get_cosine_schedule_with_warmup(optimizer, int(phase1_steps * config.warmup_ratio), phase1_steps)
scaler = torch.amp.GradScaler('cuda') if config.use_fp16 else None
global_step = 0
best_loss = float('inf')
os.makedirs(config.output_dir, exist_ok=True)
for epoch in range(1, config.num_epochs_phase1 + 1):
t0 = time.time()
loss, global_step = train_epoch(
model, phase1_loader, optimizer, scheduler, config,
epoch, phase=1, scaler=scaler, use_trackio=use_trackio, global_step=global_step,
)
elapsed = time.time() - t0
logger.info(f"Phase1 Epoch {epoch}/{config.num_epochs_phase1} | Loss: {loss:.4f} | Time: {elapsed:.0f}s ({elapsed/60:.1f}min)")
if loss < best_loss:
best_loss = loss
torch.save(model.state_dict(), f"{config.output_dir}/best_model_phase1.pt")
logger.info(f" → New best model saved (loss={loss:.4f})")
del pretrain_data, phase1_dataset, phase1_loader
torch.cuda.empty_cache() if torch.cuda.is_available() else None
# Phase 2
logger.info("=" * 70 + "\nPHASE 2: Property-conditioned alignment + NL query training\n" + "=" * 70)
finetune_datasets = []
for ft_cfg, ft_split, prop_name in config.finetune_configs:
try:
ft = load_dataset(config.dataset_name, ft_cfg, split=ft_split)
logger.info(f"Loaded {ft_cfg}/{ft_split}: {len(ft)} samples")
finetune_datasets.append((ft, prop_name))
except Exception as e:
logger.warning(f"Failed to load {ft_cfg}/{ft_split}: {e}")
if finetune_datasets:
all_phase2_datasets = []
for ft_data, prop_name in finetune_datasets:
if len(ft_data) > config.max_finetune_samples // len(finetune_datasets):
n = config.max_finetune_samples // len(finetune_datasets)
ft_data = ft_data.shuffle(seed=42).select(range(n))
phase2_ds = MatTextPhase2Dataset(
ft_data, config.modalities, "labels", prop_name,
nl_descriptions_per_sample=config.nl_descriptions_per_sample,
)
all_phase2_datasets.append(phase2_ds)
logger.info(f"Phase2 dataset ({prop_name}): {len(phase2_ds)} samples")
class ConcatPhase2Dataset(Dataset):
def __init__(self, datasets):
self.datasets = datasets
self.lengths = [len(d) for d in datasets]
self.total = sum(self.lengths)
self.cum_lengths = []
acc = 0
for l in self.lengths:
self.cum_lengths.append(acc)
acc += l
def __len__(self):
return self.total
def __getitem__(self, idx):
for i, (cum, length) in enumerate(zip(self.cum_lengths, self.lengths)):
if idx < cum + length:
return self.datasets[i][idx - cum]
return self.datasets[-1][idx - self.cum_lengths[-1]]
combined_phase2 = ConcatPhase2Dataset(all_phase2_datasets)
phase2_mod_keys = list(config.modalities) + [config.nl_query_modality, "property_text"]
phase2_loader = DataLoader(
combined_phase2, batch_size=config.batch_size, shuffle=True, drop_last=True,
num_workers=2,
collate_fn=lambda batch: collate_fn(batch, tokenizer, phase2_mod_keys, config.max_length),
pin_memory=(config.device == "cuda"), prefetch_factor=2,
)
optimizer2 = torch.optim.AdamW(
model.parameters(), lr=config.learning_rate * 0.5, weight_decay=config.weight_decay,
)
phase2_steps = len(phase2_loader) * config.num_epochs_phase2 // config.grad_accum_steps
scheduler2 = get_cosine_schedule_with_warmup(optimizer2, int(phase2_steps * config.warmup_ratio), phase2_steps)
for epoch in range(1, config.num_epochs_phase2 + 1):
t0 = time.time()
loss, global_step = train_epoch(
model, phase2_loader, optimizer2, scheduler2, config,
epoch, phase=2, scaler=scaler, use_trackio=use_trackio, global_step=global_step,
)
elapsed = time.time() - t0
logger.info(f"Phase2 Epoch {epoch}/{config.num_epochs_phase2} | Loss: {loss:.4f} | Time: {elapsed:.0f}s ({elapsed/60:.1f}min)")
if loss < best_loss:
best_loss = loss
torch.save(model.state_dict(), f"{config.output_dir}/best_model.pt")
logger.info(f" → New best model saved (loss={loss:.4f})")
del combined_phase2, phase2_loader
else:
logger.warning("No finetune data loaded — skipping Phase 2")
# Evaluation
logger.info("=" * 70 + "\nEVALUATION\n" + "=" * 70)
best_path = f"{config.output_dir}/best_model.pt"
if not os.path.exists(best_path):
best_path = f"{config.output_dir}/best_model_phase1.pt"
if os.path.exists(best_path):
model.load_state_dict(torch.load(best_path, map_location=config.device))
logger.info(f"Loaded best model from {best_path}")
eval_data = load_dataset(config.dataset_name, config.pretrain_config, split="test")
if len(eval_data) > 5000:
eval_data = eval_data.shuffle(seed=42).select(range(5000))
logger.info(f"Eval data: {len(eval_data)} samples")
eval_dataset = MatTextPhase1Dataset(eval_data, config.modalities)
eval_loader = DataLoader(
eval_dataset, batch_size=config.batch_size, shuffle=False,
num_workers=2, collate_fn=make_collate(config.modalities),
)
retrieval_results = evaluate_retrieval(model, eval_loader, config)
logger.info("\nBuilding FAISS vector database...")
db_indices = build_vector_database(
model, eval_dataset, tokenizer, config,
modalities_to_index=["composition", "crystal_text_llm", "slices", "cif_symmetrized", "robocrys_rep"],
)
faiss_dir = f"{config.output_dir}/faiss"
os.makedirs(faiss_dir, exist_ok=True)
for mod, d in db_indices.items():
faiss.write_index(d["index"], f"{faiss_dir}/{mod}.index")
with open(f"{faiss_dir}/{mod}_metadata.json", "w") as f:
json.dump(d["metadata"], f)
logger.info("\n" + "=" * 70 + "\nNATURAL LANGUAGE QUERY EVALUATION\n" + "=" * 70)
nl_results = evaluate_nl_queries(model, tokenizer, db_indices, config)
# Save
logger.info("\nSaving model and artifacts...")
torch.save(model.state_dict(), f"{config.output_dir}/model.pt")
tokenizer.save_pretrained(config.output_dir)
model_config = model.get_config_dict()
model_config["training"] = {
"num_epochs_phase1": config.num_epochs_phase1,
"num_epochs_phase2": config.num_epochs_phase2,
"batch_size": config.batch_size,
"grad_accum_steps": config.grad_accum_steps,
"learning_rate": config.learning_rate,
"max_length": config.max_length,
"nl_descriptions_per_sample": config.nl_descriptions_per_sample,
}
with open(f"{config.output_dir}/config.json", "w") as f:
json.dump(model_config, f, indent=2)
with open(f"{config.output_dir}/retrieval_results.json", "w") as f:
json.dump(retrieval_results, f, indent=2)
nl_results_serializable = {}
for k, v in nl_results.items():
nl_results_serializable[k] = {
"modality": v["modality"],
"top_hits": [(s, m) for s, m in v["top_hits"]],
}
with open(f"{config.output_dir}/nl_query_results.json", "w") as f:
json.dump(nl_results_serializable, f, indent=2)
if config.push_to_hub:
try:
api = HfApi()
api.create_repo(config.hub_model_id, exist_ok=True)
api.upload_folder(
folder_path=config.output_dir,
repo_id=config.hub_model_id,
commit_message=f"Upload MatText aligned embeddings v2 (1024 ctx, NL queries)",
)
logger.info(f"✓ Pushed to https://huggingface.co/{config.hub_model_id}")
except Exception as e:
logger.error(f"Push failed: {e}")
logger.info("\n" + "=" * 70)
logger.info("TRAINING COMPLETE")
logger.info(f"Model: {config.output_dir}/model.pt")
logger.info(f"FAISS: {faiss_dir}/")
logger.info(f"Hub: https://huggingface.co/{config.hub_model_id}")
logger.info("=" * 70)
if __name__ == "__main__":
main()