| import os
|
| import subprocess
|
| import torchaudio
|
| import soundfile
|
| import numpy as np
|
| from glob import glob
|
| from loguru import logger
|
| from huggingface_hub import snapshot_download
|
|
|
| from VietTTS.utils.vad import get_speech
|
|
|
| import torchaudio
|
| import os
|
| import subprocess
|
| import tempfile
|
|
|
|
|
| def convert_to_wav(input_filepath: str, target_sr: int) -> str:
|
| """
|
| Convert an input audio file to WAV format with the desired sample rate using FFmpeg.
|
|
|
| Args:
|
| input_filepath (str): Path to the input audio file.
|
| target_sr (int): Target sample rate.
|
|
|
| Returns:
|
| str: Path to the converted WAV file.
|
| """
|
| temp_wav_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
| temp_wav_filepath = temp_wav_file.name
|
| temp_wav_file.close()
|
|
|
| ffmpeg_command = [
|
| "ffmpeg", "-y",
|
| "-loglevel", "error",
|
| "-i", input_filepath,
|
| "-ar", str(target_sr),
|
| "-ac", "1",
|
| temp_wav_filepath
|
| ]
|
|
|
| result = subprocess.run(ffmpeg_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
| if result.returncode != 0:
|
| os.unlink(temp_wav_filepath)
|
| raise RuntimeError(f"FFmpeg conversion failed: {result.stderr.decode()}")
|
|
|
| return temp_wav_filepath
|
|
|
|
|
| def load_wav(filepath: str, target_sr: int):
|
| """
|
| Load an audio file in any supported format, convert it to WAV, and load as a tensor.
|
|
|
| Args:
|
| filepath (str): Path to the audio file in any format.
|
| target_sr (int): Target sample rate.
|
|
|
| Returns:
|
| Tensor: Loaded audio tensor resampled to the target sample rate.
|
| """
|
|
|
| if not filepath.lower().endswith(".wav"):
|
| logger.info(f"Converting {filepath} to WAV format")
|
| filepath = convert_to_wav(filepath, target_sr)
|
|
|
|
|
| speech, sample_rate = torchaudio.load(filepath)
|
| speech = speech.mean(dim=0, keepdim=True)
|
| if sample_rate != target_sr:
|
| assert sample_rate > target_sr, f'WAV sample rate {sample_rate} must be greater than {target_sr}'
|
| speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
|
|
| return speech
|
|
|
|
|
| def save_wav(wav: np.ndarray, sr: int, filepath: str):
|
| soundfile.write(filepath, wav, sr)
|
|
|
|
|
| def load_prompt_speech_from_file(filepath: str, min_duration: float=3, max_duration: float=5, return_numpy: bool=False):
|
| wav = load_wav(filepath, 16000)
|
|
|
| if wav.abs().max() > 0.9:
|
| wav = wav / wav.abs().max() * 0.9
|
|
|
| wav = get_speech(
|
| audio_input=wav.squeeze(0),
|
| min_duration=min_duration,
|
| max_duration=max_duration,
|
| return_numpy=return_numpy
|
| )
|
| return wav
|
|
|
|
|
| def load_voices(voice_dir: str):
|
| files = glob(os.path.join(voice_dir, '*.wav')) + glob(os.path.join(voice_dir, '*.mp3'))
|
| voice_name_map = {
|
| os.path.basename(f).split('.')[0]: f
|
| for f in files
|
| }
|
| return voice_name_map
|
|
|
|
|
| def download_model(save_dir: str):
|
| snapshot_download(repo_id="duyv/viet-tts", local_dir=save_dir) |