Reza2kn commited on
Commit
ea99fd4
·
verified ·
1 Parent(s): 74e5b09

Working CoreML LUT4 input_embeds variant (86.9% on VITW)

Browse files
Files changed (1) hide show
  1. convert_embeds.py +221 -0
convert_embeds.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom ANEMLL conversion that takes inputs_embeds instead of input_ids.
2
+
3
+ Required for Mega-ASR: at inference we scatter audio encoder outputs at
4
+ <|audio_pad|> positions BEFORE the LLM, then feed pre-embedded hidden_states
5
+ to the decoder. The default ANEMLL conversion has embed_tokens baked in
6
+ (takes input_ids); we need it bypassed.
7
+
8
+ This script:
9
+ 1. Loads QwenForCausalLM via ANEMLL's loader
10
+ 2. Monkey-patches QwenModel.forward to accept an optional inputs_embeds arg
11
+ 3. Defines a fresh Wrapper that exposes inputs_embeds as the first input
12
+ 4. Traces + converts via ct.convert with LUT-4 palettization in postprocess
13
+ 5. Saves the resulting .mlpackage
14
+
15
+ Reuses ANEMLL's QwenConverter postprocessing (LUT-4 quantization, state
16
+ declarations) by calling its methods after the inputs are swapped.
17
+ """
18
+ from __future__ import annotations
19
+ import argparse
20
+ import os
21
+ import sys
22
+ from pathlib import Path
23
+
24
+ sys.path.insert(0, "/tmp/Anemll")
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ import coremltools as ct
29
+ import coremltools.optimize as cto
30
+
31
+ # Apply the local coremltools _cast patch we made earlier (now resident in the
32
+ # env's installed file; nothing to do here, just import).
33
+
34
+
35
+ def patch_qwen_for_inputs_embeds():
36
+ """Monkey-patch QwenModel.forward + QwenForCausalLM.forward to accept inputs_embeds.
37
+
38
+ When the caller passes a float tensor in the input_ids slot, treat it as
39
+ pre-embedded hidden_states and skip embed_tokens. Also relax the strict
40
+ 2D shape assert in QwenForCausalLM.
41
+ """
42
+ from anemll.models import qwen_model as qm
43
+
44
+ orig_model_forward = qm.QwenModel.forward
45
+
46
+ def model_forward_or_embeds(
47
+ self, input_ids, causal_mask, position_ids, current_pos, IN_PREFILL: bool = False,
48
+ ):
49
+ if input_ids.dtype in (torch.float16, torch.float32, torch.bfloat16):
50
+ hidden_states = input_ids
51
+ if IN_PREFILL:
52
+ rotary_emb = self.get_rotary_embedding_prefill(position_ids)
53
+ else:
54
+ rotary_emb = self.get_rotary_embeddings_s(current_pos)
55
+ hidden_states = self.process_layers(
56
+ hidden_states, position_ids, causal_mask,
57
+ current_pos, rotary_emb, start_layer=0, end_layer=None,
58
+ IN_PREFILL=IN_PREFILL,
59
+ )
60
+ hidden_states = self.norm(hidden_states)
61
+ return hidden_states
62
+ return orig_model_forward(self, input_ids, causal_mask, position_ids,
63
+ current_pos, IN_PREFILL=IN_PREFILL)
64
+
65
+ qm.QwenModel.forward = model_forward_or_embeds
66
+
67
+ # Also patch QwenForCausalLM.forward — it asserts input_ids must be 2D
68
+ # (line 1050 in qwen_model.py). For inputs_embeds (3D), skip that.
69
+ orig_causal_forward = qm.QwenForCausalLM.forward
70
+
71
+ def causal_forward_or_embeds(
72
+ self, input_ids, update_mask, position_ids, causal_mask, current_pos,
73
+ IN_PREFILL: bool = False,
74
+ ):
75
+ if input_ids.dtype in (torch.float16, torch.float32, torch.bfloat16):
76
+ # Pre-embedded path — call QwenModel directly, bypass the 2D assert
77
+ hidden_states = self.model(
78
+ input_ids, causal_mask, position_ids, current_pos,
79
+ IN_PREFILL=IN_PREFILL,
80
+ )
81
+ # Replicate the lm-head projection logic from the original forward
82
+ # (single-token decode case)
83
+ if not IN_PREFILL and current_pos is not None:
84
+ seq_len = hidden_states.shape[1]
85
+ if seq_len == 1:
86
+ pos_tensor = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
87
+ else:
88
+ if isinstance(current_pos, torch.Tensor):
89
+ pos_tensor = current_pos if current_pos.dim() > 0 else current_pos.unsqueeze(0)
90
+ else:
91
+ pos_tensor = torch.tensor([current_pos], device=hidden_states.device, dtype=torch.long)
92
+ hidden_states = torch.index_select(hidden_states, dim=1, index=pos_tensor)
93
+ # Use the same Conv2d / 16-way split as the original
94
+ hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(2).to(qm.MODEL_DTYPE)
95
+ outs = tuple(
96
+ getattr(self, f"lm_head16_{k}")(hidden_states).squeeze(2).transpose(1, 2)
97
+ for k in range(1, 17)
98
+ )
99
+ return outs
100
+ return orig_causal_forward(
101
+ self, input_ids, update_mask, position_ids, causal_mask, current_pos,
102
+ IN_PREFILL=IN_PREFILL,
103
+ )
104
+
105
+ qm.QwenForCausalLM.forward = causal_forward_or_embeds
106
+ print("[patch] QwenModel + QwenForCausalLM now accept float inputs_embeds")
107
+
108
+
109
+ def main():
110
+ ap = argparse.ArgumentParser()
111
+ ap.add_argument("--model", required=True, type=Path)
112
+ ap.add_argument("--output", required=True, type=Path,
113
+ help="Output .mlpackage path")
114
+ ap.add_argument("--lut", type=int, default=4)
115
+ ap.add_argument("--per-channel", type=int, default=8)
116
+ ap.add_argument("--context-length", type=int, default=512)
117
+ ap.add_argument("--hidden-size", type=int, default=2048)
118
+ args = ap.parse_args()
119
+
120
+ patch_qwen_for_inputs_embeds()
121
+
122
+ from anemll.models.qwen_model import (
123
+ QwenForCausalLM, QwenConfig, MODEL_DTYPE, TEST_DEVICE,
124
+ )
125
+ from anemll.ane_converter import qwen_converter as qc
126
+
127
+ # Force CoreML mode flags
128
+ import anemll.models.qwen_model as qm
129
+ qm.ENABLE_COREML = True
130
+
131
+ # Load config + model
132
+ import json
133
+ cfg = json.load(open(args.model / "config.json"))
134
+ cfg["context_length"] = args.context_length
135
+ cfg["state_length"] = args.context_length
136
+ config = QwenConfig(**cfg)
137
+ model = QwenForCausalLM(config, enable_coreml=True)
138
+ model.load_pretrained_weights(str(args.model))
139
+ model.eval()
140
+ for p in model.parameters():
141
+ p.requires_grad = False
142
+ print(f"Model loaded: hidden={config.hidden_size}, layers={config.num_hidden_layers}")
143
+
144
+ # Custom wrapper taking inputs_embeds
145
+ class WrapperEmbeds(torch.nn.Module):
146
+ def __init__(self, model):
147
+ super().__init__()
148
+ self.model = model
149
+
150
+ def forward(self, inputs_embeds, position_ids, causal_mask, current_pos, update_mask):
151
+ return self.model(
152
+ input_ids=inputs_embeds, # float tensor → triggers the patched path
153
+ update_mask=update_mask,
154
+ position_ids=position_ids,
155
+ causal_mask=causal_mask,
156
+ current_pos=current_pos,
157
+ IN_PREFILL=False,
158
+ )
159
+
160
+ wrapper = WrapperEmbeds(model).eval()
161
+
162
+ # Build sample inputs for tracing
163
+ sample_inputs_embeds = torch.zeros(
164
+ (1, 1, config.hidden_size), dtype=torch.float16, device=TEST_DEVICE,
165
+ )
166
+ sample_position_ids = torch.zeros((1,), dtype=torch.int32, device=TEST_DEVICE)
167
+ sample_causal_mask = torch.zeros(
168
+ (1, 1, 1, args.context_length), dtype=torch.float16, device=TEST_DEVICE,
169
+ )
170
+ sample_current_pos = torch.zeros((1,), dtype=torch.int32, device=TEST_DEVICE)
171
+ sample_update_mask = torch.zeros(
172
+ (1, 1, args.context_length, 1), dtype=torch.float16, device=TEST_DEVICE,
173
+ )
174
+
175
+ print("Tracing ...")
176
+ traced = torch.jit.trace(
177
+ wrapper,
178
+ (sample_inputs_embeds, sample_position_ids, sample_causal_mask,
179
+ sample_current_pos, sample_update_mask),
180
+ )
181
+ print("Trace done. Converting to CoreML (fp16) ...")
182
+
183
+ # ANEMLL declares the KV cache as a state via GetTransformerStates
184
+ states = qc.QwenConverter.GetTransformerStates(model, prefix="model.model.")
185
+
186
+ mlmodel = ct.convert(
187
+ traced,
188
+ inputs=[
189
+ ct.TensorType(name="inputs_embeds", shape=sample_inputs_embeds.shape, dtype=np.float16),
190
+ ct.TensorType(name="position_ids", shape=sample_position_ids.shape, dtype=np.int32),
191
+ ct.TensorType(name="causal_mask", shape=sample_causal_mask.shape, dtype=np.float16),
192
+ ct.TensorType(name="current_pos", shape=sample_current_pos.shape, dtype=np.int32),
193
+ ct.TensorType(name="update_mask", shape=sample_update_mask.shape, dtype=np.float16),
194
+ ],
195
+ outputs=[ct.TensorType(name=f"logits{i+1}", dtype=np.float16) for i in range(16)],
196
+ states=states,
197
+ minimum_deployment_target=ct.target.iOS18,
198
+ # fp32 compute (activations) — fp16 overflows in Qwen3-ASR's RMSNorm/attention.
199
+ # Matches aoiandroid's finding for the same base model.
200
+ compute_precision=ct.precision.FLOAT32,
201
+ compute_units=ct.ComputeUnit.CPU_AND_NE,
202
+ convert_to="mlprogram",
203
+ skip_model_load=True,
204
+ )
205
+
206
+ if args.lut and args.lut < 16:
207
+ print(f"Applying LUT-{args.lut} palettization (per_channel={args.per_channel}) ...")
208
+ config_palette = cto.coreml.OpPalettizerConfig(
209
+ nbits=args.lut, mode="kmeans",
210
+ granularity="per_grouped_channel", group_size=args.per_channel,
211
+ )
212
+ pal_config = cto.coreml.OptimizationConfig(global_config=config_palette)
213
+ mlmodel = cto.coreml.palettize_weights(mlmodel, pal_config)
214
+
215
+ args.output.parent.mkdir(parents=True, exist_ok=True)
216
+ mlmodel.save(str(args.output))
217
+ print(f"Saved: {args.output}")
218
+
219
+
220
+ if __name__ == "__main__":
221
+ main()