File size: 3,930 Bytes
cde44c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#!/usr/bin/env python3
"""
Split pyannote Community-1 embedding ONNX into encoder + projection weights.

Input:  embedding_model.onnx (full pipeline: fbank -> encoder -> stats_pool -> Gemm -> 256d)
Output: embedding_encoder.onnx     - everything before stats_pool, output frame features (B, 256, F)
        resnet_seg_1_weight.npy    - final Gemm weight (256, 5120)
        resnet_seg_1_bias.npy      - final Gemm bias (256,)

Why split? Stats pooling in the full graph operates over ALL frames of the input.
For diarization with overlapping speakers, we need a MASKED stats pool (weighted by
per-speaker activity mask). Doing this in NumPy after running the encoder once per
batch of chunks is ~30x faster than running the full model once per (chunk, speaker).
See core/speaker_diarization_pure_ort.py for usage.

Usage:
    python convert_onnx/split_pyannote_embedding.py \\
        --input  models/pyannote-onnx/embedding_model.onnx \\
        --output_dir models/pyannote-onnx/

Requires: onnx>=1.14, numpy
"""

import argparse
import os
from pathlib import Path

import numpy as np
import onnx
from onnx import numpy_helper

ENCODER_OUTPUT_TENSOR = "/resnet/pool/Reshape_output_0"
GEMM_WEIGHT_NAME = "resnet.seg_1.weight"
GEMM_BIAS_NAME = "resnet.seg_1.bias"


def split(input_path: str, output_dir: str) -> None:
    input_path = Path(input_path).resolve()
    output_dir = Path(output_dir).resolve()
    output_dir.mkdir(parents=True, exist_ok=True)

    if not input_path.exists():
        raise FileNotFoundError(f"Input ONNX not found: {input_path}")

    print(f"[1/3] Loading {input_path.name}...")
    model = onnx.load(str(input_path))
    input_name = model.graph.input[0].name

    print(f"[2/3] Extracting encoder subgraph (output: {ENCODER_OUTPUT_TENSOR})...")
    encoder_path = output_dir / "embedding_encoder.onnx"
    onnx.utils.extract_model(
        str(input_path),
        str(encoder_path),
        input_names=[input_name],
        output_names=[ENCODER_OUTPUT_TENSOR],
    )
    enc_size_mb = encoder_path.stat().st_size / 1024 / 1024
    print(f"      -> {encoder_path.name} ({enc_size_mb:.1f} MB)")

    print(f"[3/3] Extracting Gemm projection weights -> .npy...")
    found = {}
    for init in model.graph.initializer:
        if init.name == GEMM_WEIGHT_NAME:
            found["weight"] = numpy_helper.to_array(init)
        elif init.name == GEMM_BIAS_NAME:
            found["bias"] = numpy_helper.to_array(init)

    if "weight" not in found or "bias" not in found:
        raise RuntimeError(
            f"Could not find Gemm initializers ({GEMM_WEIGHT_NAME}, {GEMM_BIAS_NAME}) "
            f"in {input_path.name}. Is this the right model?"
        )

    weight_path = output_dir / "resnet_seg_1_weight.npy"
    bias_path = output_dir / "resnet_seg_1_bias.npy"
    np.save(weight_path, found["weight"])
    np.save(bias_path, found["bias"])
    print(f"      -> {weight_path.name} (shape={found['weight'].shape}, "
          f"{weight_path.stat().st_size / 1024:.1f} KB)")
    print(f"      -> {bias_path.name} (shape={found['bias'].shape}, "
          f"{bias_path.stat().st_size:.0f} B)")

    print("\nDone. Verify with core/speaker_diarization_pure_ort.py — it auto-detects")
    print("embedding_encoder.onnx + resnet_seg_1_*.npy and uses the fast batched path.")


def main():
    parser = argparse.ArgumentParser(description=__doc__.split("\n\n")[0])
    parser.add_argument(
        "--input",
        default="models/pyannote-onnx/embedding_model.onnx",
        help="Path to full embedding_model.onnx (download from "
             "https://huggingface.co/altunenes/speaker-diarization-community-1-onnx)",
    )
    parser.add_argument(
        "--output_dir",
        default="models/pyannote-onnx/",
        help="Output directory for encoder + .npy files",
    )
    args = parser.parse_args()
    split(args.input, args.output_dir)


if __name__ == "__main__":
    main()