File size: 3,638 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""VQ modules — vector quantization adapters."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm
from .components import TernaryVQCodebook
from .config import EMBEDDING_DIM, HIDDEN_DIM, CODEBOOK_DIM, SHARED_VQ_SIZE, TIMESTAMP_MAX_PERIOD


class SharedVQ(nn.Module):
    """Single shared VQ codebook for all modalities (10M entries).

    Each modality projects to the shared CODEBOOK_DIM=64 space, then
    quantizes independently through the shared codebook. Text uses
    CODEBOOK_DIM directly.

    IDs are globally unique: all modalities share the same range [0, 10M).
    """
    def __init__(self, codebook_size=SHARED_VQ_SIZE, codebook_dim=CODEBOOK_DIM,
                 tscale_type=TScaleType.T32, enable_image=True, enable_audio=True):
        super().__init__()
        codebook_size = SHARED_VQ_SIZE if codebook_size is None else codebook_size
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim

        # Per-modality input projections (their_dim → CODEBOOK_DIM)
        self.text_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
        if enable_image:
            self.image_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)
        if enable_audio:
            self.audio_proj = TernaryScaleTensor(HIDDEN_DIM, codebook_dim, tscale_type=tscale_type)

        # Shared VQ codebook
        self.vq = TernaryVQCodebook(
            codebook_size=codebook_size,
            codebook_dim=codebook_dim,
            commitment_weight=1.0,
            tscale_type=tscale_type,
        )
        self.modalities = ['text']
        if enable_image:
            self.modalities.append('image')
        if enable_audio:
            self.modalities.append('audio')

    @staticmethod
    def _sinusoidal_timestamp(seq_len, dim, max_period=TIMESTAMP_MAX_PERIOD, device=None):
        freqs = torch.exp(-torch.arange(0, dim, 2, device=device).float() * (math.log(max_period) / dim))
        t = torch.arange(seq_len, device=device).float().unsqueeze(1)
        pe = torch.zeros(seq_len, dim, device=device)
        pe[:, 0::2] = torch.sin(t * freqs)
        pe[:, 1::2] = torch.cos(t * freqs)
        return pe

    def forward(self, modality_inputs, timestep=0):
        outputs = []
        vq_losses = {}
        indices_dict = {}
        for mod in self.modalities:
            if mod not in modality_inputs or modality_inputs[mod] is None:
                continue
            x = modality_inputs[mod]
            proj = getattr(self, f'{mod}_proj')
            x_proj = proj(x)
            quantized, idx, loss = self.vq(x_proj)
            outputs.append(quantized)
            vq_losses[f'{mod}_vq'] = loss
            indices_dict[mod] = idx

        combined = torch.cat(outputs, dim=1) if outputs else modality_inputs.get('text', None)
        if combined is not None and timestep > 0:
            ts_enc = self._sinusoidal_timestamp(combined.shape[1], combined.shape[2], device=combined.device)
            combined = combined + ts_enc.unsqueeze(0)
        return combined, vq_losses, indices_dict

    @property
    def total_codebook_size(self):
        return self.codebook_size

    @torch.no_grad()
    def get_codebook_utilization(self):
        cluster_size = self.vq.cluster_size
        return (cluster_size > 0).float().mean().item()

    @torch.no_grad()
    def get_dead_code_count(self):
        cluster_size = self.vq.cluster_size
        return (cluster_size < self.vq.threshold_ema_dead_code).sum().item()