Reza2kn commited on
Commit
17540fa
·
verified ·
1 Parent(s): b799fc8

Add inference.py

Browse files
Files changed (1) hide show
  1. inference.py +200 -0
inference.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MLX ASR bench: ONNX INT8 audio encoder + MLX 4-bit Qwen3 LLM.
2
+
3
+ For each VITW example:
4
+ 1. Build mel features via WhisperFeatureExtractor (from Qwen3-ASR preprocessor config)
5
+ 2. Run mel -> ONNX audio encoder -> audio embeddings
6
+ 3. Build prompt + audio placeholder token sequence
7
+ 4. Embed text tokens via the model's embed_tokens (MLX), scatter audio embeddings
8
+ 5. MLX greedy decode using input_embeddings
9
+ 6. Compare hypothesis to reference -> WER + agreement %
10
+ """
11
+ from __future__ import annotations
12
+ import argparse
13
+ import json
14
+ import re
15
+ import sys
16
+ import warnings
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import mlx.core as mx
21
+
22
+ warnings.filterwarnings("ignore")
23
+
24
+
25
+ REFERENCES = {
26
+ "noise": "I usually take the quieter road home because the main street gets crowded after work.",
27
+ "far_field": "Please remind me to print the forms before we leave for the appointment tomorrow.",
28
+ "obstructed": "I forgot my charger at home, so I need to find an outlet before the meeting starts.",
29
+ "distortion": "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.",
30
+ "recording": "Can you check whether the train still stops at the downtown station after eight tonight?",
31
+ "echo": "I need to return these shoes because the size feels fine standing up but terrible while walking.",
32
+ "dropout": "My aunt is learning video calls, and she gets excited whenever the picture actually works.",
33
+ "mixed": "My sister is bringing dinner over later, so we do not need to cook tonight.",
34
+ }
35
+
36
+ _NORM_RE = re.compile(r"[^a-z0-9\s]")
37
+
38
+
39
+ def normalize(text: str) -> str:
40
+ if "<asr_text>" in text:
41
+ text = text.split("<asr_text>", 1)[1]
42
+ text = text.lower()
43
+ text = _NORM_RE.sub(" ", text)
44
+ text = re.sub(r"\s+", " ", text).strip()
45
+ return text
46
+
47
+
48
+ def wer(ref: str, hyp: str):
49
+ r = ref.split(); h = hyp.split()
50
+ if not r:
51
+ return (1.0 if h else 0.0, len(h), 0)
52
+ d = np.zeros((len(r) + 1, len(h) + 1), dtype=np.int32)
53
+ for i in range(len(r) + 1): d[i, 0] = i
54
+ for j in range(len(h) + 1): d[0, j] = j
55
+ for i in range(1, len(r) + 1):
56
+ for j in range(1, len(h) + 1):
57
+ cost = 0 if r[i-1] == h[j-1] else 1
58
+ d[i, j] = min(d[i-1, j] + 1, d[i, j-1] + 1, d[i-1, j-1] + cost)
59
+ return d[len(r), len(h)] / max(len(r), 1), d[len(r), len(h)], len(r)
60
+
61
+
62
+ def color(pct, s):
63
+ if pct >= 70: return f"\033[92m{s}\033[0m" # green
64
+ if pct >= 50: return f"\033[93m{s}\033[0m" # yellow/orange
65
+ if pct >= 25: return f"\033[33m{s}\033[0m"
66
+ return f"\033[91m{s}\033[0m"
67
+
68
+
69
+ def build_prompt_ids(tokenizer, audio_pad_count: int, audio_pad_id: int):
70
+ """Build the Qwen3-ASR chat-template prompt + force_language='English'.
71
+
72
+ Inserts `audio_pad_count` copies of audio_pad_id between audio_start and
73
+ audio_end. Returns input_ids as np.int64 (1, L).
74
+ """
75
+ # Apply chat template manually since the original tokenizer was trained with the template
76
+ prompt = (
77
+ "<|im_start|>system\n<|im_end|>\n"
78
+ "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n"
79
+ "<|im_start|>assistant\n"
80
+ "language English<asr_text>"
81
+ )
82
+ ids = tokenizer.encode(prompt, add_special_tokens=False)
83
+ # Find the single <|audio_pad|> position and expand it to audio_pad_count copies
84
+ pos = ids.index(audio_pad_id)
85
+ expanded = ids[:pos] + [audio_pad_id] * audio_pad_count + ids[pos + 1:]
86
+ return np.array([expanded], dtype=np.int64)
87
+
88
+
89
+ def main():
90
+ ap = argparse.ArgumentParser()
91
+ ap.add_argument("--mlx-llm-path", default="models/mlx/mega-asr-llm-4bit", type=Path)
92
+ ap.add_argument("--encoder-path", default="models/mega-asr-onnx-hf/onnx/audio_encoder_int8.onnx", type=Path)
93
+ ap.add_argument("--examples-dir", default="models/mega-asr-onnx-hf/examples", type=Path)
94
+ ap.add_argument("--max-new-tokens", type=int, default=80)
95
+ args = ap.parse_args()
96
+
97
+ import soundfile as sf
98
+ import onnxruntime as ort
99
+ from mlx_lm.utils import load
100
+ from transformers import AutoFeatureExtractor, AutoTokenizer
101
+
102
+ print(f"Loading MLX LLM from {args.mlx_llm_path} ...")
103
+ model, tokenizer = load(str(args.mlx_llm_path))
104
+ # Use the ORIGINAL Qwen3-ASR tokenizer (has audio special tokens)
105
+ hf_tokenizer = AutoTokenizer.from_pretrained("models/mega-asr-hf/Qwen3-ASR-1.7B")
106
+
107
+ # Audio-related token ids per the Qwen3-ASR config
108
+ audio_pad_id = 151676 # <|audio_pad|> — the placeholder we scatter audio_embeds into
109
+ eos_id = 151645 # <|im_end|>
110
+ print(f"audio_pad_id={audio_pad_id} eos_id={eos_id}")
111
+
112
+ print(f"Loading ONNX encoder from {args.encoder_path} ...")
113
+ enc_sess = ort.InferenceSession(str(args.encoder_path), providers=["CPUExecutionProvider"])
114
+
115
+ # Use Qwen3-ASR's WhisperFeatureExtractor — load from the original HF path
116
+ feat_ext = AutoFeatureExtractor.from_pretrained("models/mega-asr-hf/Qwen3-ASR-1.7B")
117
+
118
+ # Bench all 8 examples
119
+ total_wer = 0.0
120
+ total_edits = 0
121
+ total_words = 0
122
+ n = 0
123
+ results = []
124
+ for name in sorted(REFERENCES):
125
+ wav_path = args.examples_dir / f"{name}.wav"
126
+ if not wav_path.exists():
127
+ print(f" skip {name} (missing wav)")
128
+ continue
129
+ audio, sr = sf.read(str(wav_path))
130
+ if audio.ndim > 1:
131
+ audio = audio.mean(axis=1)
132
+ if sr != 16000:
133
+ import librosa
134
+ audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000)
135
+ # Mel features
136
+ feats = feat_ext(audio, sampling_rate=16000, return_tensors="np", return_attention_mask=False)
137
+ mel = feats["input_features"] # (1, 128, T_mel)
138
+ T_mel = mel.shape[-1]
139
+ if T_mel > 3000:
140
+ mel = mel[..., :3000]; T_mel = 3000
141
+ mel_padded = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32)
142
+ # Encoder
143
+ audio_embeds = enc_sess.run(["audio_embeds"], {"mel": mel_padded})[0]
144
+ # Compute actual audio frames (chunked CNN — see audio_encoder_wrapper.py)
145
+ real_chunks = (T_mel + 99) // 100
146
+ last_chunk_mel = T_mel - (real_chunks - 1) * 100
147
+ real_frames = (real_chunks - 1) * 13 + (last_chunk_mel + 7) // 8
148
+ audio_embeds = audio_embeds[:, :real_frames] # (1, F, 2048)
149
+
150
+ # Build prompt with placeholder audio_pad tokens of length = real_frames
151
+ prompt_ids = build_prompt_ids(hf_tokenizer, real_frames, audio_pad_id)
152
+ # Embed text via MLX model's embed_tokens
153
+ ids_mx = mx.array(prompt_ids)
154
+ text_emb = model.model.embed_tokens(ids_mx) # (1, L, 2048)
155
+ # Scatter audio embeddings into positions where input_ids == audio_pad_id
156
+ mask_np = (prompt_ids[0] == audio_pad_id).astype(np.int32)
157
+ mask_idx_np = np.where(mask_np)[0]
158
+ mask_idx = mx.array(mask_idx_np)
159
+ audio_emb_mx = mx.array(audio_embeds[0]).astype(text_emb.dtype) # (real_frames, 2048)
160
+ combined_mx = text_emb
161
+ combined_mx[0, mask_idx] = audio_emb_mx
162
+
163
+ # MLX greedy decode using input_embeddings
164
+ from mlx_lm.models.cache import make_prompt_cache
165
+ cache = make_prompt_cache(model)
166
+ logits = model(inputs=mx.zeros((1, combined_mx.shape[1]), dtype=mx.int64), cache=cache,
167
+ input_embeddings=combined_mx)
168
+ next_tok = int(mx.argmax(logits[0, -1, :]).item())
169
+ out_ids = [next_tok]
170
+ for _ in range(args.max_new_tokens - 1):
171
+ if next_tok == eos_id:
172
+ break
173
+ logits = model(inputs=mx.array([[next_tok]]), cache=cache)
174
+ next_tok = int(mx.argmax(logits[0, -1, :]).item())
175
+ out_ids.append(next_tok)
176
+ # Strip trailing eos if any
177
+ if out_ids and out_ids[-1] == eos_id:
178
+ out_ids = out_ids[:-1]
179
+ hyp_text = hf_tokenizer.decode(out_ids, skip_special_tokens=True)
180
+
181
+ ref = normalize(REFERENCES[name])
182
+ hyp = normalize(hyp_text)
183
+ w, ed, words = wer(ref, hyp)
184
+ agree = max(0.0, 1.0 - w) * 100
185
+ total_wer += w
186
+ total_edits += ed
187
+ total_words += words
188
+ n += 1
189
+ print(f"\n[{color(agree, name.ljust(10))}] WER={w*100:5.1f}% agree={color(agree, f'{agree:5.1f}%')}")
190
+ print(f" REF: {ref}")
191
+ print(f" HYP: {hyp}")
192
+ results.append({"name": name, "wer": w, "agree": agree, "hyp": hyp})
193
+
194
+ avg_agree = (1 - total_wer / n) * 100 if n else 0
195
+ print(f"\n{color(avg_agree, f'=== AVERAGE: agreement {avg_agree:.1f}% WER {total_edits/total_words*100:.1f}% ({total_edits}/{total_words}) ===')}")
196
+ return 0
197
+
198
+
199
+ if __name__ == "__main__":
200
+ sys.exit(main())