File size: 3,955 Bytes
b260242
 
 
 
 
 
 
 
 
 
 
 
 
b992e76
 
 
 
 
 
 
 
b260242
 
b992e76
b260242
b992e76
b260242
 
 
 
b992e76
b260242
b992e76
 
 
b260242
 
b992e76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b260242
 
 
b992e76
 
 
 
b260242
 
b992e76
 
b260242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Post-generation: stem separation (Demucs), loudness normalisation
(pyloudnorm), and MP3 export (ffmpeg)."""

from __future__ import annotations

import subprocess
from pathlib import Path
from typing import Any

_DEMUCS = None


def _get_demucs() -> Any:
    """Lazy-load the htdemucs model.

    Demucs 4.0.x exposes ``demucs.pretrained.get_model`` and
    ``demucs.apply.apply_model`` — the higher-level
    ``demucs.api.Separator`` convenience wrapper only appears in 4.1+.
    We pin to the lower-level API so this works across both pip-installable
    lines without forcing an upgrade on the apple-silicon torch stack.
    """
    global _DEMUCS
    if _DEMUCS is None:
        from demucs.pretrained import get_model

        _DEMUCS = get_model("htdemucs")
    return _DEMUCS


def separate_stems(audio_path: Path | str) -> dict[str, Path]:
    """Split into vocals/drums/bass/other via htdemucs.

    Uses the lower-level ``demucs.apply.apply_model`` so we don't depend
    on the ``demucs.api.Separator`` wrapper (which only ships with
    demucs >= 4.1). Returns a dict mapping stem name to written file path.
    """
    import soundfile as sf
    import torch
    import torchaudio
    from demucs.apply import apply_model

    model = _get_demucs()
    target_sr = int(getattr(model, "samplerate", 44100))
    sources = list(getattr(model, "sources", ["drums", "bass", "other", "vocals"]))
    audio_channels = int(getattr(model, "audio_channels", 2))

    waveform, sr = torchaudio.load(str(audio_path))  # (channels, frames)
    if sr != target_sr:
        waveform = torchaudio.functional.resample(waveform, sr, target_sr)
    # Match the model's expected channel count (htdemucs is stereo).
    if waveform.shape[0] == 1 and audio_channels == 2:
        waveform = waveform.repeat(2, 1)
    elif waveform.shape[0] > audio_channels:
        waveform = waveform[:audio_channels]

    # apply_model expects shape (batch, channels, frames).
    batch = waveform.unsqueeze(0)
    with torch.no_grad():
        # apply_model returns (batch, sources, channels, frames).
        out = apply_model(model, batch, device="cpu", progress=False)
    out = out[0]  # drop batch dim -> (sources, channels, frames)

    base = Path(audio_path).with_suffix("")
    stems: dict[str, Path] = {}
    for idx, name in enumerate(sources):
        out_path = base.with_name(f"{base.name}.{name}.wav")
        data = out[idx].cpu().numpy()
        # soundfile expects (frames, channels); demucs gives (channels, frames)
        if data.ndim == 2 and data.shape[0] in (1, 2):
            data = data.T
        sf.write(str(out_path), data, target_sr)
        stems[name] = out_path
    return stems


def _pyloudnorm_normalise(in_path: str, out_path: str, target_lufs: float) -> None:
    """Real pyloudnorm path; isolated for easy mocking in tests."""
    import pyloudnorm as pyln
    import soundfile as sf

    data, rate = sf.read(in_path)
    meter = pyln.Meter(rate)
    current = meter.integrated_loudness(data)
    normalised = pyln.normalize.loudness(data, current, target_lufs)
    sf.write(out_path, normalised, rate)


def normalise_lufs(audio_path: Path | str, target_lufs: float = -14.0) -> Path:
    """Normalise to streaming-spec LUFS. Writes a new file alongside the input."""
    audio_path = Path(audio_path)
    out_path = audio_path.with_name(f"{audio_path.stem}.lufs{int(target_lufs)}.wav")
    _pyloudnorm_normalise(str(audio_path), str(out_path), target_lufs)
    return out_path


def to_mp3(wav_path: Path | str, bitrate_kbps: int = 320) -> Path:
    """Encode WAV to MP3 via system ffmpeg."""
    wav_path = Path(wav_path)
    out_path = wav_path.with_suffix(".mp3")
    cmd = [
        "ffmpeg",
        "-y",
        "-i",
        str(wav_path),
        "-b:a",
        f"{bitrate_kbps}k",
        "-ar",
        "44100",
        str(out_path),
    ]
    subprocess.run(cmd, check=True, capture_output=True)
    return out_path