| import torch |
| import torch.nn as nn |
| from torch.nn import init |
| import torch.nn.functional as F |
| import random |
|
|
| class PLEncoder(nn.Module): |
|
|
| def __init__(self, embed_dim, pocket_graph=None, aggregator=None, idx2assayid={}, assayid_lst_train=[], mol_smi={}, train_label_lst=[], cuda="cpu", uv=True): |
| super(PLEncoder, self).__init__() |
|
|
| self.uv = uv |
| self.pocket_graph = pocket_graph |
| self.aggregator = aggregator |
| self.embed_dim = embed_dim |
| self.device = cuda |
| smi2idx = {smi:idx for idx, smi in enumerate(mol_smi)} |
| self.idx2assayid, self.assayid_lst_train, self.smi2idx, self.mol_smi, self.train_label_lst = idx2assayid, assayid_lst_train, smi2idx, mol_smi, train_label_lst |
| self.assayid_set_train = set(assayid_lst_train) |
| self.label_dicts = {x["assay_id"]: x for x in self.train_label_lst} |
| self.linear1 = nn.Linear(2 * self.embed_dim, self.embed_dim) |
|
|
| def forward(self, nodes_pocket, nodes_lig=None, max_sample=10): |
| to_neighs = [] |
| if nodes_lig is None: |
| lig_smi_lst = ["----"] * len(nodes_pocket) |
| else: |
| lig_smi_lst = [self.mol_smi[lig_id] for lig_id in nodes_lig] |
|
|
| for node, smi in zip(nodes_pocket, lig_smi_lst): |
| assayid = self.idx2assayid[node] |
| neighbors = [] |
| nbr_pockets = self.pocket_graph.get(assayid, []) |
| |
| |
| for n_assayid, score in nbr_pockets: |
| nbr_smi = self.label_dicts[n_assayid]["ligands"][0]["smi"] |
| if assayid == n_assayid: |
| continue |
| if smi == nbr_smi: |
| continue |
| if n_assayid not in self.assayid_set_train: |
| continue |
| neighbors.append((self.smi2idx[nbr_smi], int((score - 0.5) * 10))) |
| to_neighs.append(neighbors) |
|
|
| neigh_feats = self.aggregator.forward(nodes_pocket, to_neighs) |
| return neigh_feats |
| |
| def refine_pocket(self, pocket_embed, neighbor_list=None): |
| return self.aggregator.forward_inference(pocket_embed, neighbor_list) |
|
|