Spaces:
Running on Zero
Running on Zero
test(post): rewrite separate_stems test for demucs.apply.apply_model path
Browse files- tests/test_post_process.py +35 -10
tests/test_post_process.py
CHANGED
|
@@ -9,19 +9,44 @@ import post_process as pp
|
|
| 9 |
|
| 10 |
|
| 11 |
def test_separate_stems_returns_four_paths(tmp_path, monkeypatch):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
src = tmp_path / "song.wav"
|
| 13 |
src.write_bytes(b"RIFF" + b"\0" * 100)
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
monkeypatch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
stems = pp.separate_stems(src)
|
| 27 |
assert set(stems.keys()) == {"vocals", "drums", "bass", "other"}
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def test_separate_stems_returns_four_paths(tmp_path, monkeypatch):
|
| 12 |
+
"""Mocks the lower-level demucs path used by post_process.separate_stems:
|
| 13 |
+
torchaudio.load → apply_model → sf.write. The wrapper convenience API
|
| 14 |
+
(Separator.separate_audio_file) is intentionally not in this code path
|
| 15 |
+
because it only ships with demucs >= 4.1."""
|
| 16 |
+
import sys
|
| 17 |
+
import types
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
src = tmp_path / "song.wav"
|
| 23 |
src.write_bytes(b"RIFF" + b"\0" * 100)
|
| 24 |
|
| 25 |
+
fake_model = MagicMock()
|
| 26 |
+
fake_model.samplerate = 44100
|
| 27 |
+
fake_model.sources = ["drums", "bass", "other", "vocals"]
|
| 28 |
+
fake_model.audio_channels = 2
|
| 29 |
+
monkeypatch.setattr(pp, "_get_demucs", lambda: fake_model)
|
| 30 |
+
|
| 31 |
+
fake_torchaudio = types.ModuleType("torchaudio")
|
| 32 |
+
fake_torchaudio.load = lambda _path: (torch.zeros((2, 44100)), 44100)
|
| 33 |
+
fake_torchaudio.functional = types.SimpleNamespace(resample=lambda wav, _sr_in, _sr_out: wav)
|
| 34 |
+
monkeypatch.setitem(sys.modules, "torchaudio", fake_torchaudio)
|
| 35 |
+
|
| 36 |
+
fake_demucs_apply = types.ModuleType("demucs.apply")
|
| 37 |
+
fake_demucs_apply.apply_model = lambda _m, batch, **_kw: torch.zeros((batch.shape[0], 4, 2, 44100))
|
| 38 |
+
monkeypatch.setitem(sys.modules, "demucs.apply", fake_demucs_apply)
|
| 39 |
+
|
| 40 |
+
written: list[str] = []
|
| 41 |
+
|
| 42 |
+
def fake_sf_write(path, _data, _sr):
|
| 43 |
+
written.append(path)
|
| 44 |
+
Path(path).write_bytes(b"RIFF" + b"\0" * 100)
|
| 45 |
+
|
| 46 |
+
fake_sf = types.ModuleType("soundfile")
|
| 47 |
+
fake_sf.write = fake_sf_write
|
| 48 |
+
fake_sf.read = lambda _path: (np.zeros((44100, 2)), 44100)
|
| 49 |
+
monkeypatch.setitem(sys.modules, "soundfile", fake_sf)
|
| 50 |
|
| 51 |
stems = pp.separate_stems(src)
|
| 52 |
assert set(stems.keys()) == {"vocals", "drums", "bass", "other"}
|