| """Load pre-quantized INT8 DramaBox DiT weights from safetensors. |
| |
| Quantized layers are stored as: |
| {layer_name}.weight.__int_data (INT8 tensor) |
| {layer_name}.weight.__scale (BF16 per-channel scale) |
| |
| Non-quantized layers are stored as plain BF16 tensors. |
| |
| Usage: |
| from load_int8 import load_int8_dit |
| load_int8_dit(tts._velocity_model, "dramabox-dit-int8-selective.safetensors") |
| """ |
| import json |
| import logging |
| import os |
|
|
| import torch |
| from safetensors.torch import load_file |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
|
|
|
|
| def load_int8_dit( |
| model: torch.nn.Module, |
| safetensors_path: str, |
| config_path: str | None = None, |
| device: str = "cuda", |
| ) -> torch.nn.Module: |
| """Replace model weights with pre-quantized INT8 weights from safetensors. |
| |
| For quantized layers, reconstructs the dequantized BF16 weight from |
| int_data * scale (equivalent to what torchao does at runtime, but without |
| needing torchao installed for loading). |
| |
| For runtime INT8 inference (keeping weights in INT8 and dequantizing during |
| matmul), use the torchao approach instead — see README.md Option 1. |
| """ |
| if config_path is None: |
| config_path = os.path.join(os.path.dirname(safetensors_path), "config.json") |
|
|
| tensors = load_file(safetensors_path, device=device) |
| logging.info(f"Loaded {len(tensors)} tensors from {safetensors_path}") |
|
|
| int_data_suffix = ".weight.__int_data" |
| scale_suffix = ".weight.__scale" |
|
|
| quantized_names = set() |
| for key in tensors: |
| if key.endswith(int_data_suffix): |
| name = key[: -len(int_data_suffix)] |
| quantized_names.add(name) |
|
|
| sd = model.state_dict() |
| loaded, skipped = 0, 0 |
|
|
| for key in list(sd.keys()): |
| parts = key.rsplit(".", 1) |
| if len(parts) == 2: |
| layer_name, param_name = parts |
| else: |
| layer_name, param_name = "", parts[0] |
|
|
| if layer_name in quantized_names and param_name == "weight": |
| int_data = tensors[f"{layer_name}{int_data_suffix}"] |
| scale = tensors[f"{layer_name}{scale_suffix}"] |
| sd[key] = (int_data.float() * scale.unsqueeze(1)).to(torch.bfloat16) |
| loaded += 1 |
| elif key in tensors: |
| sd[key] = tensors[key] |
| loaded += 1 |
| elif f"{layer_name}.{param_name}" in tensors: |
| sd[key] = tensors[f"{layer_name}.{param_name}"] |
| loaded += 1 |
| else: |
| skipped += 1 |
|
|
| model.load_state_dict(sd, strict=False) |
| logging.info( |
| f"Loaded {loaded} params ({len(quantized_names)} dequantized from INT8), " |
| f"skipped {skipped}" |
| ) |
| return model |
|
|
|
|
| def load_int8_dit_torchao( |
| model: torch.nn.Module, |
| safetensors_path: str, |
| device: str = "cuda", |
| ) -> torch.nn.Module: |
| """Load INT8 weights and apply torchao quantization for runtime INT8 matmul. |
| |
| This keeps weights in INT8 during inference (lower VRAM) but requires torchao. |
| """ |
| from torchao.quantization import quantize_, Int8WeightOnlyConfig |
|
|
| load_int8_dit(model, safetensors_path, device=device) |
|
|
| config_path = os.path.join(os.path.dirname(safetensors_path), "config.json") |
| with open(config_path) as f: |
| config = json.load(f) |
|
|
| quantized_set = set(config["quantized_layers"]) |
|
|
| def filter_fn(mod, fqn): |
| return isinstance(mod, torch.nn.Linear) and fqn in quantized_set |
|
|
| quantize_(model, Int8WeightOnlyConfig(), filter_fn=filter_fn) |
| logging.info(f"Applied torchao INT8 to {len(quantized_set)} layers") |
| return model |
|
|