| import argparse |
| from pathlib import Path |
|
|
| import torch |
| from openbabel import openbabel |
| openbabel.obErrorLog.StopLogging() |
|
|
| import utils |
| from lightning_modules import LigandPocketDDPM |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument('checkpoint', type=Path) |
| parser.add_argument('--pdbfile', type=str) |
| parser.add_argument('--resi_list', type=str, nargs='+', default=None) |
| parser.add_argument('--ref_ligand', type=str, default=None) |
| parser.add_argument('--outfile', type=Path) |
| parser.add_argument('--n_samples', type=int, default=20) |
| parser.add_argument('--batch_size', type=int, default=None) |
| parser.add_argument('--num_nodes_lig', type=int, default=None) |
| parser.add_argument('--all_frags', action='store_true') |
| parser.add_argument('--sanitize', action='store_true') |
| parser.add_argument('--relax', action='store_true') |
| parser.add_argument('--resamplings', type=int, default=10) |
| parser.add_argument('--jump_length', type=int, default=1) |
| parser.add_argument('--timesteps', type=int, default=None) |
| args = parser.parse_args() |
|
|
| pdb_id = Path(args.pdbfile).stem |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| if args.batch_size is None: |
| args.batch_size = args.n_samples |
| assert args.n_samples % args.batch_size == 0 |
|
|
| |
| model = LigandPocketDDPM.load_from_checkpoint( |
| args.checkpoint, map_location=device) |
| model = model.to(device) |
|
|
| if args.num_nodes_lig is not None: |
| num_nodes_lig = torch.ones(args.n_samples, dtype=int) * \ |
| args.num_nodes_lig |
| else: |
| num_nodes_lig = None |
|
|
| molecules = [] |
| for i in range(args.n_samples // args.batch_size): |
| molecules_batch = model.generate_ligands( |
| args.pdbfile, args.batch_size, args.resi_list, args.ref_ligand, |
| num_nodes_lig, args.sanitize, largest_frag=not args.all_frags, |
| relax_iter=(200 if args.relax else 0), |
| resamplings=args.resamplings, jump_length=args.jump_length, |
| timesteps=args.timesteps) |
| molecules.extend(molecules_batch) |
|
|
| |
| utils.write_sdf_file(args.outfile, molecules) |
|
|