| import re
|
| import yaml
|
| from munch import Munch
|
| import numpy as np
|
| import librosa
|
| import noisereduce as nr
|
| from meldataset import TextCleaner
|
| import torch
|
| import torchaudio
|
| from nltk.tokenize import word_tokenize
|
| import nltk
|
| nltk.download('punkt_tab')
|
|
|
| from models import ProsodyPredictor, TextEncoder, StyleEncoder
|
| from Modules.hifigan import Decoder
|
|
|
| class Preprocess:
|
| def __text_normalize(self, text):
|
| punctuation = [",", "、", "،", ";", "(", ".", "。", "…", "!", "–", ":", "?"]
|
| map_to = "."
|
| punctuation_pattern = re.compile(f"[{''.join(re.escape(p) for p in punctuation)}]")
|
|
|
| text = punctuation_pattern.sub(map_to, text)
|
|
|
| text = re.sub(r'\s+', ' ', text).strip()
|
| return text
|
| def __merge_fragments(self, texts, n):
|
| merged = []
|
| i = 0
|
| while i < len(texts):
|
| fragment = texts[i]
|
| j = i + 1
|
| while len(fragment.split()) < n and j < len(texts):
|
| fragment += ", " + texts[j]
|
| j += 1
|
| merged.append(fragment)
|
| i = j
|
| if len(merged[-1].split()) < n and len(merged) > 1:
|
| merged[-2] = merged[-2] + ", " + merged[-1]
|
| del merged[-1]
|
| else:
|
| merged[-1] = merged[-1]
|
| return merged
|
| def wave_preprocess(self, wave):
|
| to_mel = torchaudio.transforms.MelSpectrogram(n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
|
| mean, std = -4, 4
|
| wave_tensor = torch.from_numpy(wave).float()
|
| mel_tensor = to_mel(wave_tensor)
|
| mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
|
| return mel_tensor
|
| def text_preprocess(self, text, n_merge=12):
|
| text_norm = self.__text_normalize(text).split(".")
|
| text_norm = [s.strip() for s in text_norm]
|
| text_norm = list(filter(lambda x: x != '', text_norm))
|
| text_norm = self.__merge_fragments(text_norm, n=n_merge)
|
| return text_norm
|
| def length_to_mask(self, lengths):
|
| mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
|
| mask = torch.gt(mask+1, lengths.unsqueeze(1))
|
| return mask
|
|
|
|
|
| class StyleTTS2(torch.nn.Module):
|
| def __init__(self, config_path, models_path):
|
| super().__init__()
|
| self.register_buffer("get_device", torch.empty(0))
|
| self.preprocess = Preprocess()
|
| self.ref_s = None
|
| config = yaml.safe_load(open(config_path, "r", encoding="utf-8"))
|
|
|
| try:
|
| symbols = (
|
| list(config['symbol']['pad']) +
|
| list(config['symbol']['punctuation']) +
|
| list(config['symbol']['letters']) +
|
| list(config['symbol']['letters_ipa']) +
|
| list(config['symbol']['extend'])
|
| )
|
| symbol_dict = {}
|
| for i in range(len((symbols))):
|
| symbol_dict[symbols[i]] = i
|
|
|
| n_token = len(symbol_dict) + 1
|
| print("\nFound:", n_token, "symbols")
|
| except Exception as e:
|
| print(f"\nERROR: Cannot find {e} in config file!\nYour config file is likely outdated, please download updated version from the repository.")
|
| raise SystemExit(1)
|
|
|
| args = self.__recursive_munch(config['model_params'])
|
| args['n_token'] = n_token
|
|
|
| self.cleaner = TextCleaner(symbol_dict, debug=False)
|
|
|
| assert args.decoder.type in ['hifigan'], 'Decoder type unknown'
|
|
|
| self.decoder = Decoder(dim_in=args.hidden_dim, style_dim=args.style_dim, dim_out=args.n_mels,
|
| resblock_kernel_sizes = args.decoder.resblock_kernel_sizes,
|
| upsample_rates = args.decoder.upsample_rates,
|
| upsample_initial_channel=args.decoder.upsample_initial_channel,
|
| resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
|
| upsample_kernel_sizes=args.decoder.upsample_kernel_sizes)
|
| self.predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
|
| self.text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
|
| self.style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim)
|
|
|
| self.__load_models(models_path)
|
|
|
| def __recursive_munch(self, d):
|
| if isinstance(d, dict):
|
| return Munch((k, self.__recursive_munch(v)) for k, v in d.items())
|
| elif isinstance(d, list):
|
| return [self.__recursive_munch(v) for v in d]
|
| else:
|
| return d
|
|
|
| def __replace_outliers_zscore(self, tensor, threshold=3.0, factor=0.95):
|
| mean = tensor.mean()
|
| std = tensor.std()
|
| z = (tensor - mean) / std
|
|
|
|
|
| outlier_mask = torch.abs(z) > threshold
|
|
|
| sign = torch.sign(tensor - mean)
|
| replacement = mean + sign * (threshold * std * factor)
|
|
|
| result = tensor.clone()
|
| result[outlier_mask] = replacement[outlier_mask]
|
|
|
| return result
|
|
|
| def __load_models(self, models_path):
|
| module_params = []
|
| model = {'decoder':self.decoder, 'predictor':self.predictor, 'text_encoder':self.text_encoder, 'style_encoder':self.style_encoder}
|
|
|
| params_whole = torch.load(models_path, map_location='cpu')
|
| params = params_whole['net']
|
| params = {key: value for key, value in params.items() if key in model.keys()}
|
|
|
| for key in model:
|
| try:
|
| model[key].load_state_dict(params[key])
|
| except:
|
| from collections import OrderedDict
|
| state_dict = params[key]
|
| new_state_dict = OrderedDict()
|
| for k, v in state_dict.items():
|
| name = k[7:]
|
| new_state_dict[name] = v
|
| model[key].load_state_dict(new_state_dict, strict=False)
|
|
|
| total_params = sum(p.numel() for p in model[key].parameters())
|
| print(key,":",total_params)
|
| module_params.append(total_params)
|
|
|
| print('\nTotal',":",sum(module_params))
|
|
|
| def __compute_style(self, path, denoise, split_dur):
|
| device = self.get_device.device
|
| denoise = min(denoise, 1)
|
| if split_dur != 0: split_dur = max(int(split_dur), 1)
|
| max_samples = 24000*20
|
| print("Computing the style for:", path)
|
|
|
| wave, sr = librosa.load(path, sr=24000)
|
| audio, index = librosa.effects.trim(wave, top_db=30)
|
| if sr != 24000:
|
| audio = librosa.resample(audio, sr, 24000)
|
| if len(audio) > max_samples:
|
| audio = audio[:max_samples]
|
|
|
| if denoise > 0.0:
|
| audio_denoise = nr.reduce_noise(y=audio, sr=sr, n_fft=2048, win_length=1200, hop_length=300)
|
| audio = audio*(1-denoise) + audio_denoise*denoise
|
|
|
| with torch.no_grad():
|
| if split_dur>0 and len(audio)/sr>=4:
|
|
|
| count = 0
|
| ref_s = None
|
| jump = sr*split_dur
|
| total_len = len(audio)
|
|
|
|
|
| mel_tensor = self.preprocess.wave_preprocess(audio[0:jump]).to(device)
|
| ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
|
| count += 1
|
| for i in range(jump, total_len, jump):
|
| if i+jump >= total_len:
|
| left_dur = (total_len-i)/sr
|
| if left_dur >= 1:
|
| mel_tensor = self.preprocess.wave_preprocess(audio[i:total_len]).to(device)
|
| ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
|
| count += 1
|
| continue
|
| mel_tensor = self.preprocess.wave_preprocess(audio[i:i+jump]).to(device)
|
| ref_s += self.style_encoder(mel_tensor.unsqueeze(1))
|
| count += 1
|
| ref_s /= count
|
| else:
|
| mel_tensor = self.preprocess.wave_preprocess(audio).to(device)
|
| ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
|
|
|
| return ref_s
|
|
|
| def __inference(self, phonem, ref_s, speed=1, prev_d_mean=0, t=0.1):
|
| device = self.get_device.device
|
| speed = min(max(speed, 0.0001), 2)
|
|
|
| phonem = ' '.join(word_tokenize(phonem))
|
| tokens = self.cleaner(phonem)
|
| tokens.insert(0, 0)
|
| tokens.append(0)
|
| tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
|
|
|
| with torch.no_grad():
|
| input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
|
| text_mask = self.preprocess.length_to_mask(input_lengths).to(device)
|
|
|
|
|
| t_en = self.text_encoder(tokens, input_lengths, text_mask)
|
| s = ref_s.to(device)
|
|
|
|
|
| d = self.predictor.text_encoder(t_en, s, input_lengths, text_mask)
|
| x, _ = self.predictor.lstm(d)
|
| duration = self.predictor.duration_proj(x)
|
| duration = torch.sigmoid(duration).sum(axis=-1)
|
|
|
| if prev_d_mean != 0:
|
| dur_stats = torch.empty(duration.shape).normal_(mean=prev_d_mean, std=duration.std()).to(device)
|
| else:
|
| dur_stats = torch.empty(duration.shape).normal_(mean=duration.mean(), std=duration.std()).to(device)
|
| duration = duration*(1-t) + dur_stats*t
|
| duration[:,1:-2] = self.__replace_outliers_zscore(duration[:,1:-2])
|
|
|
| duration /= speed
|
|
|
| pred_dur = torch.round(duration.squeeze()).clamp(min=1)
|
| pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
|
| c_frame = 0
|
| for i in range(pred_aln_trg.size(0)):
|
| pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
|
| c_frame += int(pred_dur[i].data)
|
| alignment = pred_aln_trg.unsqueeze(0).to(device)
|
|
|
|
|
| en = (d.transpose(-1, -2) @ alignment)
|
| F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
|
| asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
|
|
|
| out = self.decoder(asr, F0_pred, N_pred, s)
|
|
|
| return out.squeeze().cpu().numpy(), duration.mean()
|
|
|
| def get_styles(self, speaker, denoise=0.3, avg_style=True, load_styles=False):
|
| if not load_styles:
|
| if avg_style: split_dur = 3
|
| else: split_dur = 0
|
| self.ref_s = self.__compute_style(speaker['path'], denoise=denoise, split_dur=split_dur)
|
| else:
|
| if self.ref_s is None:
|
| raise Exception("Have to compute or load the styles first!")
|
| style = {
|
| 'style': self.ref_s,
|
| 'path': speaker['path'],
|
| 'speed': speaker['speed'],
|
| }
|
| return style
|
|
|
| def save_styles(self, save_dir):
|
| if self.ref_s is not None:
|
| torch.save(self.ref_s, save_dir)
|
| print("Saved styles!")
|
| else:
|
| raise Exception("Have to compute the styles before saving it.")
|
|
|
| def load_styles(self, save_dir):
|
| try:
|
| self.ref_s = torch.load(save_dir)
|
| print("Loaded styles!")
|
| except Exception as e:
|
| print(e)
|
|
|
| def generate(self, phonem, style, stabilize=True, n_merge=16):
|
| if stabilize: smooth_value=0.2
|
| else: smooth_value=0
|
|
|
| list_wav = []
|
| prev_d_mean = 0
|
|
|
| print("Generating Audio...")
|
| text_norm = self.preprocess.text_preprocess(phonem, n_merge=n_merge)
|
| for sentence in text_norm:
|
| wav, prev_d_mean = self.__inference(sentence, style['style'], speed=style['speed'], prev_d_mean=prev_d_mean, t=smooth_value)
|
| wav = wav[4000:-4000]
|
| list_wav.append(wav)
|
|
|
| final_wav = np.concatenate(list_wav)
|
| final_wav = np.concatenate([np.zeros([4000]), final_wav, np.zeros([4000])], axis=0)
|
| return final_wav |