| |
| |
| |
| |
|
|
| import argparse |
| import os |
| import re |
| import time |
| from pathlib import Path |
|
|
| import torch |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from models.vocoders.vocoder_inference import synthesis |
| from torch.utils.data import DataLoader |
| from utils.util import set_all_random_seed |
| from utils.util import load_config |
|
|
|
|
| def parse_vocoder(vocoder_dir): |
| r"""Parse vocoder config""" |
| vocoder_dir = os.path.abspath(vocoder_dir) |
| ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] |
| ckpt_list.sort(key=lambda x: int(x.stem), reverse=True) |
| ckpt_path = str(ckpt_list[0]) |
| vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True) |
| vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder |
| return vocoder_cfg, ckpt_path |
|
|
|
|
| class BaseInference(object): |
| def __init__(self, cfg, args): |
| self.cfg = cfg |
| self.args = args |
| self.model_type = cfg.model_type |
| self.avg_rtf = list() |
| set_all_random_seed(10086) |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| if torch.cuda.is_available(): |
| self.device = torch.device("cuda") |
| else: |
| self.device = torch.device("cpu") |
| torch.set_num_threads(10) |
|
|
| |
| self.model = self.create_model().to(self.device) |
| state_dict = self.load_state_dict() |
| self.load_model(state_dict) |
| self.model.eval() |
|
|
| |
| if self.args.checkpoint_dir_vocoder is not None: |
| self.get_vocoder_info() |
|
|
| def create_model(self): |
| raise NotImplementedError |
|
|
| def load_state_dict(self): |
| self.checkpoint_file = self.args.checkpoint_file |
| if self.checkpoint_file is None: |
| assert self.args.checkpoint_dir is not None |
| checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint") |
| checkpoint_filename = open(checkpoint_path).readlines()[-1].strip() |
| self.checkpoint_file = os.path.join( |
| self.args.checkpoint_dir, checkpoint_filename |
| ) |
|
|
| self.checkpoint_dir = os.path.split(self.checkpoint_file)[0] |
|
|
| print("Restore acoustic model from {}".format(self.checkpoint_file)) |
| raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device) |
| self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0] |
|
|
| return raw_state_dict |
|
|
| def load_model(self, model): |
| raise NotImplementedError |
|
|
| def get_vocoder_info(self): |
| self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder |
| self.vocoder_cfg = os.path.join( |
| os.path.dirname(self.checkpoint_dir_vocoder), "args.json" |
| ) |
| self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True) |
| self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1] |
| self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0] |
|
|
| def build_test_utt_data(self): |
| raise NotImplementedError |
|
|
| def build_testdata_loader(self, args, target_speaker=None): |
| datasets, collate = self.build_test_dataset() |
| self.test_dataset = datasets(self.cfg, args, target_speaker) |
| self.test_collate = collate(self.cfg) |
| self.test_batch_size = min( |
| self.cfg.train.batch_size, len(self.test_dataset.metadata) |
| ) |
| test_loader = DataLoader( |
| self.test_dataset, |
| collate_fn=self.test_collate, |
| num_workers=self.args.num_workers, |
| batch_size=self.test_batch_size, |
| shuffle=False, |
| ) |
| return test_loader |
|
|
| def inference_each_batch(self, batch_data): |
| raise NotImplementedError |
|
|
| def inference_for_batches(self, args, target_speaker=None): |
| |
| loader = self.build_testdata_loader(args, target_speaker) |
|
|
| n_batch = len(loader) |
| now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) |
| print( |
| "Model eval time: {}, batch_size = {}, n_batch = {}".format( |
| now, self.test_batch_size, n_batch |
| ) |
| ) |
| self.model.eval() |
|
|
| |
| pred_res = [] |
| with torch.no_grad(): |
| for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)): |
| |
| for k, v in batch_data.items(): |
| batch_data[k] = batch_data[k].to(self.device) |
|
|
| y_pred, stats = self.inference_each_batch(batch_data) |
|
|
| pred_res += y_pred |
|
|
| return pred_res |
|
|
| def inference(self, feature): |
| raise NotImplementedError |
|
|
| def synthesis_by_vocoder(self, pred): |
| audios_pred = synthesis( |
| self.vocoder_cfg, |
| self.checkpoint_dir_vocoder, |
| len(pred), |
| pred, |
| ) |
| return audios_pred |
|
|
| def __call__(self, utt): |
| feature = self.build_test_utt_data(utt) |
| start_time = time.time() |
| with torch.no_grad(): |
| outputs = self.inference(feature)[0] |
| time_used = time.time() - start_time |
| rtf = time_used / ( |
| outputs.shape[1] |
| * self.cfg.preprocess.hop_size |
| / self.cfg.preprocess.sample_rate |
| ) |
| print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf)) |
| self.avg_rtf.append(rtf) |
| audios = outputs.cpu().squeeze().numpy().reshape(-1, 1) |
| return audios |
|
|
|
|
| def base_parser(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--config", default="config.json", help="json files for configurations." |
| ) |
| parser.add_argument("--use_ddp_inference", default=False) |
| parser.add_argument("--n_workers", default=1, type=int) |
| parser.add_argument("--local_rank", default=-1, type=int) |
| parser.add_argument( |
| "--batch_size", default=1, type=int, help="Batch size for inference" |
| ) |
| parser.add_argument( |
| "--num_workers", |
| default=1, |
| type=int, |
| help="Worker number for inference dataloader", |
| ) |
| parser.add_argument( |
| "--checkpoint_dir", |
| type=str, |
| default=None, |
| help="Checkpoint dir including model file and configuration", |
| ) |
| parser.add_argument( |
| "--checkpoint_file", help="checkpoint file", type=str, default=None |
| ) |
| parser.add_argument( |
| "--test_list", help="test utterance list for testing", type=str, default=None |
| ) |
| parser.add_argument( |
| "--checkpoint_dir_vocoder", |
| help="Vocoder's checkpoint dir including model file and configuration", |
| type=str, |
| default=None, |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default=None, |
| help="Output dir for saving generated results", |
| ) |
| return parser |
|
|
|
|
| if __name__ == "__main__": |
| parser = base_parser() |
| args = parser.parse_args() |
| cfg = load_config(args.config) |
|
|
| |
| inference = BaseInference(cfg, args) |
| inference() |
|
|