Spaces:
Sleeping
Sleeping
File size: 2,721 Bytes
2376414 9ff3bb8 f729219 2376414 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | """
Audio separation module: uses Demucs to separate vocals from instruments.
"""
import os
import logging
import torch
logger = logging.getLogger(__name__)
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(duration=60, **kwargs):
def decorator(fn):
return fn
return decorator
OUTPUT_DIR = "/tmp/demucs_output"
@spaces.GPU(duration=60)
def separate_audio(audio_path: str, model_name: str = "htdemucs_ft"):
"""
Separate audio into vocals and instruments using Demucs.
Returns (vocals_path, instruments_path).
"""
import torchaudio
from demucs.pretrained import get_model
from demucs.apply import apply_model
os.makedirs(OUTPUT_DIR, exist_ok=True)
logger.info(f"Loading Demucs model '{model_name}'...")
model = get_model(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logger.info(f"Loading audio: {audio_path}")
waveform, sr = torchaudio.load(audio_path)
# Resample if needed
if sr != model.samplerate:
resampler = torchaudio.transforms.Resample(sr, model.samplerate)
waveform = resampler(waveform)
sr = model.samplerate
# Ensure stereo
if waveform.shape[0] == 1:
waveform = waveform.repeat(2, 1)
elif waveform.shape[0] > 2:
waveform = waveform[:2]
# Apply model
logger.info("Separating audio...")
ref = waveform.mean(0)
std = ref.std()
if std < 1e-6:
std = torch.tensor(1e-6)
waveform = (waveform - ref.mean()) / std
sources = apply_model(
model,
waveform[None].to(device),
device=device,
progress=True,
num_workers=0,
)
sources = sources * std + ref.mean()
sources = sources[0] # Remove batch dimension
# Demucs sources order: drums, bass, other, vocals
source_names = model.sources
vocals_idx = source_names.index("vocals")
vocals = sources[vocals_idx].cpu()
# Instruments = everything except vocals
instruments = torch.zeros_like(vocals)
for i, name in enumerate(source_names):
if name != "vocals":
instruments += sources[i].cpu()
# Save outputs
base_name = os.path.splitext(os.path.basename(audio_path))[0]
vocals_path = os.path.join(OUTPUT_DIR, f"{base_name}_vocals.wav")
instruments_path = os.path.join(OUTPUT_DIR, f"{base_name}_instruments.wav")
torchaudio.save(vocals_path, vocals, sr)
torchaudio.save(instruments_path, instruments, sr)
logger.info(f"Separation complete. Vocals: {vocals_path}, Instruments: {instruments_path}")
return vocals_path, instruments_path
|