welcomyou's picture
Initial upload: pyannote Community-1 embedding split for masked pooling
cde44c2 verified
---
license: cc-by-4.0
language:
- en
base_model: pyannote/speaker-diarization-community-1
tags:
- speaker-diarization
- speaker-embedding
- onnx
- pyannote
library_name: onnxruntime
---
# Pyannote Community-1 Embedding β€” Split for Masked Pooling (ONNX)
A 3-piece split of the speaker embedding model from [pyannote/speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) that enables **~30Γ— faster batched inference with masked stats pooling** for diarization with overlapping speakers.
This repo derives from [altunenes/speaker-diarization-community-1-onnx](https://huggingface.co/altunenes/speaker-diarization-community-1-onnx) (which exported the original PyTorch ResNet-based embedding model to ONNX). Here we split the graph into three parts so the slow per-(chunk, speaker) loop can be replaced by a single batched encoder pass + NumPy pooling.
## What's in this repo
| File | Size | Purpose |
|---|---|---|
| `embedding_encoder.onnx` | 20 MB | ResNet encoder, output frame features `(B, 2560, F)` |
| `resnet_seg_1_weight.npy` | 5 MB | Final Gemm projection weight `(256, 5120)` |
| `resnet_seg_1_bias.npy` | 1 KB | Final Gemm projection bias `(256,)` |
| `split_pyannote_embedding.py` | β€” | Reproducibility script |
## Why split? The masked-pooling problem
The full `embedding_model.onnx` graph is `fbank β†’ ResNet β†’ StatsPool β†’ Gemm β†’ 256d`. The internal `StatsPool` op pools **all frames** of the input β€” there's no way to mask which frames belong to which speaker.
In diarization, each 10-second chunk may contain overlapping speakers. To extract one embedding per (chunk, speaker), you need a **masked stats pool** that weights frames by per-speaker activity. Two options:
| Approach | Speed | Correctness |
|---|---|---|
| Run full model once per (chunk, speaker), filtering frames first | ~720 chunks Γ— 3 speakers = 2160 ORT calls for a 2-hour file | ⚠️ Wrong: subset-filtering is not the same as weighted pooling β€” variance is computed on a different denominator |
| **This split**: batch-encode 64 chunks at a time, then NumPy-pool with mask | ~12 ORT calls for 2 hours | βœ… Matches `pyannote.audio.models.blocks.pooling.StatsPool._pool()` |
Combined with `run_with_iobinding` and `enable_cpu_mem_arena=False`, this gives ~30Γ— speedup on CPU for the embedding extraction stage of long-form diarization.
## Quick start
```python
import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download
local = snapshot_download("welcomyou/pyannote-community-1-onnx-split")
opts = ort.SessionOptions()
opts.enable_cpu_mem_arena = False # avoid 1.8 GB arena that never shrinks
encoder = ort.InferenceSession(f"{local}/embedding_encoder.onnx", opts,
providers=["CPUExecutionProvider"])
W = np.load(f"{local}/resnet_seg_1_weight.npy") # (256, 5120)
b = np.load(f"{local}/resnet_seg_1_bias.npy") # (256,)
def masked_stats_pool(frame_feat, mask):
"""frame_feat: (D=2560, F), mask: (F,) float [0..1] β€” per-frame speaker weight.
Returns (5120,) = concat(weighted_mean, weighted_std).
Matches pyannote.audio StatsPool._pool()."""
w = mask[np.newaxis, :]
v1 = w.sum() + 1e-8
mean = (frame_feat * w).sum(axis=1) / v1
var = ((frame_feat - mean[:, None]) ** 2 * w).sum(axis=1) / (v1 - (w * w).sum() / v1 + 1e-8)
return np.concatenate([mean, np.sqrt(var)])
# Batched encoder pass: (64 chunks, 998 frames, 80 dim fbank)
fbank_batch = ... # shape (64, 998, 80) float32
frame_feats = encoder.run(None, {"fbank_features": fbank_batch})[0] # (64, 2560, F)
# For each (chunk, speaker) compute embedding:
for c in range(64):
for s in range(num_speakers):
mask = per_speaker_activity_masks[c, s] # (F,) float32
stats = masked_stats_pool(frame_feats[c], mask)
emb = stats @ W.T + b # (256,)
```
Full reference implementation: [`core/speaker_diarization_pure_ort.py`](https://github.com/welcomyou/sherpa-vietnamese-asr/blob/main/core/speaker_diarization_pure_ort.py) (lines 707–810).
## Reproducing the split
```bash
# 1. Download the upstream ONNX export from altunenes
huggingface-cli download altunenes/speaker-diarization-community-1-onnx \
--include embedding_model.onnx --local-dir pyannote-onnx/
# 2. Run the split script (β‰ˆ30 s)
python split_pyannote_embedding.py \
--input pyannote-onnx/embedding_model.onnx \
--output_dir pyannote-onnx/
```
The script uses `onnx.utils.extract_model` to carve out the encoder subgraph (output tensor `/resnet/pool/Reshape_output_0`, just before stats pooling) and `onnx.numpy_helper` to dump the final Gemm initializers `resnet.seg_1.weight` and `resnet.seg_1.bias` as `.npy`.
## Credits & License
- **Original PyTorch model**: [pyannote/speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) (segmentation + WeSpeaker-ResNet34 embedding + VBx clustering pipeline)
- **ONNX conversion of upstream embedding model**: [altunenes/speaker-diarization-community-1-onnx](https://huggingface.co/altunenes/speaker-diarization-community-1-onnx)
- **Authors**: HervΓ© Bredin et al. (pyannote), Hongji Wang et al. (WeSpeaker)
License: **CC-BY-4.0** (inherited from `pyannote/speaker-diarization-community-1`). Attribution required; commercial use allowed.
The original `pyannote/speaker-diarization-community-1` repository is gated and requires accepting a contact-information form. CC-BY-4.0 itself does not impose that restriction on derivative works, but please consider visiting the original repo to support pyannote.
## Used by
- [Sherpa Vietnamese ASR](https://github.com/welcomyou/sherpa-vietnamese-asr) β€” Pure ORT diarization pipeline replacing `pyannote.audio` (~2.1 GB β†’ ~60 MB of dependencies).