File size: 3,123 Bytes
8a77fae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Utility functions for DKM compression analysis.
"""

import torch
import torch.nn as nn
import math
from typing import Dict, Optional


def compute_model_size(model: nn.Module, bits_per_param: float = 32.0) -> float:
    """
    Compute model size in MB.
    
    Args:
        model: PyTorch model
        bits_per_param: Bits per parameter (32 for float32, 16 for float16)
    
    Returns:
        Size in MB
    """
    total_params = sum(p.numel() for p in model.parameters())
    size_bits = total_params * bits_per_param
    size_mb = size_bits / 8 / 1024 / 1024
    return size_mb


def compute_compression_ratio(
    original_model: nn.Module,
    compressed_info: Dict,
) -> float:
    """
    Compute compression ratio between original and compressed model.
    
    Args:
        original_model: Original uncompressed model
        compressed_info: Info dict from DKMCompressor.get_compression_info()
    
    Returns:
        Compression ratio (original_size / compressed_size)
    """
    original_size = compute_model_size(original_model)
    compressed_size = compressed_info["compressed_size_mb"]
    return original_size / max(compressed_size, 1e-10)


def compute_effective_bpw(n_clusters: int, dim: int = 1) -> float:
    """
    Compute effective bits-per-weight for a DKM configuration.
    
    Following Section 3.3:
    bits_per_weight = log2(n_clusters) / dim
    
    For example:
    - 4 bits, dim=4: 4/4 = 1 bpw
    - 8 bits, dim=8: 8/8 = 1 bpw  
    - 4 bits, dim=8: 4/8 = 0.5 bpw
    """
    bits = math.log2(n_clusters)
    return bits / dim


def count_unique_weights(model: nn.Module) -> Dict[str, int]:
    """
    Count unique weight values per layer (useful for verifying compression).
    
    After DKM compression + snapping, each layer should have at most
    k unique weight values (or k unique d-dimensional vectors).
    """
    result = {}
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            with torch.no_grad():
                unique_vals = torch.unique(module.weight.data)
                result[name] = len(unique_vals)
    return result


def print_compression_summary(info: Dict):
    """Pretty print compression information."""
    print("=" * 70)
    print("DKM Compression Summary")
    print("=" * 70)
    print(f"Total parameters:     {info['total_params']:,}")
    print(f"Compressed parameters: {info['compressed_params']:,}")
    print(f"Original size:        {info['original_size_mb']:.2f} MB")
    print(f"Compressed size:      {info['compressed_size_mb']:.2f} MB")
    print(f"Compression ratio:    {info['compression_ratio']:.1f}x")
    print("-" * 70)
    print(f"{'Layer':<40} {'Params':>10} {'Bits':>5} {'Dim':>4} {'BPW':>6} {'CR':>6}")
    print("-" * 70)
    for name, layer_info in info["per_layer"].items():
        print(
            f"{name:<40} {layer_info['n_params']:>10,} "
            f"{layer_info['bits']:>5.0f} {layer_info['dim']:>4} "
            f"{layer_info['bits_per_weight']:>6.2f} "
            f"{layer_info['compression_ratio']:>6.1f}x"
        )
    print("=" * 70)