| |
| """ |
| Split pyannote Community-1 embedding ONNX into encoder + projection weights. |
| |
| Input: embedding_model.onnx (full pipeline: fbank -> encoder -> stats_pool -> Gemm -> 256d) |
| Output: embedding_encoder.onnx - everything before stats_pool, output frame features (B, 256, F) |
| resnet_seg_1_weight.npy - final Gemm weight (256, 5120) |
| resnet_seg_1_bias.npy - final Gemm bias (256,) |
| |
| Why split? Stats pooling in the full graph operates over ALL frames of the input. |
| For diarization with overlapping speakers, we need a MASKED stats pool (weighted by |
| per-speaker activity mask). Doing this in NumPy after running the encoder once per |
| batch of chunks is ~30x faster than running the full model once per (chunk, speaker). |
| See core/speaker_diarization_pure_ort.py for usage. |
| |
| Usage: |
| python convert_onnx/split_pyannote_embedding.py \\ |
| --input models/pyannote-onnx/embedding_model.onnx \\ |
| --output_dir models/pyannote-onnx/ |
| |
| Requires: onnx>=1.14, numpy |
| """ |
|
|
| import argparse |
| import os |
| from pathlib import Path |
|
|
| import numpy as np |
| import onnx |
| from onnx import numpy_helper |
|
|
| ENCODER_OUTPUT_TENSOR = "/resnet/pool/Reshape_output_0" |
| GEMM_WEIGHT_NAME = "resnet.seg_1.weight" |
| GEMM_BIAS_NAME = "resnet.seg_1.bias" |
|
|
|
|
| def split(input_path: str, output_dir: str) -> None: |
| input_path = Path(input_path).resolve() |
| output_dir = Path(output_dir).resolve() |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| if not input_path.exists(): |
| raise FileNotFoundError(f"Input ONNX not found: {input_path}") |
|
|
| print(f"[1/3] Loading {input_path.name}...") |
| model = onnx.load(str(input_path)) |
| input_name = model.graph.input[0].name |
|
|
| print(f"[2/3] Extracting encoder subgraph (output: {ENCODER_OUTPUT_TENSOR})...") |
| encoder_path = output_dir / "embedding_encoder.onnx" |
| onnx.utils.extract_model( |
| str(input_path), |
| str(encoder_path), |
| input_names=[input_name], |
| output_names=[ENCODER_OUTPUT_TENSOR], |
| ) |
| enc_size_mb = encoder_path.stat().st_size / 1024 / 1024 |
| print(f" -> {encoder_path.name} ({enc_size_mb:.1f} MB)") |
|
|
| print(f"[3/3] Extracting Gemm projection weights -> .npy...") |
| found = {} |
| for init in model.graph.initializer: |
| if init.name == GEMM_WEIGHT_NAME: |
| found["weight"] = numpy_helper.to_array(init) |
| elif init.name == GEMM_BIAS_NAME: |
| found["bias"] = numpy_helper.to_array(init) |
|
|
| if "weight" not in found or "bias" not in found: |
| raise RuntimeError( |
| f"Could not find Gemm initializers ({GEMM_WEIGHT_NAME}, {GEMM_BIAS_NAME}) " |
| f"in {input_path.name}. Is this the right model?" |
| ) |
|
|
| weight_path = output_dir / "resnet_seg_1_weight.npy" |
| bias_path = output_dir / "resnet_seg_1_bias.npy" |
| np.save(weight_path, found["weight"]) |
| np.save(bias_path, found["bias"]) |
| print(f" -> {weight_path.name} (shape={found['weight'].shape}, " |
| f"{weight_path.stat().st_size / 1024:.1f} KB)") |
| print(f" -> {bias_path.name} (shape={found['bias'].shape}, " |
| f"{bias_path.stat().st_size:.0f} B)") |
|
|
| print("\nDone. Verify with core/speaker_diarization_pure_ort.py — it auto-detects") |
| print("embedding_encoder.onnx + resnet_seg_1_*.npy and uses the fast batched path.") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0]) |
| parser.add_argument( |
| "--input", |
| default="models/pyannote-onnx/embedding_model.onnx", |
| help="Path to full embedding_model.onnx (download from " |
| "https://huggingface.co/altunenes/speaker-diarization-community-1-onnx)", |
| ) |
| parser.add_argument( |
| "--output_dir", |
| default="models/pyannote-onnx/", |
| help="Output directory for encoder + .npy files", |
| ) |
| args = parser.parse_args() |
| split(args.input, args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|