rvc / pipeline /inference.py
ibcplateformes
Add voice similarity control + improve reference audio processing
969158e
"""
Voice conversion module using Seed-VC (zero-shot diffusion transformer).
Based on the official Seed-VC app_svc.py implementation.
"""
import os
import sys
import logging
import numpy as np
import torch
import torchaudio
import librosa
logger = logging.getLogger(__name__)
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(duration=60, **kwargs):
def decorator(fn):
return fn
return decorator
OUTPUT_DIR = "/tmp/rvc_output"
# Cached models (loaded once, reused across calls)
_model_cache = {}
def _load_seed_vc_models(device):
"""Load Seed-VC singing voice conversion models (following official app_svc.py)."""
if "model" in _model_cache:
return _model_cache
import yaml
from modules.commons import recursive_munch, build_model, load_checkpoint
from hf_utils import load_custom_model_from_hf
logger.info("Loading Seed-VC models on {}...".format(device))
# Load the singing model (F0-conditioned, whisper-base, 44kHz, BigVGAN)
dit_checkpoint_path, dit_config_path = load_custom_model_from_hf(
"Plachta/Seed-VC",
"DiT_seed_v2_uvit_whisper_base_f0_44k_bigvgan_pruned_ft_ema_v2.pth",
"config_dit_mel_seed_uvit_whisper_base_f0_44k.yml",
)
config = yaml.safe_load(open(dit_config_path, "r"))
model_params = recursive_munch(config["model_params"])
model_params.dit_type = "DiT"
model = build_model(model_params, stage="DiT")
hop_length = config["preprocess_params"]["spect_params"]["hop_length"]
sr = config["preprocess_params"]["sr"]
# Load checkpoint
model, _, _, _ = load_checkpoint(
model, None, dit_checkpoint_path,
load_only_params=True, ignore_modules=[], is_distributed=False,
)
for key in model:
model[key].eval()
model[key].to(device)
# Setup caches for faster inference
model.cfm.estimator.setup_caches(max_batch_size=1, max_seq_length=8192)
# Load CAMPPlus speaker embedding model
from modules.campplus.DTDNN import CAMPPlus
campplus_ckpt_path = load_custom_model_from_hf(
"funasr/campplus", "campplus_cn_common.bin", config_filename=None
)
campplus_model = CAMPPlus(feat_dim=80, embedding_size=192)
campplus_model.load_state_dict(torch.load(campplus_ckpt_path, map_location="cpu"))
campplus_model.eval()
campplus_model.to(device)
# Load BigVGAN vocoder
from modules.bigvgan import bigvgan
bigvgan_name = model_params.vocoder.name
bigvgan_model = bigvgan.BigVGAN.from_pretrained(bigvgan_name, use_cuda_kernel=False)
bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval().to(device)
# Load Whisper speech tokenizer (using transformers, NOT custom module)
from transformers import AutoFeatureExtractor, WhisperModel
whisper_name = model_params.speech_tokenizer.name
whisper_model = WhisperModel.from_pretrained(whisper_name, torch_dtype=torch.float16).to(device)
del whisper_model.decoder
whisper_feature_extractor = AutoFeatureExtractor.from_pretrained(whisper_name)
def semantic_fn(waves_16k):
ori_inputs = whisper_feature_extractor(
[waves_16k.squeeze(0).cpu().numpy()],
return_tensors="pt",
return_attention_mask=True,
)
ori_input_features = whisper_model._mask_input_features(
ori_inputs.input_features, attention_mask=ori_inputs.attention_mask
).to(device)
with torch.no_grad():
ori_outputs = whisper_model.encoder(
ori_input_features.to(whisper_model.encoder.dtype),
head_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
S_ori = ori_outputs.last_hidden_state.to(torch.float32)
S_ori = S_ori[:, :waves_16k.size(-1) // 320 + 1]
return S_ori
# Mel spectrogram
from modules.audio import mel_spectrogram
mel_fn_args = {
"n_fft": config["preprocess_params"]["spect_params"]["n_fft"],
"win_size": config["preprocess_params"]["spect_params"]["win_length"],
"hop_size": config["preprocess_params"]["spect_params"]["hop_length"],
"num_mels": config["preprocess_params"]["spect_params"]["n_mels"],
"sampling_rate": sr,
"fmin": config["preprocess_params"]["spect_params"].get("fmin", 0),
"fmax": None if config["preprocess_params"]["spect_params"].get("fmax", "None") == "None" else 8000,
"center": False,
}
to_mel = lambda x: mel_spectrogram(x, **mel_fn_args)
# F0 extractor (RMVPE)
from modules.rmvpe import RMVPE
model_path = load_custom_model_from_hf("lj1995/VoiceConversionWebUI", "rmvpe.pt", None)
rmvpe = RMVPE(model_path, is_half=False, device=device)
f0_fn = rmvpe.infer_from_audio
max_context_window = sr // hop_length * 30
overlap_frame_len = 16
overlap_wave_len = overlap_frame_len * hop_length
_model_cache.update({
"model": model,
"semantic_fn": semantic_fn,
"vocoder_fn": bigvgan_model,
"campplus_model": campplus_model,
"f0_fn": f0_fn,
"to_mel": to_mel,
"sr": sr,
"hop_length": hop_length,
"max_context_window": max_context_window,
"overlap_frame_len": overlap_frame_len,
"overlap_wave_len": overlap_wave_len,
"device": device,
})
logger.info("Seed-VC models loaded (sr={}, hop={})".format(sr, hop_length))
return _model_cache
def adjust_f0_semitones(f0_sequence, n_semitones):
factor = 2 ** (n_semitones / 12)
return f0_sequence * factor
def crossfade(chunk1, chunk2, overlap):
fade_out = np.cos(np.linspace(0, np.pi / 2, overlap)) ** 2
fade_in = np.cos(np.linspace(np.pi / 2, 0, overlap)) ** 2
chunk2[:overlap] = chunk2[:overlap] * fade_in + chunk1[-overlap:] * fade_out
return chunk2
DEBUG_LOG = "/home/user/app/debug_gpu.log"
def _test_import(name, module_path, subattr=None):
"""Test a single import and return (ok, error_msg)."""
try:
import importlib
mod = importlib.import_module(module_path)
if subattr:
getattr(mod, subattr)
return True, "OK"
except Exception as ie:
return False, "{}: {}".format(type(ie).__name__, ie)
@spaces.GPU(duration=300)
def convert_voice(
audio_path,
reference_path,
pitch=0,
diffusion_steps=25,
similarity=0.7,
):
"""
Convert voice using Seed-VC zero-shot singing voice conversion.
Based on the official app_svc.py voice_conversion function.
"""
# CRITICAL: Ensure app directory is in sys.path for ZeroGPU worker
app_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if app_dir not in sys.path:
sys.path.insert(0, app_dir)
os.chdir(app_dir)
# Write debug diagnostics BEFORE attempting anything
try:
with open(DEBUG_LOG, "w") as f:
f.write("=== GPU Worker Debug ===\n")
f.write("app_dir: {}\n".format(app_dir))
f.write("cwd: {}\n".format(os.getcwd()))
f.write("sys.path[:5]: {}\n".format(sys.path[:5]))
f.write("modules/ exists: {}\n".format(os.path.isdir(os.path.join(app_dir, "modules"))))
f.write("hf_utils.py exists: {}\n".format(os.path.isfile(os.path.join(app_dir, "hf_utils.py"))))
f.write("cuda available: {}\n".format(torch.cuda.is_available()))
# Test each critical import
tests = [
("yaml", "yaml", None),
("munch", "munch", "Munch"),
("einops", "einops", None),
("transformers", "transformers", "WhisperModel"),
("modules.commons", "modules.commons", "build_model"),
("hf_utils", "hf_utils", "load_custom_model_from_hf"),
("modules.campplus.DTDNN", "modules.campplus.DTDNN", "CAMPPlus"),
("modules.bigvgan.bigvgan", "modules.bigvgan.bigvgan", "BigVGAN"),
("modules.audio", "modules.audio", "mel_spectrogram"),
("modules.rmvpe", "modules.rmvpe", "RMVPE"),
]
for label, mod_path, attr in tests:
ok, msg = _test_import(label, mod_path, attr)
f.write("IMPORT {}: {} -> {}\n".format("OK" if ok else "FAIL", label, msg))
f.write("=== Import tests done ===\n")
except Exception:
pass
try:
return _convert_voice_impl(
audio_path, reference_path, pitch, diffusion_steps, similarity
)
except Exception as e:
import traceback
tb = traceback.format_exc()
try:
with open(DEBUG_LOG, "a") as f:
f.write("\n=== CONVERSION ERROR ===\n")
f.write(tb)
except Exception:
pass
raise
@torch.no_grad()
@torch.inference_mode()
def _convert_voice_impl(audio_path, reference_path, pitch, diffusion_steps, similarity=0.7):
"""Actual conversion implementation (called from GPU-decorated wrapper)."""
import soundfile as sf
os.makedirs(OUTPUT_DIR, exist_ok=True)
base_name = os.path.splitext(os.path.basename(audio_path))[0]
output_path = os.path.join(OUTPUT_DIR, "{}_converted.wav".format(base_name))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Converting voice with Seed-VC on {}".format(device))
# Load models
cache = _load_seed_vc_models(device)
inference_module = cache["model"]
semantic_fn = cache["semantic_fn"]
vocoder_fn = cache["vocoder_fn"]
campplus_model = cache["campplus_model"]
f0_fn = cache["f0_fn"]
mel_fn = cache["to_mel"]
sr = cache["sr"]
hop_length = cache["hop_length"]
max_context_window = cache["max_context_window"]
overlap_frame_len = cache["overlap_frame_len"]
overlap_wave_len = cache["overlap_wave_len"]
# Load source audio
source_audio = librosa.load(audio_path, sr=sr)[0]
source_audio = torch.tensor(source_audio).unsqueeze(0).float().to(device)
# Load reference audio (clip to 25s as per official code)
ref_audio = librosa.load(reference_path, sr=sr)[0]
ref_audio = torch.tensor(ref_audio[:sr * 25]).unsqueeze(0).float().to(device)
# Resample to 16kHz
ref_waves_16k = torchaudio.functional.resample(ref_audio, sr, 16000)
converted_waves_16k = torchaudio.functional.resample(source_audio, sr, 16000)
# Extract semantic tokens with Whisper
# Handle long audio by chunking (>30s)
if converted_waves_16k.size(-1) <= 16000 * 30:
S_alt = semantic_fn(converted_waves_16k)
else:
overlapping_time = 5 # seconds
S_alt_list = []
buffer = None
traversed_time = 0
while traversed_time < converted_waves_16k.size(-1):
if buffer is None:
chunk = converted_waves_16k[:, traversed_time:traversed_time + 16000 * 30]
else:
chunk = torch.cat([
buffer,
converted_waves_16k[:, traversed_time:traversed_time + 16000 * (30 - overlapping_time)]
], dim=-1)
S_alt = semantic_fn(chunk)
if traversed_time == 0:
S_alt_list.append(S_alt)
else:
S_alt_list.append(S_alt[:, 50 * overlapping_time:])
buffer = chunk[:, -16000 * overlapping_time:]
traversed_time += 30 * 16000 if traversed_time == 0 else chunk.size(-1) - 16000 * overlapping_time
S_alt = torch.cat(S_alt_list, dim=1)
S_ori = semantic_fn(ref_waves_16k)
# Mel spectrograms
mel = mel_fn(source_audio.to(device).float())
mel2 = mel_fn(ref_audio.to(device).float())
target_lengths = torch.LongTensor([mel.size(2)]).to(mel.device)
target2_lengths = torch.LongTensor([mel2.size(2)]).to(mel2.device)
# Speaker embedding from reference
feat2 = torchaudio.compliance.kaldi.fbank(
ref_waves_16k,
num_mel_bins=80,
dither=0,
sample_frequency=16000,
)
feat2 = feat2 - feat2.mean(dim=0, keepdim=True)
style2 = campplus_model(feat2.unsqueeze(0))
# F0 extraction
F0_ori = f0_fn(ref_waves_16k[0], thred=0.03)
F0_alt = f0_fn(converted_waves_16k[0], thred=0.03)
F0_ori = torch.from_numpy(F0_ori).to(device)[None]
F0_alt = torch.from_numpy(F0_alt).to(device)[None]
voiced_F0_ori = F0_ori[F0_ori > 1]
voiced_F0_alt = F0_alt[F0_alt > 1]
log_f0_alt = torch.log(F0_alt + 1e-5)
voiced_log_f0_ori = torch.log(voiced_F0_ori + 1e-5)
voiced_log_f0_alt = torch.log(voiced_F0_alt + 1e-5)
median_log_f0_ori = torch.median(voiced_log_f0_ori)
median_log_f0_alt = torch.median(voiced_log_f0_alt)
# Auto F0 adjust + pitch shift
shifted_log_f0_alt = log_f0_alt.clone()
shifted_log_f0_alt[F0_alt > 1] = log_f0_alt[F0_alt > 1] - median_log_f0_alt + median_log_f0_ori
shifted_f0_alt = torch.exp(shifted_log_f0_alt)
if pitch != 0:
shifted_f0_alt[F0_alt > 1] = adjust_f0_semitones(shifted_f0_alt[F0_alt > 1], pitch)
# Length regulation
cond, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(
S_alt, ylens=target_lengths, n_quantizers=3, f0=shifted_f0_alt
)
prompt_condition, _, codes, commitment_loss, codebook_loss = inference_module.length_regulator(
S_ori, ylens=target2_lengths, n_quantizers=3, f0=F0_ori
)
# Interpolate F0
interpolated_shifted_f0_alt = torch.nn.functional.interpolate(
shifted_f0_alt.unsqueeze(1), size=cond.size(1), mode="nearest"
).squeeze(1)
max_source_window = max_context_window - mel2.size(2)
# Generate chunk by chunk with crossfading
processed_frames = 0
generated_wave_chunks = []
while processed_frames < cond.size(1):
chunk_cond = cond[:, processed_frames:processed_frames + max_source_window]
is_last_chunk = processed_frames + max_source_window >= cond.size(1)
cat_condition = torch.cat([prompt_condition, chunk_cond], dim=1)
with torch.autocast(device_type=device.type, dtype=torch.float16):
vc_target = inference_module.cfm.inference(
cat_condition,
torch.LongTensor([cat_condition.size(1)]).to(mel2.device),
mel2, style2, None, diffusion_steps,
inference_cfg_rate=similarity,
)
vc_target = vc_target[:, :, mel2.size(-1):]
vc_wave = vocoder_fn(vc_target.float()).squeeze().cpu()
if vc_wave.ndim == 1:
vc_wave = vc_wave.unsqueeze(0)
if processed_frames == 0:
if is_last_chunk:
generated_wave_chunks.append(vc_wave[0].cpu().numpy())
break
generated_wave_chunks.append(vc_wave[0, :-overlap_wave_len].cpu().numpy())
previous_chunk = vc_wave[0, -overlap_wave_len:]
processed_frames += vc_target.size(2) - overlap_frame_len
elif is_last_chunk:
output_wave = crossfade(
previous_chunk.cpu().numpy(),
vc_wave[0].cpu().numpy(),
overlap_wave_len,
)
generated_wave_chunks.append(output_wave)
break
else:
output_wave = crossfade(
previous_chunk.cpu().numpy(),
vc_wave[0, :-overlap_wave_len].cpu().numpy(),
overlap_wave_len,
)
generated_wave_chunks.append(output_wave)
previous_chunk = vc_wave[0, -overlap_wave_len:]
processed_frames += vc_target.size(2) - overlap_frame_len
# Concatenate and normalize to -18 dBFS RMS (standard vocal level before mixing)
audio_out = np.concatenate(generated_wave_chunks)
rms = np.sqrt(np.mean(audio_out ** 2))
target_rms = 10 ** (-18.0 / 20.0) # -18 dBFS
if rms > 1e-6:
audio_out = audio_out * (target_rms / rms)
# Safety clip to prevent any overflow
audio_out = np.clip(audio_out, -0.99, 0.99)
# Save
sf.write(output_path, audio_out, sr, subtype="PCM_16")
logger.info("Conversion complete: {} ({:.1f}s)".format(output_path, len(audio_out) / sr))
return output_path