"""Decoder modules — video diffusion, audio codec, speech generation. These modules convert HIDDEN_DIM relational states into modality-specific outputs: video (latent diffusion), audio (codec tokens), and speech (token striding + codec). """ import torch import torch.nn as nn import torch.nn.functional as F from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm from .kernel.triton_video import video_denoise_step from .config import HIDDEN_DIM, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, \ VIDEO_LATENT_CHANNELS, VIDEO_MAX_STEPS, VIDEO_HEIGHT, VIDEO_WIDTH, \ VIDEOHEAD_ACT_MIN_FPS, VIDEOHEAD_ACT_MAX_FPS, VIDEOHEAD_ACT_FRAME_CHUNK, \ TALKERHEAD_ACT_CHUNK_FRAMES from .components import TernaryEmbeddingTable class LTIInjection(nn.Module): """LTI state injection: h = A*h + B*e + trans_out. Spectral radius < 1 guaranteed by construction via ZOH discretization. """ def __init__(self, dim: int): super().__init__() self.log_A = nn.Parameter(torch.zeros(dim)) self.log_dt = nn.Parameter(torch.zeros(1)) self.B = nn.Parameter(torch.ones(dim) * 0.1) for p in (self.log_A, self.log_dt, self.B): p.requires_grad_(False) def get_A(self): return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20))) def forward(self, h, e, trans_out): return self.get_A() * h + self.B * e + trans_out class VideoHead(nn.Module): """Scaled latent diffusion with cross-attention conditioning, frame gate, and 4-frame latent. Produces [B, ch, 4, H', W'] latents (4-frame temporal chunks) per D-102. Frame gate controls adaptive fps in [MIN_FPS, MAX_FPS] range. """ def __init__(self, tscale_type=TScaleType.T32, max_steps=VIDEO_MAX_STEPS, latent_channels=VIDEO_LATENT_CHANNELS, height=VIDEO_HEIGHT, width=VIDEO_WIDTH, min_fps=VIDEOHEAD_ACT_MIN_FPS, max_fps=VIDEOHEAD_ACT_MAX_FPS, frame_chunk=VIDEOHEAD_ACT_FRAME_CHUNK): super().__init__() self.max_steps = max_steps self.latent_channels = latent_channels self.height = height self.width = width self.latent_dim = latent_channels * height * width self.halt_threshold = 0.05 self.min_fps = min_fps self.max_fps = max_fps self.frame_chunk = frame_chunk self.cross_attn_q = TernaryScaleTensor(self.latent_dim, HIDDEN_DIM, tscale_type=tscale_type) self.cross_attn_kv = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type) self.diffusion_step = TernaryScaleTensor(HIDDEN_DIM, self.latent_dim, tscale_type=tscale_type) self.halt_unit = TernaryScaleTensor(HIDDEN_DIM, 1, tscale_type=tscale_type) self.frame_gate = TernaryScaleTensor(HIDDEN_DIM, 1, tscale_type=tscale_type) self.noise_embed = TernaryEmbeddingTable(max_steps, HIDDEN_DIM, tscale_type=tscale_type) self.lti = LTIInjection(self.latent_dim) @torch.no_grad() def _compute_fps(self, cond): frame_prob = torch.sigmoid(self.frame_gate(cond)) fps = self.min_fps + frame_prob * (self.max_fps - self.min_fps) return fps.mean().item() def forward(self, relational, max_steps=None, duration_seconds=1.0): B, T, D = relational.shape max_steps = max_steps or self.max_steps cond = relational.mean(dim=1, keepdim=True) fps = self._compute_fps(cond) n_frames = max(1, int(fps * duration_seconds)) n_latents = min((n_frames + self.frame_chunk - 1) // self.frame_chunk, max_steps) all_latents = [] for chunk_idx in range(n_latents): latent = torch.randn(B, 1, self.latent_dim, device=relational.device, requires_grad=torch.is_grad_enabled()) for step in range(max_steps): q = self.cross_attn_q(latent) kv = self.cross_attn_kv(cond.expand(-1, T, -1)) context = kv.mean(dim=1, keepdim=True) step_embed = self.noise_embed(torch.tensor(step, device=relational.device)) step_embed = step_embed.expand(B, 1, -1) step_input = q + context + step_embed pred_noise = self.diffusion_step(step_input) alpha = 0.9 ** step trans_out = video_denoise_step(latent, pred_noise, alpha) h = torch.zeros(B, 1, self.latent_dim, device=context.device) h[:, :, :HIDDEN_DIM] = context latent = self.lti(latent, h, trans_out) halt = torch.sigmoid(self.halt_unit(context)) if halt.mean() > self.halt_threshold and step > 1: break all_latents.append(latent.view(B, self.latent_channels, 1, self.height, self.width)) return torch.cat(all_latents, dim=2) class MRFBlock(nn.Module): """Multi-Receptive Field Fusion block from HiFi-GAN.""" def __init__(self, channels, kernel_sizes=(3, 5, 7)): super().__init__() self.convs = nn.ModuleList([ nn.Sequential( nn.LeakyReLU(0.1), nn.Conv1d(channels, channels, k, padding=k//2, dilation=1), ) for k in kernel_sizes ]) def forward(self, x): return sum(conv(x) for conv in self.convs) / len(self.convs) class TinyNeuralCodec(nn.Module): """Lightweight neural audio decoder (frozen float32 sidecar). Maps byte token sequences to 16 kHz audio waveforms via transposed conv. Token rate: 50 Hz → output: [B, 1, T * 320] at 16 kHz. """ def __init__(self, vocab=AUDIO_VOCAB, embed_dim=512, upsample_ratios=(5, 4, 4, 4)): super().__init__() self.embed = nn.Embedding(vocab, embed_dim) in_ch = embed_dim self.blocks = nn.ModuleList() for i, ratio in enumerate(upsample_ratios): out_ch = max(1, embed_dim // (2 ** (i + 1))) k = ratio * 2 pad = (ratio + 1) // 2 if ratio % 2 else ratio // 2 op = max(0, ratio + 2 * pad - k) block = nn.Sequential( nn.ConvTranspose1d(in_ch, out_ch, k, stride=ratio, padding=pad, output_padding=op), MRFBlock(out_ch), ) self.blocks.append(block) in_ch = out_ch self.to_audio = nn.Conv1d(in_ch, 1, kernel_size=7, padding=3) def forward(self, tokens): x = self.embed(tokens) x = x.permute(0, 2, 1) for block in self.blocks: x = block(x) x = self.to_audio(x) return torch.tanh(x) def encode_audio(self, audio, frame_rate=AUDIO_FRAME_RATE, sr=AUDIO_SR): B, C, T = audio.shape frame_len = sr // frame_rate pad = (frame_len - T % frame_len) % frame_len if pad > 0: audio = F.pad(audio, (0, pad)) frames = audio.unfold(2, frame_len, frame_len) frames = frames.mean(dim=1) emb = self.embed.weight B, NF, FL = frames.shape frames_flat = frames.reshape(-1, FL) frame_energy = frames_flat.mean(dim=1) tokens = torch.clamp(((frame_energy + 1.0) * 127.5).long(), 0, 255) tokens = tokens.reshape(B, NF) recon = self(tokens) if pad > 0: recon = recon[:, :, :T] return tokens, recon class TalkerHead(nn.Module): """Audio generation head with temporal stride and chunked ACT generation. 2-layer MLP: 8192 → 8192 → 288. Generates byte token predictions at 50 Hz frame rate in 500-frame chunks. TinyNeuralCodec decodes the predicted tokens to audio waveform. """ def __init__(self, tscale_type=TScaleType.T32, chunk_frames=TALKERHEAD_ACT_CHUNK_FRAMES): super().__init__() self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) self.hidden = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type) self.hidden_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type) self.head = TernaryScaleTensor(HIDDEN_DIM, AUDIO_VOCAB, tscale_type=tscale_type) self.codec = None self.max_frames = chunk_frames self.chunk_frames = chunk_frames def load_codec(self, device='cuda'): if self.codec is None: self.codec = TinyNeuralCodec().to(device) self.codec.eval() return self.codec def token_logits(self, x, max_frames=None): max_frames = max_frames or self.max_frames cond = self.norm(x) cond = F.silu(self.hidden_norm(self.hidden(cond))) stride = max(1, max_frames // max(1, cond.shape[1])) logits = self.head(cond) logits = logits.repeat_interleave(stride, dim=1) if logits.shape[1] > max_frames: logits = logits[:, :max_frames, :] elif logits.shape[1] < max_frames: pad = logits.new_zeros(logits.shape[0], max_frames - logits.shape[1], logits.shape[2]) logits = torch.cat([logits, pad], dim=1) return logits def forward(self, x, max_frames=None): return self.token_logits(x, max_frames=max_frames).argmax(dim=-1) def generate_audio(self, x, max_frames=None, return_all=True): if max_frames is None: max_frames = self.max_frames all_tokens = [] remaining = max_frames while remaining > 0: chunk = min(remaining, self.chunk_frames) tokens = self.forward(x, max_frames=chunk) all_tokens.append(tokens) remaining -= chunk tokens = torch.cat(all_tokens, dim=1) codec = self.load_codec(x.device if hasattr(x, 'device') else 'cuda') with torch.no_grad(): waveform = codec(tokens) return waveform, tokens