| import torch |
| import torchaudio |
|
|
| |
| try: |
| torchaudio.set_audio_backend("soundfile") |
| except: |
| pass |
|
|
| from demucs.apply import apply_model |
| from demucs.pretrained import get_model |
| import os |
| import pathlib |
|
|
| |
| |
| MODELS = { |
| "2stem": "htdemucs", |
| "4stem": "htdemucs", |
| "6stem": "htdemucs_6s", |
| } |
|
|
| class AudioSeparator: |
| def __init__(self): |
| |
| self.models = {} |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| |
| unique_models = set(MODELS.values()) |
| for model_name in unique_models: |
| print(f"Loading Demucs Model: {model_name}...") |
| model = get_model(model_name) |
| model.to(self.device) |
| self.models[model_name] = model |
| print(f"All models loaded on {self.device}") |
|
|
| def separate(self, audio_path: str, output_dir: str, callback=None, mode="4stem"): |
| """ |
| Memisahkan file audio menjadi stems. |
| Modes: 2stem, 4stem, 6stem |
| """ |
| |
| model_name = MODELS.get(mode, "htdemucs") |
| model = self.models[model_name] |
| |
| |
| import soundfile as sf |
| wav_np, sr = sf.read(audio_path) |
| |
| |
| wav = torch.from_numpy(wav_np).float() |
| |
| if wav.ndim == 1: |
| wav = wav.unsqueeze(0) |
| else: |
| wav = wav.t() |
| |
| |
| if sr != 44100: |
| if callback: callback("Resampling audio...", 15) |
| resampler = torchaudio.transforms.Resample(sr, 44100) |
| wav = resampler(wav) |
| sr = 44100 |
| |
| wav = wav.unsqueeze(0).to(self.device) |
|
|
| |
| ref = wav.mean(0) |
| wav = (wav - ref.mean()) / ref.std() |
| |
| if callback: callback("Running Demucs Inference...", 20) |
| print(f"Starting separation with {model_name} (mode: {mode})...") |
| |
| sources = apply_model(model, wav, shifts=1, split=True, overlap=0.25, progress=True)[0] |
| source_names = model.sources |
| |
| results = {} |
| os.makedirs(output_dir, exist_ok=True) |
| |
| total_sources = len(source_names) |
| source_tensors = {name: source for name, source in zip(source_names, sources)} |
| |
| |
| if mode == "2stem": |
| |
| if callback: callback("Merging to 2 stems...", 45) |
| |
| vocals = source_tensors.get('vocals') |
| instruments = None |
| for name, src in source_tensors.items(): |
| if name != 'vocals': |
| if instruments is None: |
| instruments = src.clone() |
| else: |
| instruments += src |
| |
| if vocals is not None: |
| self._save_audio(vocals, sr, os.path.join(output_dir, "vocals.mp3")) |
| results['vocals'] = os.path.join(output_dir, "vocals.mp3") |
| if instruments is not None: |
| self._save_audio(instruments, sr, os.path.join(output_dir, "instruments.mp3")) |
| results['instruments'] = os.path.join(output_dir, "instruments.mp3") |
| |
| elif mode == "6stem": |
| |
| for i, (name, source) in enumerate(source_tensors.items()): |
| progress = 30 + int((i / total_sources) * 20) |
| if callback: callback(f"Saving stem: {name}", progress) |
| |
| if name == 'guitar': |
| results.update(self._process_guitar(source, sr, output_dir)) |
| else: |
| stem_path = os.path.join(output_dir, f"{name}.mp3") |
| self._save_audio(source, sr, stem_path) |
| results[name] = stem_path |
| else: |
| |
| for i, (name, source) in enumerate(source_tensors.items()): |
| progress = 30 + int((i / total_sources) * 20) |
| if callback: callback(f"Saving stem: {name}", progress) |
| |
| stem_path = os.path.join(output_dir, f"{name}.mp3") |
| self._save_audio(source, sr, stem_path) |
| results[name] = stem_path |
| |
| input_duration = len(wav_np) / sr |
| return results, input_duration |
|
|
| def _process_guitar(self, source, sr, output_dir): |
| """ |
| Memisahkan stem gitar menjadi Lead dan Rhythm menggunakan Mid-Side processing. |
| - Mid (center) = Rhythm (biasanya power chords, strumming di center) |
| - Side (stereo difference) = Lead (biasanya di-pan atau dengan stereo effects) |
| """ |
| |
| |
| |
| if source.shape[0] < 2: |
| print("Warning: Guitar stem is Mono. Cannot split Rhythm/Lead.") |
| path = os.path.join(output_dir, "guitar.mp3") |
| self._save_audio(source, sr, path) |
| return {"guitar_rhythm": path, "guitar_lead": path} |
|
|
| |
| left = source[0:1, :] |
| right = source[1:2, :] |
|
|
| |
| |
| |
| |
| |
| |
| mean_l = left.mean() |
| mean_r = right.mean() |
| var_l = ((left - mean_l)**2).mean() |
| var_r = ((right - mean_r)**2).mean() |
| cov = ((left - mean_l) * (right - mean_r)).mean() |
| |
| correlation = 0.0 |
| if var_l > 0 and var_r > 0: |
| correlation = cov / torch.sqrt(var_l * var_r) |
| |
| print(f"Guitar Stereo Correlation: {correlation:.4f}") |
| |
| |
| if abs(correlation) < 0.6: |
| print("Detected Wide Stereo Guitar (Math Rock Style). Using Spatial Split (L=Rhythm, R=Lead).") |
| |
| |
| |
| rhythm_stereo = torch.cat([left, left], dim=0) |
| lead_stereo = torch.cat([right, right], dim=0) |
| else: |
| print("Detected Narrow/Mono Guitar. Using Mid-Side Frequency Split.") |
| |
| |
| mid = (left + right) / 2.0 |
| side = (left - right) / 2.0 |
| |
| try: |
| import scipy.signal as signal |
| nyquist = sr / 2 |
| |
| |
| |
| rhythm_low = 80 / nyquist |
| rhythm_high = 1200 / nyquist |
| b_r, a_r = signal.butter(4, [rhythm_low, rhythm_high], btype='band') |
| |
| |
| lead_low = 1000 / nyquist |
| lead_high = 8000 / nyquist |
| b_l, a_l = signal.butter(4, [lead_low, lead_high], btype='band') |
| |
| |
| |
| rhythm_from_mid = signal.filtfilt(b_r, a_r, mid.numpy()) |
| lead_from_mid = signal.filtfilt(b_l, a_l, mid.numpy()) |
| |
| |
| |
| side_np = side.numpy() |
| |
| rhythm_final = rhythm_from_mid |
| lead_final = lead_from_mid + (side_np * 1.5) |
| |
| rhythm_stereo = torch.from_numpy(rhythm_final).float() |
| rhythm_stereo = torch.cat([rhythm_stereo, rhythm_stereo], dim=0) |
| |
| lead_stereo = torch.from_numpy(lead_final).float() |
| lead_stereo = torch.cat([lead_stereo, lead_stereo], dim=0) |
| |
| except Exception as e: |
| print(f"Filter failed: {e}. Fallback to raw.") |
| rhythm_stereo = torch.cat([left, left], dim=0) |
| lead_stereo = torch.cat([right, right], dim=0) |
| |
| |
| def normalize(tensor): |
| peak = tensor.abs().max() |
| if peak > 0: |
| target_peak = 0.89 |
| return tensor * (target_peak / peak) |
| return tensor |
|
|
| rhythm_stereo = normalize(rhythm_stereo) |
| lead_stereo = normalize(lead_stereo) |
| |
| |
| |
| |
| |
| |
| rhythm_mono = rhythm_stereo.mean(dim=0, keepdim=True) |
| lead_mono = lead_stereo.mean(dim=0, keepdim=True) |
| |
| |
| guitar_split = torch.cat([rhythm_mono, lead_mono], dim=0) |
| |
| guitar_split = normalize(guitar_split) |
| |
| |
| path = os.path.join(output_dir, "guitar.mp3") |
| self._save_audio(guitar_split, sr, path) |
| |
| return { |
| "guitar": path |
| } |
|
|
| def _save_audio(self, source, sr, path): |
| |
| |
| source = source.cpu() |
|
|
| |
| peak = source.abs().max() |
| if peak > 0.89: |
| source = source / peak * 0.89 |
| |
| |
| |
| import soundfile as sf |
| sf.write(path, source.t().numpy(), sr) |
|
|