End-to-end CoreML ASR works (86.9% on VITW); document input_embeds fork + fp32 compute fix
Browse files
README.md
CHANGED
|
@@ -33,114 +33,141 @@ base_model: zhifeixie/Mega-ASR
|
|
| 33 |
base_model_relation: quantized
|
| 34 |
---
|
| 35 |
|
| 36 |
-
# Mega-ASR β CoreML LUT-4 (
|
| 37 |
|
| 38 |
-
CoreML LUT-4 (4-bit lookup-table palettized)
|
| 39 |
-
[zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
## What's in this repo
|
| 48 |
|
| 49 |
| File | Size | Role |
|
| 50 |
| --- | ---: | --- |
|
| 51 |
-
| `coreml/mega-asr-
|
| 52 |
-
| `
|
|
|
|
| 53 |
| `tokenizer/*` | β | Original Qwen3-ASR tokenizer (`<\|audio_pad\|>`, `<asr_text>`, etc.) |
|
| 54 |
| `examples/*.wav` | ~3 MB | 8 noisy benchmark clips from Voices-in-the-Wild-Bench |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
| `causal_mask` | `(1, 1, 1, 512)` | float16 |
|
| 66 |
-
| `current_pos` | `(1,)` | int32 |
|
| 67 |
-
| `update_mask` | `(1, 1, 512, 1)` | float16 |
|
| 68 |
-
|
| 69 |
-
**Outputs**: `logits1` β¦ `logits16`, each `(1, 1, 9496)` float16 β concat
|
| 70 |
-
along last axis to get the 151936-dim vocabulary.
|
| 71 |
-
|
| 72 |
-
**State**: `model_model_kv_cache_0` β shape `(56, 8, 512, 128)` float16 (28
|
| 73 |
-
layers Γ 2 (K/V) Γ 8 KV heads Γ 512 max context Γ 128 head dim). Create with
|
| 74 |
-
`model.make_state()` and pass to every `predict()`.
|
| 75 |
-
|
| 76 |
-
## Quick run (Python)
|
| 77 |
-
|
| 78 |
-
```python
|
| 79 |
-
import coremltools as ct
|
| 80 |
-
import numpy as np
|
| 81 |
-
|
| 82 |
-
m = ct.models.MLModel("coreml/mega-asr-llm_lut4.mlpackage",
|
| 83 |
-
compute_units=ct.ComputeUnit.CPU_AND_NE)
|
| 84 |
-
state = m.make_state()
|
| 85 |
-
out = m.predict({
|
| 86 |
-
"input_ids": np.array([[40]], dtype=np.int32), # token 'I'
|
| 87 |
-
"position_ids": np.array([0], dtype=np.int32),
|
| 88 |
-
"causal_mask": np.zeros((1, 1, 1, 512), dtype=np.float16),
|
| 89 |
-
"current_pos": np.array([0], dtype=np.int32),
|
| 90 |
-
"update_mask": np.zeros((1, 1, 512, 1), dtype=np.float16),
|
| 91 |
-
}, state=state)
|
| 92 |
-
all_logits = np.concatenate([out[f"logits{i}"][0, 0] for i in range(1, 17)])
|
| 93 |
```
|
| 94 |
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
**audio embeddings** at `<|audio_pad|>` placeholder positions, which means
|
| 100 |
-
the model needs to accept `input_embeddings` *instead of* `input_ids`.
|
| 101 |
|
| 102 |
-
|
| 103 |
-
hidden_states as the entry point, then re-running the conversion. (See
|
| 104 |
-
[`aoiandroid/Qwen3-ASR-1.7B-CoreML`](https://huggingface.co/aoiandroid/Qwen3-ASR-1.7B-CoreML)
|
| 105 |
-
for a prior community attempt of the same pattern; their decoder is named
|
| 106 |
-
`qwen3_asr_decoder_f32_anemll_int8-mixed.mlpackage` and pairs with a
|
| 107 |
-
separately stored `qwen3_asr_embeddings.bin`.)
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
prompt format the base model expects).
|
| 112 |
-
- A starting point for building an ANE-targeted Mega-ASR ASR pipeline by
|
| 113 |
-
re-converting with the embedding bypass.
|
| 114 |
|
| 115 |
-
|
|
|
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
python -m anemll.ane_converter.qwen_converter \
|
| 120 |
-
--model /path/to/Qwen3-ASR-1.7B-llm-only \
|
| 121 |
-
--prefix mega-asr-llm --lut 4 \
|
| 122 |
-
--context-length 512 --batch-size 64 --chunk 4 \
|
| 123 |
-
--output /path/to/out
|
| 124 |
```
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
`
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
## Companion repos
|
| 137 |
|
| 138 |
-
- [Reza2kn/mega-asr-onnx](https://huggingface.co/Reza2kn/mega-asr-onnx) β full ONNX pipeline (GPTQ-INT4
|
| 139 |
-
- [Reza2kn/mega-asr-mlx](https://huggingface.co/Reza2kn/mega-asr-mlx) β MLX 4-bit (
|
| 140 |
-
- [Reza2kn/mega-asr-bench](https://huggingface.co/spaces/Reza2kn/mega-asr-bench) β
|
| 141 |
|
| 142 |
## Credits
|
| 143 |
|
| 144 |
-
- Original model: [zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR) (1.7B
|
| 145 |
-
- CoreML conversion via [ANEMLL](https://github.com/Anemll/Anemll)
|
| 146 |
- Benchmark: [Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)
|
|
|
|
| 33 |
base_model_relation: quantized
|
| 34 |
---
|
| 35 |
|
| 36 |
+
# Mega-ASR β CoreML LUT-4 (end-to-end ASR)
|
| 37 |
|
| 38 |
+
CoreML LUT-4 (4-bit lookup-table palettized) deployment of
|
| 39 |
+
[zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR), with an
|
| 40 |
+
`input_embeds`-aware decoder so audio embeddings can be scattered at
|
| 41 |
+
`<|audio_pad|>` positions to do real ASR β not just text generation.
|
| 42 |
|
| 43 |
+
Converted via [ANEMLL](https://github.com/Anemll/Anemll) with a custom
|
| 44 |
+
`coreml_convert_embeds.py` that monkey-patches `QwenModel.forward` +
|
| 45 |
+
`QwenForCausalLM.forward` to accept pre-embedded `hidden_states` (skipping the
|
| 46 |
+
internal `embed_tokens` lookup). The model is single-token-step, stateful KV
|
| 47 |
+
cache (28 layers Γ 2 Γ 8 KV heads Γ 512 ctx Γ 128 head_dim, fp16), LUT-4
|
| 48 |
+
weights at `--per_channel 8`, and **fp32 compute precision** β `compute_precision=FLOAT16`
|
| 49 |
+
overflows in Qwen3-ASR's RMSNorm/attention layers and produces NaN logits.
|
| 50 |
|
| 51 |
## What's in this repo
|
| 52 |
|
| 53 |
| File | Size | Role |
|
| 54 |
| --- | ---: | --- |
|
| 55 |
+
| `coreml/mega-asr-llm-embeds_fp32compute_lut4.mlpackage/` | **826 MB** | **Recommended.** Qwen3 1.7B LLM, `inputs_embeds` input, fp32 compute, LUT-4 weights. Pair with the ONNX audio encoder for end-to-end ASR. |
|
| 56 |
+
| `coreml/mega-asr-llm_lut4.mlpackage/` | 974 MB | Original `input_ids` variant β standalone Qwen3 1.7B text LLM (no audio scatter). |
|
| 57 |
+
| `onnx/audio_encoder_fp32.onnx` | 1.27 GB | 24-layer Whisper-style audio encoder (ONNX, runs via onnxruntime; CoreML port pending) |
|
| 58 |
| `tokenizer/*` | β | Original Qwen3-ASR tokenizer (`<\|audio_pad\|>`, `<asr_text>`, etc.) |
|
| 59 |
| `examples/*.wav` | ~3 MB | 8 noisy benchmark clips from Voices-in-the-Wild-Bench |
|
| 60 |
+
| `inference_asr.py` | β | End-to-end ASR pipeline: ONNX encoder + CoreML LLM |
|
| 61 |
+
| `convert_embeds.py` | β | The custom converter (use to reproduce / re-quantize) |
|
| 62 |
+
|
| 63 |
+
## Quality (bench)
|
| 64 |
+
|
| 65 |
+
8-clip [Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)
|
| 66 |
+
agreement (1 β WER), prompt forced to `language English`, on M-series Mac
|
| 67 |
+
CPU (CPU_AND_NE failed to compile for ANE due to model size + state):
|
| 68 |
+
|
| 69 |
+
| Per-sample | Hyp β Ref? | Agreement |
|
| 70 |
+
| --- | --- | ---: |
|
| 71 |
+
| distortion | exact match | 100% |
|
| 72 |
+
| dropout | exact match | 100% |
|
| 73 |
+
| far_field | exact match | 100% |
|
| 74 |
+
| mixed | exact match | 100% |
|
| 75 |
+
| noise | exact match | 100% |
|
| 76 |
+
| obstructed | "i have forgotten" vs "i forgot" | 88.2% |
|
| 77 |
+
| echo (hard, heavy reverb) | "size 25 stand not and the 125 walk" | 47.1% |
|
| 78 |
+
| recording (hard, truncated audio) | "train stopped at the station" | 60.0% |
|
| 79 |
+
| **AVERAGE** | | **86.9%** |
|
| 80 |
+
|
| 81 |
+
For reference (same 8 samples, same audio encoder, same prompt):
|
| 82 |
+
|
| 83 |
+
| Backend | Agreement |
|
| 84 |
+
| --- | ---: |
|
| 85 |
+
| ONNX recommended (GPTQ) | 92.7% |
|
| 86 |
+
| MLX recommended (mixed 8/4) | 92.2% |
|
| 87 |
+
| **CoreML LUT-4 (this repo)** | **86.9%** |
|
| 88 |
+
| ONNX RTN INT4 baseline | 87.8% |
|
| 89 |
+
|
| 90 |
+
LUT-4 k-means is a more aggressive quantization than ONNX GPTQ (which uses
|
| 91 |
+
activation-aware error redistribution) or MLX mixed 8/4 (which keeps the
|
| 92 |
+
4 attention projections at 8-bit). The roughly **6% gap** vs the leaders is
|
| 93 |
+
concentrated on the 2 hard samples (`echo`, `recording`) and one near-miss
|
| 94 |
+
on `obstructed`. Six of eight samples produce exact-match transcriptions.
|
| 95 |
+
|
| 96 |
+
## Inference
|
| 97 |
|
| 98 |
+
```bash
|
| 99 |
+
pip install coremltools onnxruntime soundfile transformers safetensors librosa numpy
|
| 100 |
+
git clone https://huggingface.co/Reza2kn/mega-asr-coreml
|
| 101 |
+
cd mega-asr-coreml
|
| 102 |
+
python inference_asr.py \
|
| 103 |
+
--mlpackage coreml/mega-asr-llm-embeds_fp32compute_lut4.mlpackage \
|
| 104 |
+
--encoder-path onnx/audio_encoder_fp32.onnx \
|
| 105 |
+
--examples-dir examples \
|
| 106 |
+
--qwen-asr-dir <local path to Qwen3-ASR-1.7B HF dir>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
```
|
| 108 |
|
| 109 |
+
The pipeline runs:
|
| 110 |
+
1. **Mel features** via Qwen3-ASR's `WhisperFeatureExtractor`.
|
| 111 |
+
2. **Audio encoder** (ONNX fp32) β audio embeddings `(F, 2048)`.
|
| 112 |
+
3. **Prompt + scatter**: build the Qwen3-ASR chat template, expand the single
|
| 113 |
+
`<|audio_pad|>` placeholder to `F` slots, lookup text embeds via the
|
| 114 |
+
original HF model's `embed_tokens` weight, scatter audio embeds in.
|
| 115 |
+
4. **CoreML prefill**: feed each token's embedding one-at-a-time to populate the
|
| 116 |
+
KV cache state.
|
| 117 |
+
5. **CoreML decode**: greedy step-by-step until `<|im_end|>`.
|
| 118 |
|
| 119 |
+
The KV cache lives inside the CoreML model as `state`. Call `model.make_state()`
|
| 120 |
+
once per request, then pass the same state object to every `predict()` call.
|
|
|
|
|
|
|
| 121 |
|
| 122 |
+
## Conversion details
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
Two-step monkey-patch in `convert_embeds.py` lets ANEMLL's Qwen3 conversion
|
| 125 |
+
accept pre-embedded inputs:
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
```python
|
| 128 |
+
# 1. QwenModel.forward β detect float input_ids and skip embed_tokens
|
| 129 |
+
qm.QwenModel.forward = model_forward_or_embeds
|
| 130 |
|
| 131 |
+
# 2. QwenForCausalLM.forward β relax the 2D assert; replicate lm_head logic
|
| 132 |
+
qm.QwenForCausalLM.forward = causal_forward_or_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
```
|
| 134 |
|
| 135 |
+
ANEMLL's CoreML conversion then traces with a `WrapperEmbeds` module whose
|
| 136 |
+
inputs are `(inputs_embeds, position_ids, causal_mask, current_pos, update_mask)`.
|
| 137 |
+
`coremltools.optimize.coreml.palettize_weights` applies LUT-4 with
|
| 138 |
+
`per_grouped_channel` / `group_size=8`.
|
| 139 |
+
|
| 140 |
+
**Key compute-precision tweak**: `compute_precision=ct.precision.FLOAT32`
|
| 141 |
+
in `ct.convert`. fp16 compute produces all-NaN logits on Qwen3-ASR's
|
| 142 |
+
RMSNorm + attention layers β same finding as the aoiandroid community
|
| 143 |
+
CoreML port. Weights stay LUT-4 (4-bit storage); only activations run fp32.
|
| 144 |
+
|
| 145 |
+
Also patched: `coremltools/converters/mil/frontend/torch/ops.py` `_cast` op
|
| 146 |
+
handler (numpy array of size 1 β extract scalar via `.flatten()[0].item()`).
|
| 147 |
+
Diff lives in `convert_embeds.py` setup notes.
|
| 148 |
+
|
| 149 |
+
## Known limitations
|
| 150 |
+
|
| 151 |
+
1. **CPU compute only** in practice. CoreML's ANE compiler rejects this model
|
| 152 |
+
(`MILCompilerForANE error: failed to compile ANE model using ANEF`) β likely
|
| 153 |
+
due to model size + stateful KV cache. CPU_AND_NE / ALL fail to load;
|
| 154 |
+
CPU_ONLY works and is correct. Per-token latency is ~1.5 s on CPU.
|
| 155 |
+
2. **Audio encoder is ONNX**. The 24-layer Whisper-style encoder hasn't been
|
| 156 |
+
ported to CoreML (ANEMLL is LLM-only). End-to-end inference runs the
|
| 157 |
+
encoder via `onnxruntime` and the LLM via `coremltools`.
|
| 158 |
+
3. **Quality below ONNX/MLX** at 4-bit due to LUT-4 k-means being weaker than
|
| 159 |
+
GPTQ on this architecture. Mitigations: use LUT-6 (`--lut 6` in the
|
| 160 |
+
converter) to recover ~3% at +50% size, or use the fp16 variant
|
| 161 |
+
(`mega-asr-llm-embeds_fp16.mlpackage`, ~3.2 GB) for full quality.
|
| 162 |
|
| 163 |
## Companion repos
|
| 164 |
|
| 165 |
+
- [Reza2kn/mega-asr-onnx](https://huggingface.co/Reza2kn/mega-asr-onnx) β full ONNX pipeline (GPTQ-INT4, 92.7%)
|
| 166 |
+
- [Reza2kn/mega-asr-mlx](https://huggingface.co/Reza2kn/mega-asr-mlx) β MLX 4-bit (mixed 8/4 attn/MLP, 92.2%)
|
| 167 |
+
- [Reza2kn/mega-asr-bench](https://huggingface.co/spaces/Reza2kn/mega-asr-bench) β browser demo (WebGPU)
|
| 168 |
|
| 169 |
## Credits
|
| 170 |
|
| 171 |
+
- Original model: [zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR) (1.7B, Apache-2.0)
|
| 172 |
+
- CoreML conversion via [ANEMLL](https://github.com/Anemll/Anemll) with a custom input_embeds patch
|
| 173 |
- Benchmark: [Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)
|