ARBS / arbitor /encoders /audio.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""Audio training encoder — VQ encoder for TalkerHead target preparation.
Training-only component (~5M float params). Maps audio at 50 Hz to 289-class byte tokens.
TinyNeuralCodec (the decoder) is in arbitor.components — shared with TalkerHead.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..components import TernaryEmbeddingTable
from ..kernel.ternary_scale import TernaryScaleTensor, TScaleType
class TernaryConv1d(nn.Module):
"""Conv1d implemented as unfold + ternary linear projection."""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
tscale_type=TScaleType.T32, bias=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.proj = TernaryScaleTensor(
in_channels * kernel_size,
out_channels,
tscale_type=tscale_type,
bias=bias,
)
def forward(self, x):
if self.padding:
x = F.pad(x, (self.padding, self.padding))
windows = x.unfold(2, self.kernel_size, self.stride)
windows = windows.permute(0, 2, 1, 3).reshape(x.size(0), -1, self.in_channels * self.kernel_size)
return self.proj(windows).permute(0, 2, 1)
class AudioVQEncoder(nn.Module):
"""Encodes audio to discrete byte tokens at 50 Hz for TalkerHead training.
Input: [B, 1, T] audio waveform at 16 kHz
Output: [B, T/320, 288] logits over byte vocab (50 Hz frame rate)
"""
def __init__(self, vocab=288, codebook_dim=64, downsample_ratios=(4, 4, 4, 5),
tscale_type=TScaleType.T32):
super().__init__()
in_ch = 1
self.down_blocks = nn.ModuleList()
for i, ratio in enumerate(downsample_ratios):
out_ch = min(128, 32 * (2 ** i))
block = nn.Sequential(
TernaryConv1d(in_ch, out_ch, kernel_size=ratio * 2, stride=ratio,
padding=ratio // 2, tscale_type=tscale_type),
nn.LeakyReLU(0.1),
TernaryConv1d(out_ch, out_ch, kernel_size=3, padding=1,
tscale_type=tscale_type),
nn.LeakyReLU(0.1),
)
self.down_blocks.append(block)
in_ch = out_ch
self.proj = TernaryScaleTensor(out_ch, codebook_dim, tscale_type=tscale_type, bias=True)
self.codebook = TernaryEmbeddingTable(vocab, codebook_dim, tscale_type=tscale_type)
self.out_proj = TernaryScaleTensor(codebook_dim, vocab, tscale_type=tscale_type, bias=True)
def forward(self, audio):
x = audio
for block in self.down_blocks:
x = block(x)
x = x.permute(0, 2, 1)
x = self.proj(x)
emb_idx = torch.arange(self.out_proj.out_dim, device=x.device)
emb = self.codebook(emb_idx).to(device=x.device, dtype=x.dtype)
dist = torch.cdist(x.float(), emb.unsqueeze(0).float())
indices = dist.argmin(dim=-1)
quantized = F.embedding(indices, emb)
quantized = x + (quantized - x).detach()
logits = self.out_proj(quantized)
return logits, indices
def encode(self, audio):
with torch.no_grad():
_, indices = self.forward(audio)
return indices