| import gradio as gr |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import pandas as pd |
| from rdkit import Chem |
| from rdkit.Chem import AllChem |
| from rdkit import DataStructs |
| from rdkit.Chem import Descriptors |
| from rdkit.Chem import Draw |
| import selfies as sf |
| from rdkit.Chem import RDConfig |
| import os |
| import sys |
| sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score')) |
| import sascorer |
|
|
|
|
| def get_largest_ring_size(mol): |
| cycle_list = mol.GetRingInfo().AtomRings() |
| if cycle_list: |
| cycle_length = max([len(j) for j in cycle_list]) |
| else: |
| cycle_length = 0 |
| return cycle_length |
|
|
| def plogp(smile): |
| if smile: |
| mol = Chem.MolFromSmiles(smile) |
| if mol: |
| log_p = Descriptors.MolLogP(mol) |
| sas_score = sascorer.calculateScore(mol) |
| largest_ring_size = get_largest_ring_size(mol) |
| cycle_score = max(largest_ring_size - 6, 0) |
| if log_p and sas_score and largest_ring_size: |
| p_logp = log_p - sas_score - cycle_score |
| return p_logp |
| else: |
| return -100 |
| else: |
| return -100 |
| else: |
| return -100 |
| |
| def sf_decode(selfies): |
| try: |
| decode = sf.decoder(selfies) |
| return decode |
| except sf.DecoderError: |
| return '' |
| |
| def sim(input_smile, output_smile): |
| if input_smile and output_smile: |
| input_mol = Chem.MolFromSmiles(input_smile) |
| output_mol = Chem.MolFromSmiles(output_smile) |
| if input_mol and output_mol: |
| input_fp = AllChem.GetMorganFingerprint(input_mol, 2) |
| output_fp = AllChem.GetMorganFingerprint(output_mol, 2) |
| sim = DataStructs.TanimotoSimilarity(input_fp, output_fp) |
| return sim |
| else: return None |
| else: return None |
|
|
|
|
| def gen_process(gen_input): |
| tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen-large") |
| model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen-large") |
| |
| sf_input = tokenizer(gen_input, return_tensors="pt") |
| |
| |
| molecules = model.generate(input_ids=sf_input["input_ids"], |
| attention_mask=sf_input["attention_mask"], |
| max_length=15, |
| min_length=5, |
| num_return_sequences=4, |
| num_beams=5) |
|
|
| gen_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules] |
| smis = [sf.decoder(i) for i in gen_output] |
| mols = [] |
| for smi in smis: |
| mol = Chem.MolFromSmiles(smi) |
| mols.append(mol) |
| |
| gen_output_image = Draw.MolsToGridImage( |
| mols, |
| molsPerRow=4, |
| subImgSize=(200,200), |
| legends=['' for x in mols] |
| ) |
| |
| return "\n".join(gen_output), gen_output_image |
| |
| def opt_process(opt_input): |
|
|
| tokenizer = AutoTokenizer.from_pretrained("zjunlp/MolGen-large-opt") |
| model = AutoModelForSeq2SeqLM.from_pretrained("zjunlp/MolGen-large-opt") |
| |
| input = opt_input |
|
|
| smis_input = sf.decoder(input) |
| mol_input = [] |
| mol = Chem.MolFromSmiles(smis_input) |
| mol_input.append(mol) |
| |
| opt_input_img = Draw.MolsToGridImage( |
| mol_input, |
| molsPerRow=4, |
| subImgSize=(200,200), |
| legends=['' for x in mol_input] |
| ) |
| |
| sf_input = tokenizer(input, return_tensors="pt") |
| molecules = model.generate( |
| input_ids=sf_input["input_ids"], |
| attention_mask=sf_input["attention_mask"], |
| do_sample=True, |
| max_length=100, |
| min_length=5, |
| top_k=30, |
| top_p=1, |
| num_return_sequences=10 |
| ) |
| sf_output = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).replace(" ","") for g in molecules] |
| sf_output = list(set(sf_output)) |
| input_sm = sf_decode(input) |
| sm_output = [sf_decode(sf) for sf in sf_output] |
| |
| |
| input_plogp = plogp(input_sm) |
| plogp_improve = [plogp(i)-input_plogp for i in sm_output] |
| |
| |
| simm = [sim(i,input_sm) for i in sm_output] |
| |
| candidate_selfies = {"candidates": sf_output, "improvement": plogp_improve, "sim": simm} |
| data = pd.DataFrame(candidate_selfies) |
| |
| results = data[(data['improvement']> 0) & (data['sim']>0.4)] |
| opt_output = results["candidates"].tolist() |
| opt_output_imp = results["improvement"].tolist() |
| opt_output_imp = [str(i) for i in opt_output_imp] |
| opt_output_sim = results["sim"].tolist() |
| opt_output_sim = [str(i) for i in opt_output_sim] |
| |
|
|
| smis = [sf.decoder(i) for i in opt_output] |
| mols = [] |
| for smi in smis: |
| mol = Chem.MolFromSmiles(smi) |
| mols.append(mol) |
| |
| opt_output_img = Draw.MolsToGridImage( |
| mols, |
| molsPerRow=4, |
| subImgSize=(200,200), |
| legends=['' for x in mols] |
| ) |
| return opt_input_img, "\n".join(opt_output), "\n".join(opt_output_imp), "\n".join(opt_output_sim), opt_output_img |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# MolGen: Domain-Agnostic Molecular Generation with Self-feedback") |
|
|
| with gr.Tabs(): |
| with gr.TabItem("Molecular Generation"): |
| with gr.Row(): |
| with gr.Column(): |
| gen_input = gr.Textbox(label="Input", lines=1, placeholder="SELFIES Input") |
| gen_button = gr.Button("Generate") |
|
|
| with gr.Column(): |
| gen_output = gr.Textbox(label="Generation Results", lines=5, placeholder="") |
| gen_output_image = gr.Image(label="Visualization") |
| |
| gr.Examples( |
| examples=[["[C][=C][C][=C][C][=C][Ring1][=Branch1]"], |
| ["[C]"] |
| ], |
| inputs=[gen_input], |
| outputs=[gen_output, gen_output_image], |
| fn=gen_process, |
| cache_examples=True, |
| ) |
|
|
| with gr.TabItem("Constrained Molecular Property Optimization"): |
| with gr.Row(): |
| with gr.Column(): |
| opt_input = gr.Textbox(label="Input", lines=1, placeholder="SELFIES Input") |
| opt_button = gr.Button("Optimize") |
|
|
| with gr.Column(): |
| opt_input_img = gr.Image(label="Input Visualization") |
| opt_output = gr.Textbox(label="Optimization Results", lines=3, placeholder="") |
| opt_output_imp = gr.Textbox(label="Optimization Property Improvements", lines=3, placeholder="") |
| opt_output_sim = gr.Textbox(label="Similarity", lines=3, placeholder="") |
| opt_output_img = gr.Image(label="Output Visualization") |
| |
|
|
| gr.Examples( |
| examples=[["[C][C][=Branch1][C][=O][N][C][C][O][C][C][O][C][C][O][C][C][Ring1][N]"], |
| ["[C][C][S][C][C][S][C][C][C][S][C][C][S][C][Ring1][=C]"], |
| ["[N][#C][C][C][C@@H1][C][C][C][C][C][C][C][C][C][C][C][Ring1][N][=O]"] |
| ], |
| inputs=[opt_input], |
| outputs=[opt_input_img, opt_output, opt_output_imp, opt_output_sim, opt_output_img], |
| fn=opt_process, |
| cache_examples=True, |
| ) |
|
|
| gen_button.click(fn=gen_process, inputs=[gen_input], outputs=[gen_output, gen_output_image]) |
| opt_button.click(fn=opt_process, inputs=[opt_input], outputs=[opt_input_img, opt_output, opt_output_imp, opt_output_sim, opt_output_img]) |
|
|
| demo.launch() |
|
|