mega-asr-onnx / inference.py
Reza2kn's picture
Initial upload: INT4 ONNX (encoder + prefill + step) + tokenizer + examples + inference.py
cead59c verified
"""Standalone Mega-ASR inference using the INT4 ONNX models.
This script demonstrates how to run Mega-ASR end-to-end from raw audio
using only the ONNX models in this repo plus the qwen-asr tokenizer.
Requirements:
pip install onnxruntime numpy soundfile transformers qwen-asr
Usage:
python inference.py --audio examples/noise.wav
"""
from __future__ import annotations
import argparse
import re
import sys
from pathlib import Path
import numpy as np
import onnxruntime as ort
import soundfile as sf
HERE = Path(__file__).parent
def _ort(path, providers=("CPUExecutionProvider",)):
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
return ort.InferenceSession(str(path), so, providers=list(providers))
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--audio", required=True, type=Path)
ap.add_argument("--max-new-tokens", type=int, default=80)
ap.add_argument("--providers", default="cuda,cpu",
help="Comma-separated ORT providers, e.g. cuda,cpu")
args = ap.parse_args()
providers = []
for p in args.providers.split(","):
p = p.strip().lower()
if p == "cuda": providers.append("CUDAExecutionProvider")
elif p == "cpu": providers.append("CPUExecutionProvider")
elif p == "webgpu": providers.append("WebGpuExecutionProvider")
# 1) Tokenizer + processor (needs HF transformers + qwen-asr installed)
from transformers import AutoTokenizer
from qwen_asr.core.transformers_backend.processing_qwen3_asr import Qwen3ASRProcessor
tokenizer = AutoTokenizer.from_pretrained(str(HERE))
processor = Qwen3ASRProcessor.from_pretrained(str(HERE))
# 2) Read audio + build mel features
audio, sr = sf.read(str(args.audio))
if sr != 16000:
import librosa
audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000)
# The processor expects a prompt with the audio placeholder. Build it as
# the Qwen3-ASR chat template would.
prompt = (
"<|im_start|>system\\nYou are a speech recognition model.<|im_end|>\\n"
"<|im_start|>user\\nDetect the language and recognize the speech: "
"<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\\n"
"<|im_start|>assistant\\n"
)
inputs = processor(text=[prompt], audio=[audio], return_tensors="np", padding=True)
text_ids = inputs["input_ids"] # (1, L)
attention_mask = inputs["attention_mask"] # (1, L)
input_features = inputs["input_features"] # (1, 128, T_mel)
# 3) Audio encoder
T_mel = input_features.shape[-1]
if T_mel > 3000:
input_features = input_features[..., :3000]
T_mel = 3000
mel_padded = np.pad(input_features.astype(np.float32),
((0, 0), (0, 0), (0, 3000 - T_mel)))
enc = _ort(HERE / "onnx" / "audio_encoder_int4.onnx", providers)
audio_embeds = enc.run(["audio_embeds"], {"mel": mel_padded})[0] # (1, 390, 2048)
real_chunks = (T_mel + 99) // 100
last_chunk_mel = T_mel - (real_chunks - 1) * 100
real_audio_frames = (real_chunks - 1) * 13 + (last_chunk_mel + 7) // 8
audio_embeds = audio_embeds[:, :real_audio_frames] # (1, N, 2048)
# 4) Embed text tokens + scatter audio into placeholder positions
# The embedding table is published as embed_tokens.npy in the repo
# (extracted from the original Qwen3-ASR weights for portability).
embed_path = HERE / "onnx" / "embed_tokens.npy"
if embed_path.exists():
embed_w = np.load(embed_path).astype(np.float16)
else:
raise FileNotFoundError(
"embed_tokens.npy missing — re-extract from the source PT model."
)
inputs_embeds = embed_w[text_ids[0]][None] # (1, L, hidden)
audio_token_id = tokenizer.convert_tokens_to_ids("<|audio_pad|>")
placeholder_mask = (text_ids == audio_token_id) # (1, L)
# Replace placeholder embeddings with audio_embeds[0] in order
placeholder_idx = np.where(placeholder_mask[0])[0]
n_replace = min(len(placeholder_idx), audio_embeds.shape[1])
inputs_embeds[0, placeholder_idx[:n_replace]] = audio_embeds[0, :n_replace]
# 5) Position ids (1D)
pos_ids = np.arange(text_ids.shape[1])[None].astype(np.int64)
# 6) Prefill ONNX
prefill = _ort(HERE / "onnx" / "decoder_prefill_int4.onnx", providers)
feeds = {
"inputs_embeds": inputs_embeds.astype(np.float16),
"attention_mask": attention_mask.astype(np.int64),
"position_ids": pos_ids,
}
out = prefill.run(None, feeds)
logits, *past_kvs = out
eos = tokenizer.eos_token_id
# 7) Greedy decode with step ONNX
step = _ort(HERE / "onnx" / "decoder_step_int4.onnx", providers)
num_layers = 28
cur_len = inputs_embeds.shape[1]
gen = []
nid = int(np.argmax(logits[0, -1, :]))
gen.append(nid)
for _ in range(args.max_new_tokens - 1):
if nid == eos: break
new_embed = embed_w[nid][None, None].astype(np.float16)
new_attn = np.ones((1, cur_len + 1), dtype=np.int64)
new_pos = np.array([[cur_len]], dtype=np.int64)
f = {"inputs_embeds": new_embed,
"attention_mask": new_attn,
"position_ids": new_pos}
for i in range(num_layers):
f[f"past.{i}.key"] = past_kvs[2 * i]
f[f"past.{i}.value"] = past_kvs[2 * i + 1]
out = step.run(None, f)
logits, *past_kvs = out
nid = int(np.argmax(logits[0, -1, :]))
gen.append(nid)
cur_len += 1
text = tokenizer.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=False)
print(f"TRANSCRIPTION: {text}")
return 0
if __name__ == "__main__":
sys.exit(main())