File size: 4,700 Bytes
02ad302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3474e83
02ad302
 
3474e83
02ad302
 
3474e83
02ad302
3474e83
02ad302
 
 
 
 
 
 
3474e83
 
 
 
02ad302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3474e83
 
02ad302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
Step 1b: Separate vocals from accompaniment using Demucs (Python API).

In-process inference so ZeroGPU can intercept the GPU allocation via
`@spaces.GPU`. Works on CUDA, MPS, and CPU without code changes.
Only runs when preserve_music=True.
"""
import shutil
import subprocess
from pathlib import Path

import torch
import torchaudio

import spaces


_MODEL = None


def _select_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def _get_model():
    """Lazy-load htdemucs once per process. Module-level semantics; we load
    on first call so the import itself stays cheap on non-GPU envs."""
    global _MODEL
    if _MODEL is None:
        from demucs.pretrained import get_model
        print("[s1b] Loading htdemucs on cpu...")
        model = get_model("htdemucs")
        model.eval()
        model.to("cpu")
        _MODEL = model
    return _MODEL


@spaces.GPU(duration=120)
def _apply_demucs(mix: torch.Tensor, device: str) -> torch.Tensor:
    """GPU-bound inference call. `mix` shape: [1, channels, time]."""
    from demucs.apply import apply_model

    model = _get_model()
    if next(model.parameters()).device.type != device:
        print(f"[s1b] Moving htdemucs to {device} inside GPU scope...")
        model = model.to(device)
    with torch.no_grad():
        # apply_model returns [batch, sources, channels, time]
        sources = apply_model(
            model,
            mix.to(device),
            shifts=1,
            split=True,
            overlap=0.25,
            device=device,
        )
    return sources.cpu()


def _load_and_normalise(audio_hq_path: str, target_sr: int, target_ch: int) -> tuple[torch.Tensor, float, float]:
    """Load WAV, resample/remix to match model requirements, z-normalise."""
    wav, sr = torchaudio.load(audio_hq_path)

    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)

    if wav.shape[0] == 1 and target_ch == 2:
        wav = wav.repeat(2, 1)
    elif wav.shape[0] > target_ch:
        wav = wav[:target_ch]

    mean = wav.mean()
    std = wav.std().clamp_min(1e-8)
    wav_norm = (wav - mean) / std
    return wav_norm.unsqueeze(0), mean.item(), std.item()


def separate_audio(
    audio_hq_path: str,
    output_dir: str = "tmp",
) -> tuple[str, str]:
    """
    Separate vocals from accompaniment using Demucs htdemucs (Python API).

    Args:
        audio_hq_path: Path to input audio (any sample rate / channels).
        output_dir: Directory to write output stems.

    Returns:
        (vocals_16k_path, accompaniment_path)
    """
    out = Path(output_dir)
    out.mkdir(parents=True, exist_ok=True)

    model = _get_model()
    device = _select_device()
    target_sr = model.samplerate
    target_ch = model.audio_channels
    source_names = list(model.sources)

    print(f"[s1b] Running Demucs htdemucs on {device} (Python API)...")
    mix, mean, std = _load_and_normalise(audio_hq_path, target_sr, target_ch)

    sources = _apply_demucs(mix, device)
    sources = sources * std + mean
    sources = sources[0]  # drop batch dim → [num_sources, channels, time]

    try:
        vocals_idx = source_names.index("vocals")
    except ValueError as e:
        raise RuntimeError(f"htdemucs is missing 'vocals' source: {source_names}") from e

    vocals = sources[vocals_idx]
    no_vocals = sum(
        sources[i] for i in range(sources.shape[0]) if i != vocals_idx
    )

    vocals_path = str(out / "vocals.wav")
    accompaniment_path = str(out / "accompaniment.wav")
    vocals_16k_path = str(out / "vocals_16k.wav")

    torchaudio.save(vocals_path, vocals, target_sr)
    torchaudio.save(accompaniment_path, no_vocals, target_sr)
    print(f"[s1b] Vocals saved → {vocals_path}")
    print(f"[s1b] Accompaniment saved → {accompaniment_path}")

    # Resample vocals to 16 kHz mono for Whisper/TTS via ffmpeg
    # (torchaudio resample works but ffmpeg is more predictable for downstream)
    cmd = [
        "ffmpeg", "-y",
        "-i", vocals_path,
        "-ar", "16000",
        "-ac", "1",
        vocals_16k_path,
    ]
    result = subprocess.run(cmd, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(f"FFmpeg vocals resample failed:\n{result.stderr}")

    print(f"[s1b] Vocals (16 kHz) saved → {vocals_16k_path}")

    # Leftover cleanup for any previously-shelled-out demucs runs
    old_demucs_dir = out / "demucs"
    if old_demucs_dir.exists():
        shutil.rmtree(str(old_demucs_dir), ignore_errors=True)

    return vocals_16k_path, accompaniment_path