""" Step 1b: Separate vocals from accompaniment using Demucs (Python API). In-process inference so ZeroGPU can intercept the GPU allocation via `@spaces.GPU`. Works on CUDA, MPS, and CPU without code changes. Only runs when preserve_music=True. """ import shutil import subprocess from pathlib import Path import torch import torchaudio import spaces _MODEL = None def _select_device() -> str: if torch.cuda.is_available(): return "cuda" if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return "mps" return "cpu" def _get_model(): """Lazy-load htdemucs once per process. Module-level semantics; we load on first call so the import itself stays cheap on non-GPU envs.""" global _MODEL if _MODEL is None: from demucs.pretrained import get_model print("[s1b] Loading htdemucs on cpu...") model = get_model("htdemucs") model.eval() model.to("cpu") _MODEL = model return _MODEL @spaces.GPU(duration=120) def _apply_demucs(mix: torch.Tensor, device: str) -> torch.Tensor: """GPU-bound inference call. `mix` shape: [1, channels, time].""" from demucs.apply import apply_model model = _get_model() if next(model.parameters()).device.type != device: print(f"[s1b] Moving htdemucs to {device} inside GPU scope...") model = model.to(device) with torch.no_grad(): # apply_model returns [batch, sources, channels, time] sources = apply_model( model, mix.to(device), shifts=1, split=True, overlap=0.25, device=device, ) return sources.cpu() def _load_and_normalise(audio_hq_path: str, target_sr: int, target_ch: int) -> tuple[torch.Tensor, float, float]: """Load WAV, resample/remix to match model requirements, z-normalise.""" wav, sr = torchaudio.load(audio_hq_path) if sr != target_sr: wav = torchaudio.functional.resample(wav, sr, target_sr) if wav.shape[0] == 1 and target_ch == 2: wav = wav.repeat(2, 1) elif wav.shape[0] > target_ch: wav = wav[:target_ch] mean = wav.mean() std = wav.std().clamp_min(1e-8) wav_norm = (wav - mean) / std return wav_norm.unsqueeze(0), mean.item(), std.item() def separate_audio( audio_hq_path: str, output_dir: str = "tmp", ) -> tuple[str, str]: """ Separate vocals from accompaniment using Demucs htdemucs (Python API). Args: audio_hq_path: Path to input audio (any sample rate / channels). output_dir: Directory to write output stems. Returns: (vocals_16k_path, accompaniment_path) """ out = Path(output_dir) out.mkdir(parents=True, exist_ok=True) model = _get_model() device = _select_device() target_sr = model.samplerate target_ch = model.audio_channels source_names = list(model.sources) print(f"[s1b] Running Demucs htdemucs on {device} (Python API)...") mix, mean, std = _load_and_normalise(audio_hq_path, target_sr, target_ch) sources = _apply_demucs(mix, device) sources = sources * std + mean sources = sources[0] # drop batch dim → [num_sources, channels, time] try: vocals_idx = source_names.index("vocals") except ValueError as e: raise RuntimeError(f"htdemucs is missing 'vocals' source: {source_names}") from e vocals = sources[vocals_idx] no_vocals = sum( sources[i] for i in range(sources.shape[0]) if i != vocals_idx ) vocals_path = str(out / "vocals.wav") accompaniment_path = str(out / "accompaniment.wav") vocals_16k_path = str(out / "vocals_16k.wav") torchaudio.save(vocals_path, vocals, target_sr) torchaudio.save(accompaniment_path, no_vocals, target_sr) print(f"[s1b] Vocals saved → {vocals_path}") print(f"[s1b] Accompaniment saved → {accompaniment_path}") # Resample vocals to 16 kHz mono for Whisper/TTS via ffmpeg # (torchaudio resample works but ffmpeg is more predictable for downstream) cmd = [ "ffmpeg", "-y", "-i", vocals_path, "-ar", "16000", "-ac", "1", vocals_16k_path, ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"FFmpeg vocals resample failed:\n{result.stderr}") print(f"[s1b] Vocals (16 kHz) saved → {vocals_16k_path}") # Leftover cleanup for any previously-shelled-out demucs runs old_demucs_dir = out / "demucs" if old_demucs_dir.exists(): shutil.rmtree(str(old_demucs_dir), ignore_errors=True) return vocals_16k_path, accompaniment_path