license: cc-by-4.0
language:
- en
base_model: pyannote/speaker-diarization-community-1
tags:
- speaker-diarization
- speaker-embedding
- onnx
- pyannote
library_name: onnxruntime
Pyannote Community-1 Embedding — Split for Masked Pooling (ONNX)
A 3-piece split of the speaker embedding model from pyannote/speaker-diarization-community-1 that enables ~30× faster batched inference with masked stats pooling for diarization with overlapping speakers.
This repo derives from altunenes/speaker-diarization-community-1-onnx (which exported the original PyTorch ResNet-based embedding model to ONNX). Here we split the graph into three parts so the slow per-(chunk, speaker) loop can be replaced by a single batched encoder pass + NumPy pooling.
What's in this repo
| File | Size | Purpose |
|---|---|---|
embedding_encoder.onnx |
20 MB | ResNet encoder, output frame features (B, 2560, F) |
resnet_seg_1_weight.npy |
5 MB | Final Gemm projection weight (256, 5120) |
resnet_seg_1_bias.npy |
1 KB | Final Gemm projection bias (256,) |
split_pyannote_embedding.py |
— | Reproducibility script |
Why split? The masked-pooling problem
The full embedding_model.onnx graph is fbank → ResNet → StatsPool → Gemm → 256d. The internal StatsPool op pools all frames of the input — there's no way to mask which frames belong to which speaker.
In diarization, each 10-second chunk may contain overlapping speakers. To extract one embedding per (chunk, speaker), you need a masked stats pool that weights frames by per-speaker activity. Two options:
| Approach | Speed | Correctness |
|---|---|---|
| Run full model once per (chunk, speaker), filtering frames first | ~720 chunks × 3 speakers = 2160 ORT calls for a 2-hour file | ⚠️ Wrong: subset-filtering is not the same as weighted pooling — variance is computed on a different denominator |
| This split: batch-encode 64 chunks at a time, then NumPy-pool with mask | ~12 ORT calls for 2 hours | ✅ Matches pyannote.audio.models.blocks.pooling.StatsPool._pool() |
Combined with run_with_iobinding and enable_cpu_mem_arena=False, this gives ~30× speedup on CPU for the embedding extraction stage of long-form diarization.
Quick start
import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download
local = snapshot_download("welcomyou/pyannote-community-1-onnx-split")
opts = ort.SessionOptions()
opts.enable_cpu_mem_arena = False # avoid 1.8 GB arena that never shrinks
encoder = ort.InferenceSession(f"{local}/embedding_encoder.onnx", opts,
providers=["CPUExecutionProvider"])
W = np.load(f"{local}/resnet_seg_1_weight.npy") # (256, 5120)
b = np.load(f"{local}/resnet_seg_1_bias.npy") # (256,)
def masked_stats_pool(frame_feat, mask):
"""frame_feat: (D=2560, F), mask: (F,) float [0..1] — per-frame speaker weight.
Returns (5120,) = concat(weighted_mean, weighted_std).
Matches pyannote.audio StatsPool._pool()."""
w = mask[np.newaxis, :]
v1 = w.sum() + 1e-8
mean = (frame_feat * w).sum(axis=1) / v1
var = ((frame_feat - mean[:, None]) ** 2 * w).sum(axis=1) / (v1 - (w * w).sum() / v1 + 1e-8)
return np.concatenate([mean, np.sqrt(var)])
# Batched encoder pass: (64 chunks, 998 frames, 80 dim fbank)
fbank_batch = ... # shape (64, 998, 80) float32
frame_feats = encoder.run(None, {"fbank_features": fbank_batch})[0] # (64, 2560, F)
# For each (chunk, speaker) compute embedding:
for c in range(64):
for s in range(num_speakers):
mask = per_speaker_activity_masks[c, s] # (F,) float32
stats = masked_stats_pool(frame_feats[c], mask)
emb = stats @ W.T + b # (256,)
Full reference implementation: core/speaker_diarization_pure_ort.py (lines 707–810).
Reproducing the split
# 1. Download the upstream ONNX export from altunenes
huggingface-cli download altunenes/speaker-diarization-community-1-onnx \
--include embedding_model.onnx --local-dir pyannote-onnx/
# 2. Run the split script (≈30 s)
python split_pyannote_embedding.py \
--input pyannote-onnx/embedding_model.onnx \
--output_dir pyannote-onnx/
The script uses onnx.utils.extract_model to carve out the encoder subgraph (output tensor /resnet/pool/Reshape_output_0, just before stats pooling) and onnx.numpy_helper to dump the final Gemm initializers resnet.seg_1.weight and resnet.seg_1.bias as .npy.
Credits & License
- Original PyTorch model: pyannote/speaker-diarization-community-1 (segmentation + WeSpeaker-ResNet34 embedding + VBx clustering pipeline)
- ONNX conversion of upstream embedding model: altunenes/speaker-diarization-community-1-onnx
- Authors: Hervé Bredin et al. (pyannote), Hongji Wang et al. (WeSpeaker)
License: CC-BY-4.0 (inherited from pyannote/speaker-diarization-community-1). Attribution required; commercial use allowed.
The original pyannote/speaker-diarization-community-1 repository is gated and requires accepting a contact-information form. CC-BY-4.0 itself does not impose that restriction on derivative works, but please consider visiting the original repo to support pyannote.
Used by
- Sherpa Vietnamese ASR — Pure ORT diarization pipeline replacing
pyannote.audio(~2.1 GB → ~60 MB of dependencies).