| """ |
| ModernProteinLM: A next-generation protein encoder combining: |
| - ModernBERT architectural improvements (RoPE, Pre-LN, GeGLU, FlashAttention-compatible) |
| - ELECTRA-style discriminative pre-training |
| - Deep & narrow design optimal for protein sequences |
| - Curriculum masking (30% -> 5%) |
| - Span masking for protein structural motifs |
| |
| Architecture goals (~150M params): |
| - 28 layers, hidden 576, heads 9, intermediate 2304 (GeGLU) |
| - RoPE position embeddings (no absolute PE) |
| - Pre-LayerNorm with extra LN after embedding |
| - No dropout (following ESM-2) |
| - Tied input/output embeddings |
| """ |
|
|
| import math |
| from typing import Optional, Tuple, Union |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput |
|
|
|
|
| class ModernProteinLMConfig(PretrainedConfig): |
| model_type = "modern_protein_lm" |
| |
| def __init__( |
| self, |
| vocab_size=33, |
| hidden_size=576, |
| num_hidden_layers=28, |
| num_attention_heads=9, |
| intermediate_size=2304, |
| hidden_act="gelu", |
| hidden_dropout_prob=0.0, |
| attention_probs_dropout_prob=0.0, |
| max_position_embeddings=1026, |
| initializer_range=0.02, |
| layer_norm_eps=1e-12, |
| position_embedding_type="rotary", |
| rope_theta=10000.0, |
| use_geglu=True, |
| tie_word_embeddings=True, |
| pad_token_id=1, |
| mask_token_id=32, |
| cls_token_id=0, |
| eos_token_id=2, |
| **kwargs, |
| ): |
| super().__init__( |
| pad_token_id=pad_token_id, |
| mask_token_id=mask_token_id, |
| cls_token_id=cls_token_id, |
| eos_token_id=eos_token_id, |
| tie_word_embeddings=tie_word_embeddings, |
| **kwargs, |
| ) |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.intermediate_size = intermediate_size |
| self.hidden_act = hidden_act |
| self.hidden_dropout_prob = hidden_dropout_prob |
| self.attention_probs_dropout_prob = attention_probs_dropout_prob |
| self.max_position_embeddings = max_position_embeddings |
| self.initializer_range = initializer_range |
| self.layer_norm_eps = layer_norm_eps |
| self.position_embedding_type = position_embedding_type |
| self.rope_theta = rope_theta |
| self.use_geglu = use_geglu |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """RoPE (Rotary Position Embedding) for protein sequences.""" |
| |
| def __init__(self, dim, max_seq_len=1026, base=10000.0, device=None): |
| super().__init__() |
| self.dim = dim |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.max_seq_len = max_seq_len |
| |
| def forward(self, seq_len, device): |
| t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| return emb.cos().to(torch.float32), emb.sin().to(torch.float32) |
|
|
|
|
| def rotate_half(x): |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin): |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class ModernProteinAttention(nn.Module): |
| """Multi-head attention with RoPE and optional FlashAttention.""" |
| |
| def __init__(self, config: ModernProteinLMConfig): |
| super().__init__() |
| self.num_heads = config.num_attention_heads |
| self.head_dim = config.hidden_size // config.num_attention_heads |
| self.scale = self.head_dim ** -0.5 |
| |
| self.qkv_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False) |
| self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
| |
| self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len=config.max_position_embeddings, base=config.rope_theta) |
| |
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) if config.attention_probs_dropout_prob > 0 else None |
| |
| def forward(self, hidden_states, attention_mask=None, output_attentions=False): |
| batch_size, seq_len, _ = hidden_states.shape |
| |
| qkv = self.qkv_proj(hidden_states) |
| qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| |
| |
| cos, sin = self.rotary_emb(seq_len, device=hidden_states.device) |
| cos = cos[None, None, :, :] |
| sin = sin[None, None, :, :] |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
| |
| |
| try: |
| from flash_attn import flash_attn_func |
| if attention_mask is None and q.dtype in [torch.float16, torch.bfloat16]: |
| attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), |
| dropout_p=self.dropout.p if self.dropout else 0.0, |
| causal=False) |
| attn_output = attn_output.transpose(1, 2) |
| else: |
| raise ImportError("Fallback to standard attention") |
| except (ImportError, AttributeError): |
| |
| attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| |
| if attention_mask is not None: |
| attn_scores = attn_scores + attention_mask |
| |
| attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) |
| if self.dropout is not None: |
| attn_probs = self.dropout(attn_probs) |
| |
| attn_output = torch.matmul(attn_probs, v) |
| |
| attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1) |
| attn_output = self.out_proj(attn_output) |
| |
| if output_attentions: |
| return attn_output, attn_probs |
| return attn_output, None |
|
|
|
|
| class GeGLU(nn.Module): |
| """GeGLU activation: GELU(gate) * value. More expressive than GELU alone.""" |
| |
| def __init__(self, config: ModernProteinLMConfig): |
| super().__init__() |
| self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) |
| self.act = nn.GELU() |
| |
| def forward(self, x): |
| gate = self.act(self.gate_proj(x)) |
| up = self.up_proj(x) |
| return self.down_proj(gate * up) |
|
|
|
|
| class ModernProteinMLP(nn.Module): |
| def __init__(self, config: ModernProteinLMConfig): |
| super().__init__() |
| if config.use_geglu: |
| self.mlp = GeGLU(config) |
| else: |
| self.mlp = nn.Sequential( |
| nn.Linear(config.hidden_size, config.intermediate_size, bias=False), |
| nn.GELU(), |
| nn.Linear(config.intermediate_size, config.hidden_size, bias=False), |
| ) |
| |
| def forward(self, x): |
| return self.mlp(x) |
|
|
|
|
| class ModernProteinLayer(nn.Module): |
| """Pre-LN transformer layer with optional parallel formulation.""" |
| |
| def __init__(self, config: ModernProteinLMConfig): |
| super().__init__() |
| self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.attn = ModernProteinAttention(config) |
| self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| self.mlp = ModernProteinMLP(config) |
| |
| def forward(self, hidden_states, attention_mask=None, output_attentions=False): |
| |
| attn_out, attn_weights = self.attn(self.ln1(hidden_states), attention_mask, output_attentions) |
| hidden_states = hidden_states + attn_out |
| |
| |
| mlp_out = self.mlp(self.ln2(hidden_states)) |
| hidden_states = hidden_states + mlp_out |
| |
| return hidden_states, attn_weights |
|
|
|
|
| class ModernProteinLM(PreTrainedModel): |
| config_class = ModernProteinLMConfig |
| |
| def __init__(self, config: ModernProteinLMConfig): |
| super().__init__(config) |
| self.config = config |
| |
| self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.embed_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| |
| self.layers = nn.ModuleList([ |
| ModernProteinLayer(config) for _ in range(config.num_hidden_layers) |
| ]) |
| |
| self.final_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
| |
| |
| self._init_weights() |
| |
| |
| if config.tie_word_embeddings: |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.lm_head.weight = self.embeddings.weight |
| else: |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
| def _init_weights(self): |
| for module in self.modules(): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.ones_(module.weight) |
| nn.init.zeros_(module.bias) |
| |
| def get_input_embeddings(self): |
| return self.embeddings |
| |
| def set_input_embeddings(self, value): |
| self.embeddings = value |
| |
| def forward( |
| self, |
| input_ids, |
| attention_mask=None, |
| position_ids=None, |
| labels=None, |
| output_attentions=False, |
| output_hidden_states=False, |
| return_dict=True, |
| ): |
| batch_size, seq_len = input_ids.shape |
| |
| |
| hidden_states = self.embeddings(input_ids) |
| hidden_states = self.embed_ln(hidden_states) |
| |
| |
| if attention_mask is not None: |
| |
| attention_mask = (1.0 - attention_mask[:, None, None, :]) * -10000.0 |
| |
| all_hidden_states = () if output_hidden_states else None |
| all_attentions = () if output_attentions else None |
| |
| |
| for layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
| |
| hidden_states, attn_weights = layer(hidden_states, attention_mask, output_attentions) |
| |
| if output_attentions: |
| all_attentions += (attn_weights,) |
| |
| hidden_states = self.final_ln(hidden_states) |
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
| |
| |
| logits = self.lm_head(hidden_states) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
| |
| if not return_dict: |
| output = (logits,) |
| if output_hidden_states: |
| output += (all_hidden_states,) |
| if output_attentions: |
| output += (all_attentions,) |
| return ((loss,) + output) if loss is not None else output |
| |
| return MaskedLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| ) |
| |
| def get_sequence_embedding(self, input_ids, attention_mask=None): |
| """Extract CLS or mean-pooled embedding for downstream tasks.""" |
| outputs = self.forward( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| hidden = outputs.hidden_states[-1] |
| |
| if attention_mask is not None: |
| |
| mask_expanded = attention_mask.unsqueeze(-1).float() |
| sum_hidden = (hidden * mask_expanded).sum(dim=1) |
| pooled = sum_hidden / mask_expanded.sum(dim=1).clamp(min=1e-9) |
| else: |
| pooled = hidden[:, 0] |
| |
| return pooled |
|
|
|
|
| class ModernProteinLMForMaskedLM(ModernProteinLM): |
| """Masked Language Model wrapper.""" |
| pass |
|
|
|
|
| class ModernProteinLMForSequenceClassification(PreTrainedModel): |
| config_class = ModernProteinLMConfig |
| |
| def __init__(self, config: ModernProteinLMConfig): |
| super().__init__(config) |
| self.modern_protein = ModernProteinLM(config) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| |
| def forward(self, input_ids, attention_mask=None, labels=None): |
| pooled = self.modern_protein.get_sequence_embedding(input_ids, attention_mask) |
| logits = self.classifier(pooled) |
| |
| loss = None |
| if labels is not None: |
| if self.config.num_labels == 1: |
| loss_fct = nn.MSELoss() |
| loss = loss_fct(logits.squeeze(-1), labels.float()) |
| else: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits, labels) |
| |
| return SequenceClassifierOutput(loss=loss, logits=logits) |
|
|
|
|
| class ModernProteinLMForTokenClassification(PreTrainedModel): |
| config_class = ModernProteinLMConfig |
| |
| def __init__(self, config: ModernProteinLMConfig): |
| super().__init__(config) |
| self.modern_protein = ModernProteinLM(config) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| |
| def forward(self, input_ids, attention_mask=None, labels=None): |
| outputs = self.modern_protein( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| ) |
| logits = self.classifier(outputs.hidden_states[-1]) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
| |
| return TokenClassifierOutput(loss=loss, logits=logits) |
|
|