""" MatText Multi-Modal Embedding Alignment Training (v2) Architecture: CLIP-style contrastive learning across 10+ material text representations + LaCLIP-style natural language property descriptions for free-form querying Key upgrades from v1: - 1024 token context (was 512) — captures long CIFs - Natural language property query support ("oxide with high bandgap") - LaCLIP-style diverse NL description generation from structured labels - A100 80GB optimized (bf16, larger batches, more modalities/step) - Flash Attention 2 when available - Phase 2 aligns NL descriptions ↔ all structure modalities Based on: - MultiMat (AllPairsCLIP, arxiv:2312.00111) - MatExpert (property↔structure InfoNCE, arxiv:2410.21317) - LaCLIP (LLM text augmentation, arxiv:2305.20088) - SupReMix (property-label-aware soft contrastive, arxiv:2309.16633) Usage: pip install torch transformers datasets faiss-cpu huggingface_hub trackio accelerate python train_mattext_embeddings.py """ import os import json import math import time import logging import random import re import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup from datasets import load_dataset, concatenate_datasets from huggingface_hub import HfApi import faiss logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # ============================================================================ # Configuration # ============================================================================ class Config: # Model encoder_name = "answerdotai/ModernBERT-base" embed_dim = 128 # projection dimension max_length = 1024 # tokens per modality (ModernBERT pretrained at 1024, extended to 8192) # Modalities to align (columns in the dataset) modalities = [ "composition", "atom_sequences", "cif_symmetrized", "cif_p1", "zmatrix", "atom_sequences_plusplus", "slices", "crystal_text_llm", "local_env", "robocrys_rep", # natural language structural description (pretrain only) ] # Natural language query modality (separate from robocrys_rep) # This is the key modality for queries like "oxide with high bandgap" nl_query_modality = "nl_property_description" # Training batch_size = 48 # A100 80GB can handle this at 1024 ctx with bf16 learning_rate = 2e-5 weight_decay = 0.01 num_epochs_phase1 = 3 num_epochs_phase2 = 3 warmup_ratio = 0.1 temperature = 0.07 grad_accum_steps = 6 # effective batch = 48*6 = 288 max_grad_norm = 1.0 gradient_checkpointing = True max_modalities_per_step = 5 # more than v1 since A100 80GB # Data dataset_name = "n0w0f/MatText" pretrain_config = "pretrain100k_v2" finetune_configs = [ ("bandgap-train-filtered", "fold_0", "bandgap"), ("form_energy-train-filtered", "fold_0", "formation_energy"), ] max_pretrain_samples = 60000 max_finetune_samples = 60000 # NL description generation nl_descriptions_per_sample = 3 # LaCLIP: diverse paraphrases per sample # Output output_dir = "mattext-embeddings" hub_model_id = "n0w0f/mattext-aligned-embeddings" push_to_hub = True # Device device = "cuda" if torch.cuda.is_available() else "cpu" use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 use_fp16 = torch.cuda.is_available() and not use_bf16 use_flash_attn = False # set True if flash-attn is installed # ============================================================================ # NL Property Description Generator (LaCLIP-style) # ============================================================================ class NLPropertyDescriptionGenerator: """ Generates diverse natural language descriptions from structured material properties. This bridges the gap between structured labels (bandgap=3.2) and free-form queries ("oxide with high bandgap"). LaCLIP-inspired: multiple paraphrases per sample. """ BANDGAP_QUALIFIERS = { (0, 0.01): "zero", (0.01, 0.5): "very narrow", (0.5, 1.5): "narrow", (1.5, 3.0): "moderate", (3.0, 5.0): "wide", (5.0, 100): "very wide", } FENERGY_QUALIFIERS = { (-100, -3.0): "very stable", (-3.0, -1.5): "stable", (-1.5, -0.5): "moderately stable", (-0.5, 0.0): "marginally stable", (0.0, 1.0): "metastable", (1.0, 100): "unstable", } ANION_PATTERNS = [ (r'O\d*$|O\d+[A-Z]', "oxide"), (r'S\d*$|S\d+[A-Z]', "sulfide"), (r'N\d*$|N\d+[A-Z]', "nitride"), (r'F\d*$|F\d+[A-Z]', "fluoride"), (r'Cl\d*$|Cl\d+[A-Z]', "chloride"), (r'Br\d*$|Br\d+[A-Z]', "bromide"), (r'I\d*$|I\d+[A-Z]', "iodide"), (r'Se\d*$|Se\d+[A-Z]', "selenide"), (r'Te\d*$|Te\d+[A-Z]', "telluride"), (r'C\d*$|C\d+[A-Z]', "carbide"), (r'H\d*$|H\d+[A-Z]', "hydride"), ] ELEMENT_COUNT_NAMES = { 1: "elemental", 2: "binary", 3: "ternary", 4: "quaternary", 5: "quinary", } @classmethod def _qualify_bandgap(cls, bg): for (lo, hi), qual in cls.BANDGAP_QUALIFIERS.items(): if lo <= bg < hi: return qual return "moderate" @classmethod def _qualify_fenergy(cls, fe): for (lo, hi), qual in cls.FENERGY_QUALIFIERS.items(): if lo <= fe < hi: return qual return "moderately stable" @classmethod def _detect_anion(cls, composition): for pattern, name in cls.ANION_PATTERNS: if re.search(pattern, composition): return name return "compound" @classmethod def _count_elements(cls, composition): elements = re.findall(r'[A-Z][a-z]?', composition) return len(set(elements)) @classmethod def _get_elements(cls, composition): return list(set(re.findall(r'[A-Z][a-z]?', composition))) @classmethod def generate_descriptions(cls, composition, property_name=None, property_value=None, crystal_system=None, n=3): """Generate n diverse NL descriptions for a material.""" anion_type = cls._detect_anion(composition) n_elements = cls._count_elements(composition) complexity = cls.ELEMENT_COUNT_NAMES.get(n_elements, "complex") property_templates = [] if property_name == "bandgap" and property_value is not None: qual = cls._qualify_bandgap(property_value) property_templates.extend([ f"A {anion_type} material with {qual} bandgap of {property_value:.2f} eV.", f"{composition} is a {complexity} {anion_type} with a {qual} electronic band gap ({property_value:.2f} eV).", f"This {anion_type} has a bandgap of {property_value:.2f} eV, classified as {qual}.", f"A {qual} bandgap {anion_type} ({property_value:.1f} eV) with composition {composition}.", f"{composition}: {anion_type} semiconductor with {qual} band gap of {property_value:.2f} electron volts.", f"An {anion_type} with {qual} bandgap around {property_value:.1f} eV, formula {composition}.", f"This {complexity} {anion_type} ({composition}) exhibits a {qual} bandgap of approximately {property_value:.2f} eV.", f"Material {composition} is a {qual}-gap {anion_type} with bandgap {property_value:.2f} eV.", ]) if property_value > 3.0: property_templates.append( f"{composition} is a wide-gap {anion_type} suitable for UV applications, bandgap {property_value:.2f} eV." ) if property_value < 1.0 and property_value > 0.01: property_templates.append( f"{composition} is a narrow-gap {anion_type}, potentially useful for infrared applications, bandgap {property_value:.2f} eV." ) if property_value < 0.01: property_templates.append( f"{composition} is metallic or near-zero gap {anion_type} with bandgap {property_value:.3f} eV." ) elif property_name == "formation_energy" and property_value is not None: qual = cls._qualify_fenergy(property_value) property_templates.extend([ f"A {qual} {anion_type} with formation energy of {property_value:.3f} eV/atom.", f"{composition} is a {complexity} {anion_type} that is {qual} with formation energy {property_value:.3f} eV/atom.", f"This {anion_type} ({composition}) has a formation energy of {property_value:.3f} eV/atom, making it {qual}.", f"A {qual} {complexity} {anion_type}: {composition}, formation energy = {property_value:.3f} eV/atom.", f"{composition}: thermodynamically {qual} {anion_type} (formation energy {property_value:.3f} eV/atom).", f"This material ({composition}) is a {qual} {anion_type} compound with Ef = {property_value:.3f} eV/atom.", f"A {anion_type} with composition {composition} showing {qual} thermodynamic stability ({property_value:.3f} eV/atom).", ]) composition_templates = [ f"A {complexity} {anion_type} with formula {composition}.", f"{composition} is a {complexity} {anion_type} compound.", f"This material has composition {composition}, a {complexity} {anion_type}.", f"A {anion_type} material: {composition} ({n_elements} elements).", ] if crystal_system: composition_templates.extend([ f"{composition} is a {crystal_system} {anion_type}.", f"A {crystal_system} structured {complexity} {anion_type}: {composition}.", ]) combined_templates = [] if property_name and property_value is not None: if property_name == "bandgap": qual = cls._qualify_bandgap(property_value) combined_templates.extend([ f"{composition} is a {complexity} {anion_type} with {qual} bandgap of {property_value:.2f} eV.", f"A {qual} bandgap {complexity} {anion_type} material, {composition}, with band gap {property_value:.1f} eV.", ]) elif property_name == "formation_energy": qual = cls._qualify_fenergy(property_value) combined_templates.extend([ f"{composition} is a {qual} {complexity} {anion_type} with formation energy {property_value:.3f} eV/atom.", f"A {qual} {anion_type}, {composition}, with Ef = {property_value:.3f} eV/atom.", ]) all_templates = property_templates + composition_templates + combined_templates if not all_templates: all_templates = composition_templates if len(all_templates) >= n: descriptions = random.sample(all_templates, n) else: descriptions = all_templates + random.choices(all_templates, k=n - len(all_templates)) return descriptions # ============================================================================ # Model: Shared Encoder + Per-Modality Projection Heads # ============================================================================ class ModalityProjection(nn.Module): """2-layer MLP projection head (MultiMat recipe)""" def __init__(self, input_dim, output_dim): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, input_dim), nn.GELU(), nn.LayerNorm(input_dim), nn.Linear(input_dim, output_dim), ) def forward(self, x): return F.normalize(self.net(x), dim=-1) class MatTextEncoder(nn.Module): """ Shared transformer encoder with per-modality projection heads. Includes an NL query projection head for free-form text queries. """ def __init__(self, config: Config): super().__init__() self.config = config model_kwargs = {} if config.use_flash_attn: model_kwargs["attn_implementation"] = "flash_attention_2" if config.use_bf16: model_kwargs["torch_dtype"] = torch.bfloat16 self.backbone = AutoModel.from_pretrained(config.encoder_name, **model_kwargs) hidden_size = self.backbone.config.hidden_size if config.gradient_checkpointing: self.backbone.gradient_checkpointing_enable() self.projections = nn.ModuleDict({ mod: ModalityProjection(hidden_size, config.embed_dim) for mod in config.modalities }) # NL query head — for "oxide with high bandgap" style queries self.projections[config.nl_query_modality] = ModalityProjection(hidden_size, config.embed_dim) # Property head — for structured property text like "bandgap: 2.1" self.projections["property"] = ModalityProjection(hidden_size, config.embed_dim) self.log_temperature = nn.Parameter( torch.tensor(math.log(1.0 / config.temperature)) ) def encode(self, input_ids, attention_mask, modality_name): outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask) mask = attention_mask.unsqueeze(-1).float() hidden = outputs.last_hidden_state pooled = (hidden * mask).sum(1) / mask.sum(1).clamp(min=1e-9) return self.projections[modality_name](pooled) @property def temperature(self): return torch.exp(self.log_temperature).clamp(min=0.01, max=100.0) def get_config_dict(self): return { "encoder_name": self.config.encoder_name, "embed_dim": self.config.embed_dim, "max_length": self.config.max_length, "modalities": self.config.modalities, "nl_query_modality": self.config.nl_query_modality, "temperature": self.temperature.item(), } # ============================================================================ # Loss Functions # ============================================================================ def symmetric_clip_loss(emb_a, emb_b, temperature): N = emb_a.size(0) if N < 2: return torch.tensor(0.0, device=emb_a.device, requires_grad=True) logits = (emb_a @ emb_b.T) * temperature labels = torch.arange(N, device=emb_a.device) loss_a = F.cross_entropy(logits, labels) loss_b = F.cross_entropy(logits.T, labels) return (loss_a + loss_b) / 2 def all_pairs_clip_loss(embeddings_dict, temperature): mods = [k for k, v in embeddings_dict.items() if v is not None] if len(mods) < 2: return torch.tensor(0.0, device=temperature.device, requires_grad=True) total_loss = torch.tensor(0.0, device=temperature.device) n_pairs = 0 for i in range(len(mods)): for j in range(i + 1, len(mods)): total_loss = total_loss + symmetric_clip_loss( embeddings_dict[mods[i]], embeddings_dict[mods[j]], temperature ) n_pairs += 1 return total_loss / max(n_pairs, 1) def property_similarity_loss(embeddings, labels, temperature): N = embeddings.size(0) if N < 2: return torch.tensor(0.0, device=embeddings.device, requires_grad=True) label_diff = torch.abs(labels.unsqueeze(0) - labels.unsqueeze(1)) max_diff = label_diff.max().clamp(min=1e-6) label_sim = 1.0 - (label_diff / max_diff) cos_sim = embeddings @ embeddings.T mask = torch.eye(N, device=embeddings.device).bool() cos_sim = cos_sim.masked_fill(mask, 0) label_sim = label_sim.masked_fill(mask, 0) return F.mse_loss(cos_sim, label_sim) # ============================================================================ # Dataset # ============================================================================ class MatTextPhase1Dataset(Dataset): """Phase 1: Multi-modal alignment on pretrain data (no labels).""" def __init__(self, data, modalities): self.data = data self.modalities = modalities available_cols = set(data.column_names) if hasattr(data, 'column_names') else set(data[0].keys()) self.available_modalities = [m for m in modalities if m in available_cols] logger.info(f"Phase1 modalities: {self.available_modalities}") def __len__(self): return len(self.data) def __getitem__(self, idx): row = self.data[idx] item = {} for mod in self.available_modalities: text = row.get(mod, None) if text and isinstance(text, str) and len(text.strip()) > 0: item[mod] = text.strip() else: item[mod] = None return item class MatTextPhase2Dataset(Dataset): """Phase 2: Property-conditioned alignment with LaCLIP-style NL descriptions.""" def __init__(self, data, modalities, property_col, property_name, nl_descriptions_per_sample=3): self.data = data self.modalities = modalities self.property_col = property_col self.property_name = property_name self.nl_descriptions_per_sample = nl_descriptions_per_sample self.nl_gen = NLPropertyDescriptionGenerator() available_cols = set(data.column_names) if hasattr(data, 'column_names') else set(data[0].keys()) self.available_modalities = [m for m in modalities if m in available_cols] self.has_properties = property_col in available_cols logger.info(f"Phase2 modalities: {self.available_modalities}") logger.info(f"Property: {property_name} (col={property_col}, has={self.has_properties})") def __len__(self): return len(self.data) def __getitem__(self, idx): row = self.data[idx] item = {} for mod in self.available_modalities: text = row.get(mod, None) if text and isinstance(text, str) and len(text.strip()) > 0: item[mod] = text.strip() else: item[mod] = None composition = row.get("composition", "unknown") crystal_system = row.get("crystal_system", None) if self.has_properties and row.get(self.property_col) is not None: label_val = float(row[self.property_col]) item["property_label"] = label_val item["property_text"] = f"composition: {composition} | {self.property_name}: {label_val:.4f}" # LaCLIP-style diverse NL descriptions — randomly sample one per call nl_descs = self.nl_gen.generate_descriptions( composition=composition, property_name=self.property_name, property_value=label_val, crystal_system=crystal_system, n=self.nl_descriptions_per_sample, ) item["nl_property_description"] = random.choice(nl_descs) else: item["property_label"] = None item["property_text"] = None item["nl_property_description"] = None return item def collate_fn(batch, tokenizer, all_modality_keys, max_length): result = {} for mod in all_modality_keys: texts = [item.get(mod) for item in batch] valid_texts = [t for t in texts if t is not None] if len(valid_texts) == 0: result[mod] = None continue texts_clean = [t if t is not None else "" for t in texts] mask_valid = [t is not None for t in texts] encoded = tokenizer( texts_clean, padding=True, truncation=True, max_length=max_length, return_tensors="pt" ) result[mod] = { "input_ids": encoded["input_ids"], "attention_mask": encoded["attention_mask"], "valid_mask": torch.tensor(mask_valid, dtype=torch.bool), } labels = [item.get("property_label") for item in batch] if any(l is not None for l in labels): labels_clean = [l if l is not None else 0.0 for l in labels] labels_mask = [l is not None for l in labels] result["property_labels"] = torch.tensor(labels_clean, dtype=torch.float32) result["property_labels_mask"] = torch.tensor(labels_mask, dtype=torch.bool) else: result["property_labels"] = None result["property_labels_mask"] = None return result # ============================================================================ # Training Loop # ============================================================================ def train_epoch(model, dataloader, optimizer, scheduler, config, epoch, phase, scaler=None, use_trackio=False, global_step=0): model.train() total_loss = 0.0 total_clip_loss = 0.0 total_prop_loss = 0.0 total_nl_loss = 0.0 log_interval = 20 autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32) use_amp = config.use_bf16 or config.use_fp16 optimizer.zero_grad() for batch_idx, batch in enumerate(dataloader): step_start = time.time() available_mods = [m for m in config.modalities if batch.get(m) is not None] if len(available_mods) > config.max_modalities_per_step: must_have = [m for m in ["composition", "crystal_text_llm"] if m in available_mods] remaining = [m for m in available_mods if m not in must_have] n_sample = max(config.max_modalities_per_step - len(must_have), 1) sampled = must_have + random.sample(remaining, min(n_sample, len(remaining))) else: sampled = available_mods if phase == 2 and batch.get(config.nl_query_modality) is not None: if config.nl_query_modality not in sampled: sampled.append(config.nl_query_modality) embeddings = {} with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): for mod in sampled: if batch.get(mod) is None: embeddings[mod] = None continue input_ids = batch[mod]["input_ids"].to(config.device) attention_mask = batch[mod]["attention_mask"].to(config.device) valid_mask = batch[mod]["valid_mask"] if not valid_mask.any(): embeddings[mod] = None continue emb = model.encode(input_ids, attention_mask, mod) emb = emb * valid_mask.to(config.device).unsqueeze(-1).float() embeddings[mod] = emb with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): temperature = model.temperature clip_l = all_pairs_clip_loss(embeddings, temperature) prop_l = torch.tensor(0.0, device=config.device) nl_l = torch.tensor(0.0, device=config.device) if phase == 2: if batch.get("property_text") is not None: prop_ids = batch["property_text"]["input_ids"].to(config.device) prop_mask_att = batch["property_text"]["attention_mask"].to(config.device) prop_valid = batch["property_text"]["valid_mask"] if prop_valid.any(): with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): prop_emb = model.encode(prop_ids, prop_mask_att, "property") labels = batch["property_labels"].to(config.device) labels_mask = batch["property_labels_mask"].to(config.device) if labels_mask.sum() > 1: prop_l = property_similarity_loss( prop_emb[labels_mask], labels[labels_mask], temperature ) for anchor_mod in ["composition", "crystal_text_llm"]: if embeddings.get(anchor_mod) is not None: valid_both = labels_mask & batch[anchor_mod]["valid_mask"].to(config.device) if valid_both.sum() > 1: with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): prop_clip = symmetric_clip_loss( prop_emb[valid_both], embeddings[anchor_mod][valid_both], temperature, ) prop_l = prop_l + 0.5 * prop_clip # NL property description ↔ all structure modalities if embeddings.get(config.nl_query_modality) is not None: nl_emb = embeddings[config.nl_query_modality] nl_valid = batch[config.nl_query_modality]["valid_mask"].to(config.device) if nl_valid.sum() > 1: n_nl_pairs = 0 for struct_mod in sampled: if struct_mod in [config.nl_query_modality, "property_text"]: continue if embeddings.get(struct_mod) is None: continue struct_valid = batch[struct_mod]["valid_mask"].to(config.device) valid_both = nl_valid & struct_valid if valid_both.sum() > 1: with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): nl_struct_loss = symmetric_clip_loss( nl_emb[valid_both], embeddings[struct_mod][valid_both], temperature, ) nl_l = nl_l + nl_struct_loss n_nl_pairs += 1 if n_nl_pairs > 0: nl_l = nl_l / n_nl_pairs loss = (clip_l + 0.3 * prop_l + 0.5 * nl_l) / config.grad_accum_steps if scaler is not None: scaler.scale(loss).backward() else: loss.backward() if (batch_idx + 1) % config.grad_accum_steps == 0: if scaler is not None: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) scaler.step(optimizer) scaler.update() else: torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) optimizer.step() scheduler.step() optimizer.zero_grad() global_step += 1 total_loss += loss.item() * config.grad_accum_steps total_clip_loss += clip_l.item() total_prop_loss += prop_l.item() if isinstance(prop_l, torch.Tensor) else prop_l total_nl_loss += nl_l.item() if isinstance(nl_l, torch.Tensor) else nl_l if (batch_idx + 1) % log_interval == 0: avg = total_loss / (batch_idx + 1) avg_clip = total_clip_loss / (batch_idx + 1) avg_prop = total_prop_loss / (batch_idx + 1) avg_nl = total_nl_loss / (batch_idx + 1) lr = scheduler.get_last_lr()[0] step_time = time.time() - step_start logger.info( f"P{phase} E{epoch} | {batch_idx+1}/{len(dataloader)} | " f"Loss: {avg:.4f} | CLIP: {avg_clip:.4f} | Prop: {avg_prop:.4f} | " f"NL: {avg_nl:.4f} | LR: {lr:.2e} | T: {model.temperature.item():.3f} | " f"mods: {len(sampled)} | {step_time:.1f}s/step" ) if use_trackio: try: import trackio trackio.log({ "phase": phase, "epoch": epoch, "step": global_step, "loss": avg, "clip_loss": avg_clip, "prop_loss": avg_prop, "nl_loss": avg_nl, "lr": lr, "temperature": model.temperature.item(), }) except: pass return total_loss / max(len(dataloader), 1), global_step # ============================================================================ # Evaluation # ============================================================================ @torch.no_grad() def evaluate_retrieval(model, dataloader, config, k_values=[1, 5, 10, 20]): model.eval() all_embeddings = {mod: [] for mod in config.modalities} autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32) use_amp = config.use_bf16 or config.use_fp16 for batch in dataloader: for mod in config.modalities: if batch.get(mod) is None: continue input_ids = batch[mod]["input_ids"].to(config.device) attention_mask = batch[mod]["attention_mask"].to(config.device) valid_mask = batch[mod]["valid_mask"] if not valid_mask.any(): continue with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): emb = model.encode(input_ids, attention_mask, mod).float().cpu() for i in range(len(emb)): all_embeddings[mod].append(emb[i] if valid_mask[i] else None) results = {} eval_pairs = [ ("composition", "crystal_text_llm"), ("composition", "cif_symmetrized"), ("composition", "slices"), ("slices", "crystal_text_llm"), ("composition", "zmatrix"), ("composition", "atom_sequences_plusplus"), ("local_env", "composition"), ] if len([e for e in all_embeddings.get("robocrys_rep", []) if e is not None]) > 0: eval_pairs.extend([ ("robocrys_rep", "composition"), ("robocrys_rep", "cif_symmetrized"), ("robocrys_rep", "slices"), ]) for mod_a, mod_b in eval_pairs: embs_a = all_embeddings.get(mod_a, []) embs_b = all_embeddings.get(mod_b, []) if not embs_a or not embs_b: continue valid_idx = [i for i in range(min(len(embs_a), len(embs_b))) if embs_a[i] is not None and embs_b[i] is not None] if len(valid_idx) < 10: continue ea = torch.stack([embs_a[i] for i in valid_idx]) eb = torch.stack([embs_b[i] for i in valid_idx]) sim = ea @ eb.T recalls = {} for k in k_values: kk = min(k, len(valid_idx) - 1) if kk < 1: continue topk = sim.topk(kk, dim=1).indices correct = (topk == torch.arange(len(valid_idx)).unsqueeze(1)).any(dim=1) recalls[f"R@{k}"] = correct.float().mean().item() results[f"{mod_a}→{mod_b}"] = recalls logger.info(f" {mod_a}→{mod_b}: {recalls}") return results @torch.no_grad() def evaluate_nl_queries(model, tokenizer, indices, config): model.eval() test_queries = [ ("oxide with high bandgap", config.nl_query_modality), ("narrow bandgap semiconductor", config.nl_query_modality), ("stable binary oxide", config.nl_query_modality), ("wide bandgap fluoride", config.nl_query_modality), ("ternary sulfide with low formation energy", config.nl_query_modality), ("metallic nitride", config.nl_query_modality), ("Fe2O3", "composition"), ("SiO2", "composition"), ("TiO2", "composition"), ("GaN", "composition"), ("perovskite structure with octahedral coordination", "robocrys_rep"), ("cubic crystal with face-centered lattice", "robocrys_rep"), ] results = {} for query_text, query_modality in test_queries: try: hits = search_vector_db(query_text, query_modality, model, tokenizer, indices, config, k=5) results[query_text] = { "modality": query_modality, "top_hits": [(s, m) for s, m in hits], } logger.info(f"\nQuery: '{query_text}' (via {query_modality})") for rank, (score, meta) in enumerate(hits[:5], 1): logger.info(f" #{rank}: {score:.4f} | {meta.get('composition', 'N/A')} | " f"via {meta.get('matched_modality', 'N/A')}") except Exception as e: logger.warning(f"Query '{query_text}' failed: {e}") return results # ============================================================================ # FAISS Vector Database # ============================================================================ def build_vector_database(model, dataset, tokenizer, config, modalities_to_index=None): if modalities_to_index is None: modalities_to_index = ["composition", "crystal_text_llm", "slices", "cif_symmetrized", "robocrys_rep"] model.eval() autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32) use_amp = config.use_bf16 or config.use_fp16 all_embeddings = {mod: [] for mod in modalities_to_index} all_metadata = [] bs = 64 for start in range(0, len(dataset), bs): end = min(start + bs, len(dataset)) items = [dataset[i] for i in range(start, end)] for item in items: meta = { "composition": item.get("composition", ""), "property_label": item.get("property_label"), } all_metadata.append(meta) all_mod_keys = list(config.modalities) batch = collate_fn(items, tokenizer, all_mod_keys, config.max_length) with torch.no_grad(): for mod in modalities_to_index: if batch.get(mod) is None: all_embeddings[mod].extend([None] * len(items)) continue with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): emb = model.encode( batch[mod]["input_ids"].to(config.device), batch[mod]["attention_mask"].to(config.device), mod, ).float().cpu().numpy() for i in range(len(emb)): if batch[mod]["valid_mask"][i]: all_embeddings[mod].append(emb[i]) else: all_embeddings[mod].append(None) if (start // bs) % 20 == 0: logger.info(f"Indexed {end}/{len(dataset)}") indices = {} for mod in modalities_to_index: valid_embs = [e for e in all_embeddings[mod] if e is not None] valid_map = [i for i, e in enumerate(all_embeddings[mod]) if e is not None] if not valid_embs: continue emb_matrix = np.stack(valid_embs).astype(np.float32) faiss.normalize_L2(emb_matrix) d = emb_matrix.shape[1] if len(valid_embs) > 10000: nlist = min(100, int(np.sqrt(len(valid_embs)))) quantizer = faiss.IndexFlatIP(d) index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_INNER_PRODUCT) index.train(emb_matrix) index.nprobe = 10 else: index = faiss.IndexFlatIP(d) index.add(emb_matrix) indices[mod] = { "index": index, "valid_indices_map": valid_map, "metadata": [all_metadata[i] for i in valid_map], } logger.info(f"FAISS {mod}: {len(valid_embs)} vectors, dim={d}") return indices def search_vector_db(query_text, query_modality, model, tokenizer, indices, config, k=10): """Search the vector DB with any modality query. For NL queries like "oxide with high bandgap": query_modality="nl_property_description" For composition queries like "Fe2O3": query_modality="composition" For structure descriptions: query_modality="robocrys_rep" """ model.eval() autocast_dtype = torch.bfloat16 if config.use_bf16 else (torch.float16 if config.use_fp16 else torch.float32) use_amp = config.use_bf16 or config.use_fp16 enc = tokenizer( [query_text], padding=True, truncation=True, max_length=config.max_length, return_tensors="pt", ) with torch.no_grad(): with torch.amp.autocast('cuda', dtype=autocast_dtype, enabled=use_amp): q_emb = model.encode( enc["input_ids"].to(config.device), enc["attention_mask"].to(config.device), query_modality, ).float().cpu().numpy().astype(np.float32) faiss.normalize_L2(q_emb) results = [] for mod_name, idx_data in indices.items(): scores, ids = idx_data["index"].search(q_emb, k) for s, i in zip(scores[0], ids[0]): if i >= 0 and i < len(idx_data["metadata"]): m = dict(idx_data["metadata"][i]) m["matched_modality"] = mod_name results.append((float(s), m)) results.sort(key=lambda x: x[0], reverse=True) seen, unique = set(), [] for s, m in results: c = m.get("composition", "") if c not in seen: seen.add(c) unique.append((s, m)) if len(unique) >= k: break return unique # ============================================================================ # Main # ============================================================================ def main(): config = Config() try: from flash_attn import flash_attn_func config.use_flash_attn = True logger.info("Flash Attention 2 available — enabling") except ImportError: config.use_flash_attn = False logger.info("Flash Attention 2 not available — using default attention") logger.info(f"Device: {config.device}") logger.info(f"Precision: {'bf16' if config.use_bf16 else 'fp16' if config.use_fp16 else 'fp32'}") logger.info(f"Max length: {config.max_length}") logger.info(f"Batch: {config.batch_size} × {config.grad_accum_steps} = {config.batch_size * config.grad_accum_steps} effective") logger.info(f"Encoder: {config.encoder_name}") use_trackio = False try: import trackio trackio.init(project="mattext-embeddings", name=f"align-v2-{config.max_length}ctx") use_trackio = True logger.info("Trackio initialized") except Exception as e: logger.warning(f"Trackio init failed: {e}") tokenizer = AutoTokenizer.from_pretrained(config.encoder_name) model = MatTextEncoder(config).to(config.device) total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"Total params: {total_params:,} | Trainable: {trainable_params:,}") # Phase 1 logger.info("=" * 70 + "\nPHASE 1: Multi-modal alignment on pretrain100k_v2\n" + "=" * 70) pretrain_data = load_dataset(config.dataset_name, config.pretrain_config, split="train") logger.info(f"Pretrain loaded: {len(pretrain_data)} samples, cols: {pretrain_data.column_names}") if len(pretrain_data) > config.max_pretrain_samples: pretrain_data = pretrain_data.shuffle(seed=42).select(range(config.max_pretrain_samples)) logger.info(f"Subsampled to {len(pretrain_data)}") phase1_dataset = MatTextPhase1Dataset(pretrain_data, config.modalities) make_collate = lambda mods: lambda batch: collate_fn(batch, tokenizer, mods, config.max_length) phase1_loader = DataLoader( phase1_dataset, batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=2, collate_fn=make_collate(config.modalities), pin_memory=(config.device == "cuda"), prefetch_factor=2, ) optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay) phase1_steps = len(phase1_loader) * config.num_epochs_phase1 // config.grad_accum_steps scheduler = get_cosine_schedule_with_warmup(optimizer, int(phase1_steps * config.warmup_ratio), phase1_steps) scaler = torch.amp.GradScaler('cuda') if config.use_fp16 else None global_step = 0 best_loss = float('inf') os.makedirs(config.output_dir, exist_ok=True) for epoch in range(1, config.num_epochs_phase1 + 1): t0 = time.time() loss, global_step = train_epoch( model, phase1_loader, optimizer, scheduler, config, epoch, phase=1, scaler=scaler, use_trackio=use_trackio, global_step=global_step, ) elapsed = time.time() - t0 logger.info(f"Phase1 Epoch {epoch}/{config.num_epochs_phase1} | Loss: {loss:.4f} | Time: {elapsed:.0f}s ({elapsed/60:.1f}min)") if loss < best_loss: best_loss = loss torch.save(model.state_dict(), f"{config.output_dir}/best_model_phase1.pt") logger.info(f" → New best model saved (loss={loss:.4f})") del pretrain_data, phase1_dataset, phase1_loader torch.cuda.empty_cache() if torch.cuda.is_available() else None # Phase 2 logger.info("=" * 70 + "\nPHASE 2: Property-conditioned alignment + NL query training\n" + "=" * 70) finetune_datasets = [] for ft_cfg, ft_split, prop_name in config.finetune_configs: try: ft = load_dataset(config.dataset_name, ft_cfg, split=ft_split) logger.info(f"Loaded {ft_cfg}/{ft_split}: {len(ft)} samples") finetune_datasets.append((ft, prop_name)) except Exception as e: logger.warning(f"Failed to load {ft_cfg}/{ft_split}: {e}") if finetune_datasets: all_phase2_datasets = [] for ft_data, prop_name in finetune_datasets: if len(ft_data) > config.max_finetune_samples // len(finetune_datasets): n = config.max_finetune_samples // len(finetune_datasets) ft_data = ft_data.shuffle(seed=42).select(range(n)) phase2_ds = MatTextPhase2Dataset( ft_data, config.modalities, "labels", prop_name, nl_descriptions_per_sample=config.nl_descriptions_per_sample, ) all_phase2_datasets.append(phase2_ds) logger.info(f"Phase2 dataset ({prop_name}): {len(phase2_ds)} samples") class ConcatPhase2Dataset(Dataset): def __init__(self, datasets): self.datasets = datasets self.lengths = [len(d) for d in datasets] self.total = sum(self.lengths) self.cum_lengths = [] acc = 0 for l in self.lengths: self.cum_lengths.append(acc) acc += l def __len__(self): return self.total def __getitem__(self, idx): for i, (cum, length) in enumerate(zip(self.cum_lengths, self.lengths)): if idx < cum + length: return self.datasets[i][idx - cum] return self.datasets[-1][idx - self.cum_lengths[-1]] combined_phase2 = ConcatPhase2Dataset(all_phase2_datasets) phase2_mod_keys = list(config.modalities) + [config.nl_query_modality, "property_text"] phase2_loader = DataLoader( combined_phase2, batch_size=config.batch_size, shuffle=True, drop_last=True, num_workers=2, collate_fn=lambda batch: collate_fn(batch, tokenizer, phase2_mod_keys, config.max_length), pin_memory=(config.device == "cuda"), prefetch_factor=2, ) optimizer2 = torch.optim.AdamW( model.parameters(), lr=config.learning_rate * 0.5, weight_decay=config.weight_decay, ) phase2_steps = len(phase2_loader) * config.num_epochs_phase2 // config.grad_accum_steps scheduler2 = get_cosine_schedule_with_warmup(optimizer2, int(phase2_steps * config.warmup_ratio), phase2_steps) for epoch in range(1, config.num_epochs_phase2 + 1): t0 = time.time() loss, global_step = train_epoch( model, phase2_loader, optimizer2, scheduler2, config, epoch, phase=2, scaler=scaler, use_trackio=use_trackio, global_step=global_step, ) elapsed = time.time() - t0 logger.info(f"Phase2 Epoch {epoch}/{config.num_epochs_phase2} | Loss: {loss:.4f} | Time: {elapsed:.0f}s ({elapsed/60:.1f}min)") if loss < best_loss: best_loss = loss torch.save(model.state_dict(), f"{config.output_dir}/best_model.pt") logger.info(f" → New best model saved (loss={loss:.4f})") del combined_phase2, phase2_loader else: logger.warning("No finetune data loaded — skipping Phase 2") # Evaluation logger.info("=" * 70 + "\nEVALUATION\n" + "=" * 70) best_path = f"{config.output_dir}/best_model.pt" if not os.path.exists(best_path): best_path = f"{config.output_dir}/best_model_phase1.pt" if os.path.exists(best_path): model.load_state_dict(torch.load(best_path, map_location=config.device)) logger.info(f"Loaded best model from {best_path}") eval_data = load_dataset(config.dataset_name, config.pretrain_config, split="test") if len(eval_data) > 5000: eval_data = eval_data.shuffle(seed=42).select(range(5000)) logger.info(f"Eval data: {len(eval_data)} samples") eval_dataset = MatTextPhase1Dataset(eval_data, config.modalities) eval_loader = DataLoader( eval_dataset, batch_size=config.batch_size, shuffle=False, num_workers=2, collate_fn=make_collate(config.modalities), ) retrieval_results = evaluate_retrieval(model, eval_loader, config) logger.info("\nBuilding FAISS vector database...") db_indices = build_vector_database( model, eval_dataset, tokenizer, config, modalities_to_index=["composition", "crystal_text_llm", "slices", "cif_symmetrized", "robocrys_rep"], ) faiss_dir = f"{config.output_dir}/faiss" os.makedirs(faiss_dir, exist_ok=True) for mod, d in db_indices.items(): faiss.write_index(d["index"], f"{faiss_dir}/{mod}.index") with open(f"{faiss_dir}/{mod}_metadata.json", "w") as f: json.dump(d["metadata"], f) logger.info("\n" + "=" * 70 + "\nNATURAL LANGUAGE QUERY EVALUATION\n" + "=" * 70) nl_results = evaluate_nl_queries(model, tokenizer, db_indices, config) # Save logger.info("\nSaving model and artifacts...") torch.save(model.state_dict(), f"{config.output_dir}/model.pt") tokenizer.save_pretrained(config.output_dir) model_config = model.get_config_dict() model_config["training"] = { "num_epochs_phase1": config.num_epochs_phase1, "num_epochs_phase2": config.num_epochs_phase2, "batch_size": config.batch_size, "grad_accum_steps": config.grad_accum_steps, "learning_rate": config.learning_rate, "max_length": config.max_length, "nl_descriptions_per_sample": config.nl_descriptions_per_sample, } with open(f"{config.output_dir}/config.json", "w") as f: json.dump(model_config, f, indent=2) with open(f"{config.output_dir}/retrieval_results.json", "w") as f: json.dump(retrieval_results, f, indent=2) nl_results_serializable = {} for k, v in nl_results.items(): nl_results_serializable[k] = { "modality": v["modality"], "top_hits": [(s, m) for s, m in v["top_hits"]], } with open(f"{config.output_dir}/nl_query_results.json", "w") as f: json.dump(nl_results_serializable, f, indent=2) if config.push_to_hub: try: api = HfApi() api.create_repo(config.hub_model_id, exist_ok=True) api.upload_folder( folder_path=config.output_dir, repo_id=config.hub_model_id, commit_message=f"Upload MatText aligned embeddings v2 (1024 ctx, NL queries)", ) logger.info(f"✓ Pushed to https://huggingface.co/{config.hub_model_id}") except Exception as e: logger.error(f"Push failed: {e}") logger.info("\n" + "=" * 70) logger.info("TRAINING COMPLETE") logger.info(f"Model: {config.output_dir}/model.pt") logger.info(f"FAISS: {faiss_dir}/") logger.info(f"Hub: https://huggingface.co/{config.hub_model_id}") logger.info("=" * 70) if __name__ == "__main__": main()