| from rdkit import Chem
|
| from rdkit.Chem import AllChem, MACCSkeys
|
| from rdkit.Chem.rdmolops import FastFindRings
|
| from rdkit.Chem.rdMolDescriptors import CalcMolFormula
|
| import torch
|
| import numpy as np
|
| import scipy
|
| import scipy.sparse as ss
|
| import scipy.sparse.linalg
|
| import math
|
| import json
|
| import itertools as it
|
| import re
|
| from GNN import featurizer as ft
|
|
|
| import rdkit.RDLogger as rkl
|
| logger = rkl.logger()
|
| logger.setLevel(rkl.ERROR)
|
|
|
| import rdkit.rdBase as rkrb
|
| rkrb.DisableLog('rdApp.error')
|
|
|
|
|
| FPBitIdx = [1, 5, 13, 41, 69, 80, 84, 94, 114, 117, 118, 119, 125, 133, 145,
|
| 147, 191, 192, 197, 202, 222, 227, 231, 249, 283, 294, 310, 314,
|
| 322, 333, 352, 361, 378, 387, 389, 392, 401, 406, 441, 478, 486,
|
| 489, 519, 521, 524, 555, 561, 591, 598, 599, 610, 622, 650, 656,
|
| 667, 675, 677, 679, 680, 694, 695, 715, 718, 722, 729, 736, 739,
|
| 745, 750, 760, 775, 781, 787, 794, 798, 802, 807, 811, 823, 835,
|
| 841, 849, 869, 872, 874, 875, 881, 890, 896, 926, 935, 980, 991,
|
| 1004, 1009, 1017, 1019, 1027, 1028, 1035, 1037, 1039, 1057, 1060,
|
| 1066, 1070, 1077, 1088, 1097, 1114, 1126, 1136, 1142, 1143, 1145,
|
| 1152, 1154, 1160, 1162, 1171, 1181, 1195, 1199, 1202, 1218, 1234,
|
| 1236, 1243, 1257, 1267, 1274, 1279, 1283, 1292, 1294, 1309, 1313,
|
| 1323, 1325, 1349, 1356, 1357, 1366, 1380, 1381, 1385, 1386, 1391,
|
| 1399, 1436, 1440, 1441, 1444, 1452, 1454, 1457, 1475, 1476, 1477,
|
| 1480, 1487, 1516, 1536, 1544, 1558, 1564, 1573, 1599, 1602, 1604,
|
| 1607, 1619, 1648, 1670, 1683, 1693, 1716, 1722, 1737, 1738, 1745,
|
| 1747, 1750, 1754, 1755, 1764, 1781, 1803, 1808, 1810, 1816, 1838,
|
| 1844, 1847, 1855, 1860, 1866, 1873, 1905, 1911, 1917, 1921, 1923,
|
| 1928, 1933, 1950, 1951, 1970, 1977, 1980, 1984, 1991, 2002, 2033, 2034, 2038]
|
|
|
| class ConfigDict(dict):
|
| '''
|
| Makes a dictionary behave like an object,with attribute-style access.
|
| '''
|
| def __getattr__(self, name):
|
| try:
|
| return self[name]
|
| except:
|
| raise AttributeError(name)
|
|
|
| def __setattr__(self, name, value):
|
| self[name] = value
|
|
|
| def save(self, fn):
|
| json.dump(self, open(fn, 'w'), indent=2)
|
|
|
| def load_dict(self, dic):
|
| for k, v in dic.items():
|
| self[k] = v
|
|
|
| def load(self, fn):
|
| try:
|
| d = json.load(open(fn, 'r'))
|
| self.load_dict(d)
|
| except Exception as e:
|
| print(e)
|
|
|
| def conv_out_dim(length_in, kernel, stride, padding, dilation):
|
| length_out = (length_in + 2 * padding - dilation * (kernel - 1) - 1)// stride + 1
|
| return length_out
|
|
|
| def filter_ms(ms, thr=0.05, max_mz=2000):
|
| mz = []
|
| intn = []
|
| maxi = 0
|
| for m, i in ms:
|
| if m < max_mz and i > maxi:
|
| maxi = i
|
|
|
| for m, i in ms:
|
| if m < max_mz and i/maxi > thr:
|
| mz.append(m)
|
| intn.append(round(i/maxi*100, 2))
|
|
|
| return mz, intn
|
|
|
| def calc_nls(ms, thr=0.05, max_mz=2000):
|
| mz, intn = filter_ms(ms, thr=0.05, max_mz=2000)
|
|
|
| nlmass = []
|
| nlintn = []
|
| for a, b in it.combinations(mz[::-1], 2):
|
| nl = a - b
|
| if 0 < nl < 200:
|
| nlmass.append(round(nl, 5))
|
| idxa = mz.index(a)
|
| idxb = mz.index(b)
|
| nlintn.append(round((intn[idxa]+intn[idxb])/2., 5))
|
|
|
| nls = sorted(list(zip(nlmass, nlintn)))
|
| return nls
|
|
|
| def ms_binner(ms, nls=[], min_mz=20, max_mz=2000, bin_size=0.05, add_nl=False, binary_intn=False):
|
| """
|
| Convert the given spectrum to a binned sparse SciPy vector.
|
|
|
| Parameters
|
| ----------
|
| spectrum_mz : np.ndarray
|
| The peak m/z values of the spectrum to be converted to a vector.
|
| spectrum_intensity : np.ndarray
|
| The peak intensities of the spectrum to be converted to a vector.
|
| min_mz : float
|
| The minimum m/z to include in the vector.
|
| bin_size : float
|
| The bin size in m/z used to divide the m/z range.
|
| num_bins : int
|
| The number of elements of which the vector consists.
|
|
|
| Returns
|
| -------
|
| ss.csr_matrix
|
| The binned spectrum vector.
|
| """
|
| if add_nl and not nls:
|
| nls = calc_nls(ms, max_mz=max_mz)
|
|
|
| nltensor = None
|
| mz, intn = filter_ms(ms)
|
|
|
| if add_nl:
|
| nlmass = []
|
| nlintn = []
|
|
|
| if not nls:
|
| nls = calc_nls(ms, max_mz=max_mz)
|
|
|
| for m, i in nls:
|
| if m < 200:
|
| if binary_intn:
|
| i = 1
|
| nlmass.append(m)
|
| nlintn.append(i)
|
|
|
| nlmass = np.array(nlmass)
|
| nlintn = np.array(nlintn)
|
| if len(nlintn) > 0:
|
| nlintn = nlintn/nlintn.max()
|
| num_nlbins = math.ceil((200) / bin_size)
|
|
|
| nlbins = (nlmass / bin_size).astype(np.int32)
|
|
|
| if len(nlmass) > 0:
|
| vecnl = ss.csr_matrix(
|
| (nlintn,
|
| (np.repeat(0, len(nlintn)), nlbins)),
|
| shape=(1, num_nlbins),
|
| dtype=np.float32)
|
|
|
| vecnl = (vecnl / scipy.sparse.linalg.norm(vecnl)*100)
|
| nltensor = torch.FloatTensor(vecnl.todense()).view(-1)
|
| else:
|
| nltensor = torch.zeros(num_nlbins)
|
|
|
| mz = np.array(mz)
|
| keepidx = (mz <= max_mz)
|
| mz = mz[keepidx]
|
| intn = np.array(intn)
|
| intn = intn[keepidx]
|
|
|
| if binary_intn:
|
| intn[intn > 0] = 1.0
|
| elif len(intn) > 0:
|
| intn = intn/intn.max()
|
|
|
| num_bins = math.ceil((max_mz - min_mz) / bin_size)
|
|
|
| bins = ((mz - min_mz) / bin_size).astype(np.int32)
|
|
|
|
|
|
|
| if len(mz) > 0:
|
| vec = ss.csr_matrix(
|
| (intn,
|
| (np.repeat(0, len(intn)), bins)),
|
| shape=(1, num_bins),
|
| dtype=np.float32)
|
|
|
| if not binary_intn:
|
| vec = (vec / scipy.sparse.linalg.norm(vec)*100)
|
|
|
| mstensor = torch.FloatTensor(vec.todense()).view(-1)
|
| else:
|
| mstensor = torch.zeros(num_bins)
|
|
|
| if not nltensor is None:
|
| return torch.cat([nltensor, mstensor], dim=0)
|
|
|
| return mstensor
|
|
|
| def formula2vec(formula, elements=['C', 'H', 'O', 'N', 'P', 'S', 'P', 'F', 'Cl', 'Br']):
|
| formula_p = re.findall(r'([A-Z][a-z]*)(\d*)', formula)
|
| vec = np.zeros(len(elements))
|
| for i in range(len(formula_p)):
|
| ele = formula_p[i][0]
|
| num = formula_p[i][1]
|
| if num == '':
|
| num = 1
|
| else:
|
| num = int(num)
|
| if ele in elements:
|
| vec[elements.index(ele)] += num
|
| return np.array(vec)
|
|
|
| def mol_fp_encoder0(smiles, tp='rdkit', nbits=2048):
|
| mol = Chem.MolFromSmiles(smiles)
|
| if mol is None:
|
| mol = Chem.MolFromSmiles(smiles, sanitize=False)
|
| if not mol is None:
|
| mol.UpdatePropertyCache()
|
| FastFindRings(mol)
|
|
|
| if mol is None:
|
| return None, None
|
|
|
| if tp == 'morgan':
|
| fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=nbits)
|
| fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| fp = fp.tolist()
|
| elif tp == 'morgan1':
|
| fp_vec = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=2048)
|
| fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| fp = fp[FPBitIdx].tolist()
|
| elif tp == 'macc':
|
|
|
| fp_vec = MACCSkeys.GenMACCSKeys(mol)
|
| fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| fp = fp.tolist()
|
| elif tp == 'rdkit':
|
| fp_vec = Chem.RDKFingerprint(mol, nBitsPerHash=1)
|
| fp = np.frombuffer(fp_vec.ToBitString().encode(), 'u1') - ord('0')
|
| fp = fp.tolist()
|
|
|
| return torch.FloatTensor(fp), mol
|
|
|
| def mol_fp_encoder(smiles, tp='rdkit', nbits=2048):
|
| fpenc, _ = mol_fp_encoder0(smiles, tp, nbits)
|
| return fpenc
|
|
|
| def mol_fp_fm_encoder(smiles, tp='rdkit', nbits=2048):
|
| fmenc = None
|
| fpenc, mol = mol_fp_encoder0(smiles, tp, nbits)
|
| if not mol is None:
|
| fm = CalcMolFormula(mol)
|
| fmenc = torch.FloatTensor(formula2vec(fm))
|
| return fpenc, fmenc
|
|
|
| def smi2fmvec(smiles):
|
| mol = Chem.MolFromSmiles(smiles)
|
| if mol is None:
|
| return None
|
| fm = CalcMolFormula(mol)
|
| fmenc = torch.FloatTensor(formula2vec(fm))
|
|
|
| return fmenc
|
|
|
| def mol_graph_featurizer(smiles):
|
|
|
| '''mol_graph = ft.calc_data_from_smile(smiles,
|
| addh=True,
|
| with_ring_conj=True,
|
| with_atom_feats=True,
|
| with_submol_fp=True,
|
| radius=2)
|
| '''
|
| mol_graph = ft.calc_data_from_smile(smiles,
|
| addh=False,
|
| with_ring_conj=True,
|
| with_atom_feats=True,
|
| with_submol_fp=False,
|
| radius=2)
|
| return mol_graph
|
|
|
| def pad_V(V, max_n):
|
| N, C = V.shape
|
| if max_n > N:
|
| zeros = torch.zeros(max_n-N, C)
|
| V = torch.cat([V, zeros], dim=0)
|
| return V
|
|
|
| def pad_A(A, max_n):
|
| N, L, _ = A.shape
|
| if max_n > N:
|
| zeros = torch.zeros(N, L, max_n-N)
|
| A = torch.cat([A, zeros], dim=-1)
|
| zeros = torch.zeros(max_n-N, L, max_n)
|
| A = torch.cat([A, zeros], dim=0)
|
| return A
|
|
|
| class AvgMeter:
|
| def __init__(self, name="Metric"):
|
| self.name = name
|
| self.reset()
|
|
|
| def reset(self):
|
| self.avg, self.sum, self.count = [0] * 3
|
|
|
| def update(self, val, count=1):
|
| self.count += count
|
| self.sum += val * count
|
| self.avg = self.sum / self.count
|
|
|
| def __repr__(self):
|
| text = f"{self.name}: {self.avg:.4f}"
|
| return text
|
|
|
| def get_lr(optimizer):
|
| for param_group in optimizer.param_groups:
|
| return param_group["lr"]
|
|
|
| def segment_max(x, size_list):
|
| size_list = [int(i) for i in size_list]
|
| return torch.stack([torch.max(v, 0).values for v in torch.split(x, size_list)])
|
|
|
| def segment_sum(x, size_list):
|
| size_list = [int(i) for i in size_list]
|
| return torch.stack([torch.sum(v, 0) for v in torch.split(x, size_list)])
|
|
|
| def segment_softmax(gate, size_list):
|
| segmax = segment_max(gate, size_list)
|
|
|
| segmax_expand = torch.cat([segmax[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
| subtract = gate - segmax_expand
|
| exp = torch.exp(subtract)
|
| segsum = segment_sum(exp, size_list)
|
|
|
| segsum_expand = torch.cat([segsum[i].repeat(n,1) for i,n in enumerate(size_list)], dim=0)
|
| attention = exp / (segsum_expand + 1e-16)
|
|
|
| return attention
|
|
|
| def pad_ms_list(ms_list, thr=0.05, min_mz=20, max_mz=2000):
|
| thr = thr*100
|
| mslst = []
|
| for ms in ms_list:
|
| ms = np.array(ms)
|
| ms[:,1] = ms[:,1]/ms[:,1].max()*100
|
|
|
| if thr > 0:
|
| ms = ms[(ms[:,1] >= thr)]
|
|
|
| ms = ms[(ms[:,0] >= min_mz)]
|
| ms = ms[(ms[:,0] <= max_mz)]
|
|
|
| mslst.append(ms)
|
|
|
| size_list = [ms.shape[0] for ms in mslst]
|
| maxlen = max(size_list)
|
|
|
| l = []
|
| for ms in mslst:
|
| extn = maxlen-len(ms)
|
| if extn > 0:
|
| l.append(np.concatenate([ms, [[0,0]]*extn], axis=0))
|
| else:
|
| l.append(ms)
|
|
|
| return torch.FloatTensor(np.stack(l)), torch.IntTensor(size_list)
|
|
|