File size: 7,539 Bytes
d18bfc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
# DKM: Differentiable K-Means Clustering Layer for Neural Network Compression

**PyTorch implementation of the ICLR 2022 paper by Cho et al.**

πŸ“„ [Paper (arXiv:2108.12659)](https://arxiv.org/abs/2108.12659) | πŸ›οΈ [ICLR 2022](https://openreview.net/forum?id=J_F_qqCE3Z5)

## Overview

DKM casts **k-means weight clustering** as a differentiable **attention problem**, enabling joint optimization of DNN parameters and clustering centroids through standard backpropagation. Unlike prior weight-clustering methods that rely on hard assignments and approximated gradients, DKM uses soft attention-based assignment that is fully differentiable.

### Key Innovation

```
Traditional: weights β†’ hard k-means assignment β†’ fixed centroids (not differentiable)
DKM:         weights β†’ attention-based soft assignment β†’ differentiable centroids
```

The DKM layer:
1. Computes a **distance matrix** D between weights W and centroids C
2. Applies **softmax with temperature Ο„** to get attention matrix A = softmax(D/Ο„)
3. Updates centroids: c_j = Ξ£_i(a_ij Γ— w_i) / Ξ£_i(a_ij)
4. Iterates until convergence
5. Returns compressed weights: W̃ = A × C

### Paper Results

| Model | Config | Top-1 Acc (%) | Size (MB) | Compression |
|-------|--------|--------------|-----------|-------------|
| ResNet50 | cv:6/6, fc:6/4 | 74.5 | 3.32 | 29.4Γ— |
| MobileNet-v1 | cv:4/4, fc:4/2 | 63.9 | 0.72 | 22.4Γ— |
| MobileNet-v2 | cv:2/1, fc:4/4 | 68.0 | 0.84 | 15.8Γ— |
| DistilBERT | - | -1.1% acc drop | - | 11.8Γ— |

## Installation

```bash
git clone https://huggingface.co/syedmohaiminulhoque/dkm-compression
cd dkm-compression
pip install torch torchvision
```

## Quick Start

```python
import torch
import torch.nn as nn
from dkm import compress_model
from dkm.utils import print_compression_summary

# Load any pre-trained model
model = torchvision.models.resnet18(weights="DEFAULT")

# Compress with DKM (2-bit clustering)
compressor = compress_model(
    model, 
    bits=2,           # k = 2^bits = 4 clusters
    dim=1,            # scalar clustering (dim=1) or multi-dim
    tau=2e-5,         # temperature (controls softness of assignment)
    skip_first_last=True,  # skip first/last layers (per paper protocol)
)

# Print compression statistics
info = compressor.get_compression_info()
print_compression_summary(info)

# Train with standard PyTorch loop (paper: SGD, lr=0.008, momentum=0.9)
optimizer = torch.optim.SGD(compressor.parameters(), lr=0.008, momentum=0.9)
criterion = nn.CrossEntropyLoss()

compressor.train()
for images, labels in dataloader:
    optimizer.zero_grad()
    outputs = compressor(images)
    loss = criterion(outputs, labels)
    loss.backward()  # Gradients flow through DKM attention layers
    optimizer.step()

# Snap to nearest centroids for inference
compressor.snap_weights()

# Export compressed model (codebook + assignments)
export = compressor.export_compressed()
torch.save(export, "compressed_model.pt")
```

## Multi-Dimensional Clustering (Section 3.3)

DKM supports multi-dimensional weight clustering for higher compression:

```python
# Paper notation: "bits/dim" e.g., "4/4" means 4 bits, 4 dimensions
# Effective bits-per-weight = bits / dim

# Configuration cv:6/8, fc:6/4 (as in Table 3 of the paper)
compressor = compress_model(
    model,
    bits=6,
    conv_config={"bits": 6, "dim": 8},   # 6 bits, 8 dims β†’ 0.75 bpw
    fc_config={"bits": 6, "dim": 4},      # 6 bits, 4 dims β†’ 1.5 bpw
    tau=2e-5,
)
```

| Config | Clusters | Dim | Effective BPW |
|--------|----------|-----|---------------|
| 3-bit  | 8        | 1   | 3.0           |
| 2-bit  | 4        | 1   | 2.0           |
| 1-bit  | 2        | 1   | 1.0           |
| 4/4    | 16       | 4   | 1.0           |
| 8/8    | 256      | 8   | 1.0           |
| 4/8    | 16       | 8   | 0.5           |
| 8/16   | 256      | 16  | 0.5           |

## Temperature Ο„ Guidelines (Appendix B)

The temperature controls the softness of cluster assignment:
- **Smaller Ο„** β†’ harder assignment (near one-hot), closer to standard k-means
- **Larger Ο„** β†’ softer assignment, more gradient flow, better for hard compression tasks

| Model | 3-bit | 2-bit | 1-bit | 4/4 | 8/8 |
|-------|-------|-------|-------|-----|-----|
| ResNet18 | 8e-6 | 2e-5 | 5e-5 | 5e-5 | 8e-5 |
| ResNet50 | 8e-6 | 2e-5 | 5e-5 | 4e-5 | OOM |
| MobileNet-v1 | 5e-5 | 1e-4 | 3e-4 | 1e-4 | 1e-4 |
| MobileNet-v2 | 5e-5 | 1e-4 | 1.5e-4 | 1e-4 | 1e-4 |

## Architecture

```
dkm/
β”œβ”€β”€ __init__.py          # Package exports
β”œβ”€β”€ dkm_layer.py         # Core DKM layer (Section 3.2-3.3)
β”œβ”€β”€ compressor.py        # Model wrapper with DKM layers (Section 4)
└── utils.py             # Compression analysis utilities

tests/
└── test_dkm.py          # 16 comprehensive test groups (all passing)

train.py                 # Full training pipeline (CIFAR-10 demo)
```

### Core Components

- **`DKMLayer`**: The differentiable k-means clustering layer. Implements the iterative attention-based clustering from Fig. 2 of the paper, with k-means++ initialization, warm start across batches, and convergence checking.

- **`DKMCompressor`**: Wraps any PyTorch model by inserting DKM layers via forward pre-hooks. Handles per-layer configuration (different bits/dim for conv vs fc), the paper's protocol for small layers (<10K params β†’ 8-bit), and first/last layer skipping.

- **`compress_model`**: High-level API matching the paper's notation (cv:bits/dim, fc:bits/dim).

## Training Protocol (Section 4)

Following the paper exactly:
- **Optimizer**: SGD with momentum 0.9
- **Learning rate**: 0.008 (fixed, no per-layer tuning)
- **Loss**: Original task loss (no regularizers or modifications)
- **Epochs**: 200 for ImageNet, varies for GLUE
- **Batch size**: 128 per GPU (paper used 8Γ— V100)
- **Convergence**: Ξ΅ = 1e-4, max 5 DKM iterations per layer
- **Small layers**: Layers with <10,000 parameters get 8-bit clustering

## Compressed Model Format

After training, `export_compressed()` returns:
- **state_dict**: Standard PyTorch state dict (with snapped weights)
- **codebooks**: Per-layer centroid tensors (k Γ— d float32)
- **assignments**: Per-layer cluster index tensors (N/d integers, b bits each)
- **layer_configs**: Per-layer DKM configuration

The actual compressed size = Ξ£(codebook_bits + assignment_bits) per layer + uncompressed params.

## Tests

All 16 test groups pass, covering:
1. Shape preservation (train & eval)
2. Distance matrix correctness
3. Attention matrix properties (row-sum=1, temperature effect)
4. Centroid convergence to cluster means
5. Gradient flow (differentiability β€” key paper contribution)
6. Multi-dimensional clustering
7. Iterative convergence
8. Full compressor pipeline
9. Weight snapping for inference
10. Model export
11. Multi-step training stability
12. Paper configurations (Table 1)
13. K-means++ initialization
14. Warm start across batches
15. Numerical stability (large/small/uniform weights)
16. ResNet-like model compression

```bash
python tests/test_dkm.py
```

## Citation

```bibtex
@inproceedings{cho2022dkm,
  title={DKM: Differentiable k-Means Clustering Layer for Neural Network Compression},
  author={Cho, Minsik and Alizadeh-Vahid, Keivan and Adya, Saurabh and Rastegari, Mohammad},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2022},
  url={https://openreview.net/forum?id=J_F_qqCE3Z5}
}
```

## License

This is a research implementation. The original paper is by Apple Research (Cho et al., ICLR 2022).