Spaces:
Running on Zero
Running on Zero
| """ | |
| Step 1b: Separate vocals from accompaniment using Demucs (Python API). | |
| In-process inference so ZeroGPU can intercept the GPU allocation via | |
| `@spaces.GPU`. Works on CUDA, MPS, and CPU without code changes. | |
| Only runs when preserve_music=True. | |
| """ | |
| import shutil | |
| import subprocess | |
| from pathlib import Path | |
| import torch | |
| import torchaudio | |
| import spaces | |
| _MODEL = None | |
| def _select_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| return "mps" | |
| return "cpu" | |
| def _get_model(): | |
| """Lazy-load htdemucs once per process. Module-level semantics; we load | |
| on first call so the import itself stays cheap on non-GPU envs.""" | |
| global _MODEL | |
| if _MODEL is None: | |
| from demucs.pretrained import get_model | |
| print("[s1b] Loading htdemucs on cpu...") | |
| model = get_model("htdemucs") | |
| model.eval() | |
| model.to("cpu") | |
| _MODEL = model | |
| return _MODEL | |
| def _apply_demucs(mix: torch.Tensor, device: str) -> torch.Tensor: | |
| """GPU-bound inference call. `mix` shape: [1, channels, time].""" | |
| from demucs.apply import apply_model | |
| model = _get_model() | |
| if next(model.parameters()).device.type != device: | |
| print(f"[s1b] Moving htdemucs to {device} inside GPU scope...") | |
| model = model.to(device) | |
| with torch.no_grad(): | |
| # apply_model returns [batch, sources, channels, time] | |
| sources = apply_model( | |
| model, | |
| mix.to(device), | |
| shifts=1, | |
| split=True, | |
| overlap=0.25, | |
| device=device, | |
| ) | |
| return sources.cpu() | |
| def _load_and_normalise(audio_hq_path: str, target_sr: int, target_ch: int) -> tuple[torch.Tensor, float, float]: | |
| """Load WAV, resample/remix to match model requirements, z-normalise.""" | |
| wav, sr = torchaudio.load(audio_hq_path) | |
| if sr != target_sr: | |
| wav = torchaudio.functional.resample(wav, sr, target_sr) | |
| if wav.shape[0] == 1 and target_ch == 2: | |
| wav = wav.repeat(2, 1) | |
| elif wav.shape[0] > target_ch: | |
| wav = wav[:target_ch] | |
| mean = wav.mean() | |
| std = wav.std().clamp_min(1e-8) | |
| wav_norm = (wav - mean) / std | |
| return wav_norm.unsqueeze(0), mean.item(), std.item() | |
| def separate_audio( | |
| audio_hq_path: str, | |
| output_dir: str = "tmp", | |
| ) -> tuple[str, str]: | |
| """ | |
| Separate vocals from accompaniment using Demucs htdemucs (Python API). | |
| Args: | |
| audio_hq_path: Path to input audio (any sample rate / channels). | |
| output_dir: Directory to write output stems. | |
| Returns: | |
| (vocals_16k_path, accompaniment_path) | |
| """ | |
| out = Path(output_dir) | |
| out.mkdir(parents=True, exist_ok=True) | |
| model = _get_model() | |
| device = _select_device() | |
| target_sr = model.samplerate | |
| target_ch = model.audio_channels | |
| source_names = list(model.sources) | |
| print(f"[s1b] Running Demucs htdemucs on {device} (Python API)...") | |
| mix, mean, std = _load_and_normalise(audio_hq_path, target_sr, target_ch) | |
| sources = _apply_demucs(mix, device) | |
| sources = sources * std + mean | |
| sources = sources[0] # drop batch dim → [num_sources, channels, time] | |
| try: | |
| vocals_idx = source_names.index("vocals") | |
| except ValueError as e: | |
| raise RuntimeError(f"htdemucs is missing 'vocals' source: {source_names}") from e | |
| vocals = sources[vocals_idx] | |
| no_vocals = sum( | |
| sources[i] for i in range(sources.shape[0]) if i != vocals_idx | |
| ) | |
| vocals_path = str(out / "vocals.wav") | |
| accompaniment_path = str(out / "accompaniment.wav") | |
| vocals_16k_path = str(out / "vocals_16k.wav") | |
| torchaudio.save(vocals_path, vocals, target_sr) | |
| torchaudio.save(accompaniment_path, no_vocals, target_sr) | |
| print(f"[s1b] Vocals saved → {vocals_path}") | |
| print(f"[s1b] Accompaniment saved → {accompaniment_path}") | |
| # Resample vocals to 16 kHz mono for Whisper/TTS via ffmpeg | |
| # (torchaudio resample works but ffmpeg is more predictable for downstream) | |
| cmd = [ | |
| "ffmpeg", "-y", | |
| "-i", vocals_path, | |
| "-ar", "16000", | |
| "-ac", "1", | |
| vocals_16k_path, | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"FFmpeg vocals resample failed:\n{result.stderr}") | |
| print(f"[s1b] Vocals (16 kHz) saved → {vocals_16k_path}") | |
| # Leftover cleanup for any previously-shelled-out demucs runs | |
| old_demucs_dir = out / "demucs" | |
| if old_demucs_dir.exists(): | |
| shutil.rmtree(str(old_demucs_dir), ignore_errors=True) | |
| return vocals_16k_path, accompaniment_path | |