ARBS / arbitor /kernel /flash_vq.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""
FlashVQ: Custom Vector Quantization with dual Triton GPU + PyTorch CPU path.
Replaces vector_quantize_pytorch entirely (D-100). FlashVQCodebook is a standalone
nn.Module implementing all VQ operations:
- Cosine similarity codebook lookup
- EMA codebook update
- Dead code reset
- Rotation trick (gradient through quantization)
- Commitment loss
Dispatch pattern (following tscale.py):
if x.is_cuda and _HAS_TRITON → _TritonFlashVQFn.apply()
else → self._cpu_forward()
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
_HAS_TRITON = False
try:
import triton
import triton.language as tl
_HAS_TRITON = True
except ImportError:
pass
class _RotationTrickFn(torch.autograd.Function):
"""
Rotation trick gradient through vector quantization.
Instead of straight-through estimator (STE), rotate the encoder output
gradient toward the quantized vector direction. This helps the encoder
learn to produce outputs that align with codebook entries.
"""
@staticmethod
def forward(ctx, x, quantized):
ctx.save_for_backward(x.detach(), quantized.detach())
return quantized
@staticmethod
def backward(ctx, grad_output):
x, quantized = ctx.saved_tensors
# Normalize in fp32 for numerical stability
x_norm = F.normalize(x.float(), dim=-1)
q_norm = F.normalize(quantized.float(), dim=-1)
# Gradient deflection: subtract projection onto (x_norm - q_norm)
# This rotates the gradient toward the quantized direction
diff = x_norm - q_norm
proj = (grad_output.float() * x_norm).sum(dim=-1, keepdim=True)
grad_x = grad_output.float() - proj * diff
return grad_x.to(grad_output.dtype), None
class FlashVQCodebook(nn.Module):
"""
Vector quantization codebook with dual GPU (Triton) / CPU (PyTorch) paths.
Interface matches vector_quantize_pytorch.VectorQuantize:
forward(x) → (quantized, indices, commitment_loss)
All VQ operations are self-contained:
- Cosine similarity codebook lookup
- Straight-through estimator (STE) with optional rotation trick
- EMA codebook update (decay=0.99)
- Dead code reset (threshold_ema_dead_code=2)
- Commitment loss
"""
def __init__(
self,
codebook_size: int = 8192,
codebook_dim: int = 32,
decay: float = 0.99,
commitment_weight: float = 1.0,
threshold_ema_dead_code: int = 2,
kmeans_init: bool = True,
kmeans_iters: int = 10,
rotation_trick: bool = True,
):
super().__init__()
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
self.decay = decay
self.commitment_weight = commitment_weight
self.threshold_ema_dead_code = threshold_ema_dead_code
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.rotation_trick = rotation_trick
# Codebook buffers
self.register_buffer('embed', torch.randn(codebook_size, codebook_dim) * 0.02)
self.register_buffer('cluster_size', torch.zeros(codebook_size))
self.register_buffer('embed_avg', torch.zeros(codebook_size, codebook_dim))
# Tile sizes for Triton kernel (set on first GPU forward)
self._triton_block_bt = 16
self._triton_tile_k = 1024
def _compute_tile_sizes(self):
"""
Dynamic tile sizing per D-102.
Queries GPU device properties to determine SRAM budget, then computes
BLOCK_BT and TILE_K such that:
BLOCK_BT * codebook_dim * 2 + TILE_K * codebook_dim * 2 < SRAM * 0.9
For sm_89 (RTX 4060, 99KB SRAM per SM):
codebook_size=8192, codebook_dim=32 → BLOCK_BT=16, TILE_K=1024 (65KB)
codebook_size=4096, codebook_dim=32 → BLOCK_BT=16, TILE_K=512 (33KB)
"""
if not torch.cuda.is_available():
return
try:
props = torch.cuda.get_device_properties(0)
sram_budget = 99 * 1024 # SM 8.9: 99KB per SM
# Conservative estimate: each element is 2 bytes (bf16) in SRAM
elem_bytes = 2
# Find largest TILE_K that fits with BLOCK_BT=16
bt = 16
for tk in [2048, 1024, 512, 256, 128]:
sram_usage = bt * self.codebook_dim * elem_bytes + tk * self.codebook_dim * elem_bytes
if sram_usage < sram_budget * 0.9:
self._triton_block_bt = bt
self._triton_tile_k = tk
return
# Fallback for very constrained SRAM or large codebook_dim
self._triton_block_bt = 8
self._triton_tile_k = 256
except Exception:
# Default values
self._triton_block_bt = 16
self._triton_tile_k = 1024
def forward(self, x: torch.Tensor):
"""
Args:
x: Input tensor of shape [*, codebook_dim]
Returns:
quantized: Tensor of same shape as x
indices: Tensor of shape [*] with codebook indices
commitment_loss: Scalar tensor
"""
orig_shape = x.shape
x_flat = x.reshape(-1, self.codebook_dim)
if x.is_cuda and _HAS_TRITON:
quantized, indices, commitment_loss = self._triton_forward(x_flat)
else:
quantized, indices, commitment_loss = self._cpu_forward(x_flat)
quantized = quantized.reshape(orig_shape)
indices = indices.reshape(orig_shape[:-1])
return quantized, indices, commitment_loss
def _triton_forward(self, x_flat: torch.Tensor):
"""Triton GPU path — dispatched when CUDA + Triton available."""
# Use _TritonFlashVQFn for forward + backward via autograd
quantized, indices, commitment_loss = _TritonFlashVQFn.apply(
x_flat, self.embed, self.cluster_size, self.embed_avg,
self.codebook_size, self.codebook_dim,
self.commitment_weight, self.rotation_trick,
)
# EMA update and dead code reset (under torch.no_grad)
with torch.no_grad():
self._ema_update(x_flat, indices)
self._dead_code_reset(x_flat)
return quantized, indices, commitment_loss
def _cpu_forward(self, x_flat: torch.Tensor):
"""
Pure PyTorch CPU path — implements all VQ operations.
Steps:
1. Cosine similarity lookup → nearest codebook entry indices
2. Quantize via straight-through estimator (or rotation trick)
3. Compute commitment loss
4. EMA update codebook (under torch.no_grad)
5. Dead code reset (under torch.no_grad)
"""
# ── Step 1: Cosine similarity lookup ──
x_norm = F.normalize(x_flat.float(), dim=-1)
embed_norm = F.normalize(self.embed.float(), dim=-1)
sim = x_norm @ embed_norm.T # [N, codebook_size]
indices = sim.argmax(dim=-1) # [N]
# ── Step 2: Quantize with STE or rotation trick ──
with torch.no_grad():
quantized = self.embed[indices] # [N, D]
if self.rotation_trick:
quantized = _RotationTrickFn.apply(x_flat, quantized)
else:
# Straight-through estimator
quantized = x_flat + (quantized - x_flat).detach()
# ── Step 3: Commitment loss ──
commitment_loss = self.commitment_weight * F.mse_loss(
x_flat.float(), quantized.detach().float()
)
# ── Step 4: EMA update ──
with torch.no_grad():
self._ema_update(x_flat, indices)
# ── Step 5: Dead code reset ──
self._dead_code_reset(x_flat)
return quantized, indices, commitment_loss
def _ema_update(self, x_flat: torch.Tensor, indices: torch.Tensor):
"""
Exponential moving average codebook update.
Args:
x_flat: [N, D] input vectors
indices: [N] codebook indices for each input vector
"""
one_hot = F.one_hot(indices, num_classes=self.codebook_size).float() # [N, codebook_size]
n_assign = one_hot.sum(dim=0) # [codebook_size]
# EMA on cluster_size (how many inputs assigned to each code)
self.cluster_size.mul_(self.decay).add_(n_assign * (1 - self.decay))
# EMA on embed_avg: weighted sum of assigned inputs
# embed_avg[c] = decay * embed_avg[c] + (1 - decay) * sum(x assigned to c)
x_float = x_flat.float()
for c in range(self.codebook_size):
mask = indices == c
count = mask.sum().item()
if count > 0:
assigned_sum = x_float[mask].sum(dim=0)
self.embed_avg[c].mul_(self.decay).add_(assigned_sum * (1 - self.decay))
# Normalize: embed = embed_avg / cluster_size (with epsilon)
cluster_size_safe = self.cluster_size.clamp(min=1e-5)
self.embed.copy_(self.embed_avg / cluster_size_safe.unsqueeze(1))
def _dead_code_reset(self, x_flat: torch.Tensor):
"""
Replace dead codebook entries (cluster_size < threshold) with
random vectors from the current input batch.
"""
dead_mask = self.cluster_size < self.threshold_ema_dead_code
n_dead = dead_mask.sum().item()
if n_dead == 0:
return
dead_indices = torch.where(dead_mask)[0]
# Replace with random input vectors
rand_idx = torch.randint(0, x_flat.shape[0], (n_dead,), device=x_flat.device)
self.embed[dead_indices] = x_flat[rand_idx].detach()
self.cluster_size[dead_indices] = 0.0
self.embed_avg[dead_indices] = 0.0
@torch.no_grad()
def kmeans_init_codebook(self, x: torch.Tensor):
"""Initialize codebook via k-means on first batch."""
x_flat = x.reshape(-1, self.codebook_dim).float()
centroids = x_flat[torch.randperm(x_flat.shape[0])[:self.codebook_size]].clone()
for _ in range(self.kmeans_iters):
dist = torch.cdist(x_flat, centroids)
assign = dist.argmin(dim=-1)
for i in range(self.codebook_size):
mask = assign == i
if mask.sum() > 0:
centroids[i] = x_flat[mask].mean(dim=0)
self.embed.copy_(centroids)
@torch.no_grad()
def get_codebook_utilization(self) -> float:
"""Fraction of codebook entries with any usage."""
return (self.cluster_size > 0).float().mean().item()
@torch.no_grad()
def get_dead_code_count(self) -> int:
"""Number of codebook entries below EMA dead threshold."""
return (self.cluster_size < self.threshold_ema_dead_code).sum().item()
# ─── Triton GPU Kernels ───
# Only defined when Triton is available
if _HAS_TRITON:
@triton.jit
def _triton_flash_vq_lookup_kernel(
x_ptr, codebook_ptr, indices_ptr,
stride_xb, stride_xd,
stride_cb, stride_cd,
N_CTX: tl.constexpr,
CODEBOOK_SIZE: tl.constexpr,
CODEBOOK_DIM: tl.constexpr,
BLOCK_BT: tl.constexpr,
TILE_K: tl.constexpr,
):
"""
Tiled cosine similarity + argmax lookup for VQ codebook.
Architecture:
pid = batch tile index
Load input tile [BLOCK_BT, CODEBOOK_DIM]
Normalize in fp32
Tile over codebook in TILE_K chunks:
Load codebook tile [TILE_K, CODEBOOK_DIM]
Normalize in fp32
Compute dot product via tl.dot → [BLOCK_BT, TILE_K]
Update running argmax
Store best indices
SRAM: all arithmetic in fp32 with small tiles to fit 99KB budget.
"""
pid = tl.program_id(0)
offs_bt = pid * BLOCK_BT + tl.arange(0, BLOCK_BT)
offs_d = tl.arange(0, CODEBOOK_DIM)
# ── Load input tile ──
x_ptrs = x_ptr + offs_bt[:, None] * stride_xb + offs_d[None, :] * stride_xd
x = tl.load(x_ptrs, mask=offs_bt[:, None] < N_CTX, other=0.0)
# ── Normalize input in fp32 (no keepdims in Triton tl.sum) ──
x_f32 = x.to(tl.float32)
x_sq = tl.sum(x_f32 * x_f32, axis=1) # [BLOCK_BT]
x_norm_f32 = x_f32 / tl.sqrt(x_sq[:, None] + 1e-8)
# ── Running argmax over tiled codebook ──
best_sim = tl.full([BLOCK_BT], -float('inf'), dtype=tl.float32)
best_idx = tl.zeros([BLOCK_BT], dtype=tl.int32)
for k_start in range(0, CODEBOOK_SIZE, TILE_K):
offs_k = k_start + tl.arange(0, TILE_K)
k_mask = offs_k < CODEBOOK_SIZE
# Load codebook tile into fp32 directly for normalization
cb_ptrs = (codebook_ptr
+ offs_k[:, None] * stride_cb
+ offs_d[None, :] * stride_cd)
cb = tl.load(cb_ptrs, mask=k_mask[:, None], other=0.0)
# Normalize codebook tile in fp32
cb_f32 = cb.to(tl.float32)
cb_sq = tl.sum(cb_f32 * cb_f32, axis=1) # [TILE_K]
cb_norm_f32 = cb_f32 / tl.sqrt(cb_sq[:, None] + 1e-8)
# Cosine similarity via tl.dot (tf32 on sm_89)
sim = tl.dot(x_norm_f32, tl.trans(cb_norm_f32)) # [BLOCK_BT, TILE_K]
# Running argmax within this tile
tile_max = tl.max(sim, axis=1)
tile_argmax = tl.argmax(sim, axis=1)
tile_idx = k_start + tile_argmax
# Merge with best across tiles using element-wise mask
update_mask = tile_max > best_sim
best_sim = tl.where(update_mask, tile_max, best_sim)
best_idx = tl.where(update_mask, tile_idx, best_idx)
# ── Store results ──
tl.store(indices_ptr + offs_bt, best_idx, mask=offs_bt < N_CTX)
@triton.jit
def _triton_flash_vq_quantize_kernel(
codebook_ptr, indices_ptr, quantized_ptr,
stride_cb, stride_cd,
stride_qb, stride_qd,
N_CTX: tl.constexpr,
CODEBOOK_DIM: tl.constexpr,
BLOCK_BT: tl.constexpr,
):
"""
Gather quantized vectors from codebook at given indices.
Kernel form of: quantized[i] = codebook[indices[i]]
"""
pid = tl.program_id(0)
offs_bt = pid * BLOCK_BT + tl.arange(0, BLOCK_BT)
offs_d = tl.arange(0, CODEBOOK_DIM)
# Load indices for this batch tile
idx = tl.load(indices_ptr + offs_bt, mask=offs_bt < N_CTX, other=0)
# Gather: for each i in BLOCK_BT, load codebook[idx[i], :]
# Pointer arithmetic with broadcasting
gather_ptrs = (codebook_ptr
+ idx[:, None] * stride_cb
+ offs_d[None, :] * stride_cd)
quantized = tl.load(gather_ptrs,
mask=offs_bt[:, None] < N_CTX,
other=0.0)
# Store quantized output
out_ptrs = (quantized_ptr
+ offs_bt[:, None] * stride_qb
+ offs_d[None, :] * stride_qd)
tl.store(out_ptrs, quantized, mask=offs_bt[:, None] < N_CTX)
def _triton_lookup(x, embed, block_bt=None, tile_k=None):
"""
Launch Triton VQ lookup kernel with SRAM-safe tile sizes.
Args:
x: [N, D] input tensor (cuda, contiguous)
embed: [codebook_size, D] codebook (cuda, contiguous)
block_bt: BLOCK_BT tile size (auto-computed if None)
tile_k: TILE_K tile size (auto-computed if None)
Returns:
indices: [N] int64 tensor of argmax indices
"""
N, D = x.shape
codebook_size = embed.shape[0]
assert embed.shape[1] == D, f"Codebook dim {embed.shape[1]} != input dim {D}"
# SRAM-safe tile sizes: kernel uses tf32 (fp32 math), and Triton
# pipelines data through shared memory. Conservative sizing ensures
# fits within ~99KB (sm_89) even with default num_stages=3.
#
# fp32 codebook tile: TILE_K * D * 4 → 128*32*4 = 16KB
# fp32 input tile: BLOCK_BT * D * 4 → 8*32*4 = 1KB
# Accumulator: BLOCK_BT*TILE_K*4 → 8*128*4 = 4KB
# Per stage: ~21KB. With 3 pipeline stages: ~63KB (fits in 99KB).
#
# Larger tiles oversubscribe SRAM (tested: TILE_K=1024 → 321KB needed).
if block_bt is None or tile_k is None:
BLOCK_BT = 8
TILE_K = 128
else:
BLOCK_BT, TILE_K = block_bt, tile_k
grid = (triton.cdiv(N, BLOCK_BT),)
indices = torch.empty(N, dtype=torch.int32, device=x.device)
_triton_flash_vq_lookup_kernel[grid](
x, embed, indices,
x.stride(0), x.stride(1),
embed.stride(0), embed.stride(1),
N, codebook_size, D,
BLOCK_BT=BLOCK_BT, TILE_K=TILE_K,
)
return indices.long()
class _TritonFlashVQFn(torch.autograd.Function):
"""
Custom autograd Function wrapping Triton VQ kernels.
Forward: Triton tiled cosine similarity + argmax lookup
Backward: Rotation trick gradient or straight-through estimator
"""
@staticmethod
def forward(ctx, x_flat, embed, cluster_size, embed_avg,
codebook_size, codebook_dim,
commitment_weight, rotation_trick):
# Triton tiled lookup for indices
with torch.no_grad():
indices = _triton_lookup(x_flat.contiguous(), embed.contiguous())
quantized = embed[indices]
commitment_loss = commitment_weight * F.mse_loss(x_flat.float(), quantized.detach().float())
# Clone saved tensors to avoid version conflicts with in-place EMA updates
ctx.save_for_backward(
x_flat.detach().clone(),
quantized.detach().clone(),
embed.detach().clone(),
)
ctx.codebook_dim = codebook_dim
ctx.rotation_trick = rotation_trick
return quantized, indices, commitment_loss
@staticmethod
def backward(ctx, grad_quantized, grad_indices, grad_commitment):
x_flat, quantized, embed = ctx.saved_tensors
if ctx.rotation_trick:
# Rotation trick gradient
x_norm = F.normalize(x_flat.float(), dim=-1)
q_norm = F.normalize(quantized.float(), dim=-1)
diff = x_norm - q_norm
proj = (grad_quantized.float() * x_norm).sum(dim=-1, keepdim=True)
grad_x = grad_quantized.float() - proj * diff
else:
# Straight-through estimator
grad_x = grad_quantized.float()
return grad_x.to(grad_quantized.dtype), None, None, None, None, None, None, None
# When Triton is not available, define a fallback lookup
if not _HAS_TRITON:
def _triton_lookup(x, embed):
"""Fallback: torch-based cosine similarity lookup (CPU or CUDA without Triton)."""
with torch.no_grad():
x_norm = F.normalize(x.float(), dim=-1)
embed_norm = F.normalize(embed.float(), dim=-1)
sim = x_norm @ embed_norm.T
indices = sim.argmax(dim=-1)
return indices