File size: 8,680 Bytes
17540fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
"""MLX ASR bench: ONNX INT8 audio encoder + MLX 4-bit Qwen3 LLM.

For each VITW example:
  1. Build mel features via WhisperFeatureExtractor (from Qwen3-ASR preprocessor config)
  2. Run mel -> ONNX audio encoder -> audio embeddings
  3. Build prompt + audio placeholder token sequence
  4. Embed text tokens via the model's embed_tokens (MLX), scatter audio embeddings
  5. MLX greedy decode using input_embeddings
  6. Compare hypothesis to reference -> WER + agreement %
"""
from __future__ import annotations
import argparse
import json
import re
import sys
import warnings
from pathlib import Path

import numpy as np
import mlx.core as mx

warnings.filterwarnings("ignore")


REFERENCES = {
    "noise":      "I usually take the quieter road home because the main street gets crowded after work.",
    "far_field":  "Please remind me to print the forms before we leave for the appointment tomorrow.",
    "obstructed": "I forgot my charger at home, so I need to find an outlet before the meeting starts.",
    "distortion": "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.",
    "recording":  "Can you check whether the train still stops at the downtown station after eight tonight?",
    "echo":       "I need to return these shoes because the size feels fine standing up but terrible while walking.",
    "dropout":    "My aunt is learning video calls, and she gets excited whenever the picture actually works.",
    "mixed":      "My sister is bringing dinner over later, so we do not need to cook tonight.",
}

_NORM_RE = re.compile(r"[^a-z0-9\s]")


