Automatic Speech Recognition
MLX
ONNX
Safetensors
asr
speech-recognition
robust-asr
quantized
int4
4bit
mixed-precision
dwq
on-device
apple-silicon
qwen3
qwen3-asr
mega-asr
Instructions to use Reza2kn/mega-asr-mlx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use Reza2kn/mega-asr-mlx with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mega-asr-mlx Reza2kn/mega-asr-mlx
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
| """MLX ASR bench: ONNX INT8 audio encoder + MLX 4-bit Qwen3 LLM. | |
| For each VITW example: | |
| 1. Build mel features via WhisperFeatureExtractor (from Qwen3-ASR preprocessor config) | |
| 2. Run mel -> ONNX audio encoder -> audio embeddings | |
| 3. Build prompt + audio placeholder token sequence | |
| 4. Embed text tokens via the model's embed_tokens (MLX), scatter audio embeddings | |
| 5. MLX greedy decode using input_embeddings | |
| 6. Compare hypothesis to reference -> WER + agreement % | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import re | |
| import sys | |
| import warnings | |
| from pathlib import Path | |
| import numpy as np | |
| import mlx.core as mx | |
| warnings.filterwarnings("ignore") | |
| REFERENCES = { | |
| "noise": "I usually take the quieter road home because the main street gets crowded after work.", | |
| "far_field": "Please remind me to print the forms before we leave for the appointment tomorrow.", | |
| "obstructed": "I forgot my charger at home, so I need to find an outlet before the meeting starts.", | |
| "distortion": "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.", | |
| "recording": "Can you check whether the train still stops at the downtown station after eight tonight?", | |
| "echo": "I need to return these shoes because the size feels fine standing up but terrible while walking.", | |
| "dropout": "My aunt is learning video calls, and she gets excited whenever the picture actually works.", | |
| "mixed": "My sister is bringing dinner over later, so we do not need to cook tonight.", | |
| } | |
| _NORM_RE = re.compile(r"[^a-z0-9\s]") | |
| def normalize(text: str) -> str: | |
| if "<asr_text>" in text: | |
| text = text.split("<asr_text>", 1)[1] | |
| text = text.lower() | |
| text = _NORM_RE.sub(" ", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| def wer(ref: str, hyp: str): | |
| r = ref.split(); h = hyp.split() | |
| if not r: | |
| return (1.0 if h else 0.0, len(h), 0) | |
| d = np.zeros((len(r) + 1, len(h) + 1), dtype=np.int32) | |
| for i in range(len(r) + 1): d[i, 0] = i | |
| for j in range(len(h) + 1): d[0, j] = j | |
| for i in range(1, len(r) + 1): | |
| for j in range(1, len(h) + 1): | |
| cost = 0 if r[i-1] == h[j-1] else 1 | |
| d[i, j] = min(d[i-1, j] + 1, d[i, j-1] + 1, d[i-1, j-1] + cost) | |
| return d[len(r), len(h)] / max(len(r), 1), d[len(r), len(h)], len(r) | |
| def color(pct, s): | |
| if pct >= 70: return f"\033[92m{s}\033[0m" # green | |
| if pct >= 50: return f"\033[93m{s}\033[0m" # yellow/orange | |
| if pct >= 25: return f"\033[33m{s}\033[0m" | |
| return f"\033[91m{s}\033[0m" | |
| def build_prompt_ids(tokenizer, audio_pad_count: int, audio_pad_id: int): | |
| """Build the Qwen3-ASR chat-template prompt + force_language='English'. | |
| Inserts `audio_pad_count` copies of audio_pad_id between audio_start and | |
| audio_end. Returns input_ids as np.int64 (1, L). | |
| """ | |
| # Apply chat template manually since the original tokenizer was trained with the template | |
| prompt = ( | |
| "<|im_start|>system\n<|im_end|>\n" | |
| "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n" | |
| "<|im_start|>assistant\n" | |
| "language English<asr_text>" | |
| ) | |
| ids = tokenizer.encode(prompt, add_special_tokens=False) | |
| # Find the single <|audio_pad|> position and expand it to audio_pad_count copies | |
| pos = ids.index(audio_pad_id) | |
| expanded = ids[:pos] + [audio_pad_id] * audio_pad_count + ids[pos + 1:] | |
| return np.array([expanded], dtype=np.int64) | |
| def main(): | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--mlx-llm-path", default="models/mlx/mega-asr-llm-4bit", type=Path) | |
| ap.add_argument("--encoder-path", default="models/mega-asr-onnx-hf/onnx/audio_encoder_int8.onnx", type=Path) | |
| ap.add_argument("--examples-dir", default="models/mega-asr-onnx-hf/examples", type=Path) | |
| ap.add_argument("--max-new-tokens", type=int, default=80) | |
| args = ap.parse_args() | |
| import soundfile as sf | |
| import onnxruntime as ort | |
| from mlx_lm.utils import load | |
| from transformers import AutoFeatureExtractor, AutoTokenizer | |
| print(f"Loading MLX LLM from {args.mlx_llm_path} ...") | |
| model, tokenizer = load(str(args.mlx_llm_path)) | |
| # Use the ORIGINAL Qwen3-ASR tokenizer (has audio special tokens) | |
| hf_tokenizer = AutoTokenizer.from_pretrained("models/mega-asr-hf/Qwen3-ASR-1.7B") | |
| # Audio-related token ids per the Qwen3-ASR config | |
| audio_pad_id = 151676 # <|audio_pad|> — the placeholder we scatter audio_embeds into | |
| eos_id = 151645 # <|im_end|> | |
| print(f"audio_pad_id={audio_pad_id} eos_id={eos_id}") | |
| print(f"Loading ONNX encoder from {args.encoder_path} ...") | |
| enc_sess = ort.InferenceSession(str(args.encoder_path), providers=["CPUExecutionProvider"]) | |
| # Use Qwen3-ASR's WhisperFeatureExtractor — load from the original HF path | |
| feat_ext = AutoFeatureExtractor.from_pretrained("models/mega-asr-hf/Qwen3-ASR-1.7B") | |
| # Bench all 8 examples | |
| total_wer = 0.0 | |
| total_edits = 0 | |
| total_words = 0 | |
| n = 0 | |
| results = [] | |
| for name in sorted(REFERENCES): | |
| wav_path = args.examples_dir / f"{name}.wav" | |
| if not wav_path.exists(): | |
| print(f" skip {name} (missing wav)") | |
| continue | |
| audio, sr = sf.read(str(wav_path)) | |
| if audio.ndim > 1: | |
| audio = audio.mean(axis=1) | |
| if sr != 16000: | |
| import librosa | |
| audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000) | |
| # Mel features | |
| feats = feat_ext(audio, sampling_rate=16000, return_tensors="np", return_attention_mask=False) | |
| mel = feats["input_features"] # (1, 128, T_mel) | |
| T_mel = mel.shape[-1] | |
| if T_mel > 3000: | |
| mel = mel[..., :3000]; T_mel = 3000 | |
| mel_padded = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32) | |
| # Encoder | |
| audio_embeds = enc_sess.run(["audio_embeds"], {"mel": mel_padded})[0] | |
| # Compute actual audio frames (chunked CNN — see audio_encoder_wrapper.py) | |
| real_chunks = (T_mel + 99) // 100 | |
| last_chunk_mel = T_mel - (real_chunks - 1) * 100 | |
| real_frames = (real_chunks - 1) * 13 + (last_chunk_mel + 7) // 8 | |
| audio_embeds = audio_embeds[:, :real_frames] # (1, F, 2048) | |
| # Build prompt with placeholder audio_pad tokens of length = real_frames | |
| prompt_ids = build_prompt_ids(hf_tokenizer, real_frames, audio_pad_id) | |
| # Embed text via MLX model's embed_tokens | |
| ids_mx = mx.array(prompt_ids) | |
| text_emb = model.model.embed_tokens(ids_mx) # (1, L, 2048) | |
| # Scatter audio embeddings into positions where input_ids == audio_pad_id | |
| mask_np = (prompt_ids[0] == audio_pad_id).astype(np.int32) | |
| mask_idx_np = np.where(mask_np)[0] | |
| mask_idx = mx.array(mask_idx_np) | |
| audio_emb_mx = mx.array(audio_embeds[0]).astype(text_emb.dtype) # (real_frames, 2048) | |
| combined_mx = text_emb | |
| combined_mx[0, mask_idx] = audio_emb_mx | |
| # MLX greedy decode using input_embeddings | |
| from mlx_lm.models.cache import make_prompt_cache | |
| cache = make_prompt_cache(model) | |
| logits = model(inputs=mx.zeros((1, combined_mx.shape[1]), dtype=mx.int64), cache=cache, | |
| input_embeddings=combined_mx) | |
| next_tok = int(mx.argmax(logits[0, -1, :]).item()) | |
| out_ids = [next_tok] | |
| for _ in range(args.max_new_tokens - 1): | |
| if next_tok == eos_id: | |
| break | |
| logits = model(inputs=mx.array([[next_tok]]), cache=cache) | |
| next_tok = int(mx.argmax(logits[0, -1, :]).item()) | |
| out_ids.append(next_tok) | |
| # Strip trailing eos if any | |
| if out_ids and out_ids[-1] == eos_id: | |
| out_ids = out_ids[:-1] | |
| hyp_text = hf_tokenizer.decode(out_ids, skip_special_tokens=True) | |
| ref = normalize(REFERENCES[name]) | |
| hyp = normalize(hyp_text) | |
| w, ed, words = wer(ref, hyp) | |
| agree = max(0.0, 1.0 - w) * 100 | |
| total_wer += w | |
| total_edits += ed | |
| total_words += words | |
| n += 1 | |
| print(f"\n[{color(agree, name.ljust(10))}] WER={w*100:5.1f}% agree={color(agree, f'{agree:5.1f}%')}") | |
| print(f" REF: {ref}") | |
| print(f" HYP: {hyp}") | |
| results.append({"name": name, "wer": w, "agree": agree, "hyp": hyp}) | |
| avg_agree = (1 - total_wer / n) * 100 if n else 0 | |
| print(f"\n{color(avg_agree, f'=== AVERAGE: agreement {avg_agree:.1f}% WER {total_edits/total_words*100:.1f}% ({total_edits}/{total_words}) ===')}") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |