Malay
English

How to infer?

#2
by xiaotangsuda - opened

The audio generated by your model inference is all noise. What is the reason for this? Can you share your inference code?

Cannot work https://huggingface.co/mesolitica/Malaysian-F5-TTS-v3#how-to ?

I used the following code for inference, which uses your model, but the generated speech is all noise.
#! /opt/conda/bin/python
import argparse
import codecs
import re
import tempfile
from pathlib import Path

import librosa
import numpy as np
import soundfile as sf
import tomli
import torch
import torchaudio
import tqdm, os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from einops import rearrange
from pydub import AudioSegment, silence
from transformers import pipeline
from vocos import Vocos

from model import CFM, DiT
from model.utils import (
convert_char_to_bpe,
get_tokenizer,
load_checkpoint,
save_spectrogram,
)
import subprocess, io
base_dir_path = "f5-tts-tools/ckpts/"

def load_malaysian_model(model_cfg, ckpt_path, device):
"""
Loads the F5-TTS model specifically for Malaysian language.
"""
print("Loading tokenizer for Malaysian...")
vocab_char_map, vocab_size = get_tokenizer(
base_dir_path + "F5-TTS-Malaysian/vocab.txt", "custom"
)
print("Tokenizer loaded.")

print("Initializing CFM model with DiT backbone...")
model = CFM(
    transformer=DiT(
        **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
    ),
    mel_spec_kwargs=dict(
        target_sample_rate=target_sample_rate,
        n_mel_channels=n_mel_channels,
        hop_length=hop_length,
    ),
    odeint_kwargs=dict(
        method=ode_method,
    ),
    vocab_char_map=vocab_char_map,
).to(device)
print("Model initialized.")

print(f"Loading checkpoint from: {ckpt_path}")
model = load_checkpoint(model, ckpt_path, device, use_ema=True)
print("Checkpoint loaded successfully.")

return model

def chunk_text(text, max_chars=135):
"""
Splits the input text into chunks, each with a maximum number of characters.
"""
chunks = []
current_chunk = ""
# Split the text into sentences based on punctuation followed by whitespace
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[;:,。!?])", text)

for sentence in sentences:
    if len(current_chunk.encode("utf-8")) + len(
        sentence.encode("utf-8")
    ) <= max_chars:
        current_chunk += (
            sentence + " "
            if sentence and len(sentence[-1].encode("utf-8")) == 1
            else sentence
        )
    else:
        if current_chunk:
            chunks.append(current_chunk.strip())
        current_chunk = (
            sentence + " "
            if sentence and len(sentence[-1].encode("utf-8")) == 1
            else sentence
        )

if current_chunk:
    chunks.append(current_chunk.strip())
return chunks

def infer_batch(
ref_audio,
ref_text,
wave_path,
gen_text_batches,
ema_model,
remove_silence,
cross_fade_duration=0.15,
):
audio, sr = ref_audio
if audio.shape[0] > 1:
audio = torch.mean(audio, dim=0, keepdim=True)

rms = torch.sqrt(torch.mean(torch.square(audio)))
if rms < target_rms:
    audio = audio * target_rms / rms
if sr != target_sample_rate:
    resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
    audio = resampler(audio)
audio = audio.to(device)
generated_waves = []
print(f"Processing {len(gen_text_batches)} text batches...")
for i, gen_text in enumerate(tqdm.tqdm(gen_text_batches)):
    # Prepare the text
    if len(ref_text[-1].encode("utf-8")) == 1:
        ref_text = ref_text + " "
    text_list = [ref_text + gen_text]
    final_text_list = convert_char_to_pinyin(text_list)

    ref_audio_len = audio.shape[-1] // hop_length
    zh_pause_punc = r"。,、;:?!"
    ref_text_len = len(ref_text.encode("utf-8")) + 3 * len(
        re.findall(zh_pause_punc, ref_text)
    )
    gen_text_len = len(gen_text.encode("utf-8")) + 3 * len(
        re.findall(zh_pause_punc, gen_text)
    )
    duration = ref_audio_len + int(
        ref_audio_len / ref_text_len * gen_text_len / speed
    )

    # Inference
    with torch.inference_mode():
        generated, _ = ema_model.sample(
            cond=audio,
            text=final_text_list,
            duration=duration,
            steps=nfe_step,
            cfg_strength=cfg_strength,
            sway_sampling_coef=sway_sampling_coef,
        )

    generated = generated[:, ref_audio_len:, :]
    generated_mel_spec = rearrange(generated, "1 n d -> 1 d n")
    generated_wave = vocos.decode(generated_mel_spec.cpu())
    if rms < target_rms:
        generated_wave = generated_wave * rms / target_rms

    generated_wave = generated_wave.squeeze().cpu().numpy()
    generated_waves.append(generated_wave)

