dkm-compression / dkm /dkm_layer.py
syedmohaiminulhoque's picture
Add core DKM layer implementation
f5e358d verified
"""
DKM Layer: Differentiable K-Means Clustering Layer
Implements the core algorithm from Section 3.2 and 3.3 of the paper:
1. Compute distance matrix D between weights W and centroids C
2. Apply softmax with temperature τ to get attention matrix A
3. Update centroids: c_j = Σ_i(a_ij * w_i) / Σ_i(a_ij)
4. Iterate until convergence or max iterations
5. Compute compressed weights: W_tilde = A @ C
Supports both 1D and multi-dimensional clustering (Section 3.3).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class DKMLayer(nn.Module):
"""
Differentiable K-Means Clustering Layer.
This layer performs differentiable weight clustering by casting k-means
as an attention problem. During training, soft assignment via attention
allows gradients to flow through the clustering process. During inference,
weights are snapped to nearest centroids (hard assignment).
Args:
weight_tensor: The weight parameter to cluster (nn.Parameter)
n_clusters: Number of cluster centroids (k = 2^bits)
tau: Temperature for softmax attention (controls hardness of assignment)
dim: Dimension for multi-dimensional clustering (default=1 for scalar)
max_iter: Maximum number of DKM iterations per forward pass
epsilon: Convergence threshold for centroid updates
init_method: Centroid initialization method ('kmeans++' or 'random')
"""
def __init__(
self,
weight_tensor: nn.Parameter,
n_clusters: int = 16,
tau: float = 2e-5,
dim: int = 1,
max_iter: int = 5,
epsilon: float = 1e-4,
init_method: str = "kmeans++",
):
super().__init__()
self.n_clusters = n_clusters
self.tau = tau
self.dim = dim
self.max_iter = max_iter
self.epsilon = epsilon
self.init_method = init_method
# Store reference to the weight parameter (not a copy)
self.weight = weight_tensor
self.original_shape = weight_tensor.shape
# Validate dimensions
n_elements = weight_tensor.numel()
if n_elements % dim != 0:
raise ValueError(
f"Weight tensor has {n_elements} elements, which is not "
f"divisible by dim={dim}. Choose a different dim."
)
self.n_vectors = n_elements // dim
# Initialize centroids as a buffer (not a parameter - updated by DKM iterations)
# Shape: (n_clusters, dim) for multi-dim, (n_clusters, 1) for scalar
centroids = self._initialize_centroids(weight_tensor)
self.register_buffer("centroids", centroids)
# Track whether this is the first forward pass
self.register_buffer("initialized", torch.tensor(False))
def _initialize_centroids(self, weight_tensor: torch.Tensor) -> torch.Tensor:
"""
Initialize centroids using k-means++ or random selection.
For multi-dimensional clustering, weights are reshaped into
(N/d, d) sub-vectors before selecting initial centroids.
"""
with torch.no_grad():
# Flatten and reshape for multi-dim clustering
flat_weights = weight_tensor.detach().reshape(-1)
if self.dim == 1:
vectors = flat_weights.unsqueeze(1) # (N, 1)
else:
vectors = flat_weights.reshape(self.n_vectors, self.dim) # (N/d, d)
if self.init_method == "kmeans++":
centroids = self._kmeans_plus_plus_init(vectors)
else:
# Random initialization: select k random weight vectors
indices = torch.randperm(vectors.shape[0])[:self.n_clusters]
centroids = vectors[indices].clone()
return centroids # (n_clusters, dim)
def _kmeans_plus_plus_init(self, vectors: torch.Tensor) -> torch.Tensor:
"""
K-means++ initialization for better centroid starting positions.
Args:
vectors: (N, d) weight vectors
Returns:
centroids: (k, d) initial centroids
"""
n = vectors.shape[0]
k = self.n_clusters
d = vectors.shape[1]
# If fewer vectors than clusters, duplicate with small noise
if n < k:
repeats = (k // n) + 1
vectors_expanded = vectors.repeat(repeats, 1)[:k]
noise = torch.randn_like(vectors_expanded) * 1e-6
return (vectors_expanded + noise).clone()
# Choose first centroid randomly
idx = torch.randint(0, n, (1,)).item()
centroids = [vectors[idx].clone()]
for _ in range(1, k):
# Compute distances to nearest existing centroid
stacked = torch.stack(centroids, dim=0) # (current_k, d)
# distances: (N, current_k)
dists = torch.cdist(vectors.unsqueeze(0), stacked.unsqueeze(0)).squeeze(0)
min_dists = dists.min(dim=1).values # (N,)
# Choose next centroid with probability proportional to distance squared
probs = min_dists ** 2
prob_sum = probs.sum()
if prob_sum < 1e-30 or torch.isnan(prob_sum) or torch.isinf(prob_sum):
# Fallback: all distances are zero (e.g., uniform weights)
# Select a random unused index and add small noise to break ties
idx = torch.randint(0, n, (1,)).item()
else:
probs = probs / prob_sum
# Clamp to avoid negative values from float errors
probs = probs.clamp(min=0.0)
idx = torch.multinomial(probs, 1).item()
centroids.append(vectors[idx].clone())
result = torch.stack(centroids, dim=0) # (k, d)
# Add tiny noise to break ties if centroids are identical
if result.unique(dim=0).shape[0] < k:
noise = torch.randn_like(result) * (result.abs().mean() * 1e-4 + 1e-8)
result = result + noise
return result
def _compute_distance_matrix(
self, weights: torch.Tensor, centroids: torch.Tensor
) -> torch.Tensor:
"""
Compute negative squared Euclidean distance matrix D.
D[i,j] = -||w_i - c_j||^2
Per the paper: d_ij = -f(w_i, c_j) where f is Euclidean distance.
We use squared Euclidean for efficiency (equivalent for softmax).
Args:
weights: (N/d, d) weight sub-vectors
centroids: (k, d) centroid vectors
Returns:
D: (N/d, k) negative distance matrix
"""
# Efficient computation: ||w - c||^2 = ||w||^2 - 2*w.c + ||c||^2
w_sq = (weights ** 2).sum(dim=1, keepdim=True) # (N/d, 1)
c_sq = (centroids ** 2).sum(dim=1, keepdim=True).t() # (1, k)
wc = weights @ centroids.t() # (N/d, k)
D = -(w_sq - 2 * wc + c_sq) # (N/d, k) — negative squared Euclidean
return D
def _compute_attention(self, D: torch.Tensor) -> torch.Tensor:
"""
Compute attention matrix A from distance matrix D using softmax with temperature τ.
a_ij = exp(d_ij / τ) / Σ_k exp(d_ik / τ)
This is the key differentiable component that enables gradient flow
through the clustering assignment.
Args:
D: (N/d, k) negative distance matrix
Returns:
A: (N/d, k) attention matrix (rows sum to 1)
"""
# Scale by temperature and apply softmax along cluster dimension
A = F.softmax(D / self.tau, dim=1) # (N/d, k)
return A
def _update_centroids(
self, A: torch.Tensor, weights: torch.Tensor
) -> torch.Tensor:
"""
Update centroids using attention-weighted average of weights.
c_j_new = Σ_i(a_ij * w_i) / Σ_i(a_ij)
This is the M-step equivalent in the EM interpretation (Appendix G).
Args:
A: (N/d, k) attention matrix
weights: (N/d, d) weight sub-vectors
Returns:
new_centroids: (k, d) updated centroids
"""
# Numerator: Σ_i(a_ij * w_i) for each centroid j
# A.t() @ weights: (k, N/d) @ (N/d, d) = (k, d)
numerator = A.t() @ weights # (k, d)
# Denominator: Σ_i(a_ij) for each centroid j
denominator = A.sum(dim=0, keepdim=True).t() # (k, 1)
# Avoid division by zero
denominator = denominator.clamp(min=1e-10)
new_centroids = numerator / denominator # (k, d)
return new_centroids
def forward(self, weight_override: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass: perform differentiable k-means clustering.
The iterative process (Fig. 2 of the paper):
1. Compute distance matrix D between weights and centroids
2. Compute attention matrix A = softmax(D/τ)
3. Update centroids: C_new = (A^T @ W) / sum(A)
4. Repeat until convergence or max_iter reached
5. Return compressed weights: W_tilde = A @ C
During training: returns soft-assigned weights (differentiable)
During eval: returns hard-assigned weights (nearest centroid)
Args:
weight_override: Optional weight tensor to use instead of self.weight
Returns:
compressed_weight: Tensor with same shape as original weight
"""
weight_tensor = weight_override if weight_override is not None else self.weight
# Reshape weights into sub-vectors for multi-dim clustering
flat_weights = weight_tensor.reshape(-1)
if self.dim == 1:
W = flat_weights.unsqueeze(1) # (N, 1)
else:
W = flat_weights.reshape(self.n_vectors, self.dim) # (N/d, d)
# Re-initialize centroids on first call (ensures correct device/dtype)
if not self.initialized:
self.centroids = self._initialize_centroids(weight_tensor).to(
weight_tensor.device
)
self.initialized = torch.tensor(True, device=weight_tensor.device)
# Current centroids (detached from previous iteration's graph)
C = self.centroids.clone()
# Iterative DKM clustering (Section 3.2, Fig. 2)
for iteration in range(self.max_iter):
# Step 1: Compute distance matrix
D = self._compute_distance_matrix(W, C)
# Step 2: Compute attention matrix with temperature τ
A = self._compute_attention(D)
# Step 3: Update centroids
C_new = self._update_centroids(A, W)
# Step 4: Check convergence |C - C_new| ≤ ε
delta = (C - C_new).abs().max().item()
C = C_new
if delta <= self.epsilon:
break
# Store converged centroids for next batch (warm start)
self.centroids = C.detach().clone()
if self.training:
# Training: use soft assignment (differentiable)
# W_tilde = A @ C (attention-weighted centroids)
W_tilde = A @ C # (N/d, d)
else:
# Inference: snap to nearest centroid (hard assignment)
# Find nearest centroid for each weight vector
D_final = self._compute_distance_matrix(W, C)
assignments = D_final.argmax(dim=1) # (N/d,) — argmax because D is negative distance
W_tilde = C[assignments] # (N/d, d)
# Reshape back to original weight shape
compressed_weight = W_tilde.reshape(self.original_shape)
return compressed_weight
def get_assignments(self) -> torch.Tensor:
"""
Get hard cluster assignments for each weight (for inference/analysis).
Returns:
assignments: (N/d,) tensor of cluster indices
"""
with torch.no_grad():
flat_weights = self.weight.detach().reshape(-1)
if self.dim == 1:
W = flat_weights.unsqueeze(1)
else:
W = flat_weights.reshape(self.n_vectors, self.dim)
D = self._compute_distance_matrix(W, self.centroids)
assignments = D.argmax(dim=1)
return assignments
def get_codebook(self) -> torch.Tensor:
"""
Get the current centroid codebook.
Returns:
centroids: (k, d) centroid values
"""
return self.centroids.clone()
def extra_repr(self) -> str:
bits = math.log2(self.n_clusters)
bpw = bits / self.dim
return (
f"n_clusters={self.n_clusters}, tau={self.tau}, dim={self.dim}, "
f"max_iter={self.max_iter}, eps={self.epsilon}, "
f"bits={bits:.1f}, bits_per_weight={bpw:.2f}, "
f"weight_shape={list(self.original_shape)}"
)