File size: 2,007 Bytes
27441d2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | ---
license: apache-2.0
tags:
- audio
- speech
- source-separation
- conv-tasnet
- asteroid
- pytorch
library_name: pytorch
pipeline_tag: audio-to-audio
---
# Cocktail Party AI - Conv-TasNet 3-Source Separator
This repository contains the best checkpoint from a Conv-TasNet model trained for speech source separation.
The model takes a mixed speech waveform and estimates 3 separated source waveforms.
## Checkpoint
- File: `best.ckpt`
- Architecture: Asteroid `ConvTasNet`
- Number of sources: 3
- Sample rate: 16 kHz
- Training checkpoint epoch: 68
- Best validation loss: -2.909952
- Approximate validation SI-SNR: 2.91 dB
## Files
```text
best.ckpt
configs/data.yaml
configs/train.yaml
requirements.txt
src/model.py
src/separate.py
```
## Usage
Install dependencies:
```bash
pip install -r requirements.txt
```
Load the checkpoint with the project code:
```python
import yaml
import torch
from src.model import build_model, load_checkpoint
with open("configs/train.yaml") as f:
train_cfg = yaml.safe_load(f)
with open("configs/data.yaml") as f:
data_cfg = yaml.safe_load(f)
mod = train_cfg["model"]
ds = data_cfg["dataset"]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model(
n_src=ds["n_src"],
sample_rate=ds["sample_rate"],
n_filters=mod["n_filters"],
filter_length=mod["filter_length"],
stride=mod["stride"],
n_blocks=mod["n_blocks"],
n_repeats=mod["n_repeats"],
bn_chan=mod["bn_chan"],
hid_chan=mod["hid_chan"],
skip_chan=mod["skip_chan"],
norm_type=mod["norm_type"],
mask_act=mod["mask_act"],
use_gradient_checkpointing=False,
).to(device)
load_checkpoint(model, "best.ckpt", device)
model.eval()
```
To separate a WAV file using this project:
```bash
python src/separate.py --mix path/to/mixture.wav --ckpt best.ckpt
```
## Notes
This is a research/training checkpoint, not a fully packaged `transformers` pipeline.
It depends on PyTorch, Torchaudio, and Asteroid.
|