voice-separation-model / src /separate.py
Rahma89's picture
Add audio fallback for TorchCodec-free environments
1ea18ce verified
"""
separate.py — Séparation de sources avec le modèle entraîné
Usage :
python main.py separate --mix data/mixture/mix_0/mixture.wav
python main.py separate --mix data/mixture/mix_0/mixture.wav --ckpt checkpoints/best.ckpt
python main.py separate --mix mon_audio.wav --out_dir outputs/separated_audio
"""
import os
import argparse
import torch
import torchaudio
import soundfile as sf
from src.model import build_model, load_checkpoint
import yaml
def load_config(path):
with open(path, "r") as f:
return yaml.safe_load(f)
def parse_args():
p = argparse.ArgumentParser(description="Séparation de sources Conv-TasNet")
p.add_argument("--mix", type=str, required=True,
help="Chemin vers le fichier mixture.wav à séparer")
p.add_argument("--ckpt", type=str, default="checkpoints/best.ckpt",
help="Checkpoint du modèle entraîné")
p.add_argument("--out_dir", type=str, default="outputs/separated_audio",
help="Dossier de sortie pour les sources séparées")
p.add_argument("--train_cfg", type=str, default="configs/train.yaml")
p.add_argument("--data_cfg", type=str, default="configs/data.yaml")
return p.parse_args()
def load_audio(path):
"""Load audio as a mono/stereo tensor, with a fallback for TorchCodec-free envs."""
try:
return torchaudio.load(path)
except ImportError as exc:
if "TorchCodec" not in str(exc) and "torchcodec" not in str(exc):
raise
audio, sr = sf.read(path, dtype="float32", always_2d=True)
waveform = torch.from_numpy(audio).transpose(0, 1)
return waveform, sr
def save_audio(path, waveform, sample_rate):
"""Save audio with torchaudio when available, otherwise fall back to soundfile."""
try:
torchaudio.save(path, waveform.cpu(), sample_rate)
except ImportError as exc:
if "TorchCodec" not in str(exc) and "torchcodec" not in str(exc):
raise
audio = waveform.detach().cpu().transpose(0, 1).numpy()
sf.write(path, audio, sample_rate)
def separate(mix_path, model, sample_rate, device, out_dir):
"""Charge un mixture.wav, sépare les sources, sauvegarde les .wav."""
# ── Charger le fichier audio ─────────────
mixture, sr = load_audio(mix_path)
if sr != sample_rate:
print(f" Resample {sr} Hz → {sample_rate} Hz")
mixture = torchaudio.functional.resample(mixture, sr, sample_rate)
# Mono (1, T)
if mixture.shape[0] > 1:
mixture = mixture.mean(dim=0, keepdim=True)
print(f" Durée : {mixture.shape[-1] / sample_rate:.2f}s "
f"({mixture.shape[-1]} samples)")
# ── Inférence ────────────────────────────
mixture = mixture.to(device) # (1, T)
with torch.no_grad():
# Le modèle attend (B, T) → unsqueeze batch dim
est_sources = model(mixture.unsqueeze(0)) # (1, n_src, T)
est_sources = est_sources.squeeze(0) # (n_src, T)
# ── Sauvegarder les sources séparées ─────
os.makedirs(out_dir, exist_ok=True)
mix_name = os.path.splitext(os.path.basename(mix_path))[0]
for i, src in enumerate(est_sources):
src_cpu = src.unsqueeze(0).cpu() # (1, T)
# Normaliser pour éviter la saturation
max_val = src_cpu.abs().max()
if max_val > 0:
src_cpu = src_cpu / max_val * 0.9
out_path = os.path.join(out_dir, f"{mix_name}_source_{i+1}.wav")
save_audio(out_path, src_cpu, sample_rate)
print(f" ✓ Source {i+1} sauvegardée : {out_path}")
return est_sources
def main():
args = parse_args()
tcfg = load_config(args.train_cfg)
dcfg = load_config(args.data_cfg)
mod = tcfg["model"]
ds = dcfg["dataset"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n[Config] Device : {device}")
print(f"[Config] Checkpoint : {args.ckpt}")
print(f"[Config] Fichier mix : {args.mix}\n")
# ── Charger le modèle ────────────────────
model = build_model(
n_src = ds["n_src"],
sample_rate = ds["sample_rate"],
n_filters = mod["n_filters"],
filter_length = mod["filter_length"],
stride = mod["stride"],
n_blocks = mod["n_blocks"],
n_repeats = mod["n_repeats"],
bn_chan = mod["bn_chan"],
hid_chan = mod["hid_chan"],
skip_chan = mod["skip_chan"],
norm_type = mod["norm_type"],
mask_act = mod["mask_act"],
use_gradient_checkpointing = False, # pas besoin en inférence
)
# ── Charger les poids entraînés ──────────
if not os.path.exists(args.ckpt):
raise FileNotFoundError(
f"Checkpoint introuvable : {args.ckpt}\n"
f"Lancez d'abord : python main.py train"
)
load_checkpoint(model, args.ckpt, device)
model.to(device)
model.eval()
ckpt = torch.load(args.ckpt, map_location="cpu")
epoch = ckpt.get("epoch", "?")
val = ckpt.get("best_val_loss", None)
if val is not None:
print(f"[Model] Checkpoint chargé (epoch {epoch}, val loss {val:.4f})\n")
else:
print(f"[Model] Checkpoint chargé (epoch {epoch})\n")
# ── Séparation ───────────────────────────
separate(
mix_path = args.mix,
model = model,
sample_rate = ds["sample_rate"],
device = device,
out_dir = args.out_dir,
)
print(f"\n[Done] Sources séparées dans : {args.out_dir}/")
if __name__ == "__main__":
main()