| """ |
| Utility functions for DKM compression analysis. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import math |
| from typing import Dict, Optional |
|
|
|
|
| def compute_model_size(model: nn.Module, bits_per_param: float = 32.0) -> float: |
| """ |
| Compute model size in MB. |
| |
| Args: |
| model: PyTorch model |
| bits_per_param: Bits per parameter (32 for float32, 16 for float16) |
| |
| Returns: |
| Size in MB |
| """ |
| total_params = sum(p.numel() for p in model.parameters()) |
| size_bits = total_params * bits_per_param |
| size_mb = size_bits / 8 / 1024 / 1024 |
| return size_mb |
|
|
|
|
| def compute_compression_ratio( |
| original_model: nn.Module, |
| compressed_info: Dict, |
| ) -> float: |
| """ |
| Compute compression ratio between original and compressed model. |
| |
| Args: |
| original_model: Original uncompressed model |
| compressed_info: Info dict from DKMCompressor.get_compression_info() |
| |
| Returns: |
| Compression ratio (original_size / compressed_size) |
| """ |
| original_size = compute_model_size(original_model) |
| compressed_size = compressed_info["compressed_size_mb"] |
| return original_size / max(compressed_size, 1e-10) |
|
|
|
|
| def compute_effective_bpw(n_clusters: int, dim: int = 1) -> float: |
| """ |
| Compute effective bits-per-weight for a DKM configuration. |
| |
| Following Section 3.3: |
| bits_per_weight = log2(n_clusters) / dim |
| |
| For example: |
| - 4 bits, dim=4: 4/4 = 1 bpw |
| - 8 bits, dim=8: 8/8 = 1 bpw |
| - 4 bits, dim=8: 4/8 = 0.5 bpw |
| """ |
| bits = math.log2(n_clusters) |
| return bits / dim |
|
|
|
|
| def count_unique_weights(model: nn.Module) -> Dict[str, int]: |
| """ |
| Count unique weight values per layer (useful for verifying compression). |
| |
| After DKM compression + snapping, each layer should have at most |
| k unique weight values (or k unique d-dimensional vectors). |
| """ |
| result = {} |
| for name, module in model.named_modules(): |
| if isinstance(module, (nn.Conv2d, nn.Linear)): |
| with torch.no_grad(): |
| unique_vals = torch.unique(module.weight.data) |
| result[name] = len(unique_vals) |
| return result |
|
|
|
|
| def print_compression_summary(info: Dict): |
| """Pretty print compression information.""" |
| print("=" * 70) |
| print("DKM Compression Summary") |
| print("=" * 70) |
| print(f"Total parameters: {info['total_params']:,}") |
| print(f"Compressed parameters: {info['compressed_params']:,}") |
| print(f"Original size: {info['original_size_mb']:.2f} MB") |
| print(f"Compressed size: {info['compressed_size_mb']:.2f} MB") |
| print(f"Compression ratio: {info['compression_ratio']:.1f}x") |
| print("-" * 70) |
| print(f"{'Layer':<40} {'Params':>10} {'Bits':>5} {'Dim':>4} {'BPW':>6} {'CR':>6}") |
| print("-" * 70) |
| for name, layer_info in info["per_layer"].items(): |
| print( |
| f"{name:<40} {layer_info['n_params']:>10,} " |
| f"{layer_info['bits']:>5.0f} {layer_info['dim']:>4} " |
| f"{layer_info['bits_per_weight']:>6.2f} " |
| f"{layer_info['compression_ratio']:>6.1f}x" |
| ) |
| print("=" * 70) |
|
|