File size: 7,539 Bytes
d18bfc9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | # DKM: Differentiable K-Means Clustering Layer for Neural Network Compression
**PyTorch implementation of the ICLR 2022 paper by Cho et al.**
π [Paper (arXiv:2108.12659)](https://arxiv.org/abs/2108.12659) | ποΈ [ICLR 2022](https://openreview.net/forum?id=J_F_qqCE3Z5)
## Overview
DKM casts **k-means weight clustering** as a differentiable **attention problem**, enabling joint optimization of DNN parameters and clustering centroids through standard backpropagation. Unlike prior weight-clustering methods that rely on hard assignments and approximated gradients, DKM uses soft attention-based assignment that is fully differentiable.
### Key Innovation
```
Traditional: weights β hard k-means assignment β fixed centroids (not differentiable)
DKM: weights β attention-based soft assignment β differentiable centroids
```
The DKM layer:
1. Computes a **distance matrix** D between weights W and centroids C
2. Applies **softmax with temperature Ο** to get attention matrix A = softmax(D/Ο)
3. Updates centroids: c_j = Ξ£_i(a_ij Γ w_i) / Ξ£_i(a_ij)
4. Iterates until convergence
5. Returns compressed weights: WΜ = A Γ C
### Paper Results
| Model | Config | Top-1 Acc (%) | Size (MB) | Compression |
|-------|--------|--------------|-----------|-------------|
| ResNet50 | cv:6/6, fc:6/4 | 74.5 | 3.32 | 29.4Γ |
| MobileNet-v1 | cv:4/4, fc:4/2 | 63.9 | 0.72 | 22.4Γ |
| MobileNet-v2 | cv:2/1, fc:4/4 | 68.0 | 0.84 | 15.8Γ |
| DistilBERT | - | -1.1% acc drop | - | 11.8Γ |
## Installation
```bash
git clone https://huggingface.co/syedmohaiminulhoque/dkm-compression
cd dkm-compression
pip install torch torchvision
```
## Quick Start
```python
import torch
import torch.nn as nn
from dkm import compress_model
from dkm.utils import print_compression_summary
# Load any pre-trained model
model = torchvision.models.resnet18(weights="DEFAULT")
# Compress with DKM (2-bit clustering)
compressor = compress_model(
model,
bits=2, # k = 2^bits = 4 clusters
dim=1, # scalar clustering (dim=1) or multi-dim
tau=2e-5, # temperature (controls softness of assignment)
skip_first_last=True, # skip first/last layers (per paper protocol)
)
# Print compression statistics
info = compressor.get_compression_info()
print_compression_summary(info)
# Train with standard PyTorch loop (paper: SGD, lr=0.008, momentum=0.9)
optimizer = torch.optim.SGD(compressor.parameters(), lr=0.008, momentum=0.9)
criterion = nn.CrossEntropyLoss()
compressor.train()
for images, labels in dataloader:
optimizer.zero_grad()
outputs = compressor(images)
loss = criterion(outputs, labels)
loss.backward() # Gradients flow through DKM attention layers
optimizer.step()
# Snap to nearest centroids for inference
compressor.snap_weights()
# Export compressed model (codebook + assignments)
export = compressor.export_compressed()
torch.save(export, "compressed_model.pt")
```
## Multi-Dimensional Clustering (Section 3.3)
DKM supports multi-dimensional weight clustering for higher compression:
```python
# Paper notation: "bits/dim" e.g., "4/4" means 4 bits, 4 dimensions
# Effective bits-per-weight = bits / dim
# Configuration cv:6/8, fc:6/4 (as in Table 3 of the paper)
compressor = compress_model(
model,
bits=6,
conv_config={"bits": 6, "dim": 8}, # 6 bits, 8 dims β 0.75 bpw
fc_config={"bits": 6, "dim": 4}, # 6 bits, 4 dims β 1.5 bpw
tau=2e-5,
)
```
| Config | Clusters | Dim | Effective BPW |
|--------|----------|-----|---------------|
| 3-bit | 8 | 1 | 3.0 |
| 2-bit | 4 | 1 | 2.0 |
| 1-bit | 2 | 1 | 1.0 |
| 4/4 | 16 | 4 | 1.0 |
| 8/8 | 256 | 8 | 1.0 |
| 4/8 | 16 | 8 | 0.5 |
| 8/16 | 256 | 16 | 0.5 |
## Temperature Ο Guidelines (Appendix B)
The temperature controls the softness of cluster assignment:
- **Smaller Ο** β harder assignment (near one-hot), closer to standard k-means
- **Larger Ο** β softer assignment, more gradient flow, better for hard compression tasks
| Model | 3-bit | 2-bit | 1-bit | 4/4 | 8/8 |
|-------|-------|-------|-------|-----|-----|
| ResNet18 | 8e-6 | 2e-5 | 5e-5 | 5e-5 | 8e-5 |
| ResNet50 | 8e-6 | 2e-5 | 5e-5 | 4e-5 | OOM |
| MobileNet-v1 | 5e-5 | 1e-4 | 3e-4 | 1e-4 | 1e-4 |
| MobileNet-v2 | 5e-5 | 1e-4 | 1.5e-4 | 1e-4 | 1e-4 |
## Architecture
```
dkm/
βββ __init__.py # Package exports
βββ dkm_layer.py # Core DKM layer (Section 3.2-3.3)
βββ compressor.py # Model wrapper with DKM layers (Section 4)
βββ utils.py # Compression analysis utilities
tests/
βββ test_dkm.py # 16 comprehensive test groups (all passing)
train.py # Full training pipeline (CIFAR-10 demo)
```
### Core Components
- **`DKMLayer`**: The differentiable k-means clustering layer. Implements the iterative attention-based clustering from Fig. 2 of the paper, with k-means++ initialization, warm start across batches, and convergence checking.
- **`DKMCompressor`**: Wraps any PyTorch model by inserting DKM layers via forward pre-hooks. Handles per-layer configuration (different bits/dim for conv vs fc), the paper's protocol for small layers (<10K params β 8-bit), and first/last layer skipping.
- **`compress_model`**: High-level API matching the paper's notation (cv:bits/dim, fc:bits/dim).
## Training Protocol (Section 4)
Following the paper exactly:
- **Optimizer**: SGD with momentum 0.9
- **Learning rate**: 0.008 (fixed, no per-layer tuning)
- **Loss**: Original task loss (no regularizers or modifications)
- **Epochs**: 200 for ImageNet, varies for GLUE
- **Batch size**: 128 per GPU (paper used 8Γ V100)
- **Convergence**: Ξ΅ = 1e-4, max 5 DKM iterations per layer
- **Small layers**: Layers with <10,000 parameters get 8-bit clustering
## Compressed Model Format
After training, `export_compressed()` returns:
- **state_dict**: Standard PyTorch state dict (with snapped weights)
- **codebooks**: Per-layer centroid tensors (k Γ d float32)
- **assignments**: Per-layer cluster index tensors (N/d integers, b bits each)
- **layer_configs**: Per-layer DKM configuration
The actual compressed size = Ξ£(codebook_bits + assignment_bits) per layer + uncompressed params.
## Tests
All 16 test groups pass, covering:
1. Shape preservation (train & eval)
2. Distance matrix correctness
3. Attention matrix properties (row-sum=1, temperature effect)
4. Centroid convergence to cluster means
5. Gradient flow (differentiability β key paper contribution)
6. Multi-dimensional clustering
7. Iterative convergence
8. Full compressor pipeline
9. Weight snapping for inference
10. Model export
11. Multi-step training stability
12. Paper configurations (Table 1)
13. K-means++ initialization
14. Warm start across batches
15. Numerical stability (large/small/uniform weights)
16. ResNet-like model compression
```bash
python tests/test_dkm.py
```
## Citation
```bibtex
@inproceedings{cho2022dkm,
title={DKM: Differentiable k-Means Clustering Layer for Neural Network Compression},
author={Cho, Minsik and Alizadeh-Vahid, Keivan and Adya, Saurabh and Rastegari, Mohammad},
booktitle={International Conference on Learning Representations (ICLR)},
year={2022},
url={https://openreview.net/forum?id=J_F_qqCE3Z5}
}
```
## License
This is a research implementation. The original paper is by Apple Research (Cho et al., ICLR 2022).
|