| import gradio as gr |
| import os |
|
|
| import copy |
| import os |
| import torch |
|
|
| import time |
| from argparse import ArgumentParser, Namespace, FileType |
| from rdkit.Chem import RemoveHs |
| from functools import partial |
| import numpy as np |
| import pandas as pd |
| from rdkit import RDLogger |
| from rdkit.Chem import MolFromSmiles, AddHs |
| from torch_geometric.loader import DataLoader |
| import yaml |
| import sys |
| import csv |
|
|
| csv.field_size_limit(sys.maxsize) |
|
|
| print(torch.__version__) |
| os.makedirs("data/esm2_output", exist_ok=True) |
| os.makedirs("results", exist_ok=True) |
| from datasets.process_mols import ( |
| read_molecule, |
| generate_conformer, |
| write_mol_with_coords, |
| ) |
| from datasets.pdbbind import PDBBind |
| from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule |
| from utils.sampling import randomize_position, sampling |
| from utils.utils import get_model |
| from utils.visualise import PDBFile |
| from tqdm import tqdm |
| from datasets.esm_embedding_preparation import esm_embedding_prep |
| import subprocess |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| with open(f"workdir/paper_score_model/model_parameters.yml") as f: |
| score_model_args = Namespace(**yaml.full_load(f)) |
|
|
| with open(f"workdir/paper_confidence_model/model_parameters.yml") as f: |
| confidence_args = Namespace(**yaml.full_load(f)) |
|
|
| import shutil |
|
|
| t_to_sigma = partial(t_to_sigma_compl, args=score_model_args) |
|
|
| model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True) |
| state_dict = torch.load( |
| f"workdir/paper_score_model/best_ema_inference_epoch_model.pt", |
| map_location=torch.device("cpu"), |
| ) |
| model.load_state_dict(state_dict, strict=True) |
| model = model.to(device) |
| model.eval() |
|
|
| confidence_model = get_model( |
| confidence_args, |
| device, |
| t_to_sigma=t_to_sigma, |
| no_parallel=True, |
| confidence_mode=True, |
| ) |
| state_dict = torch.load( |
| f"workdir/paper_confidence_model/best_model_epoch75.pt", |
| map_location=torch.device("cpu"), |
| ) |
| confidence_model.load_state_dict(state_dict, strict=True) |
| confidence_model = confidence_model.to(device) |
| confidence_model.eval() |
|
|
|
|
| def get_pdb(pdb_code="", filepath=""): |
| try: |
| return filepath.name |
| except AttributeError as e: |
| if pdb_code is None or pdb_code == "": |
| return None |
| else: |
| os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb") |
| return f"{pdb_code}.pdb" |
|
|
|
|
| def get_ligand(smiles="", filepath=""): |
| if smiles is None or smiles == "": |
| try: |
| return filepath.name |
| except AttributeError as e: |
| return None |
| else: |
| return smiles |
|
|
|
|
| def read_mol(molpath): |
| with open(molpath, "r") as fp: |
| lines = fp.readlines() |
| mol = "" |
| for l in lines: |
| mol += l |
| return mol |
|
|
|
|
| def molecule(input_pdb, ligand_pdb, original_ligand): |
|
|
| structure = read_mol(input_pdb) |
| mol = read_mol(ligand_pdb) |
|
|
| try: |
| ligand = read_mol(original_ligand.name) |
| _, ext = os.path.splitext(original_ligand.name) |
| lig_str_1 = """let original_ligand = `""" + ligand + """`""" |
| lig_str_2 = f""" |
| viewer.addModel( original_ligand, "{ext[1:]}" ); |
| viewer.getModel(2).setStyle({{stick:{{colorscheme:"greenCarbon"}}}});""" |
| except AttributeError as e: |
| ligand = None |
| lig_str_1 = "" |
| lig_str_2 = "" |
|
|
| x = ( |
| """<!DOCTYPE html> |
| <html> |
| <head> |
| <meta http-equiv="content-type" content="text/html; charset=UTF-8" /> |
| <style> |
| body{ |
| font-family:sans-serif |
| } |
| .mol-container { |
| width: 600px; |
| height: 600px; |
| position: relative; |
| mx-auto:0 |
| } |
| .mol-container select{ |
| background-image:None; |
| } |
| .green{ |
| width:20px; |
| height:20px; |
| background-color:#33ff45; |
| display:inline-block; |
| } |
| .magenta{ |
| width:20px; |
| height:20px; |
| background-color:magenta; |
| display:inline-block; |
| } |
| </style> |
| <script src="https://3Dmol.csb.pitt.edu/build/3Dmol-min.js"></script> |
| </head> |
| <body> |
| <button id="startanimation">Replay diffusion process</button> |
| <button id="togglesurface">Toggle surface representation</button> |
| <div> |
| <span class="green"></span> Uploaded ligand position |
| <span class="magenta"></span> Predicted ligand position |
| </div> |
| <div id="container" class="mol-container"></div> |
| |
| <script> |
| let ligand = `""" |
| + mol |
| + """` |
| let structure = `""" |
| + structure |
| + """` |
| """ |
| + lig_str_1 |
| + """ |
| |
| let viewer = null; |
| let surface = false; |
| let surf = null; |
| $(document).ready(function () { |
| let element = $("#container"); |
| let config = { backgroundColor: "white" }; |
| viewer = $3Dmol.createViewer(element, config); |
| viewer.addModel( structure, "pdb" ); |
| viewer.setStyle({}, {cartoon: {color: "gray"}}); |
| viewer.zoomTo(); |
| viewer.zoom(0.7); |
| viewer.addModelsAsFrames(ligand, "pdb"); |
| viewer.animate({loop: "forward",reps: 1}); |
| |
| viewer.getModel(1).setStyle({stick:{colorscheme:"magentaCarbon"}}); |
| """ |
| + lig_str_2 |
| + """ |
| viewer.render(); |
| |
| }) |
| |
| $("#startanimation").click(function() { |
| viewer.animate({loop: "forward",reps: 1}); |
| }); |
| $("#togglesurface").click(function() { |
| if (surface != true) { |
| surf = viewer.addSurface($3Dmol.SurfaceType.VDW, { "opacity": 0.9, "color": "white" }, { model: 0 }); |
| surface = true; |
| } else { |
| viewer.removeAllSurfaces() |
| surface = false; |
| } |
| }); |
| </script> |
| </body></html>""" |
| ) |
|
|
| return f"""<iframe style="width: 100%; height: 700px" name="result" allow="midi; geolocation; microphone; camera; |
| display-capture; encrypted-media;" sandbox="allow-modals allow-forms |
| allow-scripts allow-same-origin allow-popups |
| allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" |
| allowpaymentrequest="" frameborder="0" srcdoc='{x}'></iframe>""" |
|
|
|
|
| import sys |
|
|
|
|
| def esm(protein_path, out_file): |
| print("running esm") |
| esm_embedding_prep(out_file, protein_path) |
| |
| os.environ["HOME"] = "esm/model_weights" |
| subprocess.call( |
| f"python esm/scripts/extract.py esm2_t33_650M_UR50D {out_file} data/esm2_output --repr_layers 33 --include per_tok", |
| shell=True, |
| env=os.environ, |
| ) |
|
|
|
|
| def update(inp, file, ligand_inp, ligand_file, n_it): |
| pdb_path = get_pdb(inp, file) |
| ligand_path = get_ligand(ligand_inp, ligand_file) |
|
|
| esm( |
| pdb_path, |
| f"data/{os.path.basename(pdb_path)}_prepared_for_esm.fasta", |
| ) |
| tr_schedule = get_t_schedule(inference_steps=n_it) |
| rot_schedule = tr_schedule |
| tor_schedule = tr_schedule |
| print("common t schedule", tr_schedule) |
| ( |
| failures, |
| skipped, |
| confidences_list, |
| names_list, |
| run_times, |
| min_self_distances_list, |
| ) = ( |
| 0, |
| 0, |
| [], |
| [], |
| [], |
| [], |
| ) |
| N = 10 |
| protein_path_list = [pdb_path] |
| ligand_descriptions = [ligand_path] |
| no_random = False |
| ode = False |
| no_final_step_noise = False |
| out_dir = "results/" |
| test_dataset = PDBBind( |
| transform=None, |
| root="", |
| protein_path_list=protein_path_list, |
| ligand_descriptions=ligand_descriptions, |
| receptor_radius=score_model_args.receptor_radius, |
| cache_path="data/cache", |
| remove_hs=score_model_args.remove_hs, |
| max_lig_size=None, |
| c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors, |
| matching=False, |
| keep_original=False, |
| popsize=score_model_args.matching_popsize, |
| maxiter=score_model_args.matching_maxiter, |
| all_atoms=score_model_args.all_atoms, |
| atom_radius=score_model_args.atom_radius, |
| atom_max_neighbors=score_model_args.atom_max_neighbors, |
| esm_embeddings_path="data/esm2_output", |
| require_ligand=True, |
| num_workers=1, |
| keep_local_structures=False, |
| ) |
| test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False) |
| confidence_test_dataset = PDBBind( |
| transform=None, |
| root="", |
| protein_path_list=protein_path_list, |
| ligand_descriptions=ligand_descriptions, |
| receptor_radius=confidence_args.receptor_radius, |
| cache_path="data/cache", |
| remove_hs=confidence_args.remove_hs, |
| max_lig_size=None, |
| c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors, |
| matching=False, |
| keep_original=False, |
| popsize=confidence_args.matching_popsize, |
| maxiter=confidence_args.matching_maxiter, |
| all_atoms=confidence_args.all_atoms, |
| atom_radius=confidence_args.atom_radius, |
| atom_max_neighbors=confidence_args.atom_max_neighbors, |
| esm_embeddings_path="data/esm2_output", |
| require_ligand=True, |
| num_workers=1, |
| ) |
| confidence_complex_dict = {d.name: d for d in confidence_test_dataset} |
| for idx, orig_complex_graph in tqdm(enumerate(test_loader)): |
| if ( |
| confidence_model is not None |
| and not ( |
| confidence_args.use_original_model_cache |
| or confidence_args.transfer_weights |
| ) |
| and orig_complex_graph.name[0] not in confidence_complex_dict.keys() |
| ): |
| skipped += 1 |
| print( |
| f"HAPPENING | The confidence dataset did not contain {orig_complex_graph.name[0]}. We are skipping this complex." |
| ) |
| continue |
| try: |
| data_list = [copy.deepcopy(orig_complex_graph) for _ in range(N)] |
| randomize_position( |
| data_list, |
| score_model_args.no_torsion, |
| no_random, |
| score_model_args.tr_sigma_max, |
| ) |
| pdb = None |
| lig = orig_complex_graph.mol[0] |
| visualization_list = [] |
| for graph in data_list: |
| pdb = PDBFile(lig) |
| pdb.add(lig, 0, 0) |
| pdb.add( |
| ( |
| orig_complex_graph["ligand"].pos |
| + orig_complex_graph.original_center |
| ) |
| .detach() |
| .cpu(), |
| 1, |
| 0, |
| ) |
| pdb.add( |
| (graph["ligand"].pos + graph.original_center).detach().cpu(), |
| part=1, |
| order=1, |
| ) |
| visualization_list.append(pdb) |
|
|
| start_time = time.time() |
| if confidence_model is not None and not ( |
| confidence_args.use_original_model_cache |
| or confidence_args.transfer_weights |
| ): |
| confidence_data_list = [ |
| copy.deepcopy(confidence_complex_dict[orig_complex_graph.name[0]]) |
| for _ in range(N) |
| ] |
| else: |
| confidence_data_list = None |
|
|
| data_list, confidence = sampling( |
| data_list=data_list, |
| model=model, |
| inference_steps=n_it, |
| tr_schedule=tr_schedule, |
| rot_schedule=rot_schedule, |
| tor_schedule=tor_schedule, |
| device=device, |
| t_to_sigma=t_to_sigma, |
| model_args=score_model_args, |
| no_random=no_random, |
| ode=ode, |
| visualization_list=visualization_list, |
| confidence_model=confidence_model, |
| confidence_data_list=confidence_data_list, |
| confidence_model_args=confidence_args, |
| batch_size=1, |
| no_final_step_noise=no_final_step_noise, |
| ) |
| ligand_pos = np.asarray( |
| [ |
| complex_graph["ligand"].pos.cpu().numpy() |
| + orig_complex_graph.original_center.cpu().numpy() |
| for complex_graph in data_list |
| ] |
| ) |
| run_times.append(time.time() - start_time) |
|
|
| if confidence is not None and isinstance( |
| confidence_args.rmsd_classification_cutoff, list |
| ): |
| confidence = confidence[:, 0] |
| if confidence is not None: |
| confidence = confidence.cpu().numpy() |
| re_order = np.argsort(confidence)[::-1] |
| confidence = confidence[re_order] |
| confidences_list.append(confidence) |
| ligand_pos = ligand_pos[re_order] |
| write_dir = ( |
| f'{out_dir}/index{idx}_{data_list[0]["name"][0].replace("/","-")}' |
| ) |
| os.makedirs(write_dir, exist_ok=True) |
| confidences = [] |
| for rank, pos in enumerate(ligand_pos): |
| mol_pred = copy.deepcopy(lig) |
| if score_model_args.remove_hs: |
| mol_pred = RemoveHs(mol_pred) |
| if rank == 0: |
| write_mol_with_coords( |
| mol_pred, pos, os.path.join(write_dir, f"rank{rank+1}.sdf") |
| ) |
| confidences.append(confidence[rank]) |
| write_mol_with_coords( |
| mol_pred, |
| pos, |
| os.path.join( |
| write_dir, f"rank{rank+1}_confidence{confidence[rank]:.2f}.sdf" |
| ), |
| ) |
| self_distances = np.linalg.norm( |
| ligand_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1 |
| ) |
| self_distances = np.where( |
| np.eye(self_distances.shape[2]), np.inf, self_distances |
| ) |
| min_self_distances_list.append(np.min(self_distances, axis=(1, 2))) |
|
|
| filenames = [] |
| if confidence is not None: |
| for rank, batch_idx in enumerate(re_order): |
| visualization_list[batch_idx].write( |
| os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") |
| ) |
| filenames.append( |
| os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") |
| ) |
| else: |
| for rank, batch_idx in enumerate(ligand_pos): |
| visualization_list[batch_idx].write( |
| os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") |
| ) |
| filenames.append( |
| os.path.join(write_dir, f"rank{rank+1}_reverseprocess.pdb") |
| ) |
| names_list.append(orig_complex_graph.name[0]) |
| except Exception as e: |
| print("Failed on", orig_complex_graph["name"], e) |
| failures += 1 |
| return None |
| |
| zippath = shutil.make_archive( |
| os.path.join("results", os.path.basename(pdb_path)), "zip", write_dir |
| ) |
| print("Zipped outputs to", zippath) |
| labels = [ |
| f"rank {i+1}, confidence {confidences[i]:.2f}" for i in range(len(filenames)) |
| ] |
|
|
| torch.cuda.empty_cache() |
| return ( |
| molecule(pdb_path, filenames[0], ligand_file), |
| gr.Dropdown.update(choices=labels, value=labels[0]), |
| filenames, |
| pdb_path, |
| zippath, |
| ) |
|
|
|
|
| def updateView(out, filenames, pdb, ligand_file): |
| print("updating view") |
| i = out |
| print(i) |
| i = int(i.split(",")[0].replace("rank", "")) - 1 |
| return molecule(pdb, filenames[i], ligand_file) |
|
|
|
|
| demo = gr.Blocks() |
|
|
| with demo: |
| gr.Markdown("# DiffDock") |
| gr.Markdown( |
| ">**DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking**, Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi, arXiv:2210.01776 [GitHub](https://github.com/gcorso/diffdock)" |
| ) |
| gr.Markdown("") |
| with gr.Box(): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("## Protein") |
| inp = gr.Textbox( |
| placeholder="PDB Code or upload file below", label="Input structure" |
| ) |
| file = gr.File(file_count="single", label="Input PDB") |
| with gr.Column(): |
| gr.Markdown("## Ligand") |
| ligand_inp = gr.Textbox( |
| placeholder="Provide SMILES input or upload mol2/sdf file below", |
| label="SMILES string", |
| ) |
| ligand_file = gr.File(file_count="single", label="Input Ligand") |
| n_it = gr.Slider( |
| minimum=10, maximum=40, label="Number of inference steps", step=1 |
| ) |
|
|
| btn = gr.Button("Run predictions") |
|
|
| gr.Markdown("## Output") |
| pdb = gr.Variable() |
| filenames = gr.Variable() |
| out = gr.Dropdown(interactive=True, label="Ranked samples") |
| mol = gr.HTML() |
| output_file = gr.File(file_count="single", label="Output files") |
| gr.Examples( |
| [ |
| [ |
| "6w70", |
| "examples/6w70.pdb", |
| "COc1ccc(cc1)n2c3c(c(n2)C(=O)N)CCN(C3=O)c4ccc(cc4)N5CCCCC5=O", |
| "examples/6w70_ligand.sdf", |
| 10, |
| ], |
| [ |
| "6moa", |
| "examples/6moa_protein_processed.pdb", |
| "", |
| "examples/6moa_ligand.sdf", |
| 10, |
| ], |
| [ |
| "", |
| "examples/6o5u_protein_processed.pdb", |
| "", |
| "examples/6o5u_ligand.sdf", |
| 10, |
| ], |
| [ |
| "", |
| "examples/6o5u_protein_processed.pdb", |
| "[NH3+]C[C@H]1O[C@H](O[C@@H]2[C@@H]([NH3+])C[C@H]([C@@H]([C@H]2O)O[C@H]2O[C@H](CO)[C@H]([C@@H]([C@H]2O)[NH3+])O)[NH3+])[C@@H]([C@H]([C@@H]1O)O)O", |
| "examples/6o5u_ligand.sdf", |
| 10, |
| ], |
| [ |
| "", |
| "examples/6o5u_protein_processed.pdb", |
| "", |
| "examples/6o5u_ligand.sdf", |
| 10, |
| ], |
| [ |
| "", |
| "examples/6ahs_protein_processed.pdb", |
| "", |
| "examples/6ahs_ligand.sdf", |
| 10, |
| ], |
| ], |
| [inp, file, ligand_inp, ligand_file, n_it], |
| [mol, out, filenames, pdb, output_file], |
| |
| |
| ) |
| btn.click( |
| fn=update, |
| inputs=[inp, file, ligand_inp, ligand_file, n_it], |
| outputs=[mol, out, filenames, pdb, output_file], |
| ) |
| out.change(fn=updateView, inputs=[out, filenames, pdb, ligand_file], outputs=mol) |
| demo.launch() |
|
|