syedmohaiminulhoque commited on
Commit
8a77fae
·
verified ·
1 Parent(s): 63a23d1

Add utils, training script, and tests

Browse files
Files changed (1) hide show
  1. dkm/utils.py +99 -0
dkm/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utility functions for DKM compression analysis.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import math
8
+ from typing import Dict, Optional
9
+
10
+
11
+ def compute_model_size(model: nn.Module, bits_per_param: float = 32.0) -> float:
12
+ """
13
+ Compute model size in MB.
14
+
15
+ Args:
16
+ model: PyTorch model
17
+ bits_per_param: Bits per parameter (32 for float32, 16 for float16)
18
+
19
+ Returns:
20
+ Size in MB
21
+ """
22
+ total_params = sum(p.numel() for p in model.parameters())
23
+ size_bits = total_params * bits_per_param
24
+ size_mb = size_bits / 8 / 1024 / 1024
25
+ return size_mb
26
+
27
+
28
+ def compute_compression_ratio(
29
+ original_model: nn.Module,
30
+ compressed_info: Dict,
31
+ ) -> float:
32
+ """
33
+ Compute compression ratio between original and compressed model.
34
+
35
+ Args:
36
+ original_model: Original uncompressed model
37
+ compressed_info: Info dict from DKMCompressor.get_compression_info()
38
+
39
+ Returns:
40
+ Compression ratio (original_size / compressed_size)
41
+ """
42
+ original_size = compute_model_size(original_model)
43
+ compressed_size = compressed_info["compressed_size_mb"]
44
+ return original_size / max(compressed_size, 1e-10)
45
+
46
+
47
+ def compute_effective_bpw(n_clusters: int, dim: int = 1) -> float:
48
+ """
49
+ Compute effective bits-per-weight for a DKM configuration.
50
+
51
+ Following Section 3.3:
52
+ bits_per_weight = log2(n_clusters) / dim
53
+
54
+ For example:
55
+ - 4 bits, dim=4: 4/4 = 1 bpw
56
+ - 8 bits, dim=8: 8/8 = 1 bpw
57
+ - 4 bits, dim=8: 4/8 = 0.5 bpw
58
+ """
59
+ bits = math.log2(n_clusters)
60
+ return bits / dim
61
+
62
+
63
+ def count_unique_weights(model: nn.Module) -> Dict[str, int]:
64
+ """
65
+ Count unique weight values per layer (useful for verifying compression).
66
+
67
+ After DKM compression + snapping, each layer should have at most
68
+ k unique weight values (or k unique d-dimensional vectors).
69
+ """
70
+ result = {}
71
+ for name, module in model.named_modules():
72
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
73
+ with torch.no_grad():
74
+ unique_vals = torch.unique(module.weight.data)
75
+ result[name] = len(unique_vals)
76
+ return result
77
+
78
+
79
+ def print_compression_summary(info: Dict):
80
+ """Pretty print compression information."""
81
+ print("=" * 70)
82
+ print("DKM Compression Summary")
83
+ print("=" * 70)
84
+ print(f"Total parameters: {info['total_params']:,}")
85
+ print(f"Compressed parameters: {info['compressed_params']:,}")
86
+ print(f"Original size: {info['original_size_mb']:.2f} MB")
87
+ print(f"Compressed size: {info['compressed_size_mb']:.2f} MB")
88
+ print(f"Compression ratio: {info['compression_ratio']:.1f}x")
89
+ print("-" * 70)
90
+ print(f"{'Layer':<40} {'Params':>10} {'Bits':>5} {'Dim':>4} {'BPW':>6} {'CR':>6}")
91
+ print("-" * 70)
92
+ for name, layer_info in info["per_layer"].items():
93
+ print(
94
+ f"{name:<40} {layer_info['n_params']:>10,} "
95
+ f"{layer_info['bits']:>5.0f} {layer_info['dim']:>4} "
96
+ f"{layer_info['bits_per_weight']:>6.2f} "
97
+ f"{layer_info['compression_ratio']:>6.1f}x"
98
+ )
99
+ print("=" * 70)