techfreakworm commited on
Commit
b992e76
·
unverified ·
1 Parent(s): a81cc03

fix(post): use lower-level demucs API + stringify stem paths for gr.Files

Browse files

demucs 4.0.x has no demucs.api.Separator (added in 4.1). Refactor
separate_stems() to use demucs.pretrained.get_model +
demucs.apply.apply_model so the installed wheel works without a forced
upgrade of the apple-silicon torch stack. Also stringify PosixPath
values before handing to gr.Files — Gradio 6.14's pydantic FileData
model rejects PosixPath inputs.

Files changed (2) hide show
  1. app.py +3 -1
  2. post_process.py +44 -18
app.py CHANGED
@@ -529,7 +529,9 @@ def on_separate_stems(audio_path):
529
  stems = post_process.separate_stems(audio_path)
530
  except Exception as e:
531
  raise gr.Error(f"Demucs failed: {e}") from e
532
- return gr.Files(value=list(stems.values()), visible=True)
 
 
533
 
534
 
535
  def on_normalise(audio_path):
 
529
  stems = post_process.separate_stems(audio_path)
530
  except Exception as e:
531
  raise gr.Error(f"Demucs failed: {e}") from e
532
+ # gr.Files's pydantic FileData model only accepts str paths in Gradio
533
+ # 6.14; PosixPath objects from separate_stems() trip its validator.
534
+ return gr.Files(value=[str(p) for p in stems.values()], visible=True)
535
 
536
 
537
  def on_normalise(audio_path):
post_process.py CHANGED
@@ -11,39 +11,65 @@ _DEMUCS = None
11
 
12
 
13
  def _get_demucs() -> Any:
 
 
 
 
 
 
 
 
14
  global _DEMUCS
15
  if _DEMUCS is None:
16
- from demucs.api import Separator
17
 
18
- _DEMUCS = Separator(model="htdemucs_ft")
19
  return _DEMUCS
20
 
21
 
22
  def separate_stems(audio_path: Path | str) -> dict[str, Path]:
23
- """Split into vocals/drums/bass/other via htdemucs_ft.
24
 
25
- Returns a dict mapping stem name to written file path.
 
 
26
  """
27
- sep = _get_demucs()
28
- result = sep.separate_audio_file(str(audio_path))
29
- # `result` may be either {name: path} OR (origin, separated) tuple
30
- # depending on demucs version. Normalise to dict[str, Path].
31
- if isinstance(result, dict):
32
- return {name: Path(p) for name, p in result.items()}
33
- # Newer demucs returns (origin_tensor, separated_dict_of_tensors)
34
- # We persist tensors next to the input file with stem suffixes.
35
  import soundfile as sf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- _origin, sep_tensors = result
38
  base = Path(audio_path).with_suffix("")
39
  stems: dict[str, Path] = {}
40
- for name, tensor in sep_tensors.items():
41
- out = base.with_name(f"{base.name}.{name}.wav")
42
- data = tensor.detach().cpu().numpy()
 
43
  if data.ndim == 2 and data.shape[0] in (1, 2):
44
  data = data.T
45
- sf.write(str(out), data, sep.samplerate)
46
- stems[name] = out
47
  return stems
48
 
49
 
 
11
 
12
 
13
  def _get_demucs() -> Any:
14
+ """Lazy-load the htdemucs model.
15
+
16
+ Demucs 4.0.x exposes ``demucs.pretrained.get_model`` and
17
+ ``demucs.apply.apply_model`` — the higher-level
18
+ ``demucs.api.Separator`` convenience wrapper only appears in 4.1+.
19
+ We pin to the lower-level API so this works across both pip-installable
20
+ lines without forcing an upgrade on the apple-silicon torch stack.
21
+ """
22
  global _DEMUCS
23
  if _DEMUCS is None:
24
+ from demucs.pretrained import get_model
25
 
26
+ _DEMUCS = get_model("htdemucs")
27
  return _DEMUCS
28
 
29
 
30
  def separate_stems(audio_path: Path | str) -> dict[str, Path]:
31
+ """Split into vocals/drums/bass/other via htdemucs.
32
 
33
+ Uses the lower-level ``demucs.apply.apply_model`` so we don't depend
34
+ on the ``demucs.api.Separator`` wrapper (which only ships with
35
+ demucs >= 4.1). Returns a dict mapping stem name to written file path.
36
  """
 
 
 
 
 
 
 
 
37
  import soundfile as sf
38
+ import torch
39
+ import torchaudio
40
+ from demucs.apply import apply_model
41
+
42
+ model = _get_demucs()
43
+ target_sr = int(getattr(model, "samplerate", 44100))
44
+ sources = list(getattr(model, "sources", ["drums", "bass", "other", "vocals"]))
45
+ audio_channels = int(getattr(model, "audio_channels", 2))
46
+
47
+ waveform, sr = torchaudio.load(str(audio_path)) # (channels, frames)
48
+ if sr != target_sr:
49
+ waveform = torchaudio.functional.resample(waveform, sr, target_sr)
50
+ # Match the model's expected channel count (htdemucs is stereo).
51
+ if waveform.shape[0] == 1 and audio_channels == 2:
52
+ waveform = waveform.repeat(2, 1)
53
+ elif waveform.shape[0] > audio_channels:
54
+ waveform = waveform[:audio_channels]
55
+
56
+ # apply_model expects shape (batch, channels, frames).
57
+ batch = waveform.unsqueeze(0)
58
+ with torch.no_grad():
59
+ # apply_model returns (batch, sources, channels, frames).
60
+ out = apply_model(model, batch, device="cpu", progress=False)
61
+ out = out[0] # drop batch dim -> (sources, channels, frames)
62
 
 
63
  base = Path(audio_path).with_suffix("")
64
  stems: dict[str, Path] = {}
65
+ for idx, name in enumerate(sources):
66
+ out_path = base.with_name(f"{base.name}.{name}.wav")
67
+ data = out[idx].cpu().numpy()
68
+ # soundfile expects (frames, channels); demucs gives (channels, frames)
69
  if data.ndim == 2 and data.shape[0] in (1, 2):
70
  data = data.T
71
+ sf.write(str(out_path), data, target_sr)
72
+ stems[name] = out_path
73
  return stems
74
 
75