| """Neural network for embedding node features.""" |
| import torch |
| from torch import nn |
| from models.utils import get_index_embedding, get_time_embedding, add_RoPE |
|
|
| class NodeEmbedder(nn.Module): |
| def __init__(self, module_cfg): |
| super(NodeEmbedder, self).__init__() |
| self._cfg = module_cfg |
| self.c_s = self._cfg.c_s |
| self.c_pos_emb = self._cfg.c_pos_emb |
| self.c_timestep_emb = self._cfg.c_timestep_emb |
| self.c_node_pre = 1280 |
| self.aatype_emb_dim = self._cfg.c_pos_emb |
|
|
| self.aatype_emb = nn.Embedding(21, self.aatype_emb_dim) |
|
|
| total_node_feats = self.aatype_emb_dim + self._cfg.c_timestep_emb + self.c_node_pre |
| |
|
|
| self.linear = nn.Sequential( |
| nn.Linear(total_node_feats, self.c_s), |
| nn.ReLU(), |
| nn.Dropout(self._cfg.dropout), |
| nn.Linear(self.c_s, self.c_s), |
| ) |
|
|
| def embed_t(self, timesteps, mask): |
| timestep_emb = get_time_embedding( |
| timesteps[:, 0], |
| self.c_timestep_emb, |
| max_positions=2056 |
| )[:, None, :].repeat(1, mask.shape[1], 1) |
| return timestep_emb * mask.unsqueeze(-1) |
|
|
| def forward(self, timesteps, aatype, node_repr_pre, mask): |
| ''' |
| mask: [B,L] |
| timesteps: [B,1] |
| energy: [B,] |
| ''' |
|
|
| b, num_res, device = mask.shape[0], mask.shape[1], mask.device |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| aatype_emb = self.aatype_emb(aatype) * mask.unsqueeze(-1) |
|
|
| |
| input_feats = [aatype_emb] |
| |
| time_emb = self.embed_t(timesteps, mask) |
| input_feats.append(time_emb) |
|
|
| input_feats.append(node_repr_pre) |
|
|
| out = self.linear(torch.cat(input_feats, dim=-1)) |
| |
| return add_RoPE(out) |