File size: 3,593 Bytes
dac920b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | """Load pre-quantized INT8 DramaBox DiT weights from safetensors.
Quantized layers are stored as:
{layer_name}.weight.__int_data (INT8 tensor)
{layer_name}.weight.__scale (BF16 per-channel scale)
Non-quantized layers are stored as plain BF16 tensors.
Usage:
from load_int8 import load_int8_dit
load_int8_dit(tts._velocity_model, "dramabox-dit-int8-selective.safetensors")
"""
import json
import logging
import os
import torch
from safetensors.torch import load_file
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
def load_int8_dit(
model: torch.nn.Module,
safetensors_path: str,
config_path: str | None = None,
device: str = "cuda",
) -> torch.nn.Module:
"""Replace model weights with pre-quantized INT8 weights from safetensors.
For quantized layers, reconstructs the dequantized BF16 weight from
int_data * scale (equivalent to what torchao does at runtime, but without
needing torchao installed for loading).
For runtime INT8 inference (keeping weights in INT8 and dequantizing during
matmul), use the torchao approach instead — see README.md Option 1.
"""
if config_path is None:
config_path = os.path.join(os.path.dirname(safetensors_path), "config.json")
tensors = load_file(safetensors_path, device=device)
logging.info(f"Loaded {len(tensors)} tensors from {safetensors_path}")
int_data_suffix = ".weight.__int_data"
scale_suffix = ".weight.__scale"
quantized_names = set()
for key in tensors:
if key.endswith(int_data_suffix):
name = key[: -len(int_data_suffix)]
quantized_names.add(name)
sd = model.state_dict()
loaded, skipped = 0, 0
for key in list(sd.keys()):
parts = key.rsplit(".", 1)
if len(parts) == 2:
layer_name, param_name = parts
else:
layer_name, param_name = "", parts[0]
if layer_name in quantized_names and param_name == "weight":
int_data = tensors[f"{layer_name}{int_data_suffix}"]
scale = tensors[f"{layer_name}{scale_suffix}"]
sd[key] = (int_data.float() * scale.unsqueeze(1)).to(torch.bfloat16)
loaded += 1
elif key in tensors:
sd[key] = tensors[key]
loaded += 1
elif f"{layer_name}.{param_name}" in tensors:
sd[key] = tensors[f"{layer_name}.{param_name}"]
loaded += 1
else:
skipped += 1
model.load_state_dict(sd, strict=False)
logging.info(
f"Loaded {loaded} params ({len(quantized_names)} dequantized from INT8), "
f"skipped {skipped}"
)
return model
def load_int8_dit_torchao(
model: torch.nn.Module,
safetensors_path: str,
device: str = "cuda",
) -> torch.nn.Module:
"""Load INT8 weights and apply torchao quantization for runtime INT8 matmul.
This keeps weights in INT8 during inference (lower VRAM) but requires torchao.
"""
from torchao.quantization import quantize_, Int8WeightOnlyConfig
load_int8_dit(model, safetensors_path, device=device)
config_path = os.path.join(os.path.dirname(safetensors_path), "config.json")
with open(config_path) as f:
config = json.load(f)
quantized_set = set(config["quantized_layers"])
def filter_fn(mod, fqn):
return isinstance(mod, torch.nn.Linear) and fqn in quantized_set
quantize_(model, Int8WeightOnlyConfig(), filter_fn=filter_fn)
logging.info(f"Applied torchao INT8 to {len(quantized_set)} layers")
return model
|