| from modules import *
|
| import os, sys
|
| import numpy as np
|
| from tqdm import tqdm
|
| import torch
|
| from torch import nn
|
| from config import CFG
|
| import utils
|
| import json
|
| import pandas as pd
|
| import pickle
|
| from rdkit import Chem
|
| from rdkit.Chem import inchi
|
|
|
| def smiles_to_inchikey(smiles, nostereo=True):
|
| try:
|
|
|
| mol = Chem.MolFromSmiles(smiles)
|
| if mol is None:
|
| return None
|
|
|
| if nostereo:
|
| options = "-SNon"
|
| inchi_string = inchi.MolToInchi(mol, options=options)
|
| else:
|
| inchi_string = inchi.MolToInchi(mol)
|
|
|
| if not inchi_string:
|
| return None
|
|
|
| inchikey = inchi.InchiToInchiKey(inchi_string)
|
|
|
| return inchikey
|
|
|
| except Exception as e:
|
| print(f"转换失败: {e}")
|
| return None
|
|
|
| def calc_mol_embeddings(model, smis, cfg):
|
| model.eval()
|
| fp_featsl = []
|
| gnn_featsl = []
|
| fm_featsl = []
|
| valid_smis = []
|
|
|
| for smil in smis:
|
| smi = smil[1]
|
| try:
|
| if 'gnn' in cfg.mol_encoder:
|
| gnn_feats = utils.mol_graph_featurizer(smi)
|
| gnn_featsl.append(gnn_feats)
|
| if 'fp' in cfg.mol_encoder:
|
| fp_feats = utils.mol_fp_encoder(smi, tp=cfg.fptype, nbits=cfg.mol_embedding_dim).to(cfg.device)
|
| fp_featsl.append(fp_feats)
|
| if 'fm' in cfg.mol_encoder:
|
| fm_feats = utils.smi2fmvec(smi).to(cfg.device)
|
| fm_featsl.append(fm_feats)
|
| valid_smis.append(smil)
|
| except Exception as e:
|
| print(smi, e)
|
| continue
|
|
|
| mol_feat_list = []
|
| if 'gnn' in cfg.mol_encoder:
|
| vl, al, msl = [], [], []
|
| bat = {}
|
| for b in gnn_featsl:
|
| if 'V' in b:
|
| vl.append(b['V'])
|
| if 'A' in b:
|
| al.append(b['A'])
|
| if 'mol_size' in b:
|
| msl.append(b['mol_size'])
|
|
|
| vl1, al1 = [], []
|
| if vl and al and msl:
|
| max_n = max(map(lambda x:x.shape[0], vl))
|
| for v in vl:
|
| vl1.append(utils.pad_V(v, max_n))
|
| for a in al:
|
| al1.append(utils.pad_A(a, max_n))
|
|
|
| bat['V'] = torch.stack(vl1).to(cfg.device)
|
| bat['A'] = torch.stack(al1).to(cfg.device)
|
| bat['mol_size'] = torch.cat(msl, dim=0).to(cfg.device)
|
|
|
| mol_feat_list.append(model.mol_gnn_encoder(bat))
|
|
|
| del bat
|
|
|
| if 'fp' in cfg.mol_encoder:
|
| mol_feat_list.append(torch.stack(fp_featsl).to(cfg.device))
|
|
|
| if 'fm' in cfg.mol_encoder:
|
| mol_feat_list.append(torch.stack(fm_featsl).to(cfg.device))
|
|
|
| if len(mol_feat_list) > 1:
|
| mol_features = torch.cat(mol_feat_list, dim=1).to(cfg.device)
|
| else:
|
| mol_features = mol_feat_list[0].to(cfg.device)
|
|
|
| with torch.no_grad():
|
| mol_embeddings = model.mol_projection(mol_features)
|
|
|
| del mol_features, mol_feat_list
|
|
|
| return mol_embeddings, valid_smis
|
|
|
| def find_matches(model, ms, smis, cfg, n=10, batch_size=64):
|
| model.eval()
|
| with torch.no_grad():
|
| ms_features = utils.ms_binner(ms, min_mz=cfg.min_mz, max_mz=cfg.max_mz, bin_size=cfg.bin_size, add_nl=cfg.add_nl, binary_intn=cfg.binary_intn).to(cfg.device)
|
| ms_features = ms_features.unsqueeze(0)
|
| ms_embeddings = model.ms_projection(ms_features)
|
| ms_embeddings_n = F.normalize(ms_embeddings, p=2, dim=1)
|
|
|
|
|
| all_similarities = []
|
| all_valid_smis = []
|
|
|
|
|
| all_embeddings = []
|
| for i in tqdm(range(0, len(smis), batch_size)):
|
| batch_smis = smis[i:i+batch_size]
|
| batch_embeddings, valid_smis = calc_mol_embeddings(model, batch_smis, cfg)
|
| all_embeddings.append(batch_embeddings)
|
| all_valid_smis.extend(valid_smis)
|
|
|
| del batch_embeddings
|
|
|
|
|
| all_embeddings = torch.cat(all_embeddings, dim=0)
|
| all_embeddings_n = F.normalize(all_embeddings, p=2, dim=1)
|
|
|
|
|
| similarities = F.cosine_similarity(all_embeddings_n, ms_embeddings_n, dim=1)
|
|
|
|
|
| if n == -1 or n > len(all_valid_smis):
|
| n = len(all_valid_smis)
|
|
|
| values, topk_indices = torch.topk(similarities, n)
|
|
|
| topk_indices_list = topk_indices.cpu().tolist()
|
|
|
| matchsmis = [all_valid_smis[idx] for idx in topk_indices_list]
|
|
|
| return matchsmis, values.cpu().numpy()*100, topk_indices_list
|
|
|
| def calc(models, datal, cfg):
|
| dicall = {}
|
| coridxd = {}
|
|
|
| for idx, model in enumerate(models):
|
| for nn, data in enumerate(datal):
|
| print(f'Calculating {nn}-th MS...')
|
|
|
| try:
|
| smis, scores, indices = find_matches(model, data['ms'], data['candidates'], cfg, 50)
|
| except Exception as e:
|
| print(131, e)
|
| continue
|
|
|
| dic = {}
|
| for n, smil in enumerate(smis):
|
| smi = smil[1]
|
| if smi in dic:
|
| dic[smi]['score'] = scores[n]
|
| dic[smi]['iscor'] = smis[n][-1]
|
| dic[smi]['idx'] = smis[n][0]
|
| else:
|
| dic[smi] = {'score': scores[n], 'iscor': smis[n][-1], 'idx': smis[n][0]}
|
|
|
|
|
| ikey = smiles_to_inchikey(data['smiles'], True)
|
| if ikey is None:
|
| ikey = data['ikey']
|
|
|
| if ikey in dicall:
|
| for k, v in dic.items():
|
| if k in dicall[ikey]:
|
| dicall[ikey][k]['score'] += v['score']
|
| dicall[ikey][k]['score'] /= 2
|
| else:
|
| dicall[ikey][k] = v
|
| else:
|
| dicall[ikey] = dic
|
|
|
| for ikey, dic in dicall.items():
|
| smis = [k for k in dic.keys()]
|
| scorel = [d['score'] for d in dic.values()]
|
| iscorl = [d['iscor'] for d in dic.values()]
|
| indexl = [d['idx'] for d in dic.values()]
|
|
|
| scoretsor = torch.tensor(scorel)
|
| n = 100
|
| if n > len(scorel):
|
| n = len(scorel)
|
|
|
| values, indices = torch.topk(scoretsor, n)
|
|
|
|
|
| indices_list = indices.cpu().tolist()
|
|
|
| scorel = values.cpu().numpy()
|
| smis = [smis[i] for i in indices_list]
|
| iscorl = [iscorl[i] for i in indices_list]
|
| indexl = [indexl[i] for i in indices_list]
|
|
|
| try:
|
| i = iscorl.index(True)
|
| k = 'Hit %.3d' %(i+1)
|
| if k in coridxd:
|
| coridxd[k] += 1
|
| else:
|
| coridxd[k] = 1
|
| except:
|
| pass
|
|
|
| ks = sorted(list(coridxd.keys()))
|
| dc = {}
|
| sumtop3 = 0
|
|
|
| for k in ks:
|
| dc[k] = [coridxd[k]]
|
| if k in ['Hit 001', 'Hit 002', 'Hit 003']:
|
| sumtop3 += coridxd[k]
|
|
|
| for i in range(100):
|
| k = 'Hit %.3d' %(i+1)
|
| if not k in dc:
|
| dc[k] = [0]
|
|
|
| return sumtop3, dc, dicall
|
|
|
| def calc_rank(dicall):
|
| rankd = {}
|
|
|
| for ikey, dic in dicall.items():
|
| smis = [k for k in dic.keys()]
|
| scorel = [d['score'] for d in dic.values()]
|
| iscorl = [d['iscor'] for d in dic.values()]
|
| indexl = [d['idx'] for d in dic.values()]
|
|
|
| scoretsor = torch.tensor(scorel)
|
| n = 100
|
| if n > len(scorel):
|
| n = len(scorel)
|
|
|
| values, indices = torch.topk(scoretsor, n)
|
|
|
| scorel = values
|
| smis = [smis[i] for i in indices]
|
| iscorl = [iscorl[i] for i in indices]
|
| indexl = [indexl[i] for i in indices]
|
|
|
| sl = []
|
| for n, smi in enumerate(smis):
|
| sl.append(f'{scorel[n]}:{smi}:{smiles_to_inchikey(smi)}')
|
|
|
| try:
|
| i = iscorl.index(True)
|
| rankd[ikey] = {'Hit': i+1, 'Rank': sl}
|
| except:
|
| pass
|
|
|
| return rankd
|
|
|
| def predict(modelfnl, datal, datafn=''):
|
| maxtop3 = 0
|
| maxoutt = ''
|
|
|
| for fn in modelfnl:
|
| d = torch.load(fn)
|
| CFG.load(d['config'])
|
| print(d['config'])
|
| CFG.save('', True)
|
|
|
| model = FragSimiModel(CFG).to(CFG.device)
|
| model.load_state_dict(d['state_dict'])
|
|
|
| sumtop3, dc, dicall = calc([model], datal, CFG)
|
|
|
| sumtop10 = 0
|
| for k in ['Hit %.3d' %(i+1) for i in range(10)]:
|
| if k in dc:
|
| sumtop10 += dc[k][0]
|
|
|
| sumtop50 = 0
|
| for k in ['Hit %.3d' %(i+1) for i in range(50)]:
|
| if k in dc:
|
| sumtop50 += dc[k][0]
|
|
|
| tops = {}
|
| for i in range(100):
|
| k = 'Hit %.3d' %(i+1)
|
| key = k.replace('Hit', 'Top')
|
| if not key in tops:
|
| tops[key] = [0]
|
| if k in dc:
|
| for n in range(i+1):
|
| kk = 'Hit %.3d' %(n+1)
|
| if kk in dc:
|
| tops[key][0] += dc[kk][0]
|
|
|
| outt = f'Top1: {dc.setdefault("Hit 001", [0])[0]}, top3: {sumtop3}, top10: {sumtop10}, top50: {sumtop50} of {len(datal)}'
|
|
|
| if sumtop3 > maxtop3:
|
| maxtop3 = sumtop3
|
| maxoutt = outt
|
|
|
| basefn = fn.replace('.pth', f'-{os.path.basename(datafn).split(".")[0]}')
|
| rank = calc_rank(dicall)
|
| json.dump(rank, open(basefn + '-predict-rank.json', 'w'), indent=2)
|
|
|
| df = pd.DataFrame(tops)
|
| df.to_csv(basefn + '-predict-summary.csv', index=False)
|
|
|
| return maxoutt, maxtop3
|
|
|
| def main(datafn, fnl):
|
| outl = []
|
|
|
| datal = json.load(open(datafn))
|
|
|
| n = 0
|
| for n, fn in enumerate(fnl):
|
| out, _ = predict([fn], datal, datafn)
|
| print(out, os.path.basename(fn))
|
| outl.append(out)
|
|
|
| print(outl)
|
|
|
| if __name__ == '__main__':
|
| import time
|
| t0 = time.time()
|
| main(sys.argv[1], sys.argv[2:])
|
| print(300, time.time()-t0)
|
|
|