Add utils, training script, and tests
Browse files- dkm/utils.py +99 -0
dkm/utils.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for DKM compression analysis.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import math
|
| 8 |
+
from typing import Dict, Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def compute_model_size(model: nn.Module, bits_per_param: float = 32.0) -> float:
|
| 12 |
+
"""
|
| 13 |
+
Compute model size in MB.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
model: PyTorch model
|
| 17 |
+
bits_per_param: Bits per parameter (32 for float32, 16 for float16)
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
Size in MB
|
| 21 |
+
"""
|
| 22 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 23 |
+
size_bits = total_params * bits_per_param
|
| 24 |
+
size_mb = size_bits / 8 / 1024 / 1024
|
| 25 |
+
return size_mb
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def compute_compression_ratio(
|
| 29 |
+
original_model: nn.Module,
|
| 30 |
+
compressed_info: Dict,
|
| 31 |
+
) -> float:
|
| 32 |
+
"""
|
| 33 |
+
Compute compression ratio between original and compressed model.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
original_model: Original uncompressed model
|
| 37 |
+
compressed_info: Info dict from DKMCompressor.get_compression_info()
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Compression ratio (original_size / compressed_size)
|
| 41 |
+
"""
|
| 42 |
+
original_size = compute_model_size(original_model)
|
| 43 |
+
compressed_size = compressed_info["compressed_size_mb"]
|
| 44 |
+
return original_size / max(compressed_size, 1e-10)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def compute_effective_bpw(n_clusters: int, dim: int = 1) -> float:
|
| 48 |
+
"""
|
| 49 |
+
Compute effective bits-per-weight for a DKM configuration.
|
| 50 |
+
|
| 51 |
+
Following Section 3.3:
|
| 52 |
+
bits_per_weight = log2(n_clusters) / dim
|
| 53 |
+
|
| 54 |
+
For example:
|
| 55 |
+
- 4 bits, dim=4: 4/4 = 1 bpw
|
| 56 |
+
- 8 bits, dim=8: 8/8 = 1 bpw
|
| 57 |
+
- 4 bits, dim=8: 4/8 = 0.5 bpw
|
| 58 |
+
"""
|
| 59 |
+
bits = math.log2(n_clusters)
|
| 60 |
+
return bits / dim
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def count_unique_weights(model: nn.Module) -> Dict[str, int]:
|
| 64 |
+
"""
|
| 65 |
+
Count unique weight values per layer (useful for verifying compression).
|
| 66 |
+
|
| 67 |
+
After DKM compression + snapping, each layer should have at most
|
| 68 |
+
k unique weight values (or k unique d-dimensional vectors).
|
| 69 |
+
"""
|
| 70 |
+
result = {}
|
| 71 |
+
for name, module in model.named_modules():
|
| 72 |
+
if isinstance(module, (nn.Conv2d, nn.Linear)):
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
unique_vals = torch.unique(module.weight.data)
|
| 75 |
+
result[name] = len(unique_vals)
|
| 76 |
+
return result
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def print_compression_summary(info: Dict):
|
| 80 |
+
"""Pretty print compression information."""
|
| 81 |
+
print("=" * 70)
|
| 82 |
+
print("DKM Compression Summary")
|
| 83 |
+
print("=" * 70)
|
| 84 |
+
print(f"Total parameters: {info['total_params']:,}")
|
| 85 |
+
print(f"Compressed parameters: {info['compressed_params']:,}")
|
| 86 |
+
print(f"Original size: {info['original_size_mb']:.2f} MB")
|
| 87 |
+
print(f"Compressed size: {info['compressed_size_mb']:.2f} MB")
|
| 88 |
+
print(f"Compression ratio: {info['compression_ratio']:.1f}x")
|
| 89 |
+
print("-" * 70)
|
| 90 |
+
print(f"{'Layer':<40} {'Params':>10} {'Bits':>5} {'Dim':>4} {'BPW':>6} {'CR':>6}")
|
| 91 |
+
print("-" * 70)
|
| 92 |
+
for name, layer_info in info["per_layer"].items():
|
| 93 |
+
print(
|
| 94 |
+
f"{name:<40} {layer_info['n_params']:>10,} "
|
| 95 |
+
f"{layer_info['bits']:>5.0f} {layer_info['dim']:>4} "
|
| 96 |
+
f"{layer_info['bits_per_weight']:>6.2f} "
|
| 97 |
+
f"{layer_info['compression_ratio']:>6.1f}x"
|
| 98 |
+
)
|
| 99 |
+
print("=" * 70)
|