| |
| |
| import torch |
| import torch.nn as nn |
|
|
| from utils.nn_utils import graph_to_batch |
|
|
| from .backbone import FrameBuilder |
|
|
|
|
| class BackboneModel(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.backbone_builder = FrameBuilder() |
|
|
| def forward(self, X, batch_ids): |
| ''' |
| X: [N, 14, 3], predicted all-atom coordinates (obviously with a lot of invalidities) |
| assume the first 4 are N, CA, C, O |
| S: [N], predicted sequence |
| ''' |
|
|
| |
| X, mask = graph_to_batch(X, batch_ids, mask_is_pad=False) |
| C = mask.long() |
|
|
| |
| R, t, q = self.backbone_builder.inverse(X, C) |
| X_bb = self.backbone_builder(R, t, C) |
| X = torch.cat([X_bb, X[:, :, 4:]], dim=-2) |
| |
| |
| return X[mask] |