Reza2kn commited on
Commit
3eeae98
Β·
verified Β·
1 Parent(s): ea99fd4

End-to-end CoreML ASR works (86.9% on VITW); document input_embeds fork + fp32 compute fix

Browse files
Files changed (1) hide show
  1. README.md +112 -85
README.md CHANGED
@@ -33,114 +33,141 @@ base_model: zhifeixie/Mega-ASR
33
  base_model_relation: quantized
34
  ---
35
 
36
- # Mega-ASR β€” CoreML LUT-4 (Apple Neural Engine)
37
 
38
- CoreML LUT-4 (4-bit lookup-table palettized) export of the LLM portion of
39
- [zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR) (Qwen3-ASR-1.7B
40
- base), produced via [ANEMLL](https://github.com/Anemll/Anemll) β€” the Apple
41
- Neural Engine reference converter β€” with `--chunk 4 --lut 4 --context-length 512`.
42
 
43
- The resulting `.mlpackage` is a stateful CoreML model with native ANE
44
- attention layouts, in-model KV cache state, and 16-way split LM head for
45
- efficient ANE residency.
 
 
 
 
46
 
47
  ## What's in this repo
48
 
49
  | File | Size | Role |
50
  | --- | ---: | --- |
51
- | `coreml/mega-asr-llm_lut4.mlpackage/` | **974 MB** | Qwen3 1.7B LLM, ANE-targeted, LUT-4 palettized weights, stateful KV cache |
52
- | `onnx/audio_encoder_fp32.onnx` | 1.27 GB | 24-layer Whisper-style audio encoder (ONNX fp32, run via onnxruntime β€” CoreML port pending) |
 
53
  | `tokenizer/*` | β€” | Original Qwen3-ASR tokenizer (`<\|audio_pad\|>`, `<asr_text>`, etc.) |
54
  | `examples/*.wav` | ~3 MB | 8 noisy benchmark clips from Voices-in-the-Wild-Bench |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- ## Model I/O
57
-
58
- The `mega-asr-llm_lut4.mlpackage` follows ANEMLL's stateful step-decoder layout:
59
-
60
- **Inputs** (single-token step):
61
- | name | shape | dtype |
62
- | --- | --- | --- |
63
- | `input_ids` | `(1, 1)` | int32 |
64
- | `position_ids` | `(1,)` | int32 |
65
- | `causal_mask` | `(1, 1, 1, 512)` | float16 |
66
- | `current_pos` | `(1,)` | int32 |
67
- | `update_mask` | `(1, 1, 512, 1)` | float16 |
68
-
69
- **Outputs**: `logits1` … `logits16`, each `(1, 1, 9496)` float16 β€” concat
70
- along last axis to get the 151936-dim vocabulary.
71
-
72
- **State**: `model_model_kv_cache_0` β€” shape `(56, 8, 512, 128)` float16 (28
73
- layers Γ— 2 (K/V) Γ— 8 KV heads Γ— 512 max context Γ— 128 head dim). Create with
74
- `model.make_state()` and pass to every `predict()`.
75
-
76
- ## Quick run (Python)
77
-
78
- ```python
79
- import coremltools as ct
80
- import numpy as np
81
-
82
- m = ct.models.MLModel("coreml/mega-asr-llm_lut4.mlpackage",
83
- compute_units=ct.ComputeUnit.CPU_AND_NE)
84
- state = m.make_state()
85
- out = m.predict({
86
- "input_ids": np.array([[40]], dtype=np.int32), # token 'I'
87
- "position_ids": np.array([0], dtype=np.int32),
88
- "causal_mask": np.zeros((1, 1, 1, 512), dtype=np.float16),
89
- "current_pos": np.array([0], dtype=np.int32),
90
- "update_mask": np.zeros((1, 1, 512, 1), dtype=np.float16),
91
- }, state=state)
92
- all_logits = np.concatenate([out[f"logits{i}"][0, 0] for i in range(1, 17)])
93
  ```
94
 
95
- ## ASR limitation (current)
 
 
 
 
 
 
 
 
96
 
97
- This conversion exports the **standard text-LLM interface** (`input_ids` β†’
98
- internal `embed_tokens` β†’ forward). End-to-end ASR requires scattering
99
- **audio embeddings** at `<|audio_pad|>` placeholder positions, which means
100
- the model needs to accept `input_embeddings` *instead of* `input_ids`.
101
 
102
- That requires forking ANEMLL's `qwen_model.py` to expose pre-embedded
103
- hidden_states as the entry point, then re-running the conversion. (See
104
- [`aoiandroid/Qwen3-ASR-1.7B-CoreML`](https://huggingface.co/aoiandroid/Qwen3-ASR-1.7B-CoreML)
105
- for a prior community attempt of the same pattern; their decoder is named
106
- `qwen3_asr_decoder_f32_anemll_int8-mixed.mlpackage` and pairs with a
107
- separately stored `qwen3_asr_embeddings.bin`.)
108
 
109
- Until the input_embeddings variant lands, this artifact is usable as:
110
- - A standalone Qwen3 1.7B CoreML LLM (e.g. text-only chat with the same
111
- prompt format the base model expects).
112
- - A starting point for building an ANE-targeted Mega-ASR ASR pipeline by
113
- re-converting with the embedding bypass.
114
 
115
- ## Conversion details
 
 
116
 
117
- ```bash
118
- # After cloning ANEMLL (https://github.com/Anemll/Anemll):
119
- python -m anemll.ane_converter.qwen_converter \
120
- --model /path/to/Qwen3-ASR-1.7B-llm-only \
121
- --prefix mega-asr-llm --lut 4 \
122
- --context-length 512 --batch-size 64 --chunk 4 \
123
- --output /path/to/out
124
  ```
125
 
126
- The Qwen3-ASR-1.7B LLM weights were first extracted from `zhifeixie/Mega-ASR`
127
- by stripping the `thinker.model.` prefix and dropping the tied lm_head
128
- (see [Reza2kn/mega-asr-mlx](https://huggingface.co/Reza2kn/mega-asr-mlx) for
129
- the extraction script).
130
-
131
- Coremltools 9.0 needed one local patch: the `_cast` op handler in
132
- `coremltools/converters/mil/frontend/torch/ops.py` does not handle numpy
133
- arrays of size 1 β€” fixed by extracting the scalar via `.flatten()[0].item()`
134
- before the dtype coercion.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  ## Companion repos
137
 
138
- - [Reza2kn/mega-asr-onnx](https://huggingface.co/Reza2kn/mega-asr-onnx) β€” full ONNX pipeline (GPTQ-INT4 decoder, 92.7% on VITW)
139
- - [Reza2kn/mega-asr-mlx](https://huggingface.co/Reza2kn/mega-asr-mlx) β€” MLX 4-bit (mixed8/4 attention/MLP, 92.2% on VITW)
140
- - [Reza2kn/mega-asr-bench](https://huggingface.co/spaces/Reza2kn/mega-asr-bench) β€” live browser demo (WebGPU)
141
 
142
  ## Credits
143
 
144
- - Original model: [zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR) (1.7B params, Apache-2.0)
145
- - CoreML conversion via [ANEMLL](https://github.com/Anemll/Anemll) (Apple Neural Engine LLM port toolkit)
146
  - Benchmark: [Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)
 
33
  base_model_relation: quantized
34
  ---
35
 
36
+ # Mega-ASR β€” CoreML LUT-4 (end-to-end ASR)
37
 
38
+ CoreML LUT-4 (4-bit lookup-table palettized) deployment of
39
+ [zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR), with an
40
+ `input_embeds`-aware decoder so audio embeddings can be scattered at
41
+ `<|audio_pad|>` positions to do real ASR β€” not just text generation.
42
 
43
+ Converted via [ANEMLL](https://github.com/Anemll/Anemll) with a custom
44
+ `coreml_convert_embeds.py` that monkey-patches `QwenModel.forward` +
45
+ `QwenForCausalLM.forward` to accept pre-embedded `hidden_states` (skipping the
46
+ internal `embed_tokens` lookup). The model is single-token-step, stateful KV
47
+ cache (28 layers Γ— 2 Γ— 8 KV heads Γ— 512 ctx Γ— 128 head_dim, fp16), LUT-4
48
+ weights at `--per_channel 8`, and **fp32 compute precision** β€” `compute_precision=FLOAT16`
49
+ overflows in Qwen3-ASR's RMSNorm/attention layers and produces NaN logits.
50
 
51
  ## What's in this repo
52
 
53
  | File | Size | Role |
54
  | --- | ---: | --- |
55
+ | `coreml/mega-asr-llm-embeds_fp32compute_lut4.mlpackage/` | **826 MB** | **Recommended.** Qwen3 1.7B LLM, `inputs_embeds` input, fp32 compute, LUT-4 weights. Pair with the ONNX audio encoder for end-to-end ASR. |
56
+ | `coreml/mega-asr-llm_lut4.mlpackage/` | 974 MB | Original `input_ids` variant β€” standalone Qwen3 1.7B text LLM (no audio scatter). |
57
+ | `onnx/audio_encoder_fp32.onnx` | 1.27 GB | 24-layer Whisper-style audio encoder (ONNX, runs via onnxruntime; CoreML port pending) |
58
  | `tokenizer/*` | β€” | Original Qwen3-ASR tokenizer (`<\|audio_pad\|>`, `<asr_text>`, etc.) |
59
  | `examples/*.wav` | ~3 MB | 8 noisy benchmark clips from Voices-in-the-Wild-Bench |
60
+ | `inference_asr.py` | β€” | End-to-end ASR pipeline: ONNX encoder + CoreML LLM |
61
+ | `convert_embeds.py` | β€” | The custom converter (use to reproduce / re-quantize) |
62
+
63
+ ## Quality (bench)
64
+
65
+ 8-clip [Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)
66
+ agreement (1 βˆ’ WER), prompt forced to `language English`, on M-series Mac
67
+ CPU (CPU_AND_NE failed to compile for ANE due to model size + state):
68
+
69
+ | Per-sample | Hyp β‰ˆ Ref? | Agreement |
70
+ | --- | --- | ---: |
71
+ | distortion | exact match | 100% |
72
+ | dropout | exact match | 100% |
73
+ | far_field | exact match | 100% |
74
+ | mixed | exact match | 100% |
75
+ | noise | exact match | 100% |
76
+ | obstructed | "i have forgotten" vs "i forgot" | 88.2% |
77
+ | echo (hard, heavy reverb) | "size 25 stand not and the 125 walk" | 47.1% |
78
+ | recording (hard, truncated audio) | "train stopped at the station" | 60.0% |
79
+ | **AVERAGE** | | **86.9%** |
80
+
81
+ For reference (same 8 samples, same audio encoder, same prompt):
82
+
83
+ | Backend | Agreement |
84
+ | --- | ---: |
85
+ | ONNX recommended (GPTQ) | 92.7% |
86
+ | MLX recommended (mixed 8/4) | 92.2% |
87
+ | **CoreML LUT-4 (this repo)** | **86.9%** |
88
+ | ONNX RTN INT4 baseline | 87.8% |
89
+
90
+ LUT-4 k-means is a more aggressive quantization than ONNX GPTQ (which uses
91
+ activation-aware error redistribution) or MLX mixed 8/4 (which keeps the
92
+ 4 attention projections at 8-bit). The roughly **6% gap** vs the leaders is
93
+ concentrated on the 2 hard samples (`echo`, `recording`) and one near-miss
94
+ on `obstructed`. Six of eight samples produce exact-match transcriptions.
95
+
96
+ ## Inference
97
 
98
+ ```bash
99
+ pip install coremltools onnxruntime soundfile transformers safetensors librosa numpy
100
+ git clone https://huggingface.co/Reza2kn/mega-asr-coreml
101
+ cd mega-asr-coreml
102
+ python inference_asr.py \
103
+ --mlpackage coreml/mega-asr-llm-embeds_fp32compute_lut4.mlpackage \
104
+ --encoder-path onnx/audio_encoder_fp32.onnx \
105
+ --examples-dir examples \
106
+ --qwen-asr-dir <local path to Qwen3-ASR-1.7B HF dir>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  ```
108
 
109
+ The pipeline runs:
110
+ 1. **Mel features** via Qwen3-ASR's `WhisperFeatureExtractor`.
111
+ 2. **Audio encoder** (ONNX fp32) β†’ audio embeddings `(F, 2048)`.
112
+ 3. **Prompt + scatter**: build the Qwen3-ASR chat template, expand the single
113
+ `<|audio_pad|>` placeholder to `F` slots, lookup text embeds via the
114
+ original HF model's `embed_tokens` weight, scatter audio embeds in.
115
+ 4. **CoreML prefill**: feed each token's embedding one-at-a-time to populate the
116
+ KV cache state.
117
+ 5. **CoreML decode**: greedy step-by-step until `<|im_end|>`.
118
 
119
+ The KV cache lives inside the CoreML model as `state`. Call `model.make_state()`
120
+ once per request, then pass the same state object to every `predict()` call.
 
 
121
 
122
+ ## Conversion details
 
 
 
 
 
123
 
124
+ Two-step monkey-patch in `convert_embeds.py` lets ANEMLL's Qwen3 conversion
125
+ accept pre-embedded inputs:
 
 
 
126
 
127
+ ```python
128
+ # 1. QwenModel.forward β€” detect float input_ids and skip embed_tokens
129
+ qm.QwenModel.forward = model_forward_or_embeds
130
 
131
+ # 2. QwenForCausalLM.forward β€” relax the 2D assert; replicate lm_head logic
132
+ qm.QwenForCausalLM.forward = causal_forward_or_embeds
 
 
 
 
 
133
  ```
134
 
135
+ ANEMLL's CoreML conversion then traces with a `WrapperEmbeds` module whose
136
+ inputs are `(inputs_embeds, position_ids, causal_mask, current_pos, update_mask)`.
137
+ `coremltools.optimize.coreml.palettize_weights` applies LUT-4 with
138
+ `per_grouped_channel` / `group_size=8`.
139
+
140
+ **Key compute-precision tweak**: `compute_precision=ct.precision.FLOAT32`
141
+ in `ct.convert`. fp16 compute produces all-NaN logits on Qwen3-ASR's
142
+ RMSNorm + attention layers β€” same finding as the aoiandroid community
143
+ CoreML port. Weights stay LUT-4 (4-bit storage); only activations run fp32.
144
+
145
+ Also patched: `coremltools/converters/mil/frontend/torch/ops.py` `_cast` op
146
+ handler (numpy array of size 1 β†’ extract scalar via `.flatten()[0].item()`).
147
+ Diff lives in `convert_embeds.py` setup notes.
148
+
149
+ ## Known limitations
150
+
151
+ 1. **CPU compute only** in practice. CoreML's ANE compiler rejects this model
152
+ (`MILCompilerForANE error: failed to compile ANE model using ANEF`) β€” likely
153
+ due to model size + stateful KV cache. CPU_AND_NE / ALL fail to load;
154
+ CPU_ONLY works and is correct. Per-token latency is ~1.5 s on CPU.
155
+ 2. **Audio encoder is ONNX**. The 24-layer Whisper-style encoder hasn't been
156
+ ported to CoreML (ANEMLL is LLM-only). End-to-end inference runs the
157
+ encoder via `onnxruntime` and the LLM via `coremltools`.
158
+ 3. **Quality below ONNX/MLX** at 4-bit due to LUT-4 k-means being weaker than
159
+ GPTQ on this architecture. Mitigations: use LUT-6 (`--lut 6` in the
160
+ converter) to recover ~3% at +50% size, or use the fp16 variant
161
+ (`mega-asr-llm-embeds_fp16.mlpackage`, ~3.2 GB) for full quality.
162
 
163
  ## Companion repos
164
 
165
+ - [Reza2kn/mega-asr-onnx](https://huggingface.co/Reza2kn/mega-asr-onnx) β€” full ONNX pipeline (GPTQ-INT4, 92.7%)
166
+ - [Reza2kn/mega-asr-mlx](https://huggingface.co/Reza2kn/mega-asr-mlx) β€” MLX 4-bit (mixed 8/4 attn/MLP, 92.2%)
167
+ - [Reza2kn/mega-asr-bench](https://huggingface.co/spaces/Reza2kn/mega-asr-bench) β€” browser demo (WebGPU)
168
 
169
  ## Credits
170
 
171
+ - Original model: [zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR) (1.7B, Apache-2.0)
172
+ - CoreML conversion via [ANEMLL](https://github.com/Anemll/Anemll) with a custom input_embeds patch
173
  - Benchmark: [Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)