""" Utility functions for DKM compression analysis. """ import torch import torch.nn as nn import math from typing import Dict, Optional def compute_model_size(model: nn.Module, bits_per_param: float = 32.0) -> float: """ Compute model size in MB. Args: model: PyTorch model bits_per_param: Bits per parameter (32 for float32, 16 for float16) Returns: Size in MB """ total_params = sum(p.numel() for p in model.parameters()) size_bits = total_params * bits_per_param size_mb = size_bits / 8 / 1024 / 1024 return size_mb def compute_compression_ratio( original_model: nn.Module, compressed_info: Dict, ) -> float: """ Compute compression ratio between original and compressed model. Args: original_model: Original uncompressed model compressed_info: Info dict from DKMCompressor.get_compression_info() Returns: Compression ratio (original_size / compressed_size) """ original_size = compute_model_size(original_model) compressed_size = compressed_info["compressed_size_mb"] return original_size / max(compressed_size, 1e-10) def compute_effective_bpw(n_clusters: int, dim: int = 1) -> float: """ Compute effective bits-per-weight for a DKM configuration. Following Section 3.3: bits_per_weight = log2(n_clusters) / dim For example: - 4 bits, dim=4: 4/4 = 1 bpw - 8 bits, dim=8: 8/8 = 1 bpw - 4 bits, dim=8: 4/8 = 0.5 bpw """ bits = math.log2(n_clusters) return bits / dim def count_unique_weights(model: nn.Module) -> Dict[str, int]: """ Count unique weight values per layer (useful for verifying compression). After DKM compression + snapping, each layer should have at most k unique weight values (or k unique d-dimensional vectors). """ result = {} for name, module in model.named_modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): with torch.no_grad(): unique_vals = torch.unique(module.weight.data) result[name] = len(unique_vals) return result def print_compression_summary(info: Dict): """Pretty print compression information.""" print("=" * 70) print("DKM Compression Summary") print("=" * 70) print(f"Total parameters: {info['total_params']:,}") print(f"Compressed parameters: {info['compressed_params']:,}") print(f"Original size: {info['original_size_mb']:.2f} MB") print(f"Compressed size: {info['compressed_size_mb']:.2f} MB") print(f"Compression ratio: {info['compression_ratio']:.1f}x") print("-" * 70) print(f"{'Layer':<40} {'Params':>10} {'Bits':>5} {'Dim':>4} {'BPW':>6} {'CR':>6}") print("-" * 70) for name, layer_info in info["per_layer"].items(): print( f"{name:<40} {layer_info['n_params']:>10,} " f"{layer_info['bits']:>5.0f} {layer_info['dim']:>4} " f"{layer_info['bits_per_weight']:>6.2f} " f"{layer_info['compression_ratio']:>6.1f}x" ) print("=" * 70)