""" ModernProteinLM: A next-generation protein encoder combining: - ModernBERT architectural improvements (RoPE, Pre-LN, GeGLU, FlashAttention-compatible) - ELECTRA-style discriminative pre-training - Deep & narrow design optimal for protein sequences - Curriculum masking (30% -> 5%) - Span masking for protein structural motifs Architecture goals (~150M params): - 28 layers, hidden 576, heads 9, intermediate 2304 (GeGLU) - RoPE position embeddings (no absolute PE) - Pre-LayerNorm with extra LN after embedding - No dropout (following ESM-2) - Tied input/output embeddings """ import math from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput class ModernProteinLMConfig(PretrainedConfig): model_type = "modern_protein_lm" def __init__( self, vocab_size=33, hidden_size=576, num_hidden_layers=28, num_attention_heads=9, intermediate_size=2304, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=1026, initializer_range=0.02, layer_norm_eps=1e-12, position_embedding_type="rotary", rope_theta=10000.0, use_geglu=True, tie_word_embeddings=True, pad_token_id=1, mask_token_id=32, cls_token_id=0, eos_token_id=2, **kwargs, ): super().__init__( pad_token_id=pad_token_id, mask_token_id=mask_token_id, cls_token_id=cls_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.layer_norm_eps = layer_norm_eps self.position_embedding_type = position_embedding_type self.rope_theta = rope_theta self.use_geglu = use_geglu class RotaryEmbedding(nn.Module): """RoPE (Rotary Position Embedding) for protein sequences.""" def __init__(self, dim, max_seq_len=1026, base=10000.0, device=None): super().__init__() self.dim = dim inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self.max_seq_len = max_seq_len def forward(self, seq_len, device): t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos().to(torch.float32), emb.sin().to(torch.float32) def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class ModernProteinAttention(nn.Module): """Multi-head attention with RoPE and optional FlashAttention.""" def __init__(self, config: ModernProteinLMConfig): super().__init__() self.num_heads = config.num_attention_heads self.head_dim = config.hidden_size // config.num_attention_heads self.scale = self.head_dim ** -0.5 self.qkv_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False) self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len=config.max_position_embeddings, base=config.rope_theta) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) if config.attention_probs_dropout_prob > 0 else None def forward(self, hidden_states, attention_mask=None, output_attentions=False): batch_size, seq_len, _ = hidden_states.shape qkv = self.qkv_proj(hidden_states) qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, T, D) q, k, v = qkv[0], qkv[1], qkv[2] # Apply RoPE cos, sin = self.rotary_emb(seq_len, device=hidden_states.device) cos = cos[None, None, :, :] # (1, 1, T, D) sin = sin[None, None, :, :] q, k = apply_rotary_pos_emb(q, k, cos, sin) # Try FlashAttention if available try: from flash_attn import flash_attn_func if attention_mask is None and q.dtype in [torch.float16, torch.bfloat16]: attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=self.dropout.p if self.dropout else 0.0, causal=False) attn_output = attn_output.transpose(1, 2) else: raise ImportError("Fallback to standard attention") except (ImportError, AttributeError): # Standard scaled dot-product attention attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale if attention_mask is not None: attn_scores = attn_scores + attention_mask attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) if self.dropout is not None: attn_probs = self.dropout(attn_probs) attn_output = torch.matmul(attn_probs, v) attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1) attn_output = self.out_proj(attn_output) if output_attentions: return attn_output, attn_probs return attn_output, None class GeGLU(nn.Module): """GeGLU activation: GELU(gate) * value. More expressive than GELU alone.""" def __init__(self, config: ModernProteinLMConfig): super().__init__() self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) self.act = nn.GELU() def forward(self, x): gate = self.act(self.gate_proj(x)) up = self.up_proj(x) return self.down_proj(gate * up) class ModernProteinMLP(nn.Module): def __init__(self, config: ModernProteinLMConfig): super().__init__() if config.use_geglu: self.mlp = GeGLU(config) else: self.mlp = nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size, bias=False), nn.GELU(), nn.Linear(config.intermediate_size, config.hidden_size, bias=False), ) def forward(self, x): return self.mlp(x) class ModernProteinLayer(nn.Module): """Pre-LN transformer layer with optional parallel formulation.""" def __init__(self, config: ModernProteinLMConfig): super().__init__() self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.attn = ModernProteinAttention(config) self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = ModernProteinMLP(config) def forward(self, hidden_states, attention_mask=None, output_attentions=False): # Pre-LN: LN -> Attn -> Residual attn_out, attn_weights = self.attn(self.ln1(hidden_states), attention_mask, output_attentions) hidden_states = hidden_states + attn_out # Pre-LN: LN -> MLP -> Residual mlp_out = self.mlp(self.ln2(hidden_states)) hidden_states = hidden_states + mlp_out return hidden_states, attn_weights class ModernProteinLM(PreTrainedModel): config_class = ModernProteinLMConfig def __init__(self, config: ModernProteinLMConfig): super().__init__(config) self.config = config self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layers = nn.ModuleList([ ModernProteinLayer(config) for _ in range(config.num_hidden_layers) ]) self.final_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # Initialize weights self._init_weights() # Tie embeddings if requested if config.tie_word_embeddings: self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head.weight = self.embeddings.weight else: self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def _init_weights(self): for module in self.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def get_input_embeddings(self): return self.embeddings def set_input_embeddings(self, value): self.embeddings = value def forward( self, input_ids, attention_mask=None, position_ids=None, labels=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): batch_size, seq_len = input_ids.shape # Embedding hidden_states = self.embeddings(input_ids) hidden_states = self.embed_ln(hidden_states) # Attention mask for padding if attention_mask is not None: # (B, T) -> (B, 1, 1, T) for broadcasting attention_mask = (1.0 - attention_mask[:, None, None, :]) * -10000.0 all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None # Transformer layers for layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) hidden_states, attn_weights = layer(hidden_states, attention_mask, output_attentions) if output_attentions: all_attentions += (attn_weights,) hidden_states = self.final_ln(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) # LM head logits = self.lm_head(hidden_states) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if not return_dict: output = (logits,) if output_hidden_states: output += (all_hidden_states,) if output_attentions: output += (all_attentions,) return ((loss,) + output) if loss is not None else output return MaskedLMOutput( loss=loss, logits=logits, hidden_states=all_hidden_states, attentions=all_attentions, ) def get_sequence_embedding(self, input_ids, attention_mask=None): """Extract CLS or mean-pooled embedding for downstream tasks.""" outputs = self.forward( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True, return_dict=True, ) hidden = outputs.hidden_states[-1] if attention_mask is not None: # Mean pool over non-padded positions mask_expanded = attention_mask.unsqueeze(-1).float() sum_hidden = (hidden * mask_expanded).sum(dim=1) pooled = sum_hidden / mask_expanded.sum(dim=1).clamp(min=1e-9) else: pooled = hidden[:, 0] # CLS token return pooled class ModernProteinLMForMaskedLM(ModernProteinLM): """Masked Language Model wrapper.""" pass class ModernProteinLMForSequenceClassification(PreTrainedModel): config_class = ModernProteinLMConfig def __init__(self, config: ModernProteinLMConfig): super().__init__(config) self.modern_protein = ModernProteinLM(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels) def forward(self, input_ids, attention_mask=None, labels=None): pooled = self.modern_protein.get_sequence_embedding(input_ids, attention_mask) logits = self.classifier(pooled) loss = None if labels is not None: if self.config.num_labels == 1: loss_fct = nn.MSELoss() loss = loss_fct(logits.squeeze(-1), labels.float()) else: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) return SequenceClassifierOutput(loss=loss, logits=logits) class ModernProteinLMForTokenClassification(PreTrainedModel): config_class = ModernProteinLMConfig def __init__(self, config: ModernProteinLMConfig): super().__init__(config) self.modern_protein = ModernProteinLM(config) self.classifier = nn.Linear(config.hidden_size, config.num_labels) def forward(self, input_ids, attention_mask=None, labels=None): outputs = self.modern_protein( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) logits = self.classifier(outputs.hidden_states[-1]) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) return TokenClassifierOutput(loss=loss, logits=logits)