copep-checkpoints / modeling_amplify.py
dapatil211's picture
Re-add modeling_amplify after history purge
8ba58fb
"""Standalone AMPLIFY model for HuggingFace Hub (trust_remote_code=True).
This is a self-contained file that can be shipped in a HuggingFace repo so that
``AutoModel.from_pretrained(..., trust_remote_code=True)`` works without
installing the ``amplify`` package.
Based on: https://github.com/chandar-lab/AMPLIFY
"""
from typing import Tuple
import torch
from torch import nn
from torch.nn.functional import scaled_dot_product_attention
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import MaskedLMOutput
# Optional: flash attention for packed-sequence training. Not required for
# standard inference.
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func # type: ignore
except ImportError:
flash_attn_varlen_func = None
# ---------------------------------------------------------------------------
# Rotary positional embeddings (inlined from amplify.model.rotary)
# ---------------------------------------------------------------------------
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1])
return freqs_cis.contiguous().unsqueeze(2)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
class AMPLIFYConfig(PretrainedConfig):
model_type = "AMPLIFY"
def __init__(
self,
hidden_size: int = 960,
num_hidden_layers: int = 32,
num_attention_heads: int = 15,
intermediate_size: int = 3840,
embedding_init_range: float = 0.02,
decoder_init_range: float = 0.02,
norm_eps: float = 1e-05,
vocab_size: int = 32,
pad_token_id: int = 0,
max_length: int = 2048,
max_protein_length: int = 50000,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.embedding_init_range = embedding_init_range
self.decoder_init_range = decoder_init_range
self.norm_eps = norm_eps
self.vocab_size = vocab_size
self.pad_token_id = pad_token_id
self.max_length = max_length
self.max_protein_length = max_protein_length
# ---------------------------------------------------------------------------
# Encoder blocks
# ---------------------------------------------------------------------------
class EncoderBlock(nn.Module):
"""Standard transformer encoder block with SwiGLU FFN and RoPE."""
def __init__(self, config: AMPLIFYConfig):
super().__init__()
self.config = config
self.d_head = config.hidden_size // config.num_attention_heads
# Attention
self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=False)
self.wo = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
# SwiGLU FFN
multiple_of = 8
intermediate_size = multiple_of * (
(int(2 * config.intermediate_size / 3) + multiple_of - 1) // multiple_of
)
self.c_fc = nn.Linear(config.hidden_size, 2 * intermediate_size, bias=False)
self.silu = nn.SiLU()
self.mlp_c_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False)
self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
def forward(
self,
x: torch.Tensor,
attention_mask: torch.Tensor,
freqs_cis: torch.Tensor,
output_attentions: bool,
max_seqlen: int = None,
cu_seqlens: torch.Tensor = None,
):
batch_size, seq_len, _ = x.shape
xq, xk, xv = (
self.qkv(self.attention_norm(x))
.reshape(batch_size, seq_len, self.config.num_attention_heads, self.d_head * 3)
.chunk(3, axis=-1)
)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
attn_weights = None
if cu_seqlens is not None:
assert flash_attn_varlen_func is not None, (
"flash_attn is required for packed-sequence attention. "
"Install with: pip install flash-attn"
)
attn = flash_attn_varlen_func(
q=xq.squeeze(0),
k=xk.squeeze(0),
v=xv.squeeze(0),
cu_seqlens_q=cu_seqlens.squeeze(),
cu_seqlens_k=cu_seqlens.squeeze(),
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0.0,
causal=False,
)
elif output_attentions:
attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5)
if attention_mask is not None:
attn_weights = attn_weights * attention_mask
attn_weights = attn_weights.softmax(-1)
attn = attn_weights @ xv.permute(0, 2, 1, 3)
attn = attn.transpose(1, 2)
else:
attn = scaled_dot_product_attention(
query=xq.transpose(1, 2),
key=xk.transpose(1, 2),
value=xv.transpose(1, 2),
attn_mask=attention_mask.bool() if attention_mask is not None else None,
dropout_p=0,
).transpose(1, 2)
attn = self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.d_head))
x = x + attn
uv = self.c_fc(self.ffn_norm(x))
u, v = torch.chunk(uv, 2, dim=-1)
x_mlp = u * self.silu(v)
h_mlp = self.mlp_c_proj(x_mlp)
x = x + h_mlp
return x, attn_weights
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class AMPLIFYPreTrainedModel(PreTrainedModel):
config_class = AMPLIFYConfig
def _init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.uniform_(
-self.config.decoder_init_range, self.config.decoder_init_range
)
elif isinstance(module, nn.Embedding):
module.weight.data.uniform_(
-self.config.embedding_init_range, self.config.embedding_init_range
)
class AMPLIFY(AMPLIFYPreTrainedModel):
"""AMPLIFY protein language model.
A transformer encoder for protein sequences using RoPE and SwiGLU,
trained with masked language modelling.
"""
def __init__(self, config: AMPLIFYConfig, **kwargs):
super().__init__(config)
self.config = config
self.encoder = nn.Embedding(
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
)
self.transformer_encoder = nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.transformer_encoder.append(EncoderBlock(config))
self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps)
self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
freqs_cis = precompute_freqs_cis(
config.hidden_size // config.num_attention_heads,
config.max_protein_length * 2,
)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
self.post_init()
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor = None,
max_seqlen: int = None,
cu_seqlens: torch.Tensor = None,
attention_mask: torch.Tensor = None,
output_hidden_states: bool = False,
output_attentions: bool = False,
):
hidden_states, attentions = [], []
if isinstance(output_hidden_states, bool) and not output_hidden_states:
output_hidden_index = self.config.num_hidden_layers + 1
elif isinstance(output_hidden_states, int):
output_hidden_index = output_hidden_states
else:
output_hidden_index = 0
if attention_mask is not None:
attention_mask = (
attention_mask.unsqueeze(1)
.unsqueeze(1)
.repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
)
if cu_seqlens is not None:
assert not output_attentions, "Output attentions is not supported when sequences are packed."
assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
assert input_ids.shape[0] == 1, "Cumulative sequence lengths are provided but input_ids are not packed."
assert input_ids.is_cuda, "Packing uses flash-attention and is only supported on GPU."
# RoPE
if position_ids is not None:
freqs_cis = self.freqs_cis[position_ids]
else:
freqs_cis = (
self.freqs_cis[: input_ids.shape[1]]
.unsqueeze(0)
.repeat(input_ids.shape[0], 1, 1)
)
x = self.encoder(input_ids)
for idx, layer in enumerate(self.transformer_encoder):
x, attn = layer(x, attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens)
if idx >= output_hidden_index:
hidden_states.append(x)
if output_attentions:
attentions.append(attn)
logits = self.decoder(self.layer_norm(x))
return MaskedLMOutput(
logits=logits,
hidden_states=hidden_states,
attentions=attentions,
)