Add sanitized INT4/INT8 quantization script
Browse files- scripts/quantize_pixie.py +306 -0
scripts/quantize_pixie.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
quantize_pixie.py β ONNX quantization for XLM-RoBERTa-family embedding models.
|
| 3 |
+
|
| 4 |
+
Produces three self-contained ONNX variants from a float32 source model:
|
| 5 |
+
model_quantized.onnx β INT8 dynamic (all weights including word embeddings)
|
| 6 |
+
model_int4.onnx β INT4 MatMul (MatMulNBits) + INT8 word embedding
|
| 7 |
+
model_int4_full.onnx β INT4 MatMul + INT4 word embedding (opset 21, smallest)
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
python quantize_pixie.py \\
|
| 11 |
+
--input onnx/model.onnx \\
|
| 12 |
+
--outdir onnx/ \\
|
| 13 |
+
[--block-size 32]
|
| 14 |
+
|
| 15 |
+
# Or via environment variables:
|
| 16 |
+
PIXIE_INPUT=onnx/model.onnx PIXIE_OUTDIR=onnx/ python quantize_pixie.py
|
| 17 |
+
|
| 18 |
+
The input model is expected to reside in the same directory as its companion
|
| 19 |
+
data file (model.onnx_data) when using the default HuggingFace layout.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import os
|
| 24 |
+
import struct
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
import numpy as np
|
| 28 |
+
import onnx
|
| 29 |
+
import onnx.version_converter
|
| 30 |
+
from onnxruntime.quantization import (
|
| 31 |
+
QuantType,
|
| 32 |
+
quantize_dynamic,
|
| 33 |
+
)
|
| 34 |
+
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ββ helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
|
| 39 |
+
def _load(path: Path) -> onnx.ModelProto:
|
| 40 |
+
"""Load an ONNX model, handling both inline and external initializers."""
|
| 41 |
+
model = onnx.load(str(path), load_external_data=False)
|
| 42 |
+
data_file = path.with_suffix(".onnx_data")
|
| 43 |
+
if data_file.exists():
|
| 44 |
+
onnx.load_external_data_for_model(model, str(path.parent))
|
| 45 |
+
return model
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _save_temp(model: onnx.ModelProto, path: Path) -> None:
|
| 49 |
+
"""Save a model to disk, inlining all tensors (needed before quantization)."""
|
| 50 |
+
onnx.save(model, str(path))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _find_gather_input_name(model: onnx.ModelProto) -> str | None:
|
| 54 |
+
"""Return the initializer name fed into the first Gather (word embedding) node."""
|
| 55 |
+
for node in model.graph.node:
|
| 56 |
+
if node.op_type == "Gather":
|
| 57 |
+
return node.input[0] # initializer with embedding weight
|
| 58 |
+
return None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ββ INT8 dynamic quantization βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 62 |
+
|
| 63 |
+
def make_int8(src: Path, dst: Path) -> None:
|
| 64 |
+
"""
|
| 65 |
+
INT8 dynamic quantization β all weight tensors (MatMul + Gather).
|
| 66 |
+
|
| 67 |
+
Uses onnxruntime quantize_dynamic with QInt8. The word embedding Gather
|
| 68 |
+
is included, bringing the ~977 MB FP32 embedding table down to ~244 MB.
|
| 69 |
+
"""
|
| 70 |
+
print(f" INT8: {src.name} β {dst.name}")
|
| 71 |
+
quantize_dynamic(str(src), str(dst), weight_type=QuantType.QInt8)
|
| 72 |
+
print(f" INT8 done ({dst.stat().st_size / 1024**2:.0f} MB)")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ββ INT4 MatMulNBits quantization βββββββββββββββββββββββββββββββββββββββββββββ
|
| 76 |
+
|
| 77 |
+
def _apply_matmul_nbits(src_model: onnx.ModelProto, block_size: int) -> onnx.ModelProto:
|
| 78 |
+
"""Apply MatMulNBits (INT4) to all MatMul weight tensors."""
|
| 79 |
+
import tempfile, copy
|
| 80 |
+
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
|
| 81 |
+
tmp_in = Path(f.name)
|
| 82 |
+
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
|
| 83 |
+
tmp_out = Path(f.name)
|
| 84 |
+
try:
|
| 85 |
+
_save_temp(src_model, tmp_in)
|
| 86 |
+
q = MatMulNBitsQuantizer(
|
| 87 |
+
str(tmp_in),
|
| 88 |
+
block_size=block_size,
|
| 89 |
+
is_symmetric=True,
|
| 90 |
+
nodes_to_exclude=[],
|
| 91 |
+
)
|
| 92 |
+
q.process()
|
| 93 |
+
q.model.save_model_to_file(str(tmp_out), use_external_data_format=False)
|
| 94 |
+
return onnx.load(str(tmp_out))
|
| 95 |
+
finally:
|
| 96 |
+
tmp_in.unlink(missing_ok=True)
|
| 97 |
+
tmp_out.unlink(missing_ok=True)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def make_int4_int8_emb(src: Path, dst: Path, block_size: int = 32) -> None:
|
| 101 |
+
"""
|
| 102 |
+
Two-pass: INT4 MatMul (MatMulNBits) + INT8 word embedding.
|
| 103 |
+
|
| 104 |
+
Pass 1 β MatMulNBitsQuantizer packs transformer MatMul weights to 4-bit.
|
| 105 |
+
Pass 2 β quantize_dynamic(op_types=["Gather"], QInt8) quantizes the
|
| 106 |
+
word embedding table (250,002 Γ 1024) from FP32 to INT8.
|
| 107 |
+
"""
|
| 108 |
+
import tempfile
|
| 109 |
+
print(f" INT4+INT8 emb: {src.name} β {dst.name}")
|
| 110 |
+
model = _load(src)
|
| 111 |
+
|
| 112 |
+
# Pass 1: INT4 MatMul
|
| 113 |
+
print(" Pass 1: MatMulNBits INT4 ...")
|
| 114 |
+
matmul_model = _apply_matmul_nbits(model, block_size=block_size)
|
| 115 |
+
|
| 116 |
+
# Pass 2: INT8 Gather (word embedding table only)
|
| 117 |
+
print(" Pass 2: INT8 Gather ...")
|
| 118 |
+
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
|
| 119 |
+
tmp = Path(f.name)
|
| 120 |
+
try:
|
| 121 |
+
_save_temp(matmul_model, tmp)
|
| 122 |
+
quantize_dynamic(
|
| 123 |
+
str(tmp), str(dst),
|
| 124 |
+
op_types_to_quantize=["Gather"],
|
| 125 |
+
weight_type=QuantType.QInt8,
|
| 126 |
+
)
|
| 127 |
+
finally:
|
| 128 |
+
tmp.unlink(missing_ok=True)
|
| 129 |
+
print(f" INT4+INT8 emb done ({dst.stat().st_size / 1024**2:.0f} MB)")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# ββ INT4 full (word embeddings packed as INT4 nibbles) ββββββββββββββββββββββββ
|
| 133 |
+
|
| 134 |
+
def _pack_int4_rows(weight: np.ndarray) -> tuple[bytes, np.ndarray]:
|
| 135 |
+
"""
|
| 136 |
+
Pack a 2-D float32 tensor as per-row symmetric INT4.
|
| 137 |
+
|
| 138 |
+
Each row r is quantized with scale = max(|row_r|) / 7.
|
| 139 |
+
Values are clamped to [-7, 7] and packed as two INT4 nibbles per byte
|
| 140 |
+
(little-endian nibble order: low nibble = even index, high nibble = odd).
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
packed_bytes β raw bytes (vocab_size Γ ceil(dim/2))
|
| 144 |
+
scales β float32 scale per row (vocab_size,)
|
| 145 |
+
"""
|
| 146 |
+
vocab, dim = weight.shape
|
| 147 |
+
abs_max = np.abs(weight).max(axis=1, keepdims=True).clip(min=1e-9)
|
| 148 |
+
scales = (abs_max / 7.0).squeeze(1).astype(np.float32)
|
| 149 |
+
quantized = np.round(weight / abs_max * 7.0).clip(-7, 7).astype(np.int8)
|
| 150 |
+
|
| 151 |
+
# Pack pairs of INT4 values into bytes
|
| 152 |
+
# Treat negative as unsigned 4-bit: -7..7 β offset doesn't apply for symmetric
|
| 153 |
+
# Use unsigned nibbles with zero_point=0 (symmetric)
|
| 154 |
+
u4 = (quantized % 16).astype(np.uint8) # map negatives: e.g. -1 β 15
|
| 155 |
+
padded = u4 if dim % 2 == 0 else np.pad(u4, ((0, 0), (0, 1)))
|
| 156 |
+
packed = padded[:, 0::2] | (padded[:, 1::2] << 4)
|
| 157 |
+
return packed.tobytes(), scales
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def make_int4_full(src: Path, dst: Path, block_size: int = 32) -> None:
|
| 161 |
+
"""
|
| 162 |
+
INT4 full: INT4 MatMul (MatMulNBits) + INT4 word embedding (DequantizeLinear).
|
| 163 |
+
|
| 164 |
+
The word embedding Gather is replaced by:
|
| 165 |
+
INT4_packed_tensor β DequantizeLinear(axis=0, scale=per_row) β FP32 lookup
|
| 166 |
+
Requires ONNX opset 21 for the INT4 DequantizeLinear kernel in OnnxRuntime.
|
| 167 |
+
|
| 168 |
+
Build from the FP32 source (not from model_int4.onnx which already has an
|
| 169 |
+
INT8 DequantizeLinear node on the Gather output, causing a type conflict).
|
| 170 |
+
"""
|
| 171 |
+
import tempfile
|
| 172 |
+
print(f" INT4 full: {src.name} β {dst.name}")
|
| 173 |
+
model = _load(src)
|
| 174 |
+
|
| 175 |
+
# Step 1: INT4 MatMul
|
| 176 |
+
print(" Step 1: MatMulNBits INT4 ...")
|
| 177 |
+
matmul_model = _apply_matmul_nbits(model, block_size=block_size)
|
| 178 |
+
|
| 179 |
+
# Step 2: Migrate to opset 21 (required for INT4 DequantizeLinear)
|
| 180 |
+
print(" Step 2: Opset 14 β 21 ...")
|
| 181 |
+
matmul_model = onnx.version_converter.convert_version(matmul_model, 21)
|
| 182 |
+
|
| 183 |
+
# Step 3: Find and replace the Gather (word embedding) node
|
| 184 |
+
print(" Step 3: INT4-pack word embedding table ...")
|
| 185 |
+
graph = matmul_model.graph
|
| 186 |
+
|
| 187 |
+
# Locate embedding initializer name
|
| 188 |
+
embed_init_name = _find_gather_input_name(matmul_model)
|
| 189 |
+
if embed_init_name is None:
|
| 190 |
+
raise RuntimeError("Could not find Gather (word embedding) node in graph.")
|
| 191 |
+
|
| 192 |
+
# Extract current FP32 embedding tensor
|
| 193 |
+
embed_init = next(
|
| 194 |
+
(init for init in graph.initializer if init.name == embed_init_name), None
|
| 195 |
+
)
|
| 196 |
+
if embed_init is None:
|
| 197 |
+
raise RuntimeError(f"Initializer '{embed_init_name}' not found.")
|
| 198 |
+
|
| 199 |
+
weight_fp32 = np.array(
|
| 200 |
+
onnx.numpy_helper.to_array(embed_init), dtype=np.float32
|
| 201 |
+
)
|
| 202 |
+
packed_bytes, scales = _pack_int4_rows(weight_fp32)
|
| 203 |
+
|
| 204 |
+
# Replace the FP32 initializer with packed INT4
|
| 205 |
+
graph.initializer.remove(embed_init)
|
| 206 |
+
int4_name = embed_init_name + "_int4"
|
| 207 |
+
scales_name = embed_init_name + "_scales"
|
| 208 |
+
|
| 209 |
+
# INT4 tensor stored as raw bytes in ONNX (UINT4 = elem_type 17)
|
| 210 |
+
int4_tensor = onnx.TensorProto()
|
| 211 |
+
int4_tensor.name = int4_name
|
| 212 |
+
int4_tensor.data_type = 17 # UINT4
|
| 213 |
+
int4_tensor.dims.extend(list(weight_fp32.shape))
|
| 214 |
+
int4_tensor.raw_data = packed_bytes
|
| 215 |
+
graph.initializer.append(int4_tensor)
|
| 216 |
+
|
| 217 |
+
# Per-row scale tensor (float32)
|
| 218 |
+
scales_tensor = onnx.numpy_helper.from_array(scales, name=scales_name)
|
| 219 |
+
graph.initializer.append(scales_tensor)
|
| 220 |
+
|
| 221 |
+
# Insert DequantizeLinear(axis=0) between INT4 weights and the Gather node
|
| 222 |
+
dql_out_name = embed_init_name + "_dq"
|
| 223 |
+
dql_node = onnx.helper.make_node(
|
| 224 |
+
"DequantizeLinear",
|
| 225 |
+
inputs=[int4_name, scales_name],
|
| 226 |
+
outputs=[dql_out_name],
|
| 227 |
+
axis=0,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Reroute: Gather now reads from dql_out instead of original initializer
|
| 231 |
+
for node in graph.node:
|
| 232 |
+
if node.op_type == "Gather" and node.input[0] == embed_init_name:
|
| 233 |
+
node.input[0] = dql_out_name
|
| 234 |
+
|
| 235 |
+
# Insert DequantizeLinear before the Gather node
|
| 236 |
+
gather_idx = next(
|
| 237 |
+
i for i, n in enumerate(graph.node)
|
| 238 |
+
if n.op_type == "Gather" and n.input[0] == dql_out_name
|
| 239 |
+
)
|
| 240 |
+
graph.node.insert(gather_idx, dql_node)
|
| 241 |
+
|
| 242 |
+
# Save
|
| 243 |
+
with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as f:
|
| 244 |
+
tmp = Path(f.name)
|
| 245 |
+
try:
|
| 246 |
+
onnx.save(matmul_model, str(tmp))
|
| 247 |
+
onnx.checker.check_model(str(tmp))
|
| 248 |
+
import shutil
|
| 249 |
+
shutil.copy(tmp, dst)
|
| 250 |
+
finally:
|
| 251 |
+
tmp.unlink(missing_ok=True)
|
| 252 |
+
print(f" INT4 full done ({dst.stat().st_size / 1024**2:.0f} MB)")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
# ββ entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββοΏ½οΏ½οΏ½ββ
|
| 256 |
+
|
| 257 |
+
def parse_args() -> argparse.Namespace:
|
| 258 |
+
p = argparse.ArgumentParser(description=__doc__,
|
| 259 |
+
formatter_class=argparse.RawDescriptionHelpFormatter)
|
| 260 |
+
p.add_argument("--input", default=os.environ.get("PIXIE_INPUT"),
|
| 261 |
+
help="Path to the FP32 source model.onnx (may have companion .onnx_data)")
|
| 262 |
+
p.add_argument("--outdir", default=os.environ.get("PIXIE_OUTDIR", "."),
|
| 263 |
+
help="Output directory for quantized models (default: cwd)")
|
| 264 |
+
p.add_argument("--block-size", type=int, default=32,
|
| 265 |
+
help="Block size for MatMulNBits INT4 (default: 32)")
|
| 266 |
+
p.add_argument("--variants", nargs="+",
|
| 267 |
+
choices=["int8", "int4", "int4_full", "all"],
|
| 268 |
+
default=["all"],
|
| 269 |
+
help="Which variants to produce (default: all)")
|
| 270 |
+
return p.parse_args()
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def main() -> None:
|
| 274 |
+
args = parse_args()
|
| 275 |
+
if not args.input:
|
| 276 |
+
raise SystemExit("Error: --input or PIXIE_INPUT env var required.")
|
| 277 |
+
|
| 278 |
+
src = Path(args.input).resolve()
|
| 279 |
+
outdir = Path(args.outdir).resolve()
|
| 280 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 281 |
+
|
| 282 |
+
variants = set(args.variants)
|
| 283 |
+
if "all" in variants:
|
| 284 |
+
variants = {"int8", "int4", "int4_full"}
|
| 285 |
+
|
| 286 |
+
print(f"Source : {src}")
|
| 287 |
+
print(f"Out dir: {outdir}")
|
| 288 |
+
print(f"Targets: {', '.join(sorted(variants))}")
|
| 289 |
+
print()
|
| 290 |
+
|
| 291 |
+
if "int8" in variants:
|
| 292 |
+
make_int8(src, outdir / "model_quantized.onnx")
|
| 293 |
+
|
| 294 |
+
if "int4" in variants:
|
| 295 |
+
make_int4_int8_emb(src, outdir / "model_int4.onnx",
|
| 296 |
+
block_size=args.block_size)
|
| 297 |
+
|
| 298 |
+
if "int4_full" in variants:
|
| 299 |
+
make_int4_full(src, outdir / "model_int4_full.onnx",
|
| 300 |
+
block_size=args.block_size)
|
| 301 |
+
|
| 302 |
+
print("\nAll done.")
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
if __name__ == "__main__":
|
| 306 |
+
main()
|