| import json |
| import random |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from PL_Encoder import PLEncoder |
| from PL_Aggregator import PLAggregator |
| from PP_Encoder import PPEncoder |
| from PP_Aggregator import PPAggregator |
| import torch.nn.functional as F |
| import torch.utils.data |
| import argparse |
| import os |
| from util import cal_metrics |
| from read_fasta import read_fasta_from_pocket, read_fasta_from_protein |
| from align import get_neighbor_pocket |
|
|
|
|
| class HGNN(nn.Module): |
|
|
| def __init__(self, enc_u, enc_v=None, r2e=None): |
| super(HGNN, self).__init__() |
| self.enc_u = enc_u |
| self.enc_v = enc_v |
| self.embed_dim = enc_u.embed_dim |
|
|
| self.w_ur1 = nn.Linear(self.embed_dim, self.embed_dim) |
| self.w_ur2 = nn.Linear(self.embed_dim, self.embed_dim) |
| self.w_vr1 = nn.Linear(self.embed_dim, self.embed_dim) |
| self.w_vr2 = nn.Linear(self.embed_dim, self.embed_dim) |
|
|
| self.r2e = r2e |
| self.bn1 = nn.BatchNorm1d(self.embed_dim, momentum=0.5) |
| self.bn2 = nn.BatchNorm1d(self.embed_dim, momentum=0.5) |
|
|
| self.logit_scale = nn.Parameter(torch.ones([1], device="cuda") * np.log(14)) |
|
|
| def trainable_parameters(self): |
| for name, param in self.named_parameters(recurse=True): |
| if param.requires_grad: |
| yield param |
|
|
| def forward(self, nodes_u, nodes_v): |
| embeds_u = self.enc_u(nodes_u, nodes_v) |
| embeds_v = self.enc_v(nodes_v) |
| return embeds_u, embeds_v |
|
|
| def criterion(self, x_u, x_v, labels): |
|
|
| netout = torch.matmul(x_u, torch.transpose(x_v, 0, 1)) |
| score = netout * self.logit_scale.exp().detach() |
| score = (labels - torch.eye(len(labels)).to(labels.device)) * -1e6 + score |
|
|
| lprobs_pocket = F.log_softmax(score.float(), dim=-1) |
| lprobs_pocket = lprobs_pocket.view(-1, lprobs_pocket.size(-1)) |
| sample_size = lprobs_pocket.size(0) |
| targets = torch.arange(sample_size, dtype=torch.long).view(-1).cuda() |
|
|
| |
| loss_pocket = F.nll_loss( |
| lprobs_pocket, |
| targets, |
| reduction="mean" |
| ) |
|
|
| lprobs_mol = F.log_softmax(torch.transpose(score.float(), 0, 1), dim=-1) |
| lprobs_mol = lprobs_mol.view(-1, lprobs_mol.size(-1)) |
| lprobs_mol = lprobs_mol[:sample_size] |
|
|
| |
| loss_mol = F.nll_loss( |
| lprobs_mol, |
| targets, |
| reduction="mean" |
| ) |
|
|
| loss = 0.5 * loss_pocket + 0.5 * loss_mol |
|
|
| ef_all = [] |
| for i in range(len(netout)): |
| act_pocket = labels[i] |
| affi_pocket = netout[i] |
| top1_index = torch.argmax(affi_pocket) |
| top1_act = act_pocket[top1_index] |
| ef_all.append(cal_metrics(affi_pocket.detach().cpu().numpy(), act_pocket.detach().cpu().numpy())) |
| ef_mean = {k: np.mean([x[k] for x in ef_all]) for k in ef_all[0].keys()} |
|
|
| return loss, ef_mean, netout |
|
|
| def loss(self, nodes_u, nodes_v, labels): |
| x_u, x_v = self.forward(nodes_u, nodes_v) |
| loss, ef_mean, netout = self.criterion(x_u, x_v, labels) |
| return loss, ef_mean |
|
|
| def refine_pocket(self, pocket_embed, neighbor_pocket_list): |
| embeds_u = self.enc_u.refine_pocket(pocket_embed, neighbor_pocket_list) |
| return embeds_u |
|
|
|
|
|
|
| def main(): |
| |
| parser = argparse.ArgumentParser(description='HGNN model inference') |
| parser.add_argument('--embed_dim', type=int, default=128, metavar='N', help='embedding size') |
| parser.add_argument("--test_ckpt", type=str, default=None) |
| parser.add_argument("--data_root", type=str, default="../data") |
| parser.add_argument("--result_root", type=str, default="../result/pocket_ranking") |
| parser.add_argument("--pocket_embed", type=str, default="../example/pocket_embed.npy") |
| parser.add_argument("--save_file", type=str, default="../example/refined_pocket.npy") |
| parser.add_argument("--pocket_pdb", type=str, default=None) |
| parser.add_argument("--protein_pdb", type=str, default="../example/protein.pdb") |
| parser.add_argument("--ligand_pdb", type=str, default="../example/ligand.pdb") |
| |
| args = parser.parse_args() |
|
|
| seed = 42 |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
|
|
| use_cuda = False |
| if torch.cuda.is_available(): |
| use_cuda = True |
| device = torch.device("cuda" if use_cuda else "cpu") |
|
|
| embed_dim = args.embed_dim |
| type2e = nn.Embedding(10, embed_dim).to(device) |
|
|
| |
| agg_pocket = PLAggregator(r2e=type2e, embed_dim=embed_dim, cuda=device, uv=True) |
| enc_pocket = PLEncoder(embed_dim=embed_dim, aggregator=agg_pocket, cuda=device, uv=True) |
| agg_pocket_sim = PPAggregator(embed_dim=embed_dim, cuda=device) |
| enc_pocket = PPEncoder(enc_pocket, embed_dim=embed_dim, aggregator=agg_pocket_sim, cuda=device) |
| |
| model = HGNN(enc_pocket).to(device) |
| model.load_state_dict(torch.load(args.test_ckpt, weights_only=True), strict=False) |
| model.eval() |
|
|
| |
| pocket_embed = torch.tensor(np.load(args.pocket_embed)).to(device) |
|
|
| if args.pocket_pdb is not None: |
| pocket_fasta = read_fasta_from_pocket(args.pocket_pdb) |
| else: |
| pocket_fasta = read_fasta_from_protein(args.protein_pdb, args.ligand_pdb) |
| |
| |
| neighbor_pocket_list = get_neighbor_pocket(pocket_fasta, args.data_root, args.result_root, device) |
|
|
| |
| if len(neighbor_pocket_list) > 0: |
| with torch.no_grad(): |
| refined_pocket = model.refine_pocket(pocket_embed, neighbor_pocket_list) |
| refined_pocket = refined_pocket.cpu().numpy() |
| else: |
| refined_pocket = pocket_embed.cpu().numpy() |
|
|
| print("finished, saving refined pocket embedding into:", args.save_file) |
| np.save(args.save_file, refined_pocket) |
| |
|
|
| if __name__ == "__main__": |
| main() |
|
|