"""Sequencer modules — input processing for all modalities.""" import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm, GROUP_SIZES, _HAS_TRITON, _HAS_TILELANG if _HAS_TRITON: import triton import triton.language as tl else: triton = None tl = None try: from .kernel.ternary_scale import _TritonTernaryEmbedFn except ImportError: _TritonTernaryEmbedFn = None from .converters.convert_to_ternary8 import pack_ternary, unpack_ternary from math import ceil as _ceil _ceil_div = lambda a, b: _ceil(a / b) if b > 0 else 0 from .config import VOCAB, EMBEDDING_DIM, HIDDEN_DIM, AUDIO_SR, AUDIO_FRAME_RATE class ByteEmbedding(nn.Module): """Byte-level embedding via packed ternary + BigInt correlation. All training state is integer. T_accum/E_accum replaced by corr_accum (int64 per group, never clips or resets). S = 2^(E + K × mean_corr) where mean_corr = corr_accum / (step × gs) """ def __init__(self, tscale_type=TScaleType.T32): super().__init__() self.tscale_type = tscale_type self.threshold = 0.05 self.group_size = GROUP_SIZES.get(tscale_type, GROUP_SIZES[TScaleType.T64]) shape = (VOCAB, EMBEDDING_DIM) init_std = 0.02 init_threshold = min(self.threshold, 0.5 * init_std) self.threshold = init_threshold w_init = torch.randn(VOCAB, EMBEDDING_DIM) * init_std T_init = w_init.sign() * (w_init.abs() > init_threshold).to(w_init.dtype) packed_T, T_shape, T_pad = pack_ternary(T_init) self.register_buffer("T_packed", packed_T) self.register_buffer("_T_shape", torch.tensor([VOCAB, EMBEDDING_DIM], dtype=torch.long)) self.register_buffer("_T_pad", torch.tensor(T_pad, dtype=torch.long)) out_dim, in_dim = shape gpr = _ceil_div(in_dim, self.group_size) total_in = gpr * self.group_size padded = torch.zeros(out_dim, total_in) abs_w = w_init.abs() padded[:, :in_dim] = abs_w grouped = padded.view(out_dim, gpr, self.group_size) grp_means = grouped.mean(dim=2) E_vals = torch.where(grp_means > 0, grp_means, torch.ones_like(grp_means)) self.register_buffer("E", E_vals.flatten().log2().clamp(-128, 127).to(torch.int8)) # BigInt correlation accumulator (replaces T_accum + E_accum) n_grp = out_dim * gpr self.register_buffer("corr_accum", torch.zeros(n_grp, dtype=torch.int64)) self.register_buffer("step_counter", torch.zeros(1, dtype=torch.int64)) self.norm = TernaryRMSNorm(EMBEDDING_DIM, tscale_type=tscale_type) def _get_T(self): return unpack_ternary(self.T_packed, tuple(self._T_shape.tolist()), int(self._T_pad.item())) def _get_S(self): gpr = _ceil_div(EMBEDDING_DIM, self.group_size) e_adj = self.E.float() step = int(self.step_counter.item()) if step > 0: from .kernel.ternary_scale import _bigint_corr_strength denom = max(step * self.group_size, 1) e_adj = e_adj + (self.corr_accum.float() / denom) * _bigint_corr_strength() E_exp = e_adj.view(VOCAB, gpr).repeat_interleave(self.group_size, dim=1) if E_exp.shape[1] > EMBEDDING_DIM: E_exp = E_exp[:, :EMBEDDING_DIM] return torch.exp2(E_exp) @torch.no_grad() def _accumulate_corr_from_grad_sign(self, grad_sign, corr_step=1): if grad_sign is None: return shape = tuple(self._T_shape.tolist()) out_dim, in_dim = shape if tuple(grad_sign.shape) != shape: return gs = self.group_size T = self._get_T().to(device=grad_sign.device, dtype=torch.int16) signed = grad_sign.to(torch.int16) * T gpr = _ceil_div(in_dim, gs) total_in = gpr * gs if total_in > in_dim: signed = F.pad(signed, (0, total_in - in_dim)) score = signed.view(out_dim, gpr, gs).sum(dim=2, dtype=torch.int16) self.corr_accum -= score.flatten().to(dtype=torch.int64) * int(corr_step) self.step_counter += abs(int(corr_step)) def forward(self, x): if x.is_cuda and _HAS_TRITON and _TritonTernaryEmbedFn is not None: _dummy = torch.zeros(1, device=x.device, requires_grad=True) emb = _TritonTernaryEmbedFn.apply(x, _dummy, self) return self.norm(emb) T = self._get_T() S = self._get_S() w_eff = S * T.float() w_eff_grad = w_eff.detach().requires_grad_(True) def capture_w_grad(grad_w): self._hook_grad_T_sign = grad_w.sign().to(torch.int8) w_eff_grad.register_hook(capture_w_grad) out = self.norm(F.embedding(x, w_eff_grad)) return out def ternary_step(self, accum_threshold=3): if hasattr(self, "_hook_grad_T_sign"): if hasattr(self, "_accumulate_corr_from_grad_sign"): self._accumulate_corr_from_grad_sign(self._hook_grad_T_sign) del self._hook_grad_T_sign def update_E(self, loss_signal=None): pass # E is fixed; S adjusted via corr_accum class Sequencer(nn.Module): def __init__(self, modality, window_size, tscale_type=TScaleType.T32): super().__init__() self.modality = modality self.window_size = window_size self.tscale_type = tscale_type def forward(self, x): raise NotImplementedError class TextSequencer(Sequencer): def __init__(self, tscale_type=TScaleType.T32): super().__init__(modality='text', window_size=3, tscale_type=tscale_type) self.projection = TernaryScaleTensor(EMBEDDING_DIM * self.window_size, HIDDEN_DIM, tscale_type=tscale_type) self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) def forward(self, x): trigrams = x.unfold(dimension=1, size=self.window_size, step=1) trigrams = rearrange(trigrams, 'b t d w -> b t (d w)') relational = self.projection(trigrams) return self.norm(relational) class VAE2DSequencer(Sequencer): def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"): super().__init__(modality='image', window_size=1, tscale_type=tscale_type) from .encoders.vae2d import load_vae2d as _load_vae2d self.vae = _load_vae2d(device=device, quantize=quantize) self.vae_device = torch.device(device) self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type) self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) def forward(self, x): if x.device != self.vae_device: x = x.to(self.vae_device) latent = self.vae(x) tokens = rearrange(latent, 'b c h w -> b (h w) c') out = self.project(tokens) return self.norm(out) class VAEAudioSequencer(Sequencer): def __init__(self, tscale_type=TScaleType.T32, quantize=None, device="cpu"): super().__init__(modality='audio', window_size=1, tscale_type=tscale_type) from .encoders.vae2d import load_vae2d as _load_vae2d from .encoders.mel_frontend import MelSpectrogram3Band as _Mel3Band self.vae = _load_vae2d(device=device, quantize=quantize) self.vae_device = torch.device(device) self.mel = _Mel3Band(sample_rate=AUDIO_SR) self.project = TernaryScaleTensor(4, HIDDEN_DIM, tscale_type=tscale_type) self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) def forward(self, waveform): if waveform.dim() == 1: waveform = waveform.unsqueeze(0) elif waveform.dim() == 3: if waveform.shape[1] == 1: waveform = waveform.squeeze(1) else: waveform = waveform.mean(dim=1) spec = self.mel(waveform) if spec.device != self.vae_device: spec = spec.to(self.vae_device) latent = self.vae(spec) tokens = rearrange(latent, 'b c h w -> b (h w) c') out = self.project(tokens) return self.norm(out) class MultimodalSequencer(nn.Module): def __init__(self, tscale_type=TScaleType.T32, enable_text=True, enable_image=True, enable_audio=True): super().__init__() self.text = TextSequencer(tscale_type=tscale_type) if enable_text else None self.image = VAE2DSequencer(tscale_type=tscale_type) if enable_image else None self.audio = VAEAudioSequencer(tscale_type=tscale_type) if enable_audio else None self.enabled_modalities = [] if enable_text: self.enabled_modalities.append('text') if enable_image: self.enabled_modalities.append('image') if enable_audio: self.enabled_modalities.append('audio') def forward(self, modality_inputs): outputs = {} for mod in self.enabled_modalities: seq = getattr(self, mod) if mod in modality_inputs and modality_inputs[mod] is not None and seq is not None: outputs[mod] = seq(modality_inputs[mod]) return outputs