Spaces:
Running on Zero
Running on Zero
| """ | |
| Audio source separation tool — three modes via Demucs. | |
| Reuses internals from steps.s1b_separate (model loader, device picker, normaliser, | |
| GPU-decorated apply). The existing separate_audio() returns only (vocals, accompaniment), | |
| so we replicate its flow here and keep all four stems addressable. | |
| """ | |
| from __future__ import annotations | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Literal | |
| import torch | |
| import torchaudio | |
| # Reuse internals — no edits to s1b_separate.py. | |
| from steps.s1b_separate import ( | |
| _apply_demucs, | |
| _get_model, | |
| _load_and_normalise, | |
| _select_device, | |
| ) | |
| Mode = Literal["vocals-only", "instrumental-only", "stems"] | |
| def _ensure_audio(input_path: Path, out_dir: Path) -> Path: | |
| """Convert input to a stable WAV format if it's a video or non-WAV audio.""" | |
| if input_path.suffix.lower() == ".wav": | |
| return input_path | |
| out = out_dir / "input.wav" | |
| cmd = [ | |
| "ffmpeg", "-y", "-i", str(input_path), | |
| "-vn", "-ac", "2", "-ar", "44100", | |
| "-acodec", "pcm_s16le", | |
| str(out), | |
| ] | |
| result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) | |
| if result.returncode != 0: | |
| raise RuntimeError(f"ffmpeg input prep failed: {result.stderr[-300:]}") | |
| return out | |
| def _separate_all_stems(audio_path: Path, out_dir: Path) -> dict[str, Path]: | |
| """Return {stem_name: wav_path} for every demucs source.""" | |
| model = _get_model() | |
| device = _select_device() | |
| target_sr = model.samplerate | |
| target_ch = model.audio_channels | |
| source_names = list(model.sources) # ["drums", "bass", "other", "vocals"] | |
| mix, mean, std = _load_and_normalise(str(audio_path), target_sr, target_ch) | |
| sources = _apply_demucs(mix, device) | |
| sources = sources * std + mean | |
| sources = sources[0] # [num_sources, channels, time] | |
| stems: dict[str, Path] = {} | |
| for idx, name in enumerate(source_names): | |
| wav_path = out_dir / f"{name}.wav" | |
| torchaudio.save(str(wav_path), sources[idx], target_sr) | |
| stems[name] = wav_path | |
| return stems | |
| def _sum_to_wav(stems: list[Path], dest: Path, sample_rate: int = 44100) -> Path: | |
| """Sum N stem WAVs into one — used to build the instrumental track.""" | |
| mix: torch.Tensor | None = None | |
| sr_used = sample_rate | |
| for path in stems: | |
| wav, sr = torchaudio.load(str(path)) | |
| sr_used = sr | |
| mix = wav if mix is None else mix + wav | |
| if mix is None: | |
| raise RuntimeError("No stems to sum.") | |
| torchaudio.save(str(dest), mix, sr_used) | |
| return dest | |
| def separate( | |
| *, | |
| input_path: Path, | |
| out_dir: Path, | |
| mode: Mode, | |
| ) -> list[dict]: | |
| """ | |
| Run separation. Returns a list of output descriptors: | |
| [{"name": "vocals.wav", "label": "Vocals", "filename": "vocals.wav"}, ...] | |
| """ | |
| audio_in = _ensure_audio(input_path, out_dir) | |
| stems = _separate_all_stems(audio_in, out_dir) | |
| if mode == "vocals-only": | |
| return [{ | |
| "name": "vocals", | |
| "label": "Vocals", | |
| "filename": stems["vocals"].name, | |
| "sub": "Dialogue track", | |
| }] | |
| if mode == "instrumental-only": | |
| non_vocal_stems = [stems[n] for n in stems if n != "vocals"] | |
| out = _sum_to_wav(non_vocal_stems, out_dir / "instrumental.wav") | |
| # Cleanup intermediate stem files we won't expose | |
| for path in stems.values(): | |
| try: | |
| path.unlink() | |
| except OSError: | |
| pass | |
| return [{ | |
| "name": "instrumental", | |
| "label": "Instrumental", | |
| "filename": out.name, | |
| "sub": "Music + ambient (vocals removed)", | |
| }] | |
| # stems mode — return all four | |
| label_map = { | |
| "vocals": ("Vocals", "Dialogue track"), | |
| "drums": ("Drums", "Percussion"), | |
| "bass": ("Bass", "Low frequency"), | |
| "other": ("Other", "Melodic / ambient"), | |
| } | |
| results: list[dict] = [] | |
| # Stable order: vocals first, then drums, bass, other | |
| for stem_key in ("vocals", "drums", "bass", "other"): | |
| if stem_key not in stems: | |
| continue | |
| label, sub = label_map[stem_key] | |
| results.append({ | |
| "name": stem_key, | |
| "label": label, | |
| "filename": stems[stem_key].name, | |
| "sub": sub, | |
| }) | |
| return results | |