| --- |
| license: cc-by-4.0 |
| language: |
| - en |
| base_model: pyannote/speaker-diarization-community-1 |
| tags: |
| - speaker-diarization |
| - speaker-embedding |
| - onnx |
| - pyannote |
| library_name: onnxruntime |
| --- |
| |
| # Pyannote Community-1 Embedding β Split for Masked Pooling (ONNX) |
|
|
| A 3-piece split of the speaker embedding model from [pyannote/speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) that enables **~30Γ faster batched inference with masked stats pooling** for diarization with overlapping speakers. |
|
|
| This repo derives from [altunenes/speaker-diarization-community-1-onnx](https://huggingface.co/altunenes/speaker-diarization-community-1-onnx) (which exported the original PyTorch ResNet-based embedding model to ONNX). Here we split the graph into three parts so the slow per-(chunk, speaker) loop can be replaced by a single batched encoder pass + NumPy pooling. |
|
|
| ## What's in this repo |
|
|
| | File | Size | Purpose | |
| |---|---|---| |
| | `embedding_encoder.onnx` | 20 MB | ResNet encoder, output frame features `(B, 2560, F)` | |
| | `resnet_seg_1_weight.npy` | 5 MB | Final Gemm projection weight `(256, 5120)` | |
| | `resnet_seg_1_bias.npy` | 1 KB | Final Gemm projection bias `(256,)` | |
| | `split_pyannote_embedding.py` | β | Reproducibility script | |
|
|
| ## Why split? The masked-pooling problem |
|
|
| The full `embedding_model.onnx` graph is `fbank β ResNet β StatsPool β Gemm β 256d`. The internal `StatsPool` op pools **all frames** of the input β there's no way to mask which frames belong to which speaker. |
|
|
| In diarization, each 10-second chunk may contain overlapping speakers. To extract one embedding per (chunk, speaker), you need a **masked stats pool** that weights frames by per-speaker activity. Two options: |
|
|
| | Approach | Speed | Correctness | |
| |---|---|---| |
| | Run full model once per (chunk, speaker), filtering frames first | ~720 chunks Γ 3 speakers = 2160 ORT calls for a 2-hour file | β οΈ Wrong: subset-filtering is not the same as weighted pooling β variance is computed on a different denominator | |
| | **This split**: batch-encode 64 chunks at a time, then NumPy-pool with mask | ~12 ORT calls for 2 hours | β
Matches `pyannote.audio.models.blocks.pooling.StatsPool._pool()` | |
|
|
| Combined with `run_with_iobinding` and `enable_cpu_mem_arena=False`, this gives ~30Γ speedup on CPU for the embedding extraction stage of long-form diarization. |
|
|
| ## Quick start |
|
|
| ```python |
| import numpy as np |
| import onnxruntime as ort |
| from huggingface_hub import snapshot_download |
| |
| local = snapshot_download("welcomyou/pyannote-community-1-onnx-split") |
| |
| opts = ort.SessionOptions() |
| opts.enable_cpu_mem_arena = False # avoid 1.8 GB arena that never shrinks |
| encoder = ort.InferenceSession(f"{local}/embedding_encoder.onnx", opts, |
| providers=["CPUExecutionProvider"]) |
| W = np.load(f"{local}/resnet_seg_1_weight.npy") # (256, 5120) |
| b = np.load(f"{local}/resnet_seg_1_bias.npy") # (256,) |
| |
| |
| def masked_stats_pool(frame_feat, mask): |
| """frame_feat: (D=2560, F), mask: (F,) float [0..1] β per-frame speaker weight. |
| Returns (5120,) = concat(weighted_mean, weighted_std). |
| Matches pyannote.audio StatsPool._pool().""" |
| w = mask[np.newaxis, :] |
| v1 = w.sum() + 1e-8 |
| mean = (frame_feat * w).sum(axis=1) / v1 |
| var = ((frame_feat - mean[:, None]) ** 2 * w).sum(axis=1) / (v1 - (w * w).sum() / v1 + 1e-8) |
| return np.concatenate([mean, np.sqrt(var)]) |
| |
| |
| # Batched encoder pass: (64 chunks, 998 frames, 80 dim fbank) |
| fbank_batch = ... # shape (64, 998, 80) float32 |
| frame_feats = encoder.run(None, {"fbank_features": fbank_batch})[0] # (64, 2560, F) |
| |
| # For each (chunk, speaker) compute embedding: |
| for c in range(64): |
| for s in range(num_speakers): |
| mask = per_speaker_activity_masks[c, s] # (F,) float32 |
| stats = masked_stats_pool(frame_feats[c], mask) |
| emb = stats @ W.T + b # (256,) |
| ``` |
|
|
| Full reference implementation: [`core/speaker_diarization_pure_ort.py`](https://github.com/welcomyou/sherpa-vietnamese-asr/blob/main/core/speaker_diarization_pure_ort.py) (lines 707β810). |
|
|
| ## Reproducing the split |
|
|
| ```bash |
| # 1. Download the upstream ONNX export from altunenes |
| huggingface-cli download altunenes/speaker-diarization-community-1-onnx \ |
| --include embedding_model.onnx --local-dir pyannote-onnx/ |
| |
| # 2. Run the split script (β30 s) |
| python split_pyannote_embedding.py \ |
| --input pyannote-onnx/embedding_model.onnx \ |
| --output_dir pyannote-onnx/ |
| ``` |
|
|
| The script uses `onnx.utils.extract_model` to carve out the encoder subgraph (output tensor `/resnet/pool/Reshape_output_0`, just before stats pooling) and `onnx.numpy_helper` to dump the final Gemm initializers `resnet.seg_1.weight` and `resnet.seg_1.bias` as `.npy`. |
|
|
| ## Credits & License |
|
|
| - **Original PyTorch model**: [pyannote/speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) (segmentation + WeSpeaker-ResNet34 embedding + VBx clustering pipeline) |
| - **ONNX conversion of upstream embedding model**: [altunenes/speaker-diarization-community-1-onnx](https://huggingface.co/altunenes/speaker-diarization-community-1-onnx) |
| - **Authors**: HervΓ© Bredin et al. (pyannote), Hongji Wang et al. (WeSpeaker) |
|
|
| License: **CC-BY-4.0** (inherited from `pyannote/speaker-diarization-community-1`). Attribution required; commercial use allowed. |
|
|
| The original `pyannote/speaker-diarization-community-1` repository is gated and requires accepting a contact-information form. CC-BY-4.0 itself does not impose that restriction on derivative works, but please consider visiting the original repo to support pyannote. |
|
|
| ## Used by |
|
|
| - [Sherpa Vietnamese ASR](https://github.com/welcomyou/sherpa-vietnamese-asr) β Pure ORT diarization pipeline replacing `pyannote.audio` (~2.1 GB β ~60 MB of dependencies). |
|
|