| |
|
|
| import gradio as gr |
| import torch |
| import torchaudio |
| from demucs.pretrained import get_model |
| from demucs.apply import apply_model |
| import os |
| import tempfile |
| import numpy as np |
| import warnings |
| import soundfile as sf |
| import librosa |
| import time |
| warnings.filterwarnings("ignore") |
|
|
| |
| print("Setting up models...") |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"Using device: {device}") |
|
|
| |
| print("Loading HT-Demucs model...") |
| htdemucs_model = get_model(name="htdemucs") |
| htdemucs_model = htdemucs_model.to(device) |
| htdemucs_model.eval() |
| print("HT-Demucs model loaded successfully.") |
|
|
| |
| print("Setting up Spleeter...") |
| spleeter_separator = None |
| spleeter_audio_adapter = None |
| spleeter_available = False |
|
|
| def patch_spleeter_redirects(): |
| """Patch Spleeter to handle GitHub redirects properly""" |
| try: |
| import httpx |
| from spleeter.model.provider.github import GithubModelProvider |
| |
| |
| original_download = GithubModelProvider.download |
| |
| def patched_download(self, name, model_directory): |
| """Patched download method that handles redirects""" |
| import os |
| import tarfile |
| import tempfile |
| from urllib.parse import urlparse |
| |
| print(f"Downloading {name} model with redirect handling...") |
| |
| |
| model_urls = { |
| '5stems': 'https://github.com/deezer/spleeter/releases/download/v1.4.0/5stems.tar.gz' |
| } |
| |
| if name not in model_urls: |
| return original_download(self, name, model_directory) |
| |
| url = model_urls[name] |
| |
| try: |
| |
| with httpx.Client(follow_redirects=True, timeout=300) as client: |
| print(f"Downloading from: {url}") |
| response = client.get(url) |
| response.raise_for_status() |
| |
| |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.tar.gz') as tmp_file: |
| tmp_file.write(response.content) |
| tmp_file_path = tmp_file.name |
| |
| print(f"Downloaded {len(response.content)} bytes") |
| |
| |
| os.makedirs(model_directory, exist_ok=True) |
| with tarfile.open(tmp_file_path, 'r:gz') as tar: |
| tar.extractall(model_directory) |
| |
| |
| os.unlink(tmp_file_path) |
| print(f"β
Successfully downloaded and extracted {name} model") |
| |
| except Exception as e: |
| print(f"β Failed to download {name} model: {e}") |
| |
| return original_download(self, name, model_directory) |
| |
| |
| GithubModelProvider.download = patched_download |
| print("β
Patched Spleeter to handle GitHub redirects") |
| return True |
| |
| except Exception as e: |
| print(f"β οΈ Could not patch Spleeter redirects: {e}") |
| return False |
|
|
| def setup_spleeter_with_retry(): |
| """Setup Spleeter 5stems model only""" |
| global spleeter_separator, spleeter_audio_adapter, spleeter_available |
| |
| try: |
| from spleeter.separator import Separator |
| from spleeter.audio.adapter import AudioAdapter |
| import os |
| |
| |
| patch_spleeter_redirects() |
| |
| |
| os.environ['SPLEETER_MODEL_PATH'] = '/tmp/spleeter_models' |
| |
| |
| print("Creating Spleeter 5stems separator...") |
| spleeter_separator = Separator('spleeter:5stems') |
| spleeter_audio_adapter = AudioAdapter.default() |
| spleeter_available = True |
| print("β
Spleeter 5stems model loaded successfully!") |
| return True |
| |
| except Exception as e: |
| print(f"β Failed to load Spleeter 5stems: {e}") |
| spleeter_separator = None |
| spleeter_audio_adapter = None |
| spleeter_available = False |
| return False |
|
|
| |
| setup_spleeter_with_retry() |
|
|
| |
| def separate_with_htdemucs(audio_path): |
| """ |
| Separates an audio file using HT-Demucs into drums, bass, other, and vocals. |
| Returns FILE PATHS. |
| """ |
| if audio_path is None: |
| return None, None, None, None, "Please upload an audio file." |
|
|
| try: |
| print(f"HT-Demucs: Loading audio from: {audio_path}") |
| |
| |
| wav, sr = torchaudio.load(audio_path) |
|
|
| if wav.shape[0] == 1: |
| print("Audio is mono, converting to stereo.") |
| wav = wav.repeat(2, 1) |
|
|
| wav = wav.to(device) |
|
|
| print("HT-Demucs: Applying the separation model...") |
| with torch.no_grad(): |
| sources = apply_model(htdemucs_model, wav[None], device=device, progress=True)[0] |
| print("HT-Demucs: Separation complete.") |
|
|
| |
| timestamp = int(time.time() * 1000) |
| output_dir = f"htdemucs_stems_{timestamp}" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| stem_names = ["drums", "bass", "other", "vocals"] |
|
|
| output_paths = [] |
| for i, name in enumerate(stem_names): |
| out_path = os.path.join(output_dir, f"{name}_{timestamp}.wav") |
| torchaudio.save(out_path, sources[i].cpu(), sr) |
| output_paths.append(out_path) |
| print(f"β
HT-Demucs saved {name} to {out_path}") |
|
|
| return output_paths[0], output_paths[1], output_paths[2], output_paths[3], "β
HT-Demucs separation successful!" |
|
|
| except Exception as e: |
| print(f"HT-Demucs Error: {e}") |
| return None, None, None, None, f"β HT-Demucs Error: {str(e)}" |
|
|
| |
| def separate_with_spleeter(audio_path): |
| """ |
| Separates an audio file using Spleeter into vocals, drums, bass, other, and piano. |
| Uses Python API approach from stem_separation_spleeter.py |
| Returns FILE PATHS. |
| """ |
| if audio_path is None: |
| return None, None, None, None, None, "Please upload an audio file." |
|
|
| if not spleeter_available or spleeter_separator is None or spleeter_audio_adapter is None: |
| return None, None, None, None, None, "β Spleeter not available. Please install Spleeter." |
|
|
| try: |
| print(f"Spleeter: Processing audio from: {audio_path}") |
| |
| |
| timestamp = int(time.time() * 1000) |
| output_dir = f"spleeter_stems_{timestamp}" |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| print("Spleeter: Loading audio...") |
| waveform, sample_rate = spleeter_audio_adapter.load(audio_path, sample_rate=44100) |
| print(f"Spleeter: Loaded audio - shape: {waveform.shape}, sr: {sample_rate}") |
| |
| |
| print("Spleeter: Separating audio sources...") |
| prediction = spleeter_separator.separate(waveform) |
| print("Spleeter: Separation complete.") |
| print(f"Spleeter: Prediction keys: {list(prediction.keys())}") |
| |
| |
| output_paths = [] |
| stem_names = ["vocals", "drums", "bass", "other", "piano"] |
| |
| for stem_name in stem_names: |
| if stem_name in prediction: |
| out_path = os.path.join(output_dir, f"{stem_name}_{timestamp}.wav") |
| stem_audio = prediction[stem_name] |
| |
| print(f"Spleeter: {stem_name} audio shape: {stem_audio.shape}, dtype: {stem_audio.dtype}") |
| |
| |
| sf.write(out_path, stem_audio, sample_rate) |
| output_paths.append(out_path) |
| print(f"β
Spleeter saved {stem_name} to {out_path}") |
| else: |
| print(f"β οΈ Warning: {stem_name} not found in prediction") |
| output_paths.append(None) |
| |
| |
| while len(output_paths) < 5: |
| output_paths.append(None) |
|
|
| return output_paths[0], output_paths[1], output_paths[2], output_paths[3], output_paths[4], "β
Spleeter separation successful!" |
|
|
| except Exception as e: |
| print(f"Spleeter Error: {e}") |
| import traceback |
| traceback.print_exc() |
| return None, None, None, None, None, f"β Spleeter Error: {str(e)}" |
|
|
| |
| def separate_selected_models(audio_path, run_htdemucs, run_spleeter): |
| """ |
| Separates an audio file using selected models (HT-Demucs, Spleeter, or both). |
| Returns stems from selected models. |
| """ |
| if audio_path is None: |
| return [None] * 11, "Please upload an audio file." |
|
|
| if not run_htdemucs and not run_spleeter: |
| return [None] * 11, "β Please select at least one model to run." |
|
|
| try: |
| htdemucs_results = [None] * 5 |
| spleeter_results = [None] * 6 |
| status_messages = [] |
| |
| |
| if run_htdemucs: |
| print("Running HT-Demucs...") |
| htdemucs_results = separate_with_htdemucs(audio_path) |
| status_messages.append(htdemucs_results[-1]) |
| |
| |
| if run_spleeter: |
| print("Running Spleeter...") |
| spleeter_results = separate_with_spleeter(audio_path) |
| status_messages.append(spleeter_results[-1]) |
| |
| |
| all_results = list(htdemucs_results[:-1]) + list(spleeter_results[:-1]) |
| |
| |
| models_used = [] |
| if run_htdemucs: |
| models_used.append("HT-Demucs") |
| if run_spleeter: |
| models_used.append("Spleeter") |
| |
| combined_status = f"π΅ {' + '.join(models_used)} completed!\n\n" + "\n".join(status_messages) |
| |
| return all_results + [combined_status] |
|
|
| except Exception as e: |
| print(f"Combined Error: {e}") |
| import traceback |
| traceback.print_exc() |
| return [None] * 11, f"β Error: {str(e)}" |
|
|
| |
| print("Creating Gradio interface...") |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown(""" |
| # π΅ Spleeter & Demucs - Now Both Work! |
| |
| **Follow me on:** [ Hugging Face @ahk-d](https://huggingface.co/ahk-d) | [ GitHub @ahk-d](https://github.com/ahk-d) |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| audio_input = gr.Audio(type="filepath", label="π΅ Upload Your Song") |
| |
| |
| gr.Markdown("### ποΈ Select Models to Run") |
| with gr.Row(): |
| htdemucs_toggle = gr.Checkbox(label="π― HT-Demucs", value=True, info="Drums, Bass, Other, Vocals") |
| spleeter_label = "π΅ Spleeter 2025 (5stems)" if spleeter_available else "π΅ Spleeter 2025" |
| spleeter_info = "Vocals, Drums, Bass, Other, Piano" if spleeter_available else "5stems model not available" |
| spleeter_toggle = gr.Checkbox( |
| label=spleeter_label, |
| value=spleeter_available, |
| info=spleeter_info, |
| interactive=spleeter_available |
| ) |
| |
| separate_button = gr.Button("π Separate Music", variant="primary", size="lg") |
| status_output = gr.Textbox(label="π Status", interactive=False, lines=4) |
|
|
| gr.Markdown("---") |
|
|
| with gr.Row(): |
| |
| with gr.Column(): |
| gr.Markdown("### π― HT-Demucs Results") |
| with gr.Row(): |
| htdemucs_drums = gr.Audio(label="π₯ Drums", type="filepath") |
| htdemucs_bass = gr.Audio(label="πΈ Bass", type="filepath") |
| with gr.Row(): |
| htdemucs_other = gr.Audio(label="πΌ Other", type="filepath") |
| htdemucs_vocals = gr.Audio(label="π€ Vocals", type="filepath") |
| |
| |
| with gr.Column(): |
| gr.Markdown("### π΅ Spleeter 2025 Results") |
| with gr.Row(): |
| spleeter_vocals = gr.Audio(label="π€ Vocals", type="filepath") |
| spleeter_drums = gr.Audio(label="π₯ Drums", type="filepath") |
| with gr.Row(): |
| spleeter_bass = gr.Audio(label="πΈ Bass", type="filepath") |
| spleeter_other = gr.Audio(label="πΌ Other", type="filepath") |
| with gr.Row(): |
| spleeter_piano = gr.Audio(label="πΉ Piano", type="filepath") |
| |
| if spleeter_available: |
| gr.Markdown("*5stems model: Vocals, Drums, Bass, Other, Piano*") |
| else: |
| gr.Markdown("*Note: Spleeter 5stems model not available*") |
|
|
| gr.Markdown("---") |
| |
| with gr.Row(): |
| comparison_text = f""" |
| ### π Model Comparison |
| |
| | Feature | HT-Demucs | Spleeter 2025 (5stems) | |
| |---------|-----------|----------| |
| | **Vocals** | β
High Quality | {'β
Available' if spleeter_available else 'β N/A'} | |
| | **Drums** | β
High Quality | {'β
Available' if spleeter_available else 'β N/A'} | |
| | **Bass** | β
High Quality | {'β
Available' if spleeter_available else 'β N/A'} | |
| | **Other** | β
High Quality | {'β
Available' if spleeter_available else 'β N/A'} | |
| | **Piano** | β Not Available | {'β
**Available**' if spleeter_available else 'β N/A'} | |
| | **Speed** | β‘ Fast | {'β‘ Fast' if spleeter_available else 'β N/A'} | |
| | **Quality** | π Excellent | {'π Good' if spleeter_available else 'β N/A'} | |
| |
| **π‘ Tip:** Use Spleeter 2025 for piano separation, HT-Demucs for other instruments! |
| """ |
| gr.Markdown(comparison_text) |
|
|
| |
| separate_button.click( |
| fn=separate_selected_models, |
| inputs=[audio_input, htdemucs_toggle, spleeter_toggle], |
| outputs=[ |
| htdemucs_drums, htdemucs_bass, htdemucs_other, htdemucs_vocals, |
| spleeter_vocals, spleeter_drums, spleeter_bass, spleeter_other, spleeter_piano, |
| status_output |
| ] |
| ) |
|
|
| gr.Markdown(""" |
| --- |
| <p style='text-align: center; font-size: small;'> |
| π Powered by <strong>HT-Demucs</strong> & <strong>Spleeter 2025</strong> | |
| π΅ Compare and choose your best stems! |
| </p> |
| """) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |