Add audio fallback for TorchCodec-free environments
Browse files- requirements.txt +2 -1
- src/separate.py +26 -2
requirements.txt
CHANGED
|
@@ -2,4 +2,5 @@ torch>=2.0.0
|
|
| 2 |
torchaudio>=2.0.0
|
| 3 |
asteroid>=0.6.0
|
| 4 |
numpy>=1.24.0
|
| 5 |
-
matplotlib>=3.10.8
|
|
|
|
|
|
| 2 |
torchaudio>=2.0.0
|
| 3 |
asteroid>=0.6.0
|
| 4 |
numpy>=1.24.0
|
| 5 |
+
matplotlib>=3.10.8
|
| 6 |
+
soundfile>=0.12.1
|
src/separate.py
CHANGED
|
@@ -10,6 +10,7 @@ import os
|
|
| 10 |
import argparse
|
| 11 |
import torch
|
| 12 |
import torchaudio
|
|
|
|
| 13 |
|
| 14 |
from src.model import build_model, load_checkpoint
|
| 15 |
import yaml
|
|
@@ -33,11 +34,34 @@ def parse_args():
|
|
| 33 |
return p.parse_args()
|
| 34 |
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
def separate(mix_path, model, sample_rate, device, out_dir):
|
| 37 |
"""Charge un mixture.wav, sépare les sources, sauvegarde les .wav."""
|
| 38 |
|
| 39 |
# ── Charger le fichier audio ─────────────
|
| 40 |
-
mixture, sr =
|
| 41 |
|
| 42 |
if sr != sample_rate:
|
| 43 |
print(f" Resample {sr} Hz → {sample_rate} Hz")
|
|
@@ -70,7 +94,7 @@ def separate(mix_path, model, sample_rate, device, out_dir):
|
|
| 70 |
src_cpu = src_cpu / max_val * 0.9
|
| 71 |
|
| 72 |
out_path = os.path.join(out_dir, f"{mix_name}_source_{i+1}.wav")
|
| 73 |
-
|
| 74 |
print(f" ✓ Source {i+1} sauvegardée : {out_path}")
|
| 75 |
|
| 76 |
return est_sources
|
|
|
|
| 10 |
import argparse
|
| 11 |
import torch
|
| 12 |
import torchaudio
|
| 13 |
+
import soundfile as sf
|
| 14 |
|
| 15 |
from src.model import build_model, load_checkpoint
|
| 16 |
import yaml
|
|
|
|
| 34 |
return p.parse_args()
|
| 35 |
|
| 36 |
|
| 37 |
+
def load_audio(path):
|
| 38 |
+
"""Load audio as a mono/stereo tensor, with a fallback for TorchCodec-free envs."""
|
| 39 |
+
try:
|
| 40 |
+
return torchaudio.load(path)
|
| 41 |
+
except ImportError as exc:
|
| 42 |
+
if "TorchCodec" not in str(exc) and "torchcodec" not in str(exc):
|
| 43 |
+
raise
|
| 44 |
+
audio, sr = sf.read(path, dtype="float32", always_2d=True)
|
| 45 |
+
waveform = torch.from_numpy(audio).transpose(0, 1)
|
| 46 |
+
return waveform, sr
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def save_audio(path, waveform, sample_rate):
|
| 50 |
+
"""Save audio with torchaudio when available, otherwise fall back to soundfile."""
|
| 51 |
+
try:
|
| 52 |
+
torchaudio.save(path, waveform.cpu(), sample_rate)
|
| 53 |
+
except ImportError as exc:
|
| 54 |
+
if "TorchCodec" not in str(exc) and "torchcodec" not in str(exc):
|
| 55 |
+
raise
|
| 56 |
+
audio = waveform.detach().cpu().transpose(0, 1).numpy()
|
| 57 |
+
sf.write(path, audio, sample_rate)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
def separate(mix_path, model, sample_rate, device, out_dir):
|
| 61 |
"""Charge un mixture.wav, sépare les sources, sauvegarde les .wav."""
|
| 62 |
|
| 63 |
# ── Charger le fichier audio ─────────────
|
| 64 |
+
mixture, sr = load_audio(mix_path)
|
| 65 |
|
| 66 |
if sr != sample_rate:
|
| 67 |
print(f" Resample {sr} Hz → {sample_rate} Hz")
|
|
|
|
| 94 |
src_cpu = src_cpu / max_val * 0.9
|
| 95 |
|
| 96 |
out_path = os.path.join(out_dir, f"{mix_name}_source_{i+1}.wav")
|
| 97 |
+
save_audio(out_path, src_cpu, sample_rate)
|
| 98 |
print(f" ✓ Source {i+1} sauvegardée : {out_path}")
|
| 99 |
|
| 100 |
return est_sources
|