| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from functools import partial |
| import math |
| from typing import Optional, List |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from openfold.model.primitives import Linear, LayerNorm, Attention |
| from openfold.model.dropout import ( |
| DropoutRowwise, |
| DropoutColumnwise, |
| ) |
| from openfold.model.pair_transition import PairTransition |
| from openfold.model.triangular_attention import ( |
| TriangleAttentionStartingNode, |
| TriangleAttentionEndingNode, |
| ) |
| from openfold.model.triangular_multiplicative_update import ( |
| TriangleMultiplicationOutgoing, |
| TriangleMultiplicationIncoming, |
| ) |
| from openfold.utils.checkpointing import checkpoint_blocks |
| from openfold.utils.tensor_utils import ( |
| chunk_layer, |
| permute_final_dims, |
| flatten_final_dims, |
| ) |
|
|
|
|
| class TemplatePointwiseAttention(nn.Module): |
| """ |
| Implements Algorithm 17. |
| """ |
| def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs): |
| """ |
| Args: |
| c_t: |
| Template embedding channel dimension |
| c_z: |
| Pair embedding channel dimension |
| c_hidden: |
| Hidden channel dimension |
| """ |
| super(TemplatePointwiseAttention, self).__init__() |
|
|
| self.c_t = c_t |
| self.c_z = c_z |
| self.c_hidden = c_hidden |
| self.no_heads = no_heads |
| self.inf = inf |
|
|
| self.mha = Attention( |
| self.c_z, |
| self.c_t, |
| self.c_t, |
| self.c_hidden, |
| self.no_heads, |
| gating=False, |
| ) |
|
|
| def _chunk(self, |
| z: torch.Tensor, |
| t: torch.Tensor, |
| biases: List[torch.Tensor], |
| chunk_size: int, |
| ) -> torch.Tensor: |
| mha_inputs = { |
| "q_x": z, |
| "kv_x": t, |
| "biases": biases, |
| } |
| return chunk_layer( |
| self.mha, |
| mha_inputs, |
| chunk_size=chunk_size, |
| no_batch_dims=len(z.shape[:-2]), |
| ) |
|
|
|
|
| def forward(self, |
| t: torch.Tensor, |
| z: torch.Tensor, |
| template_mask: Optional[torch.Tensor] = None, |
| chunk_size: Optional[int] = None |
| ) -> torch.Tensor: |
| """ |
| Args: |
| t: |
| [*, N_templ, N_res, N_res, C_t] template embedding |
| z: |
| [*, N_res, N_res, C_t] pair embedding |
| template_mask: |
| [*, N_templ] template mask |
| Returns: |
| [*, N_res, N_res, C_z] pair embedding update |
| """ |
| if template_mask is None: |
| template_mask = t.new_ones(t.shape[:-3]) |
|
|
| bias = self.inf * (template_mask[..., None, None, None, None, :] - 1) |
|
|
| |
| z = z.unsqueeze(-2) |
|
|
| |
| t = permute_final_dims(t, (1, 2, 0, 3)) |
|
|
| |
| biases = [bias] |
| if chunk_size is not None: |
| z = self._chunk(z, t, biases, chunk_size) |
| else: |
| z = self.mha(q_x=z, kv_x=t, biases=biases) |
|
|
| |
| z = z.squeeze(-2) |
|
|
| return z |
|
|
|
|
| class TemplatePairStackBlock(nn.Module): |
| def __init__( |
| self, |
| c_t: int, |
| c_hidden_tri_att: int, |
| c_hidden_tri_mul: int, |
| no_heads: int, |
| pair_transition_n: int, |
| dropout_rate: float, |
| inf: float, |
| **kwargs, |
| ): |
| super(TemplatePairStackBlock, self).__init__() |
|
|
| self.c_t = c_t |
| self.c_hidden_tri_att = c_hidden_tri_att |
| self.c_hidden_tri_mul = c_hidden_tri_mul |
| self.no_heads = no_heads |
| self.pair_transition_n = pair_transition_n |
| self.dropout_rate = dropout_rate |
| self.inf = inf |
|
|
| self.dropout_row = DropoutRowwise(self.dropout_rate) |
| self.dropout_col = DropoutColumnwise(self.dropout_rate) |
|
|
| self.tri_att_start = TriangleAttentionStartingNode( |
| self.c_t, |
| self.c_hidden_tri_att, |
| self.no_heads, |
| inf=inf, |
| ) |
| self.tri_att_end = TriangleAttentionEndingNode( |
| self.c_t, |
| self.c_hidden_tri_att, |
| self.no_heads, |
| inf=inf, |
| ) |
|
|
| self.tri_mul_out = TriangleMultiplicationOutgoing( |
| self.c_t, |
| self.c_hidden_tri_mul, |
| ) |
| self.tri_mul_in = TriangleMultiplicationIncoming( |
| self.c_t, |
| self.c_hidden_tri_mul, |
| ) |
|
|
| self.pair_transition = PairTransition( |
| self.c_t, |
| self.pair_transition_n, |
| ) |
|
|
| def forward(self, |
| z: torch.Tensor, |
| mask: torch.Tensor, |
| chunk_size: Optional[int] = None, |
| _mask_trans: bool = True |
| ): |
| single_templates = [ |
| t.unsqueeze(-4) for t in torch.unbind(z, dim=-4) |
| ] |
| single_templates_masks = [ |
| m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3) |
| ] |
| for i in range(len(single_templates)): |
| single = single_templates[i] |
| single_mask = single_templates_masks[i] |
| |
| single = single + self.dropout_row( |
| self.tri_att_start( |
| single, |
| chunk_size=chunk_size, |
| mask=single_mask |
| ) |
| ) |
| single = single + self.dropout_col( |
| self.tri_att_end( |
| single, |
| chunk_size=chunk_size, |
| mask=single_mask |
| ) |
| ) |
| single = single + self.dropout_row( |
| self.tri_mul_out( |
| single, |
| mask=single_mask |
| ) |
| ) |
| single = single + self.dropout_row( |
| self.tri_mul_in( |
| single, |
| mask=single_mask |
| ) |
| ) |
| single = single + self.pair_transition( |
| single, |
| mask=single_mask if _mask_trans else None, |
| chunk_size=chunk_size, |
| ) |
|
|
| single_templates[i] = single |
|
|
| z = torch.cat(single_templates, dim=-4) |
|
|
| return z |
|
|
|
|
| class TemplatePairStack(nn.Module): |
| """ |
| Implements Algorithm 16. |
| """ |
| def __init__( |
| self, |
| c_t, |
| c_hidden_tri_att, |
| c_hidden_tri_mul, |
| no_blocks, |
| no_heads, |
| pair_transition_n, |
| dropout_rate, |
| blocks_per_ckpt, |
| inf=1e9, |
| **kwargs, |
| ): |
| """ |
| Args: |
| c_t: |
| Template embedding channel dimension |
| c_hidden_tri_att: |
| Per-head hidden dimension for triangular attention |
| c_hidden_tri_att: |
| Hidden dimension for triangular multiplication |
| no_blocks: |
| Number of blocks in the stack |
| pair_transition_n: |
| Scale of pair transition (Alg. 15) hidden dimension |
| dropout_rate: |
| Dropout rate used throughout the stack |
| blocks_per_ckpt: |
| Number of blocks per activation checkpoint. None disables |
| activation checkpointing |
| """ |
| super(TemplatePairStack, self).__init__() |
|
|
| self.blocks_per_ckpt = blocks_per_ckpt |
|
|
| self.blocks = nn.ModuleList() |
| for _ in range(no_blocks): |
| block = TemplatePairStackBlock( |
| c_t=c_t, |
| c_hidden_tri_att=c_hidden_tri_att, |
| c_hidden_tri_mul=c_hidden_tri_mul, |
| no_heads=no_heads, |
| pair_transition_n=pair_transition_n, |
| dropout_rate=dropout_rate, |
| inf=inf, |
| ) |
| self.blocks.append(block) |
|
|
| self.layer_norm = LayerNorm(c_t) |
|
|
| def forward( |
| self, |
| t: torch.tensor, |
| mask: torch.tensor, |
| chunk_size: int, |
| _mask_trans: bool = True, |
| ): |
| """ |
| Args: |
| t: |
| [*, N_templ, N_res, N_res, C_t] template embedding |
| mask: |
| [*, N_templ, N_res, N_res] mask |
| Returns: |
| [*, N_templ, N_res, N_res, C_t] template embedding update |
| """ |
| if(mask.shape[-3] == 1): |
| expand_idx = list(mask.shape) |
| expand_idx[-3] = t.shape[-4] |
| mask = mask.expand(*expand_idx) |
|
|
| t, = checkpoint_blocks( |
| blocks=[ |
| partial( |
| b, |
| mask=mask, |
| chunk_size=chunk_size, |
| _mask_trans=_mask_trans, |
| ) |
| for b in self.blocks |
| ], |
| args=(t,), |
| blocks_per_ckpt=self.blocks_per_ckpt if self.training else None, |
| ) |
|
|
| t = self.layer_norm(t) |
|
|
| return t |
|
|