Reza2kn commited on
Commit
3639f53
·
verified ·
1 Parent(s): 4159db6

Mixed 8/4 CoreML (90.6% on VITW): LUT-8 attn + LUT-4 MLP

Browse files
Files changed (1) hide show
  1. convert_embeds_mixed.py +228 -0
convert_embeds_mixed.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mixed-precision CoreML convert for Qwen3-ASR LLM (input_embeds variant).
2
+
3
+ Attention layers (q/k/v/o_proj) → LUT-8 (8-bit palettize).
4
+ MLP layers (gate/up/down_proj) → LUT-4 (4-bit palettize).
5
+ Everything else (norms, lm_head, embed) → kept as fp16.
6
+
7
+ Compute precision = fp32 to avoid Qwen3-ASR RMSNorm/attention NaN.
8
+ """
9
+ from __future__ import annotations
10
+ import argparse
11
+ import os
12
+ import sys
13
+ import re
14
+ from pathlib import Path
15
+
16
+ sys.path.insert(0, "/tmp/Anemll")
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import coremltools as ct
21
+ import coremltools.optimize as cto
22
+
23
+
24
+ def patch_qwen_for_inputs_embeds():
25
+ from anemll.models import qwen_model as qm
26
+
27
+ orig_model_forward = qm.QwenModel.forward
28
+
29
+ def model_forward_or_embeds(
30
+ self, input_ids, causal_mask, position_ids, current_pos, IN_PREFILL: bool = False,
31
+ ):
32
+ if input_ids.dtype in (torch.float16, torch.float32, torch.bfloat16):
33
+ hidden_states = input_ids
34
+ if IN_PREFILL:
35
+ rotary_emb = self.get_rotary_embedding_prefill(position_ids)
36
+ else:
37
+ rotary_emb = self.get_rotary_embeddings_s(current_pos)
38
+ hidden_states = self.process_layers(
39
+ hidden_states, position_ids, causal_mask,
40
+ current_pos, rotary_emb, start_layer=0, end_layer=None,
41
+ IN_PREFILL=IN_PREFILL,
42
+ )
43
+ hidden_states = self.norm(hidden_states)
44
+ return hidden_states
45
+ return orig_model_forward(self, input_ids, causal_mask, position_ids,
46
+ current_pos, IN_PREFILL=IN_PREFILL)
47
+
48
+ qm.QwenModel.forward = model_forward_or_embeds
49
+
50
+ orig_causal_forward = qm.QwenForCausalLM.forward
51
+
52
+ def causal_forward_or_embeds(
53
+ self, input_ids, update_mask, position_ids, causal_mask, current_pos,
54
+ IN_PREFILL: bool = False,
55
+ ):
56
+ if input_ids.dtype in (torch.float16, torch.float32, torch.bfloat16):
57
+ hidden_states = self.model(
58
+ input_ids, causal_mask, position_ids, current_pos,
59
+ IN_PREFILL=IN_PREFILL,
60
+ )
61
+ if not IN_PREFILL and current_pos is not None:
62
+ seq_len = hidden_states.shape[1]
63
+ if seq_len == 1:
64
+ pos_tensor = torch.tensor([0], device=hidden_states.device, dtype=torch.long)
65
+ else:
66
+ if isinstance(current_pos, torch.Tensor):
67
+ pos_tensor = current_pos if current_pos.dim() > 0 else current_pos.unsqueeze(0)
68
+ else:
69
+ pos_tensor = torch.tensor([current_pos], device=hidden_states.device, dtype=torch.long)
70
+ hidden_states = torch.index_select(hidden_states, dim=1, index=pos_tensor)
71
+ hidden_states = hidden_states.permute(0, 2, 1).unsqueeze(2).to(qm.MODEL_DTYPE)
72
+ return tuple(
73
+ getattr(self, f"lm_head16_{k}")(hidden_states).squeeze(2).transpose(1, 2)
74
+ for k in range(1, 17)
75
+ )
76
+ return orig_causal_forward(
77
+ self, input_ids, update_mask, position_ids, causal_mask, current_pos,
78
+ IN_PREFILL=IN_PREFILL,
79
+ )
80
+
81
+ qm.QwenForCausalLM.forward = causal_forward_or_embeds
82
+ print("[patch] QwenModel + QwenForCausalLM accept inputs_embeds")
83
+
84
+
85
+ def select_attn_layer(op):
86
+ """Return True if op is in a self_attn projection (q/k/v/o_proj)."""
87
+ n = op.name.lower()
88
+ return ("self_attn" in n and any(p in n for p in ("q_proj", "k_proj", "v_proj", "o_proj")))
89
+
90
+
91
+ def select_mlp_layer(op):
92
+ """Return True if op is in an MLP projection (gate/up/down_proj)."""
93
+ n = op.name.lower()
94
+ return "mlp" in n and any(p in n for p in ("gate_proj", "up_proj", "down_proj"))
95
+
96
+
97
+ def main():
98
+ ap = argparse.ArgumentParser()
99
+ ap.add_argument("--model", required=True, type=Path)
100
+ ap.add_argument("--output", required=True, type=Path)
101
+ ap.add_argument("--attn-bits", type=int, default=8)
102
+ ap.add_argument("--mlp-bits", type=int, default=4)
103
+ ap.add_argument("--group-size", type=int, default=8)
104
+ ap.add_argument("--context-length", type=int, default=512)
105
+ args = ap.parse_args()
106
+
107
+ patch_qwen_for_inputs_embeds()
108
+
109
+ from anemll.models.qwen_model import (
110
+ QwenForCausalLM, QwenConfig, MODEL_DTYPE, TEST_DEVICE,
111
+ )
112
+ from anemll.ane_converter import qwen_converter as qc
113
+ import anemll.models.qwen_model as qm
114
+ qm.ENABLE_COREML = True
115
+
116
+ import json
117
+ cfg = json.load(open(args.model / "config.json"))
118
+ cfg["context_length"] = args.context_length
119
+ cfg["state_length"] = args.context_length
120
+ config = QwenConfig(**cfg)
121
+ model = QwenForCausalLM(config, enable_coreml=True)
122
+ model.load_pretrained_weights(str(args.model))
123
+ model.eval()
124
+ for p in model.parameters():
125
+ p.requires_grad = False
126
+ print(f"Model loaded: hidden={config.hidden_size}, layers={config.num_hidden_layers}")
127
+
128
+ class WrapperEmbeds(torch.nn.Module):
129
+ def __init__(self, model):
130
+ super().__init__()
131
+ self.model = model
132
+
133
+ def forward(self, inputs_embeds, position_ids, causal_mask, current_pos, update_mask):
134
+ return self.model(
135
+ input_ids=inputs_embeds, update_mask=update_mask,
136
+ position_ids=position_ids, causal_mask=causal_mask,
137
+ current_pos=current_pos, IN_PREFILL=False,
138
+ )
139
+
140
+ wrapper = WrapperEmbeds(model).eval()
141
+
142
+ sample_inputs_embeds = torch.zeros((1, 1, config.hidden_size), dtype=torch.float16, device=TEST_DEVICE)
143
+ sample_position_ids = torch.zeros((1,), dtype=torch.int32, device=TEST_DEVICE)
144
+ sample_causal_mask = torch.zeros((1, 1, 1, args.context_length), dtype=torch.float16, device=TEST_DEVICE)
145
+ sample_current_pos = torch.zeros((1,), dtype=torch.int32, device=TEST_DEVICE)
146
+ sample_update_mask = torch.zeros((1, 1, args.context_length, 1), dtype=torch.float16, device=TEST_DEVICE)
147
+
148
+ print("Tracing ...")
149
+ traced = torch.jit.trace(
150
+ wrapper,
151
+ (sample_inputs_embeds, sample_position_ids, sample_causal_mask,
152
+ sample_current_pos, sample_update_mask),
153
+ )
154
+ print("Converting (fp32 compute, no palettize yet) ...")
155
+
156
+ states = qc.QwenConverter.GetTransformerStates(model, prefix="model.model.")
157
+
158
+ mlmodel = ct.convert(
159
+ traced,
160
+ inputs=[
161
+ ct.TensorType(name="inputs_embeds", shape=sample_inputs_embeds.shape, dtype=np.float16),
162
+ ct.TensorType(name="position_ids", shape=sample_position_ids.shape, dtype=np.int32),
163
+ ct.TensorType(name="causal_mask", shape=sample_causal_mask.shape, dtype=np.float16),
164
+ ct.TensorType(name="current_pos", shape=sample_current_pos.shape, dtype=np.int32),
165
+ ct.TensorType(name="update_mask", shape=sample_update_mask.shape, dtype=np.float16),
166
+ ],
167
+ outputs=[ct.TensorType(name=f"logits{i+1}", dtype=np.float16) for i in range(16)],
168
+ states=states,
169
+ minimum_deployment_target=ct.target.iOS18,
170
+ compute_precision=ct.precision.FLOAT32,
171
+ compute_units=ct.ComputeUnit.CPU_AND_NE,
172
+ convert_to="mlprogram",
173
+ skip_model_load=True,
174
+ )
175
+
176
+ # Walk the MIL program to enumerate const-weight ops; classify by name.
177
+ prog = mlmodel._mil_program
178
+ fn = prog.functions["main"]
179
+ attn_op_names, mlp_op_names = [], []
180
+ for op in fn.operations:
181
+ if op.op_type != "const":
182
+ continue
183
+ n = op.name.lower()
184
+ # Skip tiny constants (norms, biases, indices); only target large weight matrices.
185
+ try:
186
+ arr = op.val.val
187
+ if hasattr(arr, "shape") and arr.ndim >= 2 and arr.size >= 64 * 64:
188
+ pass
189
+ else:
190
+ continue
191
+ except Exception:
192
+ continue
193
+ if ("self_attn" in n or "self.attn" in n) and any(p in n for p in ("q_proj", "k_proj", "v_proj", "o_proj")):
194
+ attn_op_names.append(op.name)
195
+ elif ("mlp" in n) and any(p in n for p in ("gate_proj", "up_proj", "down_proj")):
196
+ mlp_op_names.append(op.name)
197
+ print(f"Found {len(attn_op_names)} attention weight ops and {len(mlp_op_names)} MLP weight ops")
198
+ if not attn_op_names or not mlp_op_names:
199
+ print("WARN: matched zero ops — falling back to global LUT-4")
200
+ cfg = cto.coreml.OpPalettizerConfig(
201
+ nbits=args.mlp_bits, mode="kmeans",
202
+ granularity="per_grouped_channel", group_size=args.group_size,
203
+ )
204
+ mlmodel = cto.coreml.palettize_weights(
205
+ mlmodel, cto.coreml.OptimizationConfig(global_config=cfg),
206
+ )
207
+ else:
208
+ cfg_attn = cto.coreml.OpPalettizerConfig(
209
+ nbits=args.attn_bits, mode="kmeans",
210
+ granularity="per_grouped_channel", group_size=args.group_size,
211
+ )
212
+ cfg_mlp = cto.coreml.OpPalettizerConfig(
213
+ nbits=args.mlp_bits, mode="kmeans",
214
+ granularity="per_grouped_channel", group_size=args.group_size,
215
+ )
216
+ op_name_configs = {**{n: cfg_attn for n in attn_op_names},
217
+ **{n: cfg_mlp for n in mlp_op_names}}
218
+ pal_cfg = cto.coreml.OptimizationConfig(op_name_configs=op_name_configs)
219
+ print(f"Mixed palettize: {len(attn_op_names)} ops @ LUT-{args.attn_bits}, {len(mlp_op_names)} ops @ LUT-{args.mlp_bits}, rest fp16")
220
+ mlmodel = cto.coreml.palettize_weights(mlmodel, pal_cfg)
221
+
222
+ args.output.parent.mkdir(parents=True, exist_ok=True)
223
+ mlmodel.save(str(args.output))
224
+ print(f"Saved: {args.output}")
225
+
226
+
227
+ if __name__ == "__main__":
228
+ main()