Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import re | |
| import wandb | |
| try: | |
| from rdkit import Chem | |
| print("Found rdkit, all good") | |
| except ModuleNotFoundError as e: | |
| use_rdkit = False | |
| from warnings import warn | |
| warn("Didn't find rdkit, this will fail") | |
| assert use_rdkit, "Didn't find rdkit" | |
| allowed_bonds = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'B': 3, 'Al': 3, 'Si': 4, 'P': [3, 5], | |
| 'S': 4, 'Cl': 1, 'As': 3, 'Br': 1, 'I': 1, 'Hg': [1, 2], 'Bi': [3, 5], 'Se': [2, 4, 6]} | |
| bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, | |
| Chem.rdchem.BondType.AROMATIC] | |
| ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} | |
| class BasicMolecularMetrics(object): | |
| def __init__(self, dataset_info, train_smiles=None): | |
| self.atom_decoder = dataset_info.atom_decoder | |
| self.dataset_info = dataset_info | |
| # Retrieve dataset smiles only for qm9 currently. | |
| self.dataset_smiles_list = train_smiles | |
| def compute_validity(self, generated): | |
| """ generated: list of couples (positions, atom_types)""" | |
| valid = [] | |
| num_components = [] | |
| all_smiles = [] | |
| for graph in generated: | |
| atom_types, edge_types = graph | |
| mol = build_molecule(atom_types, edge_types, self.dataset_info.atom_decoder) | |
| smiles = mol2smiles(mol) | |
| try: | |
| mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) | |
| num_components.append(len(mol_frags)) | |
| except: | |
| pass | |
| if smiles is not None: | |
| try: | |
| mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) | |
| largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) | |
| smiles = mol2smiles(largest_mol) | |
| valid.append(smiles) | |
| all_smiles.append(smiles) | |
| except Chem.rdchem.AtomValenceException: | |
| print("Valence error in GetmolFrags") | |
| all_smiles.append(None) | |
| except Chem.rdchem.KekulizeException: | |
| print("Can't kekulize molecule") | |
| all_smiles.append(None) | |
| else: | |
| all_smiles.append(None) | |
| return valid, len(valid) / len(generated), np.array(num_components), all_smiles | |
| def compute_uniqueness(self, valid): | |
| """ valid: list of SMILES strings.""" | |
| return list(set(valid)), len(set(valid)) / len(valid) | |
| def compute_novelty(self, unique): | |
| num_novel = 0 | |
| novel = [] | |
| if self.dataset_smiles_list is None: | |
| print("Dataset smiles is None, novelty computation skipped") | |
| return 1, 1 | |
| for smiles in unique: | |
| if smiles not in self.dataset_smiles_list: | |
| novel.append(smiles) | |
| num_novel += 1 | |
| return novel, num_novel / len(unique) | |
| def compute_relaxed_validity(self, generated): | |
| valid = [] | |
| for graph in generated: | |
| atom_types, edge_types = graph | |
| mol = build_molecule_with_partial_charges(atom_types, edge_types, self.dataset_info.atom_decoder) | |
| smiles = mol2smiles(mol) | |
| if smiles is not None: | |
| try: | |
| mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True) | |
| largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms()) | |
| smiles = mol2smiles(largest_mol) | |
| valid.append(smiles) | |
| except Chem.rdchem.AtomValenceException: | |
| print("Valence error in GetmolFrags") | |
| except Chem.rdchem.KekulizeException: | |
| print("Can't kekulize molecule") | |
| return valid, len(valid) / len(generated) | |
| def evaluate(self, generated): | |
| """ generated: list of pairs (positions: n x 3, atom_types: n [int]) | |
| the positions and atom types should already be masked. """ | |
| valid, validity, num_components, all_smiles = self.compute_validity(generated) | |
| nc_mu = num_components.mean() if len(num_components) > 0 else 0 | |
| nc_min = num_components.min() if len(num_components) > 0 else 0 | |
| nc_max = num_components.max() if len(num_components) > 0 else 0 | |
| print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}%") | |
| print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}") | |
| relaxed_valid, relaxed_validity = self.compute_relaxed_validity(generated) | |
| print(f"Relaxed validity over {len(generated)} molecules: {relaxed_validity * 100 :.2f}%") | |
| if relaxed_validity > 0: | |
| unique, uniqueness = self.compute_uniqueness(relaxed_valid) | |
| print(f"Uniqueness over {len(relaxed_valid)} valid molecules: {uniqueness * 100 :.2f}%") | |
| if self.dataset_smiles_list is not None: | |
| _, novelty = self.compute_novelty(unique) | |
| print(f"Novelty over {len(unique)} unique valid molecules: {novelty * 100 :.2f}%") | |
| else: | |
| novelty = -1.0 | |
| else: | |
| novelty = -1.0 | |
| uniqueness = 0.0 | |
| unique = [] | |
| return ([validity, relaxed_validity, uniqueness, novelty], unique, | |
| dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles) | |
| def mol2smiles(mol): | |
| try: | |
| Chem.SanitizeMol(mol) | |
| except ValueError: | |
| return None | |
| return Chem.MolToSmiles(mol) | |
| def build_molecule(atom_types, edge_types, atom_decoder, verbose=False): | |
| if verbose: | |
| print("building new molecule") | |
| mol = Chem.RWMol() | |
| for atom in atom_types: | |
| a = Chem.Atom(atom_decoder[atom.item()]) | |
| mol.AddAtom(a) | |
| if verbose: | |
| print("Atom added: ", atom.item(), atom_decoder[atom.item()]) | |
| edge_types = torch.triu(edge_types) | |
| all_bonds = torch.nonzero(edge_types) | |
| for i, bond in enumerate(all_bonds): | |
| if bond[0].item() != bond[1].item(): | |
| mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()]) | |
| if verbose: | |
| print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(), | |
| bond_dict[edge_types[bond[0], bond[1]].item()] ) | |
| return mol | |
| def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False): | |
| if verbose: | |
| print("\nbuilding new molecule") | |
| mol = Chem.RWMol() | |
| for atom in atom_types: | |
| a = Chem.Atom(atom_decoder[atom.item()]) | |
| mol.AddAtom(a) | |
| if verbose: | |
| print("Atom added: ", atom.item(), atom_decoder[atom.item()]) | |
| edge_types = torch.triu(edge_types) | |
| all_bonds = torch.nonzero(edge_types) | |
| for i, bond in enumerate(all_bonds): | |
| if bond[0].item() != bond[1].item(): | |
| mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()]) | |
| if verbose: | |
| print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(), | |
| bond_dict[edge_types[bond[0], bond[1]].item()]) | |
| # add formal charge to atom: e.g. [O+], [N+], [S+] | |
| # not support [O-], [N-], [S-], [NH+] etc. | |
| flag, atomid_valence = check_valency(mol) | |
| if verbose: | |
| print("flag, valence", flag, atomid_valence) | |
| if flag: | |
| continue | |
| else: | |
| assert len(atomid_valence) == 2 | |
| idx = atomid_valence[0] | |
| v = atomid_valence[1] | |
| an = mol.GetAtomWithIdx(idx).GetAtomicNum() | |
| if verbose: | |
| print("atomic num of atom with a large valence", an) | |
| if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: | |
| mol.GetAtomWithIdx(idx).SetFormalCharge(1) | |
| # print("Formal charge added") | |
| return mol | |
| # Functions from GDSS | |
| def check_valency(mol): | |
| try: | |
| Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) | |
| return True, None | |
| except ValueError as e: | |
| e = str(e) | |
| p = e.find('#') | |
| e_sub = e[p:] | |
| atomid_valence = list(map(int, re.findall(r'\d+', e_sub))) | |
| return False, atomid_valence | |
| def correct_mol(m): | |
| # xsm = Chem.MolToSmiles(x, isomericSmiles=True) | |
| mol = m | |
| ##### | |
| no_correct = False | |
| flag, _ = check_valency(mol) | |
| if flag: | |
| no_correct = True | |
| while True: | |
| flag, atomid_valence = check_valency(mol) | |
| if flag: | |
| break | |
| else: | |
| assert len(atomid_valence) == 2 | |
| idx = atomid_valence[0] | |
| v = atomid_valence[1] | |
| queue = [] | |
| check_idx = 0 | |
| for b in mol.GetAtomWithIdx(idx).GetBonds(): | |
| type = int(b.GetBondType()) | |
| queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx())) | |
| if type == 12: | |
| check_idx += 1 | |
| queue.sort(key=lambda tup: tup[1], reverse=True) | |
| if queue[-1][1] == 12: | |
| return None, no_correct | |
| elif len(queue) > 0: | |
| start = queue[check_idx][2] | |
| end = queue[check_idx][3] | |
| t = queue[check_idx][1] - 1 | |
| mol.RemoveBond(start, end) | |
| if t >= 1: | |
| mol.AddBond(start, end, bond_dict[t]) | |
| return mol, no_correct | |
| def valid_mol_can_with_seg(m, largest_connected_comp=True): | |
| if m is None: | |
| return None | |
| sm = Chem.MolToSmiles(m, isomericSmiles=True) | |
| if largest_connected_comp and '.' in sm: | |
| vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.') | |
| vsm.sort(key=lambda tup: tup[1], reverse=True) | |
| mol = Chem.MolFromSmiles(vsm[0][0]) | |
| else: | |
| mol = Chem.MolFromSmiles(sm) | |
| return mol | |
| if __name__ == '__main__': | |
| smiles_mol = 'C1CCC1' | |
| print("Smiles mol %s" % smiles_mol) | |
| chem_mol = Chem.MolFromSmiles(smiles_mol) | |
| block_mol = Chem.MolToMolBlock(chem_mol) | |
| print("Block mol:") | |
| print(block_mol) | |
| use_rdkit = True | |
| def check_stability(atom_types, edge_types, dataset_info, debug=False,atom_decoder=None): | |
| if atom_decoder is None: | |
| atom_decoder = dataset_info.atom_decoder | |
| n_bonds = np.zeros(len(atom_types), dtype='int') | |
| for i in range(len(atom_types)): | |
| for j in range(i + 1, len(atom_types)): | |
| n_bonds[i] += abs((edge_types[i, j] + edge_types[j, i])/2) | |
| n_bonds[j] += abs((edge_types[i, j] + edge_types[j, i])/2) | |
| n_stable_bonds = 0 | |
| for atom_type, atom_n_bond in zip(atom_types, n_bonds): | |
| possible_bonds = allowed_bonds[atom_decoder[atom_type]] | |
| if type(possible_bonds) == int: | |
| is_stable = possible_bonds == atom_n_bond | |
| else: | |
| is_stable = atom_n_bond in possible_bonds | |
| if not is_stable and debug: | |
| print("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type], atom_n_bond)) | |
| n_stable_bonds += int(is_stable) | |
| molecule_stable = n_stable_bonds == len(atom_types) | |
| return molecule_stable, n_stable_bonds, len(atom_types) | |
| def compute_molecular_metrics(molecule_list, train_smiles, dataset_info): | |
| """ molecule_list: (dict) """ | |
| if not dataset_info.remove_h: | |
| print(f'Analyzing molecule stability...') | |
| molecule_stable = 0 | |
| nr_stable_bonds = 0 | |
| n_atoms = 0 | |
| n_molecules = len(molecule_list) | |
| for i, mol in enumerate(molecule_list): | |
| atom_types, edge_types = mol | |
| validity_results = check_stability(atom_types, edge_types, dataset_info) | |
| molecule_stable += int(validity_results[0]) | |
| nr_stable_bonds += int(validity_results[1]) | |
| n_atoms += int(validity_results[2]) | |
| # Validity | |
| fraction_mol_stable = molecule_stable / float(n_molecules) | |
| fraction_atm_stable = nr_stable_bonds / float(n_atoms) | |
| validity_dict = {'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable} | |
| if wandb.run: | |
| wandb.log(validity_dict) | |
| else: | |
| validity_dict = {'mol_stable': -1, 'atm_stable': -1} | |
| metrics = BasicMolecularMetrics(dataset_info, train_smiles) | |
| rdkit_metrics = metrics.evaluate(molecule_list) | |
| all_smiles = rdkit_metrics[-1] | |
| if wandb.run: | |
| nc = rdkit_metrics[-2] | |
| dic = {'Validity': rdkit_metrics[0][0], 'Relaxed Validity': rdkit_metrics[0][1], | |
| 'Uniqueness': rdkit_metrics[0][2], 'Novelty': rdkit_metrics[0][3], | |
| 'nc_max': nc['nc_max'], 'nc_mu': nc['nc_mu']} | |
| wandb.log(dic) | |
| return validity_dict, rdkit_metrics, all_smiles | |