| import os, json
|
| import torch
|
| import utils
|
|
|
| def calc_feats(smi, ms, nls, cfg):
|
| item = {}
|
| item['ms_bins'] = utils.ms_binner(ms, nls,
|
| min_mz=cfg.min_mz,
|
| max_mz=cfg.max_mz,
|
| bin_size=cfg.bin_size,
|
| add_nl=cfg.add_nl,
|
| binary_intn=cfg.binary_intn)
|
|
|
| fmcalced = False
|
| if 'fp' in cfg.mol_encoder:
|
| if not 'fm' in cfg.mol_encoder:
|
| item['mol_fps'] = utils.mol_fp_encoder(smi,
|
| tp=cfg.fptype,
|
| nbits=cfg.mol_embedding_dim)
|
| else:
|
| item['mol_fps'], item['mol_fmvec'] = utils.mol_fp_fm_encoder(smi,
|
| tp=cfg.fptype,
|
| nbits=cfg.mol_embedding_dim)
|
| fmcalced = True
|
| if 'gnn' in cfg.mol_encoder:
|
| f = utils.mol_graph_featurizer(smi)
|
| if not f:
|
| return None
|
| item.update(f)
|
| if 'fm' in cfg.mol_encoder and not fmcalced:
|
| item['mol_fmvec'] = utils.smi2fmvec(smi)
|
|
|
| return item
|
|
|
| class Dataset(torch.utils.data.Dataset):
|
| def __init__(self, inp, cfg):
|
| if type(inp) is str:
|
| self.data = json.load(open(inp))
|
| else:
|
| self.data = inp
|
|
|
| self.cfg = cfg
|
|
|
| def __getitem__(self, idx):
|
| item = {}
|
| try:
|
| if 'ms_bins' in self.data[idx]:
|
| return self.data[idx]
|
|
|
| if 'nls' in self.data[idx]:
|
| nls = self.data[idx]['nls']
|
| else:
|
| nls = []
|
|
|
| ms = self.data[idx]['ms']
|
| smi = self.data[idx]['smiles']
|
|
|
| item = calc_feats(smi, ms, nls, self.cfg)
|
|
|
| except Exception as e:
|
| print('='*50, idx, str(e))
|
| return None
|
|
|
| return item
|
|
|
| def __len__(self):
|
| return len(self.data)
|
|
|
| class DatasetGNNFP(torch.utils.data.Dataset):
|
| def __init__(self, inp, cfg):
|
| if type(inp) is str:
|
| self.data = json.load(open(inp))
|
| else:
|
| self.data = inp
|
|
|
| self.cfg = cfg
|
|
|
| def __getitem__(self, idx):
|
| try:
|
| smi = self.data[idx]['smiles']
|
| item = {}
|
| item['mol_fps'] = utils.mol_fp_encoder(smi,
|
| tp=self.cfg.fptype,
|
| nbits=self.cfg.mol_embedding_dim)
|
| item.update(utils.mol_graph_featurizer(smi))
|
| except Exception as e:
|
| print('='*50, idx, str(e))
|
| return None
|
|
|
| return item
|
|
|
| def __len__(self):
|
| return len(self.data)
|
|
|
| class PathDataset(torch.utils.data.Dataset):
|
| def __init__(self, pathlist, cfg):
|
| self.fns = pathlist
|
| self.cfg = cfg
|
| self.data = {}
|
|
|
| def __getitem__(self, idx):
|
| try:
|
| item = {}
|
| nls = []
|
| if not idx in self.data:
|
| out = self.proc_data(self.fns[idx], self.cfg.energy)
|
| if out is None:
|
| return None
|
| self.data[idx] = out
|
|
|
| ms = self.data[idx]['ms']
|
| smi = self.data[idx]['smiles']
|
|
|
| item = calc_feats(smi, ms, nls, self.cfg)
|
|
|
| except Exception as e:
|
|
|
| return None
|
|
|
| return item
|
|
|
| def proc_data(self, fn, energy='Energy1'):
|
| tl = open(fn).readlines()
|
| l = []
|
| try:
|
| flag = False
|
| for i in tl:
|
| if energy in i:
|
| smi = i.split(';')[-2]
|
| flag = True
|
| continue
|
| if 'END IONS' in i:
|
| if flag:
|
| break
|
| if flag:
|
| mz, intn = i.split(' ')
|
| l.append((float(mz), float(intn)))
|
| except:
|
| return None
|
|
|
| out = {'ms': l, 'smiles': smi}
|
| return out
|
|
|
| def __len__(self):
|
| return len(self.fns) |