| """DDP inference script.""" |
| import os |
| import time |
| import numpy as np |
| import hydra |
| import torch |
| import GPUtil |
| import sys |
|
|
| from pytorch_lightning import Trainer |
| from omegaconf import DictConfig, OmegaConf |
| from experiments import utils as eu |
| from models.flow_module import FlowModule |
| import re |
| from typing import Optional |
| import subprocess |
| from biotite.sequence.io import fasta |
| from data import utils as du |
| from analysis import metrics |
| import pandas as pd |
| import esm |
| import shutil |
| import biotite.structure.io as bsio |
|
|
|
|
| torch.set_float32_matmul_precision('high') |
| log = eu.get_pylogger(__name__) |
|
|
| class Sampler: |
|
|
| def __init__(self, cfg: DictConfig): |
| """Initialize sampler. |
| |
| Args: |
| cfg: inference config. |
| """ |
| ckpt_path = cfg.inference.ckpt_path |
| ckpt_dir = os.path.dirname(ckpt_path) |
| |
| |
|
|
| |
| OmegaConf.set_struct(cfg, False) |
| |
| |
| |
|
|
| self._cfg = cfg |
| |
| self._infer_cfg = cfg.inference |
| self._samples_cfg = self._infer_cfg.samples |
| |
|
|
| |
| self._ckpt_name = '/'.join(ckpt_path.replace('.ckpt', '').split('/')[-3:]) |
| self._output_dir = os.path.join( |
| self._infer_cfg.output_dir, |
| self._ckpt_name, |
| self._infer_cfg.name, |
| ) |
| os.makedirs(self._output_dir, exist_ok=True) |
| log.info(f'Saving results to {self._output_dir}') |
| config_path = os.path.join(self._output_dir, 'config.yaml') |
| with open(config_path, 'w') as f: |
| OmegaConf.save(config=self._cfg, f=f) |
| log.info(f'Saving inference config to {config_path}') |
|
|
| |
| self._flow_module = FlowModule.load_from_checkpoint( |
| checkpoint_path=ckpt_path, |
| ) |
| self._flow_module.eval() |
| self._flow_module._infer_cfg = self._infer_cfg |
| self._flow_module._samples_cfg = self._samples_cfg |
| self._flow_module._output_dir = self._output_dir |
|
|
| |
| |
| |
|
|
| devices = [torch.cuda.current_device()] |
|
|
| self._folding_model = esm.pretrained.esmfold_v1().eval() |
| self._folding_model = self._folding_model.to(devices[-1]) |
|
|
| def run_sampling(self): |
| |
| |
| devices = [torch.cuda.current_device()] |
|
|
| log.info(f"Using devices: {devices}") |
|
|
| eval_dataset = eu.LengthDataset(self._samples_cfg) |
| dataloader = torch.utils.data.DataLoader( |
| eval_dataset, batch_size=self._samples_cfg.sample_batch, shuffle=False, drop_last=False) |
| |
| trainer = Trainer( |
| accelerator="gpu", |
| strategy="ddp", |
| devices=devices, |
| ) |
| trainer.predict(self._flow_module, dataloaders=dataloader) |
| |
|
|
|
|
| @hydra.main(version_base=None, config_path="../configs", config_name="inference") |
| def run(cfg: DictConfig) -> None: |
|
|
| |
| log.info(f'Starting inference with {cfg.inference.num_gpus} GPUs') |
| start_time = time.time() |
| sampler = Sampler(cfg) |
| sampler.run_sampling() |
| |
| elapsed_time = time.time() - start_time |
| log.info(f'Finished in {elapsed_time:.2f}s') |
|
|
| if __name__ == '__main__': |
| run() |
|
|