Pyannote Community-1 Embedding β€” Split for Masked Pooling (ONNX)

A 3-piece split of the speaker embedding model from pyannote/speaker-diarization-community-1 that enables ~30Γ— faster batched inference with masked stats pooling for diarization with overlapping speakers.

This repo derives from altunenes/speaker-diarization-community-1-onnx (which exported the original PyTorch ResNet-based embedding model to ONNX). Here we split the graph into three parts so the slow per-(chunk, speaker) loop can be replaced by a single batched encoder pass + NumPy pooling.

What's in this repo

File Size Purpose
embedding_encoder.onnx 20 MB ResNet encoder, output frame features (B, 2560, F)
resnet_seg_1_weight.npy 5 MB Final Gemm projection weight (256, 5120)
resnet_seg_1_bias.npy 1 KB Final Gemm projection bias (256,)
split_pyannote_embedding.py β€” Reproducibility script

Why split? The masked-pooling problem

The full embedding_model.onnx graph is fbank β†’ ResNet β†’ StatsPool β†’ Gemm β†’ 256d. The internal StatsPool op pools all frames of the input β€” there's no way to mask which frames belong to which speaker.

In diarization, each 10-second chunk may contain overlapping speakers. To extract one embedding per (chunk, speaker), you need a masked stats pool that weights frames by per-speaker activity. Two options:

Approach Speed Correctness
Run full model once per (chunk, speaker), filtering frames first ~720 chunks Γ— 3 speakers = 2160 ORT calls for a 2-hour file ⚠️ Wrong: subset-filtering is not the same as weighted pooling β€” variance is computed on a different denominator
This split: batch-encode 64 chunks at a time, then NumPy-pool with mask ~12 ORT calls for 2 hours βœ… Matches pyannote.audio.models.blocks.pooling.StatsPool._pool()

Combined with run_with_iobinding and enable_cpu_mem_arena=False, this gives ~30Γ— speedup on CPU for the embedding extraction stage of long-form diarization.

Quick start

import numpy as np
import onnxruntime as ort
from huggingface_hub import snapshot_download

local = snapshot_download("welcomyou/pyannote-community-1-onnx-split")

opts = ort.SessionOptions()
opts.enable_cpu_mem_arena = False  # avoid 1.8 GB arena that never shrinks
encoder = ort.InferenceSession(f"{local}/embedding_encoder.onnx", opts,
                                providers=["CPUExecutionProvider"])
W = np.load(f"{local}/resnet_seg_1_weight.npy")  # (256, 5120)
b = np.load(f"{local}/resnet_seg_1_bias.npy")    # (256,)


def masked_stats_pool(frame_feat, mask):
    """frame_feat: (D=2560, F),  mask: (F,) float [0..1] β€” per-frame speaker weight.
    Returns (5120,) = concat(weighted_mean, weighted_std).
    Matches pyannote.audio StatsPool._pool()."""
    w = mask[np.newaxis, :]
    v1 = w.sum() + 1e-8
    mean = (frame_feat * w).sum(axis=1) / v1
    var = ((frame_feat - mean[:, None]) ** 2 * w).sum(axis=1) / (v1 - (w * w).sum() / v1 + 1e-8)
    return np.concatenate([mean, np.sqrt(var)])


# Batched encoder pass: (64 chunks, 998 frames, 80 dim fbank)
fbank_batch = ...  # shape (64, 998, 80) float32
frame_feats = encoder.run(None, {"fbank_features": fbank_batch})[0]  # (64, 2560, F)

# For each (chunk, speaker) compute embedding:
for c in range(64):
    for s in range(num_speakers):
        mask = per_speaker_activity_masks[c, s]  # (F,) float32
        stats = masked_stats_pool(frame_feats[c], mask)
        emb = stats @ W.T + b  # (256,)

Full reference implementation: core/speaker_diarization_pure_ort.py (lines 707–810).

Reproducing the split

# 1. Download the upstream ONNX export from altunenes
huggingface-cli download altunenes/speaker-diarization-community-1-onnx \
    --include embedding_model.onnx --local-dir pyannote-onnx/

# 2. Run the split script (β‰ˆ30 s)
python split_pyannote_embedding.py \
    --input pyannote-onnx/embedding_model.onnx \
    --output_dir pyannote-onnx/

The script uses onnx.utils.extract_model to carve out the encoder subgraph (output tensor /resnet/pool/Reshape_output_0, just before stats pooling) and onnx.numpy_helper to dump the final Gemm initializers resnet.seg_1.weight and resnet.seg_1.bias as .npy.

Credits & License

License: CC-BY-4.0 (inherited from pyannote/speaker-diarization-community-1). Attribution required; commercial use allowed.

The original pyannote/speaker-diarization-community-1 repository is gated and requires accepting a contact-information form. CC-BY-4.0 itself does not impose that restriction on derivative works, but please consider visiting the original repo to support pyannote.

Used by

  • Sherpa Vietnamese ASR β€” Pure ORT diarization pipeline replacing pyannote.audio (~2.1 GB β†’ ~60 MB of dependencies).
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for welcomyou/pyannote-community-1-onnx-split

Quantized
(2)
this model

Collection including welcomyou/pyannote-community-1-onnx-split