PepGLAD / setup_latent_guidance.py
Irwiny123's picture
添加PepGLAD初始代码
52007f8
#!/usr/bin/python
# -*- coding:utf-8 -*-
import os
import yaml
import argparse
from tqdm import tqdm
import torch
from generate import get_best_ckpt, to_device
from data import create_dataloader, create_dataset
def main(args):
config = yaml.safe_load(open(args.config, 'r'))
# load model
b_ckpt = args.ckpt if args.ckpt.endswith('.ckpt') else get_best_ckpt(args.ckpt)
print(f'Using checkpoint {b_ckpt}')
model = torch.load(b_ckpt, map_location='cpu')
device = torch.device('cpu' if args.gpu == -1 else f'cuda:{args.gpu}')
model.to(device)
model.eval()
# load data
_, _, test_set = create_dataset(config['dataset'])
test_loader = create_dataloader(test_set, config['dataloader'])
all_dists = []
with torch.no_grad():
for batch in tqdm(test_loader):
batch = to_device(batch, device)
H, Z, _, _ = model.autoencoder.encode(
batch['X'], batch['S'], batch['mask'], batch['position_ids'],
batch['lengths'], batch['atom_mask'], no_randomness=True
)
pos = batch['position_ids'][batch['mask']]
Z = Z.squeeze(1)
dists = torch.norm(Z[1:] - Z[:-1], dim=-1) # [N]
pos_dist = pos[1:] - pos[:-1]
dists = dists[pos_dist == 1]
all_dists.append(dists)
all_dists = torch.cat(all_dists, dim=0)
mean, std = torch.mean(all_dists), torch.std(all_dists)
print(mean, std)
model.set_consec_dist(mean.item(), std.item())
torch.save(model, b_ckpt)
def parse():
parser = argparse.ArgumentParser(description='Calculate distance between consecutive latent points')
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--ckpt', type=str, required=True)
parser.add_argument('--gpu', type=int, default=0)
return parser.parse_args()
if __name__ == '__main__':
main(parse())