Reza2kn commited on
Commit
005e85f
·
verified ·
1 Parent(s): 3639f53

Add mixed 8/4 CoreML (90.6% on VITW) — new recommended variant

Browse files
Files changed (1) hide show
  1. README.md +102 -78
README.md CHANGED
@@ -18,6 +18,8 @@ tags:
18
  - quantized
19
  - int4
20
  - 4bit
 
 
21
  - lut
22
  - palettize
23
  - on-device
@@ -33,65 +35,78 @@ base_model: zhifeixie/Mega-ASR
33
  base_model_relation: quantized
34
  ---
35
 
36
- # Mega-ASR — CoreML LUT-4 (end-to-end ASR)
37
 
38
- CoreML LUT-4 (4-bit lookup-table palettized) deployment of
39
- [zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR), with an
40
- `input_embeds`-aware decoder so audio embeddings can be scattered at
41
- `<|audio_pad|>` positions to do real ASR — not just text generation.
42
 
43
  Converted via [ANEMLL](https://github.com/Anemll/Anemll) with a custom
44
- `coreml_convert_embeds.py` that monkey-patches `QwenModel.forward` +
45
- `QwenForCausalLM.forward` to accept pre-embedded `hidden_states` (skipping the
46
- internal `embed_tokens` lookup). The model is single-token-step, stateful KV
47
- cache (28 layers × 2 × 8 KV heads × 512 ctx × 128 head_dim, fp16), LUT-4
48
- weights at `--per_channel 8`, and **fp32 compute precision** `compute_precision=FLOAT16`
49
- overflows in Qwen3-ASR's RMSNorm/attention layers and produces NaN logits.
 
 
 
 
 
50
 
51
  ## What's in this repo
52
 
53
  | File | Size | Role |
54
  | --- | ---: | --- |
55
- | `coreml/mega-asr-llm-embeds_fp32compute_lut4.mlpackage/` | **826 MB** | **Recommended.** Qwen3 1.7B LLM, `inputs_embeds` input, fp32 compute, LUT-4 weights. Pair with the ONNX audio encoder for end-to-end ASR. |
56
- | `coreml/mega-asr-llm_lut4.mlpackage/` | 974 MB | Original `input_ids` variant standalone Qwen3 1.7B text LLM (no audio scatter). |
 
57
  | `onnx/audio_encoder_fp32.onnx` | 1.27 GB | 24-layer Whisper-style audio encoder (ONNX, runs via onnxruntime; CoreML port pending) |
58
  | `tokenizer/*` | — | Original Qwen3-ASR tokenizer (`<\|audio_pad\|>`, `<asr_text>`, etc.) |
59
  | `examples/*.wav` | ~3 MB | 8 noisy benchmark clips from Voices-in-the-Wild-Bench |
60
- | `inference_asr.py` | — | End-to-end ASR pipeline: ONNX encoder + CoreML LLM |
61
- | `convert_embeds.py` | — | The custom converter (use to reproduce / re-quantize) |
62
 
63
  ## Quality (bench)
64
 
65
  8-clip [Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)
66
- agreement (1 − WER), prompt forced to `language English`, on M-series Mac
67
- CPU (CPU_AND_NE failed to compile for ANE due to model size + state):
68
-
69
- | Per-sample | Hyp ≈ Ref? | Agreement |
70
- | --- | --- | ---: |
71
- | distortion | exact match | 100% |
72
- | dropout | exact match | 100% |
73
- | far_field | exact match | 100% |
74
- | mixed | exact match | 100% |
75
- | noise | exact match | 100% |
76
- | obstructed | "i have forgotten" vs "i forgot" | 88.2% |
77
- | echo (hard, heavy reverb) | "size 25 stand not and the 125 walk" | 47.1% |
78
- | recording (hard, truncated audio) | "train stopped at the station" | 60.0% |
79
- | **AVERAGE** | | **86.9%** |
80
-
81
- For reference (same 8 samples, same audio encoder, same prompt):
 
 
 
 
 
 
 
82
 
83
  | Backend | Agreement |
84
  | --- | ---: |
85
- | ONNX recommended (GPTQ) | 92.7% |
86
  | MLX recommended (mixed 8/4) | 92.2% |
87
- | **CoreML LUT-4 (this repo)** | **86.9%** |
 
88
  | ONNX RTN INT4 baseline | 87.8% |
89
 
90
- LUT-4 k-means is a more aggressive quantization than ONNX GPTQ (which uses
91
- activation-aware error redistribution) or MLX mixed 8/4 (which keeps the
92
- 4 attention projections at 8-bit). The roughly **6% gap** vs the leaders is
93
- concentrated on the 2 hard samples (`echo`, `recording`) and one near-miss
94
- on `obstructed`. Six of eight samples produce exact-match transcriptions.
95
 
96
  ## Inference
97
 
@@ -100,65 +115,74 @@ pip install coremltools onnxruntime soundfile transformers safetensors librosa n
100
  git clone https://huggingface.co/Reza2kn/mega-asr-coreml
101
  cd mega-asr-coreml
102
  python inference_asr.py \
103
- --mlpackage coreml/mega-asr-llm-embeds_fp32compute_lut4.mlpackage \
104
  --encoder-path onnx/audio_encoder_fp32.onnx \
105
  --examples-dir examples \
106
- --qwen-asr-dir <local path to Qwen3-ASR-1.7B HF dir>
 
107
  ```
108
 
109
- The pipeline runs:
110
  1. **Mel features** via Qwen3-ASR's `WhisperFeatureExtractor`.
111
  2. **Audio encoder** (ONNX fp32) → audio embeddings `(F, 2048)`.
112
- 3. **Prompt + scatter**: build the Qwen3-ASR chat template, expand the single
113
- `<|audio_pad|>` placeholder to `F` slots, lookup text embeds via the
114
- original HF model's `embed_tokens` weight, scatter audio embeds in.
115
- 4. **CoreML prefill**: feed each token's embedding one-at-a-time to populate the
116
- KV cache state.
 
117
  5. **CoreML decode**: greedy step-by-step until `<|im_end|>`.
118
 
119
- The KV cache lives inside the CoreML model as `state`. Call `model.make_state()`
120
- once per request, then pass the same state object to every `predict()` call.
 
121
 
122
  ## Conversion details
123
 
124
- Two-step monkey-patch in `convert_embeds.py` lets ANEMLL's Qwen3 conversion
125
- accept pre-embedded inputs:
126
-
127
  ```python
128
- # 1. QwenModel.forward detect float input_ids and skip embed_tokens
129
- qm.QwenModel.forward = model_forward_or_embeds
130
-
131
- # 2. QwenForCausalLM.forward relax the 2D assert; replicate lm_head logic
132
- qm.QwenForCausalLM.forward = causal_forward_or_embeds
 
 
 
 
 
 
 
 
 
 
133
  ```
134
 
135
- ANEMLL's CoreML conversion then traces with a `WrapperEmbeds` module whose
136
- inputs are `(inputs_embeds, position_ids, causal_mask, current_pos, update_mask)`.
137
- `coremltools.optimize.coreml.palettize_weights` applies LUT-4 with
138
- `per_grouped_channel` / `group_size=8`.
139
 
140
- **Key compute-precision tweak**: `compute_precision=ct.precision.FLOAT32`
141
- in `ct.convert`. fp16 compute produces all-NaN logits on Qwen3-ASR's
142
- RMSNorm + attention layers — same finding as the aoiandroid community
143
- CoreML port. Weights stay LUT-4 (4-bit storage); only activations run fp32.
144
 
145
- Also patched: `coremltools/converters/mil/frontend/torch/ops.py` `_cast` op
146
- handler (numpy array of size 1 → extract scalar via `.flatten()[0].item()`).
147
- Diff lives in `convert_embeds.py` setup notes.
 
148
 
149
  ## Known limitations
150
 
151
- 1. **CPU compute only** in practice. CoreML's ANE compiler rejects this model
152
- (`MILCompilerForANE error: failed to compile ANE model using ANEF`) — likely
153
- due to model size + stateful KV cache. CPU_AND_NE / ALL fail to load;
154
- CPU_ONLY works and is correct. Per-token latency is ~1.5 s on CPU.
155
- 2. **Audio encoder is ONNX**. The 24-layer Whisper-style encoder hasn't been
156
- ported to CoreML (ANEMLL is LLM-only). End-to-end inference runs the
 
157
  encoder via `onnxruntime` and the LLM via `coremltools`.
158
- 3. **Quality below ONNX/MLX** at 4-bit due to LUT-4 k-means being weaker than
159
- GPTQ on this architecture. Mitigations: use LUT-6 (`--lut 6` in the
160
- converter) to recover ~3% at +50% size, or use the fp16 variant
161
- (`mega-asr-llm-embeds_fp16.mlpackage`, ~3.2 GB) for full quality.
162
 
163
  ## Companion repos
164
 
@@ -169,5 +193,5 @@ Diff lives in `convert_embeds.py` setup notes.
169
  ## Credits
170
 
171
  - Original model: [zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR) (1.7B, Apache-2.0)
172
- - CoreML conversion via [ANEMLL](https://github.com/Anemll/Anemll) with a custom input_embeds patch
173
  - Benchmark: [Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)
 
18
  - quantized
19
  - int4
20
  - 4bit
21
+ - 8bit
22
+ - mixed-precision
23
  - lut
24
  - palettize
25
  - on-device
 
35
  base_model_relation: quantized
36
  ---
37
 
38
+ # Mega-ASR — CoreML mixed 8/4 (end-to-end ASR)
39
 
40
+ CoreML deployment of [zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR)
41
+ (Qwen3-ASR-1.7B base) with an **`input_embeds`-aware decoder** so audio
42
+ embeddings can be scattered at `<|audio_pad|>` positions to do real ASR —
43
+ not just text generation.
44
 
45
  Converted via [ANEMLL](https://github.com/Anemll/Anemll) with a custom
46
+ `convert_embeds_mixed.py` that:
47
+ 1. Monkey-patches `QwenModel.forward` + `QwenForCausalLM.forward` to accept
48
+ pre-embedded `hidden_states` (skipping the internal `embed_tokens`
49
+ lookup) so audio scatter works at inference.
50
+ 2. Enumerates the MIL program's const-weight ops by name pattern and applies
51
+ **LUT-8 palettization to attention projections** (q/k/v/o_proj) and
52
+ **LUT-4 to MLP projections** (gate/up/down_proj) — mirroring the MLX
53
+ `mixed8_4` recipe that closed the gap to GPTQ on the LLM portion.
54
+ 3. Runs `compute_precision=FLOAT32` — fp16 compute precision produces
55
+ all-NaN logits on Qwen3-ASR's RMSNorm/attention (matches the aoiandroid
56
+ community finding for the same base model).
57
 
58
  ## What's in this repo
59
 
60
  | File | Size | Role |
61
  | --- | ---: | --- |
62
+ | `coreml/mega-asr-llm-embeds_mixed8_4.mlpackage/` | **1.87 GB** | **Recommended.** Qwen3 1.7B LLM, `inputs_embeds` input, fp32 compute, 8-bit attn + 4-bit MLP, ~5.0 bpw avg. |
63
+ | `coreml/mega-asr-llm-embeds_fp32compute_lut4.mlpackage/` | 826 MB | Smaller variant. Uniform LUT-4 weights. -3.7% agreement vs mixed. |
64
+ | `coreml/mega-asr-llm_lut4.mlpackage/` | 974 MB | Standalone Qwen3 text LLM with `input_ids` input (no audio scatter). |
65
  | `onnx/audio_encoder_fp32.onnx` | 1.27 GB | 24-layer Whisper-style audio encoder (ONNX, runs via onnxruntime; CoreML port pending) |
66
  | `tokenizer/*` | — | Original Qwen3-ASR tokenizer (`<\|audio_pad\|>`, `<asr_text>`, etc.) |
67
  | `examples/*.wav` | ~3 MB | 8 noisy benchmark clips from Voices-in-the-Wild-Bench |
68
+ | `inference_asr.py` | — | End-to-end ASR pipeline (ONNX encoder + CoreML LLM) |
69
+ | `convert_embeds.py` / `convert_embeds_mixed.py` | — | The custom converters |
70
 
71
  ## Quality (bench)
72
 
73
  8-clip [Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)
74
+ agreement (1 − WER), prompt forced to `language English`, ONNX fp32
75
+ audio encoder + the CoreML LLM, ran with `compute_units=ALL` (Metal GPU
76
+ since ANE compilation fails on this model size + stateful KV cache):
77
+
78
+ | Per-sample | Mixed 8/4 (recommended) | Uniform LUT-4 |
79
+ | --- | ---: | ---: |
80
+ | distortion | 100% | 100% |
81
+ | dropout | 100% | 100% |
82
+ | echo (hard, heavy reverb) | **64.7%** | 47.1% |
83
+ | far_field | 100% | 100% |
84
+ | mixed | 100% | 100% |
85
+ | noise | 100% | 100% |
86
+ | obstructed | **100%** | 88.2% |
87
+ | recording (hard, truncated audio) | 60.0% | 60.0% |
88
+ | **AVERAGE** | **90.6%** | 86.9% |
89
+
90
+ Mixed 8/4 lifts CoreML from 86.9% → 90.6% (+3.7) by allocating the 4
91
+ attention projections per layer to LUT-8 (16 unique values for every 8
92
+ channels) while keeping the 3 MLP projections at LUT-4 (16 unique values
93
+ per 8 channels). Attention layers in Qwen3 are quality-critical — same
94
+ result we found in the MLX port.
95
+
96
+ Cross-backend leaderboard (same 8 samples, same audio encoder):
97
 
98
  | Backend | Agreement |
99
  | --- | ---: |
100
+ | ONNX recommended (GPTQ INT4) | 92.7% |
101
  | MLX recommended (mixed 8/4) | 92.2% |
102
+ | **CoreML recommended (mixed 8/4)** | **90.6%** |
103
+ | CoreML LUT-4 baseline | 86.9% |
104
  | ONNX RTN INT4 baseline | 87.8% |
105
 
106
+ The remaining ~2% gap to ONNX/MLX is the LUT-vs-GPTQ scheme difference
107
+ (k-means clustering vs activation-aware Hessian redistribution). The two
108
+ hard samples (`echo`, `recording`) are audio-quality-limited and stay
109
+ around 60-65% across all 4-bit backends.
 
110
 
111
  ## Inference
112
 
 
115
  git clone https://huggingface.co/Reza2kn/mega-asr-coreml
116
  cd mega-asr-coreml
117
  python inference_asr.py \
118
+ --mlpackage coreml/mega-asr-llm-embeds_mixed8_4.mlpackage \
119
  --encoder-path onnx/audio_encoder_fp32.onnx \
120
  --examples-dir examples \
121
+ --qwen-asr-dir <local path to Qwen3-ASR-1.7B HF dir> \
122
+ --compute-unit ALL
123
  ```
124
 
125
+ The pipeline:
126
  1. **Mel features** via Qwen3-ASR's `WhisperFeatureExtractor`.
127
  2. **Audio encoder** (ONNX fp32) → audio embeddings `(F, 2048)`.
128
+ 3. **Prompt + scatter**: build the Qwen3-ASR chat template with English
129
+ forcing, expand the single `<|audio_pad|>` placeholder to F slots,
130
+ lookup text embeds via the HF model's `embed_tokens` weight, scatter
131
+ audio embeds at the placeholder positions.
132
+ 4. **CoreML prefill**: feed each token's embedding one-at-a-time to
133
+ populate the in-model KV cache state.
134
  5. **CoreML decode**: greedy step-by-step until `<|im_end|>`.
135
 
136
+ The KV cache lives inside the CoreML model as `state`. Call
137
+ `model.make_state()` once per request, then thread the same state object
138
+ through every `predict()` call.
139
 
140
  ## Conversion details
141
 
 
 
 
142
  ```python
143
+ # Apply per-op-name palettize: attention at LUT-8, MLP at LUT-4.
144
+ prog = mlmodel._mil_program
145
+ for op in prog.functions["main"].operations:
146
+ if op.op_type != "const": continue
147
+ n = op.name.lower()
148
+ if "self_attn" in n and any(p in n for p in ("q_proj","k_proj","v_proj","o_proj")):
149
+ attn_ops.append(op.name)
150
+ elif "mlp" in n and any(p in n for p in ("gate_proj","up_proj","down_proj")):
151
+ mlp_ops.append(op.name)
152
+
153
+ config = OptimizationConfig(op_name_configs={
154
+ **{n: OpPalettizerConfig(nbits=8, group_size=8) for n in attn_ops},
155
+ **{n: OpPalettizerConfig(nbits=4, group_size=8) for n in mlp_ops},
156
+ })
157
+ mlmodel = palettize_weights(mlmodel, config)
158
  ```
159
 
160
+ The model exposes 84 attention weight ops (28 layers × 3 attention
161
+ projections after the GQA-shared k/v gets clustered into k+v ops) and
162
+ 84 MLP weight ops (28 layers × 3 MLP projections).
 
163
 
164
+ `compute_precision=FLOAT32` is mandatory — fp16 compute on Qwen3-ASR
165
+ produces all-NaN logits (RMSNorm + attention score overflow).
 
 
166
 
167
+ A `coremltools` local patch was needed in
168
+ `coremltools/converters/mil/frontend/torch/ops.py` `_cast`: numpy arrays
169
+ of size 1 need to be coerced to scalar via `.flatten()[0].item()` before
170
+ the dtype call — see `convert_embeds_mixed.py` setup notes.
171
 
172
  ## Known limitations
173
 
174
+ 1. **ANE rejected**. CoreML's ANE compiler fails (`MILCompilerForANE
175
+ error: failed to compile ANE model using ANEF`) — likely due to model
176
+ size + stateful KV cache. `CPU_AND_NE` fails to load. `ALL` runs on
177
+ **Metal GPU** (correct + ~3- faster than `CPU_ONLY`), which is the
178
+ recommended setting.
179
+ 2. **Audio encoder is ONNX**. The 24-layer Whisper-style encoder isn't
180
+ ported to CoreML yet (ANEMLL is LLM-only). End-to-end runs the
181
  encoder via `onnxruntime` and the LLM via `coremltools`.
182
+ 3. **Quality below ONNX/MLX** by ~2% at 4-bit, due to LUT k-means being
183
+ weaker than GPTQ on this architecture. The uniform LUT-4 variant is
184
+ smaller (826 MB) if size is critical; the mixed 8/4 (1.87 GB) is
185
+ recommended for best quality.
186
 
187
  ## Companion repos
188
 
 
193
  ## Credits
194
 
195
  - Original model: [zhifeixie/Mega-ASR](https://huggingface.co/zhifeixie/Mega-ASR) (1.7B, Apache-2.0)
196
+ - CoreML conversion via [ANEMLL](https://github.com/Anemll/Anemll) with custom input_embeds + mixed-precision patches
197
  - Benchmark: [Voices-in-the-Wild-Bench](https://github.com/xzf-thu/Voices-in-the-Wild-Bench)