File size: 2,721 Bytes
2376414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ff3bb8
f729219
2376414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
Audio separation module: uses Demucs to separate vocals from instruments.
"""

import os
import logging
import torch

logger = logging.getLogger(__name__)

try:
    import spaces
except ImportError:
    class spaces:
        @staticmethod
        def GPU(duration=60, **kwargs):
            def decorator(fn):
                return fn
            return decorator


OUTPUT_DIR = "/tmp/demucs_output"


@spaces.GPU(duration=60)
def separate_audio(audio_path: str, model_name: str = "htdemucs_ft"):
    """
    Separate audio into vocals and instruments using Demucs.
    Returns (vocals_path, instruments_path).
    """
    import torchaudio
    from demucs.pretrained import get_model
    from demucs.apply import apply_model

    os.makedirs(OUTPUT_DIR, exist_ok=True)

    logger.info(f"Loading Demucs model '{model_name}'...")
    model = get_model(model_name)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    logger.info(f"Loading audio: {audio_path}")
    waveform, sr = torchaudio.load(audio_path)

    # Resample if needed
    if sr != model.samplerate:
        resampler = torchaudio.transforms.Resample(sr, model.samplerate)
        waveform = resampler(waveform)
        sr = model.samplerate

    # Ensure stereo
    if waveform.shape[0] == 1:
        waveform = waveform.repeat(2, 1)
    elif waveform.shape[0] > 2:
        waveform = waveform[:2]

    # Apply model
    logger.info("Separating audio...")
    ref = waveform.mean(0)
    std = ref.std()
    if std < 1e-6:
        std = torch.tensor(1e-6)
    waveform = (waveform - ref.mean()) / std

    sources = apply_model(
        model,
        waveform[None].to(device),
        device=device,
        progress=True,
        num_workers=0,
    )

    sources = sources * std + ref.mean()
    sources = sources[0]  # Remove batch dimension

    # Demucs sources order: drums, bass, other, vocals
    source_names = model.sources
    vocals_idx = source_names.index("vocals")

    vocals = sources[vocals_idx].cpu()

    # Instruments = everything except vocals
    instruments = torch.zeros_like(vocals)
    for i, name in enumerate(source_names):
        if name != "vocals":
            instruments += sources[i].cpu()

    # Save outputs
    base_name = os.path.splitext(os.path.basename(audio_path))[0]
    vocals_path = os.path.join(OUTPUT_DIR, f"{base_name}_vocals.wav")
    instruments_path = os.path.join(OUTPUT_DIR, f"{base_name}_instruments.wav")

    torchaudio.save(vocals_path, vocals, sr)
    torchaudio.save(instruments_path, instruments, sr)

    logger.info(f"Separation complete. Vocals: {vocals_path}, Instruments: {instruments_path}")
    return vocals_path, instruments_path