dkm-compression / dkm /compressor.py
syedmohaiminulhoque's picture
Add compressor and utils modules
63a23d1 verified
"""
DKM Model Compressor
Wraps a pre-trained PyTorch model with DKM layers for weight clustering compression.
Follows the paper's approach of inserting DKM layers into the forward pass
(Section 3.2) without modifying the loss function or model architecture.
Supports per-layer configuration of bits and dimensions as described in Section 4.1.
"""
import torch
import torch.nn as nn
import math
from typing import Dict, Optional, Tuple, List, Union
from collections import OrderedDict
from .dkm_layer import DKMLayer
class DKMCompressor(nn.Module):
"""
Wraps a pre-trained model with DKM clustering layers.
During forward pass, each wrapped weight parameter is replaced by its
DKM-compressed version. The original weights are kept as parameters
for gradient updates, while DKM layers control the clustering.
Args:
model: Pre-trained PyTorch model to compress
bits: Default number of bits for clustering (k = 2^bits)
dim: Default dimension for multi-dimensional clustering
tau: Default temperature for softmax attention
max_iter: Maximum DKM iterations per layer per forward pass
epsilon: Convergence threshold
layer_config: Optional per-layer configuration dict
Format: {layer_name: {"bits": int, "dim": int, "tau": float}}
skip_layers: List of layer names to skip (not compress)
min_params: Minimum number of parameters in a layer to compress
(paper uses 10000 for special handling)
"""
def __init__(
self,
model: nn.Module,
bits: int = 2,
dim: int = 1,
tau: float = 2e-5,
max_iter: int = 5,
epsilon: float = 1e-4,
layer_config: Optional[Dict] = None,
skip_layers: Optional[List[str]] = None,
min_params: int = 0,
skip_first_last: bool = False,
):
super().__init__()
self.model = model
self.bits = bits
self.dim = dim
self.tau = tau
self.max_iter = max_iter
self.epsilon = epsilon
self.layer_config = layer_config or {}
self.skip_layers = skip_layers or []
self.min_params = min_params
self.skip_first_last = skip_first_last
# Create DKM layers for each applicable weight parameter
self.dkm_layers = nn.ModuleDict()
self._hooks = []
self._setup_dkm_layers()
def _get_compressible_layers(self) -> List[Tuple[str, nn.Module]]:
"""
Identify layers that should be compressed.
Following the paper (Section 4.1):
- Compress Conv2d and Linear layers
- Skip layers in skip_layers list
- Optionally skip first and last layers (Table 1 protocol)
- Skip layers with fewer than min_params parameters
"""
compressible = []
all_layers = []
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
all_layers.append((name, module))
for i, (name, module) in enumerate(all_layers):
# Skip first/last layers if requested (common protocol from Table 1)
if self.skip_first_last:
if i == 0 or i == len(all_layers) - 1:
continue
# Skip explicitly excluded layers
if any(skip in name for skip in self.skip_layers):
continue
# Skip small layers
n_params = module.weight.numel()
if n_params < self.min_params:
continue
compressible.append((name, module))
return compressible
def _get_layer_config(self, name: str, module: nn.Module) -> dict:
"""
Get DKM configuration for a specific layer.
Per the paper (Section 4.1):
- Different bits/dim for conv vs fc layers
- Layers with <10000 params get 8-bit clustering
- Per-layer config overrides defaults
"""
config = {
"bits": self.bits,
"dim": self.dim,
"tau": self.tau,
"max_iter": self.max_iter,
"epsilon": self.epsilon,
}
# Paper: "we applied 8 bit clustering to a layer with fewer than 10,000 parameters"
if module.weight.numel() < 10000:
config["bits"] = 8
config["dim"] = 1
# Per-layer overrides
if name in self.layer_config:
config.update(self.layer_config[name])
# Check for wildcard config (e.g., "conv" applies to all conv layers)
for pattern, pattern_config in self.layer_config.items():
if pattern != name and pattern in name:
config.update(pattern_config)
return config
def _setup_dkm_layers(self):
"""
Create DKM layers and register forward hooks to replace weights
during forward pass.
"""
compressible_layers = self._get_compressible_layers()
for name, module in compressible_layers:
config = self._get_layer_config(name, module)
n_clusters = 2 ** config["bits"]
dim = config["dim"]
# Validate dim is compatible with weight size
n_elements = module.weight.numel()
if n_elements % dim != 0:
# Adjust dim to nearest valid value
while dim > 1 and n_elements % dim != 0:
dim -= 1
config["dim"] = dim
# Create DKM layer
safe_name = name.replace(".", "_")
dkm_layer = DKMLayer(
weight_tensor=module.weight,
n_clusters=n_clusters,
tau=config["tau"],
dim=dim,
max_iter=config["max_iter"],
epsilon=config["epsilon"],
)
self.dkm_layers[safe_name] = dkm_layer
# Register forward pre-hook to replace weight during forward pass
hook = module.register_forward_pre_hook(
self._make_hook(safe_name, module)
)
self._hooks.append(hook)
def _make_hook(self, dkm_name: str, module: nn.Module):
"""
Create a forward pre-hook that replaces the module's weight with
the DKM-compressed version during forward pass.
This implements the paper's approach: DKM is inserted into the
forward pass, making optimization fully aligned with the task objective.
"""
def hook(mod, input):
dkm_layer = self.dkm_layers[dkm_name]
# Get compressed weight from DKM layer
compressed_weight = dkm_layer(weight_override=mod.weight)
# Replace weight for this forward pass
mod.weight.data = compressed_weight
return hook
def forward(self, *args, **kwargs):
"""Forward pass through the wrapped model with DKM compression."""
return self.model(*args, **kwargs)
def snap_weights(self):
"""
Snap all weights to nearest centroids for inference.
This is the final step before deployment: each weight is permanently
assigned to its nearest centroid. After this, the model can be
serialized as (codebook + assignments) for compression.
"""
with torch.no_grad():
for name, module in self.model.named_modules():
safe_name = name.replace(".", "_")
if safe_name in self.dkm_layers:
dkm_layer = self.dkm_layers[safe_name]
dkm_layer.eval()
compressed_weight = dkm_layer()
module.weight.data.copy_(compressed_weight)
def get_compression_info(self) -> Dict:
"""
Compute compression statistics for the model.
Returns dict with:
- total_params: Total number of parameters
- compressed_params: Number of compressed parameters
- original_size_mb: Original model size in MB (32-bit float)
- compressed_size_mb: Compressed model size in MB
- compression_ratio: Original/Compressed size ratio
- per_layer: Per-layer compression details
"""
info = {
"per_layer": {},
"total_params": 0,
"compressed_params": 0,
"original_bits": 0,
"compressed_bits": 0,
}
# Count all parameters
for name, param in self.model.named_parameters():
n_params = param.numel()
info["total_params"] += n_params
info["original_bits"] += n_params * 32 # float32
# Count compressed layers
compressed_param_names = set()
for name, module in self.model.named_modules():
safe_name = name.replace(".", "_")
if safe_name in self.dkm_layers:
dkm_layer = self.dkm_layers[safe_name]
n_params = module.weight.numel()
bits = math.log2(dkm_layer.n_clusters)
dim = dkm_layer.dim
bpw = bits / dim # effective bits per weight
# Compressed size:
# - Codebook: k * d * 32 bits (centroids stored in float32)
# - Assignments: (N/d) * bits indices
n_vectors = n_params // dim
codebook_bits = dkm_layer.n_clusters * dim * 32
assignment_bits = n_vectors * bits
layer_compressed_bits = codebook_bits + assignment_bits
info["per_layer"][name] = {
"n_params": n_params,
"n_clusters": dkm_layer.n_clusters,
"dim": dim,
"bits": bits,
"bits_per_weight": bpw,
"original_bits": n_params * 32,
"compressed_bits": layer_compressed_bits,
"compression_ratio": (n_params * 32) / max(layer_compressed_bits, 1),
}
info["compressed_params"] += n_params
info["compressed_bits"] += layer_compressed_bits
compressed_param_names.add(name + ".weight")
# Uncompressed parameters contribute their full size
uncompressed_bits = 0
for pname, param in self.model.named_parameters():
if pname not in compressed_param_names:
uncompressed_bits += param.numel() * 32
info["compressed_bits"] += uncompressed_bits
info["original_size_mb"] = info["original_bits"] / 8 / 1024 / 1024
info["compressed_size_mb"] = info["compressed_bits"] / 8 / 1024 / 1024
info["compression_ratio"] = info["original_bits"] / max(info["compressed_bits"], 1)
return info
def export_compressed(self) -> Dict:
"""
Export the compressed model as codebook + assignments.
Returns a dict with:
- 'state_dict': Original model state dict (with snapped weights)
- 'codebooks': {layer_name: centroid tensor}
- 'assignments': {layer_name: assignment index tensor}
- 'layer_configs': {layer_name: {bits, dim, ...}}
"""
self.snap_weights()
export = {
"state_dict": self.model.state_dict(),
"codebooks": {},
"assignments": {},
"layer_configs": {},
}
for name, module in self.model.named_modules():
safe_name = name.replace(".", "_")
if safe_name in self.dkm_layers:
dkm_layer = self.dkm_layers[safe_name]
export["codebooks"][name] = dkm_layer.get_codebook()
export["assignments"][name] = dkm_layer.get_assignments()
export["layer_configs"][name] = {
"n_clusters": dkm_layer.n_clusters,
"dim": dkm_layer.dim,
"tau": dkm_layer.tau,
"original_shape": list(dkm_layer.original_shape),
}
return export
def remove_hooks(self):
"""Remove all forward hooks (for clean serialization)."""
for hook in self._hooks:
hook.remove()
self._hooks.clear()
def __del__(self):
"""Cleanup hooks on deletion."""
self.remove_hooks()
def compress_model(
model: nn.Module,
bits: int = 2,
dim: int = 1,
tau: float = 2e-5,
conv_config: Optional[Dict] = None,
fc_config: Optional[Dict] = None,
skip_first_last: bool = True,
min_params: int = 0,
**kwargs,
) -> DKMCompressor:
"""
High-level API to compress a pre-trained model using DKM.
Follows the paper's convention of separate config for conv and fc layers.
For example, "cv:6/8, fc:6/4" means:
- Conv layers: 6 bits, 8 dimensions
- FC layers: 6 bits, 4 dimensions
Args:
model: Pre-trained PyTorch model
bits: Default bits for all layers
dim: Default dimension for clustering
tau: Temperature parameter
conv_config: Config for conv layers {"bits": int, "dim": int}
fc_config: Config for fc layers {"bits": int, "dim": int}
skip_first_last: Skip first and last layers (Table 1 protocol)
min_params: Minimum params to compress a layer
Returns:
DKMCompressor wrapping the model
"""
# Build per-layer config based on conv/fc separation
layer_config = {}
if conv_config or fc_config:
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d) and conv_config:
layer_config[name] = {**conv_config}
elif isinstance(module, nn.Linear) and fc_config:
layer_config[name] = {**fc_config}
compressor = DKMCompressor(
model=model,
bits=bits,
dim=dim,
tau=tau,
layer_config=layer_config,
skip_first_last=skip_first_last,
min_params=min_params,
**kwargs,
)
return compressor