| """Standalone AMPLIFY model for HuggingFace Hub (trust_remote_code=True). |
| |
| This is a self-contained file that can be shipped in a HuggingFace repo so that |
| ``AutoModel.from_pretrained(..., trust_remote_code=True)`` works without |
| installing the ``amplify`` package. |
| |
| Based on: https://github.com/chandar-lab/AMPLIFY |
| """ |
|
|
| from typing import Tuple |
|
|
| import torch |
| from torch import nn |
| from torch.nn.functional import scaled_dot_product_attention |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import MaskedLMOutput |
|
|
| |
| |
| try: |
| from flash_attn.flash_attn_interface import flash_attn_varlen_func |
| except ImportError: |
| flash_attn_varlen_func = None |
|
|
|
|
| |
| |
| |
|
|
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
| t = torch.arange(end, device=freqs.device, dtype=torch.float32) |
| freqs = torch.outer(t, freqs) |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): |
| assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]) |
| return freqs_cis.contiguous().unsqueeze(2) |
|
|
|
|
| def apply_rotary_emb( |
| xq: torch.Tensor, |
| xk: torch.Tensor, |
| freqs_cis: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
| return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
| |
| |
| |
|
|
| class AMPLIFYConfig(PretrainedConfig): |
| model_type = "AMPLIFY" |
|
|
| def __init__( |
| self, |
| hidden_size: int = 960, |
| num_hidden_layers: int = 32, |
| num_attention_heads: int = 15, |
| intermediate_size: int = 3840, |
| embedding_init_range: float = 0.02, |
| decoder_init_range: float = 0.02, |
| norm_eps: float = 1e-05, |
| vocab_size: int = 32, |
| pad_token_id: int = 0, |
| max_length: int = 2048, |
| max_protein_length: int = 50000, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.hidden_size = hidden_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.intermediate_size = intermediate_size |
| self.embedding_init_range = embedding_init_range |
| self.decoder_init_range = decoder_init_range |
| self.norm_eps = norm_eps |
| self.vocab_size = vocab_size |
| self.pad_token_id = pad_token_id |
| self.max_length = max_length |
| self.max_protein_length = max_protein_length |
|
|
|
|
| |
| |
| |
|
|
| class EncoderBlock(nn.Module): |
| """Standard transformer encoder block with SwiGLU FFN and RoPE.""" |
|
|
| def __init__(self, config: AMPLIFYConfig): |
| super().__init__() |
| self.config = config |
| self.d_head = config.hidden_size // config.num_attention_heads |
|
|
| |
| self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=False) |
| self.wo = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
|
|
| |
| multiple_of = 8 |
| intermediate_size = multiple_of * ( |
| (int(2 * config.intermediate_size / 3) + multiple_of - 1) // multiple_of |
| ) |
| self.c_fc = nn.Linear(config.hidden_size, 2 * intermediate_size, bias=False) |
| self.silu = nn.SiLU() |
| self.mlp_c_proj = nn.Linear(intermediate_size, config.hidden_size, bias=False) |
|
|
| self.attention_norm = nn.RMSNorm(config.hidden_size, config.norm_eps) |
| self.ffn_norm = nn.RMSNorm(config.hidden_size, config.norm_eps) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: torch.Tensor, |
| freqs_cis: torch.Tensor, |
| output_attentions: bool, |
| max_seqlen: int = None, |
| cu_seqlens: torch.Tensor = None, |
| ): |
| batch_size, seq_len, _ = x.shape |
|
|
| xq, xk, xv = ( |
| self.qkv(self.attention_norm(x)) |
| .reshape(batch_size, seq_len, self.config.num_attention_heads, self.d_head * 3) |
| .chunk(3, axis=-1) |
| ) |
| xq, xk = apply_rotary_emb(xq, xk, freqs_cis) |
|
|
| attn_weights = None |
|
|
| if cu_seqlens is not None: |
| assert flash_attn_varlen_func is not None, ( |
| "flash_attn is required for packed-sequence attention. " |
| "Install with: pip install flash-attn" |
| ) |
| attn = flash_attn_varlen_func( |
| q=xq.squeeze(0), |
| k=xk.squeeze(0), |
| v=xv.squeeze(0), |
| cu_seqlens_q=cu_seqlens.squeeze(), |
| cu_seqlens_k=cu_seqlens.squeeze(), |
| max_seqlen_q=max_seqlen, |
| max_seqlen_k=max_seqlen, |
| dropout_p=0.0, |
| causal=False, |
| ) |
| elif output_attentions: |
| attn_weights = xq.permute(0, 2, 1, 3) @ xk.permute(0, 2, 3, 1) / (xq.size(-1) ** 0.5) |
| if attention_mask is not None: |
| attn_weights = attn_weights * attention_mask |
| attn_weights = attn_weights.softmax(-1) |
| attn = attn_weights @ xv.permute(0, 2, 1, 3) |
| attn = attn.transpose(1, 2) |
| else: |
| attn = scaled_dot_product_attention( |
| query=xq.transpose(1, 2), |
| key=xk.transpose(1, 2), |
| value=xv.transpose(1, 2), |
| attn_mask=attention_mask.bool() if attention_mask is not None else None, |
| dropout_p=0, |
| ).transpose(1, 2) |
|
|
| attn = self.wo(attn.reshape(batch_size, seq_len, self.config.num_attention_heads * self.d_head)) |
|
|
| x = x + attn |
|
|
| uv = self.c_fc(self.ffn_norm(x)) |
| u, v = torch.chunk(uv, 2, dim=-1) |
| x_mlp = u * self.silu(v) |
| h_mlp = self.mlp_c_proj(x_mlp) |
|
|
| x = x + h_mlp |
| return x, attn_weights |
|
|
|
|
| |
| |
| |
|
|
| class AMPLIFYPreTrainedModel(PreTrainedModel): |
| config_class = AMPLIFYConfig |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| module.weight.data.uniform_( |
| -self.config.decoder_init_range, self.config.decoder_init_range |
| ) |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.uniform_( |
| -self.config.embedding_init_range, self.config.embedding_init_range |
| ) |
|
|
|
|
| class AMPLIFY(AMPLIFYPreTrainedModel): |
| """AMPLIFY protein language model. |
| |
| A transformer encoder for protein sequences using RoPE and SwiGLU, |
| trained with masked language modelling. |
| """ |
|
|
| def __init__(self, config: AMPLIFYConfig, **kwargs): |
| super().__init__(config) |
| self.config = config |
|
|
| self.encoder = nn.Embedding( |
| config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id |
| ) |
|
|
| self.transformer_encoder = nn.ModuleList() |
| for _ in range(config.num_hidden_layers): |
| self.transformer_encoder.append(EncoderBlock(config)) |
|
|
| self.layer_norm = nn.RMSNorm(config.hidden_size, config.norm_eps) |
|
|
| self.decoder = nn.Linear(config.hidden_size, config.vocab_size) |
|
|
| freqs_cis = precompute_freqs_cis( |
| config.hidden_size // config.num_attention_heads, |
| config.max_protein_length * 2, |
| ) |
| self.register_buffer("freqs_cis", freqs_cis, persistent=False) |
|
|
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| position_ids: torch.Tensor = None, |
| max_seqlen: int = None, |
| cu_seqlens: torch.Tensor = None, |
| attention_mask: torch.Tensor = None, |
| output_hidden_states: bool = False, |
| output_attentions: bool = False, |
| ): |
| hidden_states, attentions = [], [] |
|
|
| if isinstance(output_hidden_states, bool) and not output_hidden_states: |
| output_hidden_index = self.config.num_hidden_layers + 1 |
| elif isinstance(output_hidden_states, int): |
| output_hidden_index = output_hidden_states |
| else: |
| output_hidden_index = 0 |
|
|
| if attention_mask is not None: |
| attention_mask = ( |
| attention_mask.unsqueeze(1) |
| .unsqueeze(1) |
| .repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1) |
| ) |
|
|
| if cu_seqlens is not None: |
| assert not output_attentions, "Output attentions is not supported when sequences are packed." |
| assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None." |
| assert input_ids.shape[0] == 1, "Cumulative sequence lengths are provided but input_ids are not packed." |
| assert input_ids.is_cuda, "Packing uses flash-attention and is only supported on GPU." |
|
|
| |
| if position_ids is not None: |
| freqs_cis = self.freqs_cis[position_ids] |
| else: |
| freqs_cis = ( |
| self.freqs_cis[: input_ids.shape[1]] |
| .unsqueeze(0) |
| .repeat(input_ids.shape[0], 1, 1) |
| ) |
|
|
| x = self.encoder(input_ids) |
|
|
| for idx, layer in enumerate(self.transformer_encoder): |
| x, attn = layer(x, attention_mask, freqs_cis, output_attentions, max_seqlen, cu_seqlens) |
| if idx >= output_hidden_index: |
| hidden_states.append(x) |
| if output_attentions: |
| attentions.append(attn) |
|
|
| logits = self.decoder(self.layer_norm(x)) |
|
|
| return MaskedLMOutput( |
| logits=logits, |
| hidden_states=hidden_states, |
| attentions=attentions, |
| ) |
|
|