| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| import copy |
| import math |
| from tqdm.auto import tqdm |
| import functools |
| import os |
| import argparse |
| import pandas as pd |
| from copy import deepcopy |
|
|
| from models_con.pep_dataloader import PepDataset |
|
|
| from pepflow.utils.train import recursive_to |
|
|
| from pepflow.modules.common.geometry import reconstruct_backbone, reconstruct_backbone_partially, align, batch_align |
| from pepflow.modules.protein.writers import save_pdb |
|
|
| from pepflow.utils.data import PaddingCollate |
|
|
| from models_con.utils import process_dic |
|
|
| from models_con.flow_model import FlowModel |
|
|
| from models_con.torsion import full_atom_reconstruction, get_heavyatom_mask |
|
|
| collate_fn = PaddingCollate(eight=False) |
|
|
| import argparse |
|
|
|
|
| def item_to_batch(item, nums=32): |
| data_list = [deepcopy(item) for i in range(nums)] |
| return collate_fn(data_list) |
|
|
| def sample_for_data_bb(data, model, device, save_root, num_steps=200, sample_structure=True, sample_sequence=True, nums=8): |
| if not os.path.exists(os.path.join(save_root,data["id"])): |
| os.makedirs(os.path.join(save_root,data["id"])) |
| batch = recursive_to(item_to_batch(data, nums=nums),device=device) |
| traj = model.sample(batch, num_steps=num_steps, sample_structure=sample_structure, sample_sequence=sample_sequence) |
| final = recursive_to(traj[-1], device=device) |
| pos_bb = reconstruct_backbone(R=final['rotmats'],t=final['trans'],aa=final['seqs'],chain_nb=batch['chain_nb'],res_nb=batch['res_nb'],mask=batch['res_mask']) |
| pos_ha = F.pad(pos_bb, pad=(0,0,0,15-4), value=0.) |
| pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom']) |
| mask_bb_atoms = torch.zeros_like(batch['mask_heavyatom']) |
| mask_bb_atoms[:,:,:4] = True |
| mask_new = torch.where(batch['generate_mask'][:,:,None],mask_bb_atoms,batch['mask_heavyatom']) |
| aa_new = final['seqs'] |
|
|
| chain_nb = torch.LongTensor([0 if gen_mask else 1 for gen_mask in data['generate_mask']]) |
| chain_id = ['A' if gen_mask else 'B' for gen_mask in data['generate_mask']] |
| icode = [' ' for _ in range(len(data['icode']))] |
| for i in range(nums): |
| ref_bb_pos = data['pos_heavyatom'][i][:,:4].cpu() |
| pred_bb_pos = pos_new[i][:,:4].cpu() |
| data_saved = { |
| 'chain_nb':data['chain_nb'],'chain_id':data['chain_id'],'resseq':data['resseq'],'icode':data['icode'], |
| 'aa':aa_new[i].cpu(), 'mask_heavyatom':mask_new[i].cpu(), 'pos_heavyatom':pos_new[i].cpu(), |
| } |
|
|
| save_pdb(data_saved,path=os.path.join(save_root,data["id"],f'{data["id"]}_{i}.pdb')) |
| save_pdb(data,path=os.path.join(save_root,data["id"],f'{data["id"]}_gt.pdb')) |
|
|
| def save_samples_bb(samples,save_dir): |
| |
| batch = recursive_to(samples['batch'],'cpu') |
| chain_id = [list(item) for item in zip(*batch['chain_id'])][0] |
| icode = [' ' for _ in range(len(chain_id))] |
| nums = len(batch['id']) |
| id = batch['id'][0] |
| |
| |
| pos_bb = reconstruct_backbone(R=samples['rotmats'],t=samples['trans'],aa=samples['seqs'],chain_nb=batch['chain_nb'],res_nb=batch['res_nb'],mask=batch['res_mask']) |
| pos_ha = F.pad(pos_bb, pad=(0,0,0,15-4), value=0.) |
| pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom']) |
| mask_bb_atoms = torch.zeros_like(batch['mask_heavyatom']) |
| mask_bb_atoms[:,:,:4] = True |
| mask_new = torch.where(batch['generate_mask'][:,:,None],mask_bb_atoms,batch['mask_heavyatom']) |
| aa_new = samples['seqs'] |
| for i in range(nums): |
| data_saved = { |
| 'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode, |
| 'aa':aa_new[i], 'mask_heavyatom':mask_new[i], 'pos_heavyatom':pos_new[i], |
| } |
| save_pdb(data_saved,path=os.path.join(save_dir,f'sample_{i}.pdb')) |
| data_saved = { |
| 'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode, |
| 'aa':batch['aa'][0], 'mask_heavyatom':batch['mask_heavyatom'][0], 'pos_heavyatom':batch['pos_heavyatom'][0], |
| } |
| save_pdb(data_saved,path=os.path.join(save_dir,f'gt.pdb')) |
|
|
| def save_samples_sc(samples,save_dir): |
| |
| batch = recursive_to(samples['batch'],'cpu') |
| chain_id = [list(item) for item in zip(*batch['chain_id'])][0] |
| icode = [' ' for _ in range(len(chain_id))] |
| nums = len(batch['id']) |
| id = batch['id'][0] |
| |
| |
| pos_ha,_,_ = full_atom_reconstruction(R_bb=samples['rotmats'],t_bb=samples['trans'],angles=samples['angles'],aa=samples['seqs']) |
| pos_ha = F.pad(pos_ha, pad=(0,0,0,15-14), value=0.) |
| pos_new = torch.where(batch['generate_mask'][:,:,None,None],pos_ha,batch['pos_heavyatom']) |
| mask_new = get_heavyatom_mask(samples['seqs']) |
| aa_new = samples['seqs'] |
| for i in range(nums): |
| data_saved = { |
| 'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode, |
| 'aa':aa_new[i], 'mask_heavyatom':mask_new[i], 'pos_heavyatom':pos_new[i], |
| } |
| save_pdb(data_saved,path=os.path.join(save_dir,f'sample_{i}.pdb')) |
| data_saved = { |
| 'chain_nb':batch['chain_nb'][0],'chain_id':chain_id,'resseq':batch['resseq'][0],'icode':icode, |
| 'aa':batch['aa'][0], 'mask_heavyatom':batch['mask_heavyatom'][0], 'pos_heavyatom':batch['pos_heavyatom'][0], |
| } |
| save_pdb(data_saved,path=os.path.join(save_dir,f'gt.pdb')) |
|
|
| if __name__ == '__main__': |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| args = argparse.ArgumentParser() |
| args.add_argument('--SAMPLEDIR', type=str) |
| parser = args.parse_args() |
| SAMPLE_DIR = parser.SAMPLEDIR |
| names = [n.split('.')[0] for n in os.listdir(os.path.join(SAMPLE_DIR,'outputs'))] |
| for name in tqdm(names): |
| sample = torch.load(os.path.join(SAMPLE_DIR,'outputs',f'{name}.pt')) |
| os.makedirs(os.path.join(SAMPLE_DIR,'pdbs',name),exist_ok=True) |
| save_samples_sc(sample,os.path.join(SAMPLE_DIR,'pdbs',name)) |