| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from functools import partialmethod, partial |
| import math |
| from typing import Optional, List |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from openfold.model.primitives import Linear, LayerNorm, Attention |
| from openfold.utils.tensor_utils import ( |
| chunk_layer, |
| permute_final_dims, |
| flatten_final_dims, |
| ) |
|
|
|
|
| class TriangleAttention(nn.Module): |
| def __init__( |
| self, c_in, c_hidden, no_heads, starting, inf=1e9 |
| ): |
| """ |
| Args: |
| c_in: |
| Input channel dimension |
| c_hidden: |
| Overall hidden channel dimension (not per-head) |
| no_heads: |
| Number of attention heads |
| """ |
| super(TriangleAttention, self).__init__() |
|
|
| self.c_in = c_in |
| self.c_hidden = c_hidden |
| self.no_heads = no_heads |
| self.starting = starting |
| self.inf = inf |
|
|
| self.layer_norm = LayerNorm(self.c_in) |
|
|
| self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") |
|
|
| self.mha = Attention( |
| self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads |
| ) |
|
|
| @torch.jit.ignore |
| def _chunk(self, |
| x: torch.Tensor, |
| biases: List[torch.Tensor], |
| chunk_size: int, |
| ) -> torch.Tensor: |
| mha_inputs = { |
| "q_x": x, |
| "kv_x": x, |
| "biases": biases, |
| } |
| return chunk_layer( |
| partial(self.mha), |
| mha_inputs, |
| chunk_size=chunk_size, |
| no_batch_dims=len(x.shape[:-2]), |
| ) |
|
|
| def forward(self, |
| x: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| chunk_size: Optional[int] = None |
| ) -> torch.Tensor: |
| """ |
| Args: |
| x: |
| [*, I, J, C_in] input tensor (e.g. the pair representation) |
| Returns: |
| [*, I, J, C_in] output tensor |
| """ |
| if mask is None: |
| |
| mask = x.new_ones( |
| x.shape[:-1], |
| ) |
|
|
| |
| if not self.starting: |
| x = x.transpose(-2, -3) |
| mask = mask.transpose(-1, -2) |
|
|
| |
| x = self.layer_norm(x) |
|
|
| |
| mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] |
|
|
| |
| triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1)) |
|
|
| |
| triangle_bias = triangle_bias.unsqueeze(-4) |
|
|
| biases = [mask_bias, triangle_bias] |
|
|
| if chunk_size is not None: |
| x = self._chunk(x, biases, chunk_size) |
| else: |
| x = self.mha(q_x=x, kv_x=x, biases=biases) |
|
|
| if not self.starting: |
| x = x.transpose(-2, -3) |
|
|
| return x |
|
|
|
|
| class TriangleAttentionStartingNode(TriangleAttention): |
| """ |
| Implements Algorithm 13. |
| """ |
|
|
| __init__ = partialmethod(TriangleAttention.__init__, starting=True) |
|
|
|
|
| class TriangleAttentionEndingNode(TriangleAttention): |
| """ |
| Implements Algorithm 14. |
| """ |
|
|
| __init__ = partialmethod(TriangleAttention.__init__, starting=False) |
|
|