"""MLX ASR bench: ONNX INT8 audio encoder + MLX 4-bit Qwen3 LLM. For each VITW example: 1. Build mel features via WhisperFeatureExtractor (from Qwen3-ASR preprocessor config) 2. Run mel -> ONNX audio encoder -> audio embeddings 3. Build prompt + audio placeholder token sequence 4. Embed text tokens via the model's embed_tokens (MLX), scatter audio embeddings 5. MLX greedy decode using input_embeddings 6. Compare hypothesis to reference -> WER + agreement % """ from __future__ import annotations import argparse import json import re import sys import warnings from pathlib import Path import numpy as np import mlx.core as mx warnings.filterwarnings("ignore") REFERENCES = { "noise": "I usually take the quieter road home because the main street gets crowded after work.", "far_field": "Please remind me to print the forms before we leave for the appointment tomorrow.", "obstructed": "I forgot my charger at home, so I need to find an outlet before the meeting starts.", "distortion": "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.", "recording": "Can you check whether the train still stops at the downtown station after eight tonight?", "echo": "I need to return these shoes because the size feels fine standing up but terrible while walking.", "dropout": "My aunt is learning video calls, and she gets excited whenever the picture actually works.", "mixed": "My sister is bringing dinner over later, so we do not need to cook tonight.", } _NORM_RE = re.compile(r"[^a-z0-9\s]") def normalize(text: str) -> str: if "" in text: text = text.split("", 1)[1] text = text.lower() text = _NORM_RE.sub(" ", text) text = re.sub(r"\s+", " ", text).strip() return text def wer(ref: str, hyp: str): r = ref.split(); h = hyp.split() if not r: return (1.0 if h else 0.0, len(h), 0) d = np.zeros((len(r) + 1, len(h) + 1), dtype=np.int32) for i in range(len(r) + 1): d[i, 0] = i for j in range(len(h) + 1): d[0, j] = j for i in range(1, len(r) + 1): for j in range(1, len(h) + 1): cost = 0 if r[i-1] == h[j-1] else 1 d[i, j] = min(d[i-1, j] + 1, d[i, j-1] + 1, d[i-1, j-1] + cost) return d[len(r), len(h)] / max(len(r), 1), d[len(r), len(h)], len(r) def color(pct, s): if pct >= 70: return f"\033[92m{s}\033[0m" # green if pct >= 50: return f"\033[93m{s}\033[0m" # yellow/orange if pct >= 25: return f"\033[33m{s}\033[0m" return f"\033[91m{s}\033[0m" def build_prompt_ids(tokenizer, audio_pad_count: int, audio_pad_id: int): """Build the Qwen3-ASR chat-template prompt + force_language='English'. Inserts `audio_pad_count` copies of audio_pad_id between audio_start and audio_end. Returns input_ids as np.int64 (1, L). """ # Apply chat template manually since the original tokenizer was trained with the template prompt = ( "<|im_start|>system\n<|im_end|>\n" "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n" "<|im_start|>assistant\n" "language English" ) ids = tokenizer.encode(prompt, add_special_tokens=False) # Find the single <|audio_pad|> position and expand it to audio_pad_count copies pos = ids.index(audio_pad_id) expanded = ids[:pos] + [audio_pad_id] * audio_pad_count + ids[pos + 1:] return np.array([expanded], dtype=np.int64) def main(): ap = argparse.ArgumentParser() ap.add_argument("--mlx-llm-path", default="models/mlx/mega-asr-llm-4bit", type=Path) ap.add_argument("--encoder-path", default="models/mega-asr-onnx-hf/onnx/audio_encoder_int8.onnx", type=Path) ap.add_argument("--examples-dir", default="models/mega-asr-onnx-hf/examples", type=Path) ap.add_argument("--max-new-tokens", type=int, default=80) args = ap.parse_args() import soundfile as sf import onnxruntime as ort from mlx_lm.utils import load from transformers import AutoFeatureExtractor, AutoTokenizer print(f"Loading MLX LLM from {args.mlx_llm_path} ...") model, tokenizer = load(str(args.mlx_llm_path)) # Use the ORIGINAL Qwen3-ASR tokenizer (has audio special tokens) hf_tokenizer = AutoTokenizer.from_pretrained("models/mega-asr-hf/Qwen3-ASR-1.7B") # Audio-related token ids per the Qwen3-ASR config audio_pad_id = 151676 # <|audio_pad|> — the placeholder we scatter audio_embeds into eos_id = 151645 # <|im_end|> print(f"audio_pad_id={audio_pad_id} eos_id={eos_id}") print(f"Loading ONNX encoder from {args.encoder_path} ...") enc_sess = ort.InferenceSession(str(args.encoder_path), providers=["CPUExecutionProvider"]) # Use Qwen3-ASR's WhisperFeatureExtractor — load from the original HF path feat_ext = AutoFeatureExtractor.from_pretrained("models/mega-asr-hf/Qwen3-ASR-1.7B") # Bench all 8 examples total_wer = 0.0 total_edits = 0 total_words = 0 n = 0 results = [] for name in sorted(REFERENCES): wav_path = args.examples_dir / f"{name}.wav" if not wav_path.exists(): print(f" skip {name} (missing wav)") continue audio, sr = sf.read(str(wav_path)) if audio.ndim > 1: audio = audio.mean(axis=1) if sr != 16000: import librosa audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000) # Mel features feats = feat_ext(audio, sampling_rate=16000, return_tensors="np", return_attention_mask=False) mel = feats["input_features"] # (1, 128, T_mel) T_mel = mel.shape[-1] if T_mel > 3000: mel = mel[..., :3000]; T_mel = 3000 mel_padded = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32) # Encoder audio_embeds = enc_sess.run(["audio_embeds"], {"mel": mel_padded})[0] # Compute actual audio frames (chunked CNN — see audio_encoder_wrapper.py) real_chunks = (T_mel + 99) // 100 last_chunk_mel = T_mel - (real_chunks - 1) * 100 real_frames = (real_chunks - 1) * 13 + (last_chunk_mel + 7) // 8 audio_embeds = audio_embeds[:, :real_frames] # (1, F, 2048) # Build prompt with placeholder audio_pad tokens of length = real_frames prompt_ids = build_prompt_ids(hf_tokenizer, real_frames, audio_pad_id) # Embed text via MLX model's embed_tokens ids_mx = mx.array(prompt_ids) text_emb = model.model.embed_tokens(ids_mx) # (1, L, 2048) # Scatter audio embeddings into positions where input_ids == audio_pad_id mask_np = (prompt_ids[0] == audio_pad_id).astype(np.int32) mask_idx_np = np.where(mask_np)[0] mask_idx = mx.array(mask_idx_np) audio_emb_mx = mx.array(audio_embeds[0]).astype(text_emb.dtype) # (real_frames, 2048) combined_mx = text_emb combined_mx[0, mask_idx] = audio_emb_mx # MLX greedy decode using input_embeddings from mlx_lm.models.cache import make_prompt_cache cache = make_prompt_cache(model) logits = model(inputs=mx.zeros((1, combined_mx.shape[1]), dtype=mx.int64), cache=cache, input_embeddings=combined_mx) next_tok = int(mx.argmax(logits[0, -1, :]).item()) out_ids = [next_tok] for _ in range(args.max_new_tokens - 1): if next_tok == eos_id: break logits = model(inputs=mx.array([[next_tok]]), cache=cache) next_tok = int(mx.argmax(logits[0, -1, :]).item()) out_ids.append(next_tok) # Strip trailing eos if any if out_ids and out_ids[-1] == eos_id: out_ids = out_ids[:-1] hyp_text = hf_tokenizer.decode(out_ids, skip_special_tokens=True) ref = normalize(REFERENCES[name]) hyp = normalize(hyp_text) w, ed, words = wer(ref, hyp) agree = max(0.0, 1.0 - w) * 100 total_wer += w total_edits += ed total_words += words n += 1 print(f"\n[{color(agree, name.ljust(10))}] WER={w*100:5.1f}% agree={color(agree, f'{agree:5.1f}%')}") print(f" REF: {ref}") print(f" HYP: {hyp}") results.append({"name": name, "wer": w, "agree": agree, "hyp": hyp}) avg_agree = (1 - total_wer / n) * 100 if n else 0 print(f"\n{color(avg_agree, f'=== AVERAGE: agreement {avg_agree:.1f}% WER {total_edits/total_words*100:.1f}% ({total_edits}/{total_words}) ===')}") return 0 if __name__ == "__main__": sys.exit(main())