| import spaces |
| import os |
| import random |
| import argparse |
|
|
| import torch |
| import gradio as gr |
| import numpy as np |
|
|
| import ChatTTS |
|
|
| import se_extractor |
| from api import BaseSpeakerTTS, ToneColorConverter |
| import soundfile |
|
|
| from tts_voice import tts_order_voice |
| import edge_tts |
| import tempfile |
| import anyio |
|
|
| print("loading ChatTTS model...") |
| chat = ChatTTS.Chat() |
| chat.load_models() |
|
|
|
|
| def generate_seed(): |
| new_seed = random.randint(1, 100000000) |
| return { |
| "__type__": "update", |
| "value": new_seed |
| } |
|
|
| @spaces.GPU |
| def chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, output_path=None): |
|
|
| torch.manual_seed(audio_seed_input) |
| rand_spk = torch.randn(768) |
| params_infer_code = { |
| 'spk_emb': rand_spk, |
| 'temperature': temperature, |
| 'top_P': top_P, |
| 'top_K': top_K, |
| } |
| params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'} |
| |
| torch.manual_seed(text_seed_input) |
|
|
| if refine_text_flag: |
| if refine_text_input: |
| params_refine_text['prompt'] = refine_text_input |
| text = chat.infer(text, |
| skip_refine_text=False, |
| refine_text_only=True, |
| params_refine_text=params_refine_text, |
| params_infer_code=params_infer_code |
| ) |
| print("Text has been refined!") |
| |
| wav = chat.infer(text, |
| skip_refine_text=True, |
| params_refine_text=params_refine_text, |
| params_infer_code=params_infer_code |
| ) |
| |
| audio_data = np.array(wav[0]).flatten() |
| sample_rate = 24000 |
| text_data = text[0] if isinstance(text, list) else text |
|
|
| if output_path is None: |
| return [(sample_rate, audio_data), text_data] |
| else: |
| soundfile.write(output_path, audio_data, sample_rate) |
|
|
| |
|
|
| ckpt_base_en = 'checkpoints/base_speakers/EN' |
| ckpt_converter_en = 'checkpoints/converter' |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
|
| base_speaker_tts = BaseSpeakerTTS(f'{ckpt_base_en}/config.json', device=device) |
| base_speaker_tts.load_ckpt(f'{ckpt_base_en}/checkpoint.pth') |
|
|
| tone_color_converter = ToneColorConverter(f'{ckpt_converter_en}/config.json', device=device) |
| tone_color_converter.load_ckpt(f'{ckpt_converter_en}/checkpoint.pth') |
|
|
|
|
| def generate_audio(text, audio_ref, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input): |
| source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device) |
| reference_speaker = audio_ref |
| target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True) |
| save_path = "output.wav" |
|
|
| |
| src_path = "tmp.wav" |
| chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, src_path) |
| print("Ready for voice cloning!") |
| |
| source_se, audio_name = se_extractor.get_se(src_path, tone_color_converter, target_dir='processed', vad=True) |
| print("Get source segment!") |
| |
| |
| encode_message = "@Hilley" |
| |
| tone_color_converter.convert( |
| audio_src_path=src_path, |
| src_se=source_se, |
| tgt_se=target_se, |
| output_path=save_path, |
| message=encode_message) |
|
|
| ''' |
| # convert from data |
| src_path = None |
| sample_rate, audio = chat_tts(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag, refine_text_input, src_path)[0] |
| print("Ready for voice cloning!") |
| tone_color_converter.convert_data( |
| audio=audio, |
| sample_rate=sample_rate, |
| src_se=source_se, |
| tgt_se=target_se, |
| output_path=save_path, |
| message=encode_message) |
| ''' |
| print("Finished!") |
|
|
| return "output.wav" |
|
|
| def vc_en(text, audio_ref, style_mode): |
| if style_mode=="default": |
| source_se = torch.load(f'{ckpt_base_en}/en_default_se.pth').to(device) |
| reference_speaker = audio_ref |
| target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True) |
| save_path = "output.wav" |
|
|
| |
| src_path = "tmp.wav" |
| base_speaker_tts.tts(text, src_path, speaker='default', language='English', speed=1.0) |
|
|
| |
| encode_message = "@MyShell" |
| tone_color_converter.convert( |
| audio_src_path=src_path, |
| src_se=source_se, |
| tgt_se=target_se, |
| output_path=save_path, |
| message=encode_message) |
|
|
| else: |
| source_se = torch.load(f'{ckpt_base_en}/en_style_se.pth').to(device) |
| reference_speaker = audio_ref |
| target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True) |
|
|
| save_path = "output.wav" |
|
|
| |
| src_path = "tmp.wav" |
| base_speaker_tts.tts(text, src_path, speaker=style_mode, language='English', speed=0.9) |
|
|
| |
| encode_message = "@MyShell" |
| tone_color_converter.convert( |
| audio_src_path=src_path, |
| src_se=source_se, |
| tgt_se=target_se, |
| output_path=save_path, |
| message=encode_message) |
|
|
| return "output.wav" |
|
|
| language_dict = tts_order_voice |
|
|
| base_speaker = "base_audio.mp3" |
| source_se, audio_name = se_extractor.get_se(base_speaker, tone_color_converter, vad=True) |
|
|
| async def text_to_speech_edge(text, audio_ref, language_code): |
| voice = language_dict[language_code] |
| communicate = edge_tts.Communicate(text, voice) |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: |
| tmp_path = tmp_file.name |
|
|
| await communicate.save(tmp_path) |
|
|
| reference_speaker = audio_ref |
| target_se, audio_name = se_extractor.get_se(reference_speaker, tone_color_converter, target_dir='processed', vad=True) |
| save_path = "output.wav" |
|
|
| |
| encode_message = "@MyShell" |
| tone_color_converter.convert( |
| audio_src_path=tmp_path, |
| src_se=source_se, |
| tgt_se=target_se, |
| output_path=save_path, |
| message=encode_message) |
|
|
| return "output.wav" |
|
|
|
|
| with gr.Blocks() as demo: |
| |
|
|
| default_text = "Today a man knocked on my door and asked for a small donation toward the local swimming pool. I gave him a glass of water." |
| text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text) |
| voice_ref = gr.Audio(label="Reference Audio", type="filepath", value="base_audio.mp3") |
| |
| with gr.Tab("💕Super Natural"): |
| default_refine_text = "[oral_2][laugh_0][break_6]" |
| refine_text_checkbox = gr.Checkbox(label="Refine text", info="'oral' means add filler words, 'laugh' means add laughter, and 'break' means add a pause. (0-10) ", value=True) |
| refine_text_input = gr.Textbox(label="Refine Prompt", lines=1, placeholder="Please Refine Prompt...", value=default_refine_text) |
|
|
| with gr.Row(): |
| temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature") |
| top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P") |
| top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K") |
|
|
| with gr.Row(): |
| audio_seed_input = gr.Number(value=42, label="Speaker Seed") |
| generate_audio_seed = gr.Button("\U0001F3B2") |
| text_seed_input = gr.Number(value=42, label="Text Seed") |
| generate_text_seed = gr.Button("\U0001F3B2") |
|
|
| generate_button = gr.Button("Generate!") |
| |
| |
| audio_output = gr.Audio(label="Output Audio") |
|
|
| generate_audio_seed.click(generate_seed, |
| inputs=[], |
| outputs=audio_seed_input) |
| |
| generate_text_seed.click(generate_seed, |
| inputs=[], |
| outputs=text_seed_input) |
| |
| generate_button.click(generate_audio, |
| inputs=[text_input, voice_ref, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox, refine_text_input], |
| outputs=audio_output) |
| |
| with gr.Tab("💕Emotion Control"): |
| emo_pick = gr.Dropdown(label="Emotion", info="🙂default😊friendly🤫whispering😄cheerful😱terrified😡angry😢sad", choices=["default", "friendly", "whispering", "cheerful", "terrified", "angry", "sad"], value="default") |
| generate_button_emo = gr.Button("Generate!", variant="primary") |
| audio_emo = gr.Audio(label="Output Audio", type="filepath") |
| generate_button_emo.click(vc_en, [text_input, voice_ref, emo_pick], audio_emo) |
|
|
| with gr.Tab("💕multilingual"): |
| language = gr.Dropdown(choices=list(language_dict.keys()), value=list(language_dict.keys())[15], label="Language") |
| generate_button_ml = gr.Button("Generate!", variant="primary") |
| audio_ml = gr.Audio(label="Output Audio", type="filepath") |
| generate_button_ml.click(text_to_speech_edge, [text_input, voice_ref, language], audio_ml) |
|
|
| parser = argparse.ArgumentParser(description='ChatVC demo Launch') |
| parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name') |
| parser.add_argument('--server_port', type=int, default=8080, help='Server port') |
| args = parser.parse_args() |
|
|
| |
|
|
|
|
|
|
|
|
| if __name__ == '__main__': |
| demo.launch() |