| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from ..modules import ( |
| AxialTransformerLayer, |
| LearnedPositionalEmbedding, |
| RobertaLMHead, |
| ESM1bLayerNorm, |
| ContactPredictionHead, |
| ) |
|
|
| from ..axial_attention import RowSelfAttention, ColumnSelfAttention |
|
|
|
|
|
|
| class MSATransformer(nn.Module): |
| @classmethod |
| def add_args(cls, parser): |
| |
| parser.add_argument( |
| "--num_layers", |
| default=12, |
| type=int, |
| metavar="N", |
| help="number of layers" |
| ) |
| parser.add_argument( |
| "--embed_dim", |
| default=768, |
| type=int, |
| metavar="N", |
| help="embedding dimension" |
| ) |
| parser.add_argument( |
| "--logit_bias", |
| action="store_true", |
| help="whether to apply bias to logits" |
| ) |
| parser.add_argument( |
| "--ffn_embed_dim", |
| default=3072, |
| type=int, |
| metavar="N", |
| help="embedding dimension for FFN", |
| ) |
| parser.add_argument( |
| "--attention_heads", |
| default=12, |
| type=int, |
| metavar="N", |
| help="number of attention heads", |
| ) |
| parser.add_argument( |
| "--dropout", |
| default=0.1, |
| type=float, |
| help="Dropout to apply." |
| ) |
| parser.add_argument( |
| "--attention_dropout", |
| default=0.1, |
| type=float, |
| help="Dropout to apply." |
| ) |
| parser.add_argument( |
| "--activation_dropout", |
| default=0.1, |
| type=float, |
| help="Dropout to apply." |
| ) |
| parser.add_argument( |
| "--max_tokens_per_msa", |
| default=2 ** 14, |
| type=int, |
| help=( |
| "Used during inference to batch attention computations in a single " |
| "forward pass. This allows increased input sizes with less memory." |
| ), |
| ) |
| |
|
|
| def __init__(self, args, alphabet): |
| super().__init__() |
| self.args = args |
| self.alphabet_size = len(alphabet) |
| self.padding_idx = alphabet.padding_idx |
| self.mask_idx = alphabet.mask_idx |
| self.cls_idx = alphabet.cls_idx |
| self.eos_idx = alphabet.eos_idx |
| self.prepend_bos = alphabet.prepend_bos |
| self.append_eos = alphabet.append_eos |
|
|
| self.embed_tokens = nn.Embedding( |
| self.alphabet_size, self.args.embed_dim, padding_idx=self.padding_idx |
| ) |
|
|
| if getattr(self.args, "embed_positions_msa", False): |
| emb_dim = getattr(self.args, "embed_positions_msa_dim", self.args.embed_dim) |
| self.msa_position_embedding = nn.Parameter( |
| 0.01 * torch.randn(1, 1024, 1, emb_dim), |
| requires_grad=True, |
| ) |
| else: |
| self.register_parameter("msa_position_embedding", None) |
|
|
| self.dropout_module = nn.Dropout(self.args.dropout) |
| self.layers = nn.ModuleList( |
| [ |
| AxialTransformerLayer( |
| self.args.embed_dim, |
| self.args.ffn_embed_dim, |
| self.args.attention_heads, |
| self.args.dropout, |
| self.args.attention_dropout, |
| self.args.activation_dropout, |
| getattr(self.args, "max_tokens_per_msa", self.args.max_tokens), |
| ) |
| for _ in range(self.args.layers) |
| ] |
| ) |
|
|
| self.contact_head = ContactPredictionHead( |
| self.args.layers * self.args.attention_heads, |
| self.prepend_bos, |
| self.append_eos, |
| eos_idx=self.eos_idx, |
| ) |
| self.embed_positions = LearnedPositionalEmbedding( |
| self.args.max_positions, |
| self.args.embed_dim, |
| self.padding_idx, |
| ) |
| self.emb_layer_norm_before = ESM1bLayerNorm(self.args.embed_dim) |
| self.emb_layer_norm_after = ESM1bLayerNorm(self.args.embed_dim) |
| self.lm_head = RobertaLMHead( |
| embed_dim=self.args.embed_dim, |
| output_dim=self.alphabet_size, |
| weight=self.embed_tokens.weight, |
| ) |
|
|
| def forward(self, tokens, repr_layers=[], need_head_weights=False, return_contacts=False): |
| if return_contacts: |
| need_head_weights = True |
|
|
| assert tokens.ndim == 3 |
| batch_size, num_alignments, seqlen = tokens.size() |
| padding_mask = tokens.eq(self.padding_idx) |
| if not padding_mask.any(): |
| padding_mask = None |
|
|
| x = self.embed_tokens(tokens) |
| x += self.embed_positions(tokens.view(batch_size * num_alignments, seqlen)).view(x.size()) |
| if self.msa_position_embedding is not None: |
| if x.size(1) > 1024: |
| raise RuntimeError( |
| "Using model with MSA position embedding trained on maximum MSA " |
| f"depth of 1024, but received {x.size(1)} alignments." |
| ) |
| x += self.msa_position_embedding[:, :num_alignments] |
|
|
| x = self.emb_layer_norm_before(x) |
|
|
| x = self.dropout_module(x) |
|
|
| if padding_mask is not None: |
| x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) |
|
|
| repr_layers = set(repr_layers) |
| hidden_representations = {} |
| if 0 in repr_layers: |
| hidden_representations[0] = x |
|
|
| if need_head_weights: |
| row_attn_weights = [] |
| col_attn_weights = [] |
|
|
| |
| x = x.permute(1, 2, 0, 3) |
|
|
| for layer_idx, layer in enumerate(self.layers): |
| x = layer( |
| x, |
| self_attn_padding_mask=padding_mask, |
| need_head_weights=need_head_weights, |
| ) |
| if need_head_weights: |
| x, col_attn, row_attn = x |
| |
| col_attn_weights.append(col_attn.permute(2, 0, 1, 3, 4)) |
| |
| row_attn_weights.append(row_attn.permute(1, 0, 2, 3)) |
| if (layer_idx + 1) in repr_layers: |
| hidden_representations[layer_idx + 1] = x.permute(2, 0, 1, 3) |
|
|
| x = self.emb_layer_norm_after(x) |
| x = x.permute(2, 0, 1, 3) |
|
|
| |
| if (layer_idx + 1) in repr_layers: |
| hidden_representations[layer_idx + 1] = x |
| x = self.lm_head(x) |
|
|
| result = {"logits": x, "representations": hidden_representations} |
| if need_head_weights: |
| |
| col_attentions = torch.stack(col_attn_weights, 1) |
| |
| row_attentions = torch.stack(row_attn_weights, 1) |
| result["col_attentions"] = col_attentions |
| result["row_attentions"] = row_attentions |
| if return_contacts: |
| contacts = self.contact_head(tokens, row_attentions) |
| result["contacts"] = contacts |
|
|
| return result |
|
|
| def predict_contacts(self, tokens): |
| return self(tokens, return_contacts=True)["contacts"] |
|
|
| @property |
| def num_layers(self): |
| return self.args.layers |
|
|
| def max_tokens_per_msa_(self, value: int) -> None: |
| """The MSA Transformer automatically batches attention computations when |
| gradients are disabled to allow you to pass in larger MSAs at test time than |
| you can fit in GPU memory. By default this occurs when more than 2^14 tokens |
| are passed in the input MSA. You can set this value to infinity to disable |
| this behavior. |
| """ |
| for module in self.modules(): |
| if isinstance(module, (RowSelfAttention, ColumnSelfAttention)): |
| module.max_tokens_per_msa = value |
|
|