dkm-compression / dkm /utils.py
syedmohaiminulhoque's picture
Add utils, training script, and tests
8a77fae verified
"""
Utility functions for DKM compression analysis.
"""
import torch
import torch.nn as nn
import math
from typing import Dict, Optional
def compute_model_size(model: nn.Module, bits_per_param: float = 32.0) -> float:
"""
Compute model size in MB.
Args:
model: PyTorch model
bits_per_param: Bits per parameter (32 for float32, 16 for float16)
Returns:
Size in MB
"""
total_params = sum(p.numel() for p in model.parameters())
size_bits = total_params * bits_per_param
size_mb = size_bits / 8 / 1024 / 1024
return size_mb
def compute_compression_ratio(
original_model: nn.Module,
compressed_info: Dict,
) -> float:
"""
Compute compression ratio between original and compressed model.
Args:
original_model: Original uncompressed model
compressed_info: Info dict from DKMCompressor.get_compression_info()
Returns:
Compression ratio (original_size / compressed_size)
"""
original_size = compute_model_size(original_model)
compressed_size = compressed_info["compressed_size_mb"]
return original_size / max(compressed_size, 1e-10)
def compute_effective_bpw(n_clusters: int, dim: int = 1) -> float:
"""
Compute effective bits-per-weight for a DKM configuration.
Following Section 3.3:
bits_per_weight = log2(n_clusters) / dim
For example:
- 4 bits, dim=4: 4/4 = 1 bpw
- 8 bits, dim=8: 8/8 = 1 bpw
- 4 bits, dim=8: 4/8 = 0.5 bpw
"""
bits = math.log2(n_clusters)
return bits / dim
def count_unique_weights(model: nn.Module) -> Dict[str, int]:
"""
Count unique weight values per layer (useful for verifying compression).
After DKM compression + snapping, each layer should have at most
k unique weight values (or k unique d-dimensional vectors).
"""
result = {}
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
with torch.no_grad():
unique_vals = torch.unique(module.weight.data)
result[name] = len(unique_vals)
return result
def print_compression_summary(info: Dict):
"""Pretty print compression information."""
print("=" * 70)
print("DKM Compression Summary")
print("=" * 70)
print(f"Total parameters: {info['total_params']:,}")
print(f"Compressed parameters: {info['compressed_params']:,}")
print(f"Original size: {info['original_size_mb']:.2f} MB")
print(f"Compressed size: {info['compressed_size_mb']:.2f} MB")
print(f"Compression ratio: {info['compression_ratio']:.1f}x")
print("-" * 70)
print(f"{'Layer':<40} {'Params':>10} {'Bits':>5} {'Dim':>4} {'BPW':>6} {'CR':>6}")
print("-" * 70)
for name, layer_info in info["per_layer"].items():
print(
f"{name:<40} {layer_info['n_params']:>10,} "
f"{layer_info['bits']:>5.0f} {layer_info['dim']:>4} "
f"{layer_info['bits_per_weight']:>6.2f} "
f"{layer_info['compression_ratio']:>6.1f}x"
)
print("=" * 70)