File size: 14,420 Bytes
63a23d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""
DKM Model Compressor

Wraps a pre-trained PyTorch model with DKM layers for weight clustering compression.
Follows the paper's approach of inserting DKM layers into the forward pass
(Section 3.2) without modifying the loss function or model architecture.

Supports per-layer configuration of bits and dimensions as described in Section 4.1.
"""

import torch
import torch.nn as nn
import math
from typing import Dict, Optional, Tuple, List, Union
from collections import OrderedDict

from .dkm_layer import DKMLayer


class DKMCompressor(nn.Module):
    """
    Wraps a pre-trained model with DKM clustering layers.
    
    During forward pass, each wrapped weight parameter is replaced by its
    DKM-compressed version. The original weights are kept as parameters
    for gradient updates, while DKM layers control the clustering.
    
    Args:
        model: Pre-trained PyTorch model to compress
        bits: Default number of bits for clustering (k = 2^bits)
        dim: Default dimension for multi-dimensional clustering
        tau: Default temperature for softmax attention
        max_iter: Maximum DKM iterations per layer per forward pass
        epsilon: Convergence threshold
        layer_config: Optional per-layer configuration dict
            Format: {layer_name: {"bits": int, "dim": int, "tau": float}}
        skip_layers: List of layer names to skip (not compress)
        min_params: Minimum number of parameters in a layer to compress
            (paper uses 10000 for special handling)
    """
    
    def __init__(
        self,
        model: nn.Module,
        bits: int = 2,
        dim: int = 1,
        tau: float = 2e-5,
        max_iter: int = 5,
        epsilon: float = 1e-4,
        layer_config: Optional[Dict] = None,
        skip_layers: Optional[List[str]] = None,
        min_params: int = 0,
        skip_first_last: bool = False,
    ):
        super().__init__()
        
        self.model = model
        self.bits = bits
        self.dim = dim
        self.tau = tau
        self.max_iter = max_iter
        self.epsilon = epsilon
        self.layer_config = layer_config or {}
        self.skip_layers = skip_layers or []
        self.min_params = min_params
        self.skip_first_last = skip_first_last
        
        # Create DKM layers for each applicable weight parameter
        self.dkm_layers = nn.ModuleDict()
        self._hooks = []
        
        self._setup_dkm_layers()
    
    def _get_compressible_layers(self) -> List[Tuple[str, nn.Module]]:
        """
        Identify layers that should be compressed.
        
        Following the paper (Section 4.1):
        - Compress Conv2d and Linear layers
        - Skip layers in skip_layers list
        - Optionally skip first and last layers (Table 1 protocol)
        - Skip layers with fewer than min_params parameters
        """
        compressible = []
        all_layers = []
        
        for name, module in self.model.named_modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                all_layers.append((name, module))
        
        for i, (name, module) in enumerate(all_layers):
            # Skip first/last layers if requested (common protocol from Table 1)
            if self.skip_first_last:
                if i == 0 or i == len(all_layers) - 1:
                    continue
            
            # Skip explicitly excluded layers
            if any(skip in name for skip in self.skip_layers):
                continue
            
            # Skip small layers
            n_params = module.weight.numel()
            if n_params < self.min_params:
                continue
            
            compressible.append((name, module))
        
        return compressible
    
    def _get_layer_config(self, name: str, module: nn.Module) -> dict:
        """
        Get DKM configuration for a specific layer.
        
        Per the paper (Section 4.1):
        - Different bits/dim for conv vs fc layers
        - Layers with <10000 params get 8-bit clustering
        - Per-layer config overrides defaults
        """
        config = {
            "bits": self.bits,
            "dim": self.dim, 
            "tau": self.tau,
            "max_iter": self.max_iter,
            "epsilon": self.epsilon,
        }
        
        # Paper: "we applied 8 bit clustering to a layer with fewer than 10,000 parameters"
        if module.weight.numel() < 10000:
            config["bits"] = 8
            config["dim"] = 1
        
        # Per-layer overrides
        if name in self.layer_config:
            config.update(self.layer_config[name])
        
        # Check for wildcard config (e.g., "conv" applies to all conv layers)
        for pattern, pattern_config in self.layer_config.items():
            if pattern != name and pattern in name:
                config.update(pattern_config)
        
        return config
    
    def _setup_dkm_layers(self):
        """
        Create DKM layers and register forward hooks to replace weights
        during forward pass.
        """
        compressible_layers = self._get_compressible_layers()
        
        for name, module in compressible_layers:
            config = self._get_layer_config(name, module)
            
            n_clusters = 2 ** config["bits"]
            dim = config["dim"]
            
            # Validate dim is compatible with weight size
            n_elements = module.weight.numel()
            if n_elements % dim != 0:
                # Adjust dim to nearest valid value
                while dim > 1 and n_elements % dim != 0:
                    dim -= 1
                config["dim"] = dim
            
            # Create DKM layer
            safe_name = name.replace(".", "_")
            dkm_layer = DKMLayer(
                weight_tensor=module.weight,
                n_clusters=n_clusters,
                tau=config["tau"],
                dim=dim,
                max_iter=config["max_iter"],
                epsilon=config["epsilon"],
            )
            
            self.dkm_layers[safe_name] = dkm_layer
            
            # Register forward pre-hook to replace weight during forward pass
            hook = module.register_forward_pre_hook(
                self._make_hook(safe_name, module)
            )
            self._hooks.append(hook)
    
    def _make_hook(self, dkm_name: str, module: nn.Module):
        """
        Create a forward pre-hook that replaces the module's weight with
        the DKM-compressed version during forward pass.
        
        This implements the paper's approach: DKM is inserted into the
        forward pass, making optimization fully aligned with the task objective.
        """
        def hook(mod, input):
            dkm_layer = self.dkm_layers[dkm_name]
            # Get compressed weight from DKM layer
            compressed_weight = dkm_layer(weight_override=mod.weight)
            # Replace weight for this forward pass
            mod.weight.data = compressed_weight
        
        return hook
    
    def forward(self, *args, **kwargs):
        """Forward pass through the wrapped model with DKM compression."""
        return self.model(*args, **kwargs)
    
    def snap_weights(self):
        """
        Snap all weights to nearest centroids for inference.
        
        This is the final step before deployment: each weight is permanently
        assigned to its nearest centroid. After this, the model can be
        serialized as (codebook + assignments) for compression.
        """
        with torch.no_grad():
            for name, module in self.model.named_modules():
                safe_name = name.replace(".", "_")
                if safe_name in self.dkm_layers:
                    dkm_layer = self.dkm_layers[safe_name]
                    dkm_layer.eval()
                    compressed_weight = dkm_layer()
                    module.weight.data.copy_(compressed_weight)
    
    def get_compression_info(self) -> Dict:
        """
        Compute compression statistics for the model.
        
        Returns dict with:
        - total_params: Total number of parameters
        - compressed_params: Number of compressed parameters
        - original_size_mb: Original model size in MB (32-bit float)
        - compressed_size_mb: Compressed model size in MB
        - compression_ratio: Original/Compressed size ratio
        - per_layer: Per-layer compression details
        """
        info = {
            "per_layer": {},
            "total_params": 0,
            "compressed_params": 0,
            "original_bits": 0,
            "compressed_bits": 0,
        }
        
        # Count all parameters
        for name, param in self.model.named_parameters():
            n_params = param.numel()
            info["total_params"] += n_params
            info["original_bits"] += n_params * 32  # float32
        
        # Count compressed layers
        compressed_param_names = set()
        for name, module in self.model.named_modules():
            safe_name = name.replace(".", "_")
            if safe_name in self.dkm_layers:
                dkm_layer = self.dkm_layers[safe_name]
                n_params = module.weight.numel()
                
                bits = math.log2(dkm_layer.n_clusters)
                dim = dkm_layer.dim
                bpw = bits / dim  # effective bits per weight
                
                # Compressed size: 
                # - Codebook: k * d * 32 bits (centroids stored in float32)
                # - Assignments: (N/d) * bits indices
                n_vectors = n_params // dim
                codebook_bits = dkm_layer.n_clusters * dim * 32
                assignment_bits = n_vectors * bits
                layer_compressed_bits = codebook_bits + assignment_bits
                
                info["per_layer"][name] = {
                    "n_params": n_params,
                    "n_clusters": dkm_layer.n_clusters,
                    "dim": dim,
                    "bits": bits,
                    "bits_per_weight": bpw,
                    "original_bits": n_params * 32,
                    "compressed_bits": layer_compressed_bits,
                    "compression_ratio": (n_params * 32) / max(layer_compressed_bits, 1),
                }
                
                info["compressed_params"] += n_params
                info["compressed_bits"] += layer_compressed_bits
                compressed_param_names.add(name + ".weight")
        
        # Uncompressed parameters contribute their full size
        uncompressed_bits = 0
        for pname, param in self.model.named_parameters():
            if pname not in compressed_param_names:
                uncompressed_bits += param.numel() * 32
        
        info["compressed_bits"] += uncompressed_bits
        info["original_size_mb"] = info["original_bits"] / 8 / 1024 / 1024
        info["compressed_size_mb"] = info["compressed_bits"] / 8 / 1024 / 1024
        info["compression_ratio"] = info["original_bits"] / max(info["compressed_bits"], 1)
        
        return info
    
    def export_compressed(self) -> Dict:
        """
        Export the compressed model as codebook + assignments.
        
        Returns a dict with:
        - 'state_dict': Original model state dict (with snapped weights)
        - 'codebooks': {layer_name: centroid tensor}
        - 'assignments': {layer_name: assignment index tensor}
        - 'layer_configs': {layer_name: {bits, dim, ...}}
        """
        self.snap_weights()
        
        export = {
            "state_dict": self.model.state_dict(),
            "codebooks": {},
            "assignments": {},
            "layer_configs": {},
        }
        
        for name, module in self.model.named_modules():
            safe_name = name.replace(".", "_")
            if safe_name in self.dkm_layers:
                dkm_layer = self.dkm_layers[safe_name]
                export["codebooks"][name] = dkm_layer.get_codebook()
                export["assignments"][name] = dkm_layer.get_assignments()
                export["layer_configs"][name] = {
                    "n_clusters": dkm_layer.n_clusters,
                    "dim": dkm_layer.dim,
                    "tau": dkm_layer.tau,
                    "original_shape": list(dkm_layer.original_shape),
                }
        
        return export
    
    def remove_hooks(self):
        """Remove all forward hooks (for clean serialization)."""
        for hook in self._hooks:
            hook.remove()
        self._hooks.clear()
    
    def __del__(self):
        """Cleanup hooks on deletion."""
        self.remove_hooks()


def compress_model(
    model: nn.Module,
    bits: int = 2,
    dim: int = 1,
    tau: float = 2e-5,
    conv_config: Optional[Dict] = None,
    fc_config: Optional[Dict] = None,
    skip_first_last: bool = True,
    min_params: int = 0,
    **kwargs,
) -> DKMCompressor:
    """
    High-level API to compress a pre-trained model using DKM.
    
    Follows the paper's convention of separate config for conv and fc layers.
    For example, "cv:6/8, fc:6/4" means:
    - Conv layers: 6 bits, 8 dimensions
    - FC layers: 6 bits, 4 dimensions
    
    Args:
        model: Pre-trained PyTorch model
        bits: Default bits for all layers
        dim: Default dimension for clustering
        tau: Temperature parameter
        conv_config: Config for conv layers {"bits": int, "dim": int}
        fc_config: Config for fc layers {"bits": int, "dim": int}
        skip_first_last: Skip first and last layers (Table 1 protocol)
        min_params: Minimum params to compress a layer
    
    Returns:
        DKMCompressor wrapping the model
    """
    # Build per-layer config based on conv/fc separation
    layer_config = {}
    
    if conv_config or fc_config:
        for name, module in model.named_modules():
            if isinstance(module, nn.Conv2d) and conv_config:
                layer_config[name] = {**conv_config}
            elif isinstance(module, nn.Linear) and fc_config:
                layer_config[name] = {**fc_config}
    
    compressor = DKMCompressor(
        model=model,
        bits=bits,
        dim=dim,
        tau=tau,
        layer_config=layer_config,
        skip_first_last=skip_first_last,
        min_params=min_params,
        **kwargs,
    )
    
    return compressor