File size: 14,420 Bytes
63a23d1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 | """
DKM Model Compressor
Wraps a pre-trained PyTorch model with DKM layers for weight clustering compression.
Follows the paper's approach of inserting DKM layers into the forward pass
(Section 3.2) without modifying the loss function or model architecture.
Supports per-layer configuration of bits and dimensions as described in Section 4.1.
"""
import torch
import torch.nn as nn
import math
from typing import Dict, Optional, Tuple, List, Union
from collections import OrderedDict
from .dkm_layer import DKMLayer
class DKMCompressor(nn.Module):
"""
Wraps a pre-trained model with DKM clustering layers.
During forward pass, each wrapped weight parameter is replaced by its
DKM-compressed version. The original weights are kept as parameters
for gradient updates, while DKM layers control the clustering.
Args:
model: Pre-trained PyTorch model to compress
bits: Default number of bits for clustering (k = 2^bits)
dim: Default dimension for multi-dimensional clustering
tau: Default temperature for softmax attention
max_iter: Maximum DKM iterations per layer per forward pass
epsilon: Convergence threshold
layer_config: Optional per-layer configuration dict
Format: {layer_name: {"bits": int, "dim": int, "tau": float}}
skip_layers: List of layer names to skip (not compress)
min_params: Minimum number of parameters in a layer to compress
(paper uses 10000 for special handling)
"""
def __init__(
self,
model: nn.Module,
bits: int = 2,
dim: int = 1,
tau: float = 2e-5,
max_iter: int = 5,
epsilon: float = 1e-4,
layer_config: Optional[Dict] = None,
skip_layers: Optional[List[str]] = None,
min_params: int = 0,
skip_first_last: bool = False,
):
super().__init__()
self.model = model
self.bits = bits
self.dim = dim
self.tau = tau
self.max_iter = max_iter
self.epsilon = epsilon
self.layer_config = layer_config or {}
self.skip_layers = skip_layers or []
self.min_params = min_params
self.skip_first_last = skip_first_last
# Create DKM layers for each applicable weight parameter
self.dkm_layers = nn.ModuleDict()
self._hooks = []
self._setup_dkm_layers()
def _get_compressible_layers(self) -> List[Tuple[str, nn.Module]]:
"""
Identify layers that should be compressed.
Following the paper (Section 4.1):
- Compress Conv2d and Linear layers
- Skip layers in skip_layers list
- Optionally skip first and last layers (Table 1 protocol)
- Skip layers with fewer than min_params parameters
"""
compressible = []
all_layers = []
for name, module in self.model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
all_layers.append((name, module))
for i, (name, module) in enumerate(all_layers):
# Skip first/last layers if requested (common protocol from Table 1)
if self.skip_first_last:
if i == 0 or i == len(all_layers) - 1:
continue
# Skip explicitly excluded layers
if any(skip in name for skip in self.skip_layers):
continue
# Skip small layers
n_params = module.weight.numel()
if n_params < self.min_params:
continue
compressible.append((name, module))
return compressible
def _get_layer_config(self, name: str, module: nn.Module) -> dict:
"""
Get DKM configuration for a specific layer.
Per the paper (Section 4.1):
- Different bits/dim for conv vs fc layers
- Layers with <10000 params get 8-bit clustering
- Per-layer config overrides defaults
"""
config = {
"bits": self.bits,
"dim": self.dim,
"tau": self.tau,
"max_iter": self.max_iter,
"epsilon": self.epsilon,
}
# Paper: "we applied 8 bit clustering to a layer with fewer than 10,000 parameters"
if module.weight.numel() < 10000:
config["bits"] = 8
config["dim"] = 1
# Per-layer overrides
if name in self.layer_config:
config.update(self.layer_config[name])
# Check for wildcard config (e.g., "conv" applies to all conv layers)
for pattern, pattern_config in self.layer_config.items():
if pattern != name and pattern in name:
config.update(pattern_config)
return config
def _setup_dkm_layers(self):
"""
Create DKM layers and register forward hooks to replace weights
during forward pass.
"""
compressible_layers = self._get_compressible_layers()
for name, module in compressible_layers:
config = self._get_layer_config(name, module)
n_clusters = 2 ** config["bits"]
dim = config["dim"]
# Validate dim is compatible with weight size
n_elements = module.weight.numel()
if n_elements % dim != 0:
# Adjust dim to nearest valid value
while dim > 1 and n_elements % dim != 0:
dim -= 1
config["dim"] = dim
# Create DKM layer
safe_name = name.replace(".", "_")
dkm_layer = DKMLayer(
weight_tensor=module.weight,
n_clusters=n_clusters,
tau=config["tau"],
dim=dim,
max_iter=config["max_iter"],
epsilon=config["epsilon"],
)
self.dkm_layers[safe_name] = dkm_layer
# Register forward pre-hook to replace weight during forward pass
hook = module.register_forward_pre_hook(
self._make_hook(safe_name, module)
)
self._hooks.append(hook)
def _make_hook(self, dkm_name: str, module: nn.Module):
"""
Create a forward pre-hook that replaces the module's weight with
the DKM-compressed version during forward pass.
This implements the paper's approach: DKM is inserted into the
forward pass, making optimization fully aligned with the task objective.
"""
def hook(mod, input):
dkm_layer = self.dkm_layers[dkm_name]
# Get compressed weight from DKM layer
compressed_weight = dkm_layer(weight_override=mod.weight)
# Replace weight for this forward pass
mod.weight.data = compressed_weight
return hook
def forward(self, *args, **kwargs):
"""Forward pass through the wrapped model with DKM compression."""
return self.model(*args, **kwargs)
def snap_weights(self):
"""
Snap all weights to nearest centroids for inference.
This is the final step before deployment: each weight is permanently
assigned to its nearest centroid. After this, the model can be
serialized as (codebook + assignments) for compression.
"""
with torch.no_grad():
for name, module in self.model.named_modules():
safe_name = name.replace(".", "_")
if safe_name in self.dkm_layers:
dkm_layer = self.dkm_layers[safe_name]
dkm_layer.eval()
compressed_weight = dkm_layer()
module.weight.data.copy_(compressed_weight)
def get_compression_info(self) -> Dict:
"""
Compute compression statistics for the model.
Returns dict with:
- total_params: Total number of parameters
- compressed_params: Number of compressed parameters
- original_size_mb: Original model size in MB (32-bit float)
- compressed_size_mb: Compressed model size in MB
- compression_ratio: Original/Compressed size ratio
- per_layer: Per-layer compression details
"""
info = {
"per_layer": {},
"total_params": 0,
"compressed_params": 0,
"original_bits": 0,
"compressed_bits": 0,
}
# Count all parameters
for name, param in self.model.named_parameters():
n_params = param.numel()
info["total_params"] += n_params
info["original_bits"] += n_params * 32 # float32
# Count compressed layers
compressed_param_names = set()
for name, module in self.model.named_modules():
safe_name = name.replace(".", "_")
if safe_name in self.dkm_layers:
dkm_layer = self.dkm_layers[safe_name]
n_params = module.weight.numel()
bits = math.log2(dkm_layer.n_clusters)
dim = dkm_layer.dim
bpw = bits / dim # effective bits per weight
# Compressed size:
# - Codebook: k * d * 32 bits (centroids stored in float32)
# - Assignments: (N/d) * bits indices
n_vectors = n_params // dim
codebook_bits = dkm_layer.n_clusters * dim * 32
assignment_bits = n_vectors * bits
layer_compressed_bits = codebook_bits + assignment_bits
info["per_layer"][name] = {
"n_params": n_params,
"n_clusters": dkm_layer.n_clusters,
"dim": dim,
"bits": bits,
"bits_per_weight": bpw,
"original_bits": n_params * 32,
"compressed_bits": layer_compressed_bits,
"compression_ratio": (n_params * 32) / max(layer_compressed_bits, 1),
}
info["compressed_params"] += n_params
info["compressed_bits"] += layer_compressed_bits
compressed_param_names.add(name + ".weight")
# Uncompressed parameters contribute their full size
uncompressed_bits = 0
for pname, param in self.model.named_parameters():
if pname not in compressed_param_names:
uncompressed_bits += param.numel() * 32
info["compressed_bits"] += uncompressed_bits
info["original_size_mb"] = info["original_bits"] / 8 / 1024 / 1024
info["compressed_size_mb"] = info["compressed_bits"] / 8 / 1024 / 1024
info["compression_ratio"] = info["original_bits"] / max(info["compressed_bits"], 1)
return info
def export_compressed(self) -> Dict:
"""
Export the compressed model as codebook + assignments.
Returns a dict with:
- 'state_dict': Original model state dict (with snapped weights)
- 'codebooks': {layer_name: centroid tensor}
- 'assignments': {layer_name: assignment index tensor}
- 'layer_configs': {layer_name: {bits, dim, ...}}
"""
self.snap_weights()
export = {
"state_dict": self.model.state_dict(),
"codebooks": {},
"assignments": {},
"layer_configs": {},
}
for name, module in self.model.named_modules():
safe_name = name.replace(".", "_")
if safe_name in self.dkm_layers:
dkm_layer = self.dkm_layers[safe_name]
export["codebooks"][name] = dkm_layer.get_codebook()
export["assignments"][name] = dkm_layer.get_assignments()
export["layer_configs"][name] = {
"n_clusters": dkm_layer.n_clusters,
"dim": dkm_layer.dim,
"tau": dkm_layer.tau,
"original_shape": list(dkm_layer.original_shape),
}
return export
def remove_hooks(self):
"""Remove all forward hooks (for clean serialization)."""
for hook in self._hooks:
hook.remove()
self._hooks.clear()
def __del__(self):
"""Cleanup hooks on deletion."""
self.remove_hooks()
def compress_model(
model: nn.Module,
bits: int = 2,
dim: int = 1,
tau: float = 2e-5,
conv_config: Optional[Dict] = None,
fc_config: Optional[Dict] = None,
skip_first_last: bool = True,
min_params: int = 0,
**kwargs,
) -> DKMCompressor:
"""
High-level API to compress a pre-trained model using DKM.
Follows the paper's convention of separate config for conv and fc layers.
For example, "cv:6/8, fc:6/4" means:
- Conv layers: 6 bits, 8 dimensions
- FC layers: 6 bits, 4 dimensions
Args:
model: Pre-trained PyTorch model
bits: Default bits for all layers
dim: Default dimension for clustering
tau: Temperature parameter
conv_config: Config for conv layers {"bits": int, "dim": int}
fc_config: Config for fc layers {"bits": int, "dim": int}
skip_first_last: Skip first and last layers (Table 1 protocol)
min_params: Minimum params to compress a layer
Returns:
DKMCompressor wrapping the model
"""
# Build per-layer config based on conv/fc separation
layer_config = {}
if conv_config or fc_config:
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d) and conv_config:
layer_config[name] = {**conv_config}
elif isinstance(module, nn.Linear) and fc_config:
layer_config[name] = {**fc_config}
compressor = DKMCompressor(
model=model,
bits=bits,
dim=dim,
tau=tau,
layer_config=layer_config,
skip_first_last=skip_first_last,
min_params=min_params,
**kwargs,
)
return compressor
|