| """Audio training encoder — VQ encoder for TalkerHead target preparation. |
| |
| Training-only component (~5M float params). Maps audio at 50 Hz to 289-class byte tokens. |
| TinyNeuralCodec (the decoder) is in arbitor.components — shared with TalkerHead. |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from ..components import TernaryEmbeddingTable |
| from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType |
|
|
|
|
| class TernaryConv1d(nn.Module): |
| """Conv1d implemented as unfold + ternary linear projection.""" |
| def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, |
| tscale_type=TScaleType.T32, bias=True): |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.padding = padding |
| self.proj = TernaryScaleTensor( |
| in_channels * kernel_size, |
| out_channels, |
| tscale_type=tscale_type, |
| bias=bias, |
| ) |
|
|
| def forward(self, x): |
| if self.padding: |
| x = F.pad(x, (self.padding, self.padding)) |
| windows = x.unfold(2, self.kernel_size, self.stride) |
| windows = windows.permute(0, 2, 1, 3).reshape(x.size(0), -1, self.in_channels * self.kernel_size) |
| return self.proj(windows).permute(0, 2, 1) |
|
|
|
|
| class AudioVQEncoder(nn.Module): |
| """Encodes audio to discrete byte tokens at 50 Hz for TalkerHead training. |
| |
| Input: [B, 1, T] audio waveform at 16 kHz |
| Output: [B, T/320, 288] logits over byte vocab (50 Hz frame rate) |
| """ |
| def __init__(self, vocab=288, codebook_dim=64, downsample_ratios=(4, 4, 4, 5), |
| tscale_type=TScaleType.T32): |
| super().__init__() |
| in_ch = 1 |
| self.down_blocks = nn.ModuleList() |
| for i, ratio in enumerate(downsample_ratios): |
| out_ch = min(128, 32 * (2 ** i)) |
| block = nn.Sequential( |
| TernaryConv1d(in_ch, out_ch, kernel_size=ratio * 2, stride=ratio, |
| padding=ratio // 2, tscale_type=tscale_type), |
| nn.LeakyReLU(0.1), |
| TernaryConv1d(out_ch, out_ch, kernel_size=3, padding=1, |
| tscale_type=tscale_type), |
| nn.LeakyReLU(0.1), |
| ) |
| self.down_blocks.append(block) |
| in_ch = out_ch |
| self.proj = TernaryScaleTensor(out_ch, codebook_dim, tscale_type=tscale_type, bias=True) |
| self.codebook = TernaryEmbeddingTable(vocab, codebook_dim, tscale_type=tscale_type) |
| self.out_proj = TernaryScaleTensor(codebook_dim, vocab, tscale_type=tscale_type, bias=True) |
|
|
| def forward(self, audio): |
| x = audio |
| for block in self.down_blocks: |
| x = block(x) |
| x = x.permute(0, 2, 1) |
| x = self.proj(x) |
| emb_idx = torch.arange(self.out_proj.out_dim, device=x.device) |
| emb = self.codebook(emb_idx).to(device=x.device, dtype=x.dtype) |
| dist = torch.cdist(x.float(), emb.unsqueeze(0).float()) |
| indices = dist.argmin(dim=-1) |
| quantized = F.embedding(indices, emb) |
| quantized = x + (quantized - x).detach() |
| logits = self.out_proj(quantized) |
| return logits, indices |
|
|
| def encode(self, audio): |
| with torch.no_grad(): |
| _, indices = self.forward(audio) |
| return indices |
|
|