| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Iterable |
|
|
| import torch |
|
|
|
|
| @dataclass |
| class TensorState: |
| name: str |
| shape: tuple[int, ...] |
| dtype: str |
| bytes: int |
| trainable: bool = False |
|
|
|
|
| @dataclass |
| class TernaryAudit: |
| logical_ternary_weights: int |
| ternary_packed_bytes: int |
| ternary_scale_bytes: int |
| ternary_scale_accum_bytes: int |
| ternary_accum_bytes: int |
| ternary_corr_accum_bytes: int |
| ternary_step_counter_bytes: int |
| trainable_float_params: list[TensorState] |
| frozen_float_params: list[TensorState] |
| float_buffers: list[TensorState] |
|
|
| @property |
| def ternary_training_bytes(self) -> int: |
| return ( |
| self.ternary_packed_bytes |
| + self.ternary_scale_bytes |
| + self.ternary_scale_accum_bytes |
| + self.ternary_accum_bytes |
| + self.ternary_corr_accum_bytes |
| + self.ternary_step_counter_bytes |
| ) |
|
|
| @property |
| def trainable_float_bytes(self) -> int: |
| return sum(item.bytes for item in self.trainable_float_params) |
|
|
| @property |
| def frozen_float_bytes(self) -> int: |
| return sum(item.bytes for item in self.frozen_float_params) |
|
|
| @property |
| def float_buffer_bytes(self) -> int: |
| return sum(item.bytes for item in self.float_buffers) |
|
|
|
|
| def _tensor_bytes(t: torch.Tensor) -> int: |
| return t.numel() * t.element_size() |
|
|
|
|
| def _tensor_state(name: str, t: torch.Tensor, trainable: bool = False) -> TensorState: |
| return TensorState( |
| name=name, |
| shape=tuple(t.shape), |
| dtype=str(t.dtype).replace("torch.", ""), |
| bytes=_tensor_bytes(t), |
| trainable=trainable, |
| ) |
|
|
|
|
| def _mb(n_bytes: int) -> float: |
| return n_bytes / (1024 * 1024) |
|
|
|
|
| def audit_model(model: torch.nn.Module) -> TernaryAudit: |
| logical_ternary_weights = 0 |
| ternary_packed_bytes = 0 |
| ternary_scale_bytes = 0 |
| ternary_scale_accum_bytes = 0 |
| ternary_accum_bytes = 0 |
| ternary_corr_accum_bytes = 0 |
| ternary_step_counter_bytes = 0 |
|
|
| for module in model.modules(): |
| if hasattr(module, "T_packed") and hasattr(module, "_T_shape"): |
| shape = tuple(int(x) for x in module._T_shape.tolist()) |
| n_weights = 1 |
| for dim in shape: |
| n_weights *= dim |
| logical_ternary_weights += n_weights |
| ternary_packed_bytes += _tensor_bytes(module.T_packed) |
| if hasattr(module, "E"): |
| ternary_scale_bytes += _tensor_bytes(module.E) |
| if hasattr(module, "E_accum"): |
| ternary_scale_accum_bytes += _tensor_bytes(module.E_accum) |
| if hasattr(module, "T_accum"): |
| ternary_accum_bytes += _tensor_bytes(module.T_accum) |
| if hasattr(module, "corr_accum"): |
| ternary_corr_accum_bytes += _tensor_bytes(module.corr_accum) |
| if hasattr(module, "step_counter"): |
| ternary_step_counter_bytes += _tensor_bytes(module.step_counter) |
|
|
| trainable_float_params: list[TensorState] = [] |
| frozen_float_params: list[TensorState] = [] |
| for name, param in model.named_parameters(): |
| if not param.dtype.is_floating_point: |
| continue |
| state = _tensor_state(name, param, trainable=param.requires_grad) |
| if param.requires_grad: |
| trainable_float_params.append(state) |
| else: |
| frozen_float_params.append(state) |
|
|
| float_buffers = [ |
| _tensor_state(name, buf) |
| for name, buf in model.named_buffers() |
| if buf.dtype.is_floating_point |
| ] |
|
|
| return TernaryAudit( |
| logical_ternary_weights=logical_ternary_weights, |
| ternary_packed_bytes=ternary_packed_bytes, |
| ternary_scale_bytes=ternary_scale_bytes, |
| ternary_scale_accum_bytes=ternary_scale_accum_bytes, |
| ternary_accum_bytes=ternary_accum_bytes, |
| ternary_corr_accum_bytes=ternary_corr_accum_bytes, |
| ternary_step_counter_bytes=ternary_step_counter_bytes, |
| trainable_float_params=trainable_float_params, |
| frozen_float_params=frozen_float_params, |
| float_buffers=float_buffers, |
| ) |
|
|
|
|
| def format_audit(audit: TernaryAudit, limit: int = 12) -> str: |
| lines = [ |
| "Ternary state audit:", |
| f" logical ternary weights: {audit.logical_ternary_weights:,}", |
| ( |
| " ternary training state: " |
| f"{_mb(audit.ternary_training_bytes):.2f} MB " |
| f"(T={_mb(audit.ternary_packed_bytes):.2f}, " |
| f"E={_mb(audit.ternary_scale_bytes):.2f}, " |
| f"E_accum={_mb(audit.ternary_scale_accum_bytes):.2f}, " |
| f"T_accum={_mb(audit.ternary_accum_bytes):.2f}, " |
| f"corr_accum={_mb(audit.ternary_corr_accum_bytes):.2f}, " |
| f"steps={_mb(audit.ternary_step_counter_bytes):.4f})" |
| ), |
| ( |
| " trainable float params: " |
| f"{len(audit.trainable_float_params)} tensors, " |
| f"{_mb(audit.trainable_float_bytes):.2f} MB" |
| ), |
| ( |
| " frozen float params: " |
| f"{len(audit.frozen_float_params)} tensors, " |
| f"{_mb(audit.frozen_float_bytes):.2f} MB" |
| ), |
| ( |
| " float buffers: " |
| f"{len(audit.float_buffers)} tensors, " |
| f"{_mb(audit.float_buffer_bytes):.2f} MB" |
| ), |
| ] |
|
|
| if audit.trainable_float_params: |
| lines.append(" largest trainable float params:") |
| for item in sorted(audit.trainable_float_params, key=lambda x: x.bytes, reverse=True)[:limit]: |
| lines.append(f" {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB") |
|
|
| if audit.float_buffers: |
| lines.append(" largest float buffers:") |
| for item in sorted(audit.float_buffers, key=lambda x: x.bytes, reverse=True)[:limit]: |
| lines.append(f" {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB") |
|
|
| return "\n".join(lines) |
|
|
|
|
| def freeze_float_parameters( |
| model: torch.nn.Module, |
| allow_prefixes: Iterable[str] = (), |
| ) -> list[TensorState]: |
| allow = tuple(allow_prefixes) |
| frozen: list[TensorState] = [] |
| for name, param in model.named_parameters(): |
| if allow and name.startswith(allow): |
| continue |
| if param.dtype.is_floating_point and param.requires_grad: |
| frozen.append(_tensor_state(name, param, trainable=True)) |
| param.requires_grad_(False) |
| return frozen |
|
|
|
|
| def trainable_parameters(model: torch.nn.Module) -> list[torch.nn.Parameter]: |
| return [p for p in model.parameters() if p.requires_grad] |
|
|