"""VQ modules — vector quantization adapters.""" import math import torch import torch.nn as nn import torch.nn.functional as F from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm from .components import TernaryVQCodebook from .config import EMBEDDING_DIM, HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE, TIMESTAMP_MAX_PERIOD class SharedVQ(nn.Module): """Single shared VQ codebook for all modalities (10M entries). Each modality projects to the shared CODEBOOK_DIM=64 space, then quantizes independently through the shared codebook. Text uses CODEBOOK_DIM directly. IDs are globally unique: all modalities share the same range [0, 10M). """ def __init__(self, codebook_size=SHARED_VQ_SIZE, codebook_dim=CODEBOOK_DIM, tscale_type=TScaleType.T32, enable_image=True, enable_audio=True): super().__init__() codebook_size = SHARED_VQ_SIZE if codebook_size is None else codebook_size self.codebook_size = codebook_size self.codebook_dim = codebook_dim # Per-modality input projections (their_dim → CODEBOOK_DIM) self.text_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type) if enable_image: self.image_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type) if enable_audio: self.audio_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type) # Shared VQ codebook self.vq = TernaryVQCodebook( codebook_size=codebook_size, codebook_dim=codebook_dim, commitment_weight=1.0, tscale_type=tscale_type, ) self.modalities = ['text'] if enable_image: self.modalities.append('image') if enable_audio: self.modalities.append('audio') @staticmethod def _sinusoidal_timestamp(seq_len, dim, max_period=TIMESTAMP_MAX_PERIOD, device=None): freqs = torch.exp(-torch.arange(0, dim, 2, device=device).float() * (math.log(max_period) / dim)) t = torch.arange(seq_len, device=device).float().unsqueeze(1) pe = torch.zeros(seq_len, dim, device=device) pe[:, 0::2] = torch.sin(t * freqs) pe[:, 1::2] = torch.cos(t * freqs) return pe def forward(self, modality_inputs, timestep=0): outputs = [] vq_losses = {} indices_dict = {} for mod in self.modalities: if mod not in modality_inputs or modality_inputs[mod] is None: continue x = modality_inputs[mod] proj = getattr(self, f'{mod}_proj') x_proj = proj(x) quantized, idx, loss = self.vq(x_proj) outputs.append(quantized) vq_losses[f'{mod}_vq'] = loss indices_dict[mod] = idx combined = torch.cat(outputs, dim=1) if outputs else modality_inputs.get('text', None) if combined is not None and timestep > 0: ts_enc = self._sinusoidal_timestamp(combined.shape[1], combined.shape[2], device=combined.device) combined = combined + ts_enc.unsqueeze(0) return combined, vq_losses, indices_dict @property def total_codebook_size(self): return self.codebook_size @torch.no_grad() def get_codebook_utilization(self): cluster_size = self.vq.cluster_size return (cluster_size > 0).float().mean().item() @torch.no_grad() def get_dead_code_count(self): cluster_size = self.vq.cluster_size return (cluster_size < self.vq.threshold_ema_dead_code).sum().item()