ModernProteinLM / modeling_modern_protein.py
GrimSqueaker's picture
Upload modeling_modern_protein.py
df284ec verified
"""
ModernProteinLM: A next-generation protein encoder combining:
- ModernBERT architectural improvements (RoPE, Pre-LN, GeGLU, FlashAttention-compatible)
- ELECTRA-style discriminative pre-training
- Deep & narrow design optimal for protein sequences
- Curriculum masking (30% -> 5%)
- Span masking for protein structural motifs
Architecture goals (~150M params):
- 28 layers, hidden 576, heads 9, intermediate 2304 (GeGLU)
- RoPE position embeddings (no absolute PE)
- Pre-LayerNorm with extra LN after embedding
- No dropout (following ESM-2)
- Tied input/output embeddings
"""
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
class ModernProteinLMConfig(PretrainedConfig):
model_type = "modern_protein_lm"
def __init__(
self,
vocab_size=33,
hidden_size=576,
num_hidden_layers=28,
num_attention_heads=9,
intermediate_size=2304,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1026,
initializer_range=0.02,
layer_norm_eps=1e-12,
position_embedding_type="rotary",
rope_theta=10000.0,
use_geglu=True,
tie_word_embeddings=True,
pad_token_id=1,
mask_token_id=32,
cls_token_id=0,
eos_token_id=2,
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
mask_token_id=mask_token_id,
cls_token_id=cls_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.rope_theta = rope_theta
self.use_geglu = use_geglu
class RotaryEmbedding(nn.Module):
"""RoPE (Rotary Position Embedding) for protein sequences."""
def __init__(self, dim, max_seq_len=1026, base=10000.0, device=None):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len = max_seq_len
def forward(self, seq_len, device):
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(torch.float32), emb.sin().to(torch.float32)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class ModernProteinAttention(nn.Module):
"""Multi-head attention with RoPE and optional FlashAttention."""
def __init__(self, config: ModernProteinLMConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.scale = self.head_dim ** -0.5
self.qkv_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=False)
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.rotary_emb = RotaryEmbedding(self.head_dim, max_seq_len=config.max_position_embeddings, base=config.rope_theta)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) if config.attention_probs_dropout_prob > 0 else None
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
batch_size, seq_len, _ = hidden_states.shape
qkv = self.qkv_proj(hidden_states)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, T, D)
q, k, v = qkv[0], qkv[1], qkv[2]
# Apply RoPE
cos, sin = self.rotary_emb(seq_len, device=hidden_states.device)
cos = cos[None, None, :, :] # (1, 1, T, D)
sin = sin[None, None, :, :]
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# Try FlashAttention if available
try:
from flash_attn import flash_attn_func
if attention_mask is None and q.dtype in [torch.float16, torch.bfloat16]:
attn_output = flash_attn_func(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2),
dropout_p=self.dropout.p if self.dropout else 0.0,
causal=False)
attn_output = attn_output.transpose(1, 2)
else:
raise ImportError("Fallback to standard attention")
except (ImportError, AttributeError):
# Standard scaled dot-product attention
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if attention_mask is not None:
attn_scores = attn_scores + attention_mask
attn_probs = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
if self.dropout is not None:
attn_probs = self.dropout(attn_probs)
attn_output = torch.matmul(attn_probs, v)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, -1)
attn_output = self.out_proj(attn_output)
if output_attentions:
return attn_output, attn_probs
return attn_output, None
class GeGLU(nn.Module):
"""GeGLU activation: GELU(gate) * value. More expressive than GELU alone."""
def __init__(self, config: ModernProteinLMConfig):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.act = nn.GELU()
def forward(self, x):
gate = self.act(self.gate_proj(x))
up = self.up_proj(x)
return self.down_proj(gate * up)
class ModernProteinMLP(nn.Module):
def __init__(self, config: ModernProteinLMConfig):
super().__init__()
if config.use_geglu:
self.mlp = GeGLU(config)
else:
self.mlp = nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
nn.GELU(),
nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
)
def forward(self, x):
return self.mlp(x)
class ModernProteinLayer(nn.Module):
"""Pre-LN transformer layer with optional parallel formulation."""
def __init__(self, config: ModernProteinLMConfig):
super().__init__()
self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attn = ModernProteinAttention(config)
self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = ModernProteinMLP(config)
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
# Pre-LN: LN -> Attn -> Residual
attn_out, attn_weights = self.attn(self.ln1(hidden_states), attention_mask, output_attentions)
hidden_states = hidden_states + attn_out
# Pre-LN: LN -> MLP -> Residual
mlp_out = self.mlp(self.ln2(hidden_states))
hidden_states = hidden_states + mlp_out
return hidden_states, attn_weights
class ModernProteinLM(PreTrainedModel):
config_class = ModernProteinLMConfig
def __init__(self, config: ModernProteinLMConfig):
super().__init__(config)
self.config = config
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.embed_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layers = nn.ModuleList([
ModernProteinLayer(config) for _ in range(config.num_hidden_layers)
])
self.final_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# Initialize weights
self._init_weights()
# Tie embeddings if requested
if config.tie_word_embeddings:
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head.weight = self.embeddings.weight
else:
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids,
attention_mask=None,
position_ids=None,
labels=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
batch_size, seq_len = input_ids.shape
# Embedding
hidden_states = self.embeddings(input_ids)
hidden_states = self.embed_ln(hidden_states)
# Attention mask for padding
if attention_mask is not None:
# (B, T) -> (B, 1, 1, T) for broadcasting
attention_mask = (1.0 - attention_mask[:, None, None, :]) * -10000.0
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# Transformer layers
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
hidden_states, attn_weights = layer(hidden_states, attention_mask, output_attentions)
if output_attentions:
all_attentions += (attn_weights,)
hidden_states = self.final_ln(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
# LM head
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,)
if output_hidden_states:
output += (all_hidden_states,)
if output_attentions:
output += (all_attentions,)
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
def get_sequence_embedding(self, input_ids, attention_mask=None):
"""Extract CLS or mean-pooled embedding for downstream tasks."""
outputs = self.forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
)
hidden = outputs.hidden_states[-1]
if attention_mask is not None:
# Mean pool over non-padded positions
mask_expanded = attention_mask.unsqueeze(-1).float()
sum_hidden = (hidden * mask_expanded).sum(dim=1)
pooled = sum_hidden / mask_expanded.sum(dim=1).clamp(min=1e-9)
else:
pooled = hidden[:, 0] # CLS token
return pooled
class ModernProteinLMForMaskedLM(ModernProteinLM):
"""Masked Language Model wrapper."""
pass
class ModernProteinLMForSequenceClassification(PreTrainedModel):
config_class = ModernProteinLMConfig
def __init__(self, config: ModernProteinLMConfig):
super().__init__(config)
self.modern_protein = ModernProteinLM(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, input_ids, attention_mask=None, labels=None):
pooled = self.modern_protein.get_sequence_embedding(input_ids, attention_mask)
logits = self.classifier(pooled)
loss = None
if labels is not None:
if self.config.num_labels == 1:
loss_fct = nn.MSELoss()
loss = loss_fct(logits.squeeze(-1), labels.float())
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
return SequenceClassifierOutput(loss=loss, logits=logits)
class ModernProteinLMForTokenClassification(PreTrainedModel):
config_class = ModernProteinLMConfig
def __init__(self, config: ModernProteinLMConfig):
super().__init__(config)
self.modern_protein = ModernProteinLM(config)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, input_ids, attention_mask=None, labels=None):
outputs = self.modern_protein(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
logits = self.classifier(outputs.hidden_states[-1])
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
return TokenClassifierOutput(loss=loss, logits=logits)