Automatic Speech Recognition
MLX
ONNX
Safetensors
asr
speech-recognition
robust-asr
quantized
int4
4bit
mixed-precision
dwq
on-device
apple-silicon
qwen3
qwen3-asr
mega-asr
Instructions to use Reza2kn/mega-asr-mlx with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- MLX
How to use Reza2kn/mega-asr-mlx with MLX:
# Download the model from the Hub pip install huggingface_hub[hf_xet] huggingface-cli download --local-dir mega-asr-mlx Reza2kn/mega-asr-mlx
- Notebooks
- Google Colab
- Kaggle
- Local Apps
- LM Studio
File size: 8,680 Bytes
17540fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | """MLX ASR bench: ONNX INT8 audio encoder + MLX 4-bit Qwen3 LLM.
For each VITW example:
1. Build mel features via WhisperFeatureExtractor (from Qwen3-ASR preprocessor config)
2. Run mel -> ONNX audio encoder -> audio embeddings
3. Build prompt + audio placeholder token sequence
4. Embed text tokens via the model's embed_tokens (MLX), scatter audio embeddings
5. MLX greedy decode using input_embeddings
6. Compare hypothesis to reference -> WER + agreement %
"""
from __future__ import annotations
import argparse
import json
import re
import sys
import warnings
from pathlib import Path
import numpy as np
import mlx.core as mx
warnings.filterwarnings("ignore")
REFERENCES = {
"noise": "I usually take the quieter road home because the main street gets crowded after work.",
"far_field": "Please remind me to print the forms before we leave for the appointment tomorrow.",
"obstructed": "I forgot my charger at home, so I need to find an outlet before the meeting starts.",
"distortion": "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.",
"recording": "Can you check whether the train still stops at the downtown station after eight tonight?",
"echo": "I need to return these shoes because the size feels fine standing up but terrible while walking.",
"dropout": "My aunt is learning video calls, and she gets excited whenever the picture actually works.",
"mixed": "My sister is bringing dinner over later, so we do not need to cook tonight.",
}
_NORM_RE = re.compile(r"[^a-z0-9\s]")
def normalize(text: str) -> str:
if "<asr_text>" in text:
text = text.split("<asr_text>", 1)[1]
text = text.lower()
text = _NORM_RE.sub(" ", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def wer(ref: str, hyp: str):
r = ref.split(); h = hyp.split()
if not r:
return (1.0 if h else 0.0, len(h), 0)
d = np.zeros((len(r) + 1, len(h) + 1), dtype=np.int32)
for i in range(len(r) + 1): d[i, 0] = i
for j in range(len(h) + 1): d[0, j] = j
for i in range(1, len(r) + 1):
for j in range(1, len(h) + 1):
cost = 0 if r[i-1] == h[j-1] else 1
d[i, j] = min(d[i-1, j] + 1, d[i, j-1] + 1, d[i-1, j-1] + cost)
return d[len(r), len(h)] / max(len(r), 1), d[len(r), len(h)], len(r)
def color(pct, s):
if pct >= 70: return f"\033[92m{s}\033[0m" # green
if pct >= 50: return f"\033[93m{s}\033[0m" # yellow/orange
if pct >= 25: return f"\033[33m{s}\033[0m"
return f"\033[91m{s}\033[0m"
def build_prompt_ids(tokenizer, audio_pad_count: int, audio_pad_id: int):
"""Build the Qwen3-ASR chat-template prompt + force_language='English'.
Inserts `audio_pad_count` copies of audio_pad_id between audio_start and
audio_end. Returns input_ids as np.int64 (1, L).
"""
# Apply chat template manually since the original tokenizer was trained with the template
prompt = (
"<|im_start|>system\n<|im_end|>\n"
"<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n"
"<|im_start|>assistant\n"
"language English<asr_text>"
)
ids = tokenizer.encode(prompt, add_special_tokens=False)
# Find the single <|audio_pad|> position and expand it to audio_pad_count copies
pos = ids.index(audio_pad_id)
expanded = ids[:pos] + [audio_pad_id] * audio_pad_count + ids[pos + 1:]
return np.array([expanded], dtype=np.int64)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--mlx-llm-path", default="models/mlx/mega-asr-llm-4bit", type=Path)
ap.add_argument("--encoder-path", default="models/mega-asr-onnx-hf/onnx/audio_encoder_int8.onnx", type=Path)
ap.add_argument("--examples-dir", default="models/mega-asr-onnx-hf/examples", type=Path)
ap.add_argument("--max-new-tokens", type=int, default=80)
args = ap.parse_args()
import soundfile as sf
import onnxruntime as ort
from mlx_lm.utils import load
from transformers import AutoFeatureExtractor, AutoTokenizer
print(f"Loading MLX LLM from {args.mlx_llm_path} ...")
model, tokenizer = load(str(args.mlx_llm_path))
# Use the ORIGINAL Qwen3-ASR tokenizer (has audio special tokens)
hf_tokenizer = AutoTokenizer.from_pretrained("models/mega-asr-hf/Qwen3-ASR-1.7B")
# Audio-related token ids per the Qwen3-ASR config
audio_pad_id = 151676 # <|audio_pad|> — the placeholder we scatter audio_embeds into
eos_id = 151645 # <|im_end|>
print(f"audio_pad_id={audio_pad_id} eos_id={eos_id}")
print(f"Loading ONNX encoder from {args.encoder_path} ...")
enc_sess = ort.InferenceSession(str(args.encoder_path), providers=["CPUExecutionProvider"])
# Use Qwen3-ASR's WhisperFeatureExtractor — load from the original HF path
feat_ext = AutoFeatureExtractor.from_pretrained("models/mega-asr-hf/Qwen3-ASR-1.7B")
# Bench all 8 examples
total_wer = 0.0
total_edits = 0
total_words = 0
n = 0
results = []
for name in sorted(REFERENCES):
wav_path = args.examples_dir / f"{name}.wav"
if not wav_path.exists():
print(f" skip {name} (missing wav)")
continue
audio, sr = sf.read(str(wav_path))
if audio.ndim > 1:
audio = audio.mean(axis=1)
if sr != 16000:
import librosa
audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000)
# Mel features
feats = feat_ext(audio, sampling_rate=16000, return_tensors="np", return_attention_mask=False)
mel = feats["input_features"] # (1, 128, T_mel)
T_mel = mel.shape[-1]
if T_mel > 3000:
mel = mel[..., :3000]; T_mel = 3000
mel_padded = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32)
# Encoder
audio_embeds = enc_sess.run(["audio_embeds"], {"mel": mel_padded})[0]
# Compute actual audio frames (chunked CNN — see audio_encoder_wrapper.py)
real_chunks = (T_mel + 99) // 100
last_chunk_mel = T_mel - (real_chunks - 1) * 100
real_frames = (real_chunks - 1) * 13 + (last_chunk_mel + 7) // 8
audio_embeds = audio_embeds[:, :real_frames] # (1, F, 2048)
# Build prompt with placeholder audio_pad tokens of length = real_frames
prompt_ids = build_prompt_ids(hf_tokenizer, real_frames, audio_pad_id)
# Embed text via MLX model's embed_tokens
ids_mx = mx.array(prompt_ids)
text_emb = model.model.embed_tokens(ids_mx) # (1, L, 2048)
# Scatter audio embeddings into positions where input_ids == audio_pad_id
mask_np = (prompt_ids[0] == audio_pad_id).astype(np.int32)
mask_idx_np = np.where(mask_np)[0]
mask_idx = mx.array(mask_idx_np)
audio_emb_mx = mx.array(audio_embeds[0]).astype(text_emb.dtype) # (real_frames, 2048)
combined_mx = text_emb
combined_mx[0, mask_idx] = audio_emb_mx
# MLX greedy decode using input_embeddings
from mlx_lm.models.cache import make_prompt_cache
cache = make_prompt_cache(model)
logits = model(inputs=mx.zeros((1, combined_mx.shape[1]), dtype=mx.int64), cache=cache,
input_embeddings=combined_mx)
next_tok = int(mx.argmax(logits[0, -1, :]).item())
out_ids = [next_tok]
for _ in range(args.max_new_tokens - 1):
if next_tok == eos_id:
break
logits = model(inputs=mx.array([[next_tok]]), cache=cache)
next_tok = int(mx.argmax(logits[0, -1, :]).item())
out_ids.append(next_tok)
# Strip trailing eos if any
if out_ids and out_ids[-1] == eos_id:
out_ids = out_ids[:-1]
hyp_text = hf_tokenizer.decode(out_ids, skip_special_tokens=True)
ref = normalize(REFERENCES[name])
hyp = normalize(hyp_text)
w, ed, words = wer(ref, hyp)
agree = max(0.0, 1.0 - w) * 100
total_wer += w
total_edits += ed
total_words += words
n += 1
print(f"\n[{color(agree, name.ljust(10))}] WER={w*100:5.1f}% agree={color(agree, f'{agree:5.1f}%')}")
print(f" REF: {ref}")
print(f" HYP: {hyp}")
results.append({"name": name, "wer": w, "agree": agree, "hyp": hyp})
avg_agree = (1 - total_wer / n) * 100 if n else 0
print(f"\n{color(avg_agree, f'=== AVERAGE: agreement {avg_agree:.1f}% WER {total_edits/total_words*100:.1f}% ({total_edits}/{total_words}) ===')}")
return 0
if __name__ == "__main__":
sys.exit(main())
|