File size: 9,179 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""Sequencer modules — input processing for all modalities."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _HAS_TRITON, _HAS_TILELANG
if _HAS_TRITON:
    import triton
    import triton.language as tl
else:
    triton = None
    tl = None
try:
    from .kernel.ternary_scale import _TritonTernaryEmbedFn
except ImportError:
    _TritonTernaryEmbedFn = None
from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary
from math import ceil as _ceil

_ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_SR, AUDIO_FRAME_RATE


class ByteEmbedding(nn.Module):
    """Byte-level embedding via packed ternary + BigInt correlation.

    All training state is integer. T_accum/E_accum replaced by
    corr_accum (int64 per group, never clips or resets).

    S = 2^(E + K × mean_corr)  where mean_corr = corr_accum / (step × gs)
    """
    def __init__(self, tscale_type=TScaleType.T32):
        super().__init__()
        self.tscale_type = tscale_type
        self.threshold = 0.05
        self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64])
        shape = (VOCAB, EMBEDDING_DIM)

        init_std = 0.02
        init_threshold = min(self.threshold, 0.5 * init_std)
        self.threshold = init_threshold
        w_init = torch.randn(VOCAB, EMBEDDING_DIM) * init_std
        T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype)
        packed_T, T_shape, T_pad = pack_ternary(T_init)

        self.register_buffer("T_packed", packed_T)
        self.register_buffer("_T_shape", torch.tensor([VOCAB, EMBEDDING_DIM], dtype=torch.long))
        self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))

        out_dim, in_dim = shape
        gpr = _ceil_div(in_dim, self.group_size)
        total_in = gpr * self.group_size
        padded = torch.zeros(out_dim, total_in)
        abs_w = w_init.abs()
        padded[:, :in_dim] = abs_w
        grouped = padded.view(out_dim, gpr, self.group_size)
        grp_means = grouped.mean(dim=2)
        E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
        self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))

        # BigInt correlation accumulator (replaces T_accum + E_accum)
        n_grp = out_dim * gpr
        self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64))
        self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64))

        self.norm = TernaryRMSNorm(EMBEDDING_DIM, tscale_type=tscale_type)

    def _get_T(self):
        return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))

    def _get_S(self):
        gpr = _ceil_div(EMBEDDING_DIM, self.group_size)
        e_adj = self.E.float()
        step = int(self.step_counter.item())
        if step > 0:
            from .kernel.ternary_scale import _bigint_corr_strength
            denom = max(step * self.group_size, 1)
            e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength()
        E_exp = e_adj.view(VOCAB, gpr).repeat_interleave(self.group_size, dim=1)
        if E_exp.shape[1] > EMBEDDING_DIM:
            E_exp = E_exp[:, :EMBEDDING_DIM]
        return torch.exp2(E_exp)

    @torch.no_grad()
    def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1):
        if grad_sign is None:
            return
        shape = tuple(self._T_shape.tolist())
        out_dim, in_dim = shape
        if tuple(grad_sign.shape) != shape:
            return
        gs = self.group_size
        T = self._get_T().to(device=grad_sign.device, dtype=torch.int16)
        signed = grad_sign.to(torch.int16) * T
        gpr = _ceil_div(in_dim, gs)
        total_in = gpr * gs
        if total_in > in_dim:
            signed = F.pad(signed, (0, total_in - in_dim))
        score = signed.view(out_dim, gpr, gs).sum(dim=2, dtype=torch.int16)
        self.corr_accum -= score.flatten().to(dtype=torch.int64) * int(corr_step)
        self.step_counter += abs(int(corr_step))

    def forward(self, x):
        if x.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None:
            _dummy = torch.zeros(1, device=x.device, requires_grad=True)
            emb = _TritonTernaryEmbedFn.apply(x, _dummy, self)
            return self.norm(emb)
        T = self._get_T()
        S = self._get_S()
        w_eff = S * T.float()
        w_eff_grad = w_eff.detach().requires_grad_(True)

        def capture_w_grad(grad_w):
            self._hook_grad_T_sign = grad_w.sign().to(torch.int8)

        w_eff_grad.register_hook(capture_w_grad)
        out = self.norm(F.embedding(x, w_eff_grad))
        return out

    def ternary_step(self, accum_threshold=3):
        if hasattr(self, "_hook_grad_T_sign"):
            if hasattr(self, "_accumulate_corr_from_grad_sign"):
                self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
            del self._hook_grad_T_sign

    def update_E(self, loss_signal=None):
        pass  # E is fixed; S adjusted via corr_accum