# Combine all generated waves with cross-fading
if cross_fade_duration <= 0 or len(generated_waves) == 1:
    final_wave = np.concatenate(generated_waves)
else:
    print("Applying cross-fade...")
    final_wave = generated_waves[0]
    for i in range(1, len(generated_waves)):
        prev_wave = final_wave
        next_wave = generated_waves[i]

        cross_fade_samples = int(cross_fade_duration * target_sample_rate)
        cross_fade_samples = min(cross_fade_samples, len(prev_wave), len(next_wave))

        if cross_fade_samples <= 0:
            final_wave = np.concatenate([prev_wave, next_wave])
            continue

        prev_overlap = prev_wave[-cross_fade_samples:]
        next_overlap = next_wave[:cross_fade_samples]

        fade_out = np.linspace(1, 0, cross_fade_samples)
        fade_in = np.linspace(0, 1, cross_fade_samples)

        cross_faded_overlap = prev_overlap * fade_out + next_overlap * fade_in

        final_wave = np.concatenate(
            [
                prev_wave[:-cross_fade_samples],
                cross_faded_overlap,
                next_wave[cross_fade_samples:],
            ]
        )

# Save the final audio file
with open(wave_path, "wb") as f:
    print("Resampling to 16kHz for output...")
    resampled_wave = librosa.resample(final_wave, orig_sr=24000, target_sr=16000)
    sf.write(f.name, resampled_wave, 16000)

    if remove_silence:
        print("Removing silence...")
        aseg = AudioSegment.from_file(f.name)
        non_silent_segs = silence.split_on_silence(
            aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500
        )
        non_silent_wave = AudioSegment.silent(duration=0)
        for non_silent_seg in non_silent_segs:
            non_silent_wave += non_silent_seg
        non_silent_wave.export(f.name, format="wav")
    print(f"Output saved to: {f.name}")

def infer(
ref_audio_orig,
ref_text,
gen_text,
wave_path,
model,
remove_silence,
cross_fade_duration=0.15,
):
"""
Main inference function to prepare and run the batch process.
"""
print(f"Generating text: '{gen_text}' using reference audio: '{ref_audio_orig}'")

# If reference text is empty, use Whisper for transcription
if not ref_text.strip():
    print("No reference text provided, transcribing reference audio with Whisper...")
    pipe = pipeline(
        "automatic-speech-recognition",
        model="openai/whisper-large-v3-turbo",
        torch_dtype=torch.float16,
        device=device,
    )
    ref_text = pipe(
        ref_audio_orig,
        chunk_length_s=30,
        batch_size=128,
        generate_kwargs={"task": "transcribe"},
        return_timestamps=False,
    )["text"].strip()
    print(f"Finished transcription. Reference text: '{ref_text}'")
else:
    print("Using provided reference text.")

# Ensure reference text ends with punctuation
if not ref_text.endswith(". ") and not ref_text.endswith("。"):
    if ref_text.endswith("."):
        ref_text += " "
    else:
        ref_text += ". "

# Split the generation text into batches
audio, sr = torchaudio.load(ref_audio_orig)
max_chars = int(
    len(ref_text.encode("utf-8"))
    / (audio.shape[-1] / sr)
    * (25 - audio.shape[-1] / sr)
)
gen_text_batches = chunk_text(gen_text, max_chars=max_chars)
print(f"Reference text: '{ref_text}'")
for i, gen_text_chunk in enumerate(gen_text_batches):
    print(f"Generation chunk {i+1}: '{gen_text_chunk}'")

return infer_batch(
    (audio, sr),
    ref_text,
    wave_path,
    gen_text_batches,
    model,
    remove_silence,
    cross_fade_duration,
)

