syedmohaiminulhoque commited on
Commit
63a23d1
·
verified ·
1 Parent(s): f5e358d

Add compressor and utils modules

Browse files
Files changed (1) hide show
  1. dkm/compressor.py +393 -0
dkm/compressor.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DKM Model Compressor
3
+
4
+ Wraps a pre-trained PyTorch model with DKM layers for weight clustering compression.
5
+ Follows the paper's approach of inserting DKM layers into the forward pass
6
+ (Section 3.2) without modifying the loss function or model architecture.
7
+
8
+ Supports per-layer configuration of bits and dimensions as described in Section 4.1.
9
+ """
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import math
14
+ from typing import Dict, Optional, Tuple, List, Union
15
+ from collections import OrderedDict
16
+
17
+ from .dkm_layer import DKMLayer
18
+
19
+
20
+ class DKMCompressor(nn.Module):
21
+ """
22
+ Wraps a pre-trained model with DKM clustering layers.
23
+
24
+ During forward pass, each wrapped weight parameter is replaced by its
25
+ DKM-compressed version. The original weights are kept as parameters
26
+ for gradient updates, while DKM layers control the clustering.
27
+
28
+ Args:
29
+ model: Pre-trained PyTorch model to compress
30
+ bits: Default number of bits for clustering (k = 2^bits)
31
+ dim: Default dimension for multi-dimensional clustering
32
+ tau: Default temperature for softmax attention
33
+ max_iter: Maximum DKM iterations per layer per forward pass
34
+ epsilon: Convergence threshold
35
+ layer_config: Optional per-layer configuration dict
36
+ Format: {layer_name: {"bits": int, "dim": int, "tau": float}}
37
+ skip_layers: List of layer names to skip (not compress)
38
+ min_params: Minimum number of parameters in a layer to compress
39
+ (paper uses 10000 for special handling)
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ model: nn.Module,
45
+ bits: int = 2,
46
+ dim: int = 1,
47
+ tau: float = 2e-5,
48
+ max_iter: int = 5,
49
+ epsilon: float = 1e-4,
50
+ layer_config: Optional[Dict] = None,
51
+ skip_layers: Optional[List[str]] = None,
52
+ min_params: int = 0,
53
+ skip_first_last: bool = False,
54
+ ):
55
+ super().__init__()
56
+
57
+ self.model = model
58
+ self.bits = bits
59
+ self.dim = dim
60
+ self.tau = tau
61
+ self.max_iter = max_iter
62
+ self.epsilon = epsilon
63
+ self.layer_config = layer_config or {}
64
+ self.skip_layers = skip_layers or []
65
+ self.min_params = min_params
66
+ self.skip_first_last = skip_first_last
67
+
68
+ # Create DKM layers for each applicable weight parameter
69
+ self.dkm_layers = nn.ModuleDict()
70
+ self._hooks = []
71
+
72
+ self._setup_dkm_layers()
73
+
74
+ def _get_compressible_layers(self) -> List[Tuple[str, nn.Module]]:
75
+ """
76
+ Identify layers that should be compressed.
77
+
78
+ Following the paper (Section 4.1):
79
+ - Compress Conv2d and Linear layers
80
+ - Skip layers in skip_layers list
81
+ - Optionally skip first and last layers (Table 1 protocol)
82
+ - Skip layers with fewer than min_params parameters
83
+ """
84
+ compressible = []
85
+ all_layers = []
86
+
87
+ for name, module in self.model.named_modules():
88
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
89
+ all_layers.append((name, module))
90
+
91
+ for i, (name, module) in enumerate(all_layers):
92
+ # Skip first/last layers if requested (common protocol from Table 1)
93
+ if self.skip_first_last:
94
+ if i == 0 or i == len(all_layers) - 1:
95
+ continue
96
+
97
+ # Skip explicitly excluded layers
98
+ if any(skip in name for skip in self.skip_layers):
99
+ continue
100
+
101
+ # Skip small layers
102
+ n_params = module.weight.numel()
103
+ if n_params < self.min_params:
104
+ continue
105
+
106
+ compressible.append((name, module))
107
+
108
+ return compressible
109
+
110
+ def _get_layer_config(self, name: str, module: nn.Module) -> dict:
111
+ """
112
+ Get DKM configuration for a specific layer.
113
+
114
+ Per the paper (Section 4.1):
115
+ - Different bits/dim for conv vs fc layers
116
+ - Layers with <10000 params get 8-bit clustering
117
+ - Per-layer config overrides defaults
118
+ """
119
+ config = {
120
+ "bits": self.bits,
121
+ "dim": self.dim,
122
+ "tau": self.tau,
123
+ "max_iter": self.max_iter,
124
+ "epsilon": self.epsilon,
125
+ }
126
+
127
+ # Paper: "we applied 8 bit clustering to a layer with fewer than 10,000 parameters"
128
+ if module.weight.numel() < 10000:
129
+ config["bits"] = 8
130
+ config["dim"] = 1
131
+
132
+ # Per-layer overrides
133
+ if name in self.layer_config:
134
+ config.update(self.layer_config[name])
135
+
136
+ # Check for wildcard config (e.g., "conv" applies to all conv layers)
137
+ for pattern, pattern_config in self.layer_config.items():
138
+ if pattern != name and pattern in name:
139
+ config.update(pattern_config)
140
+
141
+ return config
142
+
143
+ def _setup_dkm_layers(self):
144
+ """
145
+ Create DKM layers and register forward hooks to replace weights
146
+ during forward pass.
147
+ """
148
+ compressible_layers = self._get_compressible_layers()
149
+
150
+ for name, module in compressible_layers:
151
+ config = self._get_layer_config(name, module)
152
+
153
+ n_clusters = 2 ** config["bits"]
154
+ dim = config["dim"]
155
+
156
+ # Validate dim is compatible with weight size
157
+ n_elements = module.weight.numel()
158
+ if n_elements % dim != 0:
159
+ # Adjust dim to nearest valid value
160
+ while dim > 1 and n_elements % dim != 0:
161
+ dim -= 1
162
+ config["dim"] = dim
163
+
164
+ # Create DKM layer
165
+ safe_name = name.replace(".", "_")
166
+ dkm_layer = DKMLayer(
167
+ weight_tensor=module.weight,
168
+ n_clusters=n_clusters,
169
+ tau=config["tau"],
170
+ dim=dim,
171
+ max_iter=config["max_iter"],
172
+ epsilon=config["epsilon"],
173
+ )
174
+
175
+ self.dkm_layers[safe_name] = dkm_layer
176
+
177
+ # Register forward pre-hook to replace weight during forward pass
178
+ hook = module.register_forward_pre_hook(
179
+ self._make_hook(safe_name, module)
180
+ )
181
+ self._hooks.append(hook)
182
+
183
+ def _make_hook(self, dkm_name: str, module: nn.Module):
184
+ """
185
+ Create a forward pre-hook that replaces the module's weight with
186
+ the DKM-compressed version during forward pass.
187
+
188
+ This implements the paper's approach: DKM is inserted into the
189
+ forward pass, making optimization fully aligned with the task objective.
190
+ """
191
+ def hook(mod, input):
192
+ dkm_layer = self.dkm_layers[dkm_name]
193
+ # Get compressed weight from DKM layer
194
+ compressed_weight = dkm_layer(weight_override=mod.weight)
195
+ # Replace weight for this forward pass
196
+ mod.weight.data = compressed_weight
197
+
198
+ return hook
199
+
200
+ def forward(self, *args, **kwargs):
201
+ """Forward pass through the wrapped model with DKM compression."""
202
+ return self.model(*args, **kwargs)
203
+
204
+ def snap_weights(self):
205
+ """
206
+ Snap all weights to nearest centroids for inference.
207
+
208
+ This is the final step before deployment: each weight is permanently
209
+ assigned to its nearest centroid. After this, the model can be
210
+ serialized as (codebook + assignments) for compression.
211
+ """
212
+ with torch.no_grad():
213
+ for name, module in self.model.named_modules():
214
+ safe_name = name.replace(".", "_")
215
+ if safe_name in self.dkm_layers:
216
+ dkm_layer = self.dkm_layers[safe_name]
217
+ dkm_layer.eval()
218
+ compressed_weight = dkm_layer()
219
+ module.weight.data.copy_(compressed_weight)
220
+
221
+ def get_compression_info(self) -> Dict:
222
+ """
223
+ Compute compression statistics for the model.
224
+
225
+ Returns dict with:
226
+ - total_params: Total number of parameters
227
+ - compressed_params: Number of compressed parameters
228
+ - original_size_mb: Original model size in MB (32-bit float)
229
+ - compressed_size_mb: Compressed model size in MB
230
+ - compression_ratio: Original/Compressed size ratio
231
+ - per_layer: Per-layer compression details
232
+ """
233
+ info = {
234
+ "per_layer": {},
235
+ "total_params": 0,
236
+ "compressed_params": 0,
237
+ "original_bits": 0,
238
+ "compressed_bits": 0,
239
+ }
240
+
241
+ # Count all parameters
242
+ for name, param in self.model.named_parameters():
243
+ n_params = param.numel()
244
+ info["total_params"] += n_params
245
+ info["original_bits"] += n_params * 32 # float32
246
+
247
+ # Count compressed layers
248
+ compressed_param_names = set()
249
+ for name, module in self.model.named_modules():
250
+ safe_name = name.replace(".", "_")
251
+ if safe_name in self.dkm_layers:
252
+ dkm_layer = self.dkm_layers[safe_name]
253
+ n_params = module.weight.numel()
254
+
255
+ bits = math.log2(dkm_layer.n_clusters)
256
+ dim = dkm_layer.dim
257
+ bpw = bits / dim # effective bits per weight
258
+
259
+ # Compressed size:
260
+ # - Codebook: k * d * 32 bits (centroids stored in float32)
261
+ # - Assignments: (N/d) * bits indices
262
+ n_vectors = n_params // dim
263
+ codebook_bits = dkm_layer.n_clusters * dim * 32
264
+ assignment_bits = n_vectors * bits
265
+ layer_compressed_bits = codebook_bits + assignment_bits
266
+
267
+ info["per_layer"][name] = {
268
+ "n_params": n_params,
269
+ "n_clusters": dkm_layer.n_clusters,
270
+ "dim": dim,
271
+ "bits": bits,
272
+ "bits_per_weight": bpw,
273
+ "original_bits": n_params * 32,
274
+ "compressed_bits": layer_compressed_bits,
275
+ "compression_ratio": (n_params * 32) / max(layer_compressed_bits, 1),
276
+ }
277
+
278
+ info["compressed_params"] += n_params
279
+ info["compressed_bits"] += layer_compressed_bits
280
+ compressed_param_names.add(name + ".weight")
281
+
282
+ # Uncompressed parameters contribute their full size
283
+ uncompressed_bits = 0
284
+ for pname, param in self.model.named_parameters():
285
+ if pname not in compressed_param_names:
286
+ uncompressed_bits += param.numel() * 32
287
+
288
+ info["compressed_bits"] += uncompressed_bits
289
+ info["original_size_mb"] = info["original_bits"] / 8 / 1024 / 1024
290
+ info["compressed_size_mb"] = info["compressed_bits"] / 8 / 1024 / 1024
291
+ info["compression_ratio"] = info["original_bits"] / max(info["compressed_bits"], 1)
292
+
293
+ return info
294
+
295
+ def export_compressed(self) -> Dict:
296
+ """
297
+ Export the compressed model as codebook + assignments.
298
+
299
+ Returns a dict with:
300
+ - 'state_dict': Original model state dict (with snapped weights)
301
+ - 'codebooks': {layer_name: centroid tensor}
302
+ - 'assignments': {layer_name: assignment index tensor}
303
+ - 'layer_configs': {layer_name: {bits, dim, ...}}
304
+ """
305
+ self.snap_weights()
306
+
307
+ export = {
308
+ "state_dict": self.model.state_dict(),
309
+ "codebooks": {},
310
+ "assignments": {},
311
+ "layer_configs": {},
312
+ }
313
+
314
+ for name, module in self.model.named_modules():
315
+ safe_name = name.replace(".", "_")
316
+ if safe_name in self.dkm_layers:
317
+ dkm_layer = self.dkm_layers[safe_name]
318
+ export["codebooks"][name] = dkm_layer.get_codebook()
319
+ export["assignments"][name] = dkm_layer.get_assignments()
320
+ export["layer_configs"][name] = {
321
+ "n_clusters": dkm_layer.n_clusters,
322
+ "dim": dkm_layer.dim,
323
+ "tau": dkm_layer.tau,
324
+ "original_shape": list(dkm_layer.original_shape),
325
+ }
326
+
327
+ return export
328
+
329
+ def remove_hooks(self):
330
+ """Remove all forward hooks (for clean serialization)."""
331
+ for hook in self._hooks:
332
+ hook.remove()
333
+ self._hooks.clear()
334
+
335
+ def __del__(self):
336
+ """Cleanup hooks on deletion."""
337
+ self.remove_hooks()
338
+
339
+
340
+ def compress_model(
341
+ model: nn.Module,
342
+ bits: int = 2,
343
+ dim: int = 1,
344
+ tau: float = 2e-5,
345
+ conv_config: Optional[Dict] = None,
346
+ fc_config: Optional[Dict] = None,
347
+ skip_first_last: bool = True,
348
+ min_params: int = 0,
349
+ **kwargs,
350
+ ) -> DKMCompressor:
351
+ """
352
+ High-level API to compress a pre-trained model using DKM.
353
+
354
+ Follows the paper's convention of separate config for conv and fc layers.
355
+ For example, "cv:6/8, fc:6/4" means:
356
+ - Conv layers: 6 bits, 8 dimensions
357
+ - FC layers: 6 bits, 4 dimensions
358
+
359
+ Args:
360
+ model: Pre-trained PyTorch model
361
+ bits: Default bits for all layers
362
+ dim: Default dimension for clustering
363
+ tau: Temperature parameter
364
+ conv_config: Config for conv layers {"bits": int, "dim": int}
365
+ fc_config: Config for fc layers {"bits": int, "dim": int}
366
+ skip_first_last: Skip first and last layers (Table 1 protocol)
367
+ min_params: Minimum params to compress a layer
368
+
369
+ Returns:
370
+ DKMCompressor wrapping the model
371
+ """
372
+ # Build per-layer config based on conv/fc separation
373
+ layer_config = {}
374
+
375
+ if conv_config or fc_config:
376
+ for name, module in model.named_modules():
377
+ if isinstance(module, nn.Conv2d) and conv_config:
378
+ layer_config[name] = {**conv_config}
379
+ elif isinstance(module, nn.Linear) and fc_config:
380
+ layer_config[name] = {**fc_config}
381
+
382
+ compressor = DKMCompressor(
383
+ model=model,
384
+ bits=bits,
385
+ dim=dim,
386
+ tau=tau,
387
+ layer_config=layer_config,
388
+ skip_first_last=skip_first_last,
389
+ min_params=min_params,
390
+ **kwargs,
391
+ )
392
+
393
+ return compressor