File size: 2,007 Bytes
27441d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
---
license: apache-2.0
tags:
  - audio
  - speech
  - source-separation
  - conv-tasnet
  - asteroid
  - pytorch
library_name: pytorch
pipeline_tag: audio-to-audio
---

# Cocktail Party AI - Conv-TasNet 3-Source Separator

This repository contains the best checkpoint from a Conv-TasNet model trained for speech source separation.
The model takes a mixed speech waveform and estimates 3 separated source waveforms.

## Checkpoint

- File: `best.ckpt`
- Architecture: Asteroid `ConvTasNet`
- Number of sources: 3
- Sample rate: 16 kHz
- Training checkpoint epoch: 68
- Best validation loss: -2.909952
- Approximate validation SI-SNR: 2.91 dB

## Files

```text
best.ckpt
configs/data.yaml
configs/train.yaml
requirements.txt
src/model.py
src/separate.py
```

## Usage

Install dependencies:

```bash
pip install -r requirements.txt
```

Load the checkpoint with the project code:

```python
import yaml
import torch

from src.model import build_model, load_checkpoint

with open("configs/train.yaml") as f:
    train_cfg = yaml.safe_load(f)
with open("configs/data.yaml") as f:
    data_cfg = yaml.safe_load(f)

mod = train_cfg["model"]
ds = data_cfg["dataset"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = build_model(
    n_src=ds["n_src"],
    sample_rate=ds["sample_rate"],
    n_filters=mod["n_filters"],
    filter_length=mod["filter_length"],
    stride=mod["stride"],
    n_blocks=mod["n_blocks"],
    n_repeats=mod["n_repeats"],
    bn_chan=mod["bn_chan"],
    hid_chan=mod["hid_chan"],
    skip_chan=mod["skip_chan"],
    norm_type=mod["norm_type"],
    mask_act=mod["mask_act"],
    use_gradient_checkpointing=False,
).to(device)

load_checkpoint(model, "best.ckpt", device)
model.eval()
```

To separate a WAV file using this project:

```bash
python src/separate.py --mix path/to/mixture.wav --ckpt best.ckpt
```

## Notes

This is a research/training checkpoint, not a fully packaged `transformers` pipeline.
It depends on PyTorch, Torchaudio, and Asteroid.