ARBS / arbitor /decoders.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Decoder modules — video diffusion, audio codec, speech generation.
These modules convert HIDDEN_DIM relational states into modality-specific outputs:
video (latent diffusion), audio (codec tokens), and speech (token striding + codec).
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .kernel.ternary_scale import TernaryScaleTensor, TScaleType, TernaryRMSNorm
from .kernel.triton_video import video_denoise_step
from .config import HIDDEN_DIM, AUDIO_VOCAB, AUDIO_SR, AUDIO_FRAME_RATE, \
VIDEO_LATENT_CHANNELS, VIDEO_MAX_STEPS, VIDEO_HEIGHT, VIDEO_WIDTH, \
VIDEOHEAD_ACT_MIN_FPS, VIDEOHEAD_ACT_MAX_FPS, VIDEOHEAD_ACT_FRAME_CHUNK, \
TALKERHEAD_ACT_CHUNK_FRAMES
from .components import TernaryEmbeddingTable
class LTIInjection(nn.Module):
"""LTI state injection: h = A*h + B*e + trans_out.
Spectral radius < 1 guaranteed by construction via ZOH discretization.
"""
def __init__(self, dim: int):
super().__init__()
self.log_A = nn.Parameter(torch.zeros(dim))
self.log_dt = nn.Parameter(torch.zeros(1))
self.B = nn.Parameter(torch.ones(dim) * 0.1)
for p in (self.log_A, self.log_dt, self.B):
p.requires_grad_(False)
def get_A(self):
return torch.exp(-torch.exp((self.log_dt + self.log_A).clamp(-20, 20)))
def forward(self, h, e, trans_out):
return self.get_A() * h + self.B * e + trans_out
class VideoHead(nn.Module):
"""Scaled latent diffusion with cross-attention conditioning, frame gate, and 4-frame latent.
Produces [B, ch, 4, H', W'] latents (4-frame temporal chunks) per D-102.
Frame gate controls adaptive fps in [MIN_FPS, MAX_FPS] range.
"""
def __init__(self, tscale_type=TScaleType.T32, max_steps=VIDEO_MAX_STEPS,
latent_channels=VIDEO_LATENT_CHANNELS, height=VIDEO_HEIGHT, width=VIDEO_WIDTH,
min_fps=VIDEOHEAD_ACT_MIN_FPS, max_fps=VIDEOHEAD_ACT_MAX_FPS,
frame_chunk=VIDEOHEAD_ACT_FRAME_CHUNK):
super().__init__()
self.max_steps = max_steps
self.latent_channels = latent_channels
self.height = height
self.width = width
self.latent_dim = latent_channels * height * width
self.halt_threshold = 0.05
self.min_fps = min_fps
self.max_fps = max_fps
self.frame_chunk = frame_chunk
self.cross_attn_q = TernaryScaleTensor(self.latent_dim, HIDDEN_DIM, tscale_type=tscale_type)
self.cross_attn_kv = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type)
self.diffusion_step = TernaryScaleTensor(HIDDEN_DIM, self.latent_dim, tscale_type=tscale_type)
self.halt_unit = TernaryScaleTensor(HIDDEN_DIM, 1, tscale_type=tscale_type)
self.frame_gate = TernaryScaleTensor(HIDDEN_DIM, 1, tscale_type=tscale_type)
self.noise_embed = TernaryEmbeddingTable(max_steps, HIDDEN_DIM, tscale_type=tscale_type)
self.lti = LTIInjection(self.latent_dim)
@torch.no_grad()
def _compute_fps(self, cond):
frame_prob = torch.sigmoid(self.frame_gate(cond))
fps = self.min_fps + frame_prob * (self.max_fps - self.min_fps)
return fps.mean().item()
def forward(self, relational, max_steps=None, duration_seconds=1.0):
B, T, D = relational.shape
max_steps = max_steps or self.max_steps
cond = relational.mean(dim=1, keepdim=True)
fps = self._compute_fps(cond)
n_frames = max(1, int(fps * duration_seconds))
n_latents = min((n_frames + self.frame_chunk - 1) // self.frame_chunk, max_steps)
all_latents = []
for chunk_idx in range(n_latents):
latent = torch.randn(B, 1, self.latent_dim, device=relational.device,
requires_grad=torch.is_grad_enabled())
for step in range(max_steps):
q = self.cross_attn_q(latent)
kv = self.cross_attn_kv(cond.expand(-1, T, -1))
context = kv.mean(dim=1, keepdim=True)
step_embed = self.noise_embed(torch.tensor(step, device=relational.device))
step_embed = step_embed.expand(B, 1, -1)
step_input = q + context + step_embed
pred_noise = self.diffusion_step(step_input)
alpha = 0.9 ** step
trans_out = video_denoise_step(latent, pred_noise, alpha)
h = torch.zeros(B, 1, self.latent_dim, device=context.device)
h[:, :, :HIDDEN_DIM] = context
latent = self.lti(latent, h, trans_out)
halt = torch.sigmoid(self.halt_unit(context))
if halt.mean() > self.halt_threshold and step > 1:
break
all_latents.append(latent.view(B, self.latent_channels, 1, self.height, self.width))
return torch.cat(all_latents, dim=2)
class MRFBlock(nn.Module):
"""Multi-Receptive Field Fusion block from HiFi-GAN."""
def __init__(self, channels, kernel_sizes=(3, 5, 7)):
super().__init__()
self.convs = nn.ModuleList([
nn.Sequential(
nn.LeakyReLU(0.1),
nn.Conv1d(channels, channels, k, padding=k//2, dilation=1),
)
for k in kernel_sizes
])
def forward(self, x):
return sum(conv(x) for conv in self.convs) / len(self.convs)
class TinyNeuralCodec(nn.Module):
"""Lightweight neural audio decoder (frozen float32 sidecar).
Maps byte token sequences to 16 kHz audio waveforms via transposed conv.
Token rate: 50 Hz → output: [B, 1, T * 320] at 16 kHz.
"""
def __init__(self, vocab=AUDIO_VOCAB, embed_dim=512, upsample_ratios=(5, 4, 4, 4)):
super().__init__()
self.embed = nn.Embedding(vocab, embed_dim)
in_ch = embed_dim
self.blocks = nn.ModuleList()
for i, ratio in enumerate(upsample_ratios):
out_ch = max(1, embed_dim // (2 ** (i + 1)))
k = ratio * 2
pad = (ratio + 1) // 2 if ratio % 2 else ratio // 2
op = max(0, ratio + 2 * pad - k)
block = nn.Sequential(
nn.ConvTranspose1d(in_ch, out_ch, k, stride=ratio, padding=pad, output_padding=op),
MRFBlock(out_ch),
)
self.blocks.append(block)
in_ch = out_ch
self.to_audio = nn.Conv1d(in_ch, 1, kernel_size=7, padding=3)
def forward(self, tokens):
x = self.embed(tokens)
x = x.permute(0, 2, 1)
for block in self.blocks:
x = block(x)
x = self.to_audio(x)
return torch.tanh(x)
def encode_audio(self, audio, frame_rate=AUDIO_FRAME_RATE, sr=AUDIO_SR):
B, C, T = audio.shape
frame_len = sr // frame_rate
pad = (frame_len - T % frame_len) % frame_len
if pad > 0:
audio = F.pad(audio, (0, pad))
frames = audio.unfold(2, frame_len, frame_len)
frames = frames.mean(dim=1)
emb = self.embed.weight
B, NF, FL = frames.shape
frames_flat = frames.reshape(-1, FL)
frame_energy = frames_flat.mean(dim=1)
tokens = torch.clamp(((frame_energy + 1.0) * 127.5).long(), 0, 255)
tokens = tokens.reshape(B, NF)
recon = self(tokens)
if pad > 0:
recon = recon[:, :, :T]
return tokens, recon
class TalkerHead(nn.Module):
"""Audio generation head with temporal stride and chunked ACT generation.
2-layer MLP: 8192 → 8192 → 288.
Generates byte token predictions at 50 Hz frame rate in 500-frame chunks.
TinyNeuralCodec decodes the predicted tokens to audio waveform.
"""
def __init__(self, tscale_type=TScaleType.T32,
chunk_frames=TALKERHEAD_ACT_CHUNK_FRAMES):
super().__init__()
self.norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
self.hidden = TernaryScaleTensor(HIDDEN_DIM, HIDDEN_DIM, tscale_type=tscale_type)
self.hidden_norm = TernaryRMSNorm(HIDDEN_DIM, tscale_type=tscale_type)
self.head = TernaryScaleTensor(HIDDEN_DIM, AUDIO_VOCAB, tscale_type=tscale_type)
self.codec = None
self.max_frames = chunk_frames
self.chunk_frames = chunk_frames
def load_codec(self, device='cuda'):
if self.codec is None:
self.codec = TinyNeuralCodec().to(device)
self.codec.eval()
return self.codec
def token_logits(self, x, max_frames=None):
max_frames = max_frames or self.max_frames
cond = self.norm(x)
cond = F.silu(self.hidden_norm(self.hidden(cond)))
stride = max(1, max_frames // max(1, cond.shape[1]))
logits = self.head(cond)
logits = logits.repeat_interleave(stride, dim=1)
if logits.shape[1] > max_frames:
logits = logits[:, :max_frames, :]
elif logits.shape[1] < max_frames:
pad = logits.new_zeros(logits.shape[0], max_frames - logits.shape[1], logits.shape[2])
logits = torch.cat([logits, pad], dim=1)
return logits
def forward(self, x, max_frames=None):
return self.token_logits(x, max_frames=max_frames).argmax(dim=-1)
def generate_audio(self, x, max_frames=None, return_all=True):
if max_frames is None:
max_frames = self.max_frames
all_tokens = []
remaining = max_frames
while remaining > 0:
chunk = min(remaining, self.chunk_frames)
tokens = self.forward(x, max_frames=chunk)
all_tokens.append(tokens)
remaining -= chunk
tokens = torch.cat(all_tokens, dim=1)
codec = self.load_codec(x.device if hasattr(x, 'device') else 'cuda')
with torch.no_grad():
waveform = codec(tokens)
return waveform, tokens