ARBS / arbitor /kernel /ternary_audit.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable
import torch
@dataclass
class TensorState:
name: str
shape: tuple[int, ...]
dtype: str
bytes: int
trainable: bool = False
@dataclass
class TernaryAudit:
logical_ternary_weights: int
ternary_packed_bytes: int
ternary_scale_bytes: int
ternary_scale_accum_bytes: int
ternary_accum_bytes: int
ternary_corr_accum_bytes: int
ternary_step_counter_bytes: int
trainable_float_params: list[TensorState]
frozen_float_params: list[TensorState]
float_buffers: list[TensorState]
@property
def ternary_training_bytes(self) -> int:
return (
self.ternary_packed_bytes
+ self.ternary_scale_bytes
+ self.ternary_scale_accum_bytes
+ self.ternary_accum_bytes
+ self.ternary_corr_accum_bytes
+ self.ternary_step_counter_bytes
)
@property
def trainable_float_bytes(self) -> int:
return sum(item.bytes for item in self.trainable_float_params)
@property
def frozen_float_bytes(self) -> int:
return sum(item.bytes for item in self.frozen_float_params)
@property
def float_buffer_bytes(self) -> int:
return sum(item.bytes for item in self.float_buffers)
def _tensor_bytes(t: torch.Tensor) -> int:
return t.numel() * t.element_size()
def _tensor_state(name: str, t: torch.Tensor, trainable: bool = False) -> TensorState:
return TensorState(
name=name,
shape=tuple(t.shape),
dtype=str(t.dtype).replace("torch.", ""),
bytes=_tensor_bytes(t),
trainable=trainable,
)
def _mb(n_bytes: int) -> float:
return n_bytes / (1024 * 1024)
def audit_model(model: torch.nn.Module) -> TernaryAudit:
logical_ternary_weights = 0
ternary_packed_bytes = 0
ternary_scale_bytes = 0
ternary_scale_accum_bytes = 0
ternary_accum_bytes = 0
ternary_corr_accum_bytes = 0
ternary_step_counter_bytes = 0
for module in model.modules():
if hasattr(module, "T_packed") and hasattr(module, "_T_shape"):
shape = tuple(int(x) for x in module._T_shape.tolist())
n_weights = 1
for dim in shape:
n_weights *= dim
logical_ternary_weights += n_weights
ternary_packed_bytes += _tensor_bytes(module.T_packed)
if hasattr(module, "E"):
ternary_scale_bytes += _tensor_bytes(module.E)
if hasattr(module, "E_accum"):
ternary_scale_accum_bytes += _tensor_bytes(module.E_accum)
if hasattr(module, "T_accum"):
ternary_accum_bytes += _tensor_bytes(module.T_accum)
if hasattr(module, "corr_accum"):
ternary_corr_accum_bytes += _tensor_bytes(module.corr_accum)
if hasattr(module, "step_counter"):
ternary_step_counter_bytes += _tensor_bytes(module.step_counter)
trainable_float_params: list[TensorState] = []
frozen_float_params: list[TensorState] = []
for name, param in model.named_parameters():
if not param.dtype.is_floating_point:
continue
state = _tensor_state(name, param, trainable=param.requires_grad)
if param.requires_grad:
trainable_float_params.append(state)
else:
frozen_float_params.append(state)
float_buffers = [
_tensor_state(name, buf)
for name, buf in model.named_buffers()
if buf.dtype.is_floating_point
]
return TernaryAudit(
logical_ternary_weights=logical_ternary_weights,
ternary_packed_bytes=ternary_packed_bytes,
ternary_scale_bytes=ternary_scale_bytes,
ternary_scale_accum_bytes=ternary_scale_accum_bytes,
ternary_accum_bytes=ternary_accum_bytes,
ternary_corr_accum_bytes=ternary_corr_accum_bytes,
ternary_step_counter_bytes=ternary_step_counter_bytes,
trainable_float_params=trainable_float_params,
frozen_float_params=frozen_float_params,
float_buffers=float_buffers,
)
def format_audit(audit: TernaryAudit, limit: int = 12) -> str:
lines = [
"Ternary state audit:",
f" logical ternary weights: {audit.logical_ternary_weights:,}",
(
" ternary training state: "
f"{_mb(audit.ternary_training_bytes):.2f} MB "
f"(T={_mb(audit.ternary_packed_bytes):.2f}, "
f"E={_mb(audit.ternary_scale_bytes):.2f}, "
f"E_accum={_mb(audit.ternary_scale_accum_bytes):.2f}, "
f"T_accum={_mb(audit.ternary_accum_bytes):.2f}, "
f"corr_accum={_mb(audit.ternary_corr_accum_bytes):.2f}, "
f"steps={_mb(audit.ternary_step_counter_bytes):.4f})"
),
(
" trainable float params: "
f"{len(audit.trainable_float_params)} tensors, "
f"{_mb(audit.trainable_float_bytes):.2f} MB"
),
(
" frozen float params: "
f"{len(audit.frozen_float_params)} tensors, "
f"{_mb(audit.frozen_float_bytes):.2f} MB"
),
(
" float buffers: "
f"{len(audit.float_buffers)} tensors, "
f"{_mb(audit.float_buffer_bytes):.2f} MB"
),
]
if audit.trainable_float_params:
lines.append(" largest trainable float params:")
for item in sorted(audit.trainable_float_params, key=lambda x: x.bytes, reverse=True)[:limit]:
lines.append(f" {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB")
if audit.float_buffers:
lines.append(" largest float buffers:")
for item in sorted(audit.float_buffers, key=lambda x: x.bytes, reverse=True)[:limit]:
lines.append(f" {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB")
return "\n".join(lines)
def freeze_float_parameters(
model: torch.nn.Module,
allow_prefixes: Iterable[str] = (),
) -> list[TensorState]:
allow = tuple(allow_prefixes)
frozen: list[TensorState] = []
for name, param in model.named_parameters():
if allow and name.startswith(allow):
continue
if param.dtype.is_floating_point and param.requires_grad:
frozen.append(_tensor_state(name, param, trainable=True))
param.requires_grad_(False)
return frozen
def trainable_parameters(model: torch.nn.Module) -> list[torch.nn.Parameter]:
return [p for p in model.parameters() if p.requires_grad]