| |
| |
| import os |
| import sys |
| import json |
| import argparse |
| from tqdm import tqdm |
| from os.path import splitext, basename |
|
|
| import ray |
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from data.format import Atom, Block, VOCAB |
| from data.converter.pdb_to_list_blocks import pdb_to_list_blocks |
| from data.converter.list_blocks_to_pdb import list_blocks_to_pdb |
| from data.codesign import calculate_covariance_matrix |
| from utils.const import sidechain_atoms |
| from utils.logger import print_log |
| from evaluation.dG.openmm_relaxer import ForceFieldMinimizer |
|
|
|
|
| class DesignDataset(torch.utils.data.Dataset): |
|
|
| MAX_N_ATOM = 14 |
|
|
| def __init__(self, pdbs, epitopes, lengths_range=None, seqs=None) -> None: |
| super().__init__() |
| self.pdbs = pdbs |
| self.epitopes = epitopes |
| self.lengths_range = lengths_range |
| self.seqs = seqs |
| |
| assert (self.seqs is not None and self.lengths_range is None) | \ |
| (self.seqs is None and self.lengths_range is not None) |
|
|
| def get_epitope(self, idx): |
| pdb, epitope_def = self.pdbs[idx], self.epitopes[idx] |
|
|
| with open(epitope_def, 'r') as fin: |
| epitope = json.load(fin) |
| to_str = lambda pos: f'{pos[0]}-{pos[1]}' |
| epi_map = {} |
| for chain_name, pos in epitope: |
| if chain_name not in epi_map: |
| epi_map[chain_name] = {} |
| epi_map[chain_name][to_str(pos)] = True |
| residues, position_ids = [], [] |
| chain2blocks = pdb_to_list_blocks(pdb, list(epi_map.keys()), dict_form=True) |
| if len(chain2blocks) != len(epi_map): |
| print_log(f'Some chains in the epitope are missing. Parsed {list(chain2blocks.keys())}, given {list(epi_map.keys())}.', level='WARN') |
| for chain_name in chain2blocks: |
| chain = chain2blocks[chain_name] |
| for i, block in enumerate(chain): |
| if to_str(block.id) in epi_map[chain_name]: |
| residues.append(block) |
| position_ids.append(i + 1) |
| return residues, position_ids, chain2blocks |
|
|
| def generate_pep_chain(self, idx): |
| if self.lengths_range is not None: |
| lmin, lmax = self.lengths_range[idx] |
| length = np.random.randint(lmin, lmax) |
| unk_block = Block(VOCAB.symbol_to_abrv(VOCAB.UNK), [Atom('CA', [0, 0, 0], 'C')]) |
| return [unk_block] * length |
| else: |
| seq = self.seqs[idx] |
| blocks = [] |
| for s in seq: |
| atoms = [] |
| for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(s, []): |
| atoms.append(Atom(atom_name, [0, 0, 0], atom_name[0])) |
| blocks.append(Block(VOCAB.symbol_to_abrv(s), atoms)) |
| return blocks |
| |
| def __len__(self): |
| return len(self.pdbs) |
|
|
| def __getitem__(self, idx: int): |
| rec_blocks, rec_position_ids, rec_chain2blocks = self.get_epitope(idx) |
| lig_blocks = self.generate_pep_chain(idx) |
|
|
| mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks] |
| position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)] |
| X, S, atom_mask = [], [], [] |
| for block in rec_blocks + lig_blocks: |
| symbol = VOCAB.abrv_to_symbol(block.abrv) |
| atom2coord = { unit.name: unit.get_coord() for unit in block.units } |
| bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist() |
| coords, coord_mask = [], [] |
| for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []): |
| if atom_name in atom2coord: |
| coords.append(atom2coord[atom_name]) |
| coord_mask.append(1) |
| else: |
| coords.append(bb_pos) |
| coord_mask.append(0) |
| n_pad = self.MAX_N_ATOM - len(coords) |
| for _ in range(n_pad): |
| coords.append(bb_pos) |
| coord_mask.append(0) |
|
|
| X.append(coords) |
| S.append(VOCAB.symbol_to_idx(symbol)) |
| atom_mask.append(coord_mask) |
| |
| X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool) |
| mask = torch.tensor(mask, dtype=torch.bool) |
| cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) |
| eps = 1e-4 |
| cov = cov + eps * np.identity(cov.shape[0]) |
| L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0) |
|
|
| return { |
| 'X': X, |
| 'S': torch.tensor(S, dtype=torch.long), |
| 'position_ids': torch.tensor(position_ids, dtype=torch.long), |
| 'mask': mask, |
| 'atom_mask': atom_mask, |
| 'lengths': len(S), |
| 'rec_chain2blocks': rec_chain2blocks, |
| 'L': L |
| } |
|
|
| def collate_fn(self, batch): |
| results = {} |
| for key in batch[0]: |
| values = [item[key] for item in batch] |
| if key == 'lengths': |
| results[key] = torch.tensor(values, dtype=torch.long) |
| elif key == 'rec_chain2blocks': |
| results[key] = values |
| else: |
| results[key] = torch.cat(values, dim=0) |
| return results |
|
|
|
|
| @ray.remote(num_cpus=1, num_gpus=1/16) |
| def openmm_relax(pdb_path): |
| force_field = ForceFieldMinimizer() |
| force_field(pdb_path, pdb_path) |
| return pdb_path |
|
|
|
|
| def design(mode, ckpt, gpu, pdbs, epitope_defs, n_samples, out_dir, |
| lengths_range=None, seqs=None, identifiers=None, batch_size=8, num_workers=4): |
|
|
| |
| if not os.path.exists(out_dir): |
| os.makedirs(out_dir) |
| result_summary = open(os.path.join(out_dir, 'summary.jsonl'), 'w') |
| if identifiers is None: |
| identifiers = [splitext(basename(pdb))[0] for pdb in pdbs] |
| |
| device = torch.device('cpu' if gpu == -1 else f'cuda:{gpu}') |
| model = torch.load(ckpt, map_location='cpu') |
| model.to(device) |
| model.eval() |
|
|
| |
| |
| if lengths_range is None: lengths_range = [None for _ in pdbs] |
| if seqs is None: seqs = [None for _ in pdbs] |
| expand_pdbs, expand_epitopes, expand_lens, expand_ids, expand_seqs = [], [], [], [], [] |
| for _id, pdb, epitope, l, s, n in zip(identifiers, pdbs, epitope_defs, lengths_range, seqs, n_samples): |
| expand_ids.extend([f'{_id}_{i}' for i in range(n)]) |
| expand_pdbs.extend([pdb for _ in range(n)]) |
| expand_epitopes.extend([epitope for _ in range(n)]) |
| expand_lens.extend([l for _ in range(n)]) |
| expand_seqs.extend([s for _ in range(n)]) |
| |
| if expand_lens[0] is None: expand_lens = None |
| if expand_seqs[0] is None: expand_seqs = None |
| dataset = DesignDataset(expand_pdbs, expand_epitopes, expand_lens, expand_seqs) |
| dataloader = DataLoader(dataset, batch_size=batch_size, |
| num_workers=num_workers, |
| collate_fn=dataset.collate_fn, |
| shuffle=False |
| ) |
| |
| |
| cnt = 0 |
| all_pdbs = [] |
| for batch in tqdm(dataloader): |
| with torch.no_grad(): |
| |
| for k in batch: |
| if hasattr(batch[k], 'to'): |
| batch[k] = batch[k].to(device) |
| |
| batch_X, batch_S, batch_pmetric = model.sample( |
| batch['X'], batch['S'], |
| batch['mask'], batch['position_ids'], |
| batch['lengths'], batch['atom_mask'], |
| L=batch['L'], sample_opt={ |
| 'energy_func': 'default', |
| 'energy_lambda': 0.5 if mode == 'struct_pred' else 0.8 |
| } |
| ) |
| |
| for X, S, pmetric, rec_chain2blocks in zip(batch_X, batch_S, batch_pmetric, batch['rec_chain2blocks']): |
| if S is None: S = expand_seqs[cnt] |
| lig_blocks = [] |
| for x, s in zip(X, S): |
| abrv = VOCAB.symbol_to_abrv(s) |
| atoms = VOCAB.backbone_atoms + sidechain_atoms[VOCAB.abrv_to_symbol(abrv)] |
| units = [ |
| Atom(atom_name, coord, atom_name[0]) for atom_name, coord in zip(atoms, x) |
| ] |
| lig_blocks.append(Block(abrv, units)) |
| list_blocks, chain_names = [], [] |
| for chain in rec_chain2blocks: |
| list_blocks.append(rec_chain2blocks[chain]) |
| chain_names.append(chain) |
| pep_chain_id = chr(max([ord(c) for c in chain_names]) + 1) |
| list_blocks.append(lig_blocks) |
| chain_names.append(pep_chain_id) |
| out_pdb = os.path.join(out_dir, expand_ids[cnt] + '.pdb') |
| list_blocks_to_pdb(list_blocks, chain_names, out_pdb) |
| all_pdbs.append(out_pdb) |
| result_summary.write(json.dumps({ |
| 'id': expand_ids[cnt], |
| 'rec_chains': list(rec_chain2blocks.keys()), |
| 'pep_chain': pep_chain_id, |
| 'pep_seq': ''.join([VOCAB.abrv_to_symbol(block.abrv) for block in lig_blocks]) |
| }) + '\n') |
| result_summary.flush() |
| cnt += 1 |
| result_summary.close() |
|
|
| print_log(f'Running openmm relaxation...') |
| ray.init(num_cpus=8) |
| futures = [openmm_relax.remote(path) for path in all_pdbs] |
| pbar = tqdm(total=len(futures)) |
| while len(futures) > 0: |
| done_ids, futures = ray.wait(futures, num_returns=1) |
| for done_id in done_ids: |
| done_path = ray.get(done_id) |
| pbar.update(1) |
| print_log(f'Done') |
|
|
|
|
| def parse(): |
| parser = argparse.ArgumentParser(description='run pepglad for codesign or structure prediction') |
| parser.add_argument('--mode', type=str, required=True, choices=['codesign', 'struct_pred'], help='Running mode') |
| parser.add_argument('--pdb', type=str, required=True, help='Path to the PDB file of the target protein') |
| parser.add_argument('--pocket', type=str, required=True, help='Path to the pocket definition (*.json generated by detect_pocket)') |
| parser.add_argument('--n_samples', type=int, default=10, help='Number of samples') |
| parser.add_argument('--out_dir', type=str, required=True, help='Output directory') |
| parser.add_argument('--peptide_seq', type=str, required='struct_pred' in sys.argv, help='Peptide sequence for structure prediction') |
| parser.add_argument('--length_min', type=int, required='codesign' in sys.argv, help='Minimum peptide length for codesign (inclusive)') |
| parser.add_argument('--length_max', type=int, required='codesign' in sys.argv, help='Maximum peptide length for codesign (exclusive)') |
| parser.add_argument('--gpu', type=int, default=0, help='GPU to use') |
| return parser.parse_args() |
|
|
|
|
| if __name__ == '__main__': |
| args = parse() |
| proj_dir = os.path.join(os.path.dirname(__file__), '..') |
| ckpt = os.path.join(proj_dir, 'checkpoints', 'fixseq.ckpt' if args.mode == 'struct_pred' else 'codesign.ckpt') |
| print_log(f'Loading checkpoint: {ckpt}') |
| design( |
| mode=args.mode, |
| ckpt=ckpt, |
| gpu=args.gpu, |
| pdbs=[args.pdb], |
| epitope_defs=[args.pocket], |
| n_samples=[args.n_samples], |
| out_dir=args.out_dir, |
| identifiers=[os.path.basename(os.path.splitext(args.pdb)[0])], |
| lengths_range=[(args.length_min, args.length_max)] if args.mode == 'codesign' else None, |
| seqs=[args.peptide_seq] if args.mode == 'struct_pred' else None |
| ) |