File size: 4,348 Bytes
5b7cd5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
Audio source separation tool — three modes via Demucs.

Reuses internals from steps.s1b_separate (model loader, device picker, normaliser,
GPU-decorated apply). The existing separate_audio() returns only (vocals, accompaniment),
so we replicate its flow here and keep all four stems addressable.
"""
from __future__ import annotations

import subprocess
from pathlib import Path
from typing import Literal

import torch
import torchaudio

# Reuse internals — no edits to s1b_separate.py.
from steps.s1b_separate import (
    _apply_demucs,
    _get_model,
    _load_and_normalise,
    _select_device,
)

Mode = Literal["vocals-only", "instrumental-only", "stems"]


def _ensure_audio(input_path: Path, out_dir: Path) -> Path:
    """Convert input to a stable WAV format if it's a video or non-WAV audio."""
    if input_path.suffix.lower() == ".wav":
        return input_path
    out = out_dir / "input.wav"
    cmd = [
        "ffmpeg", "-y", "-i", str(input_path),
        "-vn", "-ac", "2", "-ar", "44100",
        "-acodec", "pcm_s16le",
        str(out),
    ]
    result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
    if result.returncode != 0:
        raise RuntimeError(f"ffmpeg input prep failed: {result.stderr[-300:]}")
    return out


def _separate_all_stems(audio_path: Path, out_dir: Path) -> dict[str, Path]:
    """Return {stem_name: wav_path} for every demucs source."""
    model = _get_model()
    device = _select_device()
    target_sr = model.samplerate
    target_ch = model.audio_channels
    source_names = list(model.sources)  # ["drums", "bass", "other", "vocals"]

    mix, mean, std = _load_and_normalise(str(audio_path), target_sr, target_ch)
    sources = _apply_demucs(mix, device)
    sources = sources * std + mean
    sources = sources[0]  # [num_sources, channels, time]

    stems: dict[str, Path] = {}
    for idx, name in enumerate(source_names):
        wav_path = out_dir / f"{name}.wav"
        torchaudio.save(str(wav_path), sources[idx], target_sr)
        stems[name] = wav_path
    return stems


def _sum_to_wav(stems: list[Path], dest: Path, sample_rate: int = 44100) -> Path:
    """Sum N stem WAVs into one — used to build the instrumental track."""
    mix: torch.Tensor | None = None
    sr_used = sample_rate
    for path in stems:
        wav, sr = torchaudio.load(str(path))
        sr_used = sr
        mix = wav if mix is None else mix + wav
    if mix is None:
        raise RuntimeError("No stems to sum.")
    torchaudio.save(str(dest), mix, sr_used)
    return dest


def separate(
    *,
    input_path: Path,
    out_dir: Path,
    mode: Mode,
) -> list[dict]:
    """
    Run separation. Returns a list of output descriptors:
      [{"name": "vocals.wav", "label": "Vocals", "filename": "vocals.wav"}, ...]
    """
    audio_in = _ensure_audio(input_path, out_dir)
    stems = _separate_all_stems(audio_in, out_dir)

    if mode == "vocals-only":
        return [{
            "name": "vocals",
            "label": "Vocals",
            "filename": stems["vocals"].name,
            "sub": "Dialogue track",
        }]

    if mode == "instrumental-only":
        non_vocal_stems = [stems[n] for n in stems if n != "vocals"]
        out = _sum_to_wav(non_vocal_stems, out_dir / "instrumental.wav")
        # Cleanup intermediate stem files we won't expose
        for path in stems.values():
            try:
                path.unlink()
            except OSError:
                pass
        return [{
            "name": "instrumental",
            "label": "Instrumental",
            "filename": out.name,
            "sub": "Music + ambient (vocals removed)",
        }]

    # stems mode — return all four
    label_map = {
        "vocals": ("Vocals", "Dialogue track"),
        "drums": ("Drums", "Percussion"),
        "bass": ("Bass", "Low frequency"),
        "other": ("Other", "Melodic / ambient"),
    }
    results: list[dict] = []
    # Stable order: vocals first, then drums, bass, other
    for stem_key in ("vocals", "drums", "bass", "other"):
        if stem_key not in stems:
            continue
        label, sub = label_map[stem_key]
        results.append({
            "name": stem_key,
            "label": label,
            "filename": stems[stem_key].name,
            "sub": sub,
        })
    return results