| """ |
| DKM Layer: Differentiable K-Means Clustering Layer |
| |
| Implements the core algorithm from Section 3.2 and 3.3 of the paper: |
| 1. Compute distance matrix D between weights W and centroids C |
| 2. Apply softmax with temperature τ to get attention matrix A |
| 3. Update centroids: c_j = Σ_i(a_ij * w_i) / Σ_i(a_ij) |
| 4. Iterate until convergence or max iterations |
| 5. Compute compressed weights: W_tilde = A @ C |
| |
| Supports both 1D and multi-dimensional clustering (Section 3.3). |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
|
|
| class DKMLayer(nn.Module): |
| """ |
| Differentiable K-Means Clustering Layer. |
| |
| This layer performs differentiable weight clustering by casting k-means |
| as an attention problem. During training, soft assignment via attention |
| allows gradients to flow through the clustering process. During inference, |
| weights are snapped to nearest centroids (hard assignment). |
| |
| Args: |
| weight_tensor: The weight parameter to cluster (nn.Parameter) |
| n_clusters: Number of cluster centroids (k = 2^bits) |
| tau: Temperature for softmax attention (controls hardness of assignment) |
| dim: Dimension for multi-dimensional clustering (default=1 for scalar) |
| max_iter: Maximum number of DKM iterations per forward pass |
| epsilon: Convergence threshold for centroid updates |
| init_method: Centroid initialization method ('kmeans++' or 'random') |
| """ |
| |
| def __init__( |
| self, |
| weight_tensor: nn.Parameter, |
| n_clusters: int = 16, |
| tau: float = 2e-5, |
| dim: int = 1, |
| max_iter: int = 5, |
| epsilon: float = 1e-4, |
| init_method: str = "kmeans++", |
| ): |
| super().__init__() |
| |
| self.n_clusters = n_clusters |
| self.tau = tau |
| self.dim = dim |
| self.max_iter = max_iter |
| self.epsilon = epsilon |
| self.init_method = init_method |
| |
| |
| self.weight = weight_tensor |
| self.original_shape = weight_tensor.shape |
| |
| |
| n_elements = weight_tensor.numel() |
| if n_elements % dim != 0: |
| raise ValueError( |
| f"Weight tensor has {n_elements} elements, which is not " |
| f"divisible by dim={dim}. Choose a different dim." |
| ) |
| |
| self.n_vectors = n_elements // dim |
| |
| |
| |
| centroids = self._initialize_centroids(weight_tensor) |
| self.register_buffer("centroids", centroids) |
| |
| |
| self.register_buffer("initialized", torch.tensor(False)) |
| |
| def _initialize_centroids(self, weight_tensor: torch.Tensor) -> torch.Tensor: |
| """ |
| Initialize centroids using k-means++ or random selection. |
| |
| For multi-dimensional clustering, weights are reshaped into |
| (N/d, d) sub-vectors before selecting initial centroids. |
| """ |
| with torch.no_grad(): |
| |
| flat_weights = weight_tensor.detach().reshape(-1) |
| |
| if self.dim == 1: |
| vectors = flat_weights.unsqueeze(1) |
| else: |
| vectors = flat_weights.reshape(self.n_vectors, self.dim) |
| |
| if self.init_method == "kmeans++": |
| centroids = self._kmeans_plus_plus_init(vectors) |
| else: |
| |
| indices = torch.randperm(vectors.shape[0])[:self.n_clusters] |
| centroids = vectors[indices].clone() |
| |
| return centroids |
| |
| def _kmeans_plus_plus_init(self, vectors: torch.Tensor) -> torch.Tensor: |
| """ |
| K-means++ initialization for better centroid starting positions. |
| |
| Args: |
| vectors: (N, d) weight vectors |
| |
| Returns: |
| centroids: (k, d) initial centroids |
| """ |
| n = vectors.shape[0] |
| k = self.n_clusters |
| d = vectors.shape[1] |
| |
| |
| if n < k: |
| repeats = (k // n) + 1 |
| vectors_expanded = vectors.repeat(repeats, 1)[:k] |
| noise = torch.randn_like(vectors_expanded) * 1e-6 |
| return (vectors_expanded + noise).clone() |
| |
| |
| idx = torch.randint(0, n, (1,)).item() |
| centroids = [vectors[idx].clone()] |
| |
| for _ in range(1, k): |
| |
| stacked = torch.stack(centroids, dim=0) |
| |
| dists = torch.cdist(vectors.unsqueeze(0), stacked.unsqueeze(0)).squeeze(0) |
| min_dists = dists.min(dim=1).values |
| |
| |
| probs = min_dists ** 2 |
| prob_sum = probs.sum() |
| |
| if prob_sum < 1e-30 or torch.isnan(prob_sum) or torch.isinf(prob_sum): |
| |
| |
| idx = torch.randint(0, n, (1,)).item() |
| else: |
| probs = probs / prob_sum |
| |
| probs = probs.clamp(min=0.0) |
| idx = torch.multinomial(probs, 1).item() |
| |
| centroids.append(vectors[idx].clone()) |
| |
| result = torch.stack(centroids, dim=0) |
| |
| |
| if result.unique(dim=0).shape[0] < k: |
| noise = torch.randn_like(result) * (result.abs().mean() * 1e-4 + 1e-8) |
| result = result + noise |
| |
| return result |
| |
| def _compute_distance_matrix( |
| self, weights: torch.Tensor, centroids: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Compute negative squared Euclidean distance matrix D. |
| |
| D[i,j] = -||w_i - c_j||^2 |
| |
| Per the paper: d_ij = -f(w_i, c_j) where f is Euclidean distance. |
| We use squared Euclidean for efficiency (equivalent for softmax). |
| |
| Args: |
| weights: (N/d, d) weight sub-vectors |
| centroids: (k, d) centroid vectors |
| |
| Returns: |
| D: (N/d, k) negative distance matrix |
| """ |
| |
| w_sq = (weights ** 2).sum(dim=1, keepdim=True) |
| c_sq = (centroids ** 2).sum(dim=1, keepdim=True).t() |
| wc = weights @ centroids.t() |
| |
| D = -(w_sq - 2 * wc + c_sq) |
| return D |
| |
| def _compute_attention(self, D: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute attention matrix A from distance matrix D using softmax with temperature τ. |
| |
| a_ij = exp(d_ij / τ) / Σ_k exp(d_ik / τ) |
| |
| This is the key differentiable component that enables gradient flow |
| through the clustering assignment. |
| |
| Args: |
| D: (N/d, k) negative distance matrix |
| |
| Returns: |
| A: (N/d, k) attention matrix (rows sum to 1) |
| """ |
| |
| A = F.softmax(D / self.tau, dim=1) |
| return A |
| |
| def _update_centroids( |
| self, A: torch.Tensor, weights: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Update centroids using attention-weighted average of weights. |
| |
| c_j_new = Σ_i(a_ij * w_i) / Σ_i(a_ij) |
| |
| This is the M-step equivalent in the EM interpretation (Appendix G). |
| |
| Args: |
| A: (N/d, k) attention matrix |
| weights: (N/d, d) weight sub-vectors |
| |
| Returns: |
| new_centroids: (k, d) updated centroids |
| """ |
| |
| |
| numerator = A.t() @ weights |
| |
| |
| denominator = A.sum(dim=0, keepdim=True).t() |
| |
| |
| denominator = denominator.clamp(min=1e-10) |
| |
| new_centroids = numerator / denominator |
| return new_centroids |
| |
| def forward(self, weight_override: torch.Tensor = None) -> torch.Tensor: |
| """ |
| Forward pass: perform differentiable k-means clustering. |
| |
| The iterative process (Fig. 2 of the paper): |
| 1. Compute distance matrix D between weights and centroids |
| 2. Compute attention matrix A = softmax(D/τ) |
| 3. Update centroids: C_new = (A^T @ W) / sum(A) |
| 4. Repeat until convergence or max_iter reached |
| 5. Return compressed weights: W_tilde = A @ C |
| |
| During training: returns soft-assigned weights (differentiable) |
| During eval: returns hard-assigned weights (nearest centroid) |
| |
| Args: |
| weight_override: Optional weight tensor to use instead of self.weight |
| |
| Returns: |
| compressed_weight: Tensor with same shape as original weight |
| """ |
| weight_tensor = weight_override if weight_override is not None else self.weight |
| |
| |
| flat_weights = weight_tensor.reshape(-1) |
| if self.dim == 1: |
| W = flat_weights.unsqueeze(1) |
| else: |
| W = flat_weights.reshape(self.n_vectors, self.dim) |
| |
| |
| if not self.initialized: |
| self.centroids = self._initialize_centroids(weight_tensor).to( |
| weight_tensor.device |
| ) |
| self.initialized = torch.tensor(True, device=weight_tensor.device) |
| |
| |
| C = self.centroids.clone() |
| |
| |
| for iteration in range(self.max_iter): |
| |
| D = self._compute_distance_matrix(W, C) |
| |
| |
| A = self._compute_attention(D) |
| |
| |
| C_new = self._update_centroids(A, W) |
| |
| |
| delta = (C - C_new).abs().max().item() |
| C = C_new |
| |
| if delta <= self.epsilon: |
| break |
| |
| |
| self.centroids = C.detach().clone() |
| |
| if self.training: |
| |
| |
| W_tilde = A @ C |
| else: |
| |
| |
| D_final = self._compute_distance_matrix(W, C) |
| assignments = D_final.argmax(dim=1) |
| W_tilde = C[assignments] |
| |
| |
| compressed_weight = W_tilde.reshape(self.original_shape) |
| |
| return compressed_weight |
| |
| def get_assignments(self) -> torch.Tensor: |
| """ |
| Get hard cluster assignments for each weight (for inference/analysis). |
| |
| Returns: |
| assignments: (N/d,) tensor of cluster indices |
| """ |
| with torch.no_grad(): |
| flat_weights = self.weight.detach().reshape(-1) |
| if self.dim == 1: |
| W = flat_weights.unsqueeze(1) |
| else: |
| W = flat_weights.reshape(self.n_vectors, self.dim) |
| |
| D = self._compute_distance_matrix(W, self.centroids) |
| assignments = D.argmax(dim=1) |
| |
| return assignments |
| |
| def get_codebook(self) -> torch.Tensor: |
| """ |
| Get the current centroid codebook. |
| |
| Returns: |
| centroids: (k, d) centroid values |
| """ |
| return self.centroids.clone() |
| |
| def extra_repr(self) -> str: |
| bits = math.log2(self.n_clusters) |
| bpw = bits / self.dim |
| return ( |
| f"n_clusters={self.n_clusters}, tau={self.tau}, dim={self.dim}, " |
| f"max_iter={self.max_iter}, eps={self.epsilon}, " |
| f"bits={bits:.1f}, bits_per_weight={bpw:.2f}, " |
| f"weight_shape={list(self.original_shape)}" |
| ) |
|
|