class Sequencer(nn.Module):
    def __init__(self, modality, window_size, tscale_type=TScaleType.T32):
        super().__init__()
        self.modality = modality
        self.window_size = window_size
        self.tscale_type = tscale_type

    def forward(self, x):
        raise NotImplementedError


class TextSequencer(Sequencer):
    def __init__(self, tscale_type=TScaleType.T32):
        super().__init__(modality='text', window_size=3, tscale_type=tscale_type)
        self.projection = TernaryScaleTensor(EMBEDDING_DIM * self.window_size, HIDDEN_DIM, tscale_type=tscale_type)
        self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)

    def forward(self, x):
        trigrams = x.unfold(dimension=1, size=self.window_size, step=1)
        trigrams = rearrange(trigrams, 'b t d w -> b t (d w)')
        relational = self.projection(trigrams)
        return self.norm(relational)
class VAE2DSequencer(Sequencer):
    def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"):
        super().__init__(modality='image', window_size=1, tscale_type=tscale_type)
        from .encoders.vae2d import load_vae2d as _load_vae2d
        self.vae = _load_vae2d(device=device, quantize=quantize)
        self.vae_device = torch.device(device)
        self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type)
        self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)

    def forward(self, x):
        if x.device != self.vae_device:
            x = x.to(self.vae_device)
        latent = self.vae(x)
        tokens = rearrange(latent, 'b c h w -> b (h w) c')
        out = self.project(tokens)
        return self.norm(out)


class VAEAudioSequencer(Sequencer):
    def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"):
        super().__init__(modality='audio', window_size=1, tscale_type=tscale_type)
        from .encoders.vae2d import load_vae2d as _load_vae2d
        from .encoders.mel_frontend import MelSpectrogram3Band as _Mel3Band
        self.vae = _load_vae2d(device=device, quantize=quantize)
        self.vae_device = torch.device(device)
        self.mel = _Mel3Band(sample_rate=AUDIO_SR)
        self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type)
        self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)

    def forward(self, waveform):
        if waveform.dim() == 1:
            waveform = waveform.unsqueeze(0)
        elif waveform.dim() == 3:
            if waveform.shape[1] == 1:
                waveform = waveform.squeeze(1)
            else:
                waveform = waveform.mean(dim=1)
        spec = self.mel(waveform)
        if spec.device != self.vae_device:
            spec = spec.to(self.vae_device)
        latent = self.vae(spec)
        tokens = rearrange(latent, 'b c h w -> b (h w) c')
        out = self.project(tokens)
        return self.norm(out)


class MultimodalSequencer(nn.Module):
    def __init__(self, tscale_type=TScaleType.T32, enable_text=True, enable_image=True, enable_audio=True):
        super().__init__()
        self.text = TextSequencer(tscale_type=tscale_type) if enable_text else None
        self.image = VAE2DSequencer(tscale_type=tscale_type) if enable_image else None
        self.audio = VAEAudioSequencer(tscale_type=tscale_type) if enable_audio else None
        self.enabled_modalities = []
        if enable_text:
            self.enabled_modalities.append('text')
        if enable_image:
            self.enabled_modalities.append('image')
        if enable_audio:
            self.enabled_modalities.append('audio')

    def forward(self, modality_inputs):
        outputs = {}
        for mod in self.enabled_modalities:
            seq = getattr(self, mod)
            if mod in modality_inputs and modality_inputs[mod] is not None and seq is not None:
                outputs[mod] = seq(modality_inputs[mod])
        return outputs