"""Audio training encoder — VQ encoder for TalkerHead target preparation. Training-only component (~5M float params). Maps audio at 50 Hz to 289-class byte tokens. TinyNeuralCodec (the decoder) is in arbitor.components — shared with TalkerHead. """ import torch import torch.nn as nn import torch.nn.functional as F from ..components import TernaryEmbeddingTable from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType class TernaryConv1d(nn.Module): """Conv1d implemented as unfold + ternary linear projection.""" def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, tscale_type=TScaleType.T32, bias=True): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.proj = TernaryScaleTensor( in_channels * kernel_size, out_channels, tscale_type=tscale_type, bias=bias, ) def forward(self, x): if self.padding: x = F.pad(x, (self.padding, self.padding)) windows = x.unfold(2, self.kernel_size, self.stride) windows = windows.permute(0, 2, 1, 3).reshape(x.size(0), -1, self.in_channels * self.kernel_size) return self.proj(windows).permute(0, 2, 1) class AudioVQEncoder(nn.Module): """Encodes audio to discrete byte tokens at 50 Hz for TalkerHead training. Input: [B, 1, T] audio waveform at 16 kHz Output: [B, T/320, 288] logits over byte vocab (50 Hz frame rate) """ def __init__(self, vocab=288, codebook_dim=64, downsample_ratios=(4, 4, 4, 5), tscale_type=TScaleType.T32): super().__init__() in_ch = 1 self.down_blocks = nn.ModuleList() for i, ratio in enumerate(downsample_ratios): out_ch = min(128, 32 * (2 ** i)) block = nn.Sequential( TernaryConv1d(in_ch, out_ch, kernel_size=ratio * 2, stride=ratio, padding=ratio // 2, tscale_type=tscale_type), nn.LeakyReLU(0.1), TernaryConv1d(out_ch, out_ch, kernel_size=3, padding=1, tscale_type=tscale_type), nn.LeakyReLU(0.1), ) self.down_blocks.append(block) in_ch = out_ch self.proj = TernaryScaleTensor(out_ch, codebook_dim, tscale_type=tscale_type, bias=True) self.codebook = TernaryEmbeddingTable(vocab, codebook_dim, tscale_type=tscale_type) self.out_proj = TernaryScaleTensor(codebook_dim, vocab, tscale_type=tscale_type, bias=True) def forward(self, audio): x = audio for block in self.down_blocks: x = block(x) x = x.permute(0, 2, 1) x = self.proj(x) emb_idx = torch.arange(self.out_proj.out_dim, device=x.device) emb = self.codebook(emb_idx).to(device=x.device, dtype=x.dtype) dist = torch.cdist(x.float(), emb.unsqueeze(0).float()) indices = dist.argmin(dim=-1) quantized = F.embedding(indices, emb) quantized = x + (quantized - x).detach() logits = self.out_proj(quantized) return logits, indices def encode(self, audio): with torch.no_grad(): _, indices = self.forward(audio) return indices