syedmohaiminulhoque commited on
Commit
f5e358d
·
verified ·
1 Parent(s): 85967a5

Add core DKM layer implementation

Browse files
Files changed (1) hide show
  1. dkm/dkm_layer.py +349 -0
dkm/dkm_layer.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DKM Layer: Differentiable K-Means Clustering Layer
3
+
4
+ Implements the core algorithm from Section 3.2 and 3.3 of the paper:
5
+ 1. Compute distance matrix D between weights W and centroids C
6
+ 2. Apply softmax with temperature τ to get attention matrix A
7
+ 3. Update centroids: c_j = Σ_i(a_ij * w_i) / Σ_i(a_ij)
8
+ 4. Iterate until convergence or max iterations
9
+ 5. Compute compressed weights: W_tilde = A @ C
10
+
11
+ Supports both 1D and multi-dimensional clustering (Section 3.3).
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import math
18
+
19
+
20
+ class DKMLayer(nn.Module):
21
+ """
22
+ Differentiable K-Means Clustering Layer.
23
+
24
+ This layer performs differentiable weight clustering by casting k-means
25
+ as an attention problem. During training, soft assignment via attention
26
+ allows gradients to flow through the clustering process. During inference,
27
+ weights are snapped to nearest centroids (hard assignment).
28
+
29
+ Args:
30
+ weight_tensor: The weight parameter to cluster (nn.Parameter)
31
+ n_clusters: Number of cluster centroids (k = 2^bits)
32
+ tau: Temperature for softmax attention (controls hardness of assignment)
33
+ dim: Dimension for multi-dimensional clustering (default=1 for scalar)
34
+ max_iter: Maximum number of DKM iterations per forward pass
35
+ epsilon: Convergence threshold for centroid updates
36
+ init_method: Centroid initialization method ('kmeans++' or 'random')
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ weight_tensor: nn.Parameter,
42
+ n_clusters: int = 16,
43
+ tau: float = 2e-5,
44
+ dim: int = 1,
45
+ max_iter: int = 5,
46
+ epsilon: float = 1e-4,
47
+ init_method: str = "kmeans++",
48
+ ):
49
+ super().__init__()
50
+
51
+ self.n_clusters = n_clusters
52
+ self.tau = tau
53
+ self.dim = dim
54
+ self.max_iter = max_iter
55
+ self.epsilon = epsilon
56
+ self.init_method = init_method
57
+
58
+ # Store reference to the weight parameter (not a copy)
59
+ self.weight = weight_tensor
60
+ self.original_shape = weight_tensor.shape
61
+
62
+ # Validate dimensions
63
+ n_elements = weight_tensor.numel()
64
+ if n_elements % dim != 0:
65
+ raise ValueError(
66
+ f"Weight tensor has {n_elements} elements, which is not "
67
+ f"divisible by dim={dim}. Choose a different dim."
68
+ )
69
+
70
+ self.n_vectors = n_elements // dim
71
+
72
+ # Initialize centroids as a buffer (not a parameter - updated by DKM iterations)
73
+ # Shape: (n_clusters, dim) for multi-dim, (n_clusters, 1) for scalar
74
+ centroids = self._initialize_centroids(weight_tensor)
75
+ self.register_buffer("centroids", centroids)
76
+
77
+ # Track whether this is the first forward pass
78
+ self.register_buffer("initialized", torch.tensor(False))
79
+
80
+ def _initialize_centroids(self, weight_tensor: torch.Tensor) -> torch.Tensor:
81
+ """
82
+ Initialize centroids using k-means++ or random selection.
83
+
84
+ For multi-dimensional clustering, weights are reshaped into
85
+ (N/d, d) sub-vectors before selecting initial centroids.
86
+ """
87
+ with torch.no_grad():
88
+ # Flatten and reshape for multi-dim clustering
89
+ flat_weights = weight_tensor.detach().reshape(-1)
90
+
91
+ if self.dim == 1:
92
+ vectors = flat_weights.unsqueeze(1) # (N, 1)
93
+ else:
94
+ vectors = flat_weights.reshape(self.n_vectors, self.dim) # (N/d, d)
95
+
96
+ if self.init_method == "kmeans++":
97
+ centroids = self._kmeans_plus_plus_init(vectors)
98
+ else:
99
+ # Random initialization: select k random weight vectors
100
+ indices = torch.randperm(vectors.shape[0])[:self.n_clusters]
101
+ centroids = vectors[indices].clone()
102
+
103
+ return centroids # (n_clusters, dim)
104
+
105
+ def _kmeans_plus_plus_init(self, vectors: torch.Tensor) -> torch.Tensor:
106
+ """
107
+ K-means++ initialization for better centroid starting positions.
108
+
109
+ Args:
110
+ vectors: (N, d) weight vectors
111
+
112
+ Returns:
113
+ centroids: (k, d) initial centroids
114
+ """
115
+ n = vectors.shape[0]
116
+ k = self.n_clusters
117
+ d = vectors.shape[1]
118
+
119
+ # If fewer vectors than clusters, duplicate with small noise
120
+ if n < k:
121
+ repeats = (k // n) + 1
122
+ vectors_expanded = vectors.repeat(repeats, 1)[:k]
123
+ noise = torch.randn_like(vectors_expanded) * 1e-6
124
+ return (vectors_expanded + noise).clone()
125
+
126
+ # Choose first centroid randomly
127
+ idx = torch.randint(0, n, (1,)).item()
128
+ centroids = [vectors[idx].clone()]
129
+
130
+ for _ in range(1, k):
131
+ # Compute distances to nearest existing centroid
132
+ stacked = torch.stack(centroids, dim=0) # (current_k, d)
133
+ # distances: (N, current_k)
134
+ dists = torch.cdist(vectors.unsqueeze(0), stacked.unsqueeze(0)).squeeze(0)
135
+ min_dists = dists.min(dim=1).values # (N,)
136
+
137
+ # Choose next centroid with probability proportional to distance squared
138
+ probs = min_dists ** 2
139
+ prob_sum = probs.sum()
140
+
141
+ if prob_sum < 1e-30 or torch.isnan(prob_sum) or torch.isinf(prob_sum):
142
+ # Fallback: all distances are zero (e.g., uniform weights)
143
+ # Select a random unused index and add small noise to break ties
144
+ idx = torch.randint(0, n, (1,)).item()
145
+ else:
146
+ probs = probs / prob_sum
147
+ # Clamp to avoid negative values from float errors
148
+ probs = probs.clamp(min=0.0)
149
+ idx = torch.multinomial(probs, 1).item()
150
+
151
+ centroids.append(vectors[idx].clone())
152
+
153
+ result = torch.stack(centroids, dim=0) # (k, d)
154
+
155
+ # Add tiny noise to break ties if centroids are identical
156
+ if result.unique(dim=0).shape[0] < k:
157
+ noise = torch.randn_like(result) * (result.abs().mean() * 1e-4 + 1e-8)
158
+ result = result + noise
159
+
160
+ return result
161
+
162
+ def _compute_distance_matrix(
163
+ self, weights: torch.Tensor, centroids: torch.Tensor
164
+ ) -> torch.Tensor:
165
+ """
166
+ Compute negative squared Euclidean distance matrix D.
167
+
168
+ D[i,j] = -||w_i - c_j||^2
169
+
170
+ Per the paper: d_ij = -f(w_i, c_j) where f is Euclidean distance.
171
+ We use squared Euclidean for efficiency (equivalent for softmax).
172
+
173
+ Args:
174
+ weights: (N/d, d) weight sub-vectors
175
+ centroids: (k, d) centroid vectors
176
+
177
+ Returns:
178
+ D: (N/d, k) negative distance matrix
179
+ """
180
+ # Efficient computation: ||w - c||^2 = ||w||^2 - 2*w.c + ||c||^2
181
+ w_sq = (weights ** 2).sum(dim=1, keepdim=True) # (N/d, 1)
182
+ c_sq = (centroids ** 2).sum(dim=1, keepdim=True).t() # (1, k)
183
+ wc = weights @ centroids.t() # (N/d, k)
184
+
185
+ D = -(w_sq - 2 * wc + c_sq) # (N/d, k) — negative squared Euclidean
186
+ return D
187
+
188
+ def _compute_attention(self, D: torch.Tensor) -> torch.Tensor:
189
+ """
190
+ Compute attention matrix A from distance matrix D using softmax with temperature τ.
191
+
192
+ a_ij = exp(d_ij / τ) / Σ_k exp(d_ik / τ)
193
+
194
+ This is the key differentiable component that enables gradient flow
195
+ through the clustering assignment.
196
+
197
+ Args:
198
+ D: (N/d, k) negative distance matrix
199
+
200
+ Returns:
201
+ A: (N/d, k) attention matrix (rows sum to 1)
202
+ """
203
+ # Scale by temperature and apply softmax along cluster dimension
204
+ A = F.softmax(D / self.tau, dim=1) # (N/d, k)
205
+ return A
206
+
207
+ def _update_centroids(
208
+ self, A: torch.Tensor, weights: torch.Tensor
209
+ ) -> torch.Tensor:
210
+ """
211
+ Update centroids using attention-weighted average of weights.
212
+
213
+ c_j_new = Σ_i(a_ij * w_i) / Σ_i(a_ij)
214
+
215
+ This is the M-step equivalent in the EM interpretation (Appendix G).
216
+
217
+ Args:
218
+ A: (N/d, k) attention matrix
219
+ weights: (N/d, d) weight sub-vectors
220
+
221
+ Returns:
222
+ new_centroids: (k, d) updated centroids
223
+ """
224
+ # Numerator: Σ_i(a_ij * w_i) for each centroid j
225
+ # A.t() @ weights: (k, N/d) @ (N/d, d) = (k, d)
226
+ numerator = A.t() @ weights # (k, d)
227
+
228
+ # Denominator: Σ_i(a_ij) for each centroid j
229
+ denominator = A.sum(dim=0, keepdim=True).t() # (k, 1)
230
+
231
+ # Avoid division by zero
232
+ denominator = denominator.clamp(min=1e-10)
233
+
234
+ new_centroids = numerator / denominator # (k, d)
235
+ return new_centroids
236
+
237
+ def forward(self, weight_override: torch.Tensor = None) -> torch.Tensor:
238
+ """
239
+ Forward pass: perform differentiable k-means clustering.
240
+
241
+ The iterative process (Fig. 2 of the paper):
242
+ 1. Compute distance matrix D between weights and centroids
243
+ 2. Compute attention matrix A = softmax(D/τ)
244
+ 3. Update centroids: C_new = (A^T @ W) / sum(A)
245
+ 4. Repeat until convergence or max_iter reached
246
+ 5. Return compressed weights: W_tilde = A @ C
247
+
248
+ During training: returns soft-assigned weights (differentiable)
249
+ During eval: returns hard-assigned weights (nearest centroid)
250
+
251
+ Args:
252
+ weight_override: Optional weight tensor to use instead of self.weight
253
+
254
+ Returns:
255
+ compressed_weight: Tensor with same shape as original weight
256
+ """
257
+ weight_tensor = weight_override if weight_override is not None else self.weight
258
+
259
+ # Reshape weights into sub-vectors for multi-dim clustering
260
+ flat_weights = weight_tensor.reshape(-1)
261
+ if self.dim == 1:
262
+ W = flat_weights.unsqueeze(1) # (N, 1)
263
+ else:
264
+ W = flat_weights.reshape(self.n_vectors, self.dim) # (N/d, d)
265
+
266
+ # Re-initialize centroids on first call (ensures correct device/dtype)
267
+ if not self.initialized:
268
+ self.centroids = self._initialize_centroids(weight_tensor).to(
269
+ weight_tensor.device
270
+ )
271
+ self.initialized = torch.tensor(True, device=weight_tensor.device)
272
+
273
+ # Current centroids (detached from previous iteration's graph)
274
+ C = self.centroids.clone()
275
+
276
+ # Iterative DKM clustering (Section 3.2, Fig. 2)
277
+ for iteration in range(self.max_iter):
278
+ # Step 1: Compute distance matrix
279
+ D = self._compute_distance_matrix(W, C)
280
+
281
+ # Step 2: Compute attention matrix with temperature τ
282
+ A = self._compute_attention(D)
283
+
284
+ # Step 3: Update centroids
285
+ C_new = self._update_centroids(A, W)
286
+
287
+ # Step 4: Check convergence |C - C_new| ≤ ε
288
+ delta = (C - C_new).abs().max().item()
289
+ C = C_new
290
+
291
+ if delta <= self.epsilon:
292
+ break
293
+
294
+ # Store converged centroids for next batch (warm start)
295
+ self.centroids = C.detach().clone()
296
+
297
+ if self.training:
298
+ # Training: use soft assignment (differentiable)
299
+ # W_tilde = A @ C (attention-weighted centroids)
300
+ W_tilde = A @ C # (N/d, d)
301
+ else:
302
+ # Inference: snap to nearest centroid (hard assignment)
303
+ # Find nearest centroid for each weight vector
304
+ D_final = self._compute_distance_matrix(W, C)
305
+ assignments = D_final.argmax(dim=1) # (N/d,) — argmax because D is negative distance
306
+ W_tilde = C[assignments] # (N/d, d)
307
+
308
+ # Reshape back to original weight shape
309
+ compressed_weight = W_tilde.reshape(self.original_shape)
310
+
311
+ return compressed_weight
312
+
313
+ def get_assignments(self) -> torch.Tensor:
314
+ """
315
+ Get hard cluster assignments for each weight (for inference/analysis).
316
+
317
+ Returns:
318
+ assignments: (N/d,) tensor of cluster indices
319
+ """
320
+ with torch.no_grad():
321
+ flat_weights = self.weight.detach().reshape(-1)
322
+ if self.dim == 1:
323
+ W = flat_weights.unsqueeze(1)
324
+ else:
325
+ W = flat_weights.reshape(self.n_vectors, self.dim)
326
+
327
+ D = self._compute_distance_matrix(W, self.centroids)
328
+ assignments = D.argmax(dim=1)
329
+
330
+ return assignments
331
+
332
+ def get_codebook(self) -> torch.Tensor:
333
+ """
334
+ Get the current centroid codebook.
335
+
336
+ Returns:
337
+ centroids: (k, d) centroid values
338
+ """
339
+ return self.centroids.clone()
340
+
341
+ def extra_repr(self) -> str:
342
+ bits = math.log2(self.n_clusters)
343
+ bpw = bits / self.dim
344
+ return (
345
+ f"n_clusters={self.n_clusters}, tau={self.tau}, dim={self.dim}, "
346
+ f"max_iter={self.max_iter}, eps={self.epsilon}, "
347
+ f"bits={bits:.1f}, bits_per_weight={bpw:.2f}, "
348
+ f"weight_shape={list(self.original_shape)}"
349
+ )