# Copyright 2025 Xiaomi Corporation. import typing as tp from einops import rearrange, repeat import torch from torch import nn import torch.nn.functional as F import torch.distributed as dist def rank(): if dist.is_initialized(): return dist.get_rank() else: return 0 def world_size(): if dist.is_initialized(): return dist.get_world_size() else: return 1 def default(val: tp.Any, d: tp.Any) -> tp.Any: return val if val is not None else d def ema_inplace(moving_avg, new, decay: float): if dist.is_initialized(): dist.all_reduce(new, op=dist.ReduceOp.SUM) moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): return (x + epsilon) / (x.sum() + n_categories * epsilon) def uniform_init(*shape: int): t = torch.empty(shape) nn.init.kaiming_uniform_(t) return t def sample_vectors(samples, num: int): num_samples, device = samples.shape[0], samples.device if num_samples >= num: indices = torch.randperm(num_samples, device=device)[:num] else: indices = torch.randint(0, num_samples, (num,), device=device) selected_samples = samples[indices] if dist.is_initialized(): dist.broadcast(selected_samples, src=0) return selected_samples def kmeans(samples, num_clusters: int, num_iters: int = 10): dim, dtype = samples.shape[-1], samples.dtype means = sample_vectors(samples, num_clusters) for _ in range(num_iters): dists = -( samples.pow(2).sum(1, keepdim=True) - 2 * samples @ means.t() + means.t().pow(2).sum(0, keepdim=True) ) buckets = dists.max(dim=-1).indices bins = torch.bincount(buckets, minlength=num_clusters) new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) new_means = new_means.scatter_add_( 0, repeat(buckets, "n -> n d", d=dim), samples ) if dist.is_initialized(): dist.all_reduce(bins, op=dist.ReduceOp.SUM) dist.all_reduce(new_means, op=dist.ReduceOp.SUM) zero_mask = bins == 0 bins_min_clamped = bins.masked_fill(zero_mask, 1) new_means = new_means / bins_min_clamped[..., None] means = torch.where(zero_mask[..., None], means, new_means) return means, bins class EuclideanCodebook(nn.Module): """Codebook with Euclidean distance. Args: dim (int): Dimension. codebook_size (int): Codebook size. kmeans_init (bool): Whether to use k-means to initialize the codebooks. If set to true, run the k-means algorithm on the first training batch and use the learned centroids as initialization. kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. decay (float): Decay for exponential moving average over the codebooks. epsilon (float): Epsilon value for numerical stability. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ def __init__( self, dim: int, codebook_size: int, kmeans_init: int = False, kmeans_iters: int = 10, decay: float = 0.99, epsilon: float = 1e-5, threshold_ema_dead_code: int = 2, ): super().__init__() self.decay = decay init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = ( uniform_init if not kmeans_init else torch.zeros ) embed = init_fn(codebook_size, dim) self.codebook_size = codebook_size self.kmeans_iters = kmeans_iters self.epsilon = epsilon self.threshold_ema_dead_code = threshold_ema_dead_code self.register_buffer("inited", torch.Tensor([not kmeans_init])) self.register_buffer("cluster_size", torch.zeros(codebook_size)) self.register_buffer("embed", embed) self.register_buffer("embed_avg", embed.clone()) @torch.jit.ignore def init_embed_(self, data): if self.inited: return embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) self.embed.data.copy_(embed) self.embed_avg.data.copy_(embed.clone()) self.cluster_size.data.copy_(cluster_size) self.inited.data.copy_(torch.Tensor([True])) def replace_(self, samples, mask): # modified_codebook = torch.where( # mask[..., None], sample_vectors(samples, self.codebook_size), self.embed # ) replace_num = mask.sum() modified_codebook = self.embed.clone() modified_codebook[mask] = sample_vectors(samples, replace_num) self.embed.data.copy_(modified_codebook) def expire_codes_(self, batch_samples): if self.threshold_ema_dead_code == 0: return expired_codes = self.cluster_size < self.threshold_ema_dead_code if not torch.any(expired_codes): return batch_samples = rearrange(batch_samples, "... d -> (...) d") self.replace_(batch_samples, mask=expired_codes) def preprocess(self, x): x = rearrange(x, "... d -> (...) d") return x def quantize(self, x): embed = self.embed.t() dist = -( x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True) ) embed_ind = dist.max(dim=-1).indices return embed_ind def postprocess_emb(self, embed_ind, shape): return embed_ind.view(*shape[:-1]) def dequantize(self, embed_ind): quantize = F.embedding(embed_ind, self.embed) return quantize def encode(self, x): shape = x.shape # pre-process x = self.preprocess(x) # quantize embed_ind = self.quantize(x) # post-process embed_ind = self.postprocess_emb(embed_ind, shape) return embed_ind def decode(self, embed_ind): quantize = self.dequantize(embed_ind) return quantize def forward(self, x): shape, dtype = x.shape, x.dtype x = self.preprocess(x) self.init_embed_(x) embed_ind = self.quantize(x) embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) embed_ind = self.postprocess_emb(embed_ind, shape) quantize = self.dequantize(embed_ind) if self.training: # We do the expiry of code at that point as buffers are in sync # and all the workers will take the same decision. self.expire_codes_(x) ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) embed_sum = x.t() @ embed_onehot ema_inplace(self.embed_avg, embed_sum.t().contiguous(), self.decay) cluster_size = ( laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum() ) embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) self.embed.data.copy_(embed_normalized) return quantize, embed_ind class VectorQuantization(nn.Module): """Vector quantization implementation. Currently supports only euclidean distance. Args: dim (int): Dimension codebook_size (int): Codebook size codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. decay (float): Decay for exponential moving average over the codebooks. epsilon (float): Epsilon value for numerical stability. kmeans_init (bool): Whether to use kmeans to initialize the codebooks. kmeans_iters (int): Number of iterations used for kmeans initialization. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. commitment_weight (float): Weight for commitment loss. """ def __init__( self, dim: int, codebook_size: int, codebook_dim: tp.Optional[int] = None, decay: float = 0.99, epsilon: float = 1e-5, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, commitment_weight: float = 1.0, ): super().__init__() _codebook_dim: int = default(codebook_dim, dim) requires_projection = _codebook_dim != dim self.project_in = ( nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity() ) self.project_out = ( nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity() ) self.epsilon = epsilon self.commitment_weight = commitment_weight self._codebook = EuclideanCodebook( dim=_codebook_dim, codebook_size=codebook_size, kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, decay=decay, epsilon=epsilon, threshold_ema_dead_code=threshold_ema_dead_code, ) self.codebook_size = codebook_size @property def codebook(self): return self._codebook.embed def encode(self, x): # x = rearrange(x, "b d n -> b n d") x = self.project_in(x) embed_in = self._codebook.encode(x) return embed_in def decode(self, embed_ind): quantize = self._codebook.decode(embed_ind) quantize = self.project_out(quantize) # quantize = rearrange(quantize, "b n d -> b d n") return quantize def forward(self, x): device = x.device x = self.project_in(x) quantize, embed_ind = self._codebook(x) if self.training: quantize = x + (quantize - x).detach() loss = torch.tensor([0.0], device=device, requires_grad=self.training) if self.training: if self.commitment_weight > 0: commit_loss = F.mse_loss(quantize.detach(), x) loss = loss + commit_loss * self.commitment_weight quantize = self.project_out(quantize) # quantize = rearrange(quantize, "b n d -> b d n") return quantize, embed_ind, loss class ResidualVectorQuantization(nn.Module): """Residual vector quantization implementation. Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf """ def __init__(self, *, num_quantizers, codebook_size, **kwargs): super().__init__() if isinstance(codebook_size, int): codebook_size = [codebook_size] * num_quantizers elif len(codebook_size) < num_quantizers: codebook_size += [codebook_size[-1]] * (num_quantizers - len(codebook_size)) self.layers = nn.ModuleList( [ VectorQuantization(codebook_size=codebook_size[i], **kwargs) for i in range(num_quantizers) ] ) def forward( self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None ): quantized_out = 0.0 residual = x all_losses = [] all_indices = [] out_quantized = [] n_q = n_q or len(self.layers) for i, layer in enumerate(self.layers[:n_q]): quantized, indices, loss = layer(residual) residual = residual - quantized quantized_out = quantized_out + quantized all_indices.append(indices) all_losses.append(loss) if layers and i in layers: out_quantized.append(quantized_out) out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) return quantized_out, out_indices, out_losses, out_quantized def encode( self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None ) -> torch.Tensor: residual = x all_indices = [] n_q = len(self.layers) if n_q is None else n_q st = 0 if st is None else st for layer in self.layers[st:n_q]: indices = layer.encode(residual) quantized = layer.decode(indices) residual = residual - quantized all_indices.append(indices) out_indices = torch.stack(all_indices) return out_indices def decode(self, q_indices: torch.Tensor, st: int = 0) -> torch.Tensor: quantized_out = torch.tensor(0.0, device=q_indices.device) for i, indices in enumerate(q_indices): layer = self.layers[st + i] quantized = layer.decode(indices) quantized_out = quantized_out + quantized return quantized_out class ResidualVectorQuantizer(nn.Module): """Residual Vector Quantizer. Args: dimension (int): Dimension of the codebooks. n_q (int): Number of residual vector quantizers used. bins (int): Codebook size. decay (float): Decay for exponential moving average over the codebooks. kmeans_init (bool): Whether to use kmeans to initialize the codebooks. kmeans_iters (int): Number of iterations used for kmeans initialization. threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes that have an exponential moving average cluster size less than the specified threshold with randomly selected vector from the current batch. """ def __init__( self, dimension: int = 256, n_q: int = 8, bins: int | list = 1024, decay: float = 0.99, kmeans_init: bool = True, kmeans_iters: int = 50, threshold_ema_dead_code: int = 2, ): super().__init__() self.n_q = n_q self.dimension = dimension self.bins = bins self.decay = decay self.kmeans_init = kmeans_init self.kmeans_iters = kmeans_iters self.threshold_ema_dead_code = threshold_ema_dead_code self.vq = ResidualVectorQuantization( dim=self.dimension, codebook_size=self.bins, num_quantizers=self.n_q, decay=self.decay, kmeans_init=self.kmeans_init, kmeans_iters=self.kmeans_iters, threshold_ema_dead_code=self.threshold_ema_dead_code, ) def forward( self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None, ): """Residual vector quantization on the given input tensor. Args: x (torch.Tensor): Input tensor. n_q (int): Number of quantizer used to quantize. Default: All quantizers. layers (list): Layer that need to return quantized. Defalt: None. Returns: QuantizedResult: The quantized (or approximately quantized) representation with the associated numbert quantizers and layer quantized required to return. """ n_q = n_q if n_q else self.n_q quantized, codes, commit_loss, quantized_list = self.vq( x, n_q=n_q, layers=layers ) return quantized, codes, torch.mean(commit_loss), quantized_list def encode( self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None ) -> torch.Tensor: """Encode a given input tensor with the specified sample rate at the given bandwidth. The RVQ encode method sets the appropriate number of quantizer to use and returns indices for each quantizer. Args: x (torch.Tensor): Input tensor. n_q (int): Number of quantizer used to quantize. Default: All quantizers. st (int): Start to encode input from which layers. Default: 0. """ n_q = n_q if n_q else self.n_q st = st or 0 codes = self.vq.encode(x, n_q=n_q, st=st) return codes def decode(self, codes: torch.Tensor, st: int = 0) -> torch.Tensor: """Decode the given codes to the quantized representation. Args: codes (torch.Tensor): Input indices for each quantizer. st (int): Start to decode input codes from which layers. Default: 0. """ quantized = self.vq.decode(codes, st=st) return quantized