""" DKM Model Compressor Wraps a pre-trained PyTorch model with DKM layers for weight clustering compression. Follows the paper's approach of inserting DKM layers into the forward pass (Section 3.2) without modifying the loss function or model architecture. Supports per-layer configuration of bits and dimensions as described in Section 4.1. """ import torch import torch.nn as nn import math from typing import Dict, Optional, Tuple, List, Union from collections import OrderedDict from .dkm_layer import DKMLayer class DKMCompressor(nn.Module): """ Wraps a pre-trained model with DKM clustering layers. During forward pass, each wrapped weight parameter is replaced by its DKM-compressed version. The original weights are kept as parameters for gradient updates, while DKM layers control the clustering. Args: model: Pre-trained PyTorch model to compress bits: Default number of bits for clustering (k = 2^bits) dim: Default dimension for multi-dimensional clustering tau: Default temperature for softmax attention max_iter: Maximum DKM iterations per layer per forward pass epsilon: Convergence threshold layer_config: Optional per-layer configuration dict Format: {layer_name: {"bits": int, "dim": int, "tau": float}} skip_layers: List of layer names to skip (not compress) min_params: Minimum number of parameters in a layer to compress (paper uses 10000 for special handling) """ def __init__( self, model: nn.Module, bits: int = 2, dim: int = 1, tau: float = 2e-5, max_iter: int = 5, epsilon: float = 1e-4, layer_config: Optional[Dict] = None, skip_layers: Optional[List[str]] = None, min_params: int = 0, skip_first_last: bool = False, ): super().__init__() self.model = model self.bits = bits self.dim = dim self.tau = tau self.max_iter = max_iter self.epsilon = epsilon self.layer_config = layer_config or {} self.skip_layers = skip_layers or [] self.min_params = min_params self.skip_first_last = skip_first_last # Create DKM layers for each applicable weight parameter self.dkm_layers = nn.ModuleDict() self._hooks = [] self._setup_dkm_layers() def _get_compressible_layers(self) -> List[Tuple[str, nn.Module]]: """ Identify layers that should be compressed. Following the paper (Section 4.1): - Compress Conv2d and Linear layers - Skip layers in skip_layers list - Optionally skip first and last layers (Table 1 protocol) - Skip layers with fewer than min_params parameters """ compressible = [] all_layers = [] for name, module in self.model.named_modules(): if isinstance(module, (nn.Conv2d, nn.Linear)): all_layers.append((name, module)) for i, (name, module) in enumerate(all_layers): # Skip first/last layers if requested (common protocol from Table 1) if self.skip_first_last: if i == 0 or i == len(all_layers) - 1: continue # Skip explicitly excluded layers if any(skip in name for skip in self.skip_layers): continue # Skip small layers n_params = module.weight.numel() if n_params < self.min_params: continue compressible.append((name, module)) return compressible def _get_layer_config(self, name: str, module: nn.Module) -> dict: """ Get DKM configuration for a specific layer. Per the paper (Section 4.1): - Different bits/dim for conv vs fc layers - Layers with <10000 params get 8-bit clustering - Per-layer config overrides defaults """ config = { "bits": self.bits, "dim": self.dim, "tau": self.tau, "max_iter": self.max_iter, "epsilon": self.epsilon, } # Paper: "we applied 8 bit clustering to a layer with fewer than 10,000 parameters" if module.weight.numel() < 10000: config["bits"] = 8 config["dim"] = 1 # Per-layer overrides if name in self.layer_config: config.update(self.layer_config[name]) # Check for wildcard config (e.g., "conv" applies to all conv layers) for pattern, pattern_config in self.layer_config.items(): if pattern != name and pattern in name: config.update(pattern_config) return config def _setup_dkm_layers(self): """ Create DKM layers and register forward hooks to replace weights during forward pass. """ compressible_layers = self._get_compressible_layers() for name, module in compressible_layers: config = self._get_layer_config(name, module) n_clusters = 2 ** config["bits"] dim = config["dim"] # Validate dim is compatible with weight size n_elements = module.weight.numel() if n_elements % dim != 0: # Adjust dim to nearest valid value while dim > 1 and n_elements % dim != 0: dim -= 1 config["dim"] = dim # Create DKM layer safe_name = name.replace(".", "_") dkm_layer = DKMLayer( weight_tensor=module.weight, n_clusters=n_clusters, tau=config["tau"], dim=dim, max_iter=config["max_iter"], epsilon=config["epsilon"], ) self.dkm_layers[safe_name] = dkm_layer # Register forward pre-hook to replace weight during forward pass hook = module.register_forward_pre_hook( self._make_hook(safe_name, module) ) self._hooks.append(hook) def _make_hook(self, dkm_name: str, module: nn.Module): """ Create a forward pre-hook that replaces the module's weight with the DKM-compressed version during forward pass. This implements the paper's approach: DKM is inserted into the forward pass, making optimization fully aligned with the task objective. """ def hook(mod, input): dkm_layer = self.dkm_layers[dkm_name] # Get compressed weight from DKM layer compressed_weight = dkm_layer(weight_override=mod.weight) # Replace weight for this forward pass mod.weight.data = compressed_weight return hook def forward(self, *args, **kwargs): """Forward pass through the wrapped model with DKM compression.""" return self.model(*args, **kwargs) def snap_weights(self): """ Snap all weights to nearest centroids for inference. This is the final step before deployment: each weight is permanently assigned to its nearest centroid. After this, the model can be serialized as (codebook + assignments) for compression. """ with torch.no_grad(): for name, module in self.model.named_modules(): safe_name = name.replace(".", "_") if safe_name in self.dkm_layers: dkm_layer = self.dkm_layers[safe_name] dkm_layer.eval() compressed_weight = dkm_layer() module.weight.data.copy_(compressed_weight) def get_compression_info(self) -> Dict: """ Compute compression statistics for the model. Returns dict with: - total_params: Total number of parameters - compressed_params: Number of compressed parameters - original_size_mb: Original model size in MB (32-bit float) - compressed_size_mb: Compressed model size in MB - compression_ratio: Original/Compressed size ratio - per_layer: Per-layer compression details """ info = { "per_layer": {}, "total_params": 0, "compressed_params": 0, "original_bits": 0, "compressed_bits": 0, } # Count all parameters for name, param in self.model.named_parameters(): n_params = param.numel() info["total_params"] += n_params info["original_bits"] += n_params * 32 # float32 # Count compressed layers compressed_param_names = set() for name, module in self.model.named_modules(): safe_name = name.replace(".", "_") if safe_name in self.dkm_layers: dkm_layer = self.dkm_layers[safe_name] n_params = module.weight.numel() bits = math.log2(dkm_layer.n_clusters) dim = dkm_layer.dim bpw = bits / dim # effective bits per weight # Compressed size: # - Codebook: k * d * 32 bits (centroids stored in float32) # - Assignments: (N/d) * bits indices n_vectors = n_params // dim codebook_bits = dkm_layer.n_clusters * dim * 32 assignment_bits = n_vectors * bits layer_compressed_bits = codebook_bits + assignment_bits info["per_layer"][name] = { "n_params": n_params, "n_clusters": dkm_layer.n_clusters, "dim": dim, "bits": bits, "bits_per_weight": bpw, "original_bits": n_params * 32, "compressed_bits": layer_compressed_bits, "compression_ratio": (n_params * 32) / max(layer_compressed_bits, 1), } info["compressed_params"] += n_params info["compressed_bits"] += layer_compressed_bits compressed_param_names.add(name + ".weight") # Uncompressed parameters contribute their full size uncompressed_bits = 0 for pname, param in self.model.named_parameters(): if pname not in compressed_param_names: uncompressed_bits += param.numel() * 32 info["compressed_bits"] += uncompressed_bits info["original_size_mb"] = info["original_bits"] / 8 / 1024 / 1024 info["compressed_size_mb"] = info["compressed_bits"] / 8 / 1024 / 1024 info["compression_ratio"] = info["original_bits"] / max(info["compressed_bits"], 1) return info def export_compressed(self) -> Dict: """ Export the compressed model as codebook + assignments. Returns a dict with: - 'state_dict': Original model state dict (with snapped weights) - 'codebooks': {layer_name: centroid tensor} - 'assignments': {layer_name: assignment index tensor} - 'layer_configs': {layer_name: {bits, dim, ...}} """ self.snap_weights() export = { "state_dict": self.model.state_dict(), "codebooks": {}, "assignments": {}, "layer_configs": {}, } for name, module in self.model.named_modules(): safe_name = name.replace(".", "_") if safe_name in self.dkm_layers: dkm_layer = self.dkm_layers[safe_name] export["codebooks"][name] = dkm_layer.get_codebook() export["assignments"][name] = dkm_layer.get_assignments() export["layer_configs"][name] = { "n_clusters": dkm_layer.n_clusters, "dim": dkm_layer.dim, "tau": dkm_layer.tau, "original_shape": list(dkm_layer.original_shape), } return export def remove_hooks(self): """Remove all forward hooks (for clean serialization).""" for hook in self._hooks: hook.remove() self._hooks.clear() def __del__(self): """Cleanup hooks on deletion.""" self.remove_hooks() def compress_model( model: nn.Module, bits: int = 2, dim: int = 1, tau: float = 2e-5, conv_config: Optional[Dict] = None, fc_config: Optional[Dict] = None, skip_first_last: bool = True, min_params: int = 0, **kwargs, ) -> DKMCompressor: """ High-level API to compress a pre-trained model using DKM. Follows the paper's convention of separate config for conv and fc layers. For example, "cv:6/8, fc:6/4" means: - Conv layers: 6 bits, 8 dimensions - FC layers: 6 bits, 4 dimensions Args: model: Pre-trained PyTorch model bits: Default bits for all layers dim: Default dimension for clustering tau: Temperature parameter conv_config: Config for conv layers {"bits": int, "dim": int} fc_config: Config for fc layers {"bits": int, "dim": int} skip_first_last: Skip first and last layers (Table 1 protocol) min_params: Minimum params to compress a layer Returns: DKMCompressor wrapping the model """ # Build per-layer config based on conv/fc separation layer_config = {} if conv_config or fc_config: for name, module in model.named_modules(): if isinstance(module, nn.Conv2d) and conv_config: layer_config[name] = {**conv_config} elif isinstance(module, nn.Linear) and fc_config: layer_config[name] = {**fc_config} compressor = DKMCompressor( model=model, bits=bits, dim=dim, tau=tau, layer_config=layer_config, skip_first_last=skip_first_last, min_params=min_params, **kwargs, ) return compressor