File size: 5,860 Bytes
cead59c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Standalone Mega-ASR inference using the INT4 ONNX models.

This script demonstrates how to run Mega-ASR end-to-end from raw audio
using only the ONNX models in this repo plus the qwen-asr tokenizer.

Requirements:
    pip install onnxruntime numpy soundfile transformers qwen-asr

Usage:
    python inference.py --audio examples/noise.wav
"""
from __future__ import annotations

import argparse
import re
import sys
from pathlib import Path

import numpy as np
import onnxruntime as ort
import soundfile as sf


HERE = Path(__file__).parent


def _ort(path, providers=("CPUExecutionProvider",)):
    so = ort.SessionOptions()
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    return ort.InferenceSession(str(path), so, providers=list(providers))


def main() -> int:
    ap = argparse.ArgumentParser()
    ap.add_argument("--audio", required=True, type=Path)
    ap.add_argument("--max-new-tokens", type=int, default=80)
    ap.add_argument("--providers", default="cuda,cpu",
                    help="Comma-separated ORT providers, e.g. cuda,cpu")
    args = ap.parse_args()

    providers = []
    for p in args.providers.split(","):
        p = p.strip().lower()
        if p == "cuda": providers.append("CUDAExecutionProvider")
        elif p == "cpu": providers.append("CPUExecutionProvider")
        elif p == "webgpu": providers.append("WebGpuExecutionProvider")

    # 1) Tokenizer + processor (needs HF transformers + qwen-asr installed)
    from transformers import AutoTokenizer
    from qwen_asr.core.transformers_backend.processing_qwen3_asr import Qwen3ASRProcessor
    tokenizer = AutoTokenizer.from_pretrained(str(HERE))
    processor = Qwen3ASRProcessor.from_pretrained(str(HERE))

    # 2) Read audio + build mel features
    audio, sr = sf.read(str(args.audio))
    if sr != 16000:
        import librosa
        audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000)

    # The processor expects a prompt with the audio placeholder. Build it as
    # the Qwen3-ASR chat template would.
    prompt = (
        "<|im_start|>system\\nYou are a speech recognition model.<|im_end|>\\n"
        "<|im_start|>user\\nDetect the language and recognize the speech: "
        "<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\\n"
        "<|im_start|>assistant\\n"
    )
    inputs = processor(text=[prompt], audio=[audio], return_tensors="np", padding=True)
    text_ids = inputs["input_ids"]                 # (1, L)
    attention_mask = inputs["attention_mask"]      # (1, L)
    input_features = inputs["input_features"]      # (1, 128, T_mel)

    # 3) Audio encoder
    T_mel = input_features.shape[-1]
    if T_mel > 3000:
        input_features = input_features[..., :3000]
        T_mel = 3000
    mel_padded = np.pad(input_features.astype(np.float32),
                         ((0, 0), (0, 0), (0, 3000 - T_mel)))
    enc = _ort(HERE / "onnx" / "audio_encoder_int4.onnx", providers)
    audio_embeds = enc.run(["audio_embeds"], {"mel": mel_padded})[0]   # (1, 390, 2048)
    real_chunks = (T_mel + 99) // 100
    last_chunk_mel = T_mel - (real_chunks - 1) * 100
    real_audio_frames = (real_chunks - 1) * 13 + (last_chunk_mel + 7) // 8
    audio_embeds = audio_embeds[:, :real_audio_frames]                 # (1, N, 2048)

    # 4) Embed text tokens + scatter audio into placeholder positions
    # The embedding table is published as embed_tokens.npy in the repo
    # (extracted from the original Qwen3-ASR weights for portability).
    embed_path = HERE / "onnx" / "embed_tokens.npy"
    if embed_path.exists():
        embed_w = np.load(embed_path).astype(np.float16)
    else:
        raise FileNotFoundError(
            "embed_tokens.npy missing — re-extract from the source PT model."
        )
    inputs_embeds = embed_w[text_ids[0]][None]                          # (1, L, hidden)
    audio_token_id = tokenizer.convert_tokens_to_ids("<|audio_pad|>")
    placeholder_mask = (text_ids == audio_token_id)                     # (1, L)
    # Replace placeholder embeddings with audio_embeds[0] in order
    placeholder_idx = np.where(placeholder_mask[0])[0]
    n_replace = min(len(placeholder_idx), audio_embeds.shape[1])
    inputs_embeds[0, placeholder_idx[:n_replace]] = audio_embeds[0, :n_replace]

    # 5) Position ids (1D)
    pos_ids = np.arange(text_ids.shape[1])[None].astype(np.int64)

    # 6) Prefill ONNX
    prefill = _ort(HERE / "onnx" / "decoder_prefill_int4.onnx", providers)
    feeds = {
        "inputs_embeds": inputs_embeds.astype(np.float16),
        "attention_mask": attention_mask.astype(np.int64),
        "position_ids": pos_ids,
    }
    out = prefill.run(None, feeds)
    logits, *past_kvs = out
    eos = tokenizer.eos_token_id

    # 7) Greedy decode with step ONNX
    step = _ort(HERE / "onnx" / "decoder_step_int4.onnx", providers)
    num_layers = 28
    cur_len = inputs_embeds.shape[1]
    gen = []
    nid = int(np.argmax(logits[0, -1, :]))
    gen.append(nid)

    for _ in range(args.max_new_tokens - 1):
        if nid == eos: break
        new_embed = embed_w[nid][None, None].astype(np.float16)
        new_attn = np.ones((1, cur_len + 1), dtype=np.int64)
        new_pos = np.array([[cur_len]], dtype=np.int64)
        f = {"inputs_embeds": new_embed,
             "attention_mask": new_attn,
             "position_ids": new_pos}
        for i in range(num_layers):
            f[f"past.{i}.key"] = past_kvs[2 * i]
            f[f"past.{i}.value"] = past_kvs[2 * i + 1]
        out = step.run(None, f)
        logits, *past_kvs = out
        nid = int(np.argmax(logits[0, -1, :]))
        gen.append(nid)
        cur_len += 1

    text = tokenizer.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=False)
    print(f"TRANSCRIPTION: {text}")
    return 0


if __name__ == "__main__":
    sys.exit(main())