| """ |
| FlashVQ: Custom Vector Quantization with dual Triton GPU + PyTorch CPU path. |
| |
| Replaces vector_quantize_pytorch entirely (D-100). FlashVQCodebook is a standalone |
| nn.Module implementing all VQ operations: |
| - Cosine similarity codebook lookup |
| - EMA codebook update |
| - Dead code reset |
| - Rotation trick (gradient through quantization) |
| - Commitment loss |
| |
| Dispatch pattern (following tscale.py): |
| if x.is_cuda and _HAS_TRITON → _TritonFlashVQFn.apply() |
| else → self._cpu_forward() |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| _HAS_TRITON = False |
| try: |
| import triton |
| import triton.language as tl |
| _HAS_TRITON = True |
| except ImportError: |
| pass |
|
|
|
|
| class _RotationTrickFn(torch.autograd.Function): |
| """ |
| Rotation trick gradient through vector quantization. |
| |
| Instead of straight-through estimator (STE), rotate the encoder output |
| gradient toward the quantized vector direction. This helps the encoder |
| learn to produce outputs that align with codebook entries. |
| """ |
|
|
| @staticmethod |
| def forward(ctx, x, quantized): |
| ctx.save_for_backward(x.detach(), quantized.detach()) |
| return quantized |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| x, quantized = ctx.saved_tensors |
| |
| x_norm = F.normalize(x.float(), dim=-1) |
| q_norm = F.normalize(quantized.float(), dim=-1) |
| |
| |
| diff = x_norm - q_norm |
| proj = (grad_output.float() * x_norm).sum(dim=-1, keepdim=True) |
| grad_x = grad_output.float() - proj * diff |
| return grad_x.to(grad_output.dtype), None |
|
|
|
|
| class FlashVQCodebook(nn.Module): |
| """ |
| Vector quantization codebook with dual GPU (Triton) / CPU (PyTorch) paths. |
| |
| Interface matches vector_quantize_pytorch.VectorQuantize: |
| forward(x) → (quantized, indices, commitment_loss) |
| |
| All VQ operations are self-contained: |
| - Cosine similarity codebook lookup |
| - Straight-through estimator (STE) with optional rotation trick |
| - EMA codebook update (decay=0.99) |
| - Dead code reset (threshold_ema_dead_code=2) |
| - Commitment loss |
| """ |
|
|
| def __init__( |
| self, |
| codebook_size: int = 8192, |
| codebook_dim: int = 32, |
| decay: float = 0.99, |
| commitment_weight: float = 1.0, |
| threshold_ema_dead_code: int = 2, |
| kmeans_init: bool = True, |
| kmeans_iters: int = 10, |
| rotation_trick: bool = True, |
| ): |
| super().__init__() |
| self.codebook_size = codebook_size |
| self.codebook_dim = codebook_dim |
| self.decay = decay |
| self.commitment_weight = commitment_weight |
| self.threshold_ema_dead_code = threshold_ema_dead_code |
| self.kmeans_init = kmeans_init |
| self.kmeans_iters = kmeans_iters |
| self.rotation_trick = rotation_trick |
|
|
| |
| self.register_buffer('embed', torch.randn(codebook_size, codebook_dim) * 0.02) |
| self.register_buffer('cluster_size', torch.zeros(codebook_size)) |
| self.register_buffer('embed_avg', torch.zeros(codebook_size, codebook_dim)) |
|
|
| |
| self._triton_block_bt = 16 |
| self._triton_tile_k = 1024 |
|
|
| def _compute_tile_sizes(self): |
| """ |
| Dynamic tile sizing per D-102. |
| |
| Queries GPU device properties to determine SRAM budget, then computes |
| BLOCK_BT and TILE_K such that: |
| BLOCK_BT * codebook_dim * 2 + TILE_K * codebook_dim * 2 < SRAM * 0.9 |
| |
| For sm_89 (RTX 4060, 99KB SRAM per SM): |
| codebook_size=8192, codebook_dim=32 → BLOCK_BT=16, TILE_K=1024 (65KB) |
| codebook_size=4096, codebook_dim=32 → BLOCK_BT=16, TILE_K=512 (33KB) |
| """ |
| if not torch.cuda.is_available(): |
| return |
| try: |
| props = torch.cuda.get_device_properties(0) |
| sram_budget = 99 * 1024 |
|
|
| |
| elem_bytes = 2 |
|
|
| |
| bt = 16 |
| for tk in [2048, 1024, 512, 256, 128]: |
| sram_usage = bt * self.codebook_dim * elem_bytes + tk * self.codebook_dim * elem_bytes |
| if sram_usage < sram_budget * 0.9: |
| self._triton_block_bt = bt |
| self._triton_tile_k = tk |
| return |
|
|
| |
| self._triton_block_bt = 8 |
| self._triton_tile_k = 256 |
| except Exception: |
| |
| self._triton_block_bt = 16 |
| self._triton_tile_k = 1024 |
|
|
| def forward(self, x: torch.Tensor): |
| """ |
| Args: |
| x: Input tensor of shape [*, codebook_dim] |
| Returns: |
| quantized: Tensor of same shape as x |
| indices: Tensor of shape [*] with codebook indices |
| commitment_loss: Scalar tensor |
| """ |
| orig_shape = x.shape |
| x_flat = x.reshape(-1, self.codebook_dim) |
|
|
| if x.is_cuda and _HAS_TRITON: |
| quantized, indices, commitment_loss = self._triton_forward(x_flat) |
| else: |
| quantized, indices, commitment_loss = self._cpu_forward(x_flat) |
|
|
| quantized = quantized.reshape(orig_shape) |
| indices = indices.reshape(orig_shape[:-1]) |
| return quantized, indices, commitment_loss |
|
|
| def _triton_forward(self, x_flat: torch.Tensor): |
| """Triton GPU path — dispatched when CUDA + Triton available.""" |
| |
| quantized, indices, commitment_loss = _TritonFlashVQFn.apply( |
| x_flat, self.embed, self.cluster_size, self.embed_avg, |
| self.codebook_size, self.codebook_dim, |
| self.commitment_weight, self.rotation_trick, |
| ) |
|
|
| |
| with torch.no_grad(): |
| self._ema_update(x_flat, indices) |
| self._dead_code_reset(x_flat) |
|
|
| return quantized, indices, commitment_loss |
|
|
| def _cpu_forward(self, x_flat: torch.Tensor): |
| """ |
| Pure PyTorch CPU path — implements all VQ operations. |
| |
| Steps: |
| 1. Cosine similarity lookup → nearest codebook entry indices |
| 2. Quantize via straight-through estimator (or rotation trick) |
| 3. Compute commitment loss |
| 4. EMA update codebook (under torch.no_grad) |
| 5. Dead code reset (under torch.no_grad) |
| """ |
| |
| x_norm = F.normalize(x_flat.float(), dim=-1) |
| embed_norm = F.normalize(self.embed.float(), dim=-1) |
| sim = x_norm @ embed_norm.T |
| indices = sim.argmax(dim=-1) |
|
|
| |
| with torch.no_grad(): |
| quantized = self.embed[indices] |
|
|
| if self.rotation_trick: |
| quantized = _RotationTrickFn.apply(x_flat, quantized) |
| else: |
| |
| quantized = x_flat + (quantized - x_flat).detach() |
|
|
| |
| commitment_loss = self.commitment_weight * F.mse_loss( |
| x_flat.float(), quantized.detach().float() |
| ) |
|
|
| |
| with torch.no_grad(): |
| self._ema_update(x_flat, indices) |
|
|
| |
| self._dead_code_reset(x_flat) |
|
|
| return quantized, indices, commitment_loss |
|
|
| def _ema_update(self, x_flat: torch.Tensor, indices: torch.Tensor): |
| """ |
| Exponential moving average codebook update. |
| |
| Args: |
| x_flat: [N, D] input vectors |
| indices: [N] codebook indices for each input vector |
| """ |
| one_hot = F.one_hot(indices, num_classes=self.codebook_size).float() |
| n_assign = one_hot.sum(dim=0) |
|
|
| |
| self.cluster_size.mul_(self.decay).add_(n_assign * (1 - self.decay)) |
|
|
| |
| |
| x_float = x_flat.float() |
| for c in range(self.codebook_size): |
| mask = indices == c |
| count = mask.sum().item() |
| if count > 0: |
| assigned_sum = x_float[mask].sum(dim=0) |
| self.embed_avg[c].mul_(self.decay).add_(assigned_sum * (1 - self.decay)) |
|
|
| |
| cluster_size_safe = self.cluster_size.clamp(min=1e-5) |
| self.embed.copy_(self.embed_avg / cluster_size_safe.unsqueeze(1)) |
|
|
| def _dead_code_reset(self, x_flat: torch.Tensor): |
| """ |
| Replace dead codebook entries (cluster_size < threshold) with |
| random vectors from the current input batch. |
| """ |
| dead_mask = self.cluster_size < self.threshold_ema_dead_code |
| n_dead = dead_mask.sum().item() |
| if n_dead == 0: |
| return |
| dead_indices = torch.where(dead_mask)[0] |
| |
| rand_idx = torch.randint(0, x_flat.shape[0], (n_dead,), device=x_flat.device) |
| self.embed[dead_indices] = x_flat[rand_idx].detach() |
| self.cluster_size[dead_indices] = 0.0 |
| self.embed_avg[dead_indices] = 0.0 |
|
|
| @torch.no_grad() |
| def kmeans_init_codebook(self, x: torch.Tensor): |
| """Initialize codebook via k-means on first batch.""" |
| x_flat = x.reshape(-1, self.codebook_dim).float() |
| centroids = x_flat[torch.randperm(x_flat.shape[0])[:self.codebook_size]].clone() |
| for _ in range(self.kmeans_iters): |
| dist = torch.cdist(x_flat, centroids) |
| assign = dist.argmin(dim=-1) |
| for i in range(self.codebook_size): |
| mask = assign == i |
| if mask.sum() > 0: |
| centroids[i] = x_flat[mask].mean(dim=0) |
| self.embed.copy_(centroids) |
|
|
| @torch.no_grad() |
| def get_codebook_utilization(self) -> float: |
| """Fraction of codebook entries with any usage.""" |
| return (self.cluster_size > 0).float().mean().item() |
|
|
| @torch.no_grad() |
| def get_dead_code_count(self) -> int: |
| """Number of codebook entries below EMA dead threshold.""" |
| return (self.cluster_size < self.threshold_ema_dead_code).sum().item() |
|
|
|
|
| |
| |
|
|
| if _HAS_TRITON: |
|
|
| @triton.jit |
| def _triton_flash_vq_lookup_kernel( |
| x_ptr, codebook_ptr, indices_ptr, |
| stride_xb, stride_xd, |
| stride_cb, stride_cd, |
| N_CTX: tl.constexpr, |
| CODEBOOK_SIZE: tl.constexpr, |
| CODEBOOK_DIM: tl.constexpr, |
| BLOCK_BT: tl.constexpr, |
| TILE_K: tl.constexpr, |
| ): |
| """ |
| Tiled cosine similarity + argmax lookup for VQ codebook. |
| |
| Architecture: |
| pid = batch tile index |
| Load input tile [BLOCK_BT, CODEBOOK_DIM] |
| Normalize in fp32 |
| Tile over codebook in TILE_K chunks: |
| Load codebook tile [TILE_K, CODEBOOK_DIM] |
| Normalize in fp32 |
| Compute dot product via tl.dot → [BLOCK_BT, TILE_K] |
| Update running argmax |
| Store best indices |
| |
| SRAM: all arithmetic in fp32 with small tiles to fit 99KB budget. |
| """ |
| pid = tl.program_id(0) |
| offs_bt = pid * BLOCK_BT + tl.arange(0, BLOCK_BT) |
| offs_d = tl.arange(0, CODEBOOK_DIM) |
|
|
| |
| x_ptrs = x_ptr + offs_bt[:, None] * stride_xb + offs_d[None, :] * stride_xd |
| x = tl.load(x_ptrs, mask=offs_bt[:, None] < N_CTX, other=0.0) |
|
|
| |
| x_f32 = x.to(tl.float32) |
| x_sq = tl.sum(x_f32 * x_f32, axis=1) |
| x_norm_f32 = x_f32 / tl.sqrt(x_sq[:, None] + 1e-8) |
|
|
| |
| best_sim = tl.full([BLOCK_BT], -float('inf'), dtype=tl.float32) |
| best_idx = tl.zeros([BLOCK_BT], dtype=tl.int32) |
|
|
| for k_start in range(0, CODEBOOK_SIZE, TILE_K): |
| offs_k = k_start + tl.arange(0, TILE_K) |
| k_mask = offs_k < CODEBOOK_SIZE |
|
|
| |
| cb_ptrs = (codebook_ptr |
| + offs_k[:, None] * stride_cb |
| + offs_d[None, :] * stride_cd) |
| cb = tl.load(cb_ptrs, mask=k_mask[:, None], other=0.0) |
|
|
| |
| cb_f32 = cb.to(tl.float32) |
| cb_sq = tl.sum(cb_f32 * cb_f32, axis=1) |
| cb_norm_f32 = cb_f32 / tl.sqrt(cb_sq[:, None] + 1e-8) |
|
|
| |
| sim = tl.dot(x_norm_f32, tl.trans(cb_norm_f32)) |
|
|
| |
| tile_max = tl.max(sim, axis=1) |
| tile_argmax = tl.argmax(sim, axis=1) |
| tile_idx = k_start + tile_argmax |
|
|
| |
| update_mask = tile_max > best_sim |
| best_sim = tl.where(update_mask, tile_max, best_sim) |
| best_idx = tl.where(update_mask, tile_idx, best_idx) |
|
|
| |
| tl.store(indices_ptr + offs_bt, best_idx, mask=offs_bt < N_CTX) |
|
|
|
|
| @triton.jit |
| def _triton_flash_vq_quantize_kernel( |
| codebook_ptr, indices_ptr, quantized_ptr, |
| stride_cb, stride_cd, |
| stride_qb, stride_qd, |
| N_CTX: tl.constexpr, |
| CODEBOOK_DIM: tl.constexpr, |
| BLOCK_BT: tl.constexpr, |
| ): |
| """ |
| Gather quantized vectors from codebook at given indices. |
| Kernel form of: quantized[i] = codebook[indices[i]] |
| """ |
| pid = tl.program_id(0) |
| offs_bt = pid * BLOCK_BT + tl.arange(0, BLOCK_BT) |
| offs_d = tl.arange(0, CODEBOOK_DIM) |
|
|
| |
| idx = tl.load(indices_ptr + offs_bt, mask=offs_bt < N_CTX, other=0) |
|
|
| |
| |
| gather_ptrs = (codebook_ptr |
| + idx[:, None] * stride_cb |
| + offs_d[None, :] * stride_cd) |
| quantized = tl.load(gather_ptrs, |
| mask=offs_bt[:, None] < N_CTX, |
| other=0.0) |
|
|
| |
| out_ptrs = (quantized_ptr |
| + offs_bt[:, None] * stride_qb |
| + offs_d[None, :] * stride_qd) |
| tl.store(out_ptrs, quantized, mask=offs_bt[:, None] < N_CTX) |
|
|
|
|
| def _triton_lookup(x, embed, block_bt=None, tile_k=None): |
| """ |
| Launch Triton VQ lookup kernel with SRAM-safe tile sizes. |
| |
| Args: |
| x: [N, D] input tensor (cuda, contiguous) |
| embed: [codebook_size, D] codebook (cuda, contiguous) |
| block_bt: BLOCK_BT tile size (auto-computed if None) |
| tile_k: TILE_K tile size (auto-computed if None) |
| Returns: |
| indices: [N] int64 tensor of argmax indices |
| """ |
| N, D = x.shape |
| codebook_size = embed.shape[0] |
| assert embed.shape[1] == D, f"Codebook dim {embed.shape[1]} != input dim {D}" |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if block_bt is None or tile_k is None: |
| BLOCK_BT = 8 |
| TILE_K = 128 |
| else: |
| BLOCK_BT, TILE_K = block_bt, tile_k |
|
|
| grid = (triton.cdiv(N, BLOCK_BT),) |
|
|
| indices = torch.empty(N, dtype=torch.int32, device=x.device) |
|
|
| _triton_flash_vq_lookup_kernel[grid]( |
| x, embed, indices, |
| x.stride(0), x.stride(1), |
| embed.stride(0), embed.stride(1), |
| N, codebook_size, D, |
| BLOCK_BT=BLOCK_BT, TILE_K=TILE_K, |
| ) |
|
|
| return indices.long() |
|
|
|
|
| class _TritonFlashVQFn(torch.autograd.Function): |
| """ |
| Custom autograd Function wrapping Triton VQ kernels. |
| |
| Forward: Triton tiled cosine similarity + argmax lookup |
| Backward: Rotation trick gradient or straight-through estimator |
| """ |
|
|
| @staticmethod |
| def forward(ctx, x_flat, embed, cluster_size, embed_avg, |
| codebook_size, codebook_dim, |
| commitment_weight, rotation_trick): |
| |
| with torch.no_grad(): |
| indices = _triton_lookup(x_flat.contiguous(), embed.contiguous()) |
|
|
| quantized = embed[indices] |
| commitment_loss = commitment_weight * F.mse_loss(x_flat.float(), quantized.detach().float()) |
|
|
| |
| ctx.save_for_backward( |
| x_flat.detach().clone(), |
| quantized.detach().clone(), |
| embed.detach().clone(), |
| ) |
| ctx.codebook_dim = codebook_dim |
| ctx.rotation_trick = rotation_trick |
|
|
| return quantized, indices, commitment_loss |
|
|
| @staticmethod |
| def backward(ctx, grad_quantized, grad_indices, grad_commitment): |
| x_flat, quantized, embed = ctx.saved_tensors |
|
|
| if ctx.rotation_trick: |
| |
| x_norm = F.normalize(x_flat.float(), dim=-1) |
| q_norm = F.normalize(quantized.float(), dim=-1) |
| diff = x_norm - q_norm |
| proj = (grad_quantized.float() * x_norm).sum(dim=-1, keepdim=True) |
| grad_x = grad_quantized.float() - proj * diff |
| else: |
| |
| grad_x = grad_quantized.float() |
|
|
| return grad_x.to(grad_quantized.dtype), None, None, None, None, None, None, None |
|
|
|
|
| |
| if not _HAS_TRITON: |
|
|
| def _triton_lookup(x, embed): |
| """Fallback: torch-based cosine similarity lookup (CPU or CUDA without Triton).""" |
| with torch.no_grad(): |
| x_norm = F.normalize(x.float(), dim=-1) |
| embed_norm = F.normalize(embed.float(), dim=-1) |
| sim = x_norm @ embed_norm.T |
| indices = sim.argmax(dim=-1) |
| return indices |
|
|