videovoice-dramabox / tools_api /audio_cleanup.py
github-actions[bot]
deploy: switch to dramabox requirements @ a95fda4
0422215
"""
Audio source separation tool — three modes via Demucs.
Reuses internals from steps.s1b_separate (model loader, device picker, normaliser,
GPU-decorated apply). The existing separate_audio() returns only (vocals, accompaniment),
so we replicate its flow here and keep all four stems addressable.
"""
from __future__ import annotations
import subprocess
from pathlib import Path
from typing import Literal
import torch
import torchaudio
# Reuse internals — no edits to s1b_separate.py.
from steps.s1b_separate import (
_apply_demucs,
_get_model,
_load_and_normalise,
_select_device,
)
Mode = Literal["vocals-only", "instrumental-only", "stems"]
def _ensure_audio(input_path: Path, out_dir: Path) -> Path:
"""Convert input to a stable WAV format if it's a video or non-WAV audio."""
if input_path.suffix.lower() == ".wav":
return input_path
out = out_dir / "input.wav"
cmd = [
"ffmpeg", "-y", "-i", str(input_path),
"-vn", "-ac", "2", "-ar", "44100",
"-acodec", "pcm_s16le",
str(out),
]
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
if result.returncode != 0:
raise RuntimeError(f"ffmpeg input prep failed: {result.stderr[-300:]}")
return out
def _separate_all_stems(audio_path: Path, out_dir: Path) -> dict[str, Path]:
"""Return {stem_name: wav_path} for every demucs source."""
model = _get_model()
device = _select_device()
target_sr = model.samplerate
target_ch = model.audio_channels
source_names = list(model.sources) # ["drums", "bass", "other", "vocals"]
mix, mean, std = _load_and_normalise(str(audio_path), target_sr, target_ch)
sources = _apply_demucs(mix, device)
sources = sources * std + mean
sources = sources[0] # [num_sources, channels, time]
stems: dict[str, Path] = {}
for idx, name in enumerate(source_names):
wav_path = out_dir / f"{name}.wav"
torchaudio.save(str(wav_path), sources[idx], target_sr)
stems[name] = wav_path
return stems
def _sum_to_wav(stems: list[Path], dest: Path, sample_rate: int = 44100) -> Path:
"""Sum N stem WAVs into one — used to build the instrumental track."""
mix: torch.Tensor | None = None
sr_used = sample_rate
for path in stems:
wav, sr = torchaudio.load(str(path))
sr_used = sr
mix = wav if mix is None else mix + wav
if mix is None:
raise RuntimeError("No stems to sum.")
torchaudio.save(str(dest), mix, sr_used)
return dest
def separate(
*,
input_path: Path,
out_dir: Path,
mode: Mode,
) -> list[dict]:
"""
Run separation. Returns a list of output descriptors:
[{"name": "vocals.wav", "label": "Vocals", "filename": "vocals.wav"}, ...]
"""
audio_in = _ensure_audio(input_path, out_dir)
stems = _separate_all_stems(audio_in, out_dir)
if mode == "vocals-only":
return [{
"name": "vocals",
"label": "Vocals",
"filename": stems["vocals"].name,
"sub": "Dialogue track",
}]
if mode == "instrumental-only":
non_vocal_stems = [stems[n] for n in stems if n != "vocals"]
out = _sum_to_wav(non_vocal_stems, out_dir / "instrumental.wav")
# Cleanup intermediate stem files we won't expose
for path in stems.values():
try:
path.unlink()
except OSError:
pass
return [{
"name": "instrumental",
"label": "Instrumental",
"filename": out.name,
"sub": "Music + ambient (vocals removed)",
}]
# stems mode — return all four
label_map = {
"vocals": ("Vocals", "Dialogue track"),
"drums": ("Drums", "Percussion"),
"bass": ("Bass", "Low frequency"),
"other": ("Other", "Melodic / ambient"),
}
results: list[dict] = []
# Stable order: vocals first, then drums, bass, other
for stem_key in ("vocals", "drums", "bass", "other"):
if stem_key not in stems:
continue
label, sub = label_map[stem_key]
results.append({
"name": stem_key,
"label": label,
"filename": stems[stem_key].name,
"sub": sub,
})
return results