| import subprocess |
|
|
| import numpy as np |
| import tempfile |
| from pathlib import Path |
| from tqdm import tqdm |
| from rdkit import Chem, DataStructs |
| from rdkit.Chem import AllChem |
| from rdkit.Chem import Descriptors, Crippen, Lipinski, QED |
| from rdkit.Chem import AtomKekulizeException, AtomValenceException, \ |
| KekulizeException, MolSanitizeException |
| from src.analysis.SA_Score.sascorer import calculateScore |
| from src.utils import write_sdf_file |
|
|
| from copy import deepcopy |
|
|
| from pdb import set_trace |
|
|
|
|
| class CategoricalDistribution: |
| EPS = 1e-10 |
|
|
| def __init__(self, histogram_dict, mapping): |
| histogram = np.zeros(len(mapping)) |
| for k, v in histogram_dict.items(): |
| histogram[mapping[k]] = v |
|
|
| |
| self.p = histogram / histogram.sum() |
| self.mapping = deepcopy(mapping) |
|
|
| def kl_divergence(self, other_sample): |
| sample_histogram = np.zeros(len(self.mapping)) |
| for x in other_sample: |
| |
| sample_histogram[x] += 1 |
|
|
| |
| q = sample_histogram / sample_histogram.sum() |
|
|
| return -np.sum(self.p * np.log(q / (self.p + self.EPS) + self.EPS)) |
|
|
|
|
| def check_mol(rdmol): |
| """ |
| See also: https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization |
| """ |
| if rdmol is None: |
| return 'is_none' |
|
|
| _rdmol = Chem.Mol(rdmol) |
| try: |
| Chem.SanitizeMol(_rdmol) |
| return 'valid' |
| except ValueError as e: |
| assert isinstance(e, MolSanitizeException) |
| return type(e).__name__ |
|
|
|
|
| def validity_analysis(rdmol_list): |
| """ |
| For explanations, see: https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization |
| """ |
|
|
| result = { |
| 'AtomValenceException': 0, |
| 'AtomKekulizeException': 0, |
| 'KekulizeException': 0, |
| 'other': 0, |
| 'valid': 0 |
| } |
|
|
| for rdmol in rdmol_list: |
| flag = check_mol(rdmol) |
|
|
| try: |
| result[flag] += 1 |
| except KeyError: |
| result['other'] += 1 |
|
|
| assert sum(result.values()) == len(rdmol_list) |
|
|
| return result |
|
|
|
|
| class MoleculeValidity: |
| def __init__(self, connectivity_thresh=1.0): |
| self.connectivity_thresh = connectivity_thresh |
|
|
| def compute_validity(self, generated): |
| """ generated: list of RDKit molecules. """ |
| if len(generated) < 1: |
| return [], 0.0 |
|
|
| |
| valid = [Chem.Mol(mol) for mol in generated if check_mol(mol) == 'valid'] |
| return valid, len(valid) / len(generated) |
|
|
| def compute_connectivity(self, valid): |
| """ |
| Consider molecule connected if its largest fragment contains at |
| least <self.connectivity_thresh * 100>% of all atoms. |
| :param valid: list of valid RDKit molecules |
| """ |
| if len(valid) < 1: |
| return [], 0.0 |
|
|
| for mol in valid: |
| Chem.SanitizeMol(mol) |
|
|
| connected = [] |
| for mol in valid: |
|
|
| if mol.GetNumAtoms() < 1: |
| continue |
|
|
| try: |
| mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True) |
| except MolSanitizeException as e: |
| print('Error while computing connectivity:', e) |
| continue |
|
|
| largest_frag = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) |
| if largest_frag.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh: |
| connected.append(largest_frag) |
|
|
| return connected, len(connected) / len(valid) |
|
|
| def __call__(self, rdmols, verbose=False): |
| """ |
| :param rdmols: list of RDKit molecules |
| """ |
|
|
| results = {} |
| results['n_total'] = len(rdmols) |
|
|
| valid, validity = self.compute_validity(rdmols) |
| results['n_valid'] = len(valid) |
| results['validity'] = validity |
|
|
| connected, connectivity = self.compute_connectivity(valid) |
| results['n_connected'] = len(connected) |
| results['connectivity'] = connectivity |
| results['valid_and_connected'] = results['n_connected'] / results['n_total'] |
|
|
| if verbose: |
| print(f"Validity over {results['n_total']} molecules: {validity * 100 :.2f}%") |
| print(f"Connectivity over {results['n_valid']} valid molecules: {connectivity * 100 :.2f}%") |
|
|
| return results |
|
|
|
|
| class MolecularMetrics: |
| def __init__(self, connectivity_thresh=1.0): |
| self.connectivity_thresh = connectivity_thresh |
|
|
| @staticmethod |
| def is_valid(rdmol): |
| if rdmol.GetNumAtoms() < 1: |
| return False |
|
|
| _mol = Chem.Mol(rdmol) |
| try: |
| Chem.SanitizeMol(_mol) |
| except ValueError: |
| return False |
|
|
| return True |
|
|
| def is_connected(self, rdmol): |
|
|
| if rdmol.GetNumAtoms() < 1: |
| return False |
|
|
| mol_frags = Chem.rdmolops.GetMolFrags(rdmol, asMols=True) |
|
|
| largest_frag = max(mol_frags, default=rdmol, key=lambda m: m.GetNumAtoms()) |
| if largest_frag.GetNumAtoms() / rdmol.GetNumAtoms() >= self.connectivity_thresh: |
| return True |
| else: |
| return False |
|
|
| @staticmethod |
| def calculate_qed(rdmol): |
| return QED.qed(rdmol) |
|
|
| @staticmethod |
| def calculate_sa(rdmol): |
| sa = calculateScore(rdmol) |
| return sa |
|
|
| @staticmethod |
| def calculate_logp(rdmol): |
| return Crippen.MolLogP(rdmol) |
|
|
| @staticmethod |
| def calculate_lipinski(rdmol): |
| rule_1 = Descriptors.ExactMolWt(rdmol) < 500 |
| rule_2 = Lipinski.NumHDonors(rdmol) <= 5 |
| rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10 |
| rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5) |
| rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10 |
| return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) |
|
|
| def __call__(self, rdmol): |
| valid = self.is_valid(rdmol) |
|
|
| if valid: |
| Chem.SanitizeMol(rdmol) |
|
|
| connected = None if not valid else self.is_connected(rdmol) |
| qed = None if not valid else self.calculate_qed(rdmol) |
| sa = None if not valid else self.calculate_sa(rdmol) |
| logp = None if not valid else self.calculate_logp(rdmol) |
| lipinski = None if not valid else self.calculate_lipinski(rdmol) |
|
|
| return { |
| 'valid': valid, |
| 'connected': connected, |
| 'qed': qed, |
| 'sa': sa, |
| 'logp': logp, |
| 'lipinski': lipinski |
| } |
|
|
|
|
| class Diversity: |
| @staticmethod |
| def similarity(fp1, fp2): |
| return DataStructs.TanimotoSimilarity(fp1, fp2) |
|
|
| def get_fingerprint(self, mol): |
| |
| |
| fp = Chem.RDKFingerprint(mol) |
| return fp |
|
|
| def __call__(self, pocket_mols): |
|
|
| if len(pocket_mols) < 2: |
| return 0.0 |
|
|
| pocket_fps = [self.get_fingerprint(m) for m in pocket_mols] |
|
|
| div = 0 |
| total = 0 |
| for i in range(len(pocket_fps)): |
| for j in range(i + 1, len(pocket_fps)): |
| div += 1 - self.similarity(pocket_fps[i], pocket_fps[j]) |
| total += 1 |
|
|
| return div / total |
|
|
|
|
| class MoleculeUniqueness: |
| def __call__(self, smiles_list): |
| """ smiles_list: list of SMILES strings. """ |
| if len(smiles_list) < 1: |
| return 0.0 |
|
|
| return len(set(smiles_list)) / len(smiles_list) |
|
|
|
|
| class MoleculeNovelty: |
| def __init__(self, reference_smiles): |
| """ |
| :param reference_smiles: list of SMILES strings |
| """ |
| self.reference_smiles = set(reference_smiles) |
|
|
| def __call__(self, smiles_list): |
| if len(smiles_list) < 1: |
| return 0.0 |
|
|
| novel = [smi for smi in smiles_list if smi not in self.reference_smiles] |
| return len(novel) / len(smiles_list) |
|
|
|
|
| class MolecularProperties: |
|
|
| @staticmethod |
| def calculate_qed(rdmol): |
| return QED.qed(rdmol) |
|
|
| @staticmethod |
| def calculate_sa(rdmol): |
| sa = calculateScore(rdmol) |
| |
| return sa |
|
|
| @staticmethod |
| def calculate_logp(rdmol): |
| return Crippen.MolLogP(rdmol) |
|
|
| @staticmethod |
| def calculate_lipinski(rdmol): |
| rule_1 = Descriptors.ExactMolWt(rdmol) < 500 |
| rule_2 = Lipinski.NumHDonors(rdmol) <= 5 |
| rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10 |
| rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5) |
| rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10 |
| return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) |
|
|
| @classmethod |
| def calculate_diversity(cls, pocket_mols): |
| if len(pocket_mols) < 2: |
| return 0.0 |
|
|
| div = 0 |
| total = 0 |
| for i in range(len(pocket_mols)): |
| for j in range(i + 1, len(pocket_mols)): |
| div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j]) |
| total += 1 |
| return div / total |
|
|
| @staticmethod |
| def similarity(mol_a, mol_b): |
| |
| |
| |
| |
| fp1 = Chem.RDKFingerprint(mol_a) |
| fp2 = Chem.RDKFingerprint(mol_b) |
| return DataStructs.TanimotoSimilarity(fp1, fp2) |
|
|
| def evaluate_pockets(self, pocket_rdmols, verbose=False): |
| """ |
| Run full evaluation |
| Args: |
| pocket_rdmols: list of lists, the inner list contains all RDKit |
| molecules generated for a pocket |
| Returns: |
| QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket) |
| """ |
|
|
| for pocket in pocket_rdmols: |
| for mol in pocket: |
| Chem.SanitizeMol(mol) |
|
|
| all_qed = [] |
| all_sa = [] |
| all_logp = [] |
| all_lipinski = [] |
| per_pocket_diversity = [] |
| for pocket in tqdm(pocket_rdmols): |
| all_qed.append([self.calculate_qed(mol) for mol in pocket]) |
| all_sa.append([self.calculate_sa(mol) for mol in pocket]) |
| all_logp.append([self.calculate_logp(mol) for mol in pocket]) |
| all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket]) |
| per_pocket_diversity.append(self.calculate_diversity(pocket)) |
|
|
| qed_flattened = [x for px in all_qed for x in px] |
| sa_flattened = [x for px in all_sa for x in px] |
| logp_flattened = [x for px in all_logp for x in px] |
| lipinski_flattened = [x for px in all_lipinski for x in px] |
|
|
| if verbose: |
| print(f"{sum([len(p) for p in pocket_rdmols])} molecules from " |
| f"{len(pocket_rdmols)} pockets evaluated.") |
| print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}") |
| print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}") |
| print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}") |
| print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}") |
| print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}") |
|
|
| return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity |
|
|
| def __call__(self, rdmols): |
| """ |
| Run full evaluation and return mean of each property |
| Args: |
| rdmols: list of RDKit molecules |
| Returns: |
| Dictionary with mean QED, SA, LogP, Lipinski, and Diversity values |
| """ |
|
|
| if len(rdmols) < 1: |
| return {'QED': 0.0, 'SA': 0.0, 'LogP': 0.0, 'Lipinski': 0.0, |
| 'Diversity': 0.0} |
|
|
| _rdmols = [] |
| for mol in rdmols: |
| try: |
| Chem.SanitizeMol(mol) |
| _rdmols.append(mol) |
| except ValueError as e: |
| print("Tried to analyze invalid molecule") |
| rdmols = _rdmols |
|
|
| qed = np.mean([self.calculate_qed(mol) for mol in rdmols]) |
| sa = np.mean([self.calculate_sa(mol) for mol in rdmols]) |
| logp = np.mean([self.calculate_logp(mol) for mol in rdmols]) |
| lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols]) |
| diversity = self.calculate_diversity(rdmols) |
|
|
| return {'QED': qed, 'SA': sa, 'LogP': logp, 'Lipinski': lipinski, |
| 'Diversity': diversity} |
|
|
|
|
| def compute_gnina_scores(ligands, receptors, gnina): |
| metrics = ['minimizedAffinity', 'minimizedRMSD', 'CNNscore', 'CNNaffinity', 'CNN_VS', 'CNNaffinity_variance'] |
| out = {m: [] for m in metrics} |
| with tempfile.TemporaryDirectory() as tmpdir: |
| for ligand, receptor in zip(tqdm(ligands, desc='Docking'), receptors): |
| in_ligand_path = Path(tmpdir, 'in_ligand.sdf') |
| out_ligand_path = Path(tmpdir, 'out_ligand.sdf') |
| receptor_path = Path(tmpdir, 'receptor.pdb') |
| write_sdf_file(in_ligand_path, [ligand], catch_errors=True) |
| Chem.MolToPDBFile(receptor, str(receptor_path)) |
| if ( |
| (not in_ligand_path.exists()) or |
| (not receptor_path.exists()) or |
| in_ligand_path.read_text() == '' or |
| receptor_path.read_text() == '' |
| ): |
| continue |
|
|
| cmd = ( |
| f'{gnina} -r {receptor_path} -l {in_ligand_path} ' |
| f'--minimize --seed 42 -o {out_ligand_path} --no_gpu 1> /dev/null' |
| ) |
| subprocess.run(cmd, shell=True) |
| if not out_ligand_path.exists() or out_ligand_path.read_text() == '': |
| continue |
|
|
| mol = Chem.SDMolSupplier(str(out_ligand_path), sanitize=False)[0] |
| for metric in metrics: |
| out[metric].append(float(mol.GetProp(metric))) |
|
|
| for metric in metrics: |
| out[metric] = sum(out[metric]) / len(out[metric]) if len(out[metric]) > 0 else 0 |
|
|
| return out |
|
|
|
|
| def legacy_clash_score(rdmol1, rdmol2=None, margin=0.75): |
| """ |
| Computes a clash score as the number of atoms that have at least one |
| clash divided by the number of atoms in the molecule. |
| |
| INTERMOLECULAR CLASH SCORE |
| If rdmol2 is provided, the score is the percentage of atoms in rdmol1 |
| that have at least one clash with rdmol2. |
| We define a clash if two atoms are closer than "margin times the sum of |
| their van der Waals radii". |
| |
| INTRAMOLECULAR CLASH SCORE |
| If rdmol2 is not provided, the score is the percentage of atoms in rdmol1 |
| that have at least one clash with other atoms in rdmol1. |
| In this case, a clash is defined by margin times the atoms' smallest |
| covalent radii (among single, double and triple bond radii). This is done |
| so that this function is applicable even if no connectivity information is |
| available. |
| """ |
| |
| vdw_radii = {'N': 1.55, 'O': 1.52, 'C': 1.70, 'H': 1.10, 'S': 1.80, 'P': 1.80, |
| 'Se': 1.90, 'K': 2.75, 'Na': 2.27, 'Mg': 1.73, 'Zn': 1.39, 'B': 1.92, |
| 'Br': 1.85, 'Cl': 1.75, 'I': 1.98, 'F': 1.47} |
|
|
| |
| covalent_radii = {'H': 0.32, 'C': 0.60, 'N': 0.54, 'O': 0.53, 'F': 0.53, 'B': 0.73, |
| 'Al': 1.11, 'Si': 1.02, 'P': 0.94, 'S': 0.94, 'Cl': 0.93, 'As': 1.06, |
| 'Br': 1.09, 'I': 1.25, 'Hg': 1.33, 'Bi': 1.35} |
|
|
| coord1 = rdmol1.GetConformer().GetPositions() |
|
|
| if rdmol2 is None: |
| radii1 = np.array([covalent_radii[a.GetSymbol()] for a in rdmol1.GetAtoms()]) |
| assert coord1.shape[0] == radii1.shape[0] |
|
|
| dist = np.sqrt(np.sum((coord1[:, None, :] - coord1[None, :, :]) ** 2, axis=-1)) |
| np.fill_diagonal(dist, np.inf) |
| clashes = dist < margin * (radii1[:, None] + radii1[None, :]) |
|
|
| else: |
| coord2 = rdmol2.GetConformer().GetPositions() |
|
|
| radii1 = np.array([vdw_radii[a.GetSymbol()] for a in rdmol1.GetAtoms()]) |
| assert coord1.shape[0] == radii1.shape[0] |
| radii2 = np.array([vdw_radii[a.GetSymbol()] for a in rdmol2.GetAtoms()]) |
| assert coord2.shape[0] == radii2.shape[0] |
|
|
| dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1)) |
| clashes = dist < margin * (radii1[:, None] + radii2[None, :]) |
|
|
| clashes = np.any(clashes, axis=1) |
| return np.mean(clashes) |
|
|
|
|
| def clash_score(rdmol1, rdmol2=None, margin=0.75, ignore={'H'}): |
| """ |
| Computes a clash score as the number of atoms that have at least one |
| clash divided by the number of atoms in the molecule. |
| |
| INTERMOLECULAR CLASH SCORE |
| If rdmol2 is provided, the score is the percentage of atoms in rdmol1 |
| that have at least one clash with rdmol2. |
| We define a clash if two atoms are closer than "margin times the sum of |
| their van der Waals radii". |
| |
| INTRAMOLECULAR CLASH SCORE |
| If rdmol2 is not provided, the score is the percentage of atoms in rdmol1 |
| that have at least one clash with other atoms in rdmol1. |
| In this case, a clash is defined by margin times the atoms' smallest |
| covalent radii (among single, double and triple bond radii). This is done |
| so that this function is applicable even if no connectivity information is |
| available. |
| """ |
|
|
| intramolecular = rdmol2 is None |
|
|
| _periodic_table = AllChem.GetPeriodicTable() |
|
|
| def _coord_and_radii(rdmol): |
| coord = rdmol.GetConformer().GetPositions() |
| radii = np.array([_get_radius(a.GetSymbol()) for a in rdmol.GetAtoms()]) |
|
|
| mask = np.array([a.GetSymbol() not in ignore for a in rdmol.GetAtoms()]) |
| coord = coord[mask] |
| radii = radii[mask] |
|
|
| assert coord.shape[0] == radii.shape[0] |
| return coord, radii |
|
|
| |
| if intramolecular: |
| rdmol2 = rdmol1 |
| _get_radius = _periodic_table.GetRcovalent |
|
|
| |
| else: |
| _get_radius = _periodic_table.GetRvdw |
|
|
| coord1, radii1 = _coord_and_radii(rdmol1) |
| coord2, radii2 = _coord_and_radii(rdmol2) |
|
|
| dist = np.sqrt(np.sum((coord1[:, None, :] - coord2[None, :, :]) ** 2, axis=-1)) |
| if intramolecular: |
| np.fill_diagonal(dist, np.inf) |
|
|
| clashes = dist < margin * (radii1[:, None] + radii2[None, :]) |
| clashes = np.any(clashes, axis=1) |
| return np.mean(clashes) |
|
|