File size: 3,123 Bytes
8a77fae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | """
Utility functions for DKM compression analysis.
"""
import torch
import torch.nn as nn
import math
from typing import Dict, Optional
def compute_model_size(model: nn.Module, bits_per_param: float = 32.0) -> float:
"""
Compute model size in MB.
Args:
model: PyTorch model
bits_per_param: Bits per parameter (32 for float32, 16 for float16)
Returns:
Size in MB
"""
total_params = sum(p.numel() for p in model.parameters())
size_bits = total_params * bits_per_param
size_mb = size_bits / 8 / 1024 / 1024
return size_mb
def compute_compression_ratio(
original_model: nn.Module,
compressed_info: Dict,
) -> float:
"""
Compute compression ratio between original and compressed model.
Args:
original_model: Original uncompressed model
compressed_info: Info dict from DKMCompressor.get_compression_info()
Returns:
Compression ratio (original_size / compressed_size)
"""
original_size = compute_model_size(original_model)
compressed_size = compressed_info["compressed_size_mb"]
return original_size / max(compressed_size, 1e-10)
def compute_effective_bpw(n_clusters: int, dim: int = 1) -> float:
"""
Compute effective bits-per-weight for a DKM configuration.
Following Section 3.3:
bits_per_weight = log2(n_clusters) / dim
For example:
- 4 bits, dim=4: 4/4 = 1 bpw
- 8 bits, dim=8: 8/8 = 1 bpw
- 4 bits, dim=8: 4/8 = 0.5 bpw
"""
bits = math.log2(n_clusters)
return bits / dim
def count_unique_weights(model: nn.Module) -> Dict[str, int]:
"""
Count unique weight values per layer (useful for verifying compression).
After DKM compression + snapping, each layer should have at most
k unique weight values (or k unique d-dimensional vectors).
"""
result = {}
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
with torch.no_grad():
unique_vals = torch.unique(module.weight.data)
result[name] = len(unique_vals)
return result
def print_compression_summary(info: Dict):
"""Pretty print compression information."""
print("=" * 70)
print("DKM Compression Summary")
print("=" * 70)
print(f"Total parameters: {info['total_params']:,}")
print(f"Compressed parameters: {info['compressed_params']:,}")
print(f"Original size: {info['original_size_mb']:.2f} MB")
print(f"Compressed size: {info['compressed_size_mb']:.2f} MB")
print(f"Compression ratio: {info['compression_ratio']:.1f}x")
print("-" * 70)
print(f"{'Layer':<40} {'Params':>10} {'Bits':>5} {'Dim':>4} {'BPW':>6} {'CR':>6}")
print("-" * 70)
for name, layer_info in info["per_layer"].items():
print(
f"{name:<40} {layer_info['n_params']:>10,} "
f"{layer_info['bits']:>5.0f} {layer_info['dim']:>4} "
f"{layer_info['bits_per_weight']:>6.2f} "
f"{layer_info['compression_ratio']:>6.1f}x"
)
print("=" * 70)
|