Reza2kn commited on
Commit
74e5b09
·
verified ·
1 Parent(s): 6e9918c

Working CoreML LUT4 input_embeds variant (86.9% on VITW)

Browse files
Files changed (1) hide show
  1. inference_asr.py +216 -0
inference_asr.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """End-to-end Mega-ASR pipeline on CoreML:
2
+ ONNX audio encoder + CoreML LUT-4 LLM (input_embeds variant) + bench.
3
+
4
+ The CoreML LLM is single-token-step (ANE-friendly). For each token in the
5
+ prompt we feed (inputs_embeds[t], current_pos=t) to populate the KV cache;
6
+ then we greedy-decode by feeding one token at a time.
7
+ """
8
+ from __future__ import annotations
9
+ import argparse
10
+ import json
11
+ import re
12
+ import sys
13
+ import warnings
14
+ from pathlib import Path
15
+
16
+ import numpy as np
17
+
18
+ warnings.filterwarnings("ignore")
19
+
20
+ REFERENCES = {
21
+ "noise": "I usually take the quieter road home because the main street gets crowded after work.",
22
+ "far_field": "Please remind me to print the forms before we leave for the appointment tomorrow.",
23
+ "obstructed": "I forgot my charger at home, so I need to find an outlet before the meeting starts.",
24
+ "distortion": "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.",
25
+ "recording": "Can you check whether the train still stops at the downtown station after eight tonight?",
26
+ "echo": "I need to return these shoes because the size feels fine standing up but terrible while walking.",
27
+ "dropout": "My aunt is learning video calls, and she gets excited whenever the picture actually works.",
28
+ "mixed": "My sister is bringing dinner over later, so we do not need to cook tonight.",
29
+ }
30
+ _NORM_RE = re.compile(r"[^a-z0-9\s]")
31
+
32
+
33
+ def normalize(t):
34
+ if "<asr_text>" in t:
35
+ t = t.split("<asr_text>", 1)[1]
36
+ return re.sub(r"\s+", " ", _NORM_RE.sub(" ", t.lower())).strip()
37
+
38
+
39
+ def wer(ref, hyp):
40
+ r = ref.split(); h = hyp.split()
41
+ if not r: return (1.0 if h else 0.0, len(h), 0)
42
+ d = np.zeros((len(r) + 1, len(h) + 1), dtype=np.int32)
43
+ for i in range(len(r) + 1): d[i, 0] = i
44
+ for j in range(len(h) + 1): d[0, j] = j
45
+ for i in range(1, len(r) + 1):
46
+ for j in range(1, len(h) + 1):
47
+ d[i, j] = min(d[i-1, j] + 1, d[i, j-1] + 1,
48
+ d[i-1, j-1] + (0 if r[i-1] == h[j-1] else 1))
49
+ return d[len(r), len(h)] / max(len(r), 1), int(d[len(r), len(h)]), len(r)
50
+
51
+
52
+ def color(p, s):
53
+ if p >= 70: return f"\033[92m{s}\033[0m"
54
+ if p >= 50: return f"\033[93m{s}\033[0m"
55
+ if p >= 25: return f"\033[33m{s}\033[0m"
56
+ return f"\033[91m{s}\033[0m"
57
+
58
+
59
+ def build_prompt_ids(tok, audio_pad_count, audio_pad_id=151676):
60
+ prompt = (
61
+ "<|im_start|>system\n<|im_end|>\n"
62
+ "<|im_start|>user\n<|audio_start|><|audio_pad|><|audio_end|><|im_end|>\n"
63
+ "<|im_start|>assistant\n"
64
+ "language English<asr_text>"
65
+ )
66
+ ids = tok.encode(prompt, add_special_tokens=False)
67
+ pos = ids.index(audio_pad_id)
68
+ return ids[:pos] + [audio_pad_id] * audio_pad_count + ids[pos + 1:]
69
+
70
+
71
+ def causal_mask_at(cur, ctx, neg_inf=-1e4):
72
+ """Build (1,1,1,ctx) mask: positions > cur get -inf, others 0."""
73
+ m = np.zeros((1, 1, 1, ctx), dtype=np.float16)
74
+ if cur + 1 < ctx:
75
+ m[0, 0, 0, cur + 1:] = neg_inf
76
+ return m
77
+
78
+
79
+ def update_mask_at(cur, ctx):
80
+ """(1,1,ctx,1) — 1.0 at the current position, 0 elsewhere. Used for KV cache write."""
81
+ m = np.zeros((1, 1, ctx, 1), dtype=np.float16)
82
+ m[0, 0, cur, 0] = 1.0
83
+ return m
84
+
85
+
86
+ def main():
87
+ ap = argparse.ArgumentParser()
88
+ ap.add_argument("--mlpackage", default="models/coreml/mega-asr-llm-embeds_lut4.mlpackage", type=Path)
89
+ ap.add_argument("--encoder-path", default="models/mega-asr-onnx-hf/onnx/audio_encoder_fp32.onnx", type=Path)
90
+ ap.add_argument("--examples-dir", default="models/mega-asr-onnx-hf/examples", type=Path)
91
+ ap.add_argument("--qwen-asr-dir", default="models/mega-asr-hf/Qwen3-ASR-1.7B", type=Path)
92
+ ap.add_argument("--max-new-tokens", type=int, default=80)
93
+ ap.add_argument("--context-length", type=int, default=512)
94
+ ap.add_argument("--compute-unit", default="CPU_AND_NE", choices=["CPU_ONLY", "CPU_AND_NE", "ALL"])
95
+ args = ap.parse_args()
96
+
97
+ import soundfile as sf
98
+ import onnxruntime as ort
99
+ import coremltools as ct
100
+ from transformers import AutoFeatureExtractor, AutoTokenizer
101
+
102
+ print(f"Loading CoreML mlpackage ({args.compute_unit}) ...")
103
+ cu = getattr(ct.ComputeUnit, args.compute_unit)
104
+ mlm = ct.models.MLModel(str(args.mlpackage), compute_units=cu)
105
+
106
+ print(f"Loading ONNX encoder ...")
107
+ enc = ort.InferenceSession(str(args.encoder_path), providers=["CPUExecutionProvider"])
108
+ feat = AutoFeatureExtractor.from_pretrained(str(args.qwen_asr_dir))
109
+ tok = AutoTokenizer.from_pretrained(str(args.qwen_asr_dir))
110
+
111
+ # Embed table from the HF model (for text tokens; audio_pad slots use audio_embeds)
112
+ import safetensors.torch as st
113
+ import torch
114
+ print("Loading embed_tokens (bf16 → fp32) ...")
115
+ # Find embed_tokens.weight from the original Qwen3-ASR safetensors
116
+ idx = json.load(open(args.qwen_asr_dir / "model.safetensors.index.json"))
117
+ embed_key = "thinker.model.embed_tokens.weight"
118
+ shard = idx["weight_map"][embed_key]
119
+ embed_w = st.load_file(str(args.qwen_asr_dir / shard))[embed_key]
120
+ embed_w = embed_w.to(torch.float32).numpy() # (151936, 2048)
121
+ HIDDEN = embed_w.shape[1]
122
+ print(f" embed_w shape: {embed_w.shape}")
123
+
124
+ AUDIO_PAD = 151676
125
+ EOS = 151645
126
+ CTX = args.context_length
127
+
128
+ total_wer = 0.0; total_edits = 0; total_words = 0; n = 0
129
+ for name in sorted(REFERENCES):
130
+ wav_path = args.examples_dir / f"{name}.wav"
131
+ if not wav_path.exists():
132
+ print(f" skip {name} (missing)"); continue
133
+ audio, sr = sf.read(str(wav_path))
134
+ if audio.ndim > 1: audio = audio.mean(axis=1)
135
+ if sr != 16000:
136
+ import librosa
137
+ audio = librosa.resample(audio.astype(np.float32), orig_sr=sr, target_sr=16000)
138
+ f = feat(audio, sampling_rate=16000, return_tensors="np", return_attention_mask=False)
139
+ mel = f["input_features"]
140
+ T_mel = mel.shape[-1]
141
+ if T_mel > 3000: mel = mel[..., :3000]; T_mel = 3000
142
+ mel = np.pad(mel, ((0, 0), (0, 0), (0, 3000 - T_mel)), constant_values=0).astype(np.float32)
143
+ audio_embeds = enc.run(["audio_embeds"], {"mel": mel})[0]
144
+ real_chunks = (T_mel + 99) // 100
145
+ last_chunk = T_mel - (real_chunks - 1) * 100
146
+ real_frames = (real_chunks - 1) * 13 + (last_chunk + 7) // 8
147
+ audio_embeds = audio_embeds[0, :real_frames] # (F, 2048)
148
+
149
+ # Build prompt tokens + per-position embeddings
150
+ prompt_ids = build_prompt_ids(tok, real_frames)
151
+ L = len(prompt_ids)
152
+ if L > CTX - args.max_new_tokens:
153
+ print(f" skip {name} (L={L} too long for ctx={CTX})"); continue
154
+
155
+ # Per-token embeddings: lookup for text, scatter audio_embeds at audio_pad slots
156
+ token_embeds = np.zeros((L, HIDDEN), dtype=np.float32)
157
+ ai = 0
158
+ for i, t in enumerate(prompt_ids):
159
+ if t == AUDIO_PAD:
160
+ token_embeds[i] = audio_embeds[ai]; ai += 1
161
+ else:
162
+ token_embeds[i] = embed_w[t]
163
+ token_embeds = token_embeds.astype(np.float16)
164
+
165
+ # Run prefill: feed each prompt token one at a time
166
+ state = mlm.make_state()
167
+ for i in range(L):
168
+ feeds = {
169
+ "inputs_embeds": token_embeds[i:i+1].reshape(1, 1, HIDDEN),
170
+ "position_ids": np.array([i], dtype=np.int32),
171
+ "causal_mask": causal_mask_at(i, CTX),
172
+ "current_pos": np.array([i], dtype=np.int32),
173
+ "update_mask": update_mask_at(i, CTX),
174
+ }
175
+ out = mlm.predict(feeds, state=state)
176
+ # Argmax of last logit
177
+ logits = np.concatenate([out[f"logits{k}"][0, 0] for k in range(1, 17)])
178
+ nid = int(np.argmax(logits))
179
+ gen = [nid]
180
+
181
+ # Decode step-by-step
182
+ for step in range(args.max_new_tokens - 1):
183
+ if nid == EOS: break
184
+ cur = L + step
185
+ if cur >= CTX: break
186
+ emb = embed_w[nid].astype(np.float16)
187
+ feeds = {
188
+ "inputs_embeds": emb.reshape(1, 1, HIDDEN),
189
+ "position_ids": np.array([cur], dtype=np.int32),
190
+ "causal_mask": causal_mask_at(cur, CTX),
191
+ "current_pos": np.array([cur], dtype=np.int32),
192
+ "update_mask": update_mask_at(cur, CTX),
193
+ }
194
+ out = mlm.predict(feeds, state=state)
195
+ logits = np.concatenate([out[f"logits{k}"][0, 0] for k in range(1, 17)])
196
+ nid = int(np.argmax(logits))
197
+ gen.append(nid)
198
+ # Strip trailing eos
199
+ if gen and gen[-1] == EOS: gen = gen[:-1]
200
+ hyp_text = tok.decode(gen, skip_special_tokens=True)
201
+
202
+ ref = normalize(REFERENCES[name])
203
+ hyp = normalize(hyp_text)
204
+ w, ed, words = wer(ref, hyp)
205
+ agree = max(0.0, 1.0 - w) * 100
206
+ total_wer += w; total_edits += ed; total_words += words; n += 1
207
+ print(f"\n[{color(agree, name.ljust(10))}] WER={w*100:5.1f}% agree={color(agree, f'{agree:5.1f}%')}")
208
+ print(f" REF: {ref}")
209
+ print(f" HYP: {hyp}")
210
+ avg = (1 - total_wer / n) * 100 if n else 0
211
+ print(f"\n{color(avg, f'=== AVERAGE: agreement {avg:.1f}% WER {total_edits/total_words*100:.1f}% ({total_edits}/{total_words}) ===')}")
212
+ return 0
213
+
214
+
215
+ if __name__ == "__main__":
216
+ sys.exit(main())