| """VQ modules — vector quantization adapters.""" |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm |
| from .components import TernaryVQCodebook |
| from .config import EMBEDDING_DIM, HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE, TIMESTAMP_MAX_PERIOD |
|
|
|
|
| class SharedVQ(nn.Module): |
| """Single shared VQ codebook for all modalities (10M entries). |
| |
| Each modality projects to the shared CODEBOOK_DIM=64 space, then |
| quantizes independently through the shared codebook. Text uses |
| CODEBOOK_DIM directly. |
| |
| IDs are globally unique: all modalities share the same range [0, 10M). |
| """ |
| def __init__(self, codebook_size=SHARED_VQ_SIZE, codebook_dim=CODEBOOK_DIM, |
| tscale_type=TScaleType.T32, enable_image=True, enable_audio=True): |
| super().__init__() |
| codebook_size = SHARED_VQ_SIZE if codebook_size is None else codebook_size |
| self.codebook_size = codebook_size |
| self.codebook_dim = codebook_dim |
|
|
| |
| self.text_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type) |
| if enable_image: |
| self.image_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type) |
| if enable_audio: |
| self.audio_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type) |
|
|
| |
| self.vq = TernaryVQCodebook( |
| codebook_size=codebook_size, |
| codebook_dim=codebook_dim, |
| commitment_weight=1.0, |
| tscale_type=tscale_type, |
| ) |
| self.modalities = ['text'] |
| if enable_image: |
| self.modalities.append('image') |
| if enable_audio: |
| self.modalities.append('audio') |
|
|
| @staticmethod |
| def _sinusoidal_timestamp(seq_len, dim, max_period=TIMESTAMP_MAX_PERIOD, device=None): |
| freqs = torch.exp(-torch.arange(0, dim, 2, device=device).float() * (math.log(max_period) / dim)) |
| t = torch.arange(seq_len, device=device).float().unsqueeze(1) |
| pe = torch.zeros(seq_len, dim, device=device) |
| pe[:, 0::2] = torch.sin(t * freqs) |
| pe[:, 1::2] = torch.cos(t * freqs) |
| return pe |
|
|
| def forward(self, modality_inputs, timestep=0): |
| outputs = [] |
| vq_losses = {} |
| indices_dict = {} |
| for mod in self.modalities: |
| if mod not in modality_inputs or modality_inputs[mod] is None: |
| continue |
| x = modality_inputs[mod] |
| proj = getattr(self, f'{mod}_proj') |
| x_proj = proj(x) |
| quantized, idx, loss = self.vq(x_proj) |
| outputs.append(quantized) |
| vq_losses[f'{mod}_vq'] = loss |
| indices_dict[mod] = idx |
|
|
| combined = torch.cat(outputs, dim=1) if outputs else modality_inputs.get('text', None) |
| if combined is not None and timestep > 0: |
| ts_enc = self._sinusoidal_timestamp(combined.shape[1], combined.shape[2], device=combined.device) |
| combined = combined + ts_enc.unsqueeze(0) |
| return combined, vq_losses, indices_dict |
|
|
| @property |
| def total_codebook_size(self): |
| return self.codebook_size |
|
|
| @torch.no_grad() |
| def get_codebook_utilization(self): |
| cluster_size = self.vq.cluster_size |
| return (cluster_size > 0).float().mean().item() |
|
|
| @torch.no_grad() |
| def get_dead_code_count(self): |
| cluster_size = self.vq.cluster_size |
| return (cluster_size < self.vq.threshold_ema_dead_code).sum().item() |
|
|