| """ |
| MatText Multi-Modal Embedding Alignment Training (v2) |
| |
| Architecture: CLIP-style contrastive learning across 10+ material text representations |
| + LaCLIP-style natural language property descriptions for free-form querying |
| |
| Key upgrades from v1: |
| - 1024 token context (was 512) — captures long CIFs |
| - Natural language property query support ("oxide with high bandgap") |
| - LaCLIP-style diverse NL description generation from structured labels |
| - A100 80GB optimized (bf16, larger batches, more modalities/step) |
| - Flash Attention 2 when available |
| - Phase 2 aligns NL descriptions ↔ all structure modalities |
| |
| Based on: |
| - MultiMat (AllPairsCLIP, arxiv:2312.00111) |
| - MatExpert (property↔structure InfoNCE, arxiv:2410.21317) |
| - LaCLIP (LLM text augmentation, arxiv:2305.20088) |
| - SupReMix (property-label-aware soft contrastive, arxiv:2309.16633) |
| |
| Usage: |
| pip install torch transformers datasets faiss-cpu huggingface_hub trackio accelerate |
| python train_mattext_embeddings.py |
| """ |
|
|
| import os |
| import json |
| import math |
| import time |
| import logging |
| import random |
| import re |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup |
| from datasets import load_dataset, concatenate_datasets |
| from huggingface_hub import HfApi |
| import faiss |
|
|
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
|
|
| class Config: |
| |
| encoder_name = "answerdotai/ModernBERT-base" |
| embed_dim = 128 |
| max_length = 1024 |
| |
| |
| modalities = [ |
| "composition", |
| "atom_sequences", |
| "cif_symmetrized", |
| "cif_p1", |
| "zmatrix", |
| "atom_sequences_plusplus", |
| "slices", |
| "crystal_text_llm", |
| "local_env", |
| "robocrys_rep", |
| ] |
| |
| |
| |
| nl_query_modality = "nl_property_description" |
| |
| |
| batch_size = 48 |
| learning_rate = 2e-5 |
| weight_decay = 0.01 |
| num_epochs_phase1 = 3 |
| num_epochs_phase2 = 3 |
| warmup_ratio = 0.1 |
| temperature = 0.07 |
| grad_accum_steps = 6 |
| max_grad_norm = 1.0 |
| gradient_checkpointing = True |
| max_modalities_per_step = 5 |
| |
| |
| dataset_name = "n0w0f/MatText" |
| pretrain_config = "pretrain100k_v2" |
| finetune_configs = [ |
| ("bandgap-train-filtered", "fold_0", "bandgap"), |
| ("form_energy-train-filtered", "fold_0", "formation_energy"), |
| ] |
| max_pretrain_samples = 60000 |
| max_finetune_samples = 60000 |
| |
| |
| nl_descriptions_per_sample = 3 |
| |
| |
| output_dir = "mattext-embeddings" |
| hub_model_id = "n0w0f/mattext-aligned-embeddings" |
| push_to_hub = True |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 |
| use_fp16 = torch.cuda.is_available() and not use_bf16 |
| use_flash_attn = False |
|
|
|
|
| |
| |
| |
|
|
| class NLPropertyDescriptionGenerator: |
| """ |
| Generates diverse natural language descriptions from structured material properties. |
| This bridges the gap between structured labels (bandgap=3.2) and free-form queries |
| ("oxide with high bandgap"). LaCLIP-inspired: multiple paraphrases per sample. |
| """ |
| |
| BANDGAP_QUALIFIERS = { |
| (0, 0.01): "zero", |
| (0.01, 0.5): "very narrow", |
| (0.5, 1.5): "narrow", |
| (1.5, 3.0): "moderate", |
| (3.0, 5.0): "wide", |
| (5.0, 100): "very wide", |
| } |
| |
| FENERGY_QUALIFIERS = { |
| (-100, -3.0): "very stable", |
| (-3.0, -1.5): "stable", |
| (-1.5, -0.5): "moderately stable", |
| (-0.5, 0.0): "marginally stable", |
| (0.0, 1.0): "metastable", |
| (1.0, 100): "unstable", |
| } |
| |
| ANION_PATTERNS = [ |
| (r'O\d*$|O\d+[A-Z]', "oxide"), |
| (r'S\d*$|S\d+[A-Z]', "sulfide"), |
| (r'N\d*$|N\d+[A-Z]', "nitride"), |
| (r'F\d*$|F\d+[A-Z]', "fluoride"), |
| (r'Cl\d*$|Cl\d+[A-Z]', "chloride"), |
| (r'Br\d*$|Br\d+[A-Z]', "bromide"), |
| (r'I\d*$|I\d+[A-Z]', "iodide"), |
| (r'Se\d*$|Se\d+[A-Z]', "selenide"), |
| (r'Te\d*$|Te\d+[A-Z]', "telluride"), |
| (r'C\d*$|C\d+[A-Z]', "carbide"), |
| (r'H\d*$|H\d+[A-Z]', "hydride"), |
| ] |
| |
| ELEMENT_COUNT_NAMES = { |
| 1: "elemental", 2: "binary", 3: "ternary", 4: "quaternary", 5: "quinary", |
| } |
| |
| @classmethod |
| def _qualify_bandgap(cls, bg): |
| for (lo, hi), qual in cls.BANDGAP_QUALIFIERS.items(): |
| if lo <= bg < hi: |
| return qual |
| return "moderate" |
| |
| @classmethod |
| def _qualify_fenergy(cls, fe): |
| for (lo, hi), qual in cls.FENERGY_QUALIFIERS.items(): |
| if lo <= fe < hi: |
| return qual |
| return "moderately stable" |
| |
| @classmethod |
| def _detect_anion(cls, composition): |
| for pattern, name in cls.ANION_PATTERNS: |
| if re.search(pattern, composition): |
| return name |
| return "compound" |
| |
| @classmethod |
| def _count_elements(cls, composition): |
| elements = re.findall(r'[A-Z][a-z]?', composition) |
| return len(set(elements)) |
| |
| @classmethod |
| def _get_elements(cls, composition): |
| return list(set(re.findall(r'[A-Z][a-z]?', composition))) |
| |
| @classmethod |
| def generate_descriptions(cls, composition, property_name=None, property_value=None, |
| crystal_system=None, n=3): |
| """Generate n diverse NL descriptions for a material.""" |
| anion_type = cls._detect_anion(composition) |
| n_elements = cls._count_elements(composition) |
| complexity = cls.ELEMENT_COUNT_NAMES.get(n_elements, "complex") |
| |
| property_templates = [] |
| if property_name == "bandgap" and property_value is not None: |
| qual = cls._qualify_bandgap(property_value) |
| property_templates.extend([ |
| f"A {anion_type} material with {qual} bandgap of {property_value:.2f} eV.", |
| f"{composition} is a {complexity} {anion_type} with a {qual} electronic band gap ({property_value:.2f} eV).", |
| f"This {anion_type} has a bandgap of {property_value:.2f} eV, classified as {qual}.", |
| f"A {qual} bandgap {anion_type} ({property_value:.1f} eV) with composition {composition}.", |
| f"{composition}: {anion_type} semiconductor with {qual} band gap of {property_value:.2f} electron volts.", |
| f"An {anion_type} with {qual} bandgap around {property_value:.1f} eV, formula {composition}.", |
| f"This {complexity} {anion_type} ({composition}) exhibits a {qual} bandgap of approximately {property_value:.2f} eV.", |
| f"Material {composition} is a {qual}-gap {anion_type} with bandgap {property_value:.2f} eV.", |
| ]) |
| if property_value > 3.0: |
| property_templates.append( |
| f"{composition} is a wide-gap {anion_type} suitable for UV applications, bandgap {property_value:.2f} eV." |
| ) |
| if property_value < 1.0 and property_value > 0.01: |
| property_templates.append( |
| f"{composition} is a narrow-gap {anion_type}, potentially useful for infrared applications, bandgap {property_value:.2f} eV." |
| ) |
| if property_value < 0.01: |
| property_templates.append( |
| f"{composition} is metallic or near-zero gap {anion_type} with bandgap {property_value:.3f} eV." |
| ) |
| |
| elif property_name == "formation_energy" and property_value is not None: |
| qual = cls._qualify_fenergy(property_value) |
| property_templates.extend([ |
| f"A {qual} {anion_type} with formation energy of {property_value:.3f} eV/atom.", |
| f"{composition} is a {complexity} {anion_type} that is {qual} with formation energy {property_value:.3f} eV/atom.", |
| f"This {anion_type} ({composition}) has a formation energy of {property_value:.3f} eV/atom, making it {qual}.", |
| f"A {qual} {complexity} {anion_type}: {composition}, formation energy = {property_value:.3f} eV/atom.", |
| f"{composition}: thermodynamically {qual} {anion_type} (formation energy {property_value:.3f} eV/atom).", |
| f"This material ({composition}) is a {qual} {anion_type} compound with Ef = {property_value:.3f} eV/atom.", |
| f"A {anion_type} with composition {composition} showing {qual} thermodynamic stability ({property_value:.3f} eV/atom).", |
| ]) |
| |
| composition_templates = [ |
| f"A {complexity} {anion_type} with formula {composition}.", |
| f"{composition} is a {complexity} {anion_type} compound.", |
| f"This material has composition {composition}, a {complexity} {anion_type}.", |
| f"A {anion_type} material: {composition} ({n_elements} elements).", |
| ] |
| if crystal_system: |
| composition_templates.extend([ |
| f"{composition} is a {crystal_system} {anion_type}.", |
| f"A {crystal_system} structured {complexity} {anion_type}: {composition}.", |
| ]) |
| |
| combined_templates = [] |
| if property_name and property_value is not None: |
| if property_name == "bandgap": |
| qual = cls._qualify_bandgap(property_value) |
| combined_templates.extend([ |
| f"{composition} is a {complexity} {anion_type} with {qual} bandgap of {property_value:.2f} eV.", |
| f"A {qual} bandgap {complexity} {anion_type} material, {composition}, with band gap {property_value:.1f} eV.", |
| ]) |
| elif property_name == "formation_energy": |
| qual = cls._qualify_fenergy(property_value) |
| combined_templates.extend([ |
| f"{composition} is a {qual} {complexity} {anion_type} with formation energy {property_value:.3f} eV/atom.", |
| f"A {qual} {anion_type}, {composition}, with Ef = {property_value:.3f} eV/atom.", |
| ]) |
| |
| all_templates = property_templates + composition_templates + combined_templates |
| if not all_templates: |
| all_templates = composition_templates |
| |
| if len(all_templates) >= n: |
| descriptions = random.sample(all_templates, n) |
| else: |
| descriptions = all_templates + random.choices(all_templates, k=n - len(all_templates)) |
| |
| return descriptions |
|
|
|
|
| |
| |
| |
|
|
| class ModalityProjection(nn.Module): |
| """2-layer MLP projection head (MultiMat recipe)""" |
| def __init__(self, input_dim, output_dim): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(input_dim, input_dim), |
| nn.GELU(), |
| nn.LayerNorm(input_dim), |
| nn.Linear(input_dim, output_dim), |
| ) |
| |
| def forward(self, x): |
| return F.normalize(self.net(x), dim=-1) |
|
|
|
|
| class MatTextEncoder(nn.Module): |
| """ |
| Shared transformer encoder with per-modality projection heads. |
| Includes an NL query projection head for free-form text queries. |
| """ |
| def __init__(self, config: Config): |
| super().__init__() |
| self.config = config |
| |
| model_kwargs = {} |
| if config.use_flash_attn: |
| model_kwargs["attn_implementation"] = "flash_attention_2" |
| if config.use_bf16: |
| model_kwargs["torch_dtype"] = torch.bfloat16 |
| |
| self.backbone = AutoModel.from_pretrained(config.encoder_name, **model_kwargs) |
| hidden_size = self.backbone.config.hidden_size |
| |
| if config.gradient_checkpointing: |
| self.backbone.gradient_checkpointing_enable() |
| |
| self.projections = nn.ModuleDict({ |
| mod: ModalityProjection(hidden_size, config.embed_dim) |
| for mod in config.modalities |
| }) |
| |
| |
| self.projections[config.nl_query_modality] = ModalityProjection(hidden_size, config.embed_dim) |
| |
| |
| self.projections["property"] = ModalityProjection(hidden_size, config.embed_dim) |
| |
| self.log_temperature = nn.Parameter( |
| torch.tensor(math.log(1.0 / config.temperature)) |
| ) |
| |
| def encode(self, input_ids, attention_mask, modality_name): |
| outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) |
| mask = attention_mask.unsqueeze(-1).float() |
| hidden = outputs.last_hidden_state |
| pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9) |
| return self.projections[modality_name](pooled) |
| |
| @property |
| def temperature(self): |
| return torch.exp(self.log_temperature).clamp(min=0.01, max=100.0) |
| |
| def get_config_dict(self): |
| return { |
| "encoder_name": self.config.encoder_name, |
| "embed_dim": self.config.embed_dim, |
| "max_length": self.config.max_length, |
| "modalities": self.config.modalities, |
| "nl_query_modality": self.config.nl_query_modality, |
| "temperature": self.temperature.item(), |
| } |
|
|
|
|
| |
| |
| |
|
|
| def symmetric_clip_loss(emb_a, emb_b, temperature): |
| N = emb_a.size(0) |
| if N < 2: |
| return torch.tensor(0.0, device=emb_a.device, requires_grad=True) |
| logits = (emb_a @ emb_b.T) * temperature |
| labels = torch.arange(N, device=emb_a.device) |
| loss_a = F.cross_entropy(logits, labels) |
| loss_b = F.cross_entropy(logits.T, labels) |
| return (loss_a + loss_b) / 2 |
|
|
|
|
| def all_pairs_clip_loss(embeddings_dict, temperature): |
| mods = [k for k, v in embeddings_dict.items() if v is not None] |
| if len(mods) < 2: |
| return torch.tensor(0.0, device=temperature.device, requires_grad=True) |
| |
| total_loss = torch.tensor(0.0, device=temperature.device) |
| n_pairs = 0 |
| |
| for i in range(len(mods)): |
| for j in range(i + 1, len(mods)): |
| total_loss = total_loss + symmetric_clip_loss( |
| embeddings_dict[mods[i]], embeddings_dict[mods[j]], temperature |
| ) |
| n_pairs += 1 |
| |
| return total_loss / max(n_pairs, 1) |
|
|
|
|
| def property_similarity_loss(embeddings, labels, temperature): |
| N = embeddings.size(0) |
| if N < 2: |
| return torch.tensor(0.0, device=embeddings.device, requires_grad=True) |
| |
| label_diff = torch.abs(labels.unsqueeze(0) - labels.unsqueeze(1)) |
| max_diff = label_diff.max().clamp(min=1e-6) |
| label_sim = 1.0 - (label_diff / max_diff) |
| |
| cos_sim = embeddings @ embeddings.T |
| mask = torch.eye(N, device=embeddings.device).bool() |
| cos_sim = cos_sim.masked_fill(mask, 0) |
| label_sim = label_sim.masked_fill(mask, 0) |
| |
| return F.mse_loss(cos_sim, label_sim) |
|
|
|
|
| |
| |
| |
|
|
| class MatTextPhase1Dataset(Dataset): |
| """Phase 1: Multi-modal alignment on pretrain data (no labels).""" |
| def __init__(self, data, modalities): |
| self.data = data |
| self.modalities = modalities |
| available_cols = set(data.column_names) if hasattr(data, 'column_names') else set(data[0].keys()) |
| self.available_modalities = [m for m in modalities if m in available_cols] |
| logger.info(f"Phase1 modalities: {self.available_modalities}") |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__(self, idx): |
| row = self.data[idx] |
| item = {} |
| for mod in self.available_modalities: |
| text = row.get(mod, None) |
| if text and isinstance(text, str) and len(text.strip()) > 0: |
| item[mod] = text.strip() |
| else: |
| item[mod] = None |
| return item |
|
|
|
|
| class MatTextPhase2Dataset(Dataset): |
| """Phase 2: Property-conditioned alignment with LaCLIP-style NL descriptions.""" |
| def __init__(self, data, modalities, property_col, property_name, nl_descriptions_per_sample=3): |
| self.data = data |
| self.modalities = modalities |
| self.property_col = property_col |
| self.property_name = property_name |
| self.nl_descriptions_per_sample = nl_descriptions_per_sample |
| self.nl_gen = NLPropertyDescriptionGenerator() |
| |
| available_cols = set(data.column_names) if hasattr(data, 'column_names') else set(data[0].keys()) |
| self.available_modalities = [m for m in modalities if m in available_cols] |
| self.has_properties = property_col in available_cols |
| |
| logger.info(f"Phase2 modalities: {self.available_modalities}") |
| logger.info(f"Property: {property_name} (col={property_col}, has={self.has_properties})") |
| |
| def __len__(self): |
| return len(self.data) |
| |
| def __getitem__(self, idx): |
| row = self.data[idx] |
| item = {} |
| |
| for mod in self.available_modalities: |
| text = row.get(mod, None) |
| if text and isinstance(text, str) and len(text.strip()) > 0: |
| item[mod] = text.strip() |
| else: |
| item[mod] = None |
| |
| composition = row.get("composition", "unknown") |
| crystal_system = row.get("crystal_system", None) |
| |
| if self.has_properties and row.get(self.property_col) is not None: |
| label_val = float(row[self.property_col]) |
| item["property_label"] = label_val |
| item["property_text"] = f"composition: {composition} | {self.property_name}: {label_val:.4f}" |
| |
| |
| nl_descs = self.nl_gen.generate_descriptions( |
| composition=composition, |
| property_name=self.property_name, |
| property_value=label_val, |
| crystal_system=crystal_system, |
| n=self.nl_descriptions_per_sample, |
| ) |
| item["nl_property_description"] = random.choice(nl_descs) |
| else: |
| item["property_label"] = None |
| item["property_text"] = None |
| item["nl_property_description"] = None |
| |
| return item |
|
|
|
|
| def collate_fn(batch, tokenizer, all_modality_keys, max_length): |
| result = {} |
| |
| for mod in all_modality_keys: |
| texts = [item.get(mod) for item in batch] |
| valid_texts = [t for t in texts if t is not None] |
| if len(valid_texts) == 0: |
| result[mod] = None |
| continue |
| |
| texts_clean = [t if t is not None else "" for t in texts] |
| mask_valid = [t is not None for t in texts] |
| |
| encoded = tokenizer( |
| texts_clean, padding=True, truncation=True, |
| max_length=max_length, return_tensors="pt" |
| ) |
| result[mod] = { |
| "input_ids": encoded["input_ids"], |
| "attention_mask": encoded["attention_mask"], |
| "valid_mask": torch.tensor(mask_valid, dtype=torch.bool), |
| } |
| |
| labels = [item.get("property_label") for item in batch] |
| if any(l is not None for l in labels): |
| labels_clean = [l if l is not None else 0.0 for l in labels] |
| labels_mask = [l is not None for l in labels] |
| result["property_labels"] = torch.tensor(labels_clean, dtype=torch.float32) |
| result["property_labels_mask"] = torch.tensor(labels_mask, dtype=torch.bool) |
| else: |
| result["property_labels"] = None |
| result["property_labels_mask"] = None |
| |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def train_epoch(model, dataloader, optimizer, scheduler, config, epoch, phase, |
| scaler=None, use_trackio=False, global_step=0): |
| model.train() |
| total_loss = 0.0 |
| total_clip_loss = 0.0 |
| total_prop_loss = 0.0 |
| total_nl_loss = 0.0 |
| log_interval = 20 |
| |
| autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32) |
| use_amp = config.use_bf16 or config.use_fp16 |
| |
| optimizer.zero_grad() |
| |
| for batch_idx, batch in enumerate(dataloader): |
| step_start = time.time() |
| |
| available_mods = [m for m in config.modalities if batch.get(m) is not None] |
| if len(available_mods) > config.max_modalities_per_step: |
| must_have = [m for m in ["composition", "crystal_text_llm"] if m in available_mods] |
| remaining = [m for m in available_mods if m not in must_have] |
| n_sample = max(config.max_modalities_per_step - len(must_have), 1) |
| sampled = must_have + random.sample(remaining, min(n_sample, len(remaining))) |
| else: |
| sampled = available_mods |
| |
| if phase == 2 and batch.get(config.nl_query_modality) is not None: |
| if config.nl_query_modality not in sampled: |
| sampled.append(config.nl_query_modality) |
| |
| embeddings = {} |
| with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): |
| for mod in sampled: |
| if batch.get(mod) is None: |
| embeddings[mod] = None |
| continue |
| |
| input_ids = batch[mod]["input_ids"].to(config.device) |
| attention_mask = batch[mod]["attention_mask"].to(config.device) |
| valid_mask = batch[mod]["valid_mask"] |
| |
| if not valid_mask.any(): |
| embeddings[mod] = None |
| continue |
| |
| emb = model.encode(input_ids, attention_mask, mod) |
| emb = emb * valid_mask.to(config.device).unsqueeze(-1).float() |
| embeddings[mod] = emb |
| |
| with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): |
| temperature = model.temperature |
| clip_l = all_pairs_clip_loss(embeddings, temperature) |
| |
| prop_l = torch.tensor(0.0, device=config.device) |
| nl_l = torch.tensor(0.0, device=config.device) |
| |
| if phase == 2: |
| if batch.get("property_text") is not None: |
| prop_ids = batch["property_text"]["input_ids"].to(config.device) |
| prop_mask_att = batch["property_text"]["attention_mask"].to(config.device) |
| prop_valid = batch["property_text"]["valid_mask"] |
| |
| if prop_valid.any(): |
| with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): |
| prop_emb = model.encode(prop_ids, prop_mask_att, "property") |
| |
| labels = batch["property_labels"].to(config.device) |
| labels_mask = batch["property_labels_mask"].to(config.device) |
| |
| if labels_mask.sum() > 1: |
| prop_l = property_similarity_loss( |
| prop_emb[labels_mask], labels[labels_mask], temperature |
| ) |
| |
| for anchor_mod in ["composition", "crystal_text_llm"]: |
| if embeddings.get(anchor_mod) is not None: |
| valid_both = labels_mask & batch[anchor_mod]["valid_mask"].to(config.device) |
| if valid_both.sum() > 1: |
| with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): |
| prop_clip = symmetric_clip_loss( |
| prop_emb[valid_both], |
| embeddings[anchor_mod][valid_both], |
| temperature, |
| ) |
| prop_l = prop_l + 0.5 * prop_clip |
| |
| |
| if embeddings.get(config.nl_query_modality) is not None: |
| nl_emb = embeddings[config.nl_query_modality] |
| nl_valid = batch[config.nl_query_modality]["valid_mask"].to(config.device) |
| |
| if nl_valid.sum() > 1: |
| n_nl_pairs = 0 |
| for struct_mod in sampled: |
| if struct_mod in [config.nl_query_modality, "property_text"]: |
| continue |
| if embeddings.get(struct_mod) is None: |
| continue |
| struct_valid = batch[struct_mod]["valid_mask"].to(config.device) |
| valid_both = nl_valid & struct_valid |
| if valid_both.sum() > 1: |
| with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): |
| nl_struct_loss = symmetric_clip_loss( |
| nl_emb[valid_both], |
| embeddings[struct_mod][valid_both], |
| temperature, |
| ) |
| nl_l = nl_l + nl_struct_loss |
| n_nl_pairs += 1 |
| if n_nl_pairs > 0: |
| nl_l = nl_l / n_nl_pairs |
| |
| loss = (clip_l + 0.3 * prop_l + 0.5 * nl_l) / config.grad_accum_steps |
| |
| if scaler is not None: |
| scaler.scale(loss).backward() |
| else: |
| loss.backward() |
| |
| if (batch_idx + 1) % config.grad_accum_steps == 0: |
| if scaler is not None: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
| optimizer.step() |
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
| |
| total_loss += loss.item() * config.grad_accum_steps |
| total_clip_loss += clip_l.item() |
| total_prop_loss += prop_l.item() if isinstance(prop_l, torch.Tensor) else prop_l |
| total_nl_loss += nl_l.item() if isinstance(nl_l, torch.Tensor) else nl_l |
| |
| if (batch_idx + 1) % log_interval == 0: |
| avg = total_loss / (batch_idx + 1) |
| avg_clip = total_clip_loss / (batch_idx + 1) |
| avg_prop = total_prop_loss / (batch_idx + 1) |
| avg_nl = total_nl_loss / (batch_idx + 1) |
| lr = scheduler.get_last_lr()[0] |
| step_time = time.time() - step_start |
| |
| logger.info( |
| f"P{phase} E{epoch} | {batch_idx+1}/{len(dataloader)} | " |
| f"Loss: {avg:.4f} | CLIP: {avg_clip:.4f} | Prop: {avg_prop:.4f} | " |
| f"NL: {avg_nl:.4f} | LR: {lr:.2e} | T: {model.temperature.item():.3f} | " |
| f"mods: {len(sampled)} | {step_time:.1f}s/step" |
| ) |
| |
| if use_trackio: |
| try: |
| import trackio |
| trackio.log({ |
| "phase": phase, "epoch": epoch, "step": global_step, |
| "loss": avg, "clip_loss": avg_clip, "prop_loss": avg_prop, |
| "nl_loss": avg_nl, "lr": lr, "temperature": model.temperature.item(), |
| }) |
| except: |
| pass |
| |
| return total_loss / max(len(dataloader), 1), global_step |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def evaluate_retrieval(model, dataloader, config, k_values=[1, 5, 10, 20]): |
| model.eval() |
| all_embeddings = {mod: [] for mod in config.modalities} |
| |
| autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32) |
| use_amp = config.use_bf16 or config.use_fp16 |
| |
| for batch in dataloader: |
| for mod in config.modalities: |
| if batch.get(mod) is None: |
| continue |
| input_ids = batch[mod]["input_ids"].to(config.device) |
| attention_mask = batch[mod]["attention_mask"].to(config.device) |
| valid_mask = batch[mod]["valid_mask"] |
| if not valid_mask.any(): |
| continue |
| |
| with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): |
| emb = model.encode(input_ids, attention_mask, mod).float().cpu() |
| |
| for i in range(len(emb)): |
| all_embeddings[mod].append(emb[i] if valid_mask[i] else None) |
| |
| results = {} |
| eval_pairs = [ |
| ("composition", "crystal_text_llm"), |
| ("composition", "cif_symmetrized"), |
| ("composition", "slices"), |
| ("slices", "crystal_text_llm"), |
| ("composition", "zmatrix"), |
| ("composition", "atom_sequences_plusplus"), |
| ("local_env", "composition"), |
| ] |
| if len([e for e in all_embeddings.get("robocrys_rep", []) if e is not None]) > 0: |
| eval_pairs.extend([ |
| ("robocrys_rep", "composition"), |
| ("robocrys_rep", "cif_symmetrized"), |
| ("robocrys_rep", "slices"), |
| ]) |
| |
| for mod_a, mod_b in eval_pairs: |
| embs_a = all_embeddings.get(mod_a, []) |
| embs_b = all_embeddings.get(mod_b, []) |
| if not embs_a or not embs_b: |
| continue |
| |
| valid_idx = [i for i in range(min(len(embs_a), len(embs_b))) |
| if embs_a[i] is not None and embs_b[i] is not None] |
| if len(valid_idx) < 10: |
| continue |
| |
| ea = torch.stack([embs_a[i] for i in valid_idx]) |
| eb = torch.stack([embs_b[i] for i in valid_idx]) |
| sim = ea @ eb.T |
| |
| recalls = {} |
| for k in k_values: |
| kk = min(k, len(valid_idx) - 1) |
| if kk < 1: |
| continue |
| topk = sim.topk(kk, dim=1).indices |
| correct = (topk == torch.arange(len(valid_idx)).unsqueeze(1)).any(dim=1) |
| recalls[f"R@{k}"] = correct.float().mean().item() |
| |
| results[f"{mod_a}→{mod_b}"] = recalls |
| logger.info(f" {mod_a}→{mod_b}: {recalls}") |
| |
| return results |
|
|
|
|
| @torch.no_grad() |
| def evaluate_nl_queries(model, tokenizer, indices, config): |
| model.eval() |
| |
| test_queries = [ |
| ("oxide with high bandgap", config.nl_query_modality), |
| ("narrow bandgap semiconductor", config.nl_query_modality), |
| ("stable binary oxide", config.nl_query_modality), |
| ("wide bandgap fluoride", config.nl_query_modality), |
| ("ternary sulfide with low formation energy", config.nl_query_modality), |
| ("metallic nitride", config.nl_query_modality), |
| ("Fe2O3", "composition"), |
| ("SiO2", "composition"), |
| ("TiO2", "composition"), |
| ("GaN", "composition"), |
| ("perovskite structure with octahedral coordination", "robocrys_rep"), |
| ("cubic crystal with face-centered lattice", "robocrys_rep"), |
| ] |
| |
| results = {} |
| for query_text, query_modality in test_queries: |
| try: |
| hits = search_vector_db(query_text, query_modality, model, tokenizer, indices, config, k=5) |
| results[query_text] = { |
| "modality": query_modality, |
| "top_hits": [(s, m) for s, m in hits], |
| } |
| logger.info(f"\nQuery: '{query_text}' (via {query_modality})") |
| for rank, (score, meta) in enumerate(hits[:5], 1): |
| logger.info(f" #{rank}: {score:.4f} | {meta.get('composition', 'N/A')} | " |
| f"via {meta.get('matched_modality', 'N/A')}") |
| except Exception as e: |
| logger.warning(f"Query '{query_text}' failed: {e}") |
| |
| return results |
|
|
|
|
| |
| |
| |
|
|
| def build_vector_database(model, dataset, tokenizer, config, modalities_to_index=None): |
| if modalities_to_index is None: |
| modalities_to_index = ["composition", "crystal_text_llm", "slices", |
| "cif_symmetrized", "robocrys_rep"] |
| model.eval() |
| |
| autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32) |
| use_amp = config.use_bf16 or config.use_fp16 |
| |
| all_embeddings = {mod: [] for mod in modalities_to_index} |
| all_metadata = [] |
| bs = 64 |
| |
| for start in range(0, len(dataset), bs): |
| end = min(start + bs, len(dataset)) |
| items = [dataset[i] for i in range(start, end)] |
| |
| for item in items: |
| meta = { |
| "composition": item.get("composition", ""), |
| "property_label": item.get("property_label"), |
| } |
| all_metadata.append(meta) |
| |
| all_mod_keys = list(config.modalities) |
| batch = collate_fn(items, tokenizer, all_mod_keys, config.max_length) |
| |
| with torch.no_grad(): |
| for mod in modalities_to_index: |
| if batch.get(mod) is None: |
| all_embeddings[mod].extend([None] * len(items)) |
| continue |
| with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): |
| emb = model.encode( |
| batch[mod]["input_ids"].to(config.device), |
| batch[mod]["attention_mask"].to(config.device), |
| mod, |
| ).float().cpu().numpy() |
| for i in range(len(emb)): |
| if batch[mod]["valid_mask"][i]: |
| all_embeddings[mod].append(emb[i]) |
| else: |
| all_embeddings[mod].append(None) |
| |
| if (start // bs) % 20 == 0: |
| logger.info(f"Indexed {end}/{len(dataset)}") |
| |
| indices = {} |
| for mod in modalities_to_index: |
| valid_embs = [e for e in all_embeddings[mod] if e is not None] |
| valid_map = [i for i, e in enumerate(all_embeddings[mod]) if e is not None] |
| if not valid_embs: |
| continue |
| |
| emb_matrix = np.stack(valid_embs).astype(np.float32) |
| faiss.normalize_L2(emb_matrix) |
| d = emb_matrix.shape[1] |
| |
| if len(valid_embs) > 10000: |
| nlist = min(100, int(np.sqrt(len(valid_embs)))) |
| quantizer = faiss.IndexFlatIP(d) |
| index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT) |
| index.train(emb_matrix) |
| index.nprobe = 10 |
| else: |
| index = faiss.IndexFlatIP(d) |
| |
| index.add(emb_matrix) |
| indices[mod] = { |
| "index": index, |
| "valid_indices_map": valid_map, |
| "metadata": [all_metadata[i] for i in valid_map], |
| } |
| logger.info(f"FAISS {mod}: {len(valid_embs)} vectors, dim={d}") |
| |
| return indices |
|
|
|
|
| def search_vector_db(query_text, query_modality, model, tokenizer, indices, config, k=10): |
| """Search the vector DB with any modality query. |
| |
| For NL queries like "oxide with high bandgap": query_modality="nl_property_description" |
| For composition queries like "Fe2O3": query_modality="composition" |
| For structure descriptions: query_modality="robocrys_rep" |
| """ |
| model.eval() |
| |
| autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32) |
| use_amp = config.use_bf16 or config.use_fp16 |
| |
| enc = tokenizer( |
| [query_text], padding=True, truncation=True, |
| max_length=config.max_length, return_tensors="pt", |
| ) |
| |
| with torch.no_grad(): |
| with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): |
| q_emb = model.encode( |
| enc["input_ids"].to(config.device), |
| enc["attention_mask"].to(config.device), |
| query_modality, |
| ).float().cpu().numpy().astype(np.float32) |
| |
| faiss.normalize_L2(q_emb) |
| |
| results = [] |
| for mod_name, idx_data in indices.items(): |
| scores, ids = idx_data["index"].search(q_emb, k) |
| for s, i in zip(scores[0], ids[0]): |
| if i >= 0 and i < len(idx_data["metadata"]): |
| m = dict(idx_data["metadata"][i]) |
| m["matched_modality"] = mod_name |
| results.append((float(s), m)) |
| |
| results.sort(key=lambda x: x[0], reverse=True) |
| seen, unique = set(), [] |
| for s, m in results: |
| c = m.get("composition", "") |
| if c not in seen: |
| seen.add(c) |
| unique.append((s, m)) |
| if len(unique) >= k: |
| break |
| return unique |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| config = Config() |
| |
| try: |
| from flash_attn import flash_attn_func |
| config.use_flash_attn = True |
| logger.info("Flash Attention 2 available — enabling") |
| except ImportError: |
| config.use_flash_attn = False |
| logger.info("Flash Attention 2 not available — using default attention") |
| |
| logger.info(f"Device: {config.device}") |
| logger.info(f"Precision: {'bf16' if config.use_bf16 else 'fp16' if config.use_fp16 else 'fp32'}") |
| logger.info(f"Max length: {config.max_length}") |
| logger.info(f"Batch: {config.batch_size} × {config.grad_accum_steps} = {config.batch_size * config.grad_accum_steps} effective") |
| logger.info(f"Encoder: {config.encoder_name}") |
| |
| use_trackio = False |
| try: |
| import trackio |
| trackio.init(project="mattext-embeddings", name=f"align-v2-{config.max_length}ctx") |
| use_trackio = True |
| logger.info("Trackio initialized") |
| except Exception as e: |
| logger.warning(f"Trackio init failed: {e}") |
| |
| tokenizer = AutoTokenizer.from_pretrained(config.encoder_name) |
| model = MatTextEncoder(config).to(config.device) |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| logger.info(f"Total params: {total_params:,} | Trainable: {trainable_params:,}") |
| |
| |
| logger.info("=" * 70 + "\nPHASE 1: Multi-modal alignment on pretrain100k_v2\n" + "=" * 70) |
| |
| pretrain_data = load_dataset(config.dataset_name, config.pretrain_config, split="train") |
| logger.info(f"Pretrain loaded: {len(pretrain_data)} samples, cols: {pretrain_data.column_names}") |
| |
| if len(pretrain_data) > config.max_pretrain_samples: |
| pretrain_data = pretrain_data.shuffle(seed=42).select(range(config.max_pretrain_samples)) |
| logger.info(f"Subsampled to {len(pretrain_data)}") |
| |
| phase1_dataset = MatTextPhase1Dataset(pretrain_data, config.modalities) |
| make_collate = lambda mods: lambda batch: collate_fn(batch, tokenizer, mods, config.max_length) |
| |
| phase1_loader = DataLoader( |
| phase1_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True, |
| num_workers=2, collate_fn=make_collate(config.modalities), |
| pin_memory=(config.device == "cuda"), prefetch_factor=2, |
| ) |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) |
| phase1_steps = len(phase1_loader) * config.num_epochs_phase1 // config.grad_accum_steps |
| scheduler = get_cosine_schedule_with_warmup(optimizer, int(phase1_steps * config.warmup_ratio), phase1_steps) |
| scaler = torch.amp.GradScaler('cuda') if config.use_fp16 else None |
| |
| global_step = 0 |
| best_loss = float('inf') |
| os.makedirs(config.output_dir, exist_ok=True) |
| |
| for epoch in range(1, config.num_epochs_phase1 + 1): |
| t0 = time.time() |
| loss, global_step = train_epoch( |
| model, phase1_loader, optimizer, scheduler, config, |
| epoch, phase=1, scaler=scaler, use_trackio=use_trackio, global_step=global_step, |
| ) |
| elapsed = time.time() - t0 |
| logger.info(f"Phase1 Epoch {epoch}/{config.num_epochs_phase1} | Loss: {loss:.4f} | Time: {elapsed:.0f}s ({elapsed/60:.1f}min)") |
| if loss < best_loss: |
| best_loss = loss |
| torch.save(model.state_dict(), f"{config.output_dir}/best_model_phase1.pt") |
| logger.info(f" → New best model saved (loss={loss:.4f})") |
| |
| del pretrain_data, phase1_dataset, phase1_loader |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None |
| |
| |
| logger.info("=" * 70 + "\nPHASE 2: Property-conditioned alignment + NL query training\n" + "=" * 70) |
| |
| finetune_datasets = [] |
| for ft_cfg, ft_split, prop_name in config.finetune_configs: |
| try: |
| ft = load_dataset(config.dataset_name, ft_cfg, split=ft_split) |
| logger.info(f"Loaded {ft_cfg}/{ft_split}: {len(ft)} samples") |
| finetune_datasets.append((ft, prop_name)) |
| except Exception as e: |
| logger.warning(f"Failed to load {ft_cfg}/{ft_split}: {e}") |
| |
| if finetune_datasets: |
| all_phase2_datasets = [] |
| for ft_data, prop_name in finetune_datasets: |
| if len(ft_data) > config.max_finetune_samples // len(finetune_datasets): |
| n = config.max_finetune_samples // len(finetune_datasets) |
| ft_data = ft_data.shuffle(seed=42).select(range(n)) |
| |
| phase2_ds = MatTextPhase2Dataset( |
| ft_data, config.modalities, "labels", prop_name, |
| nl_descriptions_per_sample=config.nl_descriptions_per_sample, |
| ) |
| all_phase2_datasets.append(phase2_ds) |
| logger.info(f"Phase2 dataset ({prop_name}): {len(phase2_ds)} samples") |
| |
| class ConcatPhase2Dataset(Dataset): |
| def __init__(self, datasets): |
| self.datasets = datasets |
| self.lengths = [len(d) for d in datasets] |
| self.total = sum(self.lengths) |
| self.cum_lengths = [] |
| acc = 0 |
| for l in self.lengths: |
| self.cum_lengths.append(acc) |
| acc += l |
| def __len__(self): |
| return self.total |
| def __getitem__(self, idx): |
| for i, (cum, length) in enumerate(zip(self.cum_lengths, self.lengths)): |
| if idx < cum + length: |
| return self.datasets[i][idx - cum] |
| return self.datasets[-1][idx - self.cum_lengths[-1]] |
| |
| combined_phase2 = ConcatPhase2Dataset(all_phase2_datasets) |
| phase2_mod_keys = list(config.modalities) + [config.nl_query_modality, "property_text"] |
| |
| phase2_loader = DataLoader( |
| combined_phase2, batch_size=config.batch_size, shuffle=True, drop_last=True, |
| num_workers=2, |
| collate_fn=lambda batch: collate_fn(batch, tokenizer, phase2_mod_keys, config.max_length), |
| pin_memory=(config.device == "cuda"), prefetch_factor=2, |
| ) |
| |
| optimizer2 = torch.optim.AdamW( |
| model.parameters(), lr=config.learning_rate * 0.5, weight_decay=config.weight_decay, |
| ) |
| phase2_steps = len(phase2_loader) * config.num_epochs_phase2 // config.grad_accum_steps |
| scheduler2 = get_cosine_schedule_with_warmup(optimizer2, int(phase2_steps * config.warmup_ratio), phase2_steps) |
| |
| for epoch in range(1, config.num_epochs_phase2 + 1): |
| t0 = time.time() |
| loss, global_step = train_epoch( |
| model, phase2_loader, optimizer2, scheduler2, config, |
| epoch, phase=2, scaler=scaler, use_trackio=use_trackio, global_step=global_step, |
| ) |
| elapsed = time.time() - t0 |
| logger.info(f"Phase2 Epoch {epoch}/{config.num_epochs_phase2} | Loss: {loss:.4f} | Time: {elapsed:.0f}s ({elapsed/60:.1f}min)") |
| if loss < best_loss: |
| best_loss = loss |
| torch.save(model.state_dict(), f"{config.output_dir}/best_model.pt") |
| logger.info(f" → New best model saved (loss={loss:.4f})") |
| |
| del combined_phase2, phase2_loader |
| else: |
| logger.warning("No finetune data loaded — skipping Phase 2") |
| |
| |
| logger.info("=" * 70 + "\nEVALUATION\n" + "=" * 70) |
| |
| best_path = f"{config.output_dir}/best_model.pt" |
| if not os.path.exists(best_path): |
| best_path = f"{config.output_dir}/best_model_phase1.pt" |
| if os.path.exists(best_path): |
| model.load_state_dict(torch.load(best_path, map_location=config.device)) |
| logger.info(f"Loaded best model from {best_path}") |
| |
| eval_data = load_dataset(config.dataset_name, config.pretrain_config, split="test") |
| if len(eval_data) > 5000: |
| eval_data = eval_data.shuffle(seed=42).select(range(5000)) |
| logger.info(f"Eval data: {len(eval_data)} samples") |
| |
| eval_dataset = MatTextPhase1Dataset(eval_data, config.modalities) |
| eval_loader = DataLoader( |
| eval_dataset, batch_size=config.batch_size, shuffle=False, |
| num_workers=2, collate_fn=make_collate(config.modalities), |
| ) |
| |
| retrieval_results = evaluate_retrieval(model, eval_loader, config) |
| |
| logger.info("\nBuilding FAISS vector database...") |
| db_indices = build_vector_database( |
| model, eval_dataset, tokenizer, config, |
| modalities_to_index=["composition", "crystal_text_llm", "slices", "cif_symmetrized", "robocrys_rep"], |
| ) |
| |
| faiss_dir = f"{config.output_dir}/faiss" |
| os.makedirs(faiss_dir, exist_ok=True) |
| for mod, d in db_indices.items(): |
| faiss.write_index(d["index"], f"{faiss_dir}/{mod}.index") |
| with open(f"{faiss_dir}/{mod}_metadata.json", "w") as f: |
| json.dump(d["metadata"], f) |
| |
| logger.info("\n" + "=" * 70 + "\nNATURAL LANGUAGE QUERY EVALUATION\n" + "=" * 70) |
| nl_results = evaluate_nl_queries(model, tokenizer, db_indices, config) |
| |
| |
| logger.info("\nSaving model and artifacts...") |
| torch.save(model.state_dict(), f"{config.output_dir}/model.pt") |
| tokenizer.save_pretrained(config.output_dir) |
| |
| model_config = model.get_config_dict() |
| model_config["training"] = { |
| "num_epochs_phase1": config.num_epochs_phase1, |
| "num_epochs_phase2": config.num_epochs_phase2, |
| "batch_size": config.batch_size, |
| "grad_accum_steps": config.grad_accum_steps, |
| "learning_rate": config.learning_rate, |
| "max_length": config.max_length, |
| "nl_descriptions_per_sample": config.nl_descriptions_per_sample, |
| } |
| with open(f"{config.output_dir}/config.json", "w") as f: |
| json.dump(model_config, f, indent=2) |
| |
| with open(f"{config.output_dir}/retrieval_results.json", "w") as f: |
| json.dump(retrieval_results, f, indent=2) |
| |
| nl_results_serializable = {} |
| for k, v in nl_results.items(): |
| nl_results_serializable[k] = { |
| "modality": v["modality"], |
| "top_hits": [(s, m) for s, m in v["top_hits"]], |
| } |
| with open(f"{config.output_dir}/nl_query_results.json", "w") as f: |
| json.dump(nl_results_serializable, f, indent=2) |
| |
| if config.push_to_hub: |
| try: |
| api = HfApi() |
| api.create_repo(config.hub_model_id, exist_ok=True) |
| api.upload_folder( |
| folder_path=config.output_dir, |
| repo_id=config.hub_model_id, |
| commit_message=f"Upload MatText aligned embeddings v2 (1024 ctx, NL queries)", |
| ) |
| logger.info(f"✓ Pushed to https://huggingface.co/{config.hub_model_id}") |
| except Exception as e: |
| logger.error(f"Push failed: {e}") |
| |
| logger.info("\n" + "=" * 70) |
| logger.info("TRAINING COMPLETE") |
| logger.info(f"Model: {config.output_dir}/model.pt") |
| logger.info(f"FAISS: {faiss_dir}/") |
| logger.info(f"Hub: https://huggingface.co/{config.hub_model_id}") |
| logger.info("=" * 70) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|