Spaces:
Running on Zero
Running on Zero
File size: 3,033 Bytes
0afe769 ca68ef2 0afe769 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from safetensors.torch import load_model
from text import cleaned_text_to_sequence
from text.phonemize import phonemize
from .model import GibbsTTS_Model
from amphion_utils import amphion_codec
class AudioResampler:
def __init__(self, device, target_sr=24000):
self.device = device
self.target_sr = target_sr
self.resamplers = {}
def __call__(self, wav, sr):
if sr == self.target_sr:
return wav
if sr not in self.resamplers:
self.resamplers[sr] = torchaudio.transforms.Resample(
orig_freq=sr,
new_freq=self.target_sr
).to(self.device)
return self.resamplers[sr](wav)
class GibbsTTS(nn.Module):
def __init__(self, configs):
super().__init__()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.configs = configs
self.model = GibbsTTS_Model(configs)
load_model(self.model, configs.infer_ckpt_path)
self.model = self.model.to(device)
self.model.eval()
self.codec = amphion_codec(configs, device)
self.resampler = AudioResampler(device, target_sr=24000)
self.language_dict = {"en": 0, "zh": 1, "mixed": 2}
self.space_id = cleaned_text_to_sequence([" "])
@torch.no_grad()
def synthesize(self, prompt_audio, prompt_text, target_text, language):
prompt_phone, _ = phonemize(prompt_text)
target_phone, _ = phonemize(target_text)
prompt_phone = cleaned_text_to_sequence(prompt_phone)
target_phone = self.space_id + cleaned_text_to_sequence(target_phone)
text = prompt_phone + target_phone
text = torch.tensor(text, dtype=torch.long, device=self.device).unsqueeze(0)
prompt_wav, sr = torchaudio.load(prompt_audio)
prompt_wav = self.resampler(prompt_wav.to(self.device), sr).unsqueeze(0)
if prompt_wav.shape[1] > 1:
prompt_wav = prompt_wav.mean(dim=1, keepdim=True)
prompt_token = self.codec.encode(prompt_wav)
ratio = prompt_token.shape[1] / len(prompt_phone)
if language == "en":
ratio = max(3.224 * 0.8, min(ratio, 3.224 * 1.25))
elif language == "zh":
ratio = max(3.286 * 0.8, min(ratio, 3.286 * 1.25))
elif language == "mixed":
ratio = max(3.255 * 0.8, min(ratio, 3.255 * 1.25))
length = int(len(target_phone) * ratio)
lang = torch.tensor([self.language_dict[language]], dtype=torch.long, device=self.device)
outputs = self.model.synthesize(text, lang, length, prompt_token,
n_timesteps=self.configs.steps, temperature=self.configs.temperature, top_p=self.configs.top_p, rescale_cfg=self.configs.rescale_cfg, cfg=self.configs.cfg)
codes = outputs["x"].clamp(min=0, max=1023)
audio = self.codec.decode(codes)
return audio |