| """ |
| separate.py — Séparation de sources avec le modèle entraîné |
| Usage : |
| python main.py separate --mix data/mixture/mix_0/mixture.wav |
| python main.py separate --mix data/mixture/mix_0/mixture.wav --ckpt checkpoints/best.ckpt |
| python main.py separate --mix mon_audio.wav --out_dir outputs/separated_audio |
| """ |
|
|
| import os |
| import argparse |
| import torch |
| import torchaudio |
| import soundfile as sf |
|
|
| from src.model import build_model, load_checkpoint |
| import yaml |
|
|
|
|
| def load_config(path): |
| with open(path, "r") as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Séparation de sources Conv-TasNet") |
| p.add_argument("--mix", type=str, required=True, |
| help="Chemin vers le fichier mixture.wav à séparer") |
| p.add_argument("--ckpt", type=str, default="checkpoints/best.ckpt", |
| help="Checkpoint du modèle entraîné") |
| p.add_argument("--out_dir", type=str, default="outputs/separated_audio", |
| help="Dossier de sortie pour les sources séparées") |
| p.add_argument("--train_cfg", type=str, default="configs/train.yaml") |
| p.add_argument("--data_cfg", type=str, default="configs/data.yaml") |
| return p.parse_args() |
|
|
|
|
| def load_audio(path): |
| """Load audio as a mono/stereo tensor, with a fallback for TorchCodec-free envs.""" |
| try: |
| return torchaudio.load(path) |
| except ImportError as exc: |
| if "TorchCodec" not in str(exc) and "torchcodec" not in str(exc): |
| raise |
| audio, sr = sf.read(path, dtype="float32", always_2d=True) |
| waveform = torch.from_numpy(audio).transpose(0, 1) |
| return waveform, sr |
|
|
|
|
| def save_audio(path, waveform, sample_rate): |
| """Save audio with torchaudio when available, otherwise fall back to soundfile.""" |
| try: |
| torchaudio.save(path, waveform.cpu(), sample_rate) |
| except ImportError as exc: |
| if "TorchCodec" not in str(exc) and "torchcodec" not in str(exc): |
| raise |
| audio = waveform.detach().cpu().transpose(0, 1).numpy() |
| sf.write(path, audio, sample_rate) |
|
|
|
|
| def separate(mix_path, model, sample_rate, device, out_dir): |
| """Charge un mixture.wav, sépare les sources, sauvegarde les .wav.""" |
|
|
| |
| mixture, sr = load_audio(mix_path) |
|
|
| if sr != sample_rate: |
| print(f" Resample {sr} Hz → {sample_rate} Hz") |
| mixture = torchaudio.functional.resample(mixture, sr, sample_rate) |
|
|
| |
| if mixture.shape[0] > 1: |
| mixture = mixture.mean(dim=0, keepdim=True) |
|
|
| print(f" Durée : {mixture.shape[-1] / sample_rate:.2f}s " |
| f"({mixture.shape[-1]} samples)") |
|
|
| |
| mixture = mixture.to(device) |
| with torch.no_grad(): |
| |
| est_sources = model(mixture.unsqueeze(0)) |
| est_sources = est_sources.squeeze(0) |
|
|
| |
| os.makedirs(out_dir, exist_ok=True) |
| mix_name = os.path.splitext(os.path.basename(mix_path))[0] |
|
|
| for i, src in enumerate(est_sources): |
| src_cpu = src.unsqueeze(0).cpu() |
|
|
| |
| max_val = src_cpu.abs().max() |
| if max_val > 0: |
| src_cpu = src_cpu / max_val * 0.9 |
|
|
| out_path = os.path.join(out_dir, f"{mix_name}_source_{i+1}.wav") |
| save_audio(out_path, src_cpu, sample_rate) |
| print(f" ✓ Source {i+1} sauvegardée : {out_path}") |
|
|
| return est_sources |
|
|
|
|
| def main(): |
| args = parse_args() |
| tcfg = load_config(args.train_cfg) |
| dcfg = load_config(args.data_cfg) |
|
|
| mod = tcfg["model"] |
| ds = dcfg["dataset"] |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"\n[Config] Device : {device}") |
| print(f"[Config] Checkpoint : {args.ckpt}") |
| print(f"[Config] Fichier mix : {args.mix}\n") |
|
|
| |
| model = build_model( |
| n_src = ds["n_src"], |
| sample_rate = ds["sample_rate"], |
| n_filters = mod["n_filters"], |
| filter_length = mod["filter_length"], |
| stride = mod["stride"], |
| n_blocks = mod["n_blocks"], |
| n_repeats = mod["n_repeats"], |
| bn_chan = mod["bn_chan"], |
| hid_chan = mod["hid_chan"], |
| skip_chan = mod["skip_chan"], |
| norm_type = mod["norm_type"], |
| mask_act = mod["mask_act"], |
| use_gradient_checkpointing = False, |
| ) |
|
|
| |
| if not os.path.exists(args.ckpt): |
| raise FileNotFoundError( |
| f"Checkpoint introuvable : {args.ckpt}\n" |
| f"Lancez d'abord : python main.py train" |
| ) |
|
|
| load_checkpoint(model, args.ckpt, device) |
| model.to(device) |
| model.eval() |
|
|
| ckpt = torch.load(args.ckpt, map_location="cpu") |
| epoch = ckpt.get("epoch", "?") |
| val = ckpt.get("best_val_loss", None) |
| if val is not None: |
| print(f"[Model] Checkpoint chargé (epoch {epoch}, val loss {val:.4f})\n") |
| else: |
| print(f"[Model] Checkpoint chargé (epoch {epoch})\n") |
|
|
| |
| separate( |
| mix_path = args.mix, |
| model = model, |
| sample_rate = ds["sample_rate"], |
| device = device, |
| out_dir = args.out_dir, |
| ) |
|
|
| print(f"\n[Done] Sources séparées dans : {args.out_dir}/") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|