File size: 13,394 Bytes
f5e358d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
"""
DKM Layer: Differentiable K-Means Clustering Layer

Implements the core algorithm from Section 3.2 and 3.3 of the paper:
  1. Compute distance matrix D between weights W and centroids C
  2. Apply softmax with temperature τ to get attention matrix A
  3. Update centroids: c_j = Σ_i(a_ij * w_i) / Σ_i(a_ij)
  4. Iterate until convergence or max iterations
  5. Compute compressed weights: W_tilde = A @ C

Supports both 1D and multi-dimensional clustering (Section 3.3).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class DKMLayer(nn.Module):
    """
    Differentiable K-Means Clustering Layer.
    
    This layer performs differentiable weight clustering by casting k-means
    as an attention problem. During training, soft assignment via attention
    allows gradients to flow through the clustering process. During inference,
    weights are snapped to nearest centroids (hard assignment).
    
    Args:
        weight_tensor: The weight parameter to cluster (nn.Parameter)
        n_clusters: Number of cluster centroids (k = 2^bits)
        tau: Temperature for softmax attention (controls hardness of assignment)
        dim: Dimension for multi-dimensional clustering (default=1 for scalar)
        max_iter: Maximum number of DKM iterations per forward pass
        epsilon: Convergence threshold for centroid updates
        init_method: Centroid initialization method ('kmeans++' or 'random')
    """
    
    def __init__(
        self,
        weight_tensor: nn.Parameter,
        n_clusters: int = 16,
        tau: float = 2e-5,
        dim: int = 1,
        max_iter: int = 5,
        epsilon: float = 1e-4,
        init_method: str = "kmeans++",
    ):
        super().__init__()
        
        self.n_clusters = n_clusters
        self.tau = tau
        self.dim = dim
        self.max_iter = max_iter
        self.epsilon = epsilon
        self.init_method = init_method
        
        # Store reference to the weight parameter (not a copy)
        self.weight = weight_tensor
        self.original_shape = weight_tensor.shape
        
        # Validate dimensions
        n_elements = weight_tensor.numel()
        if n_elements % dim != 0:
            raise ValueError(
                f"Weight tensor has {n_elements} elements, which is not "
                f"divisible by dim={dim}. Choose a different dim."
            )
        
        self.n_vectors = n_elements // dim
        
        # Initialize centroids as a buffer (not a parameter - updated by DKM iterations)
        # Shape: (n_clusters, dim) for multi-dim, (n_clusters, 1) for scalar
        centroids = self._initialize_centroids(weight_tensor)
        self.register_buffer("centroids", centroids)
        
        # Track whether this is the first forward pass
        self.register_buffer("initialized", torch.tensor(False))
    
    def _initialize_centroids(self, weight_tensor: torch.Tensor) -> torch.Tensor:
        """
        Initialize centroids using k-means++ or random selection.
        
        For multi-dimensional clustering, weights are reshaped into 
        (N/d, d) sub-vectors before selecting initial centroids.
        """
        with torch.no_grad():
            # Flatten and reshape for multi-dim clustering
            flat_weights = weight_tensor.detach().reshape(-1)
            
            if self.dim == 1:
                vectors = flat_weights.unsqueeze(1)  # (N, 1)
            else:
                vectors = flat_weights.reshape(self.n_vectors, self.dim)  # (N/d, d)
            
            if self.init_method == "kmeans++":
                centroids = self._kmeans_plus_plus_init(vectors)
            else:
                # Random initialization: select k random weight vectors
                indices = torch.randperm(vectors.shape[0])[:self.n_clusters]
                centroids = vectors[indices].clone()
            
            return centroids  # (n_clusters, dim)
    
    def _kmeans_plus_plus_init(self, vectors: torch.Tensor) -> torch.Tensor:
        """
        K-means++ initialization for better centroid starting positions.
        
        Args:
            vectors: (N, d) weight vectors
        
        Returns:
            centroids: (k, d) initial centroids
        """
        n = vectors.shape[0]
        k = self.n_clusters
        d = vectors.shape[1]
        
        # If fewer vectors than clusters, duplicate with small noise
        if n < k:
            repeats = (k // n) + 1
            vectors_expanded = vectors.repeat(repeats, 1)[:k]
            noise = torch.randn_like(vectors_expanded) * 1e-6
            return (vectors_expanded + noise).clone()
        
        # Choose first centroid randomly
        idx = torch.randint(0, n, (1,)).item()
        centroids = [vectors[idx].clone()]
        
        for _ in range(1, k):
            # Compute distances to nearest existing centroid
            stacked = torch.stack(centroids, dim=0)  # (current_k, d)
            # distances: (N, current_k)
            dists = torch.cdist(vectors.unsqueeze(0), stacked.unsqueeze(0)).squeeze(0)
            min_dists = dists.min(dim=1).values  # (N,)
            
            # Choose next centroid with probability proportional to distance squared
            probs = min_dists ** 2
            prob_sum = probs.sum()
            
            if prob_sum < 1e-30 or torch.isnan(prob_sum) or torch.isinf(prob_sum):
                # Fallback: all distances are zero (e.g., uniform weights)
                # Select a random unused index and add small noise to break ties
                idx = torch.randint(0, n, (1,)).item()
            else:
                probs = probs / prob_sum
                # Clamp to avoid negative values from float errors
                probs = probs.clamp(min=0.0)
                idx = torch.multinomial(probs, 1).item()
            
            centroids.append(vectors[idx].clone())
        
        result = torch.stack(centroids, dim=0)  # (k, d)
        
        # Add tiny noise to break ties if centroids are identical
        if result.unique(dim=0).shape[0] < k:
            noise = torch.randn_like(result) * (result.abs().mean() * 1e-4 + 1e-8)
            result = result + noise
        
        return result
    
    def _compute_distance_matrix(
        self, weights: torch.Tensor, centroids: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute negative squared Euclidean distance matrix D.
        
        D[i,j] = -||w_i - c_j||^2
        
        Per the paper: d_ij = -f(w_i, c_j) where f is Euclidean distance.
        We use squared Euclidean for efficiency (equivalent for softmax).
        
        Args:
            weights: (N/d, d) weight sub-vectors
            centroids: (k, d) centroid vectors
        
        Returns:
            D: (N/d, k) negative distance matrix
        """
        # Efficient computation: ||w - c||^2 = ||w||^2 - 2*w.c + ||c||^2
        w_sq = (weights ** 2).sum(dim=1, keepdim=True)     # (N/d, 1)
        c_sq = (centroids ** 2).sum(dim=1, keepdim=True).t()  # (1, k)
        wc = weights @ centroids.t()                         # (N/d, k)
        
        D = -(w_sq - 2 * wc + c_sq)  # (N/d, k) — negative squared Euclidean
        return D
    
    def _compute_attention(self, D: torch.Tensor) -> torch.Tensor:
        """
        Compute attention matrix A from distance matrix D using softmax with temperature τ.
        
        a_ij = exp(d_ij / τ) / Σ_k exp(d_ik / τ)
        
        This is the key differentiable component that enables gradient flow
        through the clustering assignment.
        
        Args:
            D: (N/d, k) negative distance matrix
        
        Returns:
            A: (N/d, k) attention matrix (rows sum to 1)
        """
        # Scale by temperature and apply softmax along cluster dimension
        A = F.softmax(D / self.tau, dim=1)  # (N/d, k)
        return A
    
    def _update_centroids(
        self, A: torch.Tensor, weights: torch.Tensor
    ) -> torch.Tensor:
        """
        Update centroids using attention-weighted average of weights.
        
        c_j_new = Σ_i(a_ij * w_i) / Σ_i(a_ij)
        
        This is the M-step equivalent in the EM interpretation (Appendix G).
        
        Args:
            A: (N/d, k) attention matrix
            weights: (N/d, d) weight sub-vectors
        
        Returns:
            new_centroids: (k, d) updated centroids
        """
        # Numerator: Σ_i(a_ij * w_i) for each centroid j
        # A.t() @ weights: (k, N/d) @ (N/d, d) = (k, d)
        numerator = A.t() @ weights  # (k, d)
        
        # Denominator: Σ_i(a_ij) for each centroid j
        denominator = A.sum(dim=0, keepdim=True).t()  # (k, 1)
        
        # Avoid division by zero
        denominator = denominator.clamp(min=1e-10)
        
        new_centroids = numerator / denominator  # (k, d)
        return new_centroids
    
    def forward(self, weight_override: torch.Tensor = None) -> torch.Tensor:
        """
        Forward pass: perform differentiable k-means clustering.
        
        The iterative process (Fig. 2 of the paper):
        1. Compute distance matrix D between weights and centroids
        2. Compute attention matrix A = softmax(D/τ)
        3. Update centroids: C_new = (A^T @ W) / sum(A)
        4. Repeat until convergence or max_iter reached
        5. Return compressed weights: W_tilde = A @ C
        
        During training: returns soft-assigned weights (differentiable)
        During eval: returns hard-assigned weights (nearest centroid)
        
        Args:
            weight_override: Optional weight tensor to use instead of self.weight
        
        Returns:
            compressed_weight: Tensor with same shape as original weight
        """
        weight_tensor = weight_override if weight_override is not None else self.weight
        
        # Reshape weights into sub-vectors for multi-dim clustering
        flat_weights = weight_tensor.reshape(-1)
        if self.dim == 1:
            W = flat_weights.unsqueeze(1)  # (N, 1)
        else:
            W = flat_weights.reshape(self.n_vectors, self.dim)  # (N/d, d)
        
        # Re-initialize centroids on first call (ensures correct device/dtype)
        if not self.initialized:
            self.centroids = self._initialize_centroids(weight_tensor).to(
                weight_tensor.device
            )
            self.initialized = torch.tensor(True, device=weight_tensor.device)
        
        # Current centroids (detached from previous iteration's graph)
        C = self.centroids.clone()
        
        # Iterative DKM clustering (Section 3.2, Fig. 2)
        for iteration in range(self.max_iter):
            # Step 1: Compute distance matrix
            D = self._compute_distance_matrix(W, C)
            
            # Step 2: Compute attention matrix with temperature τ
            A = self._compute_attention(D)
            
            # Step 3: Update centroids
            C_new = self._update_centroids(A, W)
            
            # Step 4: Check convergence |C - C_new| ≤ ε
            delta = (C - C_new).abs().max().item()
            C = C_new
            
            if delta <= self.epsilon:
                break
        
        # Store converged centroids for next batch (warm start)
        self.centroids = C.detach().clone()
        
        if self.training:
            # Training: use soft assignment (differentiable)
            # W_tilde = A @ C (attention-weighted centroids)
            W_tilde = A @ C  # (N/d, d)
        else:
            # Inference: snap to nearest centroid (hard assignment)
            # Find nearest centroid for each weight vector
            D_final = self._compute_distance_matrix(W, C)
            assignments = D_final.argmax(dim=1)  # (N/d,) — argmax because D is negative distance
            W_tilde = C[assignments]  # (N/d, d)
        
        # Reshape back to original weight shape
        compressed_weight = W_tilde.reshape(self.original_shape)
        
        return compressed_weight
    
    def get_assignments(self) -> torch.Tensor:
        """
        Get hard cluster assignments for each weight (for inference/analysis).
        
        Returns:
            assignments: (N/d,) tensor of cluster indices
        """
        with torch.no_grad():
            flat_weights = self.weight.detach().reshape(-1)
            if self.dim == 1:
                W = flat_weights.unsqueeze(1)
            else:
                W = flat_weights.reshape(self.n_vectors, self.dim)
            
            D = self._compute_distance_matrix(W, self.centroids)
            assignments = D.argmax(dim=1)
        
        return assignments
    
    def get_codebook(self) -> torch.Tensor:
        """
        Get the current centroid codebook.
        
        Returns:
            centroids: (k, d) centroid values
        """
        return self.centroids.clone()
    
    def extra_repr(self) -> str:
        bits = math.log2(self.n_clusters)
        bpw = bits / self.dim
        return (
            f"n_clusters={self.n_clusters}, tau={self.tau}, dim={self.dim}, "
            f"max_iter={self.max_iter}, eps={self.epsilon}, "
            f"bits={bits:.1f}, bits_per_weight={bpw:.2f}, "
            f"weight_shape={list(self.original_shape)}"
        )