| from .configuration_dasheng_tokenizer import DashengTokenizerConfig |
| from .modeling_dasheng_encoder import DashengEncoder |
| from .vocos import VocosModel |
| from typing import Optional, Tuple, Union |
| import torch |
| import torch.nn as nn |
| from einops import rearrange |
| import torchaudio |
| from transformers import PreTrainedModel |
|
|
|
|
| class VocosMelSpec(torch.nn.Module): |
| """MelSpectrogram frontend for Vocos.""" |
| def __init__(self, sample_rate=16000, n_fft=1024, hop_length=256, n_mels=100, padding="center"): |
| super().__init__() |
| if padding not in ["center", "same"]: |
| raise ValueError("Padding must be 'center' or 'same'.") |
| self.padding = padding |
| self.sample_rate = sample_rate |
| self.n_fft = n_fft |
| self.hop_length = hop_length |
| self.n_mels = n_mels |
| with torch.device("cpu"): |
| self.mel_spec = torchaudio.transforms.MelSpectrogram( |
| sample_rate=self.sample_rate, |
| n_fft=self.n_fft, |
| hop_length=self.hop_length, |
| n_mels=self.n_mels, |
| center=self.padding == "center", |
| power=1,) |
|
|
| def forward(self, audio, **kwargs): |
| if self.padding == "same": |
| pad = self.mel_spec.win_length - self.mel_spec.hop_length |
| audio = torch.nn.functional.pad(audio, (pad // 2, pad // 2), mode="reflect") |
| mel = self.mel_spec(audio) |
| return torch.log(torch.clip(mel, min=1e-7)) |
|
|
|
|
| class DashengTokenizerEncoder(torch.nn.Module): |
| def __init__( |
| self, |
| embed_dim: int = 1280, |
| depth:int = 32, |
| num_heads: int = 16, |
| n_mels_patch: int = 128, |
| hop_length: int = 160, |
| **kwargs, |
| ): |
| super().__init__() |
| self.model = DashengEncoder(embed_dim=embed_dim, depth=depth, num_heads=num_heads) |
| self.embed_dim = int(self.model.embed_dim) |
| self.model.outputlayer = torch.nn.Identity() |
|
|
| self.front_end = VocosMelSpec(hop_length=hop_length, n_mels=n_mels_patch) |
| self.patch_embed = torch.nn.Conv2d( |
| 1, self.model.embed_dim, (n_mels_patch, 4), (n_mels_patch, 4) |
| ) |
| self.norm = torch.nn.LayerNorm(self.model.embed_dim) |
|
|
| |
| self.n_fft = self.model.front_end.n_fft |
| self.hop_size = self.model.front_end.hop_size |
|
|
| @torch.no_grad() |
| def forward( |
| self, |
| input: torch.Tensor, |
| input_attn_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """ |
| Forward pass of the encoder. |
| |
| Args: |
| input: Audio tensor of shape (batch_size, num_samples) |
| input_attn_mask: Optional attention mask |
| |
| Returns: |
| Combined embeddings of shape (batch_size, num_tokens, embed_dim) |
| """ |
| with torch.no_grad(): |
| semantic_emb = self.model(input, input_attn_mask) |
|
|
| |
| mel = self.front_end(input).unsqueeze(1) |
| mel_emb = self.patch_embed(mel) |
| acoustic_emb = rearrange(mel_emb, "b c f t -> b (f t) c") |
| acoustic_emb = self.norm(acoustic_emb) |
|
|
| semantic_emb = semantic_emb[:, : acoustic_emb.shape[1], :] |
| emb = semantic_emb + acoustic_emb |
| return emb |
|
|
|
|
| class DashengTokenizerPreTrainedModel(PreTrainedModel): |
|
|
| config_class = DashengTokenizerConfig |
| supports_gradient_checkpointing = True |
|
|
| class DashengTokenizerModel(DashengTokenizerPreTrainedModel): |
| """ |
| HuggingFace-compatible DashEng Tokenizer Model (Encoder + Decoder). |
| |
| This model includes both the encoder and decoder for end-to-end audio processing. |
| """ |
|
|
| def __init__(self, config: DashengTokenizerConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| self.encoder = DashengTokenizerEncoder( |
| embed_dim=config.embed_dim, |
| depth = config.depth, |
| num_heads=config.num_heads, |
| n_mels_patch=config.n_mels_patch, |
| hop_length=config.hop_length, |
| ) |
|
|
| self.embed_dim = self.encoder.embed_dim |
|
|
| |
| self.upsampler = None |
| if config.upsample_tokens > 1: |
| self.upsampler = torch.nn.ConvTranspose1d( |
| self.embed_dim, self.embed_dim, |
| kernel_size=config.upsample_tokens, |
| stride=config.upsample_tokens |
| ) |
|
|
| |
| self.decoder = VocosModel( |
| input_channels=self.embed_dim, |
| hidden_dim=config.decoder_embed_dim, |
| intermediate_dim=config.decoder_intermediate_size, |
| vocos_istft_hop=config.istft_hop, |
| vocos_n_fft=config.istft_n_fft, |
| num_layers=config.decoder_depth, |
| ) |
|
|
| self.post_init() |
|
|
| def encode( |
| self, |
| audio: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Encode audio into embeddings.""" |
| return self.encoder(audio, attention_mask) |
|
|
| def decode(self, embeddings: torch.Tensor) -> torch.Tensor: |
| """Decode embeddings back to audio.""" |
| if self.upsampler is not None: |
| embeddings = self.upsampler(embeddings.transpose(-2, -1)).transpose(-2, -1) |
| output = self.decoder(embeddings.transpose(-2, -1)) |
| return output |
|
|
| def forward( |
| self, |
| audio: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor], dict]: |
| """ |
| Forward pass of the DashEng tokenizer. |
| |
| Args: |
| audio: Audio tensor of shape (batch_size, num_samples) |
| attention_mask: Optional attention mask |
| output_attentions: Whether to return attention weights |
| output_hidden_states: Whether to return hidden states |
| return_dict: Whether to return a dict |
| |
| Returns: |
| Reconstructed audio of shape (batch_size, num_samples) |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| embeddings = self.encoder(audio, attention_mask) |
|
|
| |
| audio_reconstructed = self.decode(embeddings) |
|
|
| if not return_dict: |
| return (audio_reconstructed,) |
|
|
| return { |
| "audio": audio_reconstructed, |
| "embeddings": embeddings, |
| } |
|
|
|
|