kingj233414's picture
Update app.py
36e764e verified
import spaces
import torch
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import gradio as gr
import traceback
import gc
import numpy as np
import librosa
import tempfile # Added for temporary file handling
from pydub import AudioSegment
from pydub.effects import normalize
from huggingface_hub import snapshot_download
from tts.infer_cli import MegaTTS3DiTInfer, convert_to_wav, cut_wav
def download_weights():
"""Download model weights from HuggingFace if not already present."""
repo_id = "mrfakename/MegaTTS3-VoiceCloning"
weights_dir = "checkpoints"
if not os.path.exists(weights_dir):
print("Downloading model weights from HuggingFace...")
snapshot_download(
repo_id=repo_id,
local_dir=weights_dir,
local_dir_use_symlinks=False
)
print("Model weights downloaded successfully!")
else:
print("Model weights already exist.")
return weights_dir
# Download weights and initialize model
download_weights()
print("Initializing MegaTTS3 model...")
infer_pipe = MegaTTS3DiTInfer()
print("Model loaded successfully!")
def reset_model():
"""Reset the inference pipeline to recover from CUDA errors."""
global infer_pipe
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
print("Reinitializing MegaTTS3 model...")
infer_pipe = MegaTTS3DiTInfer()
print("Model reinitialized successfully!")
return True
except Exception as e:
print(f"Failed to reinitialize model: {e}")
return False
# --- NEW HELPER FUNCTIONS FOR AUDIO APPENDING ---
def merge_audio(current_state_path, new_audio_path):
"""
Appends new_audio to current_state audio.
Returns path to the combined temporary file.
"""
if not new_audio_path:
return current_state_path
try:
# Load the new chunk
new_seg = AudioSegment.from_file(new_audio_path)
# Create a small silence to separate clips (0.3 seconds)
silence = AudioSegment.silent(duration=300)
if current_state_path is None:
# If this is the first clip
combined = new_seg
else:
# If we already have audio, append silence + new clip
prev_seg = AudioSegment.from_file(current_state_path)
combined = prev_seg + silence + new_seg
# Save to a temp file
fd, path = tempfile.mkstemp(suffix=".wav")
os.close(fd)
combined.export(path, format="wav")
return path
except Exception as e:
print(f"Error merging audio: {e}")
return current_state_path
def clear_audio_state():
"""Resets the state to None"""
return None, None
# -----------------------------------------------
@spaces.GPU
def generate_speech(inp_audio, inp_text, infer_timestep, p_w, t_w):
if not inp_audio or not inp_text:
gr.Warning("Please provide both reference audio and text to generate.")
return None
# Fixed indentation logic here compared to original snippet
try:
print(f"Generating speech with: {inp_text}...")
# Check CUDA availability and clear cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"CUDA device: {torch.cuda.get_device_name()}")
else:
gr.Warning("CUDA is not available. Please check your GPU setup.")
return None
# Robustly preprocess audio
try:
processed_audio_path = preprocess_audio_robust(inp_audio)
# Use existing cut_wav for final trimming
cut_wav(processed_audio_path, max_len=28)
wav_path = processed_audio_path
except Exception as audio_error:
gr.Warning(f"Audio preprocessing failed: {str(audio_error)}")
return None
# Read audio file
with open(wav_path, 'rb') as file:
file_content = file.read()
# Generate speech with proper error handling
try:
resource_context = infer_pipe.preprocess(file_content)
wav_bytes = infer_pipe.forward(resource_context, inp_text, time_step=infer_timestep, p_w=p_w, t_w=t_w)
# Clean up memory after successful generation
cleanup_memory()
return wav_bytes
except RuntimeError as cuda_error:
if "CUDA" in str(cuda_error):
print(f"CUDA error detected: {cuda_error}")
# Try to reset the model to recover from CUDA errors
if reset_model():
gr.Warning("CUDA error occurred. Model has been reset. Please try again.")
else:
gr.Warning("CUDA error occurred and model reset failed. Please restart the application.")
return None
else:
raise cuda_error
except Exception as e:
traceback.print_exc()
gr.Warning(f"Speech generation failed: {str(e)}")
# Clean up CUDA memory on any error
cleanup_memory()
return None
def cleanup_memory():
"""Clean up GPU and system memory."""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def preprocess_audio_robust(audio_path, target_sr=22050, max_duration=30):
"""Robustly preprocess audio to prevent CUDA errors."""
try:
# Load with pydub for robust format handling
audio = AudioSegment.from_file(audio_path)
# Convert to mono if stereo
if audio.channels > 1:
audio = audio.set_channels(1)
# Limit duration to prevent memory issues
if len(audio) > max_duration * 1000: # pydub uses milliseconds
audio = audio[:max_duration * 1000]
# Normalize audio to prevent clipping
audio = normalize(audio)
# Convert to target sample rate
audio = audio.set_frame_rate(target_sr)
# Export to temporary WAV file with specific parameters
temp_path = audio_path.replace(os.path.splitext(audio_path)[1], '_processed.wav')
audio.export(
temp_path,
format="wav",
parameters=["-acodec", "pcm_s16le", "-ac", "1", "-ar", str(target_sr)]
)
# Validate the audio with librosa
wav, sr = librosa.load(temp_path, sr=target_sr, mono=True)
# Check for invalid values
if np.any(np.isnan(wav)) or np.any(np.isinf(wav)):
raise ValueError("Audio contains NaN or infinite values")
# Ensure reasonable amplitude range
if np.max(np.abs(wav)) < 1e-6:
raise ValueError("Audio signal is too quiet")
# Re-save the validated audio
import soundfile as sf
sf.write(temp_path, wav, sr)
return temp_path
except Exception as e:
print(f"Audio preprocessing failed: {e}")
raise ValueError(f"Failed to process audio: {str(e)}")
with gr.Blocks(title="MegaTTS3 Voice Cloning") as demo:
gr.Markdown("# MegaTTS 3 Voice Cloning")
gr.Markdown("MegaTTS 3 is a text-to-speech model trained by ByteDance with exceptional voice cloning capabilities. The original authors did not release the WavVAE encoder, so voice cloning was not publicly available; however, thanks to [@ACoderPassBy](https://modelscope.cn/models/ACoderPassBy/MegaTTS-SFT)'s WavVAE encoder, we can now clone voices with MegaTTS 3!")
gr.Markdown("This is by no means the best voice cloning solution, but it works pretty well for some specific use-cases. Try out multiple and see which one works best for you.")
gr.Markdown("**Please use this Space responsibly and do not abuse it!** This demo is for research and educational purposes only!")
gr.Markdown("h/t to MysteryShack on Discord for the info about the unofficial WavVAE encoder!")
gr.Markdown("### Instructions for Reference Audio:")
gr.Markdown("1. Record or upload a short clip in the 'New Input' box. \n 2. Click **Append to Reference** to add it to your master file. \n 3. The **Full Reference Audio** player shows what will be used for cloning.")
# State to hold the path of the combined audio
reference_state = gr.State(value=None)
with gr.Row():
# Left Column: Inputs and Construction
with gr.Column():
with gr.Group():
gr.Markdown("### Step 1: Input Source")
mic_input = gr.Audio(
label="New Input (Mic/Upload)",
type="filepath",
sources=["microphone", "upload"]
)
with gr.Row():
add_btn = gr.Button("Append to Reference", variant="secondary")
clear_btn = gr.Button("Clear Reference", variant="stop")
with gr.Group():
gr.Markdown("### Step 2: Full Reference Audio")
master_audio = gr.Audio(
label="Full Reference Audio (Used for Cloning)",
type="filepath",
interactive=False # User shouldn't upload here directly, only via append
)
text_input = gr.Textbox(
label="Text to Generate",
placeholder="Enter the text you want to synthesize...",
lines=3
)
with gr.Accordion("Advanced Options", open=False):
infer_timestep = gr.Number(
label="Inference Timesteps",
value=32,
minimum=1,
maximum=100,
step=1
)
p_w = gr.Number(
label="Intelligibility Weight",
value=1.4,
minimum=0.1,
maximum=5.0,
step=0.1
)
t_w = gr.Number(
label="Similarity Weight",
value=3.0,
minimum=0.1,
maximum=10.0,
step=0.1
)
generate_btn = gr.Button("Generate Speech", variant="primary")
# Right Column: Output
with gr.Column():
output_audio = gr.Audio(label="Generated Audio")
# --- EVENT HANDLERS ---
# 1. Handle Appending Audio
add_btn.click(
fn=merge_audio,
inputs=[reference_state, mic_input],
outputs=[reference_state] # Update state first
).then(
fn=lambda x: x, # Then update visual component
inputs=[reference_state],
outputs=[master_audio]
)
# 2. Handle Clearing Audio
clear_btn.click(
fn=clear_audio_state,
inputs=[],
outputs=[reference_state, master_audio]
)
# 3. Handle Generation (NOTE: Now uses master_audio instead of reference_audio)
generate_btn.click(
fn=generate_speech,
inputs=[master_audio, text_input, infer_timestep, p_w, t_w],
outputs=[output_audio]
)
if __name__ == '__main__':
demo.launch(server_name='0.0.0.0', server_port=7860, debug=True)