from __future__ import annotations from dataclasses import dataclass from typing import Iterable import torch @dataclass class TensorState: name: str shape: tuple[int, ...] dtype: str bytes: int trainable: bool = False @dataclass class TernaryAudit: logical_ternary_weights: int ternary_packed_bytes: int ternary_scale_bytes: int ternary_scale_accum_bytes: int ternary_accum_bytes: int ternary_corr_accum_bytes: int ternary_step_counter_bytes: int trainable_float_params: list[TensorState] frozen_float_params: list[TensorState] float_buffers: list[TensorState] @property def ternary_training_bytes(self) -> int: return ( self.ternary_packed_bytes + self.ternary_scale_bytes + self.ternary_scale_accum_bytes + self.ternary_accum_bytes + self.ternary_corr_accum_bytes + self.ternary_step_counter_bytes ) @property def trainable_float_bytes(self) -> int: return sum(item.bytes for item in self.trainable_float_params) @property def frozen_float_bytes(self) -> int: return sum(item.bytes for item in self.frozen_float_params) @property def float_buffer_bytes(self) -> int: return sum(item.bytes for item in self.float_buffers) def _tensor_bytes(t: torch.Tensor) -> int: return t.numel() * t.element_size() def _tensor_state(name: str, t: torch.Tensor, trainable: bool = False) -> TensorState: return TensorState( name=name, shape=tuple(t.shape), dtype=str(t.dtype).replace("torch.", ""), bytes=_tensor_bytes(t), trainable=trainable, ) def _mb(n_bytes: int) -> float: return n_bytes / (1024 * 1024) def audit_model(model: torch.nn.Module) -> TernaryAudit: logical_ternary_weights = 0 ternary_packed_bytes = 0 ternary_scale_bytes = 0 ternary_scale_accum_bytes = 0 ternary_accum_bytes = 0 ternary_corr_accum_bytes = 0 ternary_step_counter_bytes = 0 for module in model.modules(): if hasattr(module, "T_packed") and hasattr(module, "_T_shape"): shape = tuple(int(x) for x in module._T_shape.tolist()) n_weights = 1 for dim in shape: n_weights *= dim logical_ternary_weights += n_weights ternary_packed_bytes += _tensor_bytes(module.T_packed) if hasattr(module, "E"): ternary_scale_bytes += _tensor_bytes(module.E) if hasattr(module, "E_accum"): ternary_scale_accum_bytes += _tensor_bytes(module.E_accum) if hasattr(module, "T_accum"): ternary_accum_bytes += _tensor_bytes(module.T_accum) if hasattr(module, "corr_accum"): ternary_corr_accum_bytes += _tensor_bytes(module.corr_accum) if hasattr(module, "step_counter"): ternary_step_counter_bytes += _tensor_bytes(module.step_counter) trainable_float_params: list[TensorState] = [] frozen_float_params: list[TensorState] = [] for name, param in model.named_parameters(): if not param.dtype.is_floating_point: continue state = _tensor_state(name, param, trainable=param.requires_grad) if param.requires_grad: trainable_float_params.append(state) else: frozen_float_params.append(state) float_buffers = [ _tensor_state(name, buf) for name, buf in model.named_buffers() if buf.dtype.is_floating_point ] return TernaryAudit( logical_ternary_weights=logical_ternary_weights, ternary_packed_bytes=ternary_packed_bytes, ternary_scale_bytes=ternary_scale_bytes, ternary_scale_accum_bytes=ternary_scale_accum_bytes, ternary_accum_bytes=ternary_accum_bytes, ternary_corr_accum_bytes=ternary_corr_accum_bytes, ternary_step_counter_bytes=ternary_step_counter_bytes, trainable_float_params=trainable_float_params, frozen_float_params=frozen_float_params, float_buffers=float_buffers, ) def format_audit(audit: TernaryAudit, limit: int = 12) -> str: lines = [ "Ternary state audit:", f" logical ternary weights: {audit.logical_ternary_weights:,}", ( " ternary training state: " f"{_mb(audit.ternary_training_bytes):.2f} MB " f"(T={_mb(audit.ternary_packed_bytes):.2f}, " f"E={_mb(audit.ternary_scale_bytes):.2f}, " f"E_accum={_mb(audit.ternary_scale_accum_bytes):.2f}, " f"T_accum={_mb(audit.ternary_accum_bytes):.2f}, " f"corr_accum={_mb(audit.ternary_corr_accum_bytes):.2f}, " f"steps={_mb(audit.ternary_step_counter_bytes):.4f})" ), ( " trainable float params: " f"{len(audit.trainable_float_params)} tensors, " f"{_mb(audit.trainable_float_bytes):.2f} MB" ), ( " frozen float params: " f"{len(audit.frozen_float_params)} tensors, " f"{_mb(audit.frozen_float_bytes):.2f} MB" ), ( " float buffers: " f"{len(audit.float_buffers)} tensors, " f"{_mb(audit.float_buffer_bytes):.2f} MB" ), ] if audit.trainable_float_params: lines.append(" largest trainable float params:") for item in sorted(audit.trainable_float_params, key=lambda x: x.bytes, reverse=True)[:limit]: lines.append(f" {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB") if audit.float_buffers: lines.append(" largest float buffers:") for item in sorted(audit.float_buffers, key=lambda x: x.bytes, reverse=True)[:limit]: lines.append(f" {item.name}: {item.shape} {item.dtype} {_mb(item.bytes):.2f} MB") return "\n".join(lines) def freeze_float_parameters( model: torch.nn.Module, allow_prefixes: Iterable[str] = (), ) -> list[TensorState]: allow = tuple(allow_prefixes) frozen: list[TensorState] = [] for name, param in model.named_parameters(): if allow and name.startswith(allow): continue if param.dtype.is_floating_point and param.requires_grad: frozen.append(_tensor_state(name, param, trainable=True)) param.requires_grad_(False) return frozen def trainable_parameters(model: torch.nn.Module) -> list[torch.nn.Parameter]: return [p for p in model.parameters() if p.requires_grad]