def normalize(text: str) -> str:
    if "<asr_text>" in text:
        text = text.split("<asr_text>", 1)[1]
    text = text.lower()
    text = _NORM_RE.sub(" ", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text


def wer(ref: str, hyp: str):
    r = ref.split(); h = hyp.split()
    if not r:
        return (1.0 if h else 0.0, len(h), 0)
    d = np.zeros((len(r) + 1, len(h) + 1), dtype=np.int32)
    for i in range(len(r) + 1): d[i, 0] = i
    for j in range(len(h) + 1): d[0, j] = j
    for i in range(1, len(r) + 1):
        for j in range(1, len(h) + 1):
            cost = 0 if r[i-1] == h[j-1] else 1
            d[i, j] = min(d[i-1, j] + 1, d[i, j-1] + 1, d[i-1, j-1] + cost)
    return d[len(r), len(h)] / max(len(r), 1), d[len(r), len(h)], len(r)


def color(pct, s):
    if pct >= 70:    return f"\033[92m{s}\033[0m"  # green
    if pct >= 50:    return f"\033[93m{s}\033[0m"  # yellow/orange
    if pct >= 25:    return f"\033[33m{s}\033[0m"
    return f"\033[91m{s}\033[0m"


def build_prompt_ids(tokenizer, audio_pad_count: int, audio_pad_id: int):
    """Build the Qwen3-ASR chat-template prompt + force_language='English'.

    Inserts `audio_pad_count` copies of audio_pad_id between audio_start and
    audio_end. Returns input_ids as np.int64 (1, L).
    """
    # Apply chat template manually since the original tokenizer was trained with the template
    prompt = (
        "<|im_start|>system\n<|im_end|>\n"
        "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n"
        "<|im_start|>assistant\n"
        "language English<asr_text>"
    )
    ids = tokenizer.encode(prompt, add_special_tokens=False)
    # Find the single <|audio_pad|> position and expand it to audio_pad_count copies
    pos = ids.index(audio_pad_id)
    expanded = ids[:pos] + [audio_pad_id] * audio_pad_count + ids[pos + 1:]
    return np.array([expanded], dtype=np.int64)


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--mlx-llm-path", default="models/mlx/mega-asr-llm-4bit", type=Path)
    ap.add_argument("--encoder-path", default="models/mega-asr-onnx-hf/onnx/audio_encoder_int8.onnx", type=Path)
    ap.add_argument("--examples-dir", default="models/mega-asr-onnx-hf/examples", type=Path)
    ap.add_argument("--max-new-tokens", type=int, default=80)
    args = ap.parse_args()

    import soundfile as sf
    import onnxruntime as ort
    from mlx_lm.utils import load
    from transformers import AutoFeatureExtractor, AutoTokenizer

    print(f"Loading MLX LLM from {args.mlx_llm_path} ...")
    model, tokenizer = load(str(args.mlx_llm_path))
    # Use the ORIGINAL Qwen3-ASR tokenizer (has audio special tokens)
    hf_tokenizer = AutoTokenizer.from_pretrained("models/mega-asr-hf/Qwen3-ASR-1.7B")

    # Audio-related token ids per the Qwen3-ASR config
    audio_pad_id = 151676   # <|audio_pad|> — the placeholder we scatter audio_embeds into
    eos_id = 151645         # <|im_end|>
    print(f"audio_pad_id={audio_pad_id} eos_id={eos_id}")

    print(f"Loading ONNX encoder from {args.encoder_path} ...")
    enc_sess = ort.InferenceSession(str(args.encoder_path), providers=["CPUExecutionProvider"])

    # Use Qwen3-ASR's WhisperFeatureExtractor — load from the original HF path
    feat_ext = AutoFeatureExtractor.from_pretrained("models/mega-asr-hf/Qwen3-ASR-1.7B")

    # Bench all 8 examples
    total_wer = 0.0
    total_edits = 0
    total_words = 0
    n = 0
    results = []
    for name in sorted(REFERENCES):
        wav_path = args.examples_dir / f"{name}.wav"
        if not wav_path.exists():
            print(f"  skip {name} (missing wav)")
            continue
        audio, sr = sf.read(str(wav_path))
        if audio.ndim > 1:
            audio = audio.mean(axis=1)
        if sr != 16000:
            import librosa
            audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000)
        # Mel features
        feats = feat_ext(audio, sampling_rate=16000, return_tensors="np", return_attention_mask=False)
        mel = feats["input_features"]  # (1, 128, T_mel)
        T_mel = mel.shape[-1]
        if T_mel > 3000:
            mel = mel[..., :3000]; T_mel = 3000
        mel_padded = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32)
        # Encoder
        audio_embeds = enc_sess.run(["audio_embeds"], {"mel": mel_padded})[0]
        # Compute actual audio frames (chunked CNN — see audio_encoder_wrapper.py)
        real_chunks = (T_mel + 99) // 100
        last_chunk_mel = T_mel - (real_chunks - 1) * 100
        real_frames = (real_chunks - 1) * 13 + (last_chunk_mel + 7) // 8
        audio_embeds = audio_embeds[:, :real_frames]  # (1, F, 2048)

        # Build prompt with placeholder audio_pad tokens of length = real_frames
        prompt_ids = build_prompt_ids(hf_tokenizer, real_frames, audio_pad_id)
        # Embed text via MLX model's embed_tokens
        ids_mx = mx.array(prompt_ids)
        text_emb = model.model.embed_tokens(ids_mx)  # (1, L, 2048)
        # Scatter audio embeddings into positions where input_ids == audio_pad_id
        mask_np = (prompt_ids[0] == audio_pad_id).astype(np.int32)
        mask_idx_np = np.where(mask_np)[0]
        mask_idx = mx.array(mask_idx_np)
        audio_emb_mx = mx.array(audio_embeds[0]).astype(text_emb.dtype)  # (real_frames, 2048)
        combined_mx = text_emb
        combined_mx[0, mask_idx] = audio_emb_mx

        # MLX greedy decode using input_embeddings
        from mlx_lm.models.cache import make_prompt_cache
        cache = make_prompt_cache(model)
        logits = model(inputs=mx.zeros((1, combined_mx.shape[1]), dtype=mx.int64), cache=cache,
                       input_embeddings=combined_mx)
        next_tok = int(mx.argmax(logits[0, -1, :]).item())
        out_ids = [next_tok]
        for _ in range(args.max_new_tokens - 1):
            if next_tok == eos_id:
                break
            logits = model(inputs=mx.array([[next_tok]]), cache=cache)
            next_tok = int(mx.argmax(logits[0, -1, :]).item())
            out_ids.append(next_tok)
        # Strip trailing eos if any
        if out_ids and out_ids[-1] == eos_id:
            out_ids = out_ids[:-1]
        hyp_text = hf_tokenizer.decode(out_ids, skip_special_tokens=True)

        ref = normalize(REFERENCES[name])
        hyp = normalize(hyp_text)
        w, ed, words = wer(ref, hyp)
        agree = max(0.0, 1.0 - w) * 100
        total_wer += w
        total_edits += ed
        total_words += words
        n += 1
        print(f"\n[{color(agree, name.ljust(10))}] WER={w*100:5.1f}%  agree={color(agree, f'{agree:5.1f}%')}")
        print(f"  REF: {ref}")
        print(f"  HYP: {hyp}")
        results.append({"name": name, "wer": w, "agree": agree, "hyp": hyp})

    avg_agree = (1 - total_wer / n) * 100 if n else 0
    print(f"\n{color(avg_agree, f'=== AVERAGE: agreement {avg_agree:.1f}%  WER {total_edits/total_words*100:.1f}% ({total_edits}/{total_words}) ===')}")
    return 0


if __name__ == "__main__":
    sys.exit(main())