if name == "main":
parser = argparse.ArgumentParser(
prog="python3 inference_malaysian.py",
description="Commandline interface for F5-TTS Malaysian model.",
)
parser.add_argument(
"-c",
"--config",
help="Configuration file. Default=inference-cli.toml",
default="inference-cli.toml",
)
parser.add_argument(
"-r", "--ref_audio", type=str, required=True, help="Reference audio file < 15 seconds."
)
parser.add_argument(
"-s",
"--ref_text",
type=str,
default="",
help="Subtitle for the reference audio. (Optional, will use Whisper if empty)",
)
parser.add_argument(
"-t", "--gen_text", type=str, help="Text to generate."
)
parser.add_argument(
"-f",
"--gen_file",
type=str,
help="File with text to generate. Ignores --gen_text",
)
parser.add_argument(
"-utt2wav",
"--utt2wav2txt",
type=str,
help="File with format 'utt|ref_text|ref_wav|gen_text' for batch processing.",
)
parser.add_argument(
"-o", "--output_dir", type=str, default="out", help="Path to output folder."
)
parser.add_argument(
"--remove_silence", action="store_true", help="Enable silence removal."
)

args = parser.parse_args()

# Load config from file if it exists, otherwise use defaults
try:
    config = tomli.load(open(args.config, "rb"))
except FileNotFoundError:
    config = {}

# Prioritize command line arguments over config file
ref_audio = args.ref_audio
ref_text = args.ref_text if args.ref_text else config.get("ref_text", "")
utt2wav2txt = args.utt2wav2txt
output_dir = args.output_dir
remove_silence = args.remove_silence or config.get("remove_silence", False)

gen_text_input = args.gen_text
if args.gen_file:
    try:
        with open(args.gen_file, 'r', encoding='utf-8') as f:
            gen_text_input = f.read()
    except Exception as e:
        print(f"Error reading gen_file: {e}")
        exit(1)


if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# --- Static Settings ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Vocoder setup
vocos_local_path = os.path.join(base_dir_path, "models--charactr--vocos-mel-24khz/snapshots/0feb3fdd929bcd6649e0e7c5a688cf7dd012ef21/")
print(f"Load vocos from local path {vocos_local_path}")
vocos = Vocos.from_hparams(f"{vocos_local_path}/config.yaml")
state_dict = torch.load(f"{vocos_local_path}/pytorch_model.bin", map_location=device)
vocos.load_state_dict(state_dict)
vocos.eval()

# Model specific settings
F5TTS_model_cfg = dict(
    dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
)
ckpt_path = os.path.join(base_dir_path, "F5-TTS-Malaysian/model_220000.pt")

# Inference parameters
target_sample_rate = 24000
n_mel_channels = 100
hop_length = 256
target_rms = 0.1
nfe_step = 32
cfg_strength = 2.0
ode_method = "euler"
sway_sampling_coef = -1.0
speed = 1.0

# Load the Malaysian TTS model
ema_model = load_malaysian_model(F5TTS_model_cfg, ckpt_path, device)

# --- Execution Logic ---
if utt2wav2txt:
    print(f"Processing batch file: {utt2wav2txt}")
    try:
        with open(utt2wav2txt, 'r', encoding='utf-8') as f:
            all_lines = [i.strip() for i in f.readlines()]

        for line in tqdm.tqdm(all_lines, desc="Batch Processing"):
            try:
                utt, ref_t, ref_w, gen_t = line.split('|')
                if os.path.isfile(ref_w):
                    wave_path = os.path.join(output_dir, utt + '.wav')
                    infer(ref_w, ref_t, gen_t, wave_path, ema_model, remove_silence)
                else:
                    print(f"Reference audio not found, skipping: {ref_w}")
            except Exception as e:
                print(f"Error processing line '{line}': {e}")
    except Exception as e:
        print(f"Error reading batch file: {e}")

elif gen_text_input:
    output_filename = "output.wav"
    wave_path = os.path.join(output_dir, output_filename)
    infer(ref_audio, ref_text, gen_text_input, wave_path, ema_model, remove_silence)
else:
    print("No text to generate. Please provide --gen_text or --gen_file or --utt2wav.")

Sign up or log in to comment