| """Decoder modules — video diffusion, audio codec, speech generation. |
| |
| These modules convert HIDDEN_DIM relational states into modality-specific outputs: |
| video (latent diffusion), audio (codec tokens), and speech (token striding + codec). |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm |
| from .kernel.triton_video import video_denoise_step |
| from .config import HIDDEN_DIM, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, \ |
| VIDEO_LATENT_CHANNELS, VIDEO_MAX_STEPS, VIDEO_HEIGHT, VIDEO_WIDTH, \ |
| VIDEOHEAD_ACT_MIN_FPS, VIDEOHEAD_ACT_MAX_FPS, VIDEOHEAD_ACT_FRAME_CHUNK, \ |
| TALKERHEAD_ACT_CHUNK_FRAMES |
| from .components import TernaryEmbeddingTable |
|
|
|
|
| class LTIInjection(nn.Module): |
| """LTI state injection: h = A*h + B*e + trans_out. |
| Spectral radius < 1 guaranteed by construction via ZOH discretization. |
| """ |
| def __init__(self, dim: int): |
| super().__init__() |
| self.log_A = nn.Parameter(torch.zeros(dim)) |
| self.log_dt = nn.Parameter(torch.zeros(1)) |
| self.B = nn.Parameter(torch.ones(dim) * 0.1) |
| for p in (self.log_A, self.log_dt, self.B): |
| p.requires_grad_(False) |
|
|
| def get_A(self): |
| return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))) |
|
|
| def forward(self, h, e, trans_out): |
| return self.get_A() * h + self.B * e + trans_out |
|
|
|
|
| class VideoHead(nn.Module): |
| """Scaled latent diffusion with cross-attention conditioning, frame gate, and 4-frame latent. |
| |
| Produces [B, ch, 4, H', W'] latents (4-frame temporal chunks) per D-102. |
| Frame gate controls adaptive fps in [MIN_FPS, MAX_FPS] range. |
| """ |
| def __init__(self, tscale_type=TScaleType.T32, max_steps=VIDEO_MAX_STEPS, |
| latent_channels=VIDEO_LATENT_CHANNELS, height=VIDEO_HEIGHT, width=VIDEO_WIDTH, |
| min_fps=VIDEOHEAD_ACT_MIN_FPS, max_fps=VIDEOHEAD_ACT_MAX_FPS, |
| frame_chunk=VIDEOHEAD_ACT_FRAME_CHUNK): |
| super().__init__() |
| self.max_steps = max_steps |
| self.latent_channels = latent_channels |
| self.height = height |
| self.width = width |
| self.latent_dim = latent_channels * height * width |
| self.halt_threshold = 0.05 |
| self.min_fps = min_fps |
| self.max_fps = max_fps |
| self.frame_chunk = frame_chunk |
|
|
| self.cross_attn_q = TernaryScaleTensor(self.latent_dim, HIDDEN_DIM, tscale_type=tscale_type) |
| self.cross_attn_kv = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type) |
| self.diffusion_step = TernaryScaleTensor(HIDDEN_DIM, self.latent_dim, tscale_type=tscale_type) |
| self.halt_unit = TernaryScaleTensor(HIDDEN_DIM, 1, tscale_type=tscale_type) |
| self.frame_gate = TernaryScaleTensor(HIDDEN_DIM, 1, tscale_type=tscale_type) |
| self.noise_embed = TernaryEmbeddingTable(max_steps, HIDDEN_DIM, tscale_type=tscale_type) |
| self.lti = LTIInjection(self.latent_dim) |
|
|
| @torch.no_grad() |
| def _compute_fps(self, cond): |
| frame_prob = torch.sigmoid(self.frame_gate(cond)) |
| fps = self.min_fps + frame_prob * (self.max_fps - self.min_fps) |
| return fps.mean().item() |
|
|
| def forward(self, relational, max_steps=None, duration_seconds=1.0): |
| B, T, D = relational.shape |
| max_steps = max_steps or self.max_steps |
| cond = relational.mean(dim=1, keepdim=True) |
|
|
| fps = self._compute_fps(cond) |
| n_frames = max(1, int(fps * duration_seconds)) |
| n_latents = min((n_frames + self.frame_chunk - 1) // self.frame_chunk, max_steps) |
|
|
| all_latents = [] |
| for chunk_idx in range(n_latents): |
| latent = torch.randn(B, 1, self.latent_dim, device=relational.device, |
| requires_grad=torch.is_grad_enabled()) |
| for step in range(max_steps): |
| q = self.cross_attn_q(latent) |
| kv = self.cross_attn_kv(cond.expand(-1, T, -1)) |
| context = kv.mean(dim=1, keepdim=True) |
| step_embed = self.noise_embed(torch.tensor(step, device=relational.device)) |
| step_embed = step_embed.expand(B, 1, -1) |
| step_input = q + context + step_embed |
| pred_noise = self.diffusion_step(step_input) |
| alpha = 0.9 ** step |
| trans_out = video_denoise_step(latent, pred_noise, alpha) |
| h = torch.zeros(B, 1, self.latent_dim, device=context.device) |
| h[:, :, :HIDDEN_DIM] = context |
| latent = self.lti(latent, h, trans_out) |
| halt = torch.sigmoid(self.halt_unit(context)) |
| if halt.mean() > self.halt_threshold and step > 1: |
| break |
| all_latents.append(latent.view(B, self.latent_channels, 1, self.height, self.width)) |
|
|
| return torch.cat(all_latents, dim=2) |
|
|
|
|
| class MRFBlock(nn.Module): |
| """Multi-Receptive Field Fusion block from HiFi-GAN.""" |
| def __init__(self, channels, kernel_sizes=(3, 5, 7)): |
| super().__init__() |
| self.convs = nn.ModuleList([ |
| nn.Sequential( |
| nn.LeakyReLU(0.1), |
| nn.Conv1d(channels, channels, k, padding=k//2, dilation=1), |
| ) |
| for k in kernel_sizes |
| ]) |
|
|
| def forward(self, x): |
| return sum(conv(x) for conv in self.convs) / len(self.convs) |
|
|
|
|
| class TinyNeuralCodec(nn.Module): |
| """Lightweight neural audio decoder (frozen float32 sidecar). |
| |
| Maps byte token sequences to 16 kHz audio waveforms via transposed conv. |
| Token rate: 50 Hz → output: [B, 1, T * 320] at 16 kHz. |
| """ |
| def __init__(self, vocab=AUDIO_VOCAB, embed_dim=512, upsample_ratios=(5, 4, 4, 4)): |
| super().__init__() |
| self.embed = nn.Embedding(vocab, embed_dim) |
|
|
| in_ch = embed_dim |
| self.blocks = nn.ModuleList() |
| for i, ratio in enumerate(upsample_ratios): |
| out_ch = max(1, embed_dim // (2 ** (i + 1))) |
| k = ratio * 2 |
| pad = (ratio + 1) // 2 if ratio % 2 else ratio // 2 |
| op = max(0, ratio + 2 * pad - k) |
| block = nn.Sequential( |
| nn.ConvTranspose1d(in_ch, out_ch, k, stride=ratio, padding=pad, output_padding=op), |
| MRFBlock(out_ch), |
| ) |
| self.blocks.append(block) |
| in_ch = out_ch |
|
|
| self.to_audio = nn.Conv1d(in_ch, 1, kernel_size=7, padding=3) |
|
|
| def forward(self, tokens): |
| x = self.embed(tokens) |
| x = x.permute(0, 2, 1) |
| for block in self.blocks: |
| x = block(x) |
| x = self.to_audio(x) |
| return torch.tanh(x) |
|
|
| def encode_audio(self, audio, frame_rate=AUDIO_FRAME_RATE, sr=AUDIO_SR): |
| B, C, T = audio.shape |
| frame_len = sr // frame_rate |
| pad = (frame_len - T % frame_len) % frame_len |
| if pad > 0: |
| audio = F.pad(audio, (0, pad)) |
| frames = audio.unfold(2, frame_len, frame_len) |
| frames = frames.mean(dim=1) |
| emb = self.embed.weight |
| B, NF, FL = frames.shape |
| frames_flat = frames.reshape(-1, FL) |
| frame_energy = frames_flat.mean(dim=1) |
| tokens = torch.clamp(((frame_energy + 1.0) * 127.5).long(), 0, 255) |
| tokens = tokens.reshape(B, NF) |
| recon = self(tokens) |
| if pad > 0: |
| recon = recon[:, :, :T] |
| return tokens, recon |
|
|
|
|
| class TalkerHead(nn.Module): |
| """Audio generation head with temporal stride and chunked ACT generation. |
| |
| 2-layer MLP: 8192 → 8192 → 288. |
| Generates byte token predictions at 50 Hz frame rate in 500-frame chunks. |
| TinyNeuralCodec decodes the predicted tokens to audio waveform. |
| """ |
| def __init__(self, tscale_type=TScaleType.T32, |
| chunk_frames=TALKERHEAD_ACT_CHUNK_FRAMES): |
| super().__init__() |
| self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) |
| self.hidden = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type) |
| self.hidden_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) |
| self.head = TernaryScaleTensor(HIDDEN_DIM, AUDIO_VOCAB, tscale_type=tscale_type) |
| self.codec = None |
| self.max_frames = chunk_frames |
| self.chunk_frames = chunk_frames |
|
|
| def load_codec(self, device='cuda'): |
| if self.codec is None: |
| self.codec = TinyNeuralCodec().to(device) |
| self.codec.eval() |
| return self.codec |
|
|
| def token_logits(self, x, max_frames=None): |
| max_frames = max_frames or self.max_frames |
| cond = self.norm(x) |
| cond = F.silu(self.hidden_norm(self.hidden(cond))) |
| stride = max(1, max_frames // max(1, cond.shape[1])) |
| logits = self.head(cond) |
| logits = logits.repeat_interleave(stride, dim=1) |
| if logits.shape[1] > max_frames: |
| logits = logits[:, :max_frames, :] |
| elif logits.shape[1] < max_frames: |
| pad = logits.new_zeros(logits.shape[0], max_frames - logits.shape[1], logits.shape[2]) |
| logits = torch.cat([logits, pad], dim=1) |
| return logits |
|
|
| def forward(self, x, max_frames=None): |
| return self.token_logits(x, max_frames=max_frames).argmax(dim=-1) |
|
|
| def generate_audio(self, x, max_frames=None, return_all=True): |
| if max_frames is None: |
| max_frames = self.max_frames |
| all_tokens = [] |
| remaining = max_frames |
| while remaining > 0: |
| chunk = min(remaining, self.chunk_frames) |
| tokens = self.forward(x, max_frames=chunk) |
| all_tokens.append(tokens) |
| remaining -= chunk |
| tokens = torch.cat(all_tokens, dim=1) |
| codec = self.load_codec(x.device if hasattr(x, 'device') else 'cuda') |
| with torch.no_grad(): |
| waveform = codec(tokens) |
| return waveform, tokens |
|
|