| --- |
| license: apache-2.0 |
| tags: |
| - audio |
| - speech |
| - source-separation |
| - conv-tasnet |
| - asteroid |
| - pytorch |
| library_name: pytorch |
| pipeline_tag: audio-to-audio |
| --- |
| |
| # Cocktail Party AI - Conv-TasNet 3-Source Separator |
|
|
| This repository contains the best checkpoint from a Conv-TasNet model trained for speech source separation. |
| The model takes a mixed speech waveform and estimates 3 separated source waveforms. |
|
|
| ## Checkpoint |
|
|
| - File: `best.ckpt` |
| - Architecture: Asteroid `ConvTasNet` |
| - Number of sources: 3 |
| - Sample rate: 16 kHz |
| - Training checkpoint epoch: 68 |
| - Best validation loss: -2.909952 |
| - Approximate validation SI-SNR: 2.91 dB |
|
|
| ## Files |
|
|
| ```text |
| best.ckpt |
| configs/data.yaml |
| configs/train.yaml |
| requirements.txt |
| src/model.py |
| src/separate.py |
| ``` |
|
|
| ## Usage |
|
|
| Install dependencies: |
|
|
| ```bash |
| pip install -r requirements.txt |
| ``` |
|
|
| Load the checkpoint with the project code: |
|
|
| ```python |
| import yaml |
| import torch |
| |
| from src.model import build_model, load_checkpoint |
| |
| with open("configs/train.yaml") as f: |
| train_cfg = yaml.safe_load(f) |
| with open("configs/data.yaml") as f: |
| data_cfg = yaml.safe_load(f) |
| |
| mod = train_cfg["model"] |
| ds = data_cfg["dataset"] |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| model = build_model( |
| n_src=ds["n_src"], |
| sample_rate=ds["sample_rate"], |
| n_filters=mod["n_filters"], |
| filter_length=mod["filter_length"], |
| stride=mod["stride"], |
| n_blocks=mod["n_blocks"], |
| n_repeats=mod["n_repeats"], |
| bn_chan=mod["bn_chan"], |
| hid_chan=mod["hid_chan"], |
| skip_chan=mod["skip_chan"], |
| norm_type=mod["norm_type"], |
| mask_act=mod["mask_act"], |
| use_gradient_checkpointing=False, |
| ).to(device) |
| |
| load_checkpoint(model, "best.ckpt", device) |
| model.eval() |
| ``` |
|
|
| To separate a WAV file using this project: |
|
|
| ```bash |
| python src/separate.py --mix path/to/mixture.wav --ckpt best.ckpt |
| ``` |
|
|
| ## Notes |
|
|
| This is a research/training checkpoint, not a fully packaged `transformers` pipeline. |
| It depends on PyTorch, Torchaudio, and Asteroid. |
|
|