techfreakworm commited on
Commit
fc8c46f
·
unverified ·
1 Parent(s): db12e72

test(post): rewrite separate_stems test for demucs.apply.apply_model path

Browse files
Files changed (1) hide show
  1. tests/test_post_process.py +35 -10
tests/test_post_process.py CHANGED
@@ -9,19 +9,44 @@ import post_process as pp
9
 
10
 
11
  def test_separate_stems_returns_four_paths(tmp_path, monkeypatch):
 
 
 
 
 
 
 
 
 
 
12
  src = tmp_path / "song.wav"
13
  src.write_bytes(b"RIFF" + b"\0" * 100)
14
 
15
- fake_sep = MagicMock()
16
- fake_sep.separate_audio_file.return_value = {
17
- "vocals": tmp_path / "vocals.wav",
18
- "drums": tmp_path / "drums.wav",
19
- "bass": tmp_path / "bass.wav",
20
- "other": tmp_path / "other.wav",
21
- }
22
- for k in ("vocals", "drums", "bass", "other"):
23
- (tmp_path / f"{k}.wav").write_bytes(b"RIFF" + b"\0" * 100)
24
- monkeypatch.setattr(pp, "_get_demucs", lambda: fake_sep)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  stems = pp.separate_stems(src)
27
  assert set(stems.keys()) == {"vocals", "drums", "bass", "other"}
 
9
 
10
 
11
  def test_separate_stems_returns_four_paths(tmp_path, monkeypatch):
12
+ """Mocks the lower-level demucs path used by post_process.separate_stems:
13
+ torchaudio.load → apply_model → sf.write. The wrapper convenience API
14
+ (Separator.separate_audio_file) is intentionally not in this code path
15
+ because it only ships with demucs >= 4.1."""
16
+ import sys
17
+ import types
18
+
19
+ import numpy as np
20
+ import torch
21
+
22
  src = tmp_path / "song.wav"
23
  src.write_bytes(b"RIFF" + b"\0" * 100)
24
 
25
+ fake_model = MagicMock()
26
+ fake_model.samplerate = 44100
27
+ fake_model.sources = ["drums", "bass", "other", "vocals"]
28
+ fake_model.audio_channels = 2
29
+ monkeypatch.setattr(pp, "_get_demucs", lambda: fake_model)
30
+
31
+ fake_torchaudio = types.ModuleType("torchaudio")
32
+ fake_torchaudio.load = lambda _path: (torch.zeros((2, 44100)), 44100)
33
+ fake_torchaudio.functional = types.SimpleNamespace(resample=lambda wav, _sr_in, _sr_out: wav)
34
+ monkeypatch.setitem(sys.modules, "torchaudio", fake_torchaudio)
35
+
36
+ fake_demucs_apply = types.ModuleType("demucs.apply")
37
+ fake_demucs_apply.apply_model = lambda _m, batch, **_kw: torch.zeros((batch.shape[0], 4, 2, 44100))
38
+ monkeypatch.setitem(sys.modules, "demucs.apply", fake_demucs_apply)
39
+
40
+ written: list[str] = []
41
+
42
+ def fake_sf_write(path, _data, _sr):
43
+ written.append(path)
44
+ Path(path).write_bytes(b"RIFF" + b"\0" * 100)
45
+
46
+ fake_sf = types.ModuleType("soundfile")
47
+ fake_sf.write = fake_sf_write
48
+ fake_sf.read = lambda _path: (np.zeros((44100, 2)), 44100)
49
+ monkeypatch.setitem(sys.modules, "soundfile", fake_sf)
50
 
51
  stems = pp.separate_stems(src)
52
  assert set(stems.keys()) == {"vocals", "drums", "bass", "other"}