| |
| |
| """ |
| Modified from https://github.com/generatebio/chroma/blob/main/chroma/layers/structure/backbone.py |
| """ |
| from typing import Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from ..sidechain.structure import geometry |
|
|
|
|
| def compose_translation( |
| R_a: torch.Tensor, t_a: torch.Tensor, t_b: torch.Tensor |
| ) -> torch.Tensor: |
| """Compose translation component of `T_compose = T_a * T_b` (broadcastable). |
| |
| Args: |
| R_a (torch.Tensor): Transform `T_a` rotation matrix with shape `(...,3,3)`. |
| t_a (torch.Tensor): Transform `T_a` translation with shape `(...,3)`. |
| t_b (torch.Tensor): Transform `T_b` translation with shape `(...,3)`. |
| |
| Returns: |
| t_composed (torch.Tensor): Composed transform `a * b` translation vector with |
| shape `(...,3)`. |
| """ |
| t_composed = t_a + (R_a @ t_b.unsqueeze(-1)).squeeze(-1) |
| return t_composed |
|
|
|
|
| class FrameBuilder(nn.Module): |
| """Build protein backbones from rigid residue poses. |
| |
| Inputs: |
| R (torch.Tensor): Rotation of residue orientiations |
| with shape `(num_batch, num_residues, 3, 3)`. If `None`, |
| then `q` must be provided instead. |
| t (torch.Tensor): Translation of residue orientiations |
| with shape `(num_batch, num_residues, 3)`. This is the |
| location of the C-alpha coordinates. |
| C (torch.Tensor): Chain map with shape `(num_batch, num_residues)`. |
| q (Tensor, optional): Quaternions representing residue orientiations |
| with shape `(num_batch, num_residues, 4)`. |
| |
| Outputs: |
| X (torch.Tensor): All-atom protein coordinates with shape |
| `(num_batch, num_residues, 4, 3)` |
| """ |
|
|
| def __init__(self, distance_eps: float = 1e-3): |
| super().__init__() |
|
|
| |
| t = torch.tensor( |
| [ |
| [1.459, 0.0, 0.0], |
| [0.0, 0.0, 0.0], |
| [-0.547, 0.0, -1.424], |
| ], |
| dtype=torch.float32, |
| ).reshape([1, 1, 3, 3]) |
| R = torch.eye(3).reshape([1, 1, 1, 3, 3]) |
| self.register_buffer("_t_atom", t) |
| self.register_buffer("_R_atom", R) |
|
|
| |
| self._length_C_O = 1.2297 |
| self._angle_CA_C_O = 122.5200 |
| self._dihedral_Np_CA_C_O = 180 |
| self.distance_eps = distance_eps |
|
|
| def _build_O(self, X_chain: torch.Tensor, C: torch.LongTensor): |
| """Build backbone carbonyl oxygen.""" |
| |
| X_N, X_CA, X_C = X_chain.unbind(-2) |
|
|
| |
| mask_next = (C > 0).float()[:, 1:].unsqueeze(-1) |
| X_N_next = F.pad(mask_next * X_N[:, 1:,], (0, 0, 0, 1),) |
|
|
| num_batch, num_residues = C.shape |
| ones = torch.ones(list(C.shape), dtype=torch.float32, device=C.device) |
| X_O = geometry.extend_atoms( |
| X_N_next, |
| X_CA, |
| X_C, |
| self._length_C_O * ones, |
| self._angle_CA_C_O * ones, |
| self._dihedral_Np_CA_C_O * ones, |
| degrees=True, |
| ) |
| mask = (C > 0).float().reshape(list(C.shape) + [1, 1]) |
| X = mask * torch.stack([X_N, X_CA, X_C, X_O], dim=-2) |
| return X |
|
|
| def forward( |
| self, |
| R: torch.Tensor, |
| t: torch.Tensor, |
| C: torch.LongTensor, |
| q: Optional[torch.Tensor] = None, |
| ): |
| assert q is None or R is None |
|
|
| if R is None: |
| |
| R = geometry.rotations_from_quaternions( |
| q, normalize=True, eps=self.distance_eps |
| ) |
|
|
| R = R.unsqueeze(-3) |
| t_frame = t.unsqueeze(-2) |
| X_chain = compose_translation(R, t_frame, self._t_atom) |
| X = self._build_O(X_chain, C) |
| return X |
|
|
| def inverse( |
| self, X: torch.Tensor, C: torch.LongTensor |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Reconstruct transformations from poses. |
| |
| Inputs: |
| X (torch.Tensor): All-atom protein coordinates with shape |
| `(num_batch, num_residues, 4, 3)` |
| C (torch.Tensor): Chain map with shape `(num_batch, num_residues)`. |
| |
| Outputs: |
| R (torch.Tensor): Rotation of residue orientiations |
| with shape `(num_batch, num_residues, 3, 3)`. |
| t (torch.Tensor): Translation of residue orientiations |
| with shape `(num_batch, num_residues, 3)`. This is the |
| location of the C-alpha coordinates. |
| q (torch.Tensor): Quaternions representing residue orientiations |
| with shape `(num_batch, num_residues, 4)`. |
| """ |
| X_bb = X[:, :, :4, :] |
| R, t = geometry.frames_from_backbone(X_bb, distance_eps=self.distance_eps) |
| q = geometry.quaternions_from_rotations(R, eps=self.distance_eps) |
| mask = (C > 0).float().unsqueeze(-1) |
| R = mask.unsqueeze(-1) * R |
| t = mask * t |
| q = mask * q |
| return R, t, q |