"""Standalone Mega-ASR inference using the INT4 ONNX models. This script demonstrates how to run Mega-ASR end-to-end from raw audio using only the ONNX models in this repo plus the qwen-asr tokenizer. Requirements: pip install onnxruntime numpy soundfile transformers qwen-asr Usage: python inference.py --audio examples/noise.wav """ from __future__ import annotations import argparse import re import sys from pathlib import Path import numpy as np import onnxruntime as ort import soundfile as sf HERE = Path(__file__).parent def _ort(path, providers=("CPUExecutionProvider",)): so = ort.SessionOptions() so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL return ort.InferenceSession(str(path), so, providers=list(providers)) def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--audio", required=True, type=Path) ap.add_argument("--max-new-tokens", type=int, default=80) ap.add_argument("--providers", default="cuda,cpu", help="Comma-separated ORT providers, e.g. cuda,cpu") args = ap.parse_args() providers = [] for p in args.providers.split(","): p = p.strip().lower() if p == "cuda": providers.append("CUDAExecutionProvider") elif p == "cpu": providers.append("CPUExecutionProvider") elif p == "webgpu": providers.append("WebGpuExecutionProvider") # 1) Tokenizer + processor (needs HF transformers + qwen-asr installed) from transformers import AutoTokenizer from qwen_asr.core.transformers_backend.processing_qwen3_asr import Qwen3ASRProcessor tokenizer = AutoTokenizer.from_pretrained(str(HERE)) processor = Qwen3ASRProcessor.from_pretrained(str(HERE)) # 2) Read audio + build mel features audio, sr = sf.read(str(args.audio)) if sr != 16000: import librosa audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000) # The processor expects a prompt with the audio placeholder. Build it as # the Qwen3-ASR chat template would. prompt = ( "<|im_start|>system\\nYou are a speech recognition model.<|im_end|>\\n" "<|im_start|>user\\nDetect the language and recognize the speech: " "<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\\n" "<|im_start|>assistant\\n" ) inputs = processor(text=[prompt], audio=[audio], return_tensors="np", padding=True) text_ids = inputs["input_ids"] # (1, L) attention_mask = inputs["attention_mask"] # (1, L) input_features = inputs["input_features"] # (1, 128, T_mel) # 3) Audio encoder T_mel = input_features.shape[-1] if T_mel > 3000: input_features = input_features[..., :3000] T_mel = 3000 mel_padded = np.pad(input_features.astype(np.float32), ((0, 0), (0, 0), (0, 3000 - T_mel))) enc = _ort(HERE / "onnx" / "audio_encoder_int4.onnx", providers) audio_embeds = enc.run(["audio_embeds"], {"mel": mel_padded})[0] # (1, 390, 2048) real_chunks = (T_mel + 99) // 100 last_chunk_mel = T_mel - (real_chunks - 1) * 100 real_audio_frames = (real_chunks - 1) * 13 + (last_chunk_mel + 7) // 8 audio_embeds = audio_embeds[:, :real_audio_frames] # (1, N, 2048) # 4) Embed text tokens + scatter audio into placeholder positions # The embedding table is published as embed_tokens.npy in the repo # (extracted from the original Qwen3-ASR weights for portability). embed_path = HERE / "onnx" / "embed_tokens.npy" if embed_path.exists(): embed_w = np.load(embed_path).astype(np.float16) else: raise FileNotFoundError( "embed_tokens.npy missing — re-extract from the source PT model." ) inputs_embeds = embed_w[text_ids[0]][None] # (1, L, hidden) audio_token_id = tokenizer.convert_tokens_to_ids("<|audio_pad|>") placeholder_mask = (text_ids == audio_token_id) # (1, L) # Replace placeholder embeddings with audio_embeds[0] in order placeholder_idx = np.where(placeholder_mask[0])[0] n_replace = min(len(placeholder_idx), audio_embeds.shape[1]) inputs_embeds[0, placeholder_idx[:n_replace]] = audio_embeds[0, :n_replace] # 5) Position ids (1D) pos_ids = np.arange(text_ids.shape[1])[None].astype(np.int64) # 6) Prefill ONNX prefill = _ort(HERE / "onnx" / "decoder_prefill_int4.onnx", providers) feeds = { "inputs_embeds": inputs_embeds.astype(np.float16), "attention_mask": attention_mask.astype(np.int64), "position_ids": pos_ids, } out = prefill.run(None, feeds) logits, *past_kvs = out eos = tokenizer.eos_token_id # 7) Greedy decode with step ONNX step = _ort(HERE / "onnx" / "decoder_step_int4.onnx", providers) num_layers = 28 cur_len = inputs_embeds.shape[1] gen = [] nid = int(np.argmax(logits[0, -1, :])) gen.append(nid) for _ in range(args.max_new_tokens - 1): if nid == eos: break new_embed = embed_w[nid][None, None].astype(np.float16) new_attn = np.ones((1, cur_len + 1), dtype=np.int64) new_pos = np.array([[cur_len]], dtype=np.int64) f = {"inputs_embeds": new_embed, "attention_mask": new_attn, "position_ids": new_pos} for i in range(num_layers): f[f"past.{i}.key"] = past_kvs[2 * i] f[f"past.{i}.value"] = past_kvs[2 * i + 1] out = step.run(None, f) logits, *past_kvs = out nid = int(np.argmax(logits[0, -1, :])) gen.append(nid) cur_len += 1 text = tokenizer.decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=False) print(f"TRANSCRIPTION: {text}") return 0 if __name__ == "__main__": sys.exit(main())