""" FlashVQ: Custom Vector Quantization with dual Triton GPU + PyTorch CPU path. Replaces vector_quantize_pytorch entirely (D-100). FlashVQCodebook is a standalone nn.Module implementing all VQ operations: - Cosine similarity codebook lookup - EMA codebook update - Dead code reset - Rotation trick (gradient through quantization) - Commitment loss Dispatch pattern (following tscale.py): if x.is_cuda and _HAS_TRITON → _TritonFlashVQFn.apply() else → self._cpu_forward() """ import torch import torch.nn as nn import torch.nn.functional as F _HAS_TRITON = False try: import triton import triton.language as tl _HAS_TRITON = True except ImportError: pass class _RotationTrickFn(torch.autograd.Function): """ Rotation trick gradient through vector quantization. Instead of straight-through estimator (STE), rotate the encoder output gradient toward the quantized vector direction. This helps the encoder learn to produce outputs that align with codebook entries. """ @staticmethod def forward(ctx, x, quantized): ctx.save_for_backward(x.detach(), quantized.detach()) return quantized @staticmethod def backward(ctx, grad_output): x, quantized = ctx.saved_tensors # Normalize in fp32 for numerical stability x_norm = F.normalize(x.float(), dim=-1) q_norm = F.normalize(quantized.float(), dim=-1) # Gradient deflection: subtract projection onto (x_norm - q_norm) # This rotates the gradient toward the quantized direction diff = x_norm - q_norm proj = (grad_output.float() * x_norm).sum(dim=-1, keepdim=True) grad_x = grad_output.float() - proj * diff return grad_x.to(grad_output.dtype), None class FlashVQCodebook(nn.Module): """ Vector quantization codebook with dual GPU (Triton) / CPU (PyTorch) paths. Interface matches vector_quantize_pytorch.VectorQuantize: forward(x) → (quantized, indices, commitment_loss) All VQ operations are self-contained: - Cosine similarity codebook lookup - Straight-through estimator (STE) with optional rotation trick - EMA codebook update (decay=0.99) - Dead code reset (threshold_ema_dead_code=2) - Commitment loss """ def __init__( self, codebook_size: int = 8192, codebook_dim: int = 32, decay: float = 0.99, commitment_weight: float = 1.0, threshold_ema_dead_code: int = 2, kmeans_init: bool = True, kmeans_iters: int = 10, rotation_trick: bool = True, ): super().__init__() self.codebook_size = codebook_size self.codebook_dim = codebook_dim self.decay = decay self.commitment_weight = commitment_weight self.threshold_ema_dead_code = threshold_ema_dead_code self.kmeans_init = kmeans_init self.kmeans_iters = kmeans_iters self.rotation_trick = rotation_trick # Codebook buffers self.register_buffer('embed', torch.randn(codebook_size, codebook_dim) * 0.02) self.register_buffer('cluster_size', torch.zeros(codebook_size)) self.register_buffer('embed_avg', torch.zeros(codebook_size, codebook_dim)) # Tile sizes for Triton kernel (set on first GPU forward) self._triton_block_bt = 16 self._triton_tile_k = 1024 def _compute_tile_sizes(self): """ Dynamic tile sizing per D-102. Queries GPU device properties to determine SRAM budget, then computes BLOCK_BT and TILE_K such that: BLOCK_BT * codebook_dim * 2 + TILE_K * codebook_dim * 2 < SRAM * 0.9 For sm_89 (RTX 4060, 99KB SRAM per SM): codebook_size=8192, codebook_dim=32 → BLOCK_BT=16, TILE_K=1024 (65KB) codebook_size=4096, codebook_dim=32 → BLOCK_BT=16, TILE_K=512 (33KB) """ if not torch.cuda.is_available(): return try: props = torch.cuda.get_device_properties(0) sram_budget = 99 * 1024 # SM 8.9: 99KB per SM # Conservative estimate: each element is 2 bytes (bf16) in SRAM elem_bytes = 2 # Find largest TILE_K that fits with BLOCK_BT=16 bt = 16 for tk in [2048, 1024, 512, 256, 128]: sram_usage = bt * self.codebook_dim * elem_bytes + tk * self.codebook_dim * elem_bytes if sram_usage < sram_budget * 0.9: self._triton_block_bt = bt self._triton_tile_k = tk return # Fallback for very constrained SRAM or large codebook_dim self._triton_block_bt = 8 self._triton_tile_k = 256 except Exception: # Default values self._triton_block_bt = 16 self._triton_tile_k = 1024 def forward(self, x: torch.Tensor): """ Args: x: Input tensor of shape [*, codebook_dim] Returns: quantized: Tensor of same shape as x indices: Tensor of shape [*] with codebook indices commitment_loss: Scalar tensor """ orig_shape = x.shape x_flat = x.reshape(-1, self.codebook_dim) if x.is_cuda and _HAS_TRITON: quantized, indices, commitment_loss = self._triton_forward(x_flat) else: quantized, indices, commitment_loss = self._cpu_forward(x_flat) quantized = quantized.reshape(orig_shape) indices = indices.reshape(orig_shape[:-1]) return quantized, indices, commitment_loss def _triton_forward(self, x_flat: torch.Tensor): """Triton GPU path — dispatched when CUDA + Triton available.""" # Use _TritonFlashVQFn for forward + backward via autograd quantized, indices, commitment_loss = _TritonFlashVQFn.apply( x_flat, self.embed, self.cluster_size, self.embed_avg, self.codebook_size, self.codebook_dim, self.commitment_weight, self.rotation_trick, ) # EMA update and dead code reset (under torch.no_grad) with torch.no_grad(): self._ema_update(x_flat, indices) self._dead_code_reset(x_flat) return quantized, indices, commitment_loss def _cpu_forward(self, x_flat: torch.Tensor): """ Pure PyTorch CPU path — implements all VQ operations. Steps: 1. Cosine similarity lookup → nearest codebook entry indices 2. Quantize via straight-through estimator (or rotation trick) 3. Compute commitment loss 4. EMA update codebook (under torch.no_grad) 5. Dead code reset (under torch.no_grad) """ # ── Step 1: Cosine similarity lookup ── x_norm = F.normalize(x_flat.float(), dim=-1) embed_norm = F.normalize(self.embed.float(), dim=-1) sim = x_norm @ embed_norm.T # [N, codebook_size] indices = sim.argmax(dim=-1) # [N] # ── Step 2: Quantize with STE or rotation trick ── with torch.no_grad(): quantized = self.embed[indices] # [N, D] if self.rotation_trick: quantized = _RotationTrickFn.apply(x_flat, quantized) else: # Straight-through estimator quantized = x_flat + (quantized - x_flat).detach() # ── Step 3: Commitment loss ── commitment_loss = self.commitment_weight * F.mse_loss( x_flat.float(), quantized.detach().float() ) # ── Step 4: EMA update ── with torch.no_grad(): self._ema_update(x_flat, indices) # ── Step 5: Dead code reset ── self._dead_code_reset(x_flat) return quantized, indices, commitment_loss def _ema_update(self, x_flat: torch.Tensor, indices: torch.Tensor): """ Exponential moving average codebook update. Args: x_flat: [N, D] input vectors indices: [N] codebook indices for each input vector """ one_hot = F.one_hot(indices, num_classes=self.codebook_size).float() # [N, codebook_size] n_assign = one_hot.sum(dim=0) # [codebook_size] # EMA on cluster_size (how many inputs assigned to each code) self.cluster_size.mul_(self.decay).add_(n_assign * (1 - self.decay)) # EMA on embed_avg: weighted sum of assigned inputs # embed_avg[c] = decay * embed_avg[c] + (1 - decay) * sum(x assigned to c) x_float = x_flat.float() for c in range(self.codebook_size): mask = indices == c count = mask.sum().item() if count > 0: assigned_sum = x_float[mask].sum(dim=0) self.embed_avg[c].mul_(self.decay).add_(assigned_sum * (1 - self.decay)) # Normalize: embed = embed_avg / cluster_size (with epsilon) cluster_size_safe = self.cluster_size.clamp(min=1e-5) self.embed.copy_(self.embed_avg / cluster_size_safe.unsqueeze(1)) def _dead_code_reset(self, x_flat: torch.Tensor): """ Replace dead codebook entries (cluster_size < threshold) with random vectors from the current input batch. """ dead_mask = self.cluster_size < self.threshold_ema_dead_code n_dead = dead_mask.sum().item() if n_dead == 0: return dead_indices = torch.where(dead_mask)[0] # Replace with random input vectors rand_idx = torch.randint(0, x_flat.shape[0], (n_dead,), device=x_flat.device) self.embed[dead_indices] = x_flat[rand_idx].detach() self.cluster_size[dead_indices] = 0.0 self.embed_avg[dead_indices] = 0.0 @torch.no_grad() def kmeans_init_codebook(self, x: torch.Tensor): """Initialize codebook via k-means on first batch.""" x_flat = x.reshape(-1, self.codebook_dim).float() centroids = x_flat[torch.randperm(x_flat.shape[0])[:self.codebook_size]].clone() for _ in range(self.kmeans_iters): dist = torch.cdist(x_flat, centroids) assign = dist.argmin(dim=-1) for i in range(self.codebook_size): mask = assign == i if mask.sum() > 0: centroids[i] = x_flat[mask].mean(dim=0) self.embed.copy_(centroids) @torch.no_grad() def get_codebook_utilization(self) -> float: """Fraction of codebook entries with any usage.""" return (self.cluster_size > 0).float().mean().item() @torch.no_grad() def get_dead_code_count(self) -> int: """Number of codebook entries below EMA dead threshold.""" return (self.cluster_size < self.threshold_ema_dead_code).sum().item() # ─── Triton GPU Kernels ─── # Only defined when Triton is available if _HAS_TRITON: @triton.jit def _triton_flash_vq_lookup_kernel( x_ptr, codebook_ptr, indices_ptr, stride_xb, stride_xd, stride_cb, stride_cd, N_CTX: tl.constexpr, CODEBOOK_SIZE: tl.constexpr, CODEBOOK_DIM: tl.constexpr, BLOCK_BT: tl.constexpr, TILE_K: tl.constexpr, ): """ Tiled cosine similarity + argmax lookup for VQ codebook. Architecture: pid = batch tile index Load input tile [BLOCK_BT, CODEBOOK_DIM] Normalize in fp32 Tile over codebook in TILE_K chunks: Load codebook tile [TILE_K, CODEBOOK_DIM] Normalize in fp32 Compute dot product via tl.dot → [BLOCK_BT, TILE_K] Update running argmax Store best indices SRAM: all arithmetic in fp32 with small tiles to fit 99KB budget. """ pid = tl.program_id(0) offs_bt = pid * BLOCK_BT + tl.arange(0, BLOCK_BT) offs_d = tl.arange(0, CODEBOOK_DIM) # ── Load input tile ── x_ptrs = x_ptr + offs_bt[:, None] * stride_xb + offs_d[None, :] * stride_xd x = tl.load(x_ptrs, mask=offs_bt[:, None] < N_CTX, other=0.0) # ── Normalize input in fp32 (no keepdims in Triton tl.sum) ── x_f32 = x.to(tl.float32) x_sq = tl.sum(x_f32 * x_f32, axis=1) # [BLOCK_BT] x_norm_f32 = x_f32 / tl.sqrt(x_sq[:, None] + 1e-8) # ── Running argmax over tiled codebook ── best_sim = tl.full([BLOCK_BT], -float('inf'), dtype=tl.float32) best_idx = tl.zeros([BLOCK_BT], dtype=tl.int32) for k_start in range(0, CODEBOOK_SIZE, TILE_K): offs_k = k_start + tl.arange(0, TILE_K) k_mask = offs_k < CODEBOOK_SIZE # Load codebook tile into fp32 directly for normalization cb_ptrs = (codebook_ptr + offs_k[:, None] * stride_cb + offs_d[None, :] * stride_cd) cb = tl.load(cb_ptrs, mask=k_mask[:, None], other=0.0) # Normalize codebook tile in fp32 cb_f32 = cb.to(tl.float32) cb_sq = tl.sum(cb_f32 * cb_f32, axis=1) # [TILE_K] cb_norm_f32 = cb_f32 / tl.sqrt(cb_sq[:, None] + 1e-8) # Cosine similarity via tl.dot (tf32 on sm_89) sim = tl.dot(x_norm_f32, tl.trans(cb_norm_f32)) # [BLOCK_BT, TILE_K] # Running argmax within this tile tile_max = tl.max(sim, axis=1) tile_argmax = tl.argmax(sim, axis=1) tile_idx = k_start + tile_argmax # Merge with best across tiles using element-wise mask update_mask = tile_max > best_sim best_sim = tl.where(update_mask, tile_max, best_sim) best_idx = tl.where(update_mask, tile_idx, best_idx) # ── Store results ── tl.store(indices_ptr + offs_bt, best_idx, mask=offs_bt < N_CTX) @triton.jit def _triton_flash_vq_quantize_kernel( codebook_ptr, indices_ptr, quantized_ptr, stride_cb, stride_cd, stride_qb, stride_qd, N_CTX: tl.constexpr, CODEBOOK_DIM: tl.constexpr, BLOCK_BT: tl.constexpr, ): """ Gather quantized vectors from codebook at given indices. Kernel form of: quantized[i] = codebook[indices[i]] """ pid = tl.program_id(0) offs_bt = pid * BLOCK_BT + tl.arange(0, BLOCK_BT) offs_d = tl.arange(0, CODEBOOK_DIM) # Load indices for this batch tile idx = tl.load(indices_ptr + offs_bt, mask=offs_bt < N_CTX, other=0) # Gather: for each i in BLOCK_BT, load codebook[idx[i], :] # Pointer arithmetic with broadcasting gather_ptrs = (codebook_ptr + idx[:, None] * stride_cb + offs_d[None, :] * stride_cd) quantized = tl.load(gather_ptrs, mask=offs_bt[:, None] < N_CTX, other=0.0) # Store quantized output out_ptrs = (quantized_ptr + offs_bt[:, None] * stride_qb + offs_d[None, :] * stride_qd) tl.store(out_ptrs, quantized, mask=offs_bt[:, None] < N_CTX) def _triton_lookup(x, embed, block_bt=None, tile_k=None): """ Launch Triton VQ lookup kernel with SRAM-safe tile sizes. Args: x: [N, D] input tensor (cuda, contiguous) embed: [codebook_size, D] codebook (cuda, contiguous) block_bt: BLOCK_BT tile size (auto-computed if None) tile_k: TILE_K tile size (auto-computed if None) Returns: indices: [N] int64 tensor of argmax indices """ N, D = x.shape codebook_size = embed.shape[0] assert embed.shape[1] == D, f"Codebook dim {embed.shape[1]} != input dim {D}" # SRAM-safe tile sizes: kernel uses tf32 (fp32 math), and Triton # pipelines data through shared memory. Conservative sizing ensures # fits within ~99KB (sm_89) even with default num_stages=3. # # fp32 codebook tile: TILE_K * D * 4 → 128*32*4 = 16KB # fp32 input tile: BLOCK_BT * D * 4 → 8*32*4 = 1KB # Accumulator: BLOCK_BT*TILE_K*4 → 8*128*4 = 4KB # Per stage: ~21KB. With 3 pipeline stages: ~63KB (fits in 99KB). # # Larger tiles oversubscribe SRAM (tested: TILE_K=1024 → 321KB needed). if block_bt is None or tile_k is None: BLOCK_BT = 8 TILE_K = 128 else: BLOCK_BT, TILE_K = block_bt, tile_k grid = (triton.cdiv(N, BLOCK_BT),) indices = torch.empty(N, dtype=torch.int32, device=x.device) _triton_flash_vq_lookup_kernel[grid]( x, embed, indices, x.stride(0), x.stride(1), embed.stride(0), embed.stride(1), N, codebook_size, D, BLOCK_BT=BLOCK_BT, TILE_K=TILE_K, ) return indices.long() class _TritonFlashVQFn(torch.autograd.Function): """ Custom autograd Function wrapping Triton VQ kernels. Forward: Triton tiled cosine similarity + argmax lookup Backward: Rotation trick gradient or straight-through estimator """ @staticmethod def forward(ctx, x_flat, embed, cluster_size, embed_avg, codebook_size, codebook_dim, commitment_weight, rotation_trick): # Triton tiled lookup for indices with torch.no_grad(): indices = _triton_lookup(x_flat.contiguous(), embed.contiguous()) quantized = embed[indices] commitment_loss = commitment_weight * F.mse_loss(x_flat.float(), quantized.detach().float()) # Clone saved tensors to avoid version conflicts with in-place EMA updates ctx.save_for_backward( x_flat.detach().clone(), quantized.detach().clone(), embed.detach().clone(), ) ctx.codebook_dim = codebook_dim ctx.rotation_trick = rotation_trick return quantized, indices, commitment_loss @staticmethod def backward(ctx, grad_quantized, grad_indices, grad_commitment): x_flat, quantized, embed = ctx.saved_tensors if ctx.rotation_trick: # Rotation trick gradient x_norm = F.normalize(x_flat.float(), dim=-1) q_norm = F.normalize(quantized.float(), dim=-1) diff = x_norm - q_norm proj = (grad_quantized.float() * x_norm).sum(dim=-1, keepdim=True) grad_x = grad_quantized.float() - proj * diff else: # Straight-through estimator grad_x = grad_quantized.float() return grad_x.to(grad_quantized.dtype), None, None, None, None, None, None, None # When Triton is not available, define a fallback lookup if not _HAS_TRITON: def _triton_lookup(x, embed): """Fallback: torch-based cosine similarity lookup (CPU or CUDA without Triton).""" with torch.no_grad(): x_norm = F.normalize(x.float(), dim=-1) embed_norm = F.normalize(embed.float(), dim=-1) sim = x_norm @ embed_norm.T indices = sim.argmax(dim=-1) return indices