""" Audio source separation tool — three modes via Demucs. Reuses internals from steps.s1b_separate (model loader, device picker, normaliser, GPU-decorated apply). The existing separate_audio() returns only (vocals, accompaniment), so we replicate its flow here and keep all four stems addressable. """ from __future__ import annotations import subprocess from pathlib import Path from typing import Literal import torch import torchaudio # Reuse internals — no edits to s1b_separate.py. from steps.s1b_separate import ( _apply_demucs, _get_model, _load_and_normalise, _select_device, ) Mode = Literal["vocals-only", "instrumental-only", "stems"] def _ensure_audio(input_path: Path, out_dir: Path) -> Path: """Convert input to a stable WAV format if it's a video or non-WAV audio.""" if input_path.suffix.lower() == ".wav": return input_path out = out_dir / "input.wav" cmd = [ "ffmpeg", "-y", "-i", str(input_path), "-vn", "-ac", "2", "-ar", "44100", "-acodec", "pcm_s16le", str(out), ] result = subprocess.run(cmd, capture_output=True, text=True, timeout=300) if result.returncode != 0: raise RuntimeError(f"ffmpeg input prep failed: {result.stderr[-300:]}") return out def _separate_all_stems(audio_path: Path, out_dir: Path) -> dict[str, Path]: """Return {stem_name: wav_path} for every demucs source.""" model = _get_model() device = _select_device() target_sr = model.samplerate target_ch = model.audio_channels source_names = list(model.sources) # ["drums", "bass", "other", "vocals"] mix, mean, std = _load_and_normalise(str(audio_path), target_sr, target_ch) sources = _apply_demucs(mix, device) sources = sources * std + mean sources = sources[0] # [num_sources, channels, time] stems: dict[str, Path] = {} for idx, name in enumerate(source_names): wav_path = out_dir / f"{name}.wav" torchaudio.save(str(wav_path), sources[idx], target_sr) stems[name] = wav_path return stems def _sum_to_wav(stems: list[Path], dest: Path, sample_rate: int = 44100) -> Path: """Sum N stem WAVs into one — used to build the instrumental track.""" mix: torch.Tensor | None = None sr_used = sample_rate for path in stems: wav, sr = torchaudio.load(str(path)) sr_used = sr mix = wav if mix is None else mix + wav if mix is None: raise RuntimeError("No stems to sum.") torchaudio.save(str(dest), mix, sr_used) return dest def separate( *, input_path: Path, out_dir: Path, mode: Mode, ) -> list[dict]: """ Run separation. Returns a list of output descriptors: [{"name": "vocals.wav", "label": "Vocals", "filename": "vocals.wav"}, ...] """ audio_in = _ensure_audio(input_path, out_dir) stems = _separate_all_stems(audio_in, out_dir) if mode == "vocals-only": return [{ "name": "vocals", "label": "Vocals", "filename": stems["vocals"].name, "sub": "Dialogue track", }] if mode == "instrumental-only": non_vocal_stems = [stems[n] for n in stems if n != "vocals"] out = _sum_to_wav(non_vocal_stems, out_dir / "instrumental.wav") # Cleanup intermediate stem files we won't expose for path in stems.values(): try: path.unlink() except OSError: pass return [{ "name": "instrumental", "label": "Instrumental", "filename": out.name, "sub": "Music + ambient (vocals removed)", }] # stems mode — return all four label_map = { "vocals": ("Vocals", "Dialogue track"), "drums": ("Drums", "Percussion"), "bass": ("Bass", "Low frequency"), "other": ("Other", "Melodic / ambient"), } results: list[dict] = [] # Stable order: vocals first, then drums, bass, other for stem_key in ("vocals", "drums", "bass", "other"): if stem_key not in stems: continue label, sub = label_map[stem_key] results.append({ "name": stem_key, "label": label, "filename": stems[stem_key].name, "sub": sub, }) return results