| |
| |
| import os |
| import yaml |
| import argparse |
| from tqdm import tqdm |
|
|
| import torch |
|
|
| from generate import get_best_ckpt, to_device |
| from data import create_dataloader, create_dataset |
|
|
|
|
| def main(args): |
| config = yaml.safe_load(open(args.config, 'r')) |
| |
| b_ckpt = args.ckpt if args.ckpt.endswith('.ckpt') else get_best_ckpt(args.ckpt) |
| print(f'Using checkpoint {b_ckpt}') |
| model = torch.load(b_ckpt, map_location='cpu') |
| device = torch.device('cpu' if args.gpu == -1 else f'cuda:{args.gpu}') |
| model.to(device) |
| model.eval() |
| |
| |
| _, _, test_set = create_dataset(config['dataset']) |
| test_loader = create_dataloader(test_set, config['dataloader']) |
|
|
| all_dists = [] |
|
|
| with torch.no_grad(): |
| for batch in tqdm(test_loader): |
| batch = to_device(batch, device) |
| H, Z, _, _ = model.autoencoder.encode( |
| batch['X'], batch['S'], batch['mask'], batch['position_ids'], |
| batch['lengths'], batch['atom_mask'], no_randomness=True |
| ) |
| pos = batch['position_ids'][batch['mask']] |
| Z = Z.squeeze(1) |
| dists = torch.norm(Z[1:] - Z[:-1], dim=-1) |
| pos_dist = pos[1:] - pos[:-1] |
| dists = dists[pos_dist == 1] |
| all_dists.append(dists) |
| all_dists = torch.cat(all_dists, dim=0) |
| mean, std = torch.mean(all_dists), torch.std(all_dists) |
| print(mean, std) |
| model.set_consec_dist(mean.item(), std.item()) |
| torch.save(model, b_ckpt) |
|
|
| def parse(): |
| parser = argparse.ArgumentParser(description='Calculate distance between consecutive latent points') |
| parser.add_argument('--config', type=str, required=True) |
| parser.add_argument('--ckpt', type=str, required=True) |
| parser.add_argument('--gpu', type=int, default=0) |
| return parser.parse_args() |
|
|
|
|
| if __name__ == '__main__': |
| main(parse()) |