File size: 5,918 Bytes
27441d2 1ea18ce 27441d2 1ea18ce 27441d2 1ea18ce 27441d2 1ea18ce 27441d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """
separate.py — Séparation de sources avec le modèle entraîné
Usage :
python main.py separate --mix data/mixture/mix_0/mixture.wav
python main.py separate --mix data/mixture/mix_0/mixture.wav --ckpt checkpoints/best.ckpt
python main.py separate --mix mon_audio.wav --out_dir outputs/separated_audio
"""
import os
import argparse
import torch
import torchaudio
import soundfile as sf
from src.model import build_model, load_checkpoint
import yaml
def load_config(path):
with open(path, "r") as f:
return yaml.safe_load(f)
def parse_args():
p = argparse.ArgumentParser(description="Séparation de sources Conv-TasNet")
p.add_argument("--mix", type=str, required=True,
help="Chemin vers le fichier mixture.wav à séparer")
p.add_argument("--ckpt", type=str, default="checkpoints/best.ckpt",
help="Checkpoint du modèle entraîné")
p.add_argument("--out_dir", type=str, default="outputs/separated_audio",
help="Dossier de sortie pour les sources séparées")
p.add_argument("--train_cfg", type=str, default="configs/train.yaml")
p.add_argument("--data_cfg", type=str, default="configs/data.yaml")
return p.parse_args()
def load_audio(path):
"""Load audio as a mono/stereo tensor, with a fallback for TorchCodec-free envs."""
try:
return torchaudio.load(path)
except ImportError as exc:
if "TorchCodec" not in str(exc) and "torchcodec" not in str(exc):
raise
audio, sr = sf.read(path, dtype="float32", always_2d=True)
waveform = torch.from_numpy(audio).transpose(0, 1)
return waveform, sr
def save_audio(path, waveform, sample_rate):
"""Save audio with torchaudio when available, otherwise fall back to soundfile."""
try:
torchaudio.save(path, waveform.cpu(), sample_rate)
except ImportError as exc:
if "TorchCodec" not in str(exc) and "torchcodec" not in str(exc):
raise
audio = waveform.detach().cpu().transpose(0, 1).numpy()
sf.write(path, audio, sample_rate)
def separate(mix_path, model, sample_rate, device, out_dir):
"""Charge un mixture.wav, sépare les sources, sauvegarde les .wav."""
# ── Charger le fichier audio ─────────────
mixture, sr = load_audio(mix_path)
if sr != sample_rate:
print(f" Resample {sr} Hz → {sample_rate} Hz")
mixture = torchaudio.functional.resample(mixture, sr, sample_rate)
# Mono (1, T)
if mixture.shape[0] > 1:
mixture = mixture.mean(dim=0, keepdim=True)
print(f" Durée : {mixture.shape[-1] / sample_rate:.2f}s "
f"({mixture.shape[-1]} samples)")
# ── Inférence ────────────────────────────
mixture = mixture.to(device) # (1, T)
with torch.no_grad():
# Le modèle attend (B, T) → unsqueeze batch dim
est_sources = model(mixture.unsqueeze(0)) # (1, n_src, T)
est_sources = est_sources.squeeze(0) # (n_src, T)
# ── Sauvegarder les sources séparées ─────
os.makedirs(out_dir, exist_ok=True)
mix_name = os.path.splitext(os.path.basename(mix_path))[0]
for i, src in enumerate(est_sources):
src_cpu = src.unsqueeze(0).cpu() # (1, T)
# Normaliser pour éviter la saturation
max_val = src_cpu.abs().max()
if max_val > 0:
src_cpu = src_cpu / max_val * 0.9
out_path = os.path.join(out_dir, f"{mix_name}_source_{i+1}.wav")
save_audio(out_path, src_cpu, sample_rate)
print(f" ✓ Source {i+1} sauvegardée : {out_path}")
return est_sources
def main():
args = parse_args()
tcfg = load_config(args.train_cfg)
dcfg = load_config(args.data_cfg)
mod = tcfg["model"]
ds = dcfg["dataset"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n[Config] Device : {device}")
print(f"[Config] Checkpoint : {args.ckpt}")
print(f"[Config] Fichier mix : {args.mix}\n")
# ── Charger le modèle ────────────────────
model = build_model(
n_src = ds["n_src"],
sample_rate = ds["sample_rate"],
n_filters = mod["n_filters"],
filter_length = mod["filter_length"],
stride = mod["stride"],
n_blocks = mod["n_blocks"],
n_repeats = mod["n_repeats"],
bn_chan = mod["bn_chan"],
hid_chan = mod["hid_chan"],
skip_chan = mod["skip_chan"],
norm_type = mod["norm_type"],
mask_act = mod["mask_act"],
use_gradient_checkpointing = False, # pas besoin en inférence
)
# ── Charger les poids entraînés ──────────
if not os.path.exists(args.ckpt):
raise FileNotFoundError(
f"Checkpoint introuvable : {args.ckpt}\n"
f"Lancez d'abord : python main.py train"
)
load_checkpoint(model, args.ckpt, device)
model.to(device)
model.eval()
ckpt = torch.load(args.ckpt, map_location="cpu")
epoch = ckpt.get("epoch", "?")
val = ckpt.get("best_val_loss", None)
if val is not None:
print(f"[Model] Checkpoint chargé (epoch {epoch}, val loss {val:.4f})\n")
else:
print(f"[Model] Checkpoint chargé (epoch {epoch})\n")
# ── Séparation ───────────────────────────
separate(
mix_path = args.mix,
model = model,
sample_rate = ds["sample_rate"],
device = device,
out_dir = args.out_dir,
)
print(f"\n[Done] Sources séparées dans : {args.out_dir}/")
if __name__ == "__main__":
main()
|