File size: 9,179 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | """Sequencer modules — input processing for all modalities."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _HAS_TRITON, _HAS_TILELANG
if _HAS_TRITON:
import triton
import triton.language as tl
else:
triton = None
tl = None
try:
from .kernel.ternary_scale import _TritonTernaryEmbedFn
except ImportError:
_TritonTernaryEmbedFn = None
from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary
from math import ceil as _ceil
_ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0
from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_SR, AUDIO_FRAME_RATE
class ByteEmbedding(nn.Module):
"""Byte-level embedding via packed ternary + BigInt correlation.
All training state is integer. T_accum/E_accum replaced by
corr_accum (int64 per group, never clips or resets).
S = 2^(E + K × mean_corr) where mean_corr = corr_accum / (step × gs)
"""
def __init__(self, tscale_type=TScaleType.T32):
super().__init__()
self.tscale_type = tscale_type
self.threshold = 0.05
self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64])
shape = (VOCAB, EMBEDDING_DIM)
init_std = 0.02
init_threshold = min(self.threshold, 0.5 * init_std)
self.threshold = init_threshold
w_init = torch.randn(VOCAB, EMBEDDING_DIM) * init_std
T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype)
packed_T, T_shape, T_pad = pack_ternary(T_init)
self.register_buffer("T_packed", packed_T)
self.register_buffer("_T_shape", torch.tensor([VOCAB, EMBEDDING_DIM], dtype=torch.long))
self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long))
out_dim, in_dim = shape
gpr = _ceil_div(in_dim, self.group_size)
total_in = gpr * self.group_size
padded = torch.zeros(out_dim, total_in)
abs_w = w_init.abs()
padded[:, :in_dim] = abs_w
grouped = padded.view(out_dim, gpr, self.group_size)
grp_means = grouped.mean(dim=2)
E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means))
self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8))
# BigInt correlation accumulator (replaces T_accum + E_accum)
n_grp = out_dim * gpr
self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64))
self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64))
self.norm = TernaryRMSNorm(EMBEDDING_DIM, tscale_type=tscale_type)
def _get_T(self):
return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item()))
def _get_S(self):
gpr = _ceil_div(EMBEDDING_DIM, self.group_size)
e_adj = self.E.float()
step = int(self.step_counter.item())
if step > 0:
from .kernel.ternary_scale import _bigint_corr_strength
denom = max(step * self.group_size, 1)
e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength()
E_exp = e_adj.view(VOCAB, gpr).repeat_interleave(self.group_size, dim=1)
if E_exp.shape[1] > EMBEDDING_DIM:
E_exp = E_exp[:, :EMBEDDING_DIM]
return torch.exp2(E_exp)
@torch.no_grad()
def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1):
if grad_sign is None:
return
shape = tuple(self._T_shape.tolist())
out_dim, in_dim = shape
if tuple(grad_sign.shape) != shape:
return
gs = self.group_size
T = self._get_T().to(device=grad_sign.device, dtype=torch.int16)
signed = grad_sign.to(torch.int16) * T
gpr = _ceil_div(in_dim, gs)
total_in = gpr * gs
if total_in > in_dim:
signed = F.pad(signed, (0, total_in - in_dim))
score = signed.view(out_dim, gpr, gs).sum(dim=2, dtype=torch.int16)
self.corr_accum -= score.flatten().to(dtype=torch.int64) * int(corr_step)
self.step_counter += abs(int(corr_step))
def forward(self, x):
if x.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None:
_dummy = torch.zeros(1, device=x.device, requires_grad=True)
emb = _TritonTernaryEmbedFn.apply(x, _dummy, self)
return self.norm(emb)
T = self._get_T()
S = self._get_S()
w_eff = S * T.float()
w_eff_grad = w_eff.detach().requires_grad_(True)
def capture_w_grad(grad_w):
self._hook_grad_T_sign = grad_w.sign().to(torch.int8)
w_eff_grad.register_hook(capture_w_grad)
out = self.norm(F.embedding(x, w_eff_grad))
return out
def ternary_step(self, accum_threshold=3):
if hasattr(self, "_hook_grad_T_sign"):
if hasattr(self, "_accumulate_corr_from_grad_sign"):
self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign)
del self._hook_grad_T_sign
def update_E(self, loss_signal=None):
pass # E is fixed; S adjusted via corr_accum
class Sequencer(nn.Module):
def __init__(self, modality, window_size, tscale_type=TScaleType.T32):
super().__init__()
self.modality = modality
self.window_size = window_size
self.tscale_type = tscale_type
def forward(self, x):
raise NotImplementedError
class TextSequencer(Sequencer):
def __init__(self, tscale_type=TScaleType.T32):
super().__init__(modality='text', window_size=3, tscale_type=tscale_type)
self.projection = TernaryScaleTensor(EMBEDDING_DIM * self.window_size, HIDDEN_DIM, tscale_type=tscale_type)
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
def forward(self, x):
trigrams = x.unfold(dimension=1, size=self.window_size, step=1)
trigrams = rearrange(trigrams, 'b t d w -> b t (d w)')
relational = self.projection(trigrams)
return self.norm(relational)
class VAE2DSequencer(Sequencer):
def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"):
super().__init__(modality='image', window_size=1, tscale_type=tscale_type)
from .encoders.vae2d import load_vae2d as _load_vae2d
self.vae = _load_vae2d(device=device, quantize=quantize)
self.vae_device = torch.device(device)
self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type)
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
def forward(self, x):
if x.device != self.vae_device:
x = x.to(self.vae_device)
latent = self.vae(x)
tokens = rearrange(latent, 'b c h w -> b (h w) c')
out = self.project(tokens)
return self.norm(out)
class VAEAudioSequencer(Sequencer):
def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"):
super().__init__(modality='audio', window_size=1, tscale_type=tscale_type)
from .encoders.vae2d import load_vae2d as _load_vae2d
from .encoders.mel_frontend import MelSpectrogram3Band as _Mel3Band
self.vae = _load_vae2d(device=device, quantize=quantize)
self.vae_device = torch.device(device)
self.mel = _Mel3Band(sample_rate=AUDIO_SR)
self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type)
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
def forward(self, waveform):
if waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
elif waveform.dim() == 3:
if waveform.shape[1] == 1:
waveform = waveform.squeeze(1)
else:
waveform = waveform.mean(dim=1)
spec = self.mel(waveform)
if spec.device != self.vae_device:
spec = spec.to(self.vae_device)
latent = self.vae(spec)
tokens = rearrange(latent, 'b c h w -> b (h w) c')
out = self.project(tokens)
return self.norm(out)
class MultimodalSequencer(nn.Module):
def __init__(self, tscale_type=TScaleType.T32, enable_text=True, enable_image=True, enable_audio=True):
super().__init__()
self.text = TextSequencer(tscale_type=tscale_type) if enable_text else None
self.image = VAE2DSequencer(tscale_type=tscale_type) if enable_image else None
self.audio = VAEAudioSequencer(tscale_type=tscale_type) if enable_audio else None
self.enabled_modalities = []
if enable_text:
self.enabled_modalities.append('text')
if enable_image:
self.enabled_modalities.append('image')
if enable_audio:
self.enabled_modalities.append('audio')
def forward(self, modality_inputs):
outputs = {}
for mod in self.enabled_modalities:
seq = getattr(self, mod)
if mod in modality_inputs and modality_inputs[mod] is not None and seq is not None:
outputs[mod] = seq(modality_inputs[mod])
return outputs
|