--- license: apache-2.0 tags: - audio - speech - source-separation - conv-tasnet - asteroid - pytorch library_name: pytorch pipeline_tag: audio-to-audio --- # Cocktail Party AI - Conv-TasNet 3-Source Separator This repository contains the best checkpoint from a Conv-TasNet model trained for speech source separation. The model takes a mixed speech waveform and estimates 3 separated source waveforms. ## Checkpoint - File: `best.ckpt` - Architecture: Asteroid `ConvTasNet` - Number of sources: 3 - Sample rate: 16 kHz - Training checkpoint epoch: 68 - Best validation loss: -2.909952 - Approximate validation SI-SNR: 2.91 dB ## Files ```text best.ckpt configs/data.yaml configs/train.yaml requirements.txt src/model.py src/separate.py ``` ## Usage Install dependencies: ```bash pip install -r requirements.txt ``` Load the checkpoint with the project code: ```python import yaml import torch from src.model import build_model, load_checkpoint with open("configs/train.yaml") as f: train_cfg = yaml.safe_load(f) with open("configs/data.yaml") as f: data_cfg = yaml.safe_load(f) mod = train_cfg["model"] ds = data_cfg["dataset"] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = build_model( n_src=ds["n_src"], sample_rate=ds["sample_rate"], n_filters=mod["n_filters"], filter_length=mod["filter_length"], stride=mod["stride"], n_blocks=mod["n_blocks"], n_repeats=mod["n_repeats"], bn_chan=mod["bn_chan"], hid_chan=mod["hid_chan"], skip_chan=mod["skip_chan"], norm_type=mod["norm_type"], mask_act=mod["mask_act"], use_gradient_checkpointing=False, ).to(device) load_checkpoint(model, "best.ckpt", device) model.eval() ``` To separate a WAV file using this project: ```bash python src/separate.py --mix path/to/mixture.wav --ckpt best.ckpt ``` ## Notes This is a research/training checkpoint, not a fully packaged `transformers` pipeline. It depends on PyTorch, Torchaudio, and Asteroid.