File size: 3,593 Bytes
dac920b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Load pre-quantized INT8 DramaBox DiT weights from safetensors.

Quantized layers are stored as:
  {layer_name}.weight.__int_data  (INT8 tensor)
  {layer_name}.weight.__scale     (BF16 per-channel scale)

Non-quantized layers are stored as plain BF16 tensors.

Usage:
    from load_int8 import load_int8_dit
    load_int8_dit(tts._velocity_model, "dramabox-dit-int8-selective.safetensors")
"""
import json
import logging
import os

import torch
from safetensors.torch import load_file

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")


def load_int8_dit(
    model: torch.nn.Module,
    safetensors_path: str,
    config_path: str | None = None,
    device: str = "cuda",
) -> torch.nn.Module:
    """Replace model weights with pre-quantized INT8 weights from safetensors.

    For quantized layers, reconstructs the dequantized BF16 weight from
    int_data * scale (equivalent to what torchao does at runtime, but without
    needing torchao installed for loading).

    For runtime INT8 inference (keeping weights in INT8 and dequantizing during
    matmul), use the torchao approach instead — see README.md Option 1.
    """
    if config_path is None:
        config_path = os.path.join(os.path.dirname(safetensors_path), "config.json")

    tensors = load_file(safetensors_path, device=device)
    logging.info(f"Loaded {len(tensors)} tensors from {safetensors_path}")

    int_data_suffix = ".weight.__int_data"
    scale_suffix = ".weight.__scale"

    quantized_names = set()
    for key in tensors:
        if key.endswith(int_data_suffix):
            name = key[: -len(int_data_suffix)]
            quantized_names.add(name)

    sd = model.state_dict()
    loaded, skipped = 0, 0

    for key in list(sd.keys()):
        parts = key.rsplit(".", 1)
        if len(parts) == 2:
            layer_name, param_name = parts
        else:
            layer_name, param_name = "", parts[0]

        if layer_name in quantized_names and param_name == "weight":
            int_data = tensors[f"{layer_name}{int_data_suffix}"]
            scale = tensors[f"{layer_name}{scale_suffix}"]
            sd[key] = (int_data.float() * scale.unsqueeze(1)).to(torch.bfloat16)
            loaded += 1
        elif key in tensors:
            sd[key] = tensors[key]
            loaded += 1
        elif f"{layer_name}.{param_name}" in tensors:
            sd[key] = tensors[f"{layer_name}.{param_name}"]
            loaded += 1
        else:
            skipped += 1

    model.load_state_dict(sd, strict=False)
    logging.info(
        f"Loaded {loaded} params ({len(quantized_names)} dequantized from INT8), "
        f"skipped {skipped}"
    )
    return model


def load_int8_dit_torchao(
    model: torch.nn.Module,
    safetensors_path: str,
    device: str = "cuda",
) -> torch.nn.Module:
    """Load INT8 weights and apply torchao quantization for runtime INT8 matmul.

    This keeps weights in INT8 during inference (lower VRAM) but requires torchao.
    """
    from torchao.quantization import quantize_, Int8WeightOnlyConfig

    load_int8_dit(model, safetensors_path, device=device)

    config_path = os.path.join(os.path.dirname(safetensors_path), "config.json")
    with open(config_path) as f:
        config = json.load(f)

    quantized_set = set(config["quantized_layers"])

    def filter_fn(mod, fqn):
        return isinstance(mod, torch.nn.Linear) and fqn in quantized_set

    quantize_(model, Int8WeightOnlyConfig(), filter_fn=filter_fn)
    logging.info(f"Applied torchao INT8 to {len(quantized_set)} layers")
    return model