| import copy |
| import os |
| import warnings |
|
|
| import numpy as np |
| import scipy.spatial as spa |
| import torch |
| from Bio.PDB import PDBParser |
| from Bio.PDB.PDBExceptions import PDBConstructionWarning |
| from rdkit import Chem |
| from rdkit.Chem.rdchem import BondType as BT |
| from rdkit.Chem import AllChem, GetPeriodicTable, RemoveHs |
| from rdkit.Geometry import Point3D |
| from scipy import spatial |
| from scipy.special import softmax |
| from torch_cluster import radius_graph |
|
|
|
|
| import torch.nn.functional as F |
|
|
| from datasets.conformer_matching import get_torsion_angles, optimize_rotatable_bonds |
| from utils.torsion import get_transformation_mask |
|
|
|
|
| biopython_parser = PDBParser() |
| periodic_table = GetPeriodicTable() |
| allowable_features = { |
| 'possible_atomic_num_list': list(range(1, 119)) + ['misc'], |
| 'possible_chirality_list': [ |
| 'CHI_UNSPECIFIED', |
| 'CHI_TETRAHEDRAL_CW', |
| 'CHI_TETRAHEDRAL_CCW', |
| 'CHI_OTHER' |
| ], |
| 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], |
| 'possible_numring_list': [0, 1, 2, 3, 4, 5, 6, 'misc'], |
| 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6, 'misc'], |
| 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], |
| 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], |
| 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], |
| 'possible_hybridization_list': [ |
| 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'misc' |
| ], |
| 'possible_is_aromatic_list': [False, True], |
| 'possible_is_in_ring3_list': [False, True], |
| 'possible_is_in_ring4_list': [False, True], |
| 'possible_is_in_ring5_list': [False, True], |
| 'possible_is_in_ring6_list': [False, True], |
| 'possible_is_in_ring7_list': [False, True], |
| 'possible_is_in_ring8_list': [False, True], |
| 'possible_amino_acids': ['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', |
| 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL', 'HIP', 'HIE', 'TPO', 'HID', 'LEV', 'MEU', |
| 'PTR', 'GLV', 'CYT', 'SEP', 'HIZ', 'CYM', 'GLM', 'ASQ', 'TYS', 'CYX', 'GLZ', 'misc'], |
| 'possible_atom_type_2': ['C*', 'CA', 'CB', 'CD', 'CE', 'CG', 'CH', 'CZ', 'N*', 'ND', 'NE', 'NH', 'NZ', 'O*', 'OD', |
| 'OE', 'OG', 'OH', 'OX', 'S*', 'SD', 'SG', 'misc'], |
| 'possible_atom_type_3': ['C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2', |
| 'CZ', 'CZ2', 'CZ3', 'N', 'ND1', 'ND2', 'NE', 'NE1', 'NE2', 'NH1', 'NH2', 'NZ', 'O', 'OD1', |
| 'OD2', 'OE1', 'OE2', 'OG', 'OG1', 'OH', 'OXT', 'SD', 'SG', 'misc'], |
| } |
| bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} |
|
|
| lig_feature_dims = (list(map(len, [ |
| allowable_features['possible_atomic_num_list'], |
| allowable_features['possible_chirality_list'], |
| allowable_features['possible_degree_list'], |
| allowable_features['possible_formal_charge_list'], |
| allowable_features['possible_implicit_valence_list'], |
| allowable_features['possible_numH_list'], |
| allowable_features['possible_number_radical_e_list'], |
| allowable_features['possible_hybridization_list'], |
| allowable_features['possible_is_aromatic_list'], |
| allowable_features['possible_numring_list'], |
| allowable_features['possible_is_in_ring3_list'], |
| allowable_features['possible_is_in_ring4_list'], |
| allowable_features['possible_is_in_ring5_list'], |
| allowable_features['possible_is_in_ring6_list'], |
| allowable_features['possible_is_in_ring7_list'], |
| allowable_features['possible_is_in_ring8_list'], |
| ])), 0) |
|
|
| rec_atom_feature_dims = (list(map(len, [ |
| allowable_features['possible_amino_acids'], |
| allowable_features['possible_atomic_num_list'], |
| allowable_features['possible_atom_type_2'], |
| allowable_features['possible_atom_type_3'], |
| ])), 0) |
|
|
| rec_residue_feature_dims = (list(map(len, [ |
| allowable_features['possible_amino_acids'] |
| ])), 0) |
|
|
|
|
| def lig_atom_featurizer(mol): |
| ringinfo = mol.GetRingInfo() |
| atom_features_list = [] |
| for idx, atom in enumerate(mol.GetAtoms()): |
| atom_features_list.append([ |
| safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()), |
| allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())), |
| safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()), |
| safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()), |
| safe_index(allowable_features['possible_implicit_valence_list'], atom.GetImplicitValence()), |
| safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()), |
| safe_index(allowable_features['possible_number_radical_e_list'], atom.GetNumRadicalElectrons()), |
| safe_index(allowable_features['possible_hybridization_list'], str(atom.GetHybridization())), |
| allowable_features['possible_is_aromatic_list'].index(atom.GetIsAromatic()), |
| safe_index(allowable_features['possible_numring_list'], ringinfo.NumAtomRings(idx)), |
| allowable_features['possible_is_in_ring3_list'].index(ringinfo.IsAtomInRingOfSize(idx, 3)), |
| allowable_features['possible_is_in_ring4_list'].index(ringinfo.IsAtomInRingOfSize(idx, 4)), |
| allowable_features['possible_is_in_ring5_list'].index(ringinfo.IsAtomInRingOfSize(idx, 5)), |
| allowable_features['possible_is_in_ring6_list'].index(ringinfo.IsAtomInRingOfSize(idx, 6)), |
| allowable_features['possible_is_in_ring7_list'].index(ringinfo.IsAtomInRingOfSize(idx, 7)), |
| allowable_features['possible_is_in_ring8_list'].index(ringinfo.IsAtomInRingOfSize(idx, 8)), |
| ]) |
|
|
| return torch.tensor(atom_features_list) |
|
|
|
|
| def rec_residue_featurizer(rec): |
| feature_list = [] |
| for residue in rec.get_residues(): |
| feature_list.append([safe_index(allowable_features['possible_amino_acids'], residue.get_resname())]) |
| return torch.tensor(feature_list, dtype=torch.float32) |
|
|
|
|
| def safe_index(l, e): |
| """ Return index of element e in list l. If e is not present, return the last index """ |
| try: |
| return l.index(e) |
| except: |
| return len(l) - 1 |
|
|
|
|
|
|
| def parse_receptor(pdbid, pdbbind_dir): |
| rec = parsePDB(pdbid, pdbbind_dir) |
| return rec |
|
|
|
|
| def parsePDB(pdbid, pdbbind_dir): |
| rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein_processed.pdb') |
| return parse_pdb_from_path(rec_path) |
|
|
| def parse_pdb_from_path(path): |
| with warnings.catch_warnings(): |
| warnings.filterwarnings("ignore", category=PDBConstructionWarning) |
| structure = biopython_parser.get_structure('random_id', path) |
| rec = structure[0] |
| return rec |
|
|
|
|
| def extract_receptor_structure(rec, lig, lm_embedding_chains=None): |
| conf = lig.GetConformer() |
| lig_coords = conf.GetPositions() |
| min_distances = [] |
| coords = [] |
| c_alpha_coords = [] |
| n_coords = [] |
| c_coords = [] |
| valid_chain_ids = [] |
| lengths = [] |
| for i, chain in enumerate(rec): |
| chain_coords = [] |
| chain_c_alpha_coords = [] |
| chain_n_coords = [] |
| chain_c_coords = [] |
| count = 0 |
| invalid_res_ids = [] |
| for res_idx, residue in enumerate(chain): |
| if residue.get_resname() == 'HOH': |
| invalid_res_ids.append(residue.get_id()) |
| continue |
| residue_coords = [] |
| c_alpha, n, c = None, None, None |
| for atom in residue: |
| if atom.name == 'CA': |
| c_alpha = list(atom.get_vector()) |
| if atom.name == 'N': |
| n = list(atom.get_vector()) |
| if atom.name == 'C': |
| c = list(atom.get_vector()) |
| residue_coords.append(list(atom.get_vector())) |
|
|
| if c_alpha != None and n != None and c != None: |
| |
| chain_c_alpha_coords.append(c_alpha) |
| chain_n_coords.append(n) |
| chain_c_coords.append(c) |
| chain_coords.append(np.array(residue_coords)) |
| count += 1 |
| else: |
| invalid_res_ids.append(residue.get_id()) |
| for res_id in invalid_res_ids: |
| chain.detach_child(res_id) |
| if len(chain_coords) > 0: |
| all_chain_coords = np.concatenate(chain_coords, axis=0) |
| distances = spatial.distance.cdist(lig_coords, all_chain_coords) |
| min_distance = distances.min() |
| else: |
| min_distance = np.inf |
|
|
| min_distances.append(min_distance) |
| lengths.append(count) |
| coords.append(chain_coords) |
| c_alpha_coords.append(np.array(chain_c_alpha_coords)) |
| n_coords.append(np.array(chain_n_coords)) |
| c_coords.append(np.array(chain_c_coords)) |
| if not count == 0: valid_chain_ids.append(chain.get_id()) |
|
|
| min_distances = np.array(min_distances) |
| if len(valid_chain_ids) == 0: |
| valid_chain_ids.append(np.argmin(min_distances)) |
| valid_coords = [] |
| valid_c_alpha_coords = [] |
| valid_n_coords = [] |
| valid_c_coords = [] |
| valid_lengths = [] |
| invalid_chain_ids = [] |
| valid_lm_embeddings = [] |
| for i, chain in enumerate(rec): |
| if chain.get_id() in valid_chain_ids: |
| valid_coords.append(coords[i]) |
| valid_c_alpha_coords.append(c_alpha_coords[i]) |
| if lm_embedding_chains is not None: |
| if i >= len(lm_embedding_chains): |
| raise ValueError('Encountered valid chain id that was not present in the LM embeddings') |
| valid_lm_embeddings.append(lm_embedding_chains[i]) |
| valid_n_coords.append(n_coords[i]) |
| valid_c_coords.append(c_coords[i]) |
| valid_lengths.append(lengths[i]) |
| else: |
| invalid_chain_ids.append(chain.get_id()) |
| coords = [item for sublist in valid_coords for item in sublist] |
|
|
| c_alpha_coords = np.concatenate(valid_c_alpha_coords, axis=0) |
| n_coords = np.concatenate(valid_n_coords, axis=0) |
| c_coords = np.concatenate(valid_c_coords, axis=0) |
| lm_embeddings = np.concatenate(valid_lm_embeddings, axis=0) if lm_embedding_chains is not None else None |
| for invalid_id in invalid_chain_ids: |
| rec.detach_child(invalid_id) |
|
|
| assert len(c_alpha_coords) == len(n_coords) |
| assert len(c_alpha_coords) == len(c_coords) |
| assert sum(valid_lengths) == len(c_alpha_coords) |
| return rec, coords, c_alpha_coords, n_coords, c_coords, lm_embeddings |
|
|
|
|
| def get_lig_graph(mol, complex_graph): |
| lig_coords = torch.from_numpy(mol.GetConformer().GetPositions()).float() |
| atom_feats = lig_atom_featurizer(mol) |
|
|
| row, col, edge_type = [], [], [] |
| for bond in mol.GetBonds(): |
| start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() |
| row += [start, end] |
| col += [end, start] |
| edge_type += 2 * [bonds[bond.GetBondType()]] if bond.GetBondType() != BT.UNSPECIFIED else [0, 0] |
|
|
| edge_index = torch.tensor([row, col], dtype=torch.long) |
| edge_type = torch.tensor(edge_type, dtype=torch.long) |
| edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float) |
|
|
| complex_graph['ligand'].x = atom_feats |
| complex_graph['ligand'].pos = lig_coords |
| complex_graph['ligand', 'lig_bond', 'ligand'].edge_index = edge_index |
| complex_graph['ligand', 'lig_bond', 'ligand'].edge_attr = edge_attr |
| return |
|
|
| def generate_conformer(mol): |
| ps = AllChem.ETKDGv2() |
| id = AllChem.EmbedMolecule(mol, ps) |
| if id == -1: |
| print('rdkit coords could not be generated without using random coords. using random coords now.') |
| ps.useRandomCoords = True |
| AllChem.EmbedMolecule(mol, ps) |
| AllChem.MMFFOptimizeMolecule(mol, confId=0) |
| |
| |
|
|
| def get_lig_graph_with_matching(mol_, complex_graph, popsize, maxiter, matching, keep_original, num_conformers, remove_hs): |
| if matching: |
| mol_maybe_noh = copy.deepcopy(mol_) |
| if remove_hs: |
| mol_maybe_noh = RemoveHs(mol_maybe_noh, sanitize=True) |
| if keep_original: |
| complex_graph['ligand'].orig_pos = mol_maybe_noh.GetConformer().GetPositions() |
|
|
| rotable_bonds = get_torsion_angles(mol_maybe_noh) |
| if not rotable_bonds: print("no_rotable_bonds but still using it") |
|
|
| for i in range(num_conformers): |
| mol_rdkit = copy.deepcopy(mol_) |
|
|
| mol_rdkit.RemoveAllConformers() |
| mol_rdkit = AllChem.AddHs(mol_rdkit) |
| generate_conformer(mol_rdkit) |
| if remove_hs: |
| mol_rdkit = RemoveHs(mol_rdkit, sanitize=True) |
| mol = copy.deepcopy(mol_maybe_noh) |
| if rotable_bonds: |
| optimize_rotatable_bonds(mol_rdkit, mol, rotable_bonds, popsize=popsize, maxiter=maxiter) |
| mol.AddConformer(mol_rdkit.GetConformer()) |
| rms_list = [] |
| AllChem.AlignMolConformers(mol, RMSlist=rms_list) |
| mol_rdkit.RemoveAllConformers() |
| mol_rdkit.AddConformer(mol.GetConformers()[1]) |
|
|
| if i == 0: |
| complex_graph.rmsd_matching = rms_list[0] |
| get_lig_graph(mol_rdkit, complex_graph) |
| else: |
| if torch.is_tensor(complex_graph['ligand'].pos): |
| complex_graph['ligand'].pos = [complex_graph['ligand'].pos] |
| complex_graph['ligand'].pos.append(torch.from_numpy(mol_rdkit.GetConformer().GetPositions()).float()) |
|
|
| else: |
| complex_graph.rmsd_matching = 0 |
| if remove_hs: mol_ = RemoveHs(mol_) |
| get_lig_graph(mol_, complex_graph) |
|
|
| edge_mask, mask_rotate = get_transformation_mask(complex_graph) |
| complex_graph['ligand'].edge_mask = torch.tensor(edge_mask) |
| complex_graph['ligand'].mask_rotate = mask_rotate |
|
|
| return |
|
|
|
|
| def get_calpha_graph(rec, c_alpha_coords, n_coords, c_coords, complex_graph, cutoff=20, max_neighbor=None, lm_embeddings=None): |
| n_rel_pos = n_coords - c_alpha_coords |
| c_rel_pos = c_coords - c_alpha_coords |
| num_residues = len(c_alpha_coords) |
| if num_residues <= 1: |
| raise ValueError(f"rec contains only 1 residue!") |
|
|
| |
| distances = spa.distance.cdist(c_alpha_coords, c_alpha_coords) |
| src_list = [] |
| dst_list = [] |
| mean_norm_list = [] |
| for i in range(num_residues): |
| dst = list(np.where(distances[i, :] < cutoff)[0]) |
| dst.remove(i) |
| if max_neighbor != None and len(dst) > max_neighbor: |
| dst = list(np.argsort(distances[i, :]))[1: max_neighbor + 1] |
| if len(dst) == 0: |
| dst = list(np.argsort(distances[i, :]))[1:2] |
| print(f'The c_alpha_cutoff {cutoff} was too small for one c_alpha such that it had no neighbors. ' |
| f'So we connected it to the closest other c_alpha') |
| assert i not in dst |
| src = [i] * len(dst) |
| src_list.extend(src) |
| dst_list.extend(dst) |
| valid_dist = list(distances[i, dst]) |
| valid_dist_np = distances[i, dst] |
| sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1)) |
| weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) |
| assert weights[0].sum() > 1 - 1e-2 and weights[0].sum() < 1.01 |
| diff_vecs = c_alpha_coords[src, :] - c_alpha_coords[dst, :] |
| mean_vec = weights.dot(diff_vecs) |
| denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) |
| mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator |
| mean_norm_list.append(mean_vec_ratio_norm) |
| assert len(src_list) == len(dst_list) |
|
|
| node_feat = rec_residue_featurizer(rec) |
| mu_r_norm = torch.from_numpy(np.array(mean_norm_list).astype(np.float32)) |
| side_chain_vecs = torch.from_numpy( |
| np.concatenate([np.expand_dims(n_rel_pos, axis=1), np.expand_dims(c_rel_pos, axis=1)], axis=1)) |
|
|
| complex_graph['receptor'].x = torch.cat([node_feat, torch.tensor(lm_embeddings)], axis=1) if lm_embeddings is not None else node_feat |
| complex_graph['receptor'].pos = torch.from_numpy(c_alpha_coords).float() |
| complex_graph['receptor'].mu_r_norm = mu_r_norm |
| complex_graph['receptor'].side_chain_vecs = side_chain_vecs.float() |
| complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = torch.from_numpy(np.asarray([src_list, dst_list])) |
|
|
| return |
|
|
|
|
| def rec_atom_featurizer(rec): |
| atom_feats = [] |
| for i, atom in enumerate(rec.get_atoms()): |
| atom_name, element = atom.name, atom.element |
| if element == 'CD': |
| element = 'C' |
| assert not element == '' |
| try: |
| atomic_num = periodic_table.GetAtomicNumber(element) |
| except: |
| atomic_num = -1 |
| atom_feat = [safe_index(allowable_features['possible_amino_acids'], atom.get_parent().get_resname()), |
| safe_index(allowable_features['possible_atomic_num_list'], atomic_num), |
| safe_index(allowable_features['possible_atom_type_2'], (atom_name + '*')[:2]), |
| safe_index(allowable_features['possible_atom_type_3'], atom_name)] |
| atom_feats.append(atom_feat) |
|
|
| return atom_feats |
|
|
|
|
| def get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius, c_alpha_max_neighbors=None, all_atoms=False, |
| atom_radius=5, atom_max_neighbors=None, remove_hs=False, lm_embeddings=None): |
| if all_atoms: |
| return get_fullrec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, |
| c_alpha_cutoff=rec_radius, c_alpha_max_neighbors=c_alpha_max_neighbors, |
| atom_cutoff=atom_radius, atom_max_neighbors=atom_max_neighbors, remove_hs=remove_hs,lm_embeddings=lm_embeddings) |
| else: |
| return get_calpha_graph(rec, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius, c_alpha_max_neighbors,lm_embeddings=lm_embeddings) |
|
|
|
|
| def get_fullrec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, c_alpha_cutoff=20, |
| c_alpha_max_neighbors=None, atom_cutoff=5, atom_max_neighbors=None, remove_hs=False, lm_embeddings=None): |
| |
|
|
| n_rel_pos = n_coords - c_alpha_coords |
| c_rel_pos = c_coords - c_alpha_coords |
| num_residues = len(c_alpha_coords) |
| if num_residues <= 1: |
| raise ValueError(f"rec contains only 1 residue!") |
|
|
| |
| distances = spa.distance.cdist(c_alpha_coords, c_alpha_coords) |
| src_list = [] |
| dst_list = [] |
| mean_norm_list = [] |
| for i in range(num_residues): |
| dst = list(np.where(distances[i, :] < c_alpha_cutoff)[0]) |
| dst.remove(i) |
| if c_alpha_max_neighbors != None and len(dst) > c_alpha_max_neighbors: |
| dst = list(np.argsort(distances[i, :]))[1: c_alpha_max_neighbors + 1] |
| if len(dst) == 0: |
| dst = list(np.argsort(distances[i, :]))[1:2] |
| print(f'The c_alpha_cutoff {c_alpha_cutoff} was too small for one c_alpha such that it had no neighbors. ' |
| f'So we connected it to the closest other c_alpha') |
| assert i not in dst |
| src = [i] * len(dst) |
| src_list.extend(src) |
| dst_list.extend(dst) |
| valid_dist = list(distances[i, dst]) |
| valid_dist_np = distances[i, dst] |
| sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1)) |
| weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) |
| assert 1 - 1e-2 < weights[0].sum() < 1.01 |
| diff_vecs = c_alpha_coords[src, :] - c_alpha_coords[dst, :] |
| mean_vec = weights.dot(diff_vecs) |
| denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) |
| mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator |
| mean_norm_list.append(mean_vec_ratio_norm) |
| assert len(src_list) == len(dst_list) |
|
|
| node_feat = rec_residue_featurizer(rec) |
| mu_r_norm = torch.from_numpy(np.array(mean_norm_list).astype(np.float32)) |
| side_chain_vecs = torch.from_numpy( |
| np.concatenate([np.expand_dims(n_rel_pos, axis=1), np.expand_dims(c_rel_pos, axis=1)], axis=1)) |
|
|
| complex_graph['receptor'].x = torch.cat([node_feat, torch.tensor(lm_embeddings)], axis=1) if lm_embeddings is not None else node_feat |
| complex_graph['receptor'].pos = torch.from_numpy(c_alpha_coords).float() |
| complex_graph['receptor'].mu_r_norm = mu_r_norm |
| complex_graph['receptor'].side_chain_vecs = side_chain_vecs.float() |
| complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = torch.from_numpy(np.asarray([src_list, dst_list])) |
|
|
| src_c_alpha_idx = np.concatenate([np.asarray([i]*len(l)) for i, l in enumerate(rec_coords)]) |
| atom_feat = torch.from_numpy(np.asarray(rec_atom_featurizer(rec))) |
| atom_coords = torch.from_numpy(np.concatenate(rec_coords, axis=0)).float() |
|
|
| if remove_hs: |
| not_hs = (atom_feat[:, 1] != 0) |
| src_c_alpha_idx = src_c_alpha_idx[not_hs] |
| atom_feat = atom_feat[not_hs] |
| atom_coords = atom_coords[not_hs] |
|
|
| atoms_edge_index = radius_graph(atom_coords, atom_cutoff, max_num_neighbors=atom_max_neighbors if atom_max_neighbors else 1000) |
| atom_res_edge_index = torch.from_numpy(np.asarray([np.arange(len(atom_feat)), src_c_alpha_idx])).long() |
|
|
| complex_graph['atom'].x = atom_feat |
| complex_graph['atom'].pos = atom_coords |
| complex_graph['atom', 'atom_contact', 'atom'].edge_index = atoms_edge_index |
| complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index = atom_res_edge_index |
|
|
| return |
|
|
| def write_mol_with_coords(mol, new_coords, path): |
| w = Chem.SDWriter(path) |
| conf = mol.GetConformer() |
| for i in range(mol.GetNumAtoms()): |
| x,y,z = new_coords.astype(np.double)[i] |
| conf.SetAtomPosition(i,Point3D(x,y,z)) |
| w.write(mol) |
| w.close() |
|
|
| def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=False): |
| if molecule_file.endswith('.mol2'): |
| mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False) |
| elif molecule_file.endswith('.sdf'): |
| print(molecule_file) |
| supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False) |
| mol = supplier[0] |
| print(mol) |
| elif molecule_file.endswith('.pdbqt'): |
| with open(molecule_file) as file: |
| pdbqt_data = file.readlines() |
| pdb_block = '' |
| for line in pdbqt_data: |
| pdb_block += '{}\n'.format(line[:66]) |
| mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False) |
| elif molecule_file.endswith('.pdb'): |
| mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False) |
| else: |
| return ValueError('Expect the format of the molecule_file to be ' |
| 'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file)) |
|
|
| print(sanitize, calc_charges, remove_hs) |
|
|
| try: |
| if sanitize or calc_charges: |
| Chem.SanitizeMol(mol) |
|
|
| if calc_charges: |
| |
| try: |
| AllChem.ComputeGasteigerCharges(mol) |
| except: |
| warnings.warn('Unable to compute charges for the molecule.') |
|
|
| if remove_hs: |
| mol = Chem.RemoveHs(mol, sanitize=sanitize) |
| except Exception as e: |
| print(e) |
| return None |
|
|
| return mol |
|
|
|
|
| def read_sdf_or_mol2(sdf_fileName, mol2_fileName): |
|
|
| mol = Chem.MolFromMolFile(sdf_fileName, sanitize=False) |
| problem = False |
| try: |
| Chem.SanitizeMol(mol) |
| mol = Chem.RemoveHs(mol) |
| except Exception as e: |
| problem = True |
| if problem: |
| mol = Chem.MolFromMol2File(mol2_fileName, sanitize=False) |
| try: |
| Chem.SanitizeMol(mol) |
| mol = Chem.RemoveHs(mol) |
| problem = False |
| except Exception as e: |
| problem = True |
|
|
| return mol, problem |
|
|