| import torch |
| import torch.nn as nn |
| from torch.nn import init |
| import torch.nn.functional as F |
| import random |
| import copy |
|
|
| class PPEncoder(nn.Module): |
|
|
| def __init__(self, pocket_encoder, embed_dim, pocket_graph=None, aggregator=None, assayid_lst_all=[], assayid_lst_train=[], base_model=None, cuda="cpu"): |
| super(PPEncoder, self).__init__() |
|
|
| self.pocket_encoder = pocket_encoder |
| self.pocket_graph = pocket_graph |
| self.aggregator = aggregator |
| if base_model != None: |
| self.base_model = base_model |
| self.embed_dim = embed_dim |
| self.device = cuda |
| self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim) |
| self.assayid_lst_all, self.assayid_set_train = assayid_lst_all, set(assayid_lst_train) |
| self.assayid2idxes = {} |
| for idx, assayid in enumerate(assayid_lst_all): |
| if assayid not in self.assayid2idxes: |
| self.assayid2idxes[assayid] = [] |
| self.assayid2idxes[assayid].append(idx) |
|
|
| def forward(self, nodes_pocket, nodes_lig=None, max_sample=10): |
| to_neighs = [] |
|
|
| for node in nodes_pocket: |
| assayid = self.assayid_lst_all[node] |
| neighbors = [] |
| nbr_pockets = self.pocket_graph.get(assayid, []) |
| for n_assayid, score in nbr_pockets: |
| if n_assayid == assayid: |
| continue |
| if n_assayid not in self.assayid_set_train: |
| continue |
| neighbors.append((random.choices(self.assayid2idxes[n_assayid])[0], score)) |
| to_neighs.append(neighbors) |
|
|
| neigh_feats = self.aggregator.forward(nodes_pocket, to_neighs) |
| self_feats = self.pocket_encoder(nodes_pocket, nodes_lig, max_sample) |
|
|
| return (self_feats + neigh_feats) / 2 |
| |
| def refine_pocket(self, pocket_embed, neighbor_list=None): |
| neigh_feats = self.aggregator.forward_inference(pocket_embed, neighbor_list) |
| self_feats = self.pocket_encoder.refine_pocket(pocket_embed, neighbor_list) |
| return (self_feats + neigh_feats) / 2 |