File size: 3,033 Bytes
0afe769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca68ef2
 
 
0afe769
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from safetensors.torch import load_model

from text import cleaned_text_to_sequence
from text.phonemize import phonemize

from .model import GibbsTTS_Model
from amphion_utils import amphion_codec

class AudioResampler:
    def __init__(self, device, target_sr=24000):
        self.device = device
        self.target_sr = target_sr
        self.resamplers = {}

    def __call__(self, wav, sr):
        if sr == self.target_sr:
            return wav

        if sr not in self.resamplers:
            self.resamplers[sr] = torchaudio.transforms.Resample(
                orig_freq=sr,
                new_freq=self.target_sr
            ).to(self.device)

        return self.resamplers[sr](wav)
    
class GibbsTTS(nn.Module):
    def __init__(self, configs):
        super().__init__()

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = device
        self.configs = configs

        self.model = GibbsTTS_Model(configs)
        load_model(self.model, configs.infer_ckpt_path)
        self.model = self.model.to(device)
        self.model.eval()

        self.codec = amphion_codec(configs, device)

        self.resampler = AudioResampler(device, target_sr=24000)

        self.language_dict = {"en": 0, "zh": 1, "mixed": 2}
        self.space_id = cleaned_text_to_sequence([" "])

    @torch.no_grad()
    def synthesize(self, prompt_audio, prompt_text, target_text, language):
        prompt_phone, _ = phonemize(prompt_text)
        target_phone, _ = phonemize(target_text)
        prompt_phone = cleaned_text_to_sequence(prompt_phone)
        target_phone = self.space_id + cleaned_text_to_sequence(target_phone)
        text = prompt_phone + target_phone
        text = torch.tensor(text, dtype=torch.long, device=self.device).unsqueeze(0)

        prompt_wav, sr = torchaudio.load(prompt_audio)
        prompt_wav = self.resampler(prompt_wav.to(self.device), sr).unsqueeze(0)

        if prompt_wav.shape[1] > 1:
            prompt_wav = prompt_wav.mean(dim=1, keepdim=True)
        prompt_token = self.codec.encode(prompt_wav)

        ratio = prompt_token.shape[1] / len(prompt_phone)
        if language == "en":
            ratio = max(3.224 * 0.8, min(ratio, 3.224 * 1.25))
        elif language == "zh":
            ratio = max(3.286 * 0.8, min(ratio, 3.286 * 1.25))
        elif language == "mixed":
            ratio = max(3.255 * 0.8, min(ratio, 3.255 * 1.25))
        length = int(len(target_phone) * ratio)

        lang = torch.tensor([self.language_dict[language]], dtype=torch.long, device=self.device)

        outputs = self.model.synthesize(text, lang, length, prompt_token, 
                                   n_timesteps=self.configs.steps, temperature=self.configs.temperature, top_p=self.configs.top_p, rescale_cfg=self.configs.rescale_cfg, cfg=self.configs.cfg)
        codes = outputs["x"].clamp(min=0, max=1023)
    
        audio = self.codec.decode(codes)
        return audio