cstr commited on
Commit
4c989c0
Β·
verified Β·
1 Parent(s): 8a5712e

Add sanitized INT4/INT8 quantization script

Browse files
Files changed (1) hide show
  1. scripts/quantize_pixie.py +306 -0
scripts/quantize_pixie.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ quantize_pixie.py β€” ONNX quantization for XLM-RoBERTa-family embedding models.
3
+
4
+ Produces three self-contained ONNX variants from a float32 source model:
5
+ model_quantized.onnx β€” INT8 dynamic (all weights including word embeddings)
6
+ model_int4.onnx β€” INT4 MatMul (MatMulNBits) + INT8 word embedding
7
+ model_int4_full.onnx β€” INT4 MatMul + INT4 word embedding (opset 21, smallest)
8
+
9
+ Usage:
10
+ python quantize_pixie.py \\
11
+ --input onnx/model.onnx \\
12
+ --outdir onnx/ \\
13
+ [--block-size 32]
14
+
15
+ # Or via environment variables:
16
+ PIXIE_INPUT=onnx/model.onnx PIXIE_OUTDIR=onnx/ python quantize_pixie.py
17
+
18
+ The input model is expected to reside in the same directory as its companion
19
+ data file (model.onnx_data) when using the default HuggingFace layout.
20
+ """
21
+
22
+ import argparse
23
+ import os
24
+ import struct
25
+ from pathlib import Path
26
+
27
+ import numpy as np
28
+ import onnx
29
+ import onnx.version_converter
30
+ from onnxruntime.quantization import (
31
+ QuantType,
32
+ quantize_dynamic,
33
+ )
34
+ from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
35
+
36
+
37
+ # ── helpers ──────────────────────────────────────────────────────────────────
38
+
39
+ def _load(path: Path) -> onnx.ModelProto:
40
+ """Load an ONNX model, handling both inline and external initializers."""
41
+ model = onnx.load(str(path), load_external_data=False)
42
+ data_file = path.with_suffix(".onnx_data")
43
+ if data_file.exists():
44
+ onnx.load_external_data_for_model(model, str(path.parent))
45
+ return model
46
+
47
+
48
+ def _save_temp(model: onnx.ModelProto, path: Path) -> None:
49
+ """Save a model to disk, inlining all tensors (needed before quantization)."""
50
+ onnx.save(model, str(path))
51
+
52
+
53
+ def _find_gather_input_name(model: onnx.ModelProto) -> str | None:
54
+ """Return the initializer name fed into the first Gather (word embedding) node."""
55
+ for node in model.graph.node:
56
+ if node.op_type == "Gather":
57
+ return node.input[0] # initializer with embedding weight
58
+ return None
59
+
60
+
61
+ # ── INT8 dynamic quantization ─────────────────────────────────────────────────
62
+
63
+ def make_int8(src: Path, dst: Path) -> None:
64
+ """
65
+ INT8 dynamic quantization β€” all weight tensors (MatMul + Gather).
66
+
67
+ Uses onnxruntime quantize_dynamic with QInt8. The word embedding Gather
68
+ is included, bringing the ~977 MB FP32 embedding table down to ~244 MB.
69
+ """
70
+ print(f" INT8: {src.name} β†’ {dst.name}")
71
+ quantize_dynamic(str(src), str(dst), weight_type=QuantType.QInt8)
72
+ print(f" INT8 done ({dst.stat().st_size / 1024**2:.0f} MB)")
73
+
74
+
75
+ # ── INT4 MatMulNBits quantization ─────────────────────────────────────────────
76
+
77
+ def _apply_matmul_nbits(src_model: onnx.ModelProto, block_size: int) -> onnx.ModelProto:
78
+ """Apply MatMulNBits (INT4) to all MatMul weight tensors."""
79
+ import tempfile, copy
80
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
81
+ tmp_in = Path(f.name)
82
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
83
+ tmp_out = Path(f.name)
84
+ try:
85
+ _save_temp(src_model, tmp_in)
86
+ q = MatMulNBitsQuantizer(
87
+ str(tmp_in),
88
+ block_size=block_size,
89
+ is_symmetric=True,
90
+ nodes_to_exclude=[],
91
+ )
92
+ q.process()
93
+ q.model.save_model_to_file(str(tmp_out), use_external_data_format=False)
94
+ return onnx.load(str(tmp_out))
95
+ finally:
96
+ tmp_in.unlink(missing_ok=True)
97
+ tmp_out.unlink(missing_ok=True)
98
+
99
+
100
+ def make_int4_int8_emb(src: Path, dst: Path, block_size: int = 32) -> None:
101
+ """
102
+ Two-pass: INT4 MatMul (MatMulNBits) + INT8 word embedding.
103
+
104
+ Pass 1 β€” MatMulNBitsQuantizer packs transformer MatMul weights to 4-bit.
105
+ Pass 2 β€” quantize_dynamic(op_types=["Gather"], QInt8) quantizes the
106
+ word embedding table (250,002 Γ— 1024) from FP32 to INT8.
107
+ """
108
+ import tempfile
109
+ print(f" INT4+INT8 emb: {src.name} β†’ {dst.name}")
110
+ model = _load(src)
111
+
112
+ # Pass 1: INT4 MatMul
113
+ print(" Pass 1: MatMulNBits INT4 ...")
114
+ matmul_model = _apply_matmul_nbits(model, block_size=block_size)
115
+
116
+ # Pass 2: INT8 Gather (word embedding table only)
117
+ print(" Pass 2: INT8 Gather ...")
118
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
119
+ tmp = Path(f.name)
120
+ try:
121
+ _save_temp(matmul_model, tmp)
122
+ quantize_dynamic(
123
+ str(tmp), str(dst),
124
+ op_types_to_quantize=["Gather"],
125
+ weight_type=QuantType.QInt8,
126
+ )
127
+ finally:
128
+ tmp.unlink(missing_ok=True)
129
+ print(f" INT4+INT8 emb done ({dst.stat().st_size / 1024**2:.0f} MB)")
130
+
131
+
132
+ # ── INT4 full (word embeddings packed as INT4 nibbles) ────────────────────────
133
+
134
+ def _pack_int4_rows(weight: np.ndarray) -> tuple[bytes, np.ndarray]:
135
+ """
136
+ Pack a 2-D float32 tensor as per-row symmetric INT4.
137
+
138
+ Each row r is quantized with scale = max(|row_r|) / 7.
139
+ Values are clamped to [-7, 7] and packed as two INT4 nibbles per byte
140
+ (little-endian nibble order: low nibble = even index, high nibble = odd).
141
+
142
+ Returns:
143
+ packed_bytes β€” raw bytes (vocab_size Γ— ceil(dim/2))
144
+ scales β€” float32 scale per row (vocab_size,)
145
+ """
146
+ vocab, dim = weight.shape
147
+ abs_max = np.abs(weight).max(axis=1, keepdims=True).clip(min=1e-9)
148
+ scales = (abs_max / 7.0).squeeze(1).astype(np.float32)
149
+ quantized = np.round(weight / abs_max * 7.0).clip(-7, 7).astype(np.int8)
150
+
151
+ # Pack pairs of INT4 values into bytes
152
+ # Treat negative as unsigned 4-bit: -7..7 β†’ offset doesn't apply for symmetric
153
+ # Use unsigned nibbles with zero_point=0 (symmetric)
154
+ u4 = (quantized % 16).astype(np.uint8) # map negatives: e.g. -1 β†’ 15
155
+ padded = u4 if dim % 2 == 0 else np.pad(u4, ((0, 0), (0, 1)))
156
+ packed = padded[:, 0::2] | (padded[:, 1::2] << 4)
157
+ return packed.tobytes(), scales
158
+
159
+
160
+ def make_int4_full(src: Path, dst: Path, block_size: int = 32) -> None:
161
+ """
162
+ INT4 full: INT4 MatMul (MatMulNBits) + INT4 word embedding (DequantizeLinear).
163
+
164
+ The word embedding Gather is replaced by:
165
+ INT4_packed_tensor β†’ DequantizeLinear(axis=0, scale=per_row) β†’ FP32 lookup
166
+ Requires ONNX opset 21 for the INT4 DequantizeLinear kernel in OnnxRuntime.
167
+
168
+ Build from the FP32 source (not from model_int4.onnx which already has an
169
+ INT8 DequantizeLinear node on the Gather output, causing a type conflict).
170
+ """
171
+ import tempfile
172
+ print(f" INT4 full: {src.name} β†’ {dst.name}")
173
+ model = _load(src)
174
+
175
+ # Step 1: INT4 MatMul
176
+ print(" Step 1: MatMulNBits INT4 ...")
177
+ matmul_model = _apply_matmul_nbits(model, block_size=block_size)
178
+
179
+ # Step 2: Migrate to opset 21 (required for INT4 DequantizeLinear)
180
+ print(" Step 2: Opset 14 β†’ 21 ...")
181
+ matmul_model = onnx.version_converter.convert_version(matmul_model, 21)
182
+
183
+ # Step 3: Find and replace the Gather (word embedding) node
184
+ print(" Step 3: INT4-pack word embedding table ...")
185
+ graph = matmul_model.graph
186
+
187
+ # Locate embedding initializer name
188
+ embed_init_name = _find_gather_input_name(matmul_model)
189
+ if embed_init_name is None:
190
+ raise RuntimeError("Could not find Gather (word embedding) node in graph.")
191
+
192
+ # Extract current FP32 embedding tensor
193
+ embed_init = next(
194
+ (init for init in graph.initializer if init.name == embed_init_name), None
195
+ )
196
+ if embed_init is None:
197
+ raise RuntimeError(f"Initializer '{embed_init_name}' not found.")
198
+
199
+ weight_fp32 = np.array(
200
+ onnx.numpy_helper.to_array(embed_init), dtype=np.float32
201
+ )
202
+ packed_bytes, scales = _pack_int4_rows(weight_fp32)
203
+
204
+ # Replace the FP32 initializer with packed INT4
205
+ graph.initializer.remove(embed_init)
206
+ int4_name = embed_init_name + "_int4"
207
+ scales_name = embed_init_name + "_scales"
208
+
209
+ # INT4 tensor stored as raw bytes in ONNX (UINT4 = elem_type 17)
210
+ int4_tensor = onnx.TensorProto()
211
+ int4_tensor.name = int4_name
212
+ int4_tensor.data_type = 17 # UINT4
213
+ int4_tensor.dims.extend(list(weight_fp32.shape))
214
+ int4_tensor.raw_data = packed_bytes
215
+ graph.initializer.append(int4_tensor)
216
+
217
+ # Per-row scale tensor (float32)
218
+ scales_tensor = onnx.numpy_helper.from_array(scales, name=scales_name)
219
+ graph.initializer.append(scales_tensor)
220
+
221
+ # Insert DequantizeLinear(axis=0) between INT4 weights and the Gather node
222
+ dql_out_name = embed_init_name + "_dq"
223
+ dql_node = onnx.helper.make_node(
224
+ "DequantizeLinear",
225
+ inputs=[int4_name, scales_name],
226
+ outputs=[dql_out_name],
227
+ axis=0,
228
+ )
229
+
230
+ # Reroute: Gather now reads from dql_out instead of original initializer
231
+ for node in graph.node:
232
+ if node.op_type == "Gather" and node.input[0] == embed_init_name:
233
+ node.input[0] = dql_out_name
234
+
235
+ # Insert DequantizeLinear before the Gather node
236
+ gather_idx = next(
237
+ i for i, n in enumerate(graph.node)
238
+ if n.op_type == "Gather" and n.input[0] == dql_out_name
239
+ )
240
+ graph.node.insert(gather_idx, dql_node)
241
+
242
+ # Save
243
+ with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
244
+ tmp = Path(f.name)
245
+ try:
246
+ onnx.save(matmul_model, str(tmp))
247
+ onnx.checker.check_model(str(tmp))
248
+ import shutil
249
+ shutil.copy(tmp, dst)
250
+ finally:
251
+ tmp.unlink(missing_ok=True)
252
+ print(f" INT4 full done ({dst.stat().st_size / 1024**2:.0f} MB)")
253
+
254
+
255
+ # ── entry point ───────────────────────────────────────────────────────────���──
256
+
257
+ def parse_args() -> argparse.Namespace:
258
+ p = argparse.ArgumentParser(description=__doc__,
259
+ formatter_class=argparse.RawDescriptionHelpFormatter)
260
+ p.add_argument("--input", default=os.environ.get("PIXIE_INPUT"),
261
+ help="Path to the FP32 source model.onnx (may have companion .onnx_data)")
262
+ p.add_argument("--outdir", default=os.environ.get("PIXIE_OUTDIR", "."),
263
+ help="Output directory for quantized models (default: cwd)")
264
+ p.add_argument("--block-size", type=int, default=32,
265
+ help="Block size for MatMulNBits INT4 (default: 32)")
266
+ p.add_argument("--variants", nargs="+",
267
+ choices=["int8", "int4", "int4_full", "all"],
268
+ default=["all"],
269
+ help="Which variants to produce (default: all)")
270
+ return p.parse_args()
271
+
272
+
273
+ def main() -> None:
274
+ args = parse_args()
275
+ if not args.input:
276
+ raise SystemExit("Error: --input or PIXIE_INPUT env var required.")
277
+
278
+ src = Path(args.input).resolve()
279
+ outdir = Path(args.outdir).resolve()
280
+ outdir.mkdir(parents=True, exist_ok=True)
281
+
282
+ variants = set(args.variants)
283
+ if "all" in variants:
284
+ variants = {"int8", "int4", "int4_full"}
285
+
286
+ print(f"Source : {src}")
287
+ print(f"Out dir: {outdir}")
288
+ print(f"Targets: {', '.join(sorted(variants))}")
289
+ print()
290
+
291
+ if "int8" in variants:
292
+ make_int8(src, outdir / "model_quantized.onnx")
293
+
294
+ if "int4" in variants:
295
+ make_int4_int8_emb(src, outdir / "model_int4.onnx",
296
+ block_size=args.block_size)
297
+
298
+ if "int4_full" in variants:
299
+ make_int4_full(src, outdir / "model_int4_full.onnx",
300
+ block_size=args.block_size)
301
+
302
+ print("\nAll done.")
303
+
304
+
305
+ if __name__ == "__main__":
306
+ main()