Rahma89 commited on
Commit
1ea18ce
·
verified ·
1 Parent(s): 27441d2

Add audio fallback for TorchCodec-free environments

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. src/separate.py +26 -2
requirements.txt CHANGED
@@ -2,4 +2,5 @@ torch>=2.0.0
2
  torchaudio>=2.0.0
3
  asteroid>=0.6.0
4
  numpy>=1.24.0
5
- matplotlib>=3.10.8
 
 
2
  torchaudio>=2.0.0
3
  asteroid>=0.6.0
4
  numpy>=1.24.0
5
+ matplotlib>=3.10.8
6
+ soundfile>=0.12.1
src/separate.py CHANGED
@@ -10,6 +10,7 @@ import os
10
  import argparse
11
  import torch
12
  import torchaudio
 
13
 
14
  from src.model import build_model, load_checkpoint
15
  import yaml
@@ -33,11 +34,34 @@ def parse_args():
33
  return p.parse_args()
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def separate(mix_path, model, sample_rate, device, out_dir):
37
  """Charge un mixture.wav, sépare les sources, sauvegarde les .wav."""
38
 
39
  # ── Charger le fichier audio ─────────────
40
- mixture, sr = torchaudio.load(mix_path)
41
 
42
  if sr != sample_rate:
43
  print(f" Resample {sr} Hz → {sample_rate} Hz")
@@ -70,7 +94,7 @@ def separate(mix_path, model, sample_rate, device, out_dir):
70
  src_cpu = src_cpu / max_val * 0.9
71
 
72
  out_path = os.path.join(out_dir, f"{mix_name}_source_{i+1}.wav")
73
- torchaudio.save(out_path, src_cpu, sample_rate)
74
  print(f" ✓ Source {i+1} sauvegardée : {out_path}")
75
 
76
  return est_sources
 
10
  import argparse
11
  import torch
12
  import torchaudio
13
+ import soundfile as sf
14
 
15
  from src.model import build_model, load_checkpoint
16
  import yaml
 
34
  return p.parse_args()
35
 
36
 
37
+ def load_audio(path):
38
+ """Load audio as a mono/stereo tensor, with a fallback for TorchCodec-free envs."""
39
+ try:
40
+ return torchaudio.load(path)
41
+ except ImportError as exc:
42
+ if "TorchCodec" not in str(exc) and "torchcodec" not in str(exc):
43
+ raise
44
+ audio, sr = sf.read(path, dtype="float32", always_2d=True)
45
+ waveform = torch.from_numpy(audio).transpose(0, 1)
46
+ return waveform, sr
47
+
48
+
49
+ def save_audio(path, waveform, sample_rate):
50
+ """Save audio with torchaudio when available, otherwise fall back to soundfile."""
51
+ try:
52
+ torchaudio.save(path, waveform.cpu(), sample_rate)
53
+ except ImportError as exc:
54
+ if "TorchCodec" not in str(exc) and "torchcodec" not in str(exc):
55
+ raise
56
+ audio = waveform.detach().cpu().transpose(0, 1).numpy()
57
+ sf.write(path, audio, sample_rate)
58
+
59
+
60
  def separate(mix_path, model, sample_rate, device, out_dir):
61
  """Charge un mixture.wav, sépare les sources, sauvegarde les .wav."""
62
 
63
  # ── Charger le fichier audio ─────────────
64
+ mixture, sr = load_audio(mix_path)
65
 
66
  if sr != sample_rate:
67
  print(f" Resample {sr} Hz → {sample_rate} Hz")
 
94
  src_cpu = src_cpu / max_val * 0.9
95
 
96
  out_path = os.path.join(out_dir, f"{mix_name}_source_{i+1}.wav")
97
+ save_audio(out_path, src_cpu, sample_rate)
98
  print(f" ✓ Source {i+1} sauvegardée : {out_path}")
99
 
100
  return est_sources