Spaces:
Running on Zero
Running on Zero
File size: 4,348 Bytes
5b7cd5f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """
Audio source separation tool — three modes via Demucs.
Reuses internals from steps.s1b_separate (model loader, device picker, normaliser,
GPU-decorated apply). The existing separate_audio() returns only (vocals, accompaniment),
so we replicate its flow here and keep all four stems addressable.
"""
from __future__ import annotations
import subprocess
from pathlib import Path
from typing import Literal
import torch
import torchaudio
# Reuse internals — no edits to s1b_separate.py.
from steps.s1b_separate import (
_apply_demucs,
_get_model,
_load_and_normalise,
_select_device,
)
Mode = Literal["vocals-only", "instrumental-only", "stems"]
def _ensure_audio(input_path: Path, out_dir: Path) -> Path:
"""Convert input to a stable WAV format if it's a video or non-WAV audio."""
if input_path.suffix.lower() == ".wav":
return input_path
out = out_dir / "input.wav"
cmd = [
"ffmpeg", "-y", "-i", str(input_path),
"-vn", "-ac", "2", "-ar", "44100",
"-acodec", "pcm_s16le",
str(out),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg input prep failed: {result.stderr[-300:]}")
return out
def _separate_all_stems(audio_path: Path, out_dir: Path) -> dict[str, Path]:
"""Return {stem_name: wav_path} for every demucs source."""
model = _get_model()
device = _select_device()
target_sr = model.samplerate
target_ch = model.audio_channels
source_names = list(model.sources) # ["drums", "bass", "other", "vocals"]
mix, mean, std = _load_and_normalise(str(audio_path), target_sr, target_ch)
sources = _apply_demucs(mix, device)
sources = sources * std + mean
sources = sources[0] # [num_sources, channels, time]
stems: dict[str, Path] = {}
for idx, name in enumerate(source_names):
wav_path = out_dir / f"{name}.wav"
torchaudio.save(str(wav_path), sources[idx], target_sr)
stems[name] = wav_path
return stems
def _sum_to_wav(stems: list[Path], dest: Path, sample_rate: int = 44100) -> Path:
"""Sum N stem WAVs into one — used to build the instrumental track."""
mix: torch.Tensor | None = None
sr_used = sample_rate
for path in stems:
wav, sr = torchaudio.load(str(path))
sr_used = sr
mix = wav if mix is None else mix + wav
if mix is None:
raise RuntimeError("No stems to sum.")
torchaudio.save(str(dest), mix, sr_used)
return dest
def separate(
*,
input_path: Path,
out_dir: Path,
mode: Mode,
) -> list[dict]:
"""
Run separation. Returns a list of output descriptors:
[{"name": "vocals.wav", "label": "Vocals", "filename": "vocals.wav"}, ...]
"""
audio_in = _ensure_audio(input_path, out_dir)
stems = _separate_all_stems(audio_in, out_dir)
if mode == "vocals-only":
return [{
"name": "vocals",
"label": "Vocals",
"filename": stems["vocals"].name,
"sub": "Dialogue track",
}]
if mode == "instrumental-only":
non_vocal_stems = [stems[n] for n in stems if n != "vocals"]
out = _sum_to_wav(non_vocal_stems, out_dir / "instrumental.wav")
# Cleanup intermediate stem files we won't expose
for path in stems.values():
try:
path.unlink()
except OSError:
pass
return [{
"name": "instrumental",
"label": "Instrumental",
"filename": out.name,
"sub": "Music + ambient (vocals removed)",
}]
# stems mode — return all four
label_map = {
"vocals": ("Vocals", "Dialogue track"),
"drums": ("Drums", "Percussion"),
"bass": ("Bass", "Low frequency"),
"other": ("Other", "Melodic / ambient"),
}
results: list[dict] = []
# Stable order: vocals first, then drums, bass, other
for stem_key in ("vocals", "drums", "bass", "other"):
if stem_key not in stems:
continue
label, sub = label_map[stem_key]
results.append({
"name": stem_key,
"label": label,
"filename": stems[stem_key].name,
"sub": sub,
})
return results
|