| import sys | |
| import os | |
| import base64 | |
| from tts import TTS | |
| from utils.file_utils import load_prompt_speech_from_file, load_voices | |
| from app.config import settings | |
| from datetime import datetime | |
| def load_model( | |
| speed, | |
| voice, | |
| text, | |
| output_path, | |
| output_format="wav", | |
| ): | |
| print("Loading TTS model...", settings.DIR_ROOT) | |
| tts_obj = TTS(model_dir=os.path.join(settings.DIR_ROOT, "VietTTS", "models")) | |
| VOICE_MAP = load_voices(os.path.join(settings.DIR_ROOT, "VietTTS", "samples")) | |
| speed = float(speed) | |
| if voice.isdigit(): | |
| voice_file = list(VOICE_MAP.values())[int(voice)] | |
| else: | |
| voice_file = VOICE_MAP.get(voice) | |
| if not voice_file or not os.path.exists(voice_file): | |
| raise ValueError("Voice file not found") | |
| print(f"Output path: {output_path}") | |
| prompt_speech_16k = load_prompt_speech_from_file(filepath=voice_file, min_duration=3, max_duration=10) | |
| tts_obj.tts_to_file(text=text, prompt_speech_16k=prompt_speech_16k, output_path=output_path, speed=speed) | |
| print("END TTS worker") | |
| if __name__ == "__main__": | |
| speed = sys.argv[1] | |
| voice = sys.argv[2] | |
| text = sys.argv[3] | |
| output_path = sys.argv[4] | |
| output_format = sys.argv[5] if len(sys.argv) > 5 else "wav" | |
| load_model(speed, voice, text, output_path, output_format) | |