Spaces:
Running on Zero
Running on Zero
File size: 4,700 Bytes
02ad302 3474e83 02ad302 3474e83 02ad302 3474e83 02ad302 3474e83 02ad302 3474e83 02ad302 3474e83 02ad302 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """
Step 1b: Separate vocals from accompaniment using Demucs (Python API).
In-process inference so ZeroGPU can intercept the GPU allocation via
`@spaces.GPU`. Works on CUDA, MPS, and CPU without code changes.
Only runs when preserve_music=True.
"""
import shutil
import subprocess
from pathlib import Path
import torch
import torchaudio
import spaces
_MODEL = None
def _select_device() -> str:
if torch.cuda.is_available():
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"
def _get_model():
"""Lazy-load htdemucs once per process. Module-level semantics; we load
on first call so the import itself stays cheap on non-GPU envs."""
global _MODEL
if _MODEL is None:
from demucs.pretrained import get_model
print("[s1b] Loading htdemucs on cpu...")
model = get_model("htdemucs")
model.eval()
model.to("cpu")
_MODEL = model
return _MODEL
@spaces.GPU(duration=120)
def _apply_demucs(mix: torch.Tensor, device: str) -> torch.Tensor:
"""GPU-bound inference call. `mix` shape: [1, channels, time]."""
from demucs.apply import apply_model
model = _get_model()
if next(model.parameters()).device.type != device:
print(f"[s1b] Moving htdemucs to {device} inside GPU scope...")
model = model.to(device)
with torch.no_grad():
# apply_model returns [batch, sources, channels, time]
sources = apply_model(
model,
mix.to(device),
shifts=1,
split=True,
overlap=0.25,
device=device,
)
return sources.cpu()
def _load_and_normalise(audio_hq_path: str, target_sr: int, target_ch: int) -> tuple[torch.Tensor, float, float]:
"""Load WAV, resample/remix to match model requirements, z-normalise."""
wav, sr = torchaudio.load(audio_hq_path)
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
if wav.shape[0] == 1 and target_ch == 2:
wav = wav.repeat(2, 1)
elif wav.shape[0] > target_ch:
wav = wav[:target_ch]
mean = wav.mean()
std = wav.std().clamp_min(1e-8)
wav_norm = (wav - mean) / std
return wav_norm.unsqueeze(0), mean.item(), std.item()
def separate_audio(
audio_hq_path: str,
output_dir: str = "tmp",
) -> tuple[str, str]:
"""
Separate vocals from accompaniment using Demucs htdemucs (Python API).
Args:
audio_hq_path: Path to input audio (any sample rate / channels).
output_dir: Directory to write output stems.
Returns:
(vocals_16k_path, accompaniment_path)
"""
out = Path(output_dir)
out.mkdir(parents=True, exist_ok=True)
model = _get_model()
device = _select_device()
target_sr = model.samplerate
target_ch = model.audio_channels
source_names = list(model.sources)
print(f"[s1b] Running Demucs htdemucs on {device} (Python API)...")
mix, mean, std = _load_and_normalise(audio_hq_path, target_sr, target_ch)
sources = _apply_demucs(mix, device)
sources = sources * std + mean
sources = sources[0] # drop batch dim → [num_sources, channels, time]
try:
vocals_idx = source_names.index("vocals")
except ValueError as e:
raise RuntimeError(f"htdemucs is missing 'vocals' source: {source_names}") from e
vocals = sources[vocals_idx]
no_vocals = sum(
sources[i] for i in range(sources.shape[0]) if i != vocals_idx
)
vocals_path = str(out / "vocals.wav")
accompaniment_path = str(out / "accompaniment.wav")
vocals_16k_path = str(out / "vocals_16k.wav")
torchaudio.save(vocals_path, vocals, target_sr)
torchaudio.save(accompaniment_path, no_vocals, target_sr)
print(f"[s1b] Vocals saved → {vocals_path}")
print(f"[s1b] Accompaniment saved → {accompaniment_path}")
# Resample vocals to 16 kHz mono for Whisper/TTS via ffmpeg
# (torchaudio resample works but ffmpeg is more predictable for downstream)
cmd = [
"ffmpeg", "-y",
"-i", vocals_path,
"-ar", "16000",
"-ac", "1",
vocals_16k_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg vocals resample failed:\n{result.stderr}")
print(f"[s1b] Vocals (16 kHz) saved → {vocals_16k_path}")
# Leftover cleanup for any previously-shelled-out demucs runs
old_demucs_dir = out / "demucs"
if old_demucs_dir.exists():
shutil.rmtree(str(old_demucs_dir), ignore_errors=True)
return vocals_16k_path, accompaniment_path
|