""" Audio separation module: uses Demucs to separate vocals from instruments. """ import os import logging import torch logger = logging.getLogger(__name__) try: import spaces except ImportError: class spaces: @staticmethod def GPU(duration=60, **kwargs): def decorator(fn): return fn return decorator OUTPUT_DIR = "/tmp/demucs_output" @spaces.GPU(duration=60) def separate_audio(audio_path: str, model_name: str = "htdemucs_ft"): """ Separate audio into vocals and instruments using Demucs. Returns (vocals_path, instruments_path). """ import torchaudio from demucs.pretrained import get_model from demucs.apply import apply_model os.makedirs(OUTPUT_DIR, exist_ok=True) logger.info(f"Loading Demucs model '{model_name}'...") model = get_model(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) logger.info(f"Loading audio: {audio_path}") waveform, sr = torchaudio.load(audio_path) # Resample if needed if sr != model.samplerate: resampler = torchaudio.transforms.Resample(sr, model.samplerate) waveform = resampler(waveform) sr = model.samplerate # Ensure stereo if waveform.shape[0] == 1: waveform = waveform.repeat(2, 1) elif waveform.shape[0] > 2: waveform = waveform[:2] # Apply model logger.info("Separating audio...") ref = waveform.mean(0) std = ref.std() if std < 1e-6: std = torch.tensor(1e-6) waveform = (waveform - ref.mean()) / std sources = apply_model( model, waveform[None].to(device), device=device, progress=True, num_workers=0, ) sources = sources * std + ref.mean() sources = sources[0] # Remove batch dimension # Demucs sources order: drums, bass, other, vocals source_names = model.sources vocals_idx = source_names.index("vocals") vocals = sources[vocals_idx].cpu() # Instruments = everything except vocals instruments = torch.zeros_like(vocals) for i, name in enumerate(source_names): if name != "vocals": instruments += sources[i].cpu() # Save outputs base_name = os.path.splitext(os.path.basename(audio_path))[0] vocals_path = os.path.join(OUTPUT_DIR, f"{base_name}_vocals.wav") instruments_path = os.path.join(OUTPUT_DIR, f"{base_name}_instruments.wav") torchaudio.save(vocals_path, vocals, sr) torchaudio.save(instruments_path, instruments, sr) logger.info(f"Separation complete. Vocals: {vocals_path}, Instruments: {instruments_path}") return vocals_path, instruments_path