"""Load pre-quantized INT8 DramaBox DiT weights from safetensors. Quantized layers are stored as: {layer_name}.weight.__int_data (INT8 tensor) {layer_name}.weight.__scale (BF16 per-channel scale) Non-quantized layers are stored as plain BF16 tensors. Usage: from load_int8 import load_int8_dit load_int8_dit(tts._velocity_model, "dramabox-dit-int8-selective.safetensors") """ import json import logging import os import torch from safetensors.torch import load_file logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") def load_int8_dit( model: torch.nn.Module, safetensors_path: str, config_path: str | None = None, device: str = "cuda", ) -> torch.nn.Module: """Replace model weights with pre-quantized INT8 weights from safetensors. For quantized layers, reconstructs the dequantized BF16 weight from int_data * scale (equivalent to what torchao does at runtime, but without needing torchao installed for loading). For runtime INT8 inference (keeping weights in INT8 and dequantizing during matmul), use the torchao approach instead — see README.md Option 1. """ if config_path is None: config_path = os.path.join(os.path.dirname(safetensors_path), "config.json") tensors = load_file(safetensors_path, device=device) logging.info(f"Loaded {len(tensors)} tensors from {safetensors_path}") int_data_suffix = ".weight.__int_data" scale_suffix = ".weight.__scale" quantized_names = set() for key in tensors: if key.endswith(int_data_suffix): name = key[: -len(int_data_suffix)] quantized_names.add(name) sd = model.state_dict() loaded, skipped = 0, 0 for key in list(sd.keys()): parts = key.rsplit(".", 1) if len(parts) == 2: layer_name, param_name = parts else: layer_name, param_name = "", parts[0] if layer_name in quantized_names and param_name == "weight": int_data = tensors[f"{layer_name}{int_data_suffix}"] scale = tensors[f"{layer_name}{scale_suffix}"] sd[key] = (int_data.float() * scale.unsqueeze(1)).to(torch.bfloat16) loaded += 1 elif key in tensors: sd[key] = tensors[key] loaded += 1 elif f"{layer_name}.{param_name}" in tensors: sd[key] = tensors[f"{layer_name}.{param_name}"] loaded += 1 else: skipped += 1 model.load_state_dict(sd, strict=False) logging.info( f"Loaded {loaded} params ({len(quantized_names)} dequantized from INT8), " f"skipped {skipped}" ) return model def load_int8_dit_torchao( model: torch.nn.Module, safetensors_path: str, device: str = "cuda", ) -> torch.nn.Module: """Load INT8 weights and apply torchao quantization for runtime INT8 matmul. This keeps weights in INT8 during inference (lower VRAM) but requires torchao. """ from torchao.quantization import quantize_, Int8WeightOnlyConfig load_int8_dit(model, safetensors_path, device=device) config_path = os.path.join(os.path.dirname(safetensors_path), "config.json") with open(config_path) as f: config = json.load(f) quantized_set = set(config["quantized_layers"]) def filter_fn(mod, fqn): return isinstance(mod, torch.nn.Linear) and fqn in quantized_set quantize_(model, Int8WeightOnlyConfig(), filter_fn=filter_fn) logging.info(f"Applied torchao INT8 to {len(quantized_set)} layers") return model