ARBS / arbitor /vq.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""VQ modules — vector quantization adapters."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm
from .components import TernaryVQCodebook
from .config import EMBEDDING_DIM, HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE, TIMESTAMP_MAX_PERIOD
class SharedVQ(nn.Module):
"""Single shared VQ codebook for all modalities (10M entries).
Each modality projects to the shared CODEBOOK_DIM=64 space, then
quantizes independently through the shared codebook. Text uses
CODEBOOK_DIM directly.
IDs are globally unique: all modalities share the same range [0, 10M).
"""
def __init__(self, codebook_size=SHARED_VQ_SIZE, codebook_dim=CODEBOOK_DIM,
tscale_type=TScaleType.T32, enable_image=True, enable_audio=True):
super().__init__()
codebook_size = SHARED_VQ_SIZE if codebook_size is None else codebook_size
self.codebook_size = codebook_size
self.codebook_dim = codebook_dim
# Per-modality input projections (their_dim → CODEBOOK_DIM)
self.text_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
if enable_image:
self.image_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
if enable_audio:
self.audio_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
# Shared VQ codebook
self.vq = TernaryVQCodebook(
codebook_size=codebook_size,
codebook_dim=codebook_dim,
commitment_weight=1.0,
tscale_type=tscale_type,
)
self.modalities = ['text']
if enable_image:
self.modalities.append('image')
if enable_audio:
self.modalities.append('audio')
@staticmethod
def _sinusoidal_timestamp(seq_len, dim, max_period=TIMESTAMP_MAX_PERIOD, device=None):
freqs = torch.exp(-torch.arange(0, dim, 2, device=device).float() * (math.log(max_period) / dim))
t = torch.arange(seq_len, device=device).float().unsqueeze(1)
pe = torch.zeros(seq_len, dim, device=device)
pe[:, 0::2] = torch.sin(t * freqs)
pe[:, 1::2] = torch.cos(t * freqs)
return pe
def forward(self, modality_inputs, timestep=0):
outputs = []
vq_losses = {}
indices_dict = {}
for mod in self.modalities:
if mod not in modality_inputs or modality_inputs[mod] is None:
continue
x = modality_inputs[mod]
proj = getattr(self, f'{mod}_proj')
x_proj = proj(x)
quantized, idx, loss = self.vq(x_proj)
outputs.append(quantized)
vq_losses[f'{mod}_vq'] = loss
indices_dict[mod] = idx
combined = torch.cat(outputs, dim=1) if outputs else modality_inputs.get('text', None)
if combined is not None and timestep > 0:
ts_enc = self._sinusoidal_timestamp(combined.shape[1], combined.shape[2], device=combined.device)
combined = combined + ts_enc.unsqueeze(0)
return combined, vq_losses, indices_dict
@property
def total_codebook_size(self):
return self.codebook_size
@torch.no_grad()
def get_codebook_utilization(self):
cluster_size = self.vq.cluster_size
return (cluster_size > 0).float().mean().item()
@torch.no_grad()
def get_dead_code_count(self):
cluster_size = self.vq.cluster_size
return (cluster_size < self.vq.threshold_ema_dead_code).sum().item()