File size: 6,591 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable

import torch


@dataclass
class TensorState:
    name: str
    shape: tuple[int, ...]
    dtype: str
    bytes: int
    trainable: bool = False


@dataclass
class TernaryAudit:
    logical_ternary_weights: int
    ternary_packed_bytes: int
    ternary_scale_bytes: int
    ternary_scale_accum_bytes: int
    ternary_accum_bytes: int
    ternary_corr_accum_bytes: int
    ternary_step_counter_bytes: int
    trainable_float_params: list[TensorState]
    frozen_float_params: list[TensorState]
    float_buffers: list[TensorState]

    @property
    def ternary_training_bytes(self) -> int:
        return (
            self.ternary_packed_bytes
            + self.ternary_scale_bytes
            + self.ternary_scale_accum_bytes
            + self.ternary_accum_bytes
            + self.ternary_corr_accum_bytes
            + self.ternary_step_counter_bytes
        )

    @property
    def trainable_float_bytes(self) -> int:
        return sum(item.bytes for item in self.trainable_float_params)

    @property
    def frozen_float_bytes(self) -> int:
        return sum(item.bytes for item in self.frozen_float_params)

    @property
    def float_buffer_bytes(self) -> int:
        return sum(item.bytes for item in self.float_buffers)


def _tensor_bytes(t: torch.Tensor) -> int:
    return t.numel() * t.element_size()


def _tensor_state(name: str, t: torch.Tensor, trainable: bool = False) -> TensorState:
    return TensorState(
        name=name,
        shape=tuple(t.shape),
        dtype=str(t.dtype).replace("torch.", ""),
        bytes=_tensor_bytes(t),
        trainable=trainable,
    )


def _mb(n_bytes: int) -> float:
    return n_bytes / (1024 * 1024)


def audit_model(model: torch.nn.Module) -> TernaryAudit:
    logical_ternary_weights = 0
    ternary_packed_bytes = 0
    ternary_scale_bytes = 0
    ternary_scale_accum_bytes = 0
    ternary_accum_bytes = 0
    ternary_corr_accum_bytes = 0
    ternary_step_counter_bytes = 0

    for module in model.modules():
        if hasattr(module, "T_packed") and hasattr(module, "_T_shape"):
            shape = tuple(int(x) for x in module._T_shape.tolist())
            n_weights = 1
            for dim in shape:
                n_weights *= dim
            logical_ternary_weights += n_weights
            ternary_packed_bytes += _tensor_bytes(module.T_packed)
            if hasattr(module, "E"):
                ternary_scale_bytes += _tensor_bytes(module.E)
            if hasattr(module, "E_accum"):
                ternary_scale_accum_bytes += _tensor_bytes(module.E_accum)
            if hasattr(module, "T_accum"):
                ternary_accum_bytes += _tensor_bytes(module.T_accum)
            if hasattr(module, "corr_accum"):
                ternary_corr_accum_bytes += _tensor_bytes(module.corr_accum)
            if hasattr(module, "step_counter"):
                ternary_step_counter_bytes += _tensor_bytes(module.step_counter)

    trainable_float_params: list[TensorState] = []
    frozen_float_params: list[TensorState] = []
    for name, param in model.named_parameters():
        if not param.dtype.is_floating_point:
            continue
        state = _tensor_state(name, param, trainable=param.requires_grad)
        if param.requires_grad:
            trainable_float_params.append(state)
        else:
            frozen_float_params.append(state)

    float_buffers = [
        _tensor_state(name, buf)
        for name, buf in model.named_buffers()
        if buf.dtype.is_floating_point
    ]

    return TernaryAudit(
        logical_ternary_weights=logical_ternary_weights,
        ternary_packed_bytes=ternary_packed_bytes,
        ternary_scale_bytes=ternary_scale_bytes,
        ternary_scale_accum_bytes=ternary_scale_accum_bytes,
        ternary_accum_bytes=ternary_accum_bytes,
        ternary_corr_accum_bytes=ternary_corr_accum_bytes,
        ternary_step_counter_bytes=ternary_step_counter_bytes,
        trainable_float_params=trainable_float_params,
        frozen_float_params=frozen_float_params,
        float_buffers=float_buffers,
    )


def format_audit(audit: TernaryAudit, limit: int = 12) -> str:
    lines = [
        "Ternary state audit:",
        f"  logical ternary weights: {audit.logical_ternary_weights:,}",
        (
            "  ternary training state: "
            f"{_mb(audit.ternary_training_bytes):.2f} MB "
            f"(T={_mb(audit.ternary_packed_bytes):.2f}, "
            f"E={_mb(audit.ternary_scale_bytes):.2f}, "
            f"E_accum={_mb(audit.ternary_scale_accum_bytes):.2f}, "
            f"T_accum={_mb(audit.ternary_accum_bytes):.2f}, "
            f"corr_accum={_mb(audit.ternary_corr_accum_bytes):.2f}, "
            f"steps={_mb(audit.ternary_step_counter_bytes):.4f})"
        ),
        (
            "  trainable float params: "
            f"{len(audit.trainable_float_params)} tensors, "
            f"{_mb(audit.trainable_float_bytes):.2f} MB"
        ),
        (
            "  frozen float params: "
            f"{len(audit.frozen_float_params)} tensors, "
            f"{_mb(audit.frozen_float_bytes):.2f} MB"
        ),
        (
            "  float buffers: "
            f"{len(audit.float_buffers)} tensors, "
            f"{_mb(audit.float_buffer_bytes):.2f} MB"
        ),
    ]

    if audit.trainable_float_params:
        lines.append("  largest trainable float params:")
        for item in sorted(audit.trainable_float_params, key=lambda x: x.bytes, reverse=True)[:limit]:
            lines.append(f"    {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB")

    if audit.float_buffers:
        lines.append("  largest float buffers:")
        for item in sorted(audit.float_buffers, key=lambda x: x.bytes, reverse=True)[:limit]:
            lines.append(f"    {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB")

    return "\n".join(lines)


def freeze_float_parameters(
    model: torch.nn.Module,
    allow_prefixes: Iterable[str] = (),
) -> list[TensorState]:
    allow = tuple(allow_prefixes)
    frozen: list[TensorState] = []
    for name, param in model.named_parameters():
        if allow and name.startswith(allow):
            continue
        if param.dtype.is_floating_point and param.requires_grad:
            frozen.append(_tensor_state(name, param, trainable=True))
            param.requires_grad_(False)
    return frozen


def trainable_parameters(model: torch.nn.Module) -> list[torch.nn.Parameter]:
    return [p for p in model.parameters() if p.requires_grad]