""" DKM Layer: Differentiable K-Means Clustering Layer Implements the core algorithm from Section 3.2 and 3.3 of the paper: 1. Compute distance matrix D between weights W and centroids C 2. Apply softmax with temperature τ to get attention matrix A 3. Update centroids: c_j = Σ_i(a_ij * w_i) / Σ_i(a_ij) 4. Iterate until convergence or max iterations 5. Compute compressed weights: W_tilde = A @ C Supports both 1D and multi-dimensional clustering (Section 3.3). """ import torch import torch.nn as nn import torch.nn.functional as F import math class DKMLayer(nn.Module): """ Differentiable K-Means Clustering Layer. This layer performs differentiable weight clustering by casting k-means as an attention problem. During training, soft assignment via attention allows gradients to flow through the clustering process. During inference, weights are snapped to nearest centroids (hard assignment). Args: weight_tensor: The weight parameter to cluster (nn.Parameter) n_clusters: Number of cluster centroids (k = 2^bits) tau: Temperature for softmax attention (controls hardness of assignment) dim: Dimension for multi-dimensional clustering (default=1 for scalar) max_iter: Maximum number of DKM iterations per forward pass epsilon: Convergence threshold for centroid updates init_method: Centroid initialization method ('kmeans++' or 'random') """ def __init__( self, weight_tensor: nn.Parameter, n_clusters: int = 16, tau: float = 2e-5, dim: int = 1, max_iter: int = 5, epsilon: float = 1e-4, init_method: str = "kmeans++", ): super().__init__() self.n_clusters = n_clusters self.tau = tau self.dim = dim self.max_iter = max_iter self.epsilon = epsilon self.init_method = init_method # Store reference to the weight parameter (not a copy) self.weight = weight_tensor self.original_shape = weight_tensor.shape # Validate dimensions n_elements = weight_tensor.numel() if n_elements % dim != 0: raise ValueError( f"Weight tensor has {n_elements} elements, which is not " f"divisible by dim={dim}. Choose a different dim." ) self.n_vectors = n_elements // dim # Initialize centroids as a buffer (not a parameter - updated by DKM iterations) # Shape: (n_clusters, dim) for multi-dim, (n_clusters, 1) for scalar centroids = self._initialize_centroids(weight_tensor) self.register_buffer("centroids", centroids) # Track whether this is the first forward pass self.register_buffer("initialized", torch.tensor(False)) def _initialize_centroids(self, weight_tensor: torch.Tensor) -> torch.Tensor: """ Initialize centroids using k-means++ or random selection. For multi-dimensional clustering, weights are reshaped into (N/d, d) sub-vectors before selecting initial centroids. """ with torch.no_grad(): # Flatten and reshape for multi-dim clustering flat_weights = weight_tensor.detach().reshape(-1) if self.dim == 1: vectors = flat_weights.unsqueeze(1) # (N, 1) else: vectors = flat_weights.reshape(self.n_vectors, self.dim) # (N/d, d) if self.init_method == "kmeans++": centroids = self._kmeans_plus_plus_init(vectors) else: # Random initialization: select k random weight vectors indices = torch.randperm(vectors.shape[0])[:self.n_clusters] centroids = vectors[indices].clone() return centroids # (n_clusters, dim) def _kmeans_plus_plus_init(self, vectors: torch.Tensor) -> torch.Tensor: """ K-means++ initialization for better centroid starting positions. Args: vectors: (N, d) weight vectors Returns: centroids: (k, d) initial centroids """ n = vectors.shape[0] k = self.n_clusters d = vectors.shape[1] # If fewer vectors than clusters, duplicate with small noise if n < k: repeats = (k // n) + 1 vectors_expanded = vectors.repeat(repeats, 1)[:k] noise = torch.randn_like(vectors_expanded) * 1e-6 return (vectors_expanded + noise).clone() # Choose first centroid randomly idx = torch.randint(0, n, (1,)).item() centroids = [vectors[idx].clone()] for _ in range(1, k): # Compute distances to nearest existing centroid stacked = torch.stack(centroids, dim=0) # (current_k, d) # distances: (N, current_k) dists = torch.cdist(vectors.unsqueeze(0), stacked.unsqueeze(0)).squeeze(0) min_dists = dists.min(dim=1).values # (N,) # Choose next centroid with probability proportional to distance squared probs = min_dists ** 2 prob_sum = probs.sum() if prob_sum < 1e-30 or torch.isnan(prob_sum) or torch.isinf(prob_sum): # Fallback: all distances are zero (e.g., uniform weights) # Select a random unused index and add small noise to break ties idx = torch.randint(0, n, (1,)).item() else: probs = probs / prob_sum # Clamp to avoid negative values from float errors probs = probs.clamp(min=0.0) idx = torch.multinomial(probs, 1).item() centroids.append(vectors[idx].clone()) result = torch.stack(centroids, dim=0) # (k, d) # Add tiny noise to break ties if centroids are identical if result.unique(dim=0).shape[0] < k: noise = torch.randn_like(result) * (result.abs().mean() * 1e-4 + 1e-8) result = result + noise return result def _compute_distance_matrix( self, weights: torch.Tensor, centroids: torch.Tensor ) -> torch.Tensor: """ Compute negative squared Euclidean distance matrix D. D[i,j] = -||w_i - c_j||^2 Per the paper: d_ij = -f(w_i, c_j) where f is Euclidean distance. We use squared Euclidean for efficiency (equivalent for softmax). Args: weights: (N/d, d) weight sub-vectors centroids: (k, d) centroid vectors Returns: D: (N/d, k) negative distance matrix """ # Efficient computation: ||w - c||^2 = ||w||^2 - 2*w.c + ||c||^2 w_sq = (weights ** 2).sum(dim=1, keepdim=True) # (N/d, 1) c_sq = (centroids ** 2).sum(dim=1, keepdim=True).t() # (1, k) wc = weights @ centroids.t() # (N/d, k) D = -(w_sq - 2 * wc + c_sq) # (N/d, k) — negative squared Euclidean return D def _compute_attention(self, D: torch.Tensor) -> torch.Tensor: """ Compute attention matrix A from distance matrix D using softmax with temperature τ. a_ij = exp(d_ij / τ) / Σ_k exp(d_ik / τ) This is the key differentiable component that enables gradient flow through the clustering assignment. Args: D: (N/d, k) negative distance matrix Returns: A: (N/d, k) attention matrix (rows sum to 1) """ # Scale by temperature and apply softmax along cluster dimension A = F.softmax(D / self.tau, dim=1) # (N/d, k) return A def _update_centroids( self, A: torch.Tensor, weights: torch.Tensor ) -> torch.Tensor: """ Update centroids using attention-weighted average of weights. c_j_new = Σ_i(a_ij * w_i) / Σ_i(a_ij) This is the M-step equivalent in the EM interpretation (Appendix G). Args: A: (N/d, k) attention matrix weights: (N/d, d) weight sub-vectors Returns: new_centroids: (k, d) updated centroids """ # Numerator: Σ_i(a_ij * w_i) for each centroid j # A.t() @ weights: (k, N/d) @ (N/d, d) = (k, d) numerator = A.t() @ weights # (k, d) # Denominator: Σ_i(a_ij) for each centroid j denominator = A.sum(dim=0, keepdim=True).t() # (k, 1) # Avoid division by zero denominator = denominator.clamp(min=1e-10) new_centroids = numerator / denominator # (k, d) return new_centroids def forward(self, weight_override: torch.Tensor = None) -> torch.Tensor: """ Forward pass: perform differentiable k-means clustering. The iterative process (Fig. 2 of the paper): 1. Compute distance matrix D between weights and centroids 2. Compute attention matrix A = softmax(D/τ) 3. Update centroids: C_new = (A^T @ W) / sum(A) 4. Repeat until convergence or max_iter reached 5. Return compressed weights: W_tilde = A @ C During training: returns soft-assigned weights (differentiable) During eval: returns hard-assigned weights (nearest centroid) Args: weight_override: Optional weight tensor to use instead of self.weight Returns: compressed_weight: Tensor with same shape as original weight """ weight_tensor = weight_override if weight_override is not None else self.weight # Reshape weights into sub-vectors for multi-dim clustering flat_weights = weight_tensor.reshape(-1) if self.dim == 1: W = flat_weights.unsqueeze(1) # (N, 1) else: W = flat_weights.reshape(self.n_vectors, self.dim) # (N/d, d) # Re-initialize centroids on first call (ensures correct device/dtype) if not self.initialized: self.centroids = self._initialize_centroids(weight_tensor).to( weight_tensor.device ) self.initialized = torch.tensor(True, device=weight_tensor.device) # Current centroids (detached from previous iteration's graph) C = self.centroids.clone() # Iterative DKM clustering (Section 3.2, Fig. 2) for iteration in range(self.max_iter): # Step 1: Compute distance matrix D = self._compute_distance_matrix(W, C) # Step 2: Compute attention matrix with temperature τ A = self._compute_attention(D) # Step 3: Update centroids C_new = self._update_centroids(A, W) # Step 4: Check convergence |C - C_new| ≤ ε delta = (C - C_new).abs().max().item() C = C_new if delta <= self.epsilon: break # Store converged centroids for next batch (warm start) self.centroids = C.detach().clone() if self.training: # Training: use soft assignment (differentiable) # W_tilde = A @ C (attention-weighted centroids) W_tilde = A @ C # (N/d, d) else: # Inference: snap to nearest centroid (hard assignment) # Find nearest centroid for each weight vector D_final = self._compute_distance_matrix(W, C) assignments = D_final.argmax(dim=1) # (N/d,) — argmax because D is negative distance W_tilde = C[assignments] # (N/d, d) # Reshape back to original weight shape compressed_weight = W_tilde.reshape(self.original_shape) return compressed_weight def get_assignments(self) -> torch.Tensor: """ Get hard cluster assignments for each weight (for inference/analysis). Returns: assignments: (N/d,) tensor of cluster indices """ with torch.no_grad(): flat_weights = self.weight.detach().reshape(-1) if self.dim == 1: W = flat_weights.unsqueeze(1) else: W = flat_weights.reshape(self.n_vectors, self.dim) D = self._compute_distance_matrix(W, self.centroids) assignments = D.argmax(dim=1) return assignments def get_codebook(self) -> torch.Tensor: """ Get the current centroid codebook. Returns: centroids: (k, d) centroid values """ return self.centroids.clone() def extra_repr(self) -> str: bits = math.log2(self.n_clusters) bpw = bits / self.dim return ( f"n_clusters={self.n_clusters}, tau={self.tau}, dim={self.dim}, " f"max_iter={self.max_iter}, eps={self.epsilon}, " f"bits={bits:.1f}, bits_per_weight={bpw:.2f}, " f"weight_shape={list(self.original_shape)}" )