--- license: cc-by-4.0 language: - en base_model: pyannote/speaker-diarization-community-1 tags: - speaker-diarization - speaker-embedding - onnx - pyannote library_name: onnxruntime --- # Pyannote Community-1 Embedding — Split for Masked Pooling (ONNX) A 3-piece split of the speaker embedding model from [pyannote/speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) that enables **~30× faster batched inference with masked stats pooling** for diarization with overlapping speakers. This repo derives from [altunenes/speaker-diarization-community-1-onnx](https://huggingface.co/altunenes/speaker-diarization-community-1-onnx) (which exported the original PyTorch ResNet-based embedding model to ONNX). Here we split the graph into three parts so the slow per-(chunk, speaker) loop can be replaced by a single batched encoder pass + NumPy pooling. ## What's in this repo | File | Size | Purpose | |---|---|---| | `embedding_encoder.onnx` | 20 MB | ResNet encoder, output frame features `(B, 2560, F)` | | `resnet_seg_1_weight.npy` | 5 MB | Final Gemm projection weight `(256, 5120)` | | `resnet_seg_1_bias.npy` | 1 KB | Final Gemm projection bias `(256,)` | | `split_pyannote_embedding.py` | — | Reproducibility script | ## Why split? The masked-pooling problem The full `embedding_model.onnx` graph is `fbank → ResNet → StatsPool → Gemm → 256d`. The internal `StatsPool` op pools **all frames** of the input — there's no way to mask which frames belong to which speaker. In diarization, each 10-second chunk may contain overlapping speakers. To extract one embedding per (chunk, speaker), you need a **masked stats pool** that weights frames by per-speaker activity. Two options: | Approach | Speed | Correctness | |---|---|---| | Run full model once per (chunk, speaker), filtering frames first | ~720 chunks × 3 speakers = 2160 ORT calls for a 2-hour file | ⚠️ Wrong: subset-filtering is not the same as weighted pooling — variance is computed on a different denominator | | **This split**: batch-encode 64 chunks at a time, then NumPy-pool with mask | ~12 ORT calls for 2 hours | ✅ Matches `pyannote.audio.models.blocks.pooling.StatsPool._pool()` | Combined with `run_with_iobinding` and `enable_cpu_mem_arena=False`, this gives ~30× speedup on CPU for the embedding extraction stage of long-form diarization. ## Quick start ```python import numpy as np import onnxruntime as ort from huggingface_hub import snapshot_download local = snapshot_download("welcomyou/pyannote-community-1-onnx-split") opts = ort.SessionOptions() opts.enable_cpu_mem_arena = False # avoid 1.8 GB arena that never shrinks encoder = ort.InferenceSession(f"{local}/embedding_encoder.onnx", opts, providers=["CPUExecutionProvider"]) W = np.load(f"{local}/resnet_seg_1_weight.npy") # (256, 5120) b = np.load(f"{local}/resnet_seg_1_bias.npy") # (256,) def masked_stats_pool(frame_feat, mask): """frame_feat: (D=2560, F), mask: (F,) float [0..1] — per-frame speaker weight. Returns (5120,) = concat(weighted_mean, weighted_std). Matches pyannote.audio StatsPool._pool().""" w = mask[np.newaxis, :] v1 = w.sum() + 1e-8 mean = (frame_feat * w).sum(axis=1) / v1 var = ((frame_feat - mean[:, None]) ** 2 * w).sum(axis=1) / (v1 - (w * w).sum() / v1 + 1e-8) return np.concatenate([mean, np.sqrt(var)]) # Batched encoder pass: (64 chunks, 998 frames, 80 dim fbank) fbank_batch = ... # shape (64, 998, 80) float32 frame_feats = encoder.run(None, {"fbank_features": fbank_batch})[0] # (64, 2560, F) # For each (chunk, speaker) compute embedding: for c in range(64): for s in range(num_speakers): mask = per_speaker_activity_masks[c, s] # (F,) float32 stats = masked_stats_pool(frame_feats[c], mask) emb = stats @ W.T + b # (256,) ``` Full reference implementation: [`core/speaker_diarization_pure_ort.py`](https://github.com/welcomyou/sherpa-vietnamese-asr/blob/main/core/speaker_diarization_pure_ort.py) (lines 707–810). ## Reproducing the split ```bash # 1. Download the upstream ONNX export from altunenes huggingface-cli download altunenes/speaker-diarization-community-1-onnx \ --include embedding_model.onnx --local-dir pyannote-onnx/ # 2. Run the split script (≈30 s) python split_pyannote_embedding.py \ --input pyannote-onnx/embedding_model.onnx \ --output_dir pyannote-onnx/ ``` The script uses `onnx.utils.extract_model` to carve out the encoder subgraph (output tensor `/resnet/pool/Reshape_output_0`, just before stats pooling) and `onnx.numpy_helper` to dump the final Gemm initializers `resnet.seg_1.weight` and `resnet.seg_1.bias` as `.npy`. ## Credits & License - **Original PyTorch model**: [pyannote/speaker-diarization-community-1](https://huggingface.co/pyannote/speaker-diarization-community-1) (segmentation + WeSpeaker-ResNet34 embedding + VBx clustering pipeline) - **ONNX conversion of upstream embedding model**: [altunenes/speaker-diarization-community-1-onnx](https://huggingface.co/altunenes/speaker-diarization-community-1-onnx) - **Authors**: Hervé Bredin et al. (pyannote), Hongji Wang et al. (WeSpeaker) License: **CC-BY-4.0** (inherited from `pyannote/speaker-diarization-community-1`). Attribution required; commercial use allowed. The original `pyannote/speaker-diarization-community-1` repository is gated and requires accepting a contact-information form. CC-BY-4.0 itself does not impose that restriction on derivative works, but please consider visiting the original repo to support pyannote. ## Used by - [Sherpa Vietnamese ASR](https://github.com/welcomyou/sherpa-vietnamese-asr) — Pure ORT diarization pipeline replacing `pyannote.audio` (~2.1 GB → ~60 MB of dependencies).