videovoice / steps /s1b_separate.py
Rafii's picture
deploy: switch to chatterbox requirements @ 035108d
3474e83
"""
Step 1b: Separate vocals from accompaniment using Demucs (Python API).
In-process inference so ZeroGPU can intercept the GPU allocation via
`@spaces.GPU`. Works on CUDA, MPS, and CPU without code changes.
Only runs when preserve_music=True.
"""
import shutil
import subprocess
from pathlib import Path
import torch
import torchaudio
import spaces
_MODEL = None
def _select_device() -> str:
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
def _get_model():
"""Lazy-load htdemucs once per process. Module-level semantics; we load
on first call so the import itself stays cheap on non-GPU envs."""
global _MODEL
if _MODEL is None:
from demucs.pretrained import get_model
print("[s1b] Loading htdemucs on cpu...")
model = get_model("htdemucs")
model.eval()
model.to("cpu")
_MODEL = model
return _MODEL
@spaces.GPU(duration=120)
def _apply_demucs(mix: torch.Tensor, device: str) -> torch.Tensor:
"""GPU-bound inference call. `mix` shape: [1, channels, time]."""
from demucs.apply import apply_model
model = _get_model()
if next(model.parameters()).device.type != device:
print(f"[s1b] Moving htdemucs to {device} inside GPU scope...")
model = model.to(device)
with torch.no_grad():
# apply_model returns [batch, sources, channels, time]
sources = apply_model(
model,
mix.to(device),
shifts=1,
split=True,
overlap=0.25,
device=device,
)
return sources.cpu()
def _load_and_normalise(audio_hq_path: str, target_sr: int, target_ch: int) -> tuple[torch.Tensor, float, float]:
"""Load WAV, resample/remix to match model requirements, z-normalise."""
wav, sr = torchaudio.load(audio_hq_path)
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
if wav.shape[0] == 1 and target_ch == 2:
wav = wav.repeat(2, 1)
elif wav.shape[0] > target_ch:
wav = wav[:target_ch]
mean = wav.mean()
std = wav.std().clamp_min(1e-8)
wav_norm = (wav - mean) / std
return wav_norm.unsqueeze(0), mean.item(), std.item()
def separate_audio(
audio_hq_path: str,
output_dir: str = "tmp",
) -> tuple[str, str]:
"""
Separate vocals from accompaniment using Demucs htdemucs (Python API).
Args:
audio_hq_path: Path to input audio (any sample rate / channels).
output_dir: Directory to write output stems.
Returns:
(vocals_16k_path, accompaniment_path)
"""
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
model = _get_model()
device = _select_device()
target_sr = model.samplerate
target_ch = model.audio_channels
source_names = list(model.sources)
print(f"[s1b] Running Demucs htdemucs on {device} (Python API)...")
mix, mean, std = _load_and_normalise(audio_hq_path, target_sr, target_ch)
sources = _apply_demucs(mix, device)
sources = sources * std + mean
sources = sources[0] # drop batch dim → [num_sources, channels, time]
try:
vocals_idx = source_names.index("vocals")
except ValueError as e:
raise RuntimeError(f"htdemucs is missing 'vocals' source: {source_names}") from e
vocals = sources[vocals_idx]
no_vocals = sum(
sources[i] for i in range(sources.shape[0]) if i != vocals_idx
)
vocals_path = str(out / "vocals.wav")
accompaniment_path = str(out / "accompaniment.wav")
vocals_16k_path = str(out / "vocals_16k.wav")
torchaudio.save(vocals_path, vocals, target_sr)
torchaudio.save(accompaniment_path, no_vocals, target_sr)
print(f"[s1b] Vocals saved → {vocals_path}")
print(f"[s1b] Accompaniment saved → {accompaniment_path}")
# Resample vocals to 16 kHz mono for Whisper/TTS via ffmpeg
# (torchaudio resample works but ffmpeg is more predictable for downstream)
cmd = [
"ffmpeg", "-y",
"-i", vocals_path,
"-ar", "16000",
"-ac", "1",
vocals_16k_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg vocals resample failed:\n{result.stderr}")
print(f"[s1b] Vocals (16 kHz) saved → {vocals_16k_path}")
# Leftover cleanup for any previously-shelled-out demucs runs
old_demucs_dir = out / "demucs"
if old_demucs_dir.exists():
shutil.rmtree(str(old_demucs_dir), ignore_errors=True)
return vocals_16k_path, accompaniment_path