| import numpy as np |
| from tqdm import tqdm |
| from rdkit import Chem, DataStructs |
| from rdkit.Chem import Descriptors, Crippen, Lipinski, QED |
| from analysis.SA_Score.sascorer import calculateScore |
|
|
| from analysis.molecule_builder import build_molecule |
| from copy import deepcopy |
|
|
|
|
| class CategoricalDistribution: |
| EPS = 1e-10 |
|
|
| def __init__(self, histogram_dict, mapping): |
| histogram = np.zeros(len(mapping)) |
| for k, v in histogram_dict.items(): |
| histogram[mapping[k]] = v |
|
|
| |
| self.p = histogram / histogram.sum() |
| self.mapping = deepcopy(mapping) |
|
|
| def kl_divergence(self, other_sample): |
| sample_histogram = np.zeros(len(self.mapping)) |
| for x in other_sample: |
| |
| sample_histogram[x] += 1 |
|
|
| |
| q = sample_histogram / sample_histogram.sum() |
|
|
| return -np.sum(self.p * np.log(q / self.p + self.EPS)) |
|
|
|
|
| def rdmol_to_smiles(rdmol): |
| mol = Chem.Mol(rdmol) |
| Chem.RemoveStereochemistry(mol) |
| mol = Chem.RemoveHs(mol) |
| return Chem.MolToSmiles(mol) |
|
|
|
|
| class BasicMolecularMetrics(object): |
| def __init__(self, dataset_info, dataset_smiles_list=None, |
| connectivity_thresh=1.0): |
| self.atom_decoder = dataset_info['atom_decoder'] |
| if dataset_smiles_list is not None: |
| dataset_smiles_list = set(dataset_smiles_list) |
| self.dataset_smiles_list = dataset_smiles_list |
| self.dataset_info = dataset_info |
| self.connectivity_thresh = connectivity_thresh |
|
|
| def compute_validity(self, generated): |
| """ generated: list of couples (positions, atom_types)""" |
| if len(generated) < 1: |
| return [], 0.0 |
|
|
| valid = [] |
| for mol in generated: |
| try: |
| Chem.SanitizeMol(mol) |
| except ValueError: |
| continue |
|
|
| valid.append(mol) |
|
|
| return valid, len(valid) / len(generated) |
|
|
| def compute_connectivity(self, valid): |
| """ Consider molecule connected if its largest fragment contains at |
| least x% of all atoms, where x is determined by |
| self.connectivity_thresh (defaults to 100%). """ |
| if len(valid) < 1: |
| return [], 0.0 |
|
|
| connected = [] |
| connected_smiles = [] |
| for mol in valid: |
| mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True) |
| largest_mol = \ |
| max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) |
| if largest_mol.GetNumAtoms() / mol.GetNumAtoms() >= self.connectivity_thresh: |
| smiles = rdmol_to_smiles(largest_mol) |
| if smiles is not None: |
| connected_smiles.append(smiles) |
| connected.append(largest_mol) |
|
|
| return connected, len(connected_smiles) / len(valid), connected_smiles |
|
|
| def compute_uniqueness(self, connected): |
| """ valid: list of SMILES strings.""" |
| if len(connected) < 1 or self.dataset_smiles_list is None: |
| return [], 0.0 |
|
|
| return list(set(connected)), len(set(connected)) / len(connected) |
|
|
| def compute_novelty(self, unique): |
| if len(unique) < 1: |
| return [], 0.0 |
|
|
| num_novel = 0 |
| novel = [] |
| for smiles in unique: |
| if smiles not in self.dataset_smiles_list: |
| novel.append(smiles) |
| num_novel += 1 |
| return novel, num_novel / len(unique) |
|
|
| def evaluate_rdmols(self, rdmols): |
| valid, validity = self.compute_validity(rdmols) |
| print(f"Validity over {len(rdmols)} molecules: {validity * 100 :.2f}%") |
|
|
| connected, connectivity, connected_smiles = \ |
| self.compute_connectivity(valid) |
| print(f"Connectivity over {len(valid)} valid molecules: " |
| f"{connectivity * 100 :.2f}%") |
|
|
| unique, uniqueness = self.compute_uniqueness(connected_smiles) |
| print(f"Uniqueness over {len(connected)} connected molecules: " |
| f"{uniqueness * 100 :.2f}%") |
|
|
| _, novelty = self.compute_novelty(unique) |
| print(f"Novelty over {len(unique)} unique connected molecules: " |
| f"{novelty * 100 :.2f}%") |
|
|
| return [validity, connectivity, uniqueness, novelty], [valid, connected] |
|
|
| def evaluate(self, generated): |
| """ generated: list of pairs (positions: n x 3, atom_types: n [int]) |
| the positions and atom types should already be masked. """ |
|
|
| rdmols = [build_molecule(*graph, self.dataset_info) |
| for graph in generated] |
| return self.evaluate_rdmols(rdmols) |
|
|
|
|
| class MoleculeProperties: |
|
|
| @staticmethod |
| def calculate_qed(rdmol): |
| return QED.qed(rdmol) |
|
|
| @staticmethod |
| def calculate_sa(rdmol): |
| sa = calculateScore(rdmol) |
| return round((10 - sa) / 9, 2) |
|
|
| @staticmethod |
| def calculate_logp(rdmol): |
| return Crippen.MolLogP(rdmol) |
|
|
| @staticmethod |
| def calculate_lipinski(rdmol): |
| rule_1 = Descriptors.ExactMolWt(rdmol) < 500 |
| rule_2 = Lipinski.NumHDonors(rdmol) <= 5 |
| rule_3 = Lipinski.NumHAcceptors(rdmol) <= 10 |
| rule_4 = (logp := Crippen.MolLogP(rdmol) >= -2) & (logp <= 5) |
| rule_5 = Chem.rdMolDescriptors.CalcNumRotatableBonds(rdmol) <= 10 |
| return np.sum([int(a) for a in [rule_1, rule_2, rule_3, rule_4, rule_5]]) |
|
|
| @classmethod |
| def calculate_diversity(cls, pocket_mols): |
| if len(pocket_mols) < 2: |
| return 0.0 |
|
|
| div = 0 |
| total = 0 |
| for i in range(len(pocket_mols)): |
| for j in range(i + 1, len(pocket_mols)): |
| div += 1 - cls.similarity(pocket_mols[i], pocket_mols[j]) |
| total += 1 |
| return div / total |
|
|
| @staticmethod |
| def similarity(mol_a, mol_b): |
| |
| |
| |
| |
| fp1 = Chem.RDKFingerprint(mol_a) |
| fp2 = Chem.RDKFingerprint(mol_b) |
| return DataStructs.TanimotoSimilarity(fp1, fp2) |
|
|
| def evaluate(self, pocket_rdmols): |
| """ |
| Run full evaluation |
| Args: |
| pocket_rdmols: list of lists, the inner list contains all RDKit |
| molecules generated for a pocket |
| Returns: |
| QED, SA, LogP, Lipinski (per molecule), and Diversity (per pocket) |
| """ |
|
|
| for pocket in pocket_rdmols: |
| for mol in pocket: |
| Chem.SanitizeMol(mol) |
| assert mol is not None, "only evaluate valid molecules" |
|
|
| all_qed = [] |
| all_sa = [] |
| all_logp = [] |
| all_lipinski = [] |
| per_pocket_diversity = [] |
| for pocket in tqdm(pocket_rdmols): |
| all_qed.append([self.calculate_qed(mol) for mol in pocket]) |
| all_sa.append([self.calculate_sa(mol) for mol in pocket]) |
| all_logp.append([self.calculate_logp(mol) for mol in pocket]) |
| all_lipinski.append([self.calculate_lipinski(mol) for mol in pocket]) |
| per_pocket_diversity.append(self.calculate_diversity(pocket)) |
|
|
| print(f"{sum([len(p) for p in pocket_rdmols])} molecules from " |
| f"{len(pocket_rdmols)} pockets evaluated.") |
|
|
| qed_flattened = [x for px in all_qed for x in px] |
| print(f"QED: {np.mean(qed_flattened):.3f} \pm {np.std(qed_flattened):.2f}") |
|
|
| sa_flattened = [x for px in all_sa for x in px] |
| print(f"SA: {np.mean(sa_flattened):.3f} \pm {np.std(sa_flattened):.2f}") |
|
|
| logp_flattened = [x for px in all_logp for x in px] |
| print(f"LogP: {np.mean(logp_flattened):.3f} \pm {np.std(logp_flattened):.2f}") |
|
|
| lipinski_flattened = [x for px in all_lipinski for x in px] |
| print(f"Lipinski: {np.mean(lipinski_flattened):.3f} \pm {np.std(lipinski_flattened):.2f}") |
|
|
| print(f"Diversity: {np.mean(per_pocket_diversity):.3f} \pm {np.std(per_pocket_diversity):.2f}") |
|
|
| return all_qed, all_sa, all_logp, all_lipinski, per_pocket_diversity |
|
|
| def evaluate_mean(self, rdmols): |
| """ |
| Run full evaluation and return mean of each property |
| Args: |
| rdmols: list of RDKit molecules |
| Returns: |
| QED, SA, LogP, Lipinski, and Diversity |
| """ |
|
|
| if len(rdmols) < 1: |
| return 0.0, 0.0, 0.0, 0.0, 0.0 |
|
|
| for mol in rdmols: |
| Chem.SanitizeMol(mol) |
| assert mol is not None, "only evaluate valid molecules" |
|
|
| qed = np.mean([self.calculate_qed(mol) for mol in rdmols]) |
| sa = np.mean([self.calculate_sa(mol) for mol in rdmols]) |
| logp = np.mean([self.calculate_logp(mol) for mol in rdmols]) |
| lipinski = np.mean([self.calculate_lipinski(mol) for mol in rdmols]) |
| diversity = self.calculate_diversity(rdmols) |
|
|
| return qed, sa, logp, lipinski, diversity |
|
|