ACE-Music-Studio / post_process.py
techfreakworm's picture
fix(post): use lower-level demucs API + stringify stem paths for gr.Files
b992e76 unverified
"""Post-generation: stem separation (Demucs), loudness normalisation
(pyloudnorm), and MP3 export (ffmpeg)."""
from __future__ import annotations
import subprocess
from pathlib import Path
from typing import Any
_DEMUCS = None
def _get_demucs() -> Any:
"""Lazy-load the htdemucs model.
Demucs 4.0.x exposes ``demucs.pretrained.get_model`` and
``demucs.apply.apply_model`` — the higher-level
``demucs.api.Separator`` convenience wrapper only appears in 4.1+.
We pin to the lower-level API so this works across both pip-installable
lines without forcing an upgrade on the apple-silicon torch stack.
"""
global _DEMUCS
if _DEMUCS is None:
from demucs.pretrained import get_model
_DEMUCS = get_model("htdemucs")
return _DEMUCS
def separate_stems(audio_path: Path | str) -> dict[str, Path]:
"""Split into vocals/drums/bass/other via htdemucs.
Uses the lower-level ``demucs.apply.apply_model`` so we don't depend
on the ``demucs.api.Separator`` wrapper (which only ships with
demucs >= 4.1). Returns a dict mapping stem name to written file path.
"""
import soundfile as sf
import torch
import torchaudio
from demucs.apply import apply_model
model = _get_demucs()
target_sr = int(getattr(model, "samplerate", 44100))
sources = list(getattr(model, "sources", ["drums", "bass", "other", "vocals"]))
audio_channels = int(getattr(model, "audio_channels", 2))
waveform, sr = torchaudio.load(str(audio_path)) # (channels, frames)
if sr != target_sr:
waveform = torchaudio.functional.resample(waveform, sr, target_sr)
# Match the model's expected channel count (htdemucs is stereo).
if waveform.shape[0] == 1 and audio_channels == 2:
waveform = waveform.repeat(2, 1)
elif waveform.shape[0] > audio_channels:
waveform = waveform[:audio_channels]
# apply_model expects shape (batch, channels, frames).
batch = waveform.unsqueeze(0)
with torch.no_grad():
# apply_model returns (batch, sources, channels, frames).
out = apply_model(model, batch, device="cpu", progress=False)
out = out[0] # drop batch dim -> (sources, channels, frames)
base = Path(audio_path).with_suffix("")
stems: dict[str, Path] = {}
for idx, name in enumerate(sources):
out_path = base.with_name(f"{base.name}.{name}.wav")
data = out[idx].cpu().numpy()
# soundfile expects (frames, channels); demucs gives (channels, frames)
if data.ndim == 2 and data.shape[0] in (1, 2):
data = data.T
sf.write(str(out_path), data, target_sr)
stems[name] = out_path
return stems
def _pyloudnorm_normalise(in_path: str, out_path: str, target_lufs: float) -> None:
"""Real pyloudnorm path; isolated for easy mocking in tests."""
import pyloudnorm as pyln
import soundfile as sf
data, rate = sf.read(in_path)
meter = pyln.Meter(rate)
current = meter.integrated_loudness(data)
normalised = pyln.normalize.loudness(data, current, target_lufs)
sf.write(out_path, normalised, rate)
def normalise_lufs(audio_path: Path | str, target_lufs: float = -14.0) -> Path:
"""Normalise to streaming-spec LUFS. Writes a new file alongside the input."""
audio_path = Path(audio_path)
out_path = audio_path.with_name(f"{audio_path.stem}.lufs{int(target_lufs)}.wav")
_pyloudnorm_normalise(str(audio_path), str(out_path), target_lufs)
return out_path
def to_mp3(wav_path: Path | str, bitrate_kbps: int = 320) -> Path:
"""Encode WAV to MP3 via system ffmpeg."""
wav_path = Path(wav_path)
out_path = wav_path.with_suffix(".mp3")
cmd = [
"ffmpeg",
"-y",
"-i",
str(wav_path),
"-b:a",
f"{bitrate_kbps}k",
"-ar",
"44100",
str(out_path),
]
subprocess.run(cmd, check=True, capture_output=True)
return out_path