Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| from safetensors.torch import load_model | |
| from text import cleaned_text_to_sequence | |
| from text.phonemize import phonemize | |
| from .model import GibbsTTS_Model | |
| from amphion_utils import amphion_codec | |
| class AudioResampler: | |
| def __init__(self, device, target_sr=24000): | |
| self.device = device | |
| self.target_sr = target_sr | |
| self.resamplers = {} | |
| def __call__(self, wav, sr): | |
| if sr == self.target_sr: | |
| return wav | |
| if sr not in self.resamplers: | |
| self.resamplers[sr] = torchaudio.transforms.Resample( | |
| orig_freq=sr, | |
| new_freq=self.target_sr | |
| ).to(self.device) | |
| return self.resamplers[sr](wav) | |
| class GibbsTTS(nn.Module): | |
| def __init__(self, configs): | |
| super().__init__() | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| self.device = device | |
| self.configs = configs | |
| self.model = GibbsTTS_Model(configs) | |
| load_model(self.model, configs.infer_ckpt_path) | |
| self.model = self.model.to(device) | |
| self.model.eval() | |
| self.codec = amphion_codec(configs, device) | |
| self.resampler = AudioResampler(device, target_sr=24000) | |
| self.language_dict = {"en": 0, "zh": 1, "mixed": 2} | |
| self.space_id = cleaned_text_to_sequence([" "]) | |
| def synthesize(self, prompt_audio, prompt_text, target_text, language): | |
| prompt_phone, _ = phonemize(prompt_text) | |
| target_phone, _ = phonemize(target_text) | |
| prompt_phone = cleaned_text_to_sequence(prompt_phone) | |
| target_phone = self.space_id + cleaned_text_to_sequence(target_phone) | |
| text = prompt_phone + target_phone | |
| text = torch.tensor(text, dtype=torch.long, device=self.device).unsqueeze(0) | |
| prompt_wav, sr = torchaudio.load(prompt_audio) | |
| prompt_wav = self.resampler(prompt_wav.to(self.device), sr).unsqueeze(0) | |
| if prompt_wav.shape[1] > 1: | |
| prompt_wav = prompt_wav.mean(dim=1, keepdim=True) | |
| prompt_token = self.codec.encode(prompt_wav) | |
| ratio = prompt_token.shape[1] / len(prompt_phone) | |
| if language == "en": | |
| ratio = max(3.224 * 0.8, min(ratio, 3.224 * 1.25)) | |
| elif language == "zh": | |
| ratio = max(3.286 * 0.8, min(ratio, 3.286 * 1.25)) | |
| elif language == "mixed": | |
| ratio = max(3.255 * 0.8, min(ratio, 3.255 * 1.25)) | |
| length = int(len(target_phone) * ratio) | |
| lang = torch.tensor([self.language_dict[language]], dtype=torch.long, device=self.device) | |
| outputs = self.model.synthesize(text, lang, length, prompt_token, | |
| n_timesteps=self.configs.steps, temperature=self.configs.temperature, top_p=self.configs.top_p, rescale_cfg=self.configs.rescale_cfg, cfg=self.configs.cfg) | |
| codes = outputs["x"].clamp(min=0, max=1023) | |
| audio = self.codec.decode(codes) | |
| return audio |