| """ |
| DKM Model Compressor |
| |
| Wraps a pre-trained PyTorch model with DKM layers for weight clustering compression. |
| Follows the paper's approach of inserting DKM layers into the forward pass |
| (Section 3.2) without modifying the loss function or model architecture. |
| |
| Supports per-layer configuration of bits and dimensions as described in Section 4.1. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import math |
| from typing import Dict, Optional, Tuple, List, Union |
| from collections import OrderedDict |
|
|
| from .dkm_layer import DKMLayer |
|
|
|
|
| class DKMCompressor(nn.Module): |
| """ |
| Wraps a pre-trained model with DKM clustering layers. |
| |
| During forward pass, each wrapped weight parameter is replaced by its |
| DKM-compressed version. The original weights are kept as parameters |
| for gradient updates, while DKM layers control the clustering. |
| |
| Args: |
| model: Pre-trained PyTorch model to compress |
| bits: Default number of bits for clustering (k = 2^bits) |
| dim: Default dimension for multi-dimensional clustering |
| tau: Default temperature for softmax attention |
| max_iter: Maximum DKM iterations per layer per forward pass |
| epsilon: Convergence threshold |
| layer_config: Optional per-layer configuration dict |
| Format: {layer_name: {"bits": int, "dim": int, "tau": float}} |
| skip_layers: List of layer names to skip (not compress) |
| min_params: Minimum number of parameters in a layer to compress |
| (paper uses 10000 for special handling) |
| """ |
| |
| def __init__( |
| self, |
| model: nn.Module, |
| bits: int = 2, |
| dim: int = 1, |
| tau: float = 2e-5, |
| max_iter: int = 5, |
| epsilon: float = 1e-4, |
| layer_config: Optional[Dict] = None, |
| skip_layers: Optional[List[str]] = None, |
| min_params: int = 0, |
| skip_first_last: bool = False, |
| ): |
| super().__init__() |
| |
| self.model = model |
| self.bits = bits |
| self.dim = dim |
| self.tau = tau |
| self.max_iter = max_iter |
| self.epsilon = epsilon |
| self.layer_config = layer_config or {} |
| self.skip_layers = skip_layers or [] |
| self.min_params = min_params |
| self.skip_first_last = skip_first_last |
| |
| |
| self.dkm_layers = nn.ModuleDict() |
| self._hooks = [] |
| |
| self._setup_dkm_layers() |
| |
| def _get_compressible_layers(self) -> List[Tuple[str, nn.Module]]: |
| """ |
| Identify layers that should be compressed. |
| |
| Following the paper (Section 4.1): |
| - Compress Conv2d and Linear layers |
| - Skip layers in skip_layers list |
| - Optionally skip first and last layers (Table 1 protocol) |
| - Skip layers with fewer than min_params parameters |
| """ |
| compressible = [] |
| all_layers = [] |
| |
| for name, module in self.model.named_modules(): |
| if isinstance(module, (nn.Conv2d, nn.Linear)): |
| all_layers.append((name, module)) |
| |
| for i, (name, module) in enumerate(all_layers): |
| |
| if self.skip_first_last: |
| if i == 0 or i == len(all_layers) - 1: |
| continue |
| |
| |
| if any(skip in name for skip in self.skip_layers): |
| continue |
| |
| |
| n_params = module.weight.numel() |
| if n_params < self.min_params: |
| continue |
| |
| compressible.append((name, module)) |
| |
| return compressible |
| |
| def _get_layer_config(self, name: str, module: nn.Module) -> dict: |
| """ |
| Get DKM configuration for a specific layer. |
| |
| Per the paper (Section 4.1): |
| - Different bits/dim for conv vs fc layers |
| - Layers with <10000 params get 8-bit clustering |
| - Per-layer config overrides defaults |
| """ |
| config = { |
| "bits": self.bits, |
| "dim": self.dim, |
| "tau": self.tau, |
| "max_iter": self.max_iter, |
| "epsilon": self.epsilon, |
| } |
| |
| |
| if module.weight.numel() < 10000: |
| config["bits"] = 8 |
| config["dim"] = 1 |
| |
| |
| if name in self.layer_config: |
| config.update(self.layer_config[name]) |
| |
| |
| for pattern, pattern_config in self.layer_config.items(): |
| if pattern != name and pattern in name: |
| config.update(pattern_config) |
| |
| return config |
| |
| def _setup_dkm_layers(self): |
| """ |
| Create DKM layers and register forward hooks to replace weights |
| during forward pass. |
| """ |
| compressible_layers = self._get_compressible_layers() |
| |
| for name, module in compressible_layers: |
| config = self._get_layer_config(name, module) |
| |
| n_clusters = 2 ** config["bits"] |
| dim = config["dim"] |
| |
| |
| n_elements = module.weight.numel() |
| if n_elements % dim != 0: |
| |
| while dim > 1 and n_elements % dim != 0: |
| dim -= 1 |
| config["dim"] = dim |
| |
| |
| safe_name = name.replace(".", "_") |
| dkm_layer = DKMLayer( |
| weight_tensor=module.weight, |
| n_clusters=n_clusters, |
| tau=config["tau"], |
| dim=dim, |
| max_iter=config["max_iter"], |
| epsilon=config["epsilon"], |
| ) |
| |
| self.dkm_layers[safe_name] = dkm_layer |
| |
| |
| hook = module.register_forward_pre_hook( |
| self._make_hook(safe_name, module) |
| ) |
| self._hooks.append(hook) |
| |
| def _make_hook(self, dkm_name: str, module: nn.Module): |
| """ |
| Create a forward pre-hook that replaces the module's weight with |
| the DKM-compressed version during forward pass. |
| |
| This implements the paper's approach: DKM is inserted into the |
| forward pass, making optimization fully aligned with the task objective. |
| """ |
| def hook(mod, input): |
| dkm_layer = self.dkm_layers[dkm_name] |
| |
| compressed_weight = dkm_layer(weight_override=mod.weight) |
| |
| mod.weight.data = compressed_weight |
| |
| return hook |
| |
| def forward(self, *args, **kwargs): |
| """Forward pass through the wrapped model with DKM compression.""" |
| return self.model(*args, **kwargs) |
| |
| def snap_weights(self): |
| """ |
| Snap all weights to nearest centroids for inference. |
| |
| This is the final step before deployment: each weight is permanently |
| assigned to its nearest centroid. After this, the model can be |
| serialized as (codebook + assignments) for compression. |
| """ |
| with torch.no_grad(): |
| for name, module in self.model.named_modules(): |
| safe_name = name.replace(".", "_") |
| if safe_name in self.dkm_layers: |
| dkm_layer = self.dkm_layers[safe_name] |
| dkm_layer.eval() |
| compressed_weight = dkm_layer() |
| module.weight.data.copy_(compressed_weight) |
| |
| def get_compression_info(self) -> Dict: |
| """ |
| Compute compression statistics for the model. |
| |
| Returns dict with: |
| - total_params: Total number of parameters |
| - compressed_params: Number of compressed parameters |
| - original_size_mb: Original model size in MB (32-bit float) |
| - compressed_size_mb: Compressed model size in MB |
| - compression_ratio: Original/Compressed size ratio |
| - per_layer: Per-layer compression details |
| """ |
| info = { |
| "per_layer": {}, |
| "total_params": 0, |
| "compressed_params": 0, |
| "original_bits": 0, |
| "compressed_bits": 0, |
| } |
| |
| |
| for name, param in self.model.named_parameters(): |
| n_params = param.numel() |
| info["total_params"] += n_params |
| info["original_bits"] += n_params * 32 |
| |
| |
| compressed_param_names = set() |
| for name, module in self.model.named_modules(): |
| safe_name = name.replace(".", "_") |
| if safe_name in self.dkm_layers: |
| dkm_layer = self.dkm_layers[safe_name] |
| n_params = module.weight.numel() |
| |
| bits = math.log2(dkm_layer.n_clusters) |
| dim = dkm_layer.dim |
| bpw = bits / dim |
| |
| |
| |
| |
| n_vectors = n_params // dim |
| codebook_bits = dkm_layer.n_clusters * dim * 32 |
| assignment_bits = n_vectors * bits |
| layer_compressed_bits = codebook_bits + assignment_bits |
| |
| info["per_layer"][name] = { |
| "n_params": n_params, |
| "n_clusters": dkm_layer.n_clusters, |
| "dim": dim, |
| "bits": bits, |
| "bits_per_weight": bpw, |
| "original_bits": n_params * 32, |
| "compressed_bits": layer_compressed_bits, |
| "compression_ratio": (n_params * 32) / max(layer_compressed_bits, 1), |
| } |
| |
| info["compressed_params"] += n_params |
| info["compressed_bits"] += layer_compressed_bits |
| compressed_param_names.add(name + ".weight") |
| |
| |
| uncompressed_bits = 0 |
| for pname, param in self.model.named_parameters(): |
| if pname not in compressed_param_names: |
| uncompressed_bits += param.numel() * 32 |
| |
| info["compressed_bits"] += uncompressed_bits |
| info["original_size_mb"] = info["original_bits"] / 8 / 1024 / 1024 |
| info["compressed_size_mb"] = info["compressed_bits"] / 8 / 1024 / 1024 |
| info["compression_ratio"] = info["original_bits"] / max(info["compressed_bits"], 1) |
| |
| return info |
| |
| def export_compressed(self) -> Dict: |
| """ |
| Export the compressed model as codebook + assignments. |
| |
| Returns a dict with: |
| - 'state_dict': Original model state dict (with snapped weights) |
| - 'codebooks': {layer_name: centroid tensor} |
| - 'assignments': {layer_name: assignment index tensor} |
| - 'layer_configs': {layer_name: {bits, dim, ...}} |
| """ |
| self.snap_weights() |
| |
| export = { |
| "state_dict": self.model.state_dict(), |
| "codebooks": {}, |
| "assignments": {}, |
| "layer_configs": {}, |
| } |
| |
| for name, module in self.model.named_modules(): |
| safe_name = name.replace(".", "_") |
| if safe_name in self.dkm_layers: |
| dkm_layer = self.dkm_layers[safe_name] |
| export["codebooks"][name] = dkm_layer.get_codebook() |
| export["assignments"][name] = dkm_layer.get_assignments() |
| export["layer_configs"][name] = { |
| "n_clusters": dkm_layer.n_clusters, |
| "dim": dkm_layer.dim, |
| "tau": dkm_layer.tau, |
| "original_shape": list(dkm_layer.original_shape), |
| } |
| |
| return export |
| |
| def remove_hooks(self): |
| """Remove all forward hooks (for clean serialization).""" |
| for hook in self._hooks: |
| hook.remove() |
| self._hooks.clear() |
| |
| def __del__(self): |
| """Cleanup hooks on deletion.""" |
| self.remove_hooks() |
|
|
|
|
| def compress_model( |
| model: nn.Module, |
| bits: int = 2, |
| dim: int = 1, |
| tau: float = 2e-5, |
| conv_config: Optional[Dict] = None, |
| fc_config: Optional[Dict] = None, |
| skip_first_last: bool = True, |
| min_params: int = 0, |
| **kwargs, |
| ) -> DKMCompressor: |
| """ |
| High-level API to compress a pre-trained model using DKM. |
| |
| Follows the paper's convention of separate config for conv and fc layers. |
| For example, "cv:6/8, fc:6/4" means: |
| - Conv layers: 6 bits, 8 dimensions |
| - FC layers: 6 bits, 4 dimensions |
| |
| Args: |
| model: Pre-trained PyTorch model |
| bits: Default bits for all layers |
| dim: Default dimension for clustering |
| tau: Temperature parameter |
| conv_config: Config for conv layers {"bits": int, "dim": int} |
| fc_config: Config for fc layers {"bits": int, "dim": int} |
| skip_first_last: Skip first and last layers (Table 1 protocol) |
| min_params: Minimum params to compress a layer |
| |
| Returns: |
| DKMCompressor wrapping the model |
| """ |
| |
| layer_config = {} |
| |
| if conv_config or fc_config: |
| for name, module in model.named_modules(): |
| if isinstance(module, nn.Conv2d) and conv_config: |
| layer_config[name] = {**conv_config} |
| elif isinstance(module, nn.Linear) and fc_config: |
| layer_config[name] = {**fc_config} |
| |
| compressor = DKMCompressor( |
| model=model, |
| bits=bits, |
| dim=dim, |
| tau=tau, |
| layer_config=layer_config, |
| skip_first_last=skip_first_last, |
| min_params=min_params, |
| **kwargs, |
| ) |
| |
| return compressor |
|
|