rvc / pipeline /training.py
ibcplateformes
Add voice similarity control + improve reference audio processing
969158e
"""
Voice model creation: save a reference audio clip for Seed-VC zero-shot conversion.
No neural network training needed - Seed-VC uses in-context learning from
reference audio at inference time.
"""
import os
import logging
import shutil
logger = logging.getLogger(__name__)
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(duration=60, **kwargs):
def decorator(fn):
return fn
return decorator
# Dummy GPU-decorated function so ZeroGPU detects a GPU function at startup
@spaces.GPU(duration=10)
def _gpu_warmup():
"""Minimal GPU function for ZeroGPU detection."""
import torch
return torch.cuda.is_available() if hasattr(torch.cuda, "is_available") else False
def save_voice_reference(
audio_path,
model_name,
progress_callback=None,
):
"""
Save a voice reference audio clip as the user's 'voice model'.
With Seed-VC, no training is needed. The reference audio (3-30 seconds)
is used directly at inference time for zero-shot voice conversion.
Args:
audio_path: Path to the uploaded voice recording
model_name: Name for the voice model
progress_callback: Optional callback for progress updates
Returns:
(reference_path, None) - path to saved reference audio
"""
import librosa
import soundfile as sf
import numpy as np
from pipeline.storage import LOCAL_MODELS_DIR, upload_model
if progress_callback:
progress_callback(0.1, "Chargement de l'audio...")
# Load and preprocess the reference audio
audio, sr = librosa.load(audio_path, sr=44100, mono=True)
duration = len(audio) / sr
logger.info("Reference audio: {:.1f}s at {}Hz".format(duration, sr))
if duration < 2.0:
raise RuntimeError(
"Audio trop court ({:.1f}s). Minimum 3 secondes recommande.".format(duration)
)
if progress_callback:
progress_callback(0.3, "Optimisation de la reference vocale...")
# 1. Trim silence from start and end (aggressive: top_db=20)
audio_trimmed, _ = librosa.effects.trim(audio, top_db=20)
if len(audio_trimmed) > sr * 2:
audio = audio_trimmed
# 2. Limit to 25 seconds (Seed-VC clips reference to 25s internally)
max_samples = 25 * sr
if len(audio) > max_samples:
audio = audio[:max_samples]
logger.info("Trimmed reference to 25s (Seed-VC effective max).")
# 3. Remove low-frequency noise (high-pass filter at 80Hz)
try:
from pedalboard import Pedalboard, HighpassFilter, Compressor, Gain
ref_board = Pedalboard([
HighpassFilter(cutoff_frequency_hz=80.0),
# Light compression to even out the reference voice level
Compressor(threshold_db=-20.0, ratio=2.0, attack_ms=10.0, release_ms=150.0),
Gain(gain_db=1.0),
])
audio_2d = audio.reshape(1, -1).astype(np.float32)
audio_2d = ref_board(audio_2d, sr)
audio = audio_2d.squeeze()
except Exception as e:
logger.warning("Pedalboard processing skipped: {}".format(e))
# 4. RMS normalize to -16 dBFS (slightly louder than converted vocals
# to give the speaker embedding model a strong signal)
rms = np.sqrt(np.mean(audio ** 2))
target_rms = 10 ** (-16.0 / 20.0)
if rms > 1e-6:
audio = audio * (target_rms / rms)
audio = np.clip(audio, -0.99, 0.99)
if progress_callback:
progress_callback(0.6, "Sauvegarde de la reference vocale...")
# Save to local models directory
local_model_dir = os.path.join(LOCAL_MODELS_DIR, model_name)
os.makedirs(local_model_dir, exist_ok=True)
reference_path = os.path.join(local_model_dir, "{}_ref.wav".format(model_name))
sf.write(reference_path, audio, 44100, subtype="PCM_16")
# Also save a .pth marker for compatibility with storage/listing
import torch
marker_path = os.path.join(local_model_dir, "{}.pth".format(model_name))
torch.save({
"type": "seed_vc_reference",
"reference_audio": "{}_ref.wav".format(model_name),
"duration": len(audio) / sr,
"sample_rate": 44100,
}, marker_path)
if progress_callback:
progress_callback(0.8, "Upload vers HuggingFace...")
# Upload to HF
try:
upload_model(model_name, marker_path, reference_path=reference_path)
except Exception as e:
logger.warning("Failed to upload to HF (non-critical): {}".format(e))
if progress_callback:
progress_callback(1.0, "Reference vocale sauvegardee !")
final_duration = len(audio) / sr
logger.info("Voice reference saved: {} ({:.1f}s)".format(reference_path, final_duration))
return marker_path, reference_path