from __future__ import annotations import torch import torch._inductor.config as inductor_config import torch._dynamo as dynamo # Enable TensorFloat32 tensor cores for float32 matmul (Ampere+ GPUs) # Provides significant speedup with minimal precision loss torch.set_float32_matmul_precision('high') # Enable TF32 for matrix multiplications and cuDNN operations torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Enable cuDNN autotuner - finds fastest algorithms for your hardware # Best when input sizes are consistent; may slow down first iterations torch.backends.cudnn.benchmark = True # Deterministic operations off for speed (set True if reproducibility needed) torch.backends.cudnn.deterministic = False inductor_config.max_autotune_gemm_backends = "ATEN,CUTLASS,FBGEMM" dynamo.config.capture_scalar_outputs = True torch._dynamo.config.recompile_limit = 16 import io import os import queue import sqlite3 import struct import threading import time import networkx as nx import numpy as np import torch from tqdm.auto import tqdm from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple from torch.utils.data import DataLoader from torch.utils.data import Dataset as TorchDataset from transformers import PreTrainedTokenizerBase # Compact blob serialization constants # Canonical source: core/embed/blob.py. Keep in sync with protify/utils.py. _COMPACT_VERSION = 0x01 _DTYPE_TO_CODE = {torch.float16: 0, torch.bfloat16: 1, torch.float32: 2} _CODE_TO_DTYPE = {0: torch.float16, 1: torch.bfloat16, 2: torch.float32} _CODE_TO_NP_DTYPE = {0: np.float16, 1: np.float16, 2: np.float32} def tensor_to_embedding_blob(tensor: torch.Tensor) -> bytes: """Serialize a tensor to compact binary format for SQLite blob storage. Format: [version:1][dtype_code:1][ndim:4][shape:4*ndim][raw_bytes] bfloat16 tensors are stored as float16 bytes (numpy lacks bfloat16) but tagged with dtype_code=1 so they can be cast back on read. Falls back to torch.save for unsupported dtypes. """ t = tensor.cpu() if t.dtype not in _DTYPE_TO_CODE: buffer = io.BytesIO() torch.save(t, buffer) return buffer.getvalue() dtype_code = _DTYPE_TO_CODE[t.dtype] if t.dtype == torch.bfloat16: raw = t.half().numpy().tobytes() else: raw = t.numpy().tobytes() shape = t.shape header = struct.pack(f' bytes: """Build just the compact header for a given dtype and shape.""" dtype_code = _DTYPE_TO_CODE[dtype] return struct.pack(f' List[bytes]: """Serialize a batch of same-shape tensors to compact blobs (fast path for vectors). Builds the header once and slices raw bytes per row. Much faster than per-row tensor_to_embedding_blob calls for uniform-shape batches. """ assert batch.ndim >= 2, f"Expected batch with >= 2 dims, got {batch.ndim}" t = batch.cpu() store_dtype = t.dtype if t.dtype not in _DTYPE_TO_CODE: return [tensor_to_embedding_blob(t[i]) for i in range(t.shape[0])] if t.dtype == torch.bfloat16: arr = t.half().numpy() store_dtype = torch.bfloat16 else: arr = t.numpy() row_shape = tuple(t.shape[1:]) header = _compact_header(store_dtype, row_shape) raw = arr.tobytes() stride = len(raw) // t.shape[0] return [header + raw[i * stride:(i + 1) * stride] for i in range(t.shape[0])] def embedding_blob_to_tensor(blob: bytes, fallback_shape: Optional[Tuple[int, ...]] = None) -> torch.Tensor: """Deserialize a blob back to a tensor. Auto-detects compact vs legacy formats.""" if len(blob) >= 6 and blob[0] == _COMPACT_VERSION: dtype_code = blob[1] ndim = struct.unpack_from(' torch.nn.Module: """Compile model with torch.compile if possible. Skips compilation when dynamic=True (padding='longest') because flex attention's create_block_mask is incompatible with dynamic shapes under torch.compile, causing CUDA illegal memory access. """ if dynamic: print("Skipping torch.compile (dynamic shapes + flex attention incompatible)") return model try: model = torch.compile(model) print("Model compiled") except Exception as e: print(f"Skipping torch.compile: {e}") return model def build_collator( tokenizer: PreTrainedTokenizerBase, padding: str = 'max_length', max_length: int = 512, ) -> Callable[[List[str]], Dict[str, torch.Tensor]]: def _collate_fn(sequences: List[str]) -> Dict[str, torch.Tensor]: kwargs: Dict[str, Any] = dict( return_tensors="pt", padding=padding, truncation=True, max_length=max_length, ) if padding != 'max_length': kwargs['pad_to_multiple_of'] = 8 return tokenizer(sequences, **kwargs) return _collate_fn def _make_embedding_progress( dataloader: DataLoader, padding: str, n_warmup: int = 3, n_calibration: int = 5, ) -> Iterator[Tuple[int, Any]]: """Progress-bar wrapper for embedding loops. Drop-in replacement for enumerate(dataloader). When padding='max_length', all batches have uniform cost so plain tqdm works. When padding='longest' (sorted longest-first), batch times vary dramatically. In that case: yield warmup batches first (compiler warmup + OOM check on longest sequences), then time mid-length calibration batches to estimate total ETA. Keep in sync with protify/embedder.py and core/atlas/precomputed.py. """ total = len(dataloader) if padding == 'max_length' or total <= n_warmup + n_calibration: for i, batch in tqdm(enumerate(dataloader), total=total, desc='Embedding batches'): yield i, batch return dl_iter = iter(dataloader) # Phase 1: warmup on longest batches (first n_warmup, since sorted longest-first) warmup_bar = tqdm(range(n_warmup), desc='Warmup (longest batches)', leave=False) for i in warmup_bar: batch = next(dl_iter) yield i, batch warmup_bar.close() # Phase 2: skip to middle of dataset for calibration timing # We need to yield all intermediate batches too (they contain real data) mid_start = total // 2 intermediate_bar = tqdm( range(n_warmup, mid_start), desc='Embedding batches', leave=False, ) for i in intermediate_bar: batch = next(dl_iter) yield i, batch intermediate_bar.close() # Phase 3: time calibration batches from the middle calibration_times: List[float] = [] cal_bar = tqdm(range(n_calibration), desc='Calibrating ETA', leave=False) for j in cal_bar: t0 = time.perf_counter() batch = next(dl_iter) yield mid_start + j, batch calibration_times.append(time.perf_counter() - t0) cal_bar.close() avg_time = sum(calibration_times) / len(calibration_times) remaining_start = mid_start + n_calibration remaining_count = total - remaining_start estimated_total_seconds = avg_time * remaining_count # Phase 4: remaining batches with calibrated ETA main_bar = tqdm( range(remaining_count), desc='Embedding batches', bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]', ) main_bar.set_postfix_str(f'ETA ~{estimated_total_seconds:.0f}s (calibrated)') for k in main_bar: batch = next(dl_iter) yield remaining_start + k, batch main_bar.close() class _SQLWriter: """Context manager for async SQL embedding writes. Matches core/embed/storage.SQLEmbeddingWriter.""" def __init__(self, conn: sqlite3.Connection, queue_maxsize: int = 4) -> None: self._conn = conn self._queue: queue.Queue = queue.Queue(maxsize=queue_maxsize) self._thread: Optional[threading.Thread] = None def __enter__(self) -> "_SQLWriter": self._thread = threading.Thread(target=self._writer_loop, daemon=True) self._thread.start() return self def write_batch(self, rows: List[Tuple[str, bytes]]) -> None: self._queue.put(rows) def _writer_loop(self) -> None: cursor = self._conn.cursor() while True: item = self._queue.get() if item is None: break cursor.executemany("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", item) if self._queue.qsize() == 0: self._conn.commit() self._conn.commit() def __exit__(self, *exc) -> None: if self._thread is not None: self._queue.put(None) self._thread.join() self._thread = None class Pooler: def __init__(self, pooling_types: List[str]) -> None: self.pooling_types = pooling_types self.pooling_options: Dict[str, Callable] = { 'mean': self.mean_pooling, 'max': self.max_pooling, 'norm': self.norm_pooling, 'median': self.median_pooling, 'std': self.std_pooling, 'var': self.var_pooling, 'cls': self.cls_pooling, 'parti': self._pool_parti, } def _create_pooled_matrices_across_layers(self, attentions: torch.Tensor) -> torch.Tensor: assert isinstance(attentions, torch.Tensor) maxed_attentions = torch.max(attentions, dim=1)[0] return maxed_attentions def _page_rank(self, attention_matrix: np.ndarray, personalization: Optional[dict] = None, nstart: Optional[dict] = None, prune_type: str = "top_k_outdegree") -> Dict[int, float]: G = self._convert_to_graph(attention_matrix) if G.number_of_nodes() != attention_matrix.shape[0]: raise Exception( f"The number of nodes in the graph should be equal to the number of tokens in sequence! You have {G.number_of_nodes()} nodes for {attention_matrix.shape[0]} tokens.") if G.number_of_edges() == 0: raise Exception(f"You don't seem to have any attention edges left in the graph.") return nx.pagerank(G, alpha=0.85, tol=1e-06, weight='weight', personalization=personalization, nstart=nstart, max_iter=100) def _convert_to_graph(self, matrix: np.ndarray) -> nx.DiGraph: G = nx.from_numpy_array(matrix, create_using=nx.DiGraph) return G def _calculate_importance_weights(self, dict_importance: Dict[int, float], attention_mask: Optional[torch.Tensor] = None) -> np.ndarray: if attention_mask is not None: for k in list(dict_importance.keys()): if attention_mask[k] == 0: del dict_importance[k] total = sum(dict_importance.values()) return np.array([v / total for _, v in dict_importance.items()]) def _pool_parti(self, emb: torch.Tensor, attentions: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: maxed_attentions = self._create_pooled_matrices_across_layers(attentions).numpy() emb_pooled = [] for e, a, mask in zip(emb, maxed_attentions, attention_mask): dict_importance = self._page_rank(a) importance_weights = self._calculate_importance_weights(dict_importance, mask) num_tokens = int(mask.sum().item()) emb_pooled.append(np.average(e[:num_tokens], weights=importance_weights, axis=0)) pooled = torch.tensor(np.array(emb_pooled)) return pooled def mean_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.mean(dim=1) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) def max_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.max(dim=1).values else: mask = attention_mask.unsqueeze(-1).bool() return emb.masked_fill(~mask, float('-inf')).max(dim=1).values def norm_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.norm(dim=1, p=2) else: attention_mask = attention_mask.unsqueeze(-1) return (emb * attention_mask).norm(dim=1, p=2) def median_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.median(dim=1).values else: mask = attention_mask.unsqueeze(-1).bool() return emb.masked_fill(~mask, float('nan')).nanmedian(dim=1).values def std_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.std(dim=1) else: var = self.var_pooling(emb, attention_mask, **kwargs) return torch.sqrt(var) def var_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: if attention_mask is None: return emb.var(dim=1) else: attention_mask = attention_mask.unsqueeze(-1) mean = (emb * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) mean = mean.unsqueeze(1) squared_diff = (emb - mean) ** 2 var = (squared_diff * attention_mask).sum(dim=1) / attention_mask.sum(dim=1) return var def cls_pooling(self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: return emb[:, 0, :] def __call__( self, emb: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, attentions: Optional[torch.Tensor] = None ) -> torch.Tensor: if attention_mask is not None: assert attention_mask.sum(dim=-1).min() > 0, ( "Pooler received samples with all-zero attention masks. " "This causes NaN from division by zero. Filter empty inputs before pooling." ) final_emb: List[torch.Tensor] = [] for pooling_type in self.pooling_types: final_emb.append(self.pooling_options[pooling_type](emb=emb, attention_mask=attention_mask, attentions=attentions)) return torch.cat(final_emb, dim=-1) class ProteinDataset(TorchDataset): """Simple dataset for protein sequences.""" def __init__(self, sequences: List[str]) -> None: self.sequences = sequences def __len__(self) -> int: return len(self.sequences) def __getitem__(self, idx: int) -> str: return self.sequences[idx] def parse_fasta(fasta_path: str) -> List[str]: assert os.path.exists(fasta_path), f"FASTA file does not exist: {fasta_path}" sequences = [] current_seq = [] with open(fasta_path, 'r') as f: for line in f: line = line.strip() if not line: continue if line.startswith('>'): if current_seq: sequences.append(''.join(current_seq)) current_seq = [] else: current_seq.append(line) if current_seq: sequences.append(''.join(current_seq)) return sequences class EmbeddingMixin: def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: raise NotImplementedError @property def device(self) -> torch.device: """Get the device of the model.""" return next(self.parameters()).device def _read_sequences_from_db(self, db_path: str) -> Set[str]: """Read sequences from SQLite database.""" with sqlite3.connect(db_path, timeout=30) as conn: c = conn.cursor() c.execute("SELECT sequence FROM embeddings") return {row[0] for row in c.fetchall()} def _ensure_embeddings_table(self, conn: sqlite3.Connection) -> None: cursor = conn.cursor() cursor.execute( "CREATE TABLE IF NOT EXISTS embeddings (" "sequence TEXT PRIMARY KEY, " "embedding BLOB NOT NULL" ")" ) conn.commit() def load_embeddings_from_pth(self, save_path: str) -> Dict[str, torch.Tensor]: assert os.path.exists(save_path), f"Embedding file does not exist: {save_path}" payload = torch.load(save_path, map_location="cpu", weights_only=True) assert isinstance(payload, dict), "Expected .pth embeddings file to contain a dictionary." for sequence, tensor in payload.items(): assert isinstance(sequence, str), "Expected embedding dictionary keys to be sequences (str)." assert isinstance(tensor, torch.Tensor), "Expected embedding dictionary values to be tensors." return payload def load_embeddings_from_db(self, db_path: str, sequences: Optional[List[str]] = None) -> Dict[str, torch.Tensor]: assert os.path.exists(db_path), f"Embedding database does not exist: {db_path}" loaded: Dict[str, torch.Tensor] = {} with sqlite3.connect(db_path, timeout=30) as conn: self._ensure_embeddings_table(conn) cursor = conn.cursor() if sequences is None: cursor.execute("SELECT sequence, embedding FROM embeddings") else: if len(sequences) == 0: return loaded placeholders = ",".join(["?"] * len(sequences)) cursor.execute( f"SELECT sequence, embedding FROM embeddings WHERE sequence IN ({placeholders})", tuple(sequences), ) rows = cursor.fetchall() for row in rows: sequence = row[0] embedding_bytes = row[1] loaded[sequence] = embedding_blob_to_tensor(embedding_bytes) return loaded def embed_dataset( self, sequences: Optional[List[str]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, batch_size: int = 2, max_len: int = 512, truncate: bool = True, full_embeddings: bool = False, embed_dtype: torch.dtype = torch.float32, pooling_types: List[str] = ['mean'], num_workers: int = 0, sql: bool = False, save: bool = True, sql_db_path: str = 'embeddings.db', save_path: str = 'embeddings.pth', fasta_path: Optional[str] = None, padding: str = 'max_length', **kwargs, ) -> Optional[Dict[str, torch.Tensor]]: """ Embed a dataset of protein sequences. Supports two modes: - Tokenizer mode (ESM2/ESM++): provide `tokenizer`, `_embed(input_ids, attention_mask)` is used. - Sequence mode (E1): pass `tokenizer=None`, `_embed(sequences, return_attention_mask=True, **kwargs)` is used. Sequences can be supplied as a list via `sequences`, parsed from a FASTA file via `fasta_path`, or both (the two sources are combined). At least one must be provided. """ if fasta_path is not None: fasta_sequences = parse_fasta(fasta_path) sequences = list(sequences or []) + fasta_sequences assert sequences is not None and len(sequences) > 0, \ "Must provide at least one sequence via `sequences` or `fasta_path`." sequences = list(set([seq[:max_len] if truncate else seq for seq in sequences])) sequences = sorted(sequences, key=len, reverse=True) hidden_size = self.config.hidden_size pooler = Pooler(pooling_types) if not full_embeddings else None tokenizer_mode = tokenizer is not None # Resolve padding and compilation dynamic = padding == 'longest' compiled_model = maybe_compile(self, dynamic=dynamic) if tokenizer_mode: collate_fn = build_collator(tokenizer, padding=padding, max_length=max_len) device = self.device else: collate_fn = None device = None def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: assert isinstance(residue_embeddings, torch.Tensor) if full_embeddings or residue_embeddings.ndim == 2: return residue_embeddings return pooler(residue_embeddings, attention_mask) def iter_batches(to_embed: List[str]): if tokenizer_mode: assert collate_fn is not None assert device is not None dataset = ProteinDataset(to_embed) dataloader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, prefetch_factor=2 if num_workers > 0 else None, collate_fn=collate_fn, shuffle=False, pin_memory=True, ) for i, batch in _make_embedding_progress(dataloader, padding): seqs = to_embed[i * batch_size:(i + 1) * batch_size] input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) residue_embeddings = compiled_model._embed(input_ids, attention_mask) yield seqs, residue_embeddings, attention_mask else: for batch_start in tqdm(range(0, len(to_embed), batch_size), desc='Embedding batches'): seqs = to_embed[batch_start:batch_start + batch_size] batch_output = compiled_model._embed(seqs, return_attention_mask=True, **kwargs) assert isinstance(batch_output, tuple), "Sequence mode _embed must return (last_hidden_state, attention_mask)." assert len(batch_output) == 2, "Sequence mode _embed must return exactly two values." residue_embeddings, attention_mask = batch_output assert isinstance(attention_mask, torch.Tensor), "Sequence mode _embed must return attention_mask as a torch.Tensor." yield seqs, residue_embeddings, attention_mask if sql: # Step 1: DEDUPLICATE - check existing embeddings in SQL conn = sqlite3.connect(sql_db_path, timeout=30, check_same_thread=False) conn.execute('PRAGMA journal_mode=WAL') conn.execute('PRAGMA busy_timeout=30000') conn.execute('PRAGMA synchronous=OFF') conn.execute('PRAGMA cache_size=-64000') self._ensure_embeddings_table(conn) already_embedded = self._read_sequences_from_db(sql_db_path) to_embed = [seq for seq in sequences if seq not in already_embedded] print(f"Found {len(already_embedded)} already embedded sequences in {sql_db_path}") print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: # Steps 4-7: BATCH+EMBED -> POOL/TRIM -> SERIALIZE -> WRITE (async) with _SQLWriter(conn) as writer: with torch.inference_mode(): for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) if full_embeddings: batch_rows = [] for seq, emb, mask in zip(seqs, embeddings, attention_mask): batch_rows.append((seq, tensor_to_embedding_blob(emb[mask.bool()].reshape(-1, hidden_size)))) else: blobs = batch_tensor_to_blobs(embeddings) batch_rows = list(zip(seqs, blobs)) writer.write_batch(batch_rows) conn.close() return None embeddings_dict = {} if os.path.exists(save_path): embeddings_dict = self.load_embeddings_from_pth(save_path) to_embed = [seq for seq in sequences if seq not in embeddings_dict] print(f"Found {len(embeddings_dict)} already embedded sequences in {save_path}") print(f"Embedding {len(to_embed)} new sequences") else: to_embed = sequences print(f"Embedding {len(to_embed)} new sequences") if len(to_embed) > 0: with torch.inference_mode(): for seqs, residue_embeddings, attention_mask in iter_batches(to_embed): embeddings = get_embeddings(residue_embeddings, attention_mask).to(embed_dtype) for seq, emb, mask in zip(seqs, embeddings, attention_mask): if full_embeddings: emb = emb[mask.bool()].reshape(-1, hidden_size) embeddings_dict[seq] = emb.cpu() if save: torch.save(embeddings_dict, save_path) return embeddings_dict if __name__ == "__main__": # py -m pooler pooler = Pooler(pooling_types=['max', 'parti']) batch_size = 8 seq_len = 64 hidden_size = 128 num_layers = 12 emb = torch.randn(batch_size, seq_len, hidden_size) attentions = torch.randn(batch_size, num_layers, seq_len, seq_len) attention_mask = torch.ones(batch_size, seq_len) y = pooler(emb=emb, attention_mask=attention_mask, attentions=attentions) print(y.shape) """Shared attention infrastructure for all FastPLMs models. Contains: AttentionBackend enum, backend resolution, mask creation, flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities. """ from enum import Enum from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange try: from torch.nn.attention.flex_attention import create_block_mask, flex_attention, BlockMask except ImportError: create_block_mask = None flex_attention = None BlockMask = None _compiled_flex_attention = None def _get_flex_attention_fn(): """Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.""" global _compiled_flex_attention if flex_attention is None: return None flex_mod = torch.nn.attention.flex_attention if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False): return flex_attention if _compiled_flex_attention is None: _compiled_flex_attention = torch.compile( flex_attention, dynamic=False, ) return _compiled_flex_attention ### Kernels Flash Attention Detection def _infer_kernels_flash_variant(kernel) -> Optional[str]: if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"): return "flash_attn2" if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"): return "flash_attn3" return None def _try_get_kernels_flash(): try: from kernels import get_kernel except ImportError: return None, None flash_kernel = None flash_kernel_variant = None try: flash_kernel = get_kernel("kernels-community/flash-attn3") flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) assert flash_kernel_variant is not None, "Loaded flash-attn3 kernel does not expose a supported API." except Exception: try: flash_kernel = get_kernel("kernels-community/flash-attn2") flash_kernel_variant = _infer_kernels_flash_variant(flash_kernel) assert flash_kernel_variant is not None, "Loaded flash-attn2 kernel does not expose a supported API." except Exception: flash_kernel = None flash_kernel_variant = None return flash_kernel, flash_kernel_variant _FLASH_KERNELS_LOADED = False FLASH_KERNEL = None FLASH_KERNEL_VARIANT = None def _ensure_flash_kernels_loaded(): global _FLASH_KERNELS_LOADED, FLASH_KERNEL, FLASH_KERNEL_VARIANT if _FLASH_KERNELS_LOADED: return _FLASH_KERNELS_LOADED = True FLASH_KERNEL, FLASH_KERNEL_VARIANT = _try_get_kernels_flash() def _kernels_flash_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, causal: bool = False, ) -> torch.Tensor: assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." if FLASH_KERNEL_VARIANT == "flash_attn2": return FLASH_KERNEL.fwd(q=query_states, k=key_states, v=value_states, is_causal=causal)[0] if FLASH_KERNEL_VARIANT == "flash_attn3": try: output = FLASH_KERNEL.flash_attn_func(q=query_states, k=key_states, v=value_states, causal=causal) except TypeError: output = FLASH_KERNEL.flash_attn_func(query_states, key_states, value_states, 0.0, None, causal) if isinstance(output, tuple): return output[0] return output raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") def _kernels_flash_varlen_forward( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_in_batch_q: int, max_seqlen_in_batch_k: int, causal: bool = False, ) -> torch.Tensor: assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." if FLASH_KERNEL_VARIANT == "flash_attn2": return FLASH_KERNEL.varlen_fwd( q=query_states, k=key_states, v=value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, is_causal=causal, )[0] if FLASH_KERNEL_VARIANT == "flash_attn3": try: output = FLASH_KERNEL.flash_attn_varlen_func( q=query_states, k=key_states, v=value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_in_batch_q, max_seqlen_k=max_seqlen_in_batch_k, causal=causal, ) except TypeError: output = FLASH_KERNEL.flash_attn_varlen_func( query_states, key_states, value_states, cu_seqlens_q, cu_seqlens_k, max_seqlen_in_batch_q, max_seqlen_in_batch_k, 0.0, None, causal, ) if isinstance(output, tuple): return output[0] return output raise AssertionError(f"Unsupported kernels flash attention variant: {FLASH_KERNEL_VARIANT}") ### Unpad / Pad helpers for varlen flash attention class IndexFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, input, indices) -> torch.Tensor: ctx.save_for_backward(indices) assert input.ndim >= 2 ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] second_dim = other_shape.numel() return torch.gather( rearrange(input, "b ... -> b (...)"), 0, indices.unsqueeze(1).expand(-1, second_dim) ).reshape(-1, *other_shape) @staticmethod def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]: (indices,) = ctx.saved_tensors assert grad_output.ndim >= 2 other_shape = grad_output.shape[1:] grad_output = rearrange(grad_output, "b ... -> b (...)") grad_input = torch.zeros( [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype ) grad_input.scatter_(0, indices.unsqueeze(1).expand(-1, grad_output.shape[1]), grad_output) return grad_input.reshape(ctx.first_axis_dim, *other_shape), None class IndexPutFirstAxis(torch.autograd.Function): @staticmethod def forward(ctx, values, indices, first_axis_dim) -> torch.Tensor: ctx.save_for_backward(indices) assert indices.ndim == 1 assert values.ndim >= 2 output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) output[indices] = values return output @staticmethod def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]: (indices,) = ctx.saved_tensors return grad_output[indices], None, None index_first_axis = IndexFirstAxis.apply index_put_first_axis = IndexPutFirstAxis.apply def pad_input(hidden_states: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int) -> torch.Tensor: output = index_put_first_axis(hidden_states, indices, batch * seqlen) return rearrange(output, "(b s) ... -> b s ...", b=batch) def _unpad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, attention_mask_2d: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: batch_size, seq_len, num_heads, head_dim = query_layer.shape seqlens = attention_mask_2d.sum(dim=1).int() cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0)) max_seqlen = int(seqlens.max().item()) indices = attention_mask_2d.flatten().nonzero(as_tuple=False).flatten() query_layer = index_first_axis(query_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) key_layer = index_first_axis(key_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) value_layer = index_first_axis(value_layer.reshape(batch_size * seq_len, num_heads, head_dim), indices) return query_layer, key_layer, value_layer, indices, (cu_seqlens, cu_seqlens), (max_seqlen, max_seqlen) def kernels_flash_attention_func( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, attention_mask_2d: Optional[torch.Tensor] = None, causal: bool = False, ) -> torch.Tensor: assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." if not causal and attention_mask_2d is not None: batch_size, q_len = query_states.shape[:2] ( query_states, key_states, value_states, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k), ) = _unpad_input(query_states, key_states, value_states, attention_mask_2d) attn_output_unpad = _kernels_flash_varlen_forward( query_states=query_states, key_states=key_states, value_states=value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_in_batch_q=max_seqlen_q, max_seqlen_in_batch_k=max_seqlen_k, ) return pad_input(attn_output_unpad, indices_q, batch_size, q_len) else: return _kernels_flash_forward( query_states=query_states, key_states=key_states, value_states=value_states, causal=causal, ) ### Attention Backend Enum & Resolution class AttentionBackend(Enum): AUTO = "auto" KERNELS_FLASH = "kernels_flash" FLEX = "flex" SDPA = "sdpa" VALID_ATTENTION_BACKENDS = tuple(b.value for b in AttentionBackend) _BACKEND_CONFIRMED = False def resolve_attention_backend(requested_backend: str) -> AttentionBackend: global _BACKEND_CONFIRMED assert requested_backend in VALID_ATTENTION_BACKENDS, ( f"Unsupported attention backend: {requested_backend}. Expected one of {VALID_ATTENTION_BACKENDS}." ) if requested_backend in (AttentionBackend.AUTO.value, AttentionBackend.KERNELS_FLASH.value): _ensure_flash_kernels_loaded() if requested_backend == AttentionBackend.AUTO.value: if FLASH_KERNEL is not None: resolved = AttentionBackend.KERNELS_FLASH elif flex_attention is not None: resolved = AttentionBackend.FLEX else: resolved = AttentionBackend.SDPA elif requested_backend == AttentionBackend.KERNELS_FLASH.value: assert FLASH_KERNEL is not None, "Kernels Flash Attention is not available in this environment." resolved = AttentionBackend.KERNELS_FLASH elif requested_backend == AttentionBackend.FLEX.value: assert flex_attention is not None, "Flex Attention is not available in this environment." resolved = AttentionBackend.FLEX elif requested_backend == AttentionBackend.SDPA.value: resolved = AttentionBackend.SDPA else: raise AssertionError(f"Unsupported attention backend: {requested_backend}") if not _BACKEND_CONFIRMED: print(f"Attention backend: config='{requested_backend}' -> resolved='{resolved.value}'") _BACKEND_CONFIRMED = True return resolved @torch.compiler.disable def get_attention_mask( effective_backend: AttentionBackend, batch_size: int, seq_len: int, device: torch.device, attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]: """Build padding masks once for all encoder layers. Returns (attention_mask_2d, attention_mask_4d, flex_block_mask). """ if attention_mask is None: return None, None, None attention_mask_2d = attention_mask.bool() if effective_backend == AttentionBackend.KERNELS_FLASH: return attention_mask_2d, None, None if effective_backend == AttentionBackend.FLEX: assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable." valid_lens = attention_mask_2d.sum(dim=-1) def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx]) flex_block_mask = create_block_mask(mask_mod, batch_size, 1, seq_len, seq_len, device=device) return attention_mask_2d, None, flex_block_mask # SDPA / manual -- only mask the key dimension so padding query positions attend to # real keys and produce valid (non-NaN) outputs instead of NaN from softmax(-inf,...,-inf). attention_mask_4d = attention_mask_2d[:, None, None, :] return attention_mask_2d, attention_mask_4d, None import os from enum import Enum from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple, TypedDict, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from tokenizers import Tokenizer from transformers import PretrainedConfig, PreTrainedModel from transformers.activations import ACT2FN from transformers.modeling_outputs import ModelOutput from transformers.utils import logging logger = logging.get_logger(__name__) from torch.nn.attention.flex_attention import _create_sparse_block_from_block_mask try: from kernels import get_kernel layer_norm = get_kernel("kernels-community/triton-layer-norm") except Exception as e: logger.warning(f"Failed to load triton layer norm kernel: {e}; Will be using PyTorch RMSNorm instead") layer_norm = None @torch.compiler.disable def create_block_causal_mask_optimized(sequence_ids: torch.Tensor) -> BlockMask: # Assumes sequence_ids is sorted in increasing order for each batch item, except for # the -1 values, which are used to indicate the padding tokens. def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def] return ( (sequence_ids[b, q_idx] >= sequence_ids[b, kv_idx]) & (sequence_ids[b, q_idx] != -1) & (sequence_ids[b, kv_idx] != -1) ) batch_size, seqlen = sequence_ids.shape return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device) @torch.compiler.disable def create_within_seq_block_mask(sequence_ids: torch.Tensor) -> BlockMask: def document_mask(b, h, q_idx, kv_idx): # type: ignore[no-untyped-def] return ( (sequence_ids[b, q_idx] == sequence_ids[b, kv_idx]) & (sequence_ids[b, q_idx] != -1) & (sequence_ids[b, kv_idx] != -1) ) batch_size, seqlen = sequence_ids.shape return create_block_mask(document_mask, batch_size, 1, seqlen, seqlen, device=sequence_ids.device) def build_within_seq_mask_4d(sequence_ids: torch.Tensor) -> torch.Tensor: not_pad = (sequence_ids != -1) same_seq = sequence_ids.unsqueeze(-1) == sequence_ids.unsqueeze(-2) valid = not_pad.unsqueeze(-1) & not_pad.unsqueeze(-2) return (same_seq & valid).unsqueeze(1) def build_block_causal_mask_4d(sequence_ids: torch.Tensor) -> torch.Tensor: not_pad = (sequence_ids != -1) causal = sequence_ids.unsqueeze(-1) >= sequence_ids.unsqueeze(-2) valid = not_pad.unsqueeze(-1) & not_pad.unsqueeze(-2) return (causal & valid).unsqueeze(1) def flex_attention_func( query_states: torch.Tensor, # (bs, seqlen, nh, hs) key_states: torch.Tensor, # (bs, seqlen, nkv, hs) value_states: torch.Tensor, # (bs, seqlen, nkv, hs) score_mod: Optional[Callable] = None, block_mask: Optional[BlockMask] = None, ) -> torch.Tensor: assert flex_attention is not None, "Flex Attention is not available in this environment" assert score_mod is None, "Score mod is not supported yet" query_states = query_states.transpose(1, 2).contiguous() # (bs, nh, seqlen, hs) key_states = key_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) value_states = value_states.transpose(1, 2).contiguous() # (bs, nkv, seqlen, hs) fn = _get_flex_attention_fn() outputs = fn( query_states, key_states, value_states, block_mask=block_mask, score_mod=score_mod, enable_gqa=query_states.shape[1] != key_states.shape[1], # if nkv != nh ) outputs = outputs.transpose(1, 2) # (bs, seqlen, nh, hs) return outputs def kernels_flash_attention_func( query_states: torch.Tensor, # (bs, seqlen, nh, hs) key_states: torch.Tensor, # (bs, seqlen, nkv, hs) value_states: torch.Tensor, # (bs, seqlen, nkv, hs) q_sequence_ids: torch.Tensor, k_sequence_ids: torch.Tensor, causal: bool = False, ) -> torch.Tensor: # (bs, seqlen, nh, hs) assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment." if not causal: batch_size, q_len = query_states.shape[0], query_states.shape[1] ( query_states, key_states, value_states, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) attn_output_unpad = _kernels_flash_varlen_forward( query_states, key_states, value_states, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_in_batch_q=max_seqlen_in_batch_q, max_seqlen_in_batch_k=max_seqlen_in_batch_k, causal=False, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, q_len) else: attn_output = _kernels_flash_forward(query_states, key_states, value_states, causal=True) return attn_output def block_min_max_seq_ids(SLEN: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: device = SLEN.device total_tokens = torch.sum(SLEN) B = (total_tokens + block_size - 1) // block_size padding_tokens = B * block_size - total_tokens SLEN = torch.cat([SLEN, padding_tokens.reshape(1).to(device=device, dtype=SLEN.dtype)], dim=0) assert torch.sum(SLEN) == B * block_size # Cumulative ends (exclusive) for each sequence; cum[i] == end offset of seq i cum = torch.cumsum(SLEN.to(torch.long), dim=0) # (N,) total_tokens = cum[-1].item() # Block start/end offsets [start, end) in token index space block_starts = torch.arange(0, B * block_size, block_size, device=device, dtype=torch.long) # (B,) block_ends = torch.minimum(block_starts + block_size, torch.tensor(total_tokens, device=device)) # (B,) # MIN_SEQ_ID[i] = first sequence whose end > block_start # searchsorted with right=True returns first index where cum > value MIN_SEQ_ID = torch.searchsorted(cum, block_starts, right=True) # MAX_SEQ_ID[i] = sequence containing the last token in the block (block_end - 1) # For empty tail beyond total_tokens we already clipped block_ends. last_token_in_block = torch.clamp(block_ends - 1, min=0) # valid only if block has at least 1 token MAX_SEQ_ID = torch.searchsorted(cum, last_token_in_block, right=True) return MIN_SEQ_ID, MAX_SEQ_ID def get_overlapping_blocks(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: MIN_Q, MAX_Q = block_min_max_seq_ids(SLEN_Q) MIN_K, MAX_K = block_min_max_seq_ids(SLEN_K) cond1 = MIN_Q.unsqueeze(1) <= MAX_K.unsqueeze(0) cond2 = MIN_K.unsqueeze(0) <= MAX_Q.unsqueeze(1) overlap = cond1 & cond2 cond1 = (MIN_Q == MAX_Q).unsqueeze(1) cond2 = (MIN_K == MAX_K).unsqueeze(0) same_seq_in_qk = cond1 & cond2 full_blocks = overlap & same_seq_in_qk partial_blocks = overlap & ~same_seq_in_qk return full_blocks, partial_blocks @torch.compiler.disable def direct_block_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: full_blocks, partial_blocks = get_overlapping_blocks(SLEN_Q, SLEN_K) partial_blocks = partial_blocks[None, None] full_blocks = full_blocks[None, None] q_doc_id = torch.repeat_interleave(SLEN_Q) k_doc_id = torch.repeat_interleave(SLEN_K) def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: return q_doc_id[q_idx] == k_doc_id[kv_idx] total_q_len = q_doc_id.shape[0] total_k_len = k_doc_id.shape[0] return _create_sparse_block_from_block_mask( (partial_blocks, full_blocks), doc_mask, seq_lengths=(total_q_len, total_k_len), Q_BLOCK_SIZE=128, KV_BLOCK_SIZE=128, ) @torch.compiler.disable def doc_id_mask(SLEN_Q: torch.Tensor, SLEN_K: torch.Tensor) -> BlockMask: q_doc_id = torch.repeat_interleave(SLEN_Q) k_doc_id = torch.repeat_interleave(SLEN_K) def doc_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: return q_doc_id[q_idx] == k_doc_id[kv_idx] total_q_len = q_doc_id.shape[0] total_k_len = k_doc_id.shape[0] return create_block_mask(doc_mask, 1, 1, total_q_len, total_k_len, BLOCK_SIZE=128, device=SLEN_Q.device) def varlen_flex_attention_func( query_states: torch.Tensor, key_states: torch.Tensor, value_states: torch.Tensor, q_sequence_ids: torch.Tensor, k_sequence_ids: torch.Tensor, ) -> torch.Tensor: batch_size, q_len = query_states.shape[0], query_states.shape[1] ( query_states, key_states, value_states, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) = _unpad_input(query_states, key_states, value_states, q_sequence_ids, k_sequence_ids) query_states = query_states.unsqueeze(0).transpose(1, 2).contiguous() key_states = key_states.unsqueeze(0).transpose(1, 2).contiguous() value_states = value_states.unsqueeze(0).transpose(1, 2).contiguous() seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1] block_mask = block_mask_creator(seqlens_q, seqlens_k) fn = _get_flex_attention_fn() attn_output_unpad = fn( query_states, key_states, value_states, block_mask=block_mask, enable_gqa=query_states.shape[1] != key_states.shape[1], ) attn_output = pad_input(attn_output_unpad.transpose(1, 2).squeeze(0), indices_q, batch_size, q_len) return attn_output def _get_unpad_data(sequence_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: non_pad_indices = sequence_ids != -1 non_pad_indices = torch.nonzero(non_pad_indices.flatten(), as_tuple=False).flatten() sequence_ids = sequence_ids + torch.arange(len(sequence_ids), device=sequence_ids.device)[:, None] * 1e5 sequence_ids = sequence_ids.flatten()[non_pad_indices] _, seqlens_in_batch = torch.unique_consecutive(sequence_ids, return_counts=True) max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return non_pad_indices, cu_seqlens, max_seqlen_in_batch def _unpad_input( query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor, q_sequence_ids: torch.Tensor, k_sequence_ids: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]: batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape query_length, num_q_heads = query_layer.shape[1], query_layer.shape[2] assert query_layer.shape[:2] == q_sequence_ids.shape, ( f"Shape mismatch between query layer and query sequence ids: {query_layer.shape[:2]} != {q_sequence_ids.shape}" ) assert key_layer.shape[:2] == k_sequence_ids.shape, ( f"Shape mismatch between key layer and key sequence ids: {key_layer.shape[:2]} != {k_sequence_ids.shape}" ) assert query_length <= kv_seq_len, ( f"Query length should be less than or equal to KV sequence length: {query_length} <= {kv_seq_len}" ) indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(k_sequence_ids) key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) if torch.equal(q_sequence_ids, k_sequence_ids): indices_q = indices_k cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k else: indices_q, cu_seqlens_q, max_seqlen_in_batch_q = _get_unpad_data(q_sequence_ids) query_layer = index_first_axis(query_layer.reshape(batch_size * query_length, num_q_heads, head_dim), indices_q) assert cu_seqlens_q.shape == cu_seqlens_k.shape, ( f"Query and KV should have the same number of sequences: {cu_seqlens_q.shape} != {cu_seqlens_k.shape}" ) return ( query_layer, key_layer, value_layer, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) block_mask_creator = direct_block_mask if os.getenv("FAST_BLOCK_MASK", "1") == "1" else doc_id_mask PAD_TOKEN_ID = 0 def get_tokenizer() -> Tokenizer: try: fname = os.path.join(os.path.dirname(__file__), "tokenizer.json") tokenizer: Tokenizer = Tokenizer.from_file(fname) except Exception: print("E1 Tokenizer not found in local directory, downloading from Hugging Face") from huggingface_hub import hf_hub_download fname = hf_hub_download(repo_id="Synthyra/Profluent-E1-150M", filename="tokenizer.json") tokenizer: Tokenizer = Tokenizer.from_file(fname) assert tokenizer.padding["pad_id"] == PAD_TOKEN_ID, ( f"Padding token id must be {PAD_TOKEN_ID}, but got {tokenizer.padding['pad_id']}" ) return tokenizer @dataclass class DataPrepConfig: max_num_sequences: int = 512 max_num_positions_within_seq: int = 8192 remove_X_tokens: bool = False def get_context(sequence: str) -> Optional[str]: if "," in sequence: return sequence.rsplit(",", 1)[0] return None class E1BatchPreparer: def __init__( self, data_prep_config: Optional[DataPrepConfig] = None, tokenizer: Optional[Tokenizer] = None, preserve_context_labels: bool = False, ): self.tokenizer = tokenizer or get_tokenizer() self.data_prep_config = data_prep_config or DataPrepConfig() self.pad_token_id = self.tokenizer.token_to_id("") self.preserve_context_labels = preserve_context_labels device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") self.boundary_token_ids = torch.tensor( [self.tokenizer.token_to_id(token) for token in ["", "", "1", "2", ""]], device=device ).long() self.mask_token = "?" # nosec self.mask_token_id = self.tokenizer.token_to_id(self.mask_token) self.X_token_id = self.tokenizer.token_to_id("X") self.vocab = self.tokenizer.get_vocab() def get_batch_kwargs( # type: ignore[override] self, sequences: List[str], device: torch.device = torch.device("cpu"), non_blocking: bool = False ) -> Dict[str, Union[torch.Tensor, List[str], List[int]]]: sequence_encodings = [self.prepare_multiseq(sequence) for sequence in sequences] return self.pad_encodings(sequence_encodings, device, non_blocking) def pad_encodings( self, sequence_encodings: List[Dict[str, torch.Tensor]], device: torch.device = torch.device("cpu"), non_blocking: bool = False, ) -> Dict[str, Union[torch.Tensor, List[str], List[int]]]: non_blocking = non_blocking and device.type == "cuda" padded_encodings = {} # Note: We use -1 as the padding value for sequence and position ids because the 0 value # is a valid value for sequence and position ids. -1 is then used to distinguish valid # tokens from padding tokens, for example, when doing padding/unpadding for flash attention. for key, padding_value in { "input_ids": self.pad_token_id, "sequence_ids": -1, "within_seq_position_ids": -1, "global_position_ids": -1, "labels": self.pad_token_id, }.items(): padded_encodings[key] = pad_sequence( [enc[key] for enc in sequence_encodings], batch_first=True, padding_value=padding_value ).to(device=device, dtype=torch.long, non_blocking=non_blocking) padded_encodings["context"] = [enc["context"] for enc in sequence_encodings] padded_encodings["context_len"] = [enc["context_len"] for enc in sequence_encodings] return padded_encodings def prepare_multiseq(self, sequence: str) -> Dict[str, Union[torch.Tensor, str, int]]: single_sequences = sequence.split(",") if len(single_sequences) > self.data_prep_config.max_num_sequences: raise ValueError( f"Number of sequences {len(single_sequences)} exceeds max number of sequences {self.data_prep_config.max_num_sequences}" " in the provided multi-sequence instance. Please remove some homologous sequences before trying again." ) single_sequence_encodings = [self.prepare_singleseq(sequence) for sequence in single_sequences] num_tokens = [len(x["input_ids"]) for x in single_sequence_encodings] input_ids = torch.cat([x["input_ids"] for x in single_sequence_encodings]) labels = torch.cat([x["labels"] for x in single_sequence_encodings]) within_seq_position_ids = torch.cat([encoding["position_ids"] for encoding in single_sequence_encodings]) global_position_ids, ctx_len = [], 0 for encoding in single_sequence_encodings: global_position_ids.append(encoding["position_ids"] + ctx_len) ctx_len = max(ctx_len, encoding["position_ids"].max().item() + ctx_len + 1) global_position_ids = torch.cat(global_position_ids) sequence_ids = torch.repeat_interleave(torch.tensor(num_tokens)) # Get multi-seq context & mask out all but last sequence in multi-seq instance if desired context_len = sum(num_tokens[:-1]) context = self.tokenizer.decode(input_ids[:context_len].tolist(), skip_special_tokens=False) if not self.preserve_context_labels: labels[:context_len] = self.pad_token_id assert ( input_ids.shape == sequence_ids.shape == within_seq_position_ids.shape == global_position_ids.shape == labels.shape ), "Input ids, sequence ids, within seq position ids, global position ids, and labels must have the same shape" assert input_ids.shape[0] >= context_len, "Input ids must have at least as many tokens as the context length" return { "input_ids": input_ids, "sequence_ids": sequence_ids, "within_seq_position_ids": within_seq_position_ids, "global_position_ids": global_position_ids, "labels": labels, "context": context, "context_len": context_len, } def prepare_singleseq(self, sequence: str) -> Dict[str, torch.Tensor]: if not self.validate_sequence(sequence): raise ValueError(f"Invalid sequence: {sequence}; Input sequence should contain [A-Z] or ? characters only") if len(sequence) > self.data_prep_config.max_num_positions_within_seq: raise ValueError( f"Sequence length {len(sequence)} exceeds max length {self.data_prep_config.max_num_positions_within_seq}" ) # Can also use `tokens = torch.tensor(self.tokenizer.encode(f"1{sequence}2").ids)` # but following is faster since our vocabulary is simple. tokens = torch.tensor([self.vocab[token] for token in ["", "1", *sequence, "2", ""]]) position_ids = torch.arange(len(tokens)) if self.data_prep_config.remove_X_tokens: X_positions = torch.where(tokens != self.X_token_id)[0] tokens = tokens[X_positions] position_ids = position_ids[X_positions] return {"input_ids": tokens, "labels": tokens, "position_ids": position_ids} def get_boundary_token_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: return torch.isin(tokens, self.boundary_token_ids.to(tokens.device)) def get_mask_positions_mask(self, tokens: torch.Tensor) -> torch.BoolTensor: return tokens == self.mask_token_id def validate_sequence(self, sequence: str) -> bool: assert isinstance(sequence, str), "Sequence must be a string" sequence = sequence.replace(self.mask_token, "") return sequence.isalpha() and sequence.isupper() class E1Config(PretrainedConfig): model_type = "E1" keys_to_ignore_at_inference = ["past_key_values"] def __init__( # type: ignore self, # Model architecture/initialization vocab_size=None, hidden_size=4096, intermediate_size=16384, gated_mlp=False, num_hidden_layers=40, num_attention_heads=32, num_key_value_heads=8, hidden_act="silu", rms_norm_eps=1e-5, initializer_range=0.02, dtype="bfloat16", gradient_checkpointing=False, no_ffn_gradient_checkpointing=False, # Tokenization pad_token_id=None, bos_token_id=None, eos_token_id=None, tie_word_embeddings=False, # Attention implementation & rotary positional embeddings global_attention_every_n_layers=0, max_num_sequences=512, max_num_positions_within_seq=8192, max_num_positions_global=1024 * 128, rope_theta_within_seq=10000.0, rope_theta_global=100000.0, clip_qkv=None, attn_backend="sdpa", **kwargs, ) -> None: tokenizer = get_tokenizer() super().__init__( pad_token_id=tokenizer.token_to_id(""), bos_token_id=tokenizer.token_to_id(""), eos_token_id=tokenizer.token_to_id(""), tie_word_embeddings=tie_word_embeddings, dtype=dtype, **kwargs, ) self.hidden_size = hidden_size if intermediate_size is None: intermediate_size = 3 * hidden_size if gated_mlp else 4 * hidden_size self.intermediate_size = intermediate_size self.gated_mlp = gated_mlp self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.max_num_positions_within_seq = max_num_positions_within_seq self.max_num_positions_global = max_num_positions_global # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.rope_theta_within_seq = rope_theta_within_seq self.rope_theta_global = rope_theta_global self.max_num_sequences = max_num_sequences assert clip_qkv is None or clip_qkv > 0 self.clip_qkv = clip_qkv self.global_attention_every_n_layers = global_attention_every_n_layers self.vocab_size = tokenizer.get_vocab_size() self.gradient_checkpointing = gradient_checkpointing self.no_ffn_gradient_checkpointing = no_ffn_gradient_checkpointing self.attn_backend = attn_backend if vocab_size is not None: if vocab_size < self.vocab_size: logger.warning( f"Using vocab_size {vocab_size} smaller than {self.vocab_size} from tokenizer. MAKE SURE THIS IS INTENTIONAL." ) self.vocab_size = vocab_size elif vocab_size > self.vocab_size: logger.warning(f"Using vocab_size {vocab_size} instead of smaller {self.vocab_size} from tokenizer.") self.vocab_size = vocab_size if pad_token_id is not None and pad_token_id != self.pad_token_id: logger.warning(f"Ignoring pad_token_id. Using {self.pad_token_id} from tokenizer") if bos_token_id is not None and bos_token_id != self.bos_token_id: logger.warning(f"Ignoring bos_token_id. Using {self.bos_token_id} from tokenizer") if eos_token_id is not None and eos_token_id != self.eos_token_id: logger.warning(f"Ignoring eos_token_id. Using {self.eos_token_id} from tokenizer") class DynamicCache: """ A cache layer that grows dynamically as more tokens are generated. This is the default for generative models. It stores the key and value states as tensors of shape `[batch_size, seq_len, num_heads, head_dim]`. Args: key_cache (`list[torch.Tensor]`): The list of key states. value_cache (`list[torch.Tensor]`): The list of value states. """ def __init__(self) -> None: self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int ) -> Tuple[torch.Tensor, torch.Tensor]: """ Update the key and value caches in-place, and return the necessary keys and value states. Args: key_states (`torch.Tensor`): The new key states to cache of shape [batch_size, seq_len, num_heads, head_dim] value_states (`torch.Tensor`): The new value states to cache of shape [batch_size, seq_len, num_heads, head_dim] layer_idx (`int`): The index of the layer to update. Returns: tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states of shape [batch_size, seq_len, num_heads, head_dim]. """ # Lazy initialization if len(self.key_cache) <= layer_idx: # There may be skipped layers, fill them with empty lists for _ in range(len(self.key_cache), layer_idx): self.key_cache.append(torch.tensor([])) self.value_cache.append(torch.tensor([])) self.key_cache.append(key_states) self.value_cache.append(value_states) elif ( not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model ): # fills previously skipped layers; checking for tensor causes errors self.key_cache[layer_idx] = key_states self.value_cache[layer_idx] = value_states else: self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=1) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=1) return self.key_cache[layer_idx], self.value_cache[layer_idx] def get_seq_length(self, layer_idx: int = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" is_empty_layer = ( len(self.key_cache) == 0 # no cache in any layer or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it or not self.key_cache[layer_idx].numel() # the layer has no cache ) layer_seq_length = self.key_cache[layer_idx].shape[1] if not is_empty_layer else 0 return layer_seq_length def crop(self, max_length: int) -> None: """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search.""" assert max_length > 0, "max_length must be positive" if self.get_seq_length() <= max_length: return for layer_idx in range(len(self.key_cache)): if self.key_cache[layer_idx].numel(): self.key_cache[layer_idx] = self.key_cache[layer_idx][:, :max_length, ...] self.value_cache[layer_idx] = self.value_cache[layer_idx][:, :max_length, ...] def batch_repeat_interleave(self, repeats: int) -> None: """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search.""" for layer_idx in range(len(self.key_cache)): if self.key_cache[layer_idx].numel(): self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) def batch_select_indices(self, indices: torch.Tensor) -> None: """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search.""" for layer_idx in range(len(self.key_cache)): if self.key_cache[layer_idx].numel(): self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] class KVCache: def __init__(self, cache_size: int = 4) -> None: self.cache_size = cache_size self.tensor_input_field_names = [ "input_ids", "within_seq_position_ids", "global_position_ids", "sequence_ids", "labels", ] self.tensor_output_field_names = ["logits", "embeddings"] self.cache_dict: Dict[str, DynamicCache] = {} self.cache_queue: List[str] = [] def reset(self) -> None: for k in list(self.cache_dict.keys()): del self.cache_dict[k] del self.cache_dict self.cache_dict = {} self.cache_queue = [] torch.cuda.empty_cache() def before_forward(self, batch: Dict[str, torch.Tensor]) -> None: contexts: Optional[List[str]] = batch.get("context", None) if contexts is None or "context_len" not in batch: logger.warning_once( "KVCache requires the batch dict to have both `context` and `context_len` keys to trigger. Skipping." ) return context_lens: List[int] = list(set(batch["context_len"])) contexts: List[str] = list(set(contexts)) # type: ignore[no-redef] if len(contexts) != 1 or len(context_lens) != 1: logger.warning( "SingleContextKVCache requires a single context and context length. " "Multiple contexts or context lengths found in a single batch. Skipping." ) return batch_size = batch["input_ids"].shape[0] unique_context = contexts[0] unique_context_len = context_lens[0] batch["use_cache"] = True if unique_context not in self.cache_dict: return self.cache_dict[unique_context].batch_repeat_interleave(batch_size) past_key_values = self.cache_dict[unique_context] batch["past_key_values"] = past_key_values # Remove context from the input fields for field_name in self.tensor_input_field_names: if batch.get(field_name, None) is not None: batch[field_name] = batch[field_name][:, unique_context_len:] def after_forward(self, batch: Dict[str, Any], outputs: ModelOutput) -> None: contexts = batch.get("context", None) context_lens = batch.get("context_len", []) if contexts is None or len(set(contexts)) != 1 or len(set(context_lens)) != 1 or context_lens[0] == 0: return assert batch["use_cache"] unique_context = contexts[0] unique_context_len = context_lens[0] past_key_values = getattr(outputs, "past_key_values", None) if not isinstance(past_key_values, DynamicCache): logger.warning_once("KVCache is incompatible with models that don't return a DynamicCache. Skipping.") return if "past_key_values" not in batch: if len(self.cache_queue) == self.cache_size: last_context = self.cache_queue.pop(0) if last_context not in self.cache_queue: del self.cache_dict[last_context] torch.cuda.empty_cache() self.cache_dict[unique_context] = past_key_values self.cache_queue.append(unique_context) # Remove context from the input fields for field_name in self.tensor_input_field_names: if field_name in batch and batch[field_name] is not None: batch[field_name] = batch[field_name][:, unique_context_len:] # Remove context from the output fields for field_name in self.tensor_output_field_names: if field_name in outputs and outputs[field_name] is not None: outputs[field_name] = outputs[field_name][:, unique_context_len:] if "hidden_states" in outputs and outputs["hidden_states"] is not None: outputs["hidden_states"] = [h[:, unique_context_len:] for h in outputs["hidden_states"]] self.cache_dict[unique_context].crop(unique_context_len) self.cache_dict[unique_context].batch_select_indices([0]) class AttentionLayerType(Enum): WITHIN_SEQ = "within_seq" GLOBAL = "global" class AttentionArgs(TypedDict, total=False): within_seq_block_mask: Optional[BlockMask] block_causal_block_mask: Optional[BlockMask] within_seq_mask_4d: Optional[torch.Tensor] block_causal_mask_4d: Optional[torch.Tensor] def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) class RotaryPositionalEmbedding(nn.Module): def __init__( self, dim: int, max_position_embeddings: int = 2048, base: int = 10000, device: Optional[torch.device] = None ): super().__init__() self.dim = dim self.base = base self.max_position_embeddings = max_position_embeddings inv_freq = base ** -(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_sin_cos_cache(seq_len=max_position_embeddings, device=self.inv_freq.device) @staticmethod def rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def _set_sin_cos_cache(self, seq_len: int, device: torch.device) -> None: # Different from paper, but it uses a different permutation in order to obtain the same calculation self.max_seq_len_cached = seq_len t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) angles = torch.outer(t, self.inv_freq.to(device)) angles = torch.cat((angles, angles), dim=1) self.register_buffer("cos_cached", angles.cos(), persistent=False) self.register_buffer("sin_cached", angles.sin(), persistent=False) def forward( self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.LongTensor, seq_len: Optional[int] = None ) -> Tuple[torch.Tensor, torch.Tensor]: # x: [bsz, seq_len, num_attention_heads, head_size] device, dtype = q.device, q.dtype seq_len = position_ids.max().item() + 1 if seq_len is None else seq_len if seq_len > self.max_seq_len_cached: self._set_sin_cos_cache(seq_len=seq_len, device=device) # angles_cached[position_ids] gets us something of shape (batch_size, seq_len, head_dim), # so unsqueeze dimension -2 to broadcast to (batch_size, seq_len, n_heads, head_dim). idxs = position_ids.to(device) cos = self.cos_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] sin = self.sin_cached.to(device=device, dtype=dtype).unsqueeze(-2)[idxs] # Apply rotary positional embeddings to q and k (treating them as complex numbers). The first half is # Re[x exp(it)] = Re[x] cos(t) - Im[x] sin(t), while the second half is # Im[x exp(it)] = Im[x] cos(t) + Re[x] sin(t). This works b/c both halves of cos/sin are the same. q_embed = (q * cos) + (self.rotate_half(q) * sin) k_embed = (k * cos) + (self.rotate_half(k) * sin) return q_embed, k_embed class Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper.""" def __init__(self, config: E1Config, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_kv_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_kv_heads self.max_num_seqs = config.max_num_sequences self.clip_qkv = config.clip_qkv if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) if self.config.global_attention_every_n_layers > 0: self.layer_type = ( AttentionLayerType.GLOBAL if (self.layer_idx + 1) % self.config.global_attention_every_n_layers == 0 else AttentionLayerType.WITHIN_SEQ ) else: self.layer_type = AttentionLayerType.WITHIN_SEQ self.rope_theta = ( config.rope_theta_within_seq if self.layer_type == AttentionLayerType.WITHIN_SEQ else config.rope_theta_global ) self.max_position_embeddings = ( config.max_num_positions_within_seq if self.layer_type == AttentionLayerType.WITHIN_SEQ else config.max_num_positions_global ) self.rotary_emb = RotaryPositionalEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta ) self.attn_backend = resolve_attention_backend(config.attn_backend) def prepare_qkv( self, hidden_states: torch.Tensor, position_ids: torch.LongTensor, past_key_value: Optional[DynamicCache] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: bsz, q_len, _ = hidden_states.size() query_states: torch.Tensor = self.q_proj(hidden_states) key_states: torch.Tensor = self.k_proj(hidden_states) val_states: torch.Tensor = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) key_states = key_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) val_states = val_states.view(bsz, q_len, self.num_kv_heads, self.head_dim) if self.clip_qkv is not None: query_states = query_states.clamp(-self.clip_qkv, self.clip_qkv) key_states = key_states.clamp(-self.clip_qkv, self.clip_qkv) val_states = val_states.clamp(-self.clip_qkv, self.clip_qkv) query_states, key_states = self.rotary_emb(query_states, key_states, position_ids) if use_cache and past_key_value is not None: key_states, val_states = past_key_value.update(key_states, val_states, self.layer_idx) input_dtype = query_states.dtype if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.q_proj.weight.dtype if input_dtype != target_dtype: logger.warning_once( f"The input hidden states seems to be silently casted in {input_dtype}. " f"This might be because you have upcasted embedding or layer norm layers " f"in {input_dtype}. We will cast back the input in {target_dtype}." ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) val_states = val_states.to(target_dtype) return query_states, key_states, val_states def forward( self, hidden_states: torch.Tensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, attention_args: Optional[AttentionArgs] = None, past_key_value: Optional[DynamicCache] = None, output_attentions: bool = False, output_s_max: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], Optional[List[torch.Tensor]]]: is_cache_prefilled = ( use_cache and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0 ) query_states, key_states, val_states = self.prepare_qkv( hidden_states=hidden_states, position_ids=within_seq_position_ids if self.layer_type == AttentionLayerType.WITHIN_SEQ else global_position_ids, past_key_value=past_key_value, use_cache=use_cache, ) attn_output, attn_weights, s_max = self._attn( query_states=query_states, key_states=key_states, val_states=val_states, sequence_ids=sequence_ids, attention_args=attention_args, output_attentions=output_attentions, output_s_max=output_s_max, is_cache_prefilled=is_cache_prefilled, ) attn_output = self.o_proj(attn_output) return attn_output, attn_weights, past_key_value, s_max def _attn( self, query_states: torch.Tensor, key_states: torch.Tensor, val_states: torch.Tensor, sequence_ids: torch.Tensor, attention_args: Optional[AttentionArgs] = None, output_attentions: bool = False, output_s_max: bool = False, is_cache_prefilled: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]: effective_layer_type = self.layer_type if is_cache_prefilled and self.layer_type == AttentionLayerType.GLOBAL: effective_layer_type = AttentionLayerType.WITHIN_SEQ if output_attentions: return self._manual_attn( query_states, key_states, val_states, sequence_ids=sequence_ids, attention_args=attention_args, effective_layer_type=effective_layer_type, output_s_max=output_s_max, is_cache_prefilled=is_cache_prefilled, ) if self.attn_backend == AttentionBackend.KERNELS_FLASH: if effective_layer_type == AttentionLayerType.WITHIN_SEQ: attn_output, attn_weights = self._kernels_flash_attn( query_states, key_states, val_states, sequence_ids=sequence_ids, is_cache_prefilled=is_cache_prefilled, ) else: attn_output, attn_weights = self._flex_attn( query_states, key_states, val_states, attention_args=attention_args, effective_layer_type=effective_layer_type, ) elif self.attn_backend == AttentionBackend.FLEX: attn_output, attn_weights = self._flex_attn( query_states, key_states, val_states, attention_args=attention_args, effective_layer_type=effective_layer_type, ) elif self.attn_backend == AttentionBackend.SDPA: attn_output, attn_weights = self._sdpa_attn( query_states, key_states, val_states, sequence_ids=sequence_ids, attention_args=attention_args, effective_layer_type=effective_layer_type, is_cache_prefilled=is_cache_prefilled, ) else: raise AssertionError(f"Unsupported resolved backend: {self.attn_backend}") s_max = self._compute_s_max(query_states, key_states) if output_s_max else None return attn_output, attn_weights, s_max @torch.no_grad() def _compute_s_max( self, query_states: torch.Tensor, # (B, L, H, D) key_states: torch.Tensor, # (B, L, Hkv, D) ) -> List[torch.Tensor]: query_BHLD = query_states.transpose(1, 2).contiguous() key_BHLD = key_states.transpose(1, 2).contiguous() key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups) scale = 1.0 / (self.head_dim ** 0.5) q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1) k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1) s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * scale return [s_max_bound[h] for h in range(self.num_heads)] def _kernels_flash_attn( self, query_states: torch.Tensor, key_states: torch.Tensor, val_states: torch.Tensor, sequence_ids: torch.Tensor, is_cache_prefilled: bool = False, ) -> Tuple[torch.Tensor, None]: bsz, q_len = query_states.shape[0], query_states.shape[1] _, kv_len = key_states.shape[0], key_states.shape[1] if self.layer_type == AttentionLayerType.GLOBAL and not is_cache_prefilled: q_sequence_ids = sequence_ids if q_len < kv_len: first_token_id = sequence_ids[:, 0].unsqueeze(1) k_sequence_ids = torch.cat([first_token_id.expand(bsz, kv_len - q_len), sequence_ids], dim=-1) else: k_sequence_ids = sequence_ids else: if q_len < kv_len: key_states = key_states[:, -q_len:] val_states = val_states[:, -q_len:] q_sequence_ids = k_sequence_ids = sequence_ids attn_output = kernels_flash_attention_func( query_states, key_states, val_states, q_sequence_ids=q_sequence_ids, k_sequence_ids=k_sequence_ids, causal=False, ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() return attn_output, None def _flex_attn( self, query_states: torch.Tensor, key_states: torch.Tensor, val_states: torch.Tensor, attention_args: Optional[AttentionArgs] = None, effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ, ) -> Tuple[torch.Tensor, None]: bsz, q_len = query_states.shape[0], query_states.shape[1] if effective_layer_type == AttentionLayerType.WITHIN_SEQ: block_mask = attention_args["within_seq_block_mask"] if attention_args is not None else None else: block_mask = attention_args["block_causal_block_mask"] if attention_args is not None else None outputs = flex_attention_func(query_states, key_states, val_states, block_mask=block_mask) outputs = outputs.reshape(bsz, q_len, self.hidden_size).contiguous() return outputs, None def _sdpa_attn( self, query_states: torch.Tensor, # (B, L, H, D) key_states: torch.Tensor, # (B, L, Hkv, D) val_states: torch.Tensor, # (B, L, Hkv, D) sequence_ids: torch.Tensor, attention_args: Optional[AttentionArgs] = None, effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ, is_cache_prefilled: bool = False, ) -> Tuple[torch.Tensor, None]: bsz, q_len = query_states.shape[:2] kv_len = key_states.shape[1] if is_cache_prefilled and q_len < kv_len: if effective_layer_type == AttentionLayerType.WITHIN_SEQ: key_states = key_states[:, -q_len:] val_states = val_states[:, -q_len:] attention_mask_4d = build_within_seq_mask_4d(sequence_ids) if effective_layer_type == AttentionLayerType.WITHIN_SEQ else None elif attention_args is not None: if effective_layer_type == AttentionLayerType.WITHIN_SEQ: attention_mask_4d = attention_args["within_seq_mask_4d"] else: attention_mask_4d = attention_args["block_causal_mask_4d"] else: attention_mask_4d = None query_BHLD = query_states.transpose(1, 2).contiguous() key_BHLD = key_states.transpose(1, 2).contiguous() val_BHLD = val_states.transpose(1, 2).contiguous() key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups) val_BHLD = repeat_kv(val_BHLD, self.num_key_value_groups) context_BHLD = F.scaled_dot_product_attention(query_BHLD, key_BHLD, val_BHLD, attn_mask=attention_mask_4d) attn_output = context_BHLD.transpose(1, 2).reshape(bsz, q_len, self.hidden_size).contiguous() return attn_output, None def _manual_attn( self, query_states: torch.Tensor, # (B, L, H, D) key_states: torch.Tensor, # (B, L, Hkv, D) val_states: torch.Tensor, # (B, L, Hkv, D) sequence_ids: torch.Tensor, attention_args: Optional[AttentionArgs] = None, effective_layer_type: AttentionLayerType = AttentionLayerType.WITHIN_SEQ, output_s_max: bool = False, is_cache_prefilled: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]: bsz, q_len = query_states.shape[:2] kv_len = key_states.shape[1] if is_cache_prefilled and q_len < kv_len: if effective_layer_type == AttentionLayerType.WITHIN_SEQ: key_states = key_states[:, -q_len:] val_states = val_states[:, -q_len:] attention_mask_4d = build_within_seq_mask_4d(sequence_ids) if effective_layer_type == AttentionLayerType.WITHIN_SEQ else None elif attention_args is not None: if effective_layer_type == AttentionLayerType.WITHIN_SEQ: attention_mask_4d = attention_args["within_seq_mask_4d"] else: attention_mask_4d = attention_args["block_causal_mask_4d"] else: attention_mask_4d = None query_BHLD = query_states.transpose(1, 2).contiguous() key_BHLD = key_states.transpose(1, 2).contiguous() val_BHLD = val_states.transpose(1, 2).contiguous() key_BHLD = repeat_kv(key_BHLD, self.num_key_value_groups) val_BHLD = repeat_kv(val_BHLD, self.num_key_value_groups) scale = 1.0 / (self.head_dim ** 0.5) attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * scale if attention_mask_4d is not None: attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf")) attn_weights = F.softmax(attn_weights, dim=-1) context_BHLD = torch.matmul(attn_weights, val_BHLD) attn_output = context_BHLD.transpose(1, 2).reshape(bsz, q_len, self.hidden_size).contiguous() s_max = self._compute_s_max(query_states, key_states) if output_s_max else None return attn_output, attn_weights, s_max class MLP(nn.Module): def __init__(self, config: E1Config): super().__init__() self.ffn_dim = config.intermediate_size self.hidden_dim = config.hidden_size self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.w2(self.act_fn(self.w1(hidden_states))) class GLUMLP(nn.Module): def __init__(self, config: E1Config): super().__init__() self.ffn_dim = config.intermediate_size self.hidden_dim = config.hidden_size self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) hidden_states = self.w2(hidden_states) return hidden_states class FFN(nn.Module): def __init__(self, config: E1Config): super().__init__() mlp_cls = GLUMLP if config.gated_mlp else MLP self.mlp = mlp_cls(config) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.mlp(hidden_states) @dataclass class E1ModelOutputWithPast(ModelOutput): """Base class for model's outputs, with potential hidden states and attentions. Attributes: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ last_hidden_state: Optional[torch.FloatTensor] = None past_key_values: Optional[DynamicCache] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None s_max: Optional[Tuple[List[torch.Tensor], ...]] = None @dataclass class E1MaskedLMOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None mlm_loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None past_key_values: Optional[DynamicCache] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None s_max: Optional[Tuple[List[torch.Tensor], ...]] = None @dataclass class E1ClassificationOutputWithPast(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None last_hidden_state: Optional[torch.FloatTensor] = None past_key_values: Optional[DynamicCache] = None hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None attentions: Optional[Tuple[torch.FloatTensor, ...]] = None s_max: Optional[Tuple[List[torch.Tensor], ...]] = None class RMSNorm(nn.Module): def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps self.hidden_size = hidden_size def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype if layer_norm is None: return torch.nn.functional.rms_norm( hidden_states, (self.hidden_size,), self.weight, self.variance_epsilon ).to(input_dtype) else: return layer_norm.rms_norm_fn( x=hidden_states, weight=self.weight, bias=None, # no bias residual=None, eps=self.variance_epsilon, dropout_p=0.0, # no dropout by default prenorm=False, residual_in_fp32=False, ).to(input_dtype) class NormAttentionNorm(nn.Module): def __init__(self, config: E1Config, layer_idx: int): super().__init__() self.self_attn = Attention(config, layer_idx) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, attention_args: Optional[AttentionArgs] = None, past_key_value: Optional[DynamicCache] = None, output_attentions: bool = False, output_s_max: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], Optional[List[torch.Tensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) hidden_states, self_attn_weights, present_key_value, s_max = self.self_attn( hidden_states=hidden_states, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, attention_args=attention_args, past_key_value=past_key_value, output_attentions=output_attentions, output_s_max=output_s_max, use_cache=use_cache, ) hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) return hidden_states, residual, self_attn_weights, present_key_value, s_max class DecoderLayer(nn.Module): def __init__(self, config: E1Config, layer_idx: int): super().__init__() self.initializer_range = config.initializer_range self.hidden_size = config.hidden_size self.norm_attn_norm = NormAttentionNorm(config, layer_idx) self.ffn = FFN(config) def forward( self, hidden_states: torch.Tensor, within_seq_position_ids: torch.LongTensor, global_position_ids: torch.LongTensor, sequence_ids: torch.LongTensor, attention_args: Optional[AttentionArgs] = None, past_key_value: Optional[DynamicCache] = None, output_attentions: bool = False, output_s_max: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[DynamicCache], Optional[List[torch.Tensor]]]: hidden_states, residual, self_attn_weights, present_key_value, s_max = self.norm_attn_norm( hidden_states=hidden_states, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, attention_args=attention_args, past_key_value=past_key_value, output_attentions=output_attentions, output_s_max=output_s_max, use_cache=use_cache, ) # Fully Connected hidden_states = self.ffn(hidden_states) hidden_states = residual + hidden_states return hidden_states, self_attn_weights, present_key_value, s_max class E1PreTrainedModel(PreTrainedModel): config_class = E1Config config: E1Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["DecoderLayer"] _transformer_layer_cls = [DecoderLayer] _skip_keys_device_placement = "past_key_values" all_tied_weights_keys = {} def _init_weights(self, module: nn.Module) -> None: std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, RMSNorm): module.weight.data.fill_(1.0) def _backward_compatibility_gradient_checkpointing(self) -> None: if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): self.gradient_checkpointing_enable(dict(use_reentrant=False)) def post_init(self) -> None: super().post_init() @property def _device(self) -> torch.device: return next(self.parameters()).device @property def attn_backend(self) -> str: return self.config.attn_backend @attn_backend.setter def attn_backend(self, backend: str) -> None: assert backend in VALID_ATTENTION_BACKENDS, ( f"Unsupported attn_backend: {backend}. Expected one of {VALID_ATTENTION_BACKENDS}." ) self.config.attn_backend = backend resolved = resolve_attention_backend(backend) for module in self.modules(): if isinstance(module, FAST_E1_ENCODER): module._attn_backend = resolved elif isinstance(module, Attention): module.attn_backend = resolved class FAST_E1_ENCODER(E1PreTrainedModel, EmbeddingMixin): config: E1Config config_class = E1Config def __init__(self, config: E1Config, **kwargs): E1PreTrainedModel.__init__(self, config, **kwargs) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_seq_id = nn.Embedding(config.max_num_sequences, config.hidden_size) self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = config.gradient_checkpointing self.prep_tokens = E1BatchPreparer() self._attn_backend = resolve_attention_backend(config.attn_backend) self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.embed_tokens def set_input_embeddings(self, value: nn.Embedding) -> None: self.embed_tokens = value def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) last_hidden_state = self.forward(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state if return_attention_mask: attention_mask = (batch['sequence_ids'] != -1).long() return last_hidden_state, attention_mask else: return last_hidden_state # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, within_seq_position_ids: Optional[torch.LongTensor] = None, global_position_ids: Optional[torch.LongTensor] = None, sequence_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[DynamicCache] = None, use_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, output_s_max: bool = False, **kwargs ) -> E1ModelOutputWithPast: """ Args: input_ids: (batch_size, seq_length) within_seq_position_ids: (batch_size, seq_length) This tensor contains the position of each residue within the sequence itself. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] global_position_ids: (batch_size, seq_length) This tensor contains the position of each residue within the global sequence. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] sequence_ids: (batch_size, seq_length) This tensor contains the sequence id of each residue. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] inputs_embeds: (batch_size, seq_length, hidden_size) - pre-computed embeddings, bypasses embed_tokens and embed_seq_id when provided. Used by PDE for differentiable soft sequence optimization. past_key_values: DynamicCache use_cache: bool output_attentions: bool output_hidden_states: bool output_s_max: bool Returns: E1ModelOutputWithPast: Model Outputs """ assert not (input_ids is not None and inputs_embeds is not None), ( "Cannot specify both input_ids and inputs_embeds" ) assert input_ids is not None or inputs_embeds is not None, ( "Must specify either input_ids or inputs_embeds" ) if input_ids is not None: batch_size, seq_length = input_ids.shape else: batch_size, seq_length = inputs_embeds.shape[:2] if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): if use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ) use_cache = False if use_cache and past_key_values is None: past_key_values = DynamicCache() elif not use_cache: past_key_values = None # Synthesize positional IDs for soft embedding path (single-sequence) if inputs_embeds is not None: device = inputs_embeds.device if within_seq_position_ids is None: within_seq_position_ids = torch.arange(seq_length, device=device).unsqueeze(0).expand(batch_size, -1) if global_position_ids is None: global_position_ids = torch.arange(seq_length, device=device).unsqueeze(0).expand(batch_size, -1) if sequence_ids is None: sequence_ids = torch.zeros(batch_size, seq_length, device=device, dtype=torch.long) global_position_ids = global_position_ids.view(-1, seq_length).long() within_seq_position_ids = within_seq_position_ids.view(-1, seq_length).long() sequence_ids = sequence_ids.view(-1, seq_length).long() max_position_id = torch.max(within_seq_position_ids).item() min_position_id = torch.min(within_seq_position_ids).item() assert max_position_id < self.config.max_num_positions_within_seq and min_position_id >= -1, ( f"Position ids must be in the range [-1, {self.config.max_num_positions_within_seq}); got max {max_position_id} and min {min_position_id}" ) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = inputs_embeds + self.embed_seq_id(sequence_ids.clamp(min=0)) if torch.is_autocast_enabled(): target_dtype = torch.get_autocast_gpu_dtype() else: target_dtype = self.layers[0].norm_attn_norm.self_attn.q_proj.weight.dtype hidden_states = inputs_embeds.to(target_dtype) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 attn_backend = self._attn_backend has_global_layers = self.config.global_attention_every_n_layers > 0 needs_4d_masks = (attn_backend == AttentionBackend.SDPA) or output_attentions needs_block_causal_flex = ( (attn_backend == AttentionBackend.FLEX and has_global_layers) or (attn_backend == AttentionBackend.KERNELS_FLASH and has_global_layers) ) needs_within_seq_flex = (attn_backend == AttentionBackend.FLEX) attention_args: Optional[AttentionArgs] = None if past_key_values_length == 0: attention_args = AttentionArgs( block_causal_block_mask=create_block_causal_mask_optimized(sequence_ids) if needs_block_causal_flex else None, within_seq_block_mask=create_within_seq_block_mask(sequence_ids) if needs_within_seq_flex else None, within_seq_mask_4d=build_within_seq_mask_4d(sequence_ids) if needs_4d_masks else None, block_causal_mask_4d=build_block_causal_mask_4d(sequence_ids) if needs_4d_masks else None, ) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None full_s_max = () if output_s_max else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) # type: ignore[operator] if self.gradient_checkpointing and self.training and torch.is_grad_enabled(): layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, within_seq_position_ids, global_position_ids, sequence_ids, attention_args, past_key_values, output_attentions, output_s_max, use_cache, ) else: layer_outputs = decoder_layer( hidden_states, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, attention_args=attention_args, past_key_value=past_key_values, output_attentions=output_attentions, output_s_max=output_s_max, use_cache=use_cache, ) hidden_states, self_attn_weights, present_key_value, s_max = layer_outputs if use_cache: next_decoder_cache = past_key_values = present_key_value if output_attentions: all_self_attns += (self_attn_weights,) # type: ignore[operator] if full_s_max is not None: full_s_max += (s_max,) # type: ignore[operator] hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) # type: ignore[operator] next_cache = next_decoder_cache if use_cache else None return E1ModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, s_max=full_s_max, ) class E1Model(E1PreTrainedModel, EmbeddingMixin): config: E1Config config_class = E1Config def __init__(self, config: E1Config, **kwargs): E1PreTrainedModel.__init__(self, config, **kwargs) self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) self.prep_tokens = self.model.prep_tokens self.post_init() def get_input_embeddings(self) -> nn.Embedding: return self.model.get_input_embeddings() def set_input_embeddings(self, value: nn.Embedding) -> None: self.model.set_input_embeddings(value) def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: return self.model._embed(sequences, return_attention_mask=return_attention_mask, **kwargs) def forward( self, input_ids: Optional[torch.LongTensor] = None, within_seq_position_ids: Optional[torch.LongTensor] = None, global_position_ids: Optional[torch.LongTensor] = None, sequence_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[DynamicCache] = None, use_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, output_s_max: bool = False, **kwargs, ) -> E1ModelOutputWithPast: return self.model( input_ids=input_ids, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, **kwargs, ) class E1ForMaskedLM(E1PreTrainedModel, EmbeddingMixin): config: E1Config config_class = E1Config def __init__(self, config: E1Config, **kwargs): E1PreTrainedModel.__init__(self, config, **kwargs) self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) self.vocab_size = config.vocab_size self.mlm_head = torch.nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size, bias=True), nn.GELU(), nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps), nn.Linear(config.hidden_size, config.vocab_size, bias=True), ) self.gradient_checkpointing = config.gradient_checkpointing self.prep_tokens = self.model.prep_tokens self.post_init() @property def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: return self.model.device_mesh def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state if return_attention_mask: attention_mask = (batch['sequence_ids'] != -1).long() return last_hidden_state, attention_mask else: return last_hidden_state def forward( self, input_ids: Optional[torch.LongTensor] = None, within_seq_position_ids: Optional[torch.LongTensor] = None, global_position_ids: Optional[torch.LongTensor] = None, sequence_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, past_key_values: Optional[DynamicCache] = None, use_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, output_s_max: bool = False, **kwargs, ) -> E1MaskedLMOutputWithPast: """ Args: input_ids: (batch_size, seq_length) within_seq_position_ids: (batch_size, seq_length) This tensor contains the position of each residue within the sequence itself. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0,1,2,3,4,5,6,0,1,2,3,4,5,6], [0,1,2,3,4,5,0,1,2,3,4,5,6,-1]] global_position_ids: (batch_size, seq_length) This tensor contains the position of each residue within the global sequence. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, -1]] sequence_ids: (batch_size, seq_length) This tensor contains the sequence id of each residue. For example, if the input is ["1ABC21DEF2", "1GH21JKL2"], the tensor would be [[0,0,0,0,0,0,0,1,1,1,1,1,1,1], [0,0,0,0,0,0,1,1,1,1,1,1,1,-1]] inputs_embeds: (batch_size, seq_length, hidden_size) - pre-computed embeddings labels: (batch_size, seq_length) past_key_values: DynamicCache use_cache: bool output_attentions: bool output_hidden_states: bool output_s_max: bool Returns: E1MaskedLMOutputWithPast: Model Outputs """ outputs: E1ModelOutputWithPast = self.model( input_ids=input_ids, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, ) last_hidden_state = outputs.last_hidden_state loss = None mlm_logits = self.mlm_head(last_hidden_state).float() mlm_loss = 0.0 if labels is not None: mlm_logits_flat = mlm_logits.contiguous().view(-1, self.config.vocab_size) mlm_labels_flat = labels.to(mlm_logits_flat.device).contiguous().view(-1) mlm_loss = F.cross_entropy(mlm_logits_flat, mlm_labels_flat, reduction="none") mask = mlm_labels_flat != self.model.padding_idx n_mlm = mask.sum() mlm_loss = (mlm_loss * mask.to(mlm_loss)).sum() / (1 if n_mlm == 0 else n_mlm) loss = 0.0 loss += mlm_loss return E1MaskedLMOutputWithPast( loss=loss, mlm_loss=mlm_loss, logits=mlm_logits, last_hidden_state=last_hidden_state, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) class E1ForSequenceClassification(E1PreTrainedModel, EmbeddingMixin): config: E1Config config_class = E1Config def __init__(self, config: E1Config, **kwargs): E1PreTrainedModel.__init__(self, config, **kwargs) self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) self.vocab_size = config.vocab_size self.num_labels = config.num_labels self.classifier = nn.Sequential( nn.Linear(config.hidden_size * 2, config.hidden_size * 4), nn.GELU(), nn.LayerNorm(config.hidden_size * 4), nn.Linear(config.hidden_size * 4, config.num_labels), ) self.mse = nn.MSELoss() self.ce = nn.CrossEntropyLoss() self.bce = nn.BCEWithLogitsLoss() self.gradient_checkpointing = config.gradient_checkpointing self.prep_tokens = self.model.prep_tokens if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0: pooling_types = kwargs['pooling_types'] else: pooling_types = ['mean', 'var'] self.pooler = Pooler(pooling_types) self.post_init() @property def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: return self.model.device_mesh def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state if return_attention_mask: attention_mask = (batch['sequence_ids'] != -1).long() return last_hidden_state, attention_mask else: return last_hidden_state def forward( self, input_ids: Optional[torch.LongTensor] = None, within_seq_position_ids: Optional[torch.LongTensor] = None, global_position_ids: Optional[torch.LongTensor] = None, sequence_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, past_key_values: Optional[DynamicCache] = None, use_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, output_s_max: bool = False, **kwargs, ) -> E1ClassificationOutputWithPast: outputs: E1ModelOutputWithPast = self.model( input_ids=input_ids, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, ) attention_mask = (sequence_ids != -1).long() if sequence_ids is not None else torch.ones(outputs.last_hidden_state.shape[:2], device=outputs.last_hidden_state.device, dtype=torch.long) x = outputs.last_hidden_state features = self.pooler(x, attention_mask) logits = self.classifier(features) loss = None if labels is not None: labels = labels.to(logits.device) if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": if self.num_labels == 1: loss = self.mse(logits.flatten(), labels.flatten()) else: loss = self.mse(logits, labels) elif self.config.problem_type == "single_label_classification": loss = self.ce(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss = self.bce(logits, labels) return E1ClassificationOutputWithPast( loss=loss, logits=logits, last_hidden_state=x, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) class E1ForTokenClassification(E1PreTrainedModel, EmbeddingMixin): config: E1Config config_class = E1Config def __init__(self, config: E1Config, **kwargs): E1PreTrainedModel.__init__(self, config, **kwargs) self.model: FAST_E1_ENCODER = FAST_E1_ENCODER(config, **kwargs) self.vocab_size = config.vocab_size self.num_labels = config.num_labels self.classifier = nn.Sequential( nn.Linear(config.hidden_size * 2, config.hidden_size * 4), nn.GELU(), nn.LayerNorm(config.hidden_size * 4), nn.Linear(config.hidden_size * 4, config.num_labels), ) self.loss_fct = nn.CrossEntropyLoss() self.gradient_checkpointing = config.gradient_checkpointing self.prep_tokens = self.model.prep_tokens self.post_init() @property def device_mesh(self) -> torch.distributed.device_mesh.DeviceMesh: return self.model.device_mesh def _embed(self, sequences: List[str], return_attention_mask: bool = False, **kwargs) -> torch.Tensor: batch = self.prep_tokens.get_batch_kwargs(sequences, device=self._device) last_hidden_state = self.model(**batch, output_hidden_states=False, output_attentions=False).last_hidden_state if return_attention_mask: attention_mask = (batch['sequence_ids'] != -1).long() return last_hidden_state, attention_mask else: return last_hidden_state def forward( self, input_ids: Optional[torch.LongTensor] = None, within_seq_position_ids: Optional[torch.LongTensor] = None, global_position_ids: Optional[torch.LongTensor] = None, sequence_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, past_key_values: Optional[DynamicCache] = None, use_cache: bool = False, output_attentions: bool = False, output_hidden_states: bool = False, output_s_max: bool = False, **kwargs, ) -> E1ClassificationOutputWithPast: outputs: E1ModelOutputWithPast = self.model( input_ids=input_ids, within_seq_position_ids=within_seq_position_ids, global_position_ids=global_position_ids, sequence_ids=sequence_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_s_max=output_s_max, ) x = outputs.last_hidden_state logits = self.classifier(x) loss = None if labels is not None: loss = self.loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return E1ClassificationOutputWithPast( loss=loss, logits=logits, last_hidden_state=x, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, s_max=outputs.s_max, ) if __name__ == "__main__": import random import torch from torch import Tensor def print_tensor_shapes(prefix: str, obj): if isinstance(obj, Tensor): print(f"{prefix}{obj.shape}") elif isinstance(obj, dict): for name, value in obj.items(): print_tensor_shapes(f"{prefix}{name}.", value) elif isinstance(obj, list): for idx, value in enumerate(obj): print_tensor_shapes(f"{prefix}[{idx}].", value) elif isinstance(obj, tuple): for idx, value in enumerate(obj): print_tensor_shapes(f"{prefix}[{idx}].", value) elif hasattr(obj, "__dict__"): for name, value in vars(obj).items(): if name.startswith("_"): continue print_tensor_shapes(f"{prefix}{name}.", value) else: print(f"{prefix}{type(obj)}") def get_e1_batch(tokenizer, sequences: List[str], device: torch.device): preparer = E1BatchPreparer(data_prep_config=DataPrepConfig(max_num_positions_within_seq=64), tokenizer=tokenizer) return preparer.get_batch_kwargs(sequences=sequences, device=device) random.seed(0) torch.manual_seed(0) num_attention_heads = random.choice([2, 4]) config = E1Config( hidden_size=16 * num_attention_heads, intermediate_size=64 * num_attention_heads, num_hidden_layers=random.choice([1, 2]), num_attention_heads=num_attention_heads, num_key_value_heads=num_attention_heads, max_num_positions_within_seq=128, max_num_positions_global=256, max_num_sequences=8, dtype="float32", ) model = E1ForMaskedLM(config=config).eval() tokenizer = get_tokenizer() batch = get_e1_batch(tokenizer=tokenizer, sequences=["ACDEFG", "MKTW"], device=torch.device("cpu")) batch["labels"] = batch["labels"].clone() with torch.no_grad(): output = model( input_ids=batch["input_ids"], within_seq_position_ids=batch["within_seq_position_ids"], global_position_ids=batch["global_position_ids"], sequence_ids=batch["sequence_ids"], labels=batch["labels"], ) print("Batch shape:") print_tensor_shapes("", batch) print("Output shape:") print_tensor_shapes("", output)