| from einops import rearrange |
| from einops.layers.torch import Rearrange |
| import torchaudio.transforms as audio_transforms |
| import torch |
| import torch.nn as nn |
| from typing import Optional, Type |
|
|
| class FrontEnd(nn.Sequential): |
| def __init__( |
| self, |
| f_min: int = 0, |
| sample_rate: int = 16000, |
| win_size: int = 512, |
| center: bool = True, |
| n_fft: int = 512, |
| f_max: Optional[int] = 8000, |
| hop_size: int = 160, |
| n_mels: int = 64, |
| ): |
| self.f_min = f_min |
| self.sample_rate = sample_rate |
| self.win_size = win_size |
| self.center = center |
| self.n_fft = n_fft |
| self.f_max = f_max |
| self.hop_size = hop_size |
| self.n_mels = n_mels |
|
|
| with torch.device("cpu"): |
| super().__init__( |
| audio_transforms.MelSpectrogram( |
| f_min=self.f_min, |
| sample_rate=self.sample_rate, |
| win_length=self.win_size, |
| center=self.center, |
| n_fft=self.n_fft, |
| f_max=self.f_max, |
| hop_length=self.hop_size, |
| n_mels=self.n_mels, |
| ), |
| audio_transforms.AmplitudeToDB(top_db=120), |
| ) |
|
|
| @torch.autocast(enabled=False, device_type="cuda") |
| def forward(self, x, attention_mask=None): |
| """ |
| Forward pass of the frontend. |
| |
| Args: |
| x: Audio tensor of shape (batch_size, num_samples) |
| attention_mask: Optional attention mask of shape (batch_size, num_samples) |
| |
| Returns: |
| features: Mel spectrogram features of shape (batch_size, n_mels, num_frames) |
| attention_mask: Downsampled attention mask of shape (batch_size, num_frames) |
| """ |
| features = super().forward(x) |
| if attention_mask is not None: |
| lengths = attention_mask.float().sum(-1) // self.hop_size |
| attention_mask = (torch.arange(features.shape[-1], device=features.device) < lengths.unsqueeze(-1)).int() |
| return features, attention_mask |
|
|
|
|
| class Mlp(nn.Module): |
| def __init__( |
| self, |
| in_features: int, |
| hidden_features: Optional[int] = None, |
| out_features: Optional[int] = None, |
| act_layer: Type[torch.nn.Module] = nn.GELU, |
| drop: float = 0.0, |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class Attention(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int = 8, |
| qkv_bias: bool = True, |
| attn_drop: float = 0.0, |
| proj_drop: float = 0.0, |
| causal: bool = False, |
| ): |
| super().__init__() |
| self.num_heads = num_heads |
| self.scale = (dim // num_heads) ** -0.5 |
| self.causal = causal |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, x, mask: Optional[torch.Tensor] = None): |
| B, N, C = x.shape |
| |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
| attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
| |
| if self.causal: |
| c_mask = torch.ones(N, N, device=x.device, dtype=torch.bool).triu(1) |
| attn = attn.masked_fill(c_mask, float("-inf")) |
|
|
| |
| if mask is not None: |
| if mask.dtype != torch.bool: |
| padding_mask = (mask == 0) |
| else: |
| padding_mask = mask |
| padding_mask = padding_mask.view(B, 1, 1, N) |
| attn = attn.masked_fill(padding_mask, float("-inf")) |
| attn = attn.softmax(dim=-1).nan_to_num() |
| attn = self.attn_drop(attn) |
|
|
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
| return self.proj_drop(self.proj(x)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__( |
| self, |
| dim: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| qkv_bias: bool = True, |
| drop: float = 0.0, |
| attn_drop: float = 0.0, |
| ): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(dim, eps=1e-6) |
| self.attn = Attention(dim, num_heads, qkv_bias, attn_drop, drop) |
| self.norm2 = nn.LayerNorm(dim, eps=1e-6) |
| self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=nn.GELU, drop=drop) |
|
|
| def forward(self, x, mask=None): |
| x = x + self.attn(self.norm1(x), mask=mask) |
| x = x + self.mlp(self.norm2(x)) |
| return x |
|
|
|
|
| class AudioPatchEmbed(torch.nn.Module): |
|
|
| def __init__(self, *args, **kwargs) -> None: |
| super().__init__() |
| self.stride = kwargs.get('stride', [None, 4])[-1] |
| self.proj = nn.Conv2d(*args, **kwargs) |
|
|
| def forward(self, x:torch.Tensor, attention_mask:torch.Tensor | None =None): |
| x = self.proj(x) |
| if attention_mask is not None: |
| lengths = attention_mask.float().sum(-1) // self.stride |
| attention_mask = (torch.arange(x.shape[-1], device=x.device) < lengths.unsqueeze(-1)).int() |
| return x, attention_mask |
|
|
|
|
|
|
|
|
| class DashengEncoder(nn.Module): |
| def __init__( |
| self, |
| embed_dim: int = 1280, |
| depth: int = 32, |
| num_heads: int = 20, |
| patch_size=[64, 4], |
| patch_stride=[64, 4], |
| target_length=1008, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.time_patches = patch_stride[-1] |
| self.front_end = FrontEnd() |
| self.target_length = target_length |
| self.max_t_tokens = target_length // patch_stride[-1] |
| self.patch_embed = AudioPatchEmbed(1, embed_dim, kernel_size=patch_size, stride=patch_stride) |
| self.init_bn = nn.Sequential( |
| Rearrange("b c f t -> b f c t"), |
| torch.nn.BatchNorm2d(self.front_end.n_mels, momentum=0.01), |
| Rearrange("b f c t -> b c f t"), |
| ) |
|
|
| self.time_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, target_length // self.time_patches) * 0.02) |
| self.freq_pos_embed = nn.Parameter(torch.randn(1, embed_dim, 1, 1) * 0.02) |
|
|
| self.blocks = nn.ModuleList([Block(embed_dim, num_heads) for _ in range(depth)]) |
| self.norm = nn.LayerNorm(embed_dim, eps=1e-6) |
|
|
| def _forward_main(self, x, attention_mask, mask_to_zero:bool = False): |
| x, attention_mask = self.patch_embed(x, attention_mask) |
| t = x.shape[-1] |
| x = x + self.time_pos_embed[:, :, :, :t] + self.freq_pos_embed |
| x = rearrange(x, "b c f t -> b (f t) c") |
| for block in self.blocks: |
| x = block(x, mask=attention_mask) |
| x = self.norm(x) |
| if attention_mask is not None and mask_to_zero: |
| x = x * attention_mask.unsqueeze(-1) |
| return x |
|
|
|
|
| def forward(self, x: torch.Tensor, attention_mask=None): |
| """ |
| Forward pass of the AudioTransformer. |
| |
| Args: |
| x: Audio tensor of shape (batch_size, num_samples) |
| attention_mask: Optional attention mask of shape (batch_size, num_samples) |
| where True indicates valid samples and False indicates padding |
| |
| Returns: |
| embeddings: Token embeddings of shape (batch_size, num_tokens, embed_dim) |
| """ |
| |
| x, attention_mask = self.front_end(x, attention_mask) |
|
|
| |
| x = rearrange(x, "b f t -> b 1 f t") |
| x = self.init_bn(x) |
|
|
| input_splits = x.split(self.target_length, dim=-1) |
| masks = [None for _ in range(len(input_splits))] |
| if attention_mask is not None: |
| masks = attention_mask.split(self.target_length, dim=-1) |
|
|
| outputs = [] |
| for i, (input_split_x, mask) in enumerate(zip(input_splits, masks)): |
| output = self._forward_main(input_split_x, attention_mask=mask, mask_to_zero=i != 0) |
| outputs.append(output) |
| x = torch.cat(outputs, dim=1) |
| return x |
|
|