from __future__ import annotations import torch import torch._inductor.config as inductor_config import torch._dynamo as dynamo # Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs) # Provides significant speedup with minimal precision loss torch.set_float32_matmul_precision('high') # Enable TF32 for matrix multiplications and cuDNN operations torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Enable cuDNN autotuner - finds fastest algorithms for your hardware # Best when input sizes are consistent; may slow down first iterations torch.backends.cudnn.benchmark = True # Deterministic operations off for speed (set True if reproducibility needed) torch.backends.cudnn.deterministic = False inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM" dynamo.config.capture_scalar_outputs = True torch._dynamo.config.recompile_limit = 16 import io import os import queue import sqlite3 import struct import threading import time import networkx as nx import numpy as np import torch from tqdm.auto import tqdm from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple from torch.utils.data import DataLoader from torch.utils.data import Dataset as TorchDataset from transformers import PreTrainedTokenizerBase # Compact blob serialization constants # Canonical source: core/embed/blob.py. Keep in sync with protify/utils.py. _COMPACT_VERSION = 0x01 _DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2} _CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32} _CODE_TO_NP_DTYPE = {0: np.float16, 1: np.float16, 2: np.float32} def tensor_to_embedding_blob(tensor: torch.Tensor) -> bytes: """Serialize a tensor to compact binary format for SQLite blob storage. Format: [version:1][dtype_code:1][ndim:4][shape:4*ndim][raw_bytes] bfloat16 tensors are stored as float16 bytes (numpy lacks bfloat16) but tagged with dtype_code=1 so they can be cast back on read. Falls back to torch.save for unsupported dtypes. """ t = tensor.cpu() if t.dtype not in _DTYPE_TO_CODE: buffer = io.BytesIO() torch.save(t, buffer) return buffer.getvalue() dtype_code = _DTYPE_TO_CODE[t.dtype] if t.dtype == torch.bfloat16: raw = t.half().numpy().tobytes() else: raw = t.numpy().tobytes() shape = t.shape header = struct.pack(f' bytes: """Build just the compact header for a given dtype and shape.""" dtype_code = _DTYPE_TO_CODE[dtype] return struct.pack(f' List[bytes]: """Serialize a batch of same-shape tensors to compact blobs (fast path for vectors). Builds the header once and slices raw bytes per row. Much faster than per-row tensor_to_embedding_blob calls for uniform-shape batches. """ assert batch.ndim >= 2, f"Expected batch with >= 2 dims, got {batch.ndim}" t = batch.cpu() store_dtype = t.dtype if t.dtype not in _DTYPE_TO_CODE: return [tensor_to_embedding_blob(t[i]) for i in range(t.shape[0])] if t.dtype == torch.bfloat16: arr = t.half().numpy() store_dtype = torch.bfloat16 else: arr = t.numpy() row_shape = tuple(t.shape[1:]) header = _compact_header(store_dtype, row_shape) raw = arr.tobytes() stride = len(raw) // t.shape[0] return [header + raw[i * stride:(i + 1) * stride] for i in range(t.shape[0])] def embedding_blob_to_tensor(blob: bytes, fallback_shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor: """Deserialize a blob back to a tensor. Auto-detects compact vs legacy formats.""" if len(blob) >= 6 and blob[0] == _COMPACT_VERSION: dtype_code = blob[1] ndim = struct.unpack_from(' torch.nn.Module: """Compile model with torch.compile if possible. Skips compilation when dynamic=True (padding='longest') because flex attention's create_block_mask is incompatible with dynamic shapes under torch.compile, causing CUDA illegal memory access. """ if dynamic: print("Skipping torch.compile (dynamic shapes + flex attention incompatible)") return model try: model = torch.compile(model) print("Model compiled") except Exception as e: print(f"Skipping torch.compile: {e}") return model def build_collator( tokenizer: PreTrainedTokenizerBase, padding: str = 'max_length', max_length: int = 512, ) -> Callable[[List[str]], Dict[str, torch.Tensor]]: def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]: kwargs: Dict[str, Any] = dict( return_tensors="pt", padding=padding, truncation=True, max_length=max_length, ) if padding != 'max_length': kwargs['pad_to_multiple_of'] = 8 return tokenizer(sequences, **kwargs) return _collate_fn def _make_embedding_progress( dataloader: DataLoader, padding: str, n_warmup: int = 3, n_calibration: int = 5, ) -> Iterator[Tuple[int, Any]]: """Progress-bar wrapper for embedding loops. Drop-in replacement for enumerate(dataloader). When padding='max_length', all batches have uniform cost so plain tqdm works. When padding='longest' (sorted longest-first), batch times vary dramatically. In that case: yield warmup batches first (compiler warmup + OOM check on longest sequences), then time mid-length calibration batches to estimate total ETA. Keep in sync with protify/embedder.py and core/atlas/precomputed.py. """ total = len(dataloader) if padding == 'max_length' or total <= n_warmup + n_calibration: for i, batch in tqdm(enumerate(dataloader), total=total, desc='Embedding batches'): yield i, batch return dl_iter = iter(dataloader) # Phase 1: warmup on longest batches (first n_warmup, since sorted longest-first) warmup_bar = tqdm(range(n_warmup), desc='Warmup (longest batches)', leave=False) for i in warmup_bar: batch = next(dl_iter) yield i, batch warmup_bar.close() # Phase 2: skip to middle of dataset for calibration timing # We need to yield all intermediate batches too (they contain real data) mid_start = total // 2 intermediate_bar = tqdm( range(n_warmup, mid_start), desc='Embedding batches', leave=False, ) for i in intermediate_bar: batch = next(dl_iter) yield i, batch intermediate_bar.close() # Phase 3: time calibration batches from the middle calibration_times: List[float] = [] cal_bar = tqdm(range(n_calibration), desc='Calibrating ETA', leave=False) for j in cal_bar: t0 = time.perf_counter() batch = next(dl_iter) yield mid_start + j, batch calibration_times.append(time.perf_counter() - t0) cal_bar.close() avg_time = sum(calibration_times) / len(calibration_times) remaining_start = mid_start + n_calibration remaining_count = total - remaining_start estimated_total_seconds = avg_time * remaining_count # Phase 4: remaining batches with calibrated ETA main_bar = tqdm( range(remaining_count), desc='Embedding batches', bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]', ) main_bar.set_postfix_str(f'ETA ~{estimated_total_seconds:.0f}s (calibrated)') for k in main_bar: batch = next(dl_iter) yield remaining_start + k, batch main_bar.close() class _SQLWriter: """Context manager for async SQL embedding writes. Matches core/embed/storage.SQLEmbeddingWriter.""" def __init__(self, conn: sqlite3.Connection, queue_maxsize: int = 4) -> None: self._conn = conn self._queue: queue.Queue = queue.Queue(maxsize=queue_maxsize) self._thread: Optional[threading.Thread] = None def __enter__(self) -> "_SQLWriter": self._thread = threading.Thread(target=self._writer_loop, daemon=True) self._thread.start() return self def write_batch(self, rows: List[Tuple[str, bytes]]) -> None: self._queue.put(rows) def _writer_loop(self) -> None: cursor = self._conn.cursor() while True: item = self._queue.get() if item is None: break cursor.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item) if self._queue.qsize() == 0: self._conn.commit() self._conn.commit() def __exit__(self, *exc) -> None: if self._thread is not None: self._queue.put(None) self._thread.join() self._thread = None class Pooler: def __init__(self, pooling_types: List[str]) -> None: self.pooling_types = pooling_types self.pooling_options: Dict[str, Callable] = { 'mean': self.mean_pooling, 'max': self.max_pooling, 'norm': self.norm_pooling, 'median': self.median_pooling, 'std': self.std_pooling, 'var': self.var_pooling, 'cls': self.cls_pooling, 'parti': self._pool_parti, } def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: assert isinstance(attentions, torch.Tensor) maxed_attentions = torch.max(attentions, dim=1)[0] return maxed_attentions def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]: G = self._convert_to_graph(attention_matrix) if G.number_of_nodes() != attention_matrix.shape[0]: raise Exception( f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") if G.number_of_edges() == 0: raise Exception(f"You don't seem to have any attention edges left in the graph.") return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph: G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) return G def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray: if attention_mask is not None: for k in list(dict_importance.keys()): if attention_mask[k] == 0: del dict_importance[k] total = sum(dict_importance.values()) return np.array([v / total for _, v in dict_importance.items()]) def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() emb_pooled = [] for e, a, mask in zip(emb, maxed_attentions, attention_mask): dict_importance = self._page_rank(a) importance_weights = self._calculate_importance_weights(dict_importance, mask) num_tokens = int(mask.sum().item()) emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) pooled = torch.tensor(np.array(emb_pooled)) return pooled def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.mean(dim=1) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.max(dim=1).values else: mask = attention_mask.unsqueeze(-1).bool() return emb.masked_fill(~mask, float('-inf')).max(dim=1).values def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.norm(dim=1, p=2) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).norm(dim=1, p=2) def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.median(dim=1).values else: mask = attention_mask.unsqueeze(-1).bool() return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.std(dim=1) else: var = self.var_pooling(emb, attention_mask, **kwargs) return torch.sqrt(var) def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.var(dim=1) else: attention_mask = attention_mask.unsqueeze(-1) mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) mean = mean.unsqueeze(1) squared_diff = (emb - mean) ** 2 var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) return var def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: return emb[:, 0, :] def __call__( self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, attentions: Optional[torch.Tensor] = None ) -> torch.Tensor: if attention_mask is not None: assert attention_mask.sum(dim=-1).min() > 0, ( "Pooler received samples with all-zero attention masks. " "This causes NaN from division by zero. Filter empty inputs before pooling." ) final_emb: List[torch.Tensor] = [] for pooling_type in self.pooling_types: final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) return torch.cat(final_emb, dim=-1) class ProteinDataset(TorchDataset): """Simple dataset for protein sequences.""" def __init__(self, sequences: List[str]) -> None: self.sequences = sequences def __len__(self) -> int: return len(self.sequences) def __getitem__(self, idx: int) -> str: return self.sequences[idx] def parse_fasta(fasta_path: str) -> List[str]: assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}" sequences = [] current_seq = [] with open(fasta_path, 'r') as f: for line in f: line = line.strip() if not line: continue if line.startswith('>'): if current_seq: sequences.append(''.join(current_seq)) current_seq = [] else: current_seq.append(line) if current_seq: sequences.append(''.join(current_seq)) return sequences class EmbeddingMixin: def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: raise NotImplementedError @property def device(self) -> torch.device: """Get the device of the model.""" return next(self.parameters()).device def _read_sequences_from_db(self, db_path: str) -> Set[str]: """Read sequences from SQLite database.""" with sqlite3.connect(db_path, timeout=30) as conn: c = conn.cursor() c.execute("SELECT sequence FROM embeddings") return {row[0] for row in c.fetchall()} def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: cursor = conn.cursor() cursor.execute( "CREATE TABLE IF NOT EXISTS embeddings (" "sequence TEXT PRIMARY KEY, " "embedding BLOB NOT NULL" ")" ) conn.commit() def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]: assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" payload = torch.load(save_path, map_location="cpu", weights_only=True) assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." for sequence, tensor in payload.items(): assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." return payload def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]: assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" loaded: Dict[str, torch.Tensor] = {} with sqlite3.connect(db_path, timeout=30) as conn: self._ensure_embeddings_table(conn) cursor = conn.cursor() if sequences is None: cursor.execute("SELECT sequence, embedding FROM embeddings") else: if len(sequences) == 0: return loaded placeholders = ",".join(["?"] * len(sequences)) cursor.execute( f"SELECT sequence, embedding FROM embeddings WHERE sequence IN ({placeholders})", tuple(sequences), ) rows = cursor.fetchall() for row in rows: sequence = row[0] embedding_bytes = row[1] loaded[sequence] = embedding_blob_to_tensor(embedding_bytes) return loaded def embed_dataset( self, sequences: Optional[List[str]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, batch_size: int = 2, max_len: int = 512, truncate: bool = True, full_embeddings: bool = False, embed_dtype: torch.dtype = torch.float32, pooling_types: List[str] = ['mean'], num_workers: int = 0, sql: bool = False, save: bool = True, sql_db_path: str = 'embeddings.db', save_path: str = 'embeddings.pth', fasta_path: Optional[str] = None, padding: str = 'max_length', **kwargs, ) -> Optional[Dict[str, torch.Tensor]]: """ Embed a dataset of protein sequences. Supports two modes: - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used. - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via `fasta_path`, or both (the two sources are combined). At least one must be provided. """ if fasta_path is not None: fasta_sequences = parse_fasta(fasta_path) sequences = list(sequences or []) + fasta_sequences assert sequences is not None and len(sequences) > 0, \ "Must provide at least one sequence via `sequences` or `fasta_path`." sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) sequences = sorted(sequences, key=len, reverse=True) hidden_size = self.config.hidden_size pooler = Pooler(pooling_types) if not full_embeddings else None tokenizer_mode = tokenizer is not None # Resolve padding and compilation dynamic = padding == 'longest' compiled_model = maybe_compile(self, dynamic=dynamic) if tokenizer_mode: collate_fn = build_collator(tokenizer, padding=padding, max_length=max_len) device = self.device else: collate_fn = None device = None def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: assert isinstance(residue_embeddings, torch.Tensor) if full_embeddings or residue_embeddings.ndim == 2: return residue_embeddings return pooler(residue_embeddings, attention_mask) def iter_batches(to_embed: List[str]): if tokenizer_mode: assert collate_fn is not None assert device is not None dataset = ProteinDataset(to_embed) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, prefetch_factor=2 if num_workers > 0 else None, collate_fn=collate_fn, shuffle=False, pin_memory=True, ) for i, batch in _make_embedding_progress(dataloader, padding): seqs = to_embed[i * batch_size:(i + 1) * batch_size] input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) residue_embeddings = compiled_model._embed(input_ids, attention_mask) yield seqs, residue_embeddings, attention_mask else: for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): seqs = to_embed[batch_start:batch_start + batch_size] batch_output = compiled_model._embed(seqs, return_attention_mask=True, **kwargs) assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." residue_embeddings, attention_mask = batch_output assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." yield seqs, residue_embeddings, attention_mask if sql: # Step 1: DEDUPLICATE - check existing embeddings in SQL conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False) conn.execute('PRAGMA journal_mode=WAL') conn.execute('PRAGMA busy_timeout=30000') conn.execute('PRAGMA synchronous=OFF') conn.execute('PRAGMA cache_size=-64000') self._ensure_embeddings_table(conn) already_embedded = self._read_sequences_from_db(sql_db_path) to_embed = [seq for seq in sequences if seq not in already_embedded] print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: # Steps 4-7: BATCH+EMBED -> POOL/TRIM -> SERIALIZE -> WRITE (async) with _SQLWriter(conn) as writer: with torch.inference_mode(): for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) if full_embeddings: batch_rows = [] for seq, emb, mask in zip(seqs, embeddings, attention_mask): batch_rows.append((seq, tensor_to_embedding_blob(emb[mask.bool()].reshape(-1, hidden_size)))) else: blobs = batch_tensor_to_blobs(embeddings) batch_rows = list(zip(seqs, blobs)) writer.write_batch(batch_rows) conn.close() return None embeddings_dict = {} if os.path.exists(save_path): embeddings_dict = self.load_embeddings_from_pth(save_path) to_embed = [seq for seq in sequences if seq not in embeddings_dict] print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") print(f"Embedding {len(to_embed)} new sequences") else: to_embed = sequences print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: with torch.inference_mode(): for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) for seq, emb, mask in zip(seqs, embeddings, attention_mask): if full_embeddings: emb = emb[mask.bool()].reshape(-1, hidden_size) embeddings_dict[seq] = emb.cpu() if save: torch.save(embeddings_dict, save_path) return embeddings_dict if __name__ == "__main__": # py -m pooler pooler = Pooler(pooling_types=['max', 'parti']) batch_size = 8 seq_len = 64 hidden_size = 128 num_layers = 12 emb = torch.randn(batch_size, seq_len, hidden_size) attentions = torch.randn(batch_size, num_layers, seq_len, seq_len) attention_mask = torch.ones(batch_size, seq_len) y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions) print(y.shape) """Shared attention infrastructure for all FastPLMs models. Contains: AttentionBackend enum, backend resolution, mask creation, flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities. """ from enum import Enum from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange try: from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask except ImportError: create_block_mask = None flex_attention = None BlockMask = None _compiled_flex_attention = None def _get_flex_attention_fn(): """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.""" global _compiled_flex_attention if flex_attention is None: return None flex_mod = torch.nn.attention.flex_attention if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): return flex_attention if _compiled_flex_attention is None: _compiled_flex_attention = torch.compile( flex_attention, dynamic=False, ) return _compiled_flex_attention ### Kernels Flash Attention Detection def _infer_kernels_flash_variant(kernel) -> Optional[str]: if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): return "flash_attn2" if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): return "flash_attn3" return None def _try_get_kernels_flash(): try: from kernels import get_kernel except ImportError: return None, None flash_kernel = None flash_kernel_variant = None try: flash_kernel = get_kernel("kernels-community/flash-attn3") flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." except Exception: try: flash_kernel = get_kernel("kernels-community/flash-attn2") flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." except Exception: flash_kernel = None flash_kernel_variant = None return flash_kernel, flash_kernel_variant _FLASH_KERNELS_LOADED = False FLASH_KERNEL = None FLASH_KERNEL_VARIANT = None def _ensure_flash_kernels_loaded(): global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT if _FLASH_KERNELS_LOADED: return _FLASH_KERNELS_LOADED = True FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() def _kernels_flash_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, causal: bool = False, ) -> torch.Tensor: assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." if FLASH_KERNEL_VARIANT == "flash_attn2": return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0] if FLASH_KERNEL_VARIANT == "flash_attn3": try: output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal) except TypeError: output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal) if isinstance(output, tuple): return output[0] return output raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") def _kernels_flash_varlen_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_in_batch_q: int, max_seqlen_in_batch_k: int, causal: bool = False, ) -> torch.Tensor: assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." if FLASH_KERNEL_VARIANT == "flash_attn2": return FLASH_KERNEL.varlen_fwd( q=query_states, k=key_states, v=value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, is_causal=causal, )[0] if FLASH_KERNEL_VARIANT == "flash_attn3": try: output = FLASH_KERNEL.flash_attn_varlen_func( q=query_states, k=key_states, v=value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, causal=causal, ) except TypeError: output = FLASH_KERNEL.flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q, cu_seqlens_k, max_seqlen_in_batch_q, max_seqlen_in_batch_k, 0.0, None, causal, ) if isinstance(output, tuple): return output[0] return output raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") ### Unpad / Pad helpers for varlen flash attention class IndexFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, input, indices) -> torch.Tensor: ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] second_dim = other_shape.numel() return torch.gather( rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) ).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]: (indices,) = ctx.saved_tensors assert grad_output.ndim >= 2 other_shape = grad_output.shape[1:] grad_output = rearrange(grad_output, "b ... -> b (...)") grad_input = torch.zeros( [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype ) grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) return grad_input.reshape(ctx.first_axis_dim, *other_shape), None class IndexPutFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) output[indices] = values return output @staticmethod def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]: (indices,) = ctx.saved_tensors return grad_output[indices], None, None index_first_axis = IndexFirstAxis.apply index_put_first_axis = IndexPutFirstAxis.apply def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: output = index_put_first_axis(hidden_states, indices, batch * seqlen) return rearrange(output, "(b s) ... -> b s ...", b=batch) def _unpad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask_2d: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: batch_size, seq_len, num_heads, head_dim = query_layer.shape seqlens = attention_mask_2d.sum(dim=1).int() cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) max_seqlen = int(seqlens.max().item()) indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) def kernels_flash_attention_func( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask_2d: Optional[torch.Tensor] = None, causal: bool = False, ) -> torch.Tensor: assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." if not causal and attention_mask_2d is not None: batch_size, q_len = query_states.shape[:2] ( query_states, key_states, value_states, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) attn_output_unpad = _kernels_flash_varlen_forward( query_states=query_states, key_states=key_states, value_states=value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, ) return pad_input(attn_output_unpad, indices_q, batch_size, q_len) else: return _kernels_flash_forward( query_states=query_states, key_states=key_states, value_states=value_states, causal=causal, ) ### Attention Backend Enum & Resolution class AttentionBackend(Enum): AUTO = "auto" KERNELS_FLASH = "kernels_flash" FLEX = "flex" SDPA = "sdpa" VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) _BACKEND_CONFIRMED = False def resolve_attention_backend(requested_backend: str) -> AttentionBackend: global _BACKEND_CONFIRMED assert requested_backend in VALID_ATTENTION_BACKENDS, ( f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." ) if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value): _ensure_flash_kernels_loaded() if requested_backend == AttentionBackend.AUTO.value: if FLASH_KERNEL is not None: resolved = AttentionBackend.KERNELS_FLASH elif flex_attention is not None: resolved = AttentionBackend.FLEX else: resolved = AttentionBackend.SDPA elif requested_backend == AttentionBackend.KERNELS_FLASH.value: assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." resolved = AttentionBackend.KERNELS_FLASH elif requested_backend == AttentionBackend.FLEX.value: assert flex_attention is not None, "Flex Attention is not available in this environment." resolved = AttentionBackend.FLEX elif requested_backend == AttentionBackend.SDPA.value: resolved = AttentionBackend.SDPA else: raise AssertionError(f"Unsupported attention backend: {requested_backend}") if not _BACKEND_CONFIRMED: print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") _BACKEND_CONFIRMED = True return resolved @torch.compiler.disable def get_attention_mask( effective_backend: AttentionBackend, batch_size: int, seq_len: int, device: torch.device, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]: """Build padding masks once for all encoder layers. Returns (attention_mask_2d, attention_mask_4d, flex_block_mask). """ if attention_mask is None: return None, None, None attention_mask_2d = attention_mask.bool() if effective_backend == AttentionBackend.KERNELS_FLASH: return attention_mask_2d, None, None if effective_backend == AttentionBackend.FLEX: assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." valid_lens = attention_mask_2d.sum(dim=-1) def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) return attention_mask_2d, None, flex_block_mask # SDPA / manual -- only mask the key dimension so padding query positions attend to # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf). attention_mask_4d = attention_mask_2d[:, None, None, :] return attention_mask_2d, attention_mask_4d, None # Copyright (c) 2024 Bytedance Ltd. and/or its affiliates # SPDX-License-Identifier: Apache-2.0 """ FastPLMs-compatible DPLM implementation. """ import torch import torch.nn as nn from torch.nn import functional as F from dataclasses import dataclass from typing import List, Optional, Tuple, Union from einops import rearrange from transformers import EsmTokenizer from transformers.modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput, SequenceClassifierOutput, TokenClassifierOutput, ) from transformers.models.esm.configuration_esm import EsmConfig from transformers.models.esm.modeling_esm import ( EsmAttention, EsmClassificationHead, EsmContactPredictionHead, EsmEmbeddings, EsmEncoder, EsmIntermediate, EsmLayer, EsmLMHead, EsmOutput, EsmPooler, EsmPreTrainedModel, EsmSelfAttention, EsmSelfOutput, ) @dataclass class DPLMMaskedLMOutput(ModelOutput): loss: Optional[torch.Tensor] = None logits: Optional[torch.Tensor] = None last_hidden_state: Optional[torch.Tensor] = None hidden_states: Optional[Tuple[torch.Tensor, ...]] = None attentions: Optional[Tuple[torch.Tensor, ...]] = None s_max: Optional[Tuple[List[torch.Tensor], ...]] = None @dataclass class DPLMEncoderOutput(ModelOutput): last_hidden_state: Optional[torch.Tensor] = None hidden_states: Optional[Tuple[torch.Tensor, ...]] = None attentions: Optional[Tuple[torch.Tensor, ...]] = None s_max: Optional[Tuple[List[torch.Tensor], ...]] = None class DPLMConfig(EsmConfig): model_type = "dplm" def __init__( self, attn_backend: str = "sdpa", **kwargs, ): super().__init__(**kwargs) self.attn_backend = attn_backend self.tie_word_embeddings = False class DPLMPreTrainedModel(EsmPreTrainedModel): config_class = DPLMConfig base_model_prefix = "dplm" supports_gradient_checkpointing = True tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") all_tied_weights_keys = {} @classmethod def is_remote_code(cls) -> bool: # Prevent post-load reinitialization of tensors already loaded from checkpoints. return True @property def attn_backend(self) -> str: return self.config.attn_backend @attn_backend.setter def attn_backend(self, backend: str) -> None: assert backend in VALID_ATTENTION_BACKENDS, f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}." self.config.attn_backend = backend resolved = resolve_attention_backend(backend) for module in self.modules(): if isinstance(module, ModifiedEsmEncoder): module.attention_backend = resolved elif isinstance(module, ModifiedEsmSelfAttention): module.attn_backend = resolved class ModifiedEsmSelfAttention(EsmSelfAttention): def __init__(self, config, position_embedding_type=None): super().__init__(config, position_embedding_type) self.config = config self.scale = self.attention_head_size**-0.5 self.attn_backend = resolve_attention_backend(config.attn_backend) def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) def forward( self, hidden_states: torch.Tensor, attention_mask_2d: Optional[torch.Tensor] = None, attention_mask_4d: Optional[torch.Tensor] = None, flex_block_mask: Optional[object] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, output_s_max: Optional[bool] = False, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: if past_key_values is not None: past_key_value = past_key_values mixed_query_layer = self.query(hidden_states) is_cross_attention = encoder_hidden_states is not None if is_cross_attention and past_key_value is not None: key_layer = past_key_value[0] value_layer = past_key_value[1] cross_attn_mask = encoder_attention_mask elif is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) cross_attn_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) key_layer = torch.cat([past_key_value[0], key_layer], dim=2) value_layer = torch.cat([past_key_value[1], value_layer], dim=2) cross_attn_mask = None else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) cross_attn_mask = None query_layer = self.transpose_for_scores(mixed_query_layer) * self.scale if self.position_embedding_type == "rotary": query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer) if self.position_embedding_type in ["relative_key", "relative_key_query"]: raise NotImplementedError query_layer = query_layer.contiguous() key_layer = key_layer.contiguous() value_layer = value_layer.contiguous() if is_cross_attention: if output_attentions: attn_output, attn_weights, s_max = self._manual_attn( query_layer, key_layer, value_layer, cross_attn_mask, output_s_max, ) else: attn_output, attn_weights = self._sdpa_attn( query_layer, key_layer, value_layer, cross_attn_mask, ) s_max = self._compute_s_max(query_layer, key_layer) if output_s_max else None else: attn_output, attn_weights, s_max = self._attn( query_layer, key_layer, value_layer, attention_mask_2d=attention_mask_2d, attention_mask_4d=attention_mask_4d, flex_block_mask=flex_block_mask, output_attentions=output_attentions, output_s_max=output_s_max, ) if head_mask is not None and torch.is_tensor(head_mask): batch_size, seq_len, _ = attn_output.shape attn_output = attn_output.view(batch_size, seq_len, self.num_attention_heads, self.attention_head_size) attn_output = attn_output.permute(0, 2, 1, 3) * head_mask attn_output = rearrange(attn_output, "b h s d -> b s (h d)") return attn_output, attn_weights, s_max def _attn( self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor, value_BHLD: torch.Tensor, attention_mask_2d: Optional[torch.Tensor] = None, attention_mask_4d: Optional[torch.Tensor] = None, flex_block_mask: Optional[BlockMask] = None, output_attentions: bool = False, output_s_max: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: if output_attentions: return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max) if self.attn_backend == AttentionBackend.KERNELS_FLASH: attn_output, attn_weights = self._kernels_flash_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_2d) elif self.attn_backend == AttentionBackend.FLEX: attn_output, attn_weights = self._flex_attn(query_BHLD, key_BHLD, value_BHLD, flex_block_mask) elif self.attn_backend == AttentionBackend.SDPA: attn_output, attn_weights = self._sdpa_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d) else: raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") s_max = self._compute_s_max(query_BHLD, key_BHLD) if output_s_max else None return attn_output, attn_weights, s_max @torch.no_grad() def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]: q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1) k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1) s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values return [s_max_bound[h] for h in range(self.num_attention_heads)] def _manual_attn( self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor, value_BHLD: torch.Tensor, attention_mask_4d: Optional[torch.Tensor] = None, output_s_max: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]: attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2)) if attention_mask_4d is not None: attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) attn_weights = F.softmax(attn_weights, dim=-1) context_BHLD = torch.matmul(attn_weights, value_BHLD) attn_output = rearrange(context_BHLD, "b h s d -> b s (h d)") s_max = self._compute_s_max(query_BHLD, key_BHLD) if output_s_max else None return attn_output, attn_weights, s_max def _kernels_flash_attn( self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor, value_BHLD: torch.Tensor, attention_mask_2d: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, None]: query_BLHD = query_BHLD.transpose(1, 2).contiguous() key_BLHD = key_BHLD.transpose(1, 2).contiguous() value_BLHD = value_BHLD.transpose(1, 2).contiguous() attn_output = kernels_flash_attention_func( query_states=query_BLHD, key_states=key_BLHD, value_states=value_BLHD, attention_mask_2d=attention_mask_2d, causal=False, ) return rearrange(attn_output, "b s h d -> b s (h d)"), None def _flex_attn( self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor, value_BHLD: torch.Tensor, flex_block_mask: Optional[BlockMask] = None, ) -> Tuple[torch.Tensor, None]: assert flex_attention is not None, "Flex attention is not available in this environment." fn = _get_flex_attention_fn() context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0) return rearrange(context_BHLD, "b h s d -> b s (h d)"), None def _sdpa_attn( self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor, value_BHLD: torch.Tensor, attention_mask_4d: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, None]: context_BHLD = F.scaled_dot_product_attention( query_BHLD, key_BHLD, value_BHLD, attn_mask=attention_mask_4d, scale=1.0, ) return rearrange(context_BHLD, "b h s d -> b s (h d)"), None class ModifiedEsmAttention(EsmAttention): def __init__(self, config): nn.Module.__init__(self) self.self = ModifiedEsmSelfAttention(config) self.output = EsmSelfOutput(config) self.pruned_heads = set() self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask_2d: Optional[torch.Tensor] = None, attention_mask_4d: Optional[torch.Tensor] = None, flex_block_mask: Optional[object] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: bool = False, output_s_max: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: hidden_states_ln = self.LayerNorm(hidden_states) attn_output, attn_weights, s_max = self.self( hidden_states_ln, attention_mask_2d=attention_mask_2d, attention_mask_4d=attention_mask_4d, flex_block_mask=flex_block_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, output_s_max=output_s_max, ) attention_output = self.output(attn_output, hidden_states) return attention_output, attn_weights, s_max class ModifiedEsmLayer(EsmLayer): def __init__(self, config): nn.Module.__init__(self) self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = ModifiedEsmAttention(config) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if self.is_decoder is False: raise RuntimeError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = ModifiedEsmAttention(config) self.intermediate = EsmIntermediate(config) self.output = EsmOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask_2d: Optional[torch.Tensor] = None, attention_mask_4d: Optional[torch.Tensor] = None, flex_block_mask: Optional[object] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, output_attentions: bool = False, output_s_max: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: attention_output, attn_weights, s_max = self.attention( hidden_states, attention_mask_2d=attention_mask_2d, attention_mask_4d=attention_mask_4d, flex_block_mask=flex_block_mask, head_mask=head_mask, output_attentions=output_attentions, output_s_max=output_s_max, past_key_value=past_key_value[:2] if past_key_value is not None else None, ) if self.is_decoder and encoder_hidden_states is not None: if self.add_cross_attention is False: raise AttributeError( f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention " "layers by setting `config.add_cross_attention=True`" ) cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_output, _, _ = self.crossattention( attention_output, attention_mask_2d=attention_mask_2d, attention_mask_4d=attention_mask_4d, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_value=cross_attn_past_key_value, output_attentions=output_attentions, output_s_max=False, ) attention_output = cross_attention_output layer_output = self.feed_forward_chunk(attention_output) return layer_output, attn_weights, s_max class ModifiedEsmEncoder(EsmEncoder): def __init__(self, config): nn.Module.__init__(self) self.config = config self.attention_backend = resolve_attention_backend(config.attn_backend) self.layer = nn.ModuleList([ModifiedEsmLayer(config) for _ in range(config.num_hidden_layers)]) self.emb_layer_norm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[Tuple[Tuple[torch.FloatTensor]]]] = None, use_cache: Optional[bool] = None, output_attentions: bool = False, output_hidden_states: bool = False, output_s_max: bool = False, ) -> DPLMEncoderOutput: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None full_s_max = () if output_s_max else None attention_mask_2d, attention_mask_4d, flex_block_mask = get_attention_mask( effective_backend=self.attention_backend, batch_size=hidden_states.shape[0], seq_len=hidden_states.shape[1], device=hidden_states.device, attention_mask=attention_mask, ) for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if self.gradient_checkpointing and self.training: hidden_states, attn_weights, s_max = self._gradient_checkpointing_func( layer_module.__call__, hidden_states, attention_mask_2d, attention_mask_4d, flex_block_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, output_s_max, ) else: hidden_states, attn_weights, s_max = layer_module( hidden_states, attention_mask_2d=attention_mask_2d, attention_mask_4d=attention_mask_4d, flex_block_mask=flex_block_mask, head_mask=layer_head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, output_attentions=output_attentions, output_s_max=output_s_max, ) if all_self_attentions is not None: all_self_attentions = all_self_attentions + (attn_weights,) if full_s_max is not None: full_s_max = full_s_max + (s_max,) if self.emb_layer_norm_after: hidden_states = self.emb_layer_norm_after(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) return DPLMEncoderOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, s_max=full_s_max, ) class FAST_DPLM_ENCODER(DPLMPreTrainedModel, EmbeddingMixin): """Inner encoder class that holds the actual ESM-style weights (embeddings, encoder, contact_head) so that the weight keys are prefixed with 'esm.' in the outer DPLMModel, matching pretrained DPLM checkpoints.""" def __init__(self, config, **kwargs): DPLMPreTrainedModel.__init__(self, config, **kwargs) self.config = config self.embeddings = EsmEmbeddings(config) self.encoder = ModifiedEsmEncoder(config) self.contact_head = EsmContactPredictionHead( in_features=config.num_hidden_layers * config.num_attention_heads, bias=True, ) self.post_init() def get_input_embeddings(self) -> nn.Module: return self.embeddings.word_embeddings def set_input_embeddings(self, value): self.embeddings.word_embeddings = value def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: if attention_mask is None: attention_mask = input_ids.ne(self.config.pad_token_id) embedding_output = self.embeddings(input_ids, attention_mask=attention_mask) encoder_outputs = self.encoder( embedding_output, attention_mask=attention_mask, output_hidden_states=False, output_attentions=False, ) return encoder_outputs.last_hidden_state def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: attns = self(input_ids, attention_mask=attention_mask, output_attentions=True).attentions attns = torch.stack(attns, dim=1) attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(3) attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4) return self.contact_head(input_ids, attns) def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor: if head_mask.dim() == 1: head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1) elif head_mask.dim() == 2: head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}" head_mask = head_mask.to(dtype=self.dtype) return head_mask def get_head_mask( self, head_mask: Optional[torch.Tensor], num_hidden_layers: int, is_attention_chunked: bool = False, ) -> Union[torch.Tensor, List[None]]: if head_mask is None: return [None] * num_hidden_layers head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) if is_attention_chunked: head_mask = head_mask.unsqueeze(-1) return head_mask def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_s_max: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], DPLMEncoderOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") if input_ids is not None: input_shape = input_ids.size() elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device if attention_mask is None: attention_mask_2d = torch.ones((batch_size, seq_length), device=device).bool() elif attention_mask.dim() == 2: attention_mask_2d = attention_mask.bool() elif attention_mask.dim() == 4: assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask." attention_mask_2d = input_ids.ne(self.config.pad_token_id) else: raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}") encoder_extended_attention_mask = encoder_attention_mask if self.config.is_decoder and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask_2d, inputs_embeds=inputs_embeds, ) encoder_outputs = self.encoder( embedding_output, attention_mask=attention_mask_2d, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, ) sequence_output = encoder_outputs.last_hidden_state if return_dict is False: return (sequence_output,) + encoder_outputs[1:] return DPLMEncoderOutput( last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, s_max=encoder_outputs.s_max, ) class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin): config_class = DPLMConfig def __init__(self, config, add_pooling_layer=True): DPLMPreTrainedModel.__init__(self, config) self.config = config self.esm = FAST_DPLM_ENCODER(config) self.pooler = EsmPooler(config) if add_pooling_layer else None self.post_init() def get_input_embeddings(self) -> nn.Module: return self.esm.embeddings.word_embeddings def set_input_embeddings(self, value): self.esm.embeddings.word_embeddings = value def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: return self.esm.predict_contacts(input_ids, attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_s_max: Optional[bool] = False, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], DPLMEncoderOutput]: outputs = self.esm( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, return_dict=return_dict, ) sequence_output = outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None return_dict = return_dict if return_dict is not None else self.config.use_return_dict if return_dict is False: return (sequence_output, pooled_output) + outputs[1:] return DPLMEncoderOutput( last_hidden_state=sequence_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin): config_class = DPLMConfig def __init__(self, config, dropout: float = 0.1): config.hidden_dropout_prob = dropout DPLMPreTrainedModel.__init__(self, config) self.esm = FAST_DPLM_ENCODER(config) self.lm_head = EsmLMHead(config) self.loss_fct = nn.CrossEntropyLoss() self.post_init() self.tokenizer = self.__class__.tokenizer if isinstance(config._name_or_path, str) and len(config._name_or_path) > 0: try: self.tokenizer = EsmTokenizer.from_pretrained(config._name_or_path) except Exception: self.tokenizer = self.__class__.tokenizer self.mask_id = self.tokenizer.mask_token_id self.pad_id = self.tokenizer.pad_token_id self.bos_id = self.tokenizer.cls_token_id self.eos_id = self.tokenizer.eos_token_id self.x_id = self.tokenizer.convert_tokens_to_ids("X") self.contact_head = None def get_input_embeddings(self) -> nn.Module: return self.esm.get_input_embeddings() def get_output_embeddings(self): return self.lm_head.decoder def set_output_embeddings(self, new_embeddings): self.lm_head.decoder = new_embeddings def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: return self.esm.predict_contacts(input_ids, attention_mask=attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.Tensor] = None, decoder_attention_mask: Optional[torch.Tensor] = None, decoder_inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_s_max: Optional[bool] = False, return_dict: Optional[bool] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], DPLMMaskedLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict if attention_mask is None and input_ids is not None: attention_mask = input_ids.ne(self.pad_id) outputs = self.esm( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, return_dict=True, ) sequence_output = outputs.last_hidden_state logits = self.lm_head(sequence_output) loss = None if labels is not None: labels = labels.to(logits.device) loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if return_dict is False: output = (logits, sequence_output, outputs.hidden_states, outputs.attentions) if loss is not None: return (loss,) + output return output return DPLMMaskedLMOutput( loss=loss, logits=logits, last_hidden_state=sequence_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin): config_class = DPLMConfig def get_input_embeddings(self) -> nn.Module: return self.esm.get_input_embeddings() def __init__(self, config): DPLMPreTrainedModel.__init__(self, config) self.num_labels = config.num_labels self.esm = FAST_DPLM_ENCODER(config) self.classifier = EsmClassificationHead(config) self.mse = nn.MSELoss() self.ce = nn.CrossEntropyLoss() self.bce = nn.BCEWithLogitsLoss() self.post_init() def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_s_max: Optional[bool] = False, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], DPLMMaskedLMOutput]: outputs = self.esm( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, return_dict=True, ) sequence_output = outputs.last_hidden_state logits = self.classifier(sequence_output) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": if self.num_labels == 1: loss = self.mse(logits.squeeze(), labels.squeeze()) else: loss = self.mse(logits, labels) elif self.config.problem_type == "single_label_classification": loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss = self.bce(logits, labels) return DPLMMaskedLMOutput( loss=loss, logits=logits, last_hidden_state=sequence_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin): config_class = DPLMConfig def get_input_embeddings(self) -> nn.Module: return self.esm.get_input_embeddings() def __init__(self, config): DPLMPreTrainedModel.__init__(self, config) self.num_labels = config.num_labels self.esm = FAST_DPLM_ENCODER(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.loss_fct = nn.CrossEntropyLoss() self.post_init() def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: return self.esm._embed(input_ids, attention_mask) def forward( self, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_s_max: Optional[bool] = False, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], DPLMMaskedLMOutput]: outputs = self.esm( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, return_dict=True, ) sequence_output = self.dropout(outputs.last_hidden_state) logits = self.classifier(sequence_output) loss = None if labels is not None: labels = labels.to(logits.device) loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return DPLMMaskedLMOutput( loss=loss, logits=logits, last_hidden_state=sequence_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) if __name__ == "__main__": import random import torch from torch import Tensor from transformers import EsmTokenizer def print_tensor_shapes(prefix: str, obj): if isinstance(obj, Tensor): print(f"{prefix}{obj.shape}") elif isinstance(obj, dict): for name, value in obj.items(): print_tensor_shapes(f"{prefix}{name}.", value) elif isinstance(obj, list): for idx, value in enumerate(obj): print_tensor_shapes(f"{prefix}[{idx}].", value) elif isinstance(obj, tuple): for idx, value in enumerate(obj): print_tensor_shapes(f"{prefix}[{idx}].", value) elif hasattr(obj, "__dict__"): for name, value in vars(obj).items(): if name.startswith("_"): continue print_tensor_shapes(f"{prefix}{name}.", value) else: print(f"{prefix}{type(obj)}") random.seed(0) torch.manual_seed(0) num_attention_heads = random.choice([2, 4]) config = DPLMConfig( hidden_size=16 * num_attention_heads, num_attention_heads=num_attention_heads, num_hidden_layers=random.choice([1, 2]), attention_probs_dropout_prob=0.0, hidden_dropout_prob=0.0, attn_backend="sdpa", ) tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D") batch = tokenizer(["ACDEFG", "MKTW"], return_tensors="pt", padding="longest") batch["labels"] = batch["input_ids"].clone() model = DPLMForMaskedLM(config=config).eval() with torch.no_grad(): output = model(**batch, return_dict=True) print("Batch shape:") print_tensor_shapes("", batch) print("Output shape:") print_tensor_shapes("", output)