| import torch |
| import torch.nn as nn |
| from einops import rearrange |
| import torch |
| from torch.cuda.amp import autocast |
| from functools import partial |
| from typing import Optional, Tuple, Union |
| import torchaudio.transforms as audio_transforms |
| from einops import rearrange |
| from einops.layers.torch import Rearrange |
| from itertools import repeat |
| import collections |
|
|
| import torch.nn.functional as F |
| import einops |
|
|
|
|
| if hasattr(nn.functional, 'scaled_dot_product_attention'): |
| ATTENTION_MODE = 'flash' |
| else: |
| ATTENTION_MODE = 'math' |
| print(f'attention mode is {ATTENTION_MODE}') |
|
|
|
|
| def _ntuple(n): |
|
|
| def parse(x) -> Tuple: |
| if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
| return tuple(x) |
| return tuple(repeat(x, n)) |
|
|
| return parse |
|
|
|
|
| to_2tuple = _ntuple(2) |
|
|
|
|
| class MAELoss(torch.nn.Module): |
|
|
| def __init__(self, norm_pix_loss: bool = True): |
| super().__init__() |
| self.norm_pix_loss = norm_pix_loss |
|
|
| @autocast(enabled=False) |
| def forward(self, pred: torch.Tensor, target: torch.Tensor, |
| mask: torch.Tensor) -> torch.Tensor: |
| if self.norm_pix_loss is True: |
| mean = target.mean(dim=-1, keepdim=True) |
| var = target.var(dim=-1, keepdim=True) |
| target = (target - mean) / (var + 1.e-6)**.5 |
| elif self.norm_pix_loss == 'global': |
| mean = target.mean() |
| var = target.var() |
| target = (target - mean) / (var + 1.e-6)**.5 |
| loss = (pred - target)**2 |
| loss = loss.mean(dim=-1) |
| loss = (loss * mask).sum() / mask.sum() |
| return loss |
|
|
|
|
| class AudioPatchEmbed(nn.Module): |
|
|
| def __init__(self, |
| input_size: Union[int, Tuple[int, int]] = (64, 100), |
| patch_size: Tuple[int, int] = (64, 4), |
| patch_stride: Tuple[int, int] = (64, 4), |
| in_chans=1, |
| embed_dim=768, |
| norm_layer=None, |
| flatten=False): |
| super().__init__() |
| patch_size = to_2tuple(patch_size) |
| patch_stride = to_2tuple(patch_stride) |
| self.input_size: Tuple[int, int] = to_2tuple(input_size) |
| self.patch_size: Tuple[int, int] = to_2tuple(patch_size) |
| self.patch_stride: Tuple[int, int] = to_2tuple(patch_stride) |
| self.grid_size = (self.input_size[0] // self.patch_stride[0], |
| self.input_size[1] // self.patch_stride[1]) |
| self.num_patches = self.grid_size[0] * self.grid_size[1] |
| self.flatten = flatten |
|
|
| self.proj = nn.Conv2d(in_chans, |
| embed_dim, |
| kernel_size=patch_size, |
| stride=patch_stride) |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
| def forward(self, x): |
| x = self.proj(x) |
| if self.flatten: |
| x = rearrange(x, 'b c f t -> b (f t) c') |
| x = self.norm(x) |
| return x |
|
|
|
|
| class LayerScale(nn.Module): |
|
|
| def __init__(self, dim: int, init_values=1e-5, inplace=False): |
| super().__init__() |
| self.inplace = inplace |
| self.gamma = nn.Parameter(init_values * torch.ones(dim)) |
|
|
| def forward(self, x): |
| return x.mul_(self.gamma) if self.inplace else x * self.gamma |
|
|
|
|
| class Attention(nn.Module): |
|
|
| def __init__(self, |
| dim, |
| num_heads=8, |
| qkv_bias=False, |
| attn_drop=0., |
| proj_drop=0.): |
| super().__init__() |
| assert dim % num_heads == 0, 'dim should be divisible by num_heads' |
| self.num_heads = num_heads |
| head_dim = dim // num_heads |
| self.scale = head_dim**-0.5 |
|
|
| self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
| self.attn_drop = nn.Dropout(attn_drop) |
| self.attn_drop_p = attn_drop |
| self.proj = nn.Linear(dim, dim) |
| self.proj_drop = nn.Dropout(proj_drop) |
|
|
| def forward(self, x): |
| B, N, C = x.shape |
| qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, |
| C // self.num_heads).permute(2, 0, 3, 1, 4) |
| q, k, v = qkv.unbind( |
| 0) |
|
|
| if ATTENTION_MODE == 'flash': |
| x = F.scaled_dot_product_attention(q, k, v, |
| dropout_p=self.attn_drop_p, |
| scale=self.scale, |
| ) |
| x = einops.rearrange(x, 'B H L D -> B L (H D)') |
| elif ATTENTION_MODE == 'math': |
| attn = (q @ k.transpose(-2, -1)) * self.scale |
| attn = attn.softmax(dim=-1) |
| attn = self.attn_drop(attn) |
| x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
|
| x = self.proj(x) |
| x = self.proj_drop(x) |
| return x |
|
|
|
|
| class Mlp(nn.Module): |
|
|
| def __init__(self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.GELU, |
| drop=0.): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Linear(in_features, hidden_features) |
| self.act = act_layer() |
| self.fc2 = nn.Linear(hidden_features, out_features) |
| self.drop = nn.Dropout(drop) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class Block(nn.Module): |
|
|
| def __init__( |
| self, |
| dim, |
| num_heads, |
| mlp_ratio=4., |
| qkv_bias=False, |
| drop=0., |
| attn_drop=0., |
| init_values=None, |
| act_layer=nn.GELU, |
| norm_layer=nn.LayerNorm, |
| attention_type='Attention', |
| ): |
| super().__init__() |
| self.norm1 = norm_layer(dim) |
| attn_type = globals()[attention_type] |
| self.attn = attn_type(dim, |
| num_heads=num_heads, |
| qkv_bias=qkv_bias, |
| attn_drop=attn_drop, |
| proj_drop=drop) |
| self.ls1 = LayerScale( |
| dim, init_values=init_values) if init_values else nn.Identity() |
|
|
| self.norm2 = norm_layer(dim) |
| self.mlp = Mlp(in_features=dim, |
| hidden_features=int(dim * mlp_ratio), |
| act_layer=act_layer, |
| drop=drop) |
| self.ls2 = LayerScale( |
| dim, init_values=init_values) if init_values else nn.Identity() |
|
|
| def forward(self, x): |
| x = x + self.ls1(self.attn(self.norm1(x))) |
| x = x + self.ls2(self.mlp(self.norm2(x))) |
| return x |
|
|
|
|
| class AudioTransformerMAE_Encoder(nn.Module): |
|
|
| def __init__(self, |
| patch_size: Tuple[int, int] = (64, 4), |
| patch_stride: Tuple[int, int] = (64, 4), |
| embed_dim: int = 768, |
| depth: int = 12, |
| num_heads=8, |
| mlp_ratio=4., |
| qkv_bias=True, |
| drop_rate=0., |
| attn_drop_rate=0., |
| norm_layer=None, |
| act_layer=None, |
| init_values=None, |
| target_length=1008, |
| pooling='mean', |
| time_patch_out: Optional[float] = None, |
| freq_patch_out: Optional[float] = None, |
| block_type='Block', |
| attention_type='Attention', |
| eval_avg='cat', |
| n_fft: int = 512, |
| n_mels: int = 64, |
| hop_size: int = 160, |
| win_size: int = 512, |
| f_min: int = 0, |
| f_max: int = 8000, |
| center: bool = True, |
| **kwargs): |
| super().__init__() |
| self.pooling = pooling |
| self.embed_dim = embed_dim |
| self.patch_stride = patch_stride |
| self.patch_size = patch_size |
| self.n_mels = n_mels |
| self.eval_avg = eval_avg |
| self.time_patch_out = time_patch_out |
| self.freq_patch_out = freq_patch_out |
|
|
| self.front_end = nn.Sequential( |
| audio_transforms.MelSpectrogram(f_min=f_min, |
| sample_rate=16000, |
| win_length=win_size, |
| center=center, |
| n_fft=n_fft, |
| f_max=f_max, |
| hop_length=hop_size, |
| n_mels=self.n_mels), |
| audio_transforms.AmplitudeToDB(top_db=kwargs.get('top_db', 120))) |
|
|
| self.init_bn = nn.Sequential( |
| Rearrange('b c f t -> b f c t'), |
| nn.BatchNorm2d(self.n_mels, momentum=0.01), |
| Rearrange('b f c t -> b c f t')) |
|
|
| self.target_length = target_length |
| self.patch_embed = AudioPatchEmbed(input_size=(self.n_mels, |
| target_length), |
| embed_dim=self.embed_dim, |
| patch_size=self.patch_size, |
| flatten=False, |
| patch_stride=self.patch_stride) |
| self.num_patches = self.patch_embed.num_patches |
|
|
| if pooling == 'token': |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| self.token_pos_embed = nn.Parameter( |
| torch.randn(1, embed_dim) * .02) |
|
|
| self.time_pos_embed = nn.Parameter( |
| torch.randn(1, embed_dim, 1, self.patch_embed.grid_size[1]) * .02) |
| self.freq_pos_embed = nn.Parameter( |
| torch.randn(1, embed_dim, self.patch_embed.grid_size[0], 1) * .02) |
|
|
| norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
| act_layer = act_layer or nn.GELU |
| self.pos_drop = nn.Dropout(p=drop_rate) |
| block_function = globals()[block_type] |
| self.blocks = nn.Sequential(*[ |
| block_function( |
| dim=embed_dim, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| init_values=init_values, |
| drop=drop_rate, |
| attn_drop=attn_drop_rate, |
| norm_layer=norm_layer, |
| act_layer=act_layer, |
| attention_type=attention_type, |
| ) for _ in range(depth) |
| ]) |
| self.norm = norm_layer(embed_dim) |
| self.apply(self.init_weights) |
| if hasattr(self, 'cls_token') and self.cls_token is not None: |
| nn.init.normal_(self.cls_token, std=1e-6) |
| group_masking = kwargs.get('group_masking', False) |
| if isinstance(group_masking, bool): |
| if group_masking is True: |
| self.masking_func = self.random_masking_group |
| else: |
| self.masking_func = self.random_masking |
| elif isinstance(group_masking, int): |
| self.masking_func = partial(self.random_masking_group, |
| group_factor=group_masking) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| return { |
| 'time_pos_embed', 'cls_token', 'freq_pos_embed', 'token_pos_embed' |
| } |
|
|
| def init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.constant_(module.bias, 0) |
| nn.init.constant_(module.weight, 1.0) |
|
|
| def random_masking_group(self, x, mask_ratio, group_factor: int = 2): |
| """ |
| Perform per-sample random masking by per-sample shuffling. |
| Per-sample shuffling is done by argsort random noise. |
| x: [N, L, D], sequence |
| """ |
| N, L, D = x.shape |
| len_keep = int(L * (1 - mask_ratio)) |
|
|
| noise = torch.rand(N, L // group_factor, |
| device=x.device) |
| |
| indices = torch.arange(L, device=x.device).view(-1, group_factor) |
|
|
| |
| ids_shuffle = torch.argsort( |
| noise, dim=1) |
| ids_shuffle = indices[ids_shuffle].flatten(-2) |
| ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
| |
| ids_keep = ids_shuffle[:, :len_keep] |
| x_masked = torch.gather(x, |
| dim=1, |
| index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
| |
| mask = torch.ones([N, L], device=x.device) |
| mask[:, :len_keep] = 0 |
| |
| mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
| return x_masked, mask, ids_restore |
|
|
| def random_masking(self, x, mask_ratio): |
| """ |
| Perform per-sample random masking by per-sample shuffling. |
| Per-sample shuffling is done by argsort random noise. |
| x: [N, L, D], sequence |
| """ |
| N, L, D = x.shape |
| len_keep = int(L * (1 - mask_ratio)) |
|
|
| noise = torch.rand(N, L, device=x.device) |
|
|
| |
| ids_shuffle = torch.argsort( |
| noise, dim=1) |
| ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
| |
| ids_keep = ids_shuffle[:, :len_keep] |
| x_masked = torch.gather(x, |
| dim=1, |
| index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
| |
| mask = torch.ones([N, L], device=x.device) |
| mask[:, :len_keep] = 0 |
| |
| mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
| return x_masked, mask, ids_restore |
|
|
| def forward_features(self, x, mask_ratio): |
| x = self.patch_embed(x) |
| b, c, f, t = x.shape |
| x = x + self.time_pos_embed[:, :, :, :t] |
| x = x + self.freq_pos_embed[:, :, :, :] |
| x = rearrange(x, 'b c f t -> b (f t) c') |
| |
| x, mask, ids_restore = self.masking_func(x, mask_ratio) |
| if self.pooling == 'token': |
| cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
| cls_token = cls_token + self.token_pos_embed[:, :] |
| x = torch.cat((cls_token, x), dim=1) |
| x = self.pos_drop(x) |
| x = self.blocks(x) |
| x = self.norm(x) |
| return x, mask, ids_restore |
|
|
| def load_state_dict(self, state_dict, strict=True, **kwargs): |
| if 'time_pos_embed' in state_dict and self.time_pos_embed.shape != state_dict[ |
| 'time_pos_embed'].shape: |
| print( |
| "Positional Embedding shape not the same with model, resizing!" |
| ) |
| self.change_pos_embedding(state_dict) |
| super().load_state_dict(state_dict, strict=strict, **kwargs) |
|
|
| def change_pos_embedding(self, state_dict): |
| target_time_pos_embed_length = self.time_pos_embed.shape[-1] |
| target_freq_pos_embed_length = self.freq_pos_embed.shape[-2] |
|
|
| pretrained_time_pos_embed = state_dict['time_pos_embed'] |
| pretrained_freq_pos_embed = state_dict['freq_pos_embed'] |
|
|
| if target_freq_pos_embed_length <= pretrained_time_pos_embed.shape[-1]: |
| state_dict['time_pos_embed'] = pretrained_time_pos_embed[ |
| ..., :target_time_pos_embed_length] |
| else: |
| state_dict['time_pos_embed'] = torch.nn.functional.interpolate( |
| pretrained_time_pos_embed, |
| size=(1, target_time_pos_embed_length), |
| align_corners=False, |
| mode='bilinear') |
| if target_freq_pos_embed_length <= pretrained_freq_pos_embed.shape[-2]: |
| state_dict[ |
| 'freq_pos_embed'] = pretrained_freq_pos_embed[:, :, : |
| target_freq_pos_embed_length, :] |
| else: |
| state_dict['freq_pos_embed'] = torch.nn.functional.interpolate( |
| pretrained_freq_pos_embed, |
| size=(target_freq_pos_embed_length, 1), |
| align_corners=False, |
| mode='bilinear') |
|
|
| def forward_to_spec(self, x): |
| |
| with autocast(enabled=False): |
| X = self.front_end(x) |
| X = rearrange(X, 'b f t -> b 1 f t') |
| X = self.init_bn(X) |
| return X |
|
|
| def forward(self, x, mask_ratio: float = 0.75): |
| x = self.forward_to_spec(x) |
| x, mask, restore_idxs = self.forward_features(x, mask_ratio=mask_ratio) |
| return x, mask, restore_idxs |
|
|
|
|
| class AudioTransformerMAE_Decoder(nn.Module): |
|
|
| def __init__(self, |
| input_dim: int, |
| outputdim: int, |
| patch_size: int = 16, |
| patch_stride: int = 16, |
| embed_dim: int = 768, |
| num_patches: int = 100, |
| depth: int = 12, |
| num_heads: int = 12, |
| mlp_ratio: float = 4., |
| qkv_bias: bool = True, |
| drop_rate: float = 0., |
| attn_drop_rate: float = 0., |
| norm_layer: Optional[torch.nn.Module] = None, |
| act_layer: Optional[torch.nn.Module] = None, |
| cls_token: bool = False, |
| attention_type='Attention', |
| init_values=None, |
| **kwargs): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.patch_stride = patch_stride |
| self.patch_size = patch_size |
| self.input_dim = input_dim |
|
|
| self.input_proj = nn.Linear(input_dim, embed_dim) |
|
|
| self.mask_token = nn.Parameter(torch.randn(1, 1, embed_dim) * .02) |
| _norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
| _act_layer = act_layer or nn.GELU |
| self.use_cls = cls_token |
| num_patches_total = num_patches + 1 if not cls_token else num_patches |
| self.pos_embed = nn.Parameter( |
| torch.zeros(1, num_patches_total, embed_dim)) |
| self.pos_drop = nn.Dropout(p=drop_rate) |
| self.blocks = nn.Sequential(*[ |
| Block( |
| dim=embed_dim, |
| num_heads=num_heads, |
| mlp_ratio=mlp_ratio, |
| qkv_bias=qkv_bias, |
| init_values=init_values, |
| drop=drop_rate, |
| attn_drop=attn_drop_rate, |
| norm_layer=_norm_layer, |
| act_layer=_act_layer, |
| attention_type=attention_type, |
| ) for i in range(depth) |
| ]) |
| self.norm = _norm_layer(embed_dim) |
| self.outputlayer = nn.Linear(self.embed_dim, outputdim) |
| self.apply(self.init_weights) |
| torch.nn.init.normal_(self.mask_token, std=.02) |
|
|
| @torch.jit.ignore |
| def no_weight_decay(self): |
| return { |
| 'time_pos_embed', 'cls_token', 'freq_pos_embed', 'token_pos_embed' |
| } |
|
|
| def init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.trunc_normal_(module.weight, std=.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.LayerNorm): |
| nn.init.constant_(module.bias, 0) |
| nn.init.constant_(module.weight, 1.0) |
|
|
| def forward_features(self, x, ids_restore): |
| x = self.input_proj(x) |
| mask_tokens = self.mask_token.repeat( |
| x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
| if self.use_cls: |
| x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
| else: |
| x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) |
| x_ = torch.gather(x_, |
| dim=1, |
| index=ids_restore.unsqueeze(-1).repeat( |
| 1, 1, x.shape[2])) |
| if self.use_cls: |
| x = torch.cat([x[:, :1, :], x_], dim=1) |
| else: |
| x = x_ |
| t = x.shape[1] |
|
|
| x = x + self.pos_embed[:, :t, :] |
| x = self.pos_drop(x) |
| x = self.blocks(x) |
| x = self.norm(x) |
| return x |
|
|
| def forward(self, x, restore_idxs): |
| x = self.forward_features(x, restore_idxs) |
| x = self.outputlayer(x) |
| return x |
|
|
|
|
| class AudioTransformerMAE(nn.Module): |
|
|
| def __init__(self, |
| encoder: AudioTransformerMAE_Encoder, |
| decoder: AudioTransformerMAE_Decoder, |
| loss_fn: Optional[torch.nn.Module] = None): |
| super().__init__() |
| self.encoder = encoder |
| self.decoder = decoder |
| self.unfold = nn.Unfold( |
| kernel_size=self.encoder.patch_embed.patch_size, |
| stride=self.encoder.patch_embed.patch_size) |
| self.loss_fn = MAELoss() if loss_fn is None else loss_fn |
|
|
| def forward(self, |
| x: torch.Tensor, |
| mask_ratio: float = 0.75, |
| return_loss: bool = False): |
| latent, mask, restore_ids = self.encoder(x, mask_ratio=mask_ratio) |
| pred = self.decoder(latent, restore_ids) |
| with autocast(enabled=False): |
| targets = self.encoder.front_end(x) |
| targets = self.patchify(targets) |
| if return_loss: |
| return self.loss_fn(pred, targets, mask) |
| return pred, targets, mask |
|
|
| def patchify(self, x): |
| return self.unfold(x.unsqueeze(1)).transpose(-2, -1) |
|
|
|
|
| def dasheng_base(**kwargs): |
| encoder_kwargs = dict(embed_dim=768, |
| depth=12, |
| num_heads=12, |
| target_length=1008, |
| patch_size=[64, 4], |
| patch_stride=[64, 4]) |
| encoder_kwargs.update( |
| (k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs)) |
| encoder_kwargs = {**encoder_kwargs, **kwargs} |
| encoder = AudioTransformerMAE_Encoder(**encoder_kwargs) |
|
|
| decoder_kwargs = dict(embed_dim=512, |
| depth=8, |
| num_heads=16, |
| input_dim=encoder_kwargs['embed_dim'], |
| outputdim=encoder.patch_embed.patch_size[0] * |
| encoder.patch_embed.patch_size[1], |
| num_patches=encoder.patch_embed.num_patches) |
| decoder = AudioTransformerMAE_Decoder(**decoder_kwargs) |
| return AudioTransformerMAE(encoder, decoder) |
|
|
|
|
| def dasheng_06B(**kwargs): |
| encoder_kwargs = dict( |
| patch_size=[64, 4], |
| patch_stride=[64, 4], |
| embed_dim=1536, |
| depth=24, |
| num_heads=24, |
| mlp_ratio=4, |
| ) |
| encoder_kwargs.update( |
| (k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs)) |
| encoder_kwargs = {**encoder_kwargs, **kwargs} |
| encoder = AudioTransformerMAE_Encoder(**encoder_kwargs) |
|
|
| decoder_kwargs = dict(embed_dim=512, |
| depth=8, |
| num_heads=16, |
| input_dim=encoder_kwargs['embed_dim'], |
| outputdim=encoder.patch_embed.patch_size[0] * |
| encoder.patch_embed.patch_size[1], |
| num_patches=encoder.patch_embed.num_patches) |
| decoder = AudioTransformerMAE_Decoder(**decoder_kwargs) |
| return AudioTransformerMAE(encoder, decoder) |
|
|
|
|
| def dasheng_12B(**kwargs): |
| encoder_kwargs = dict( |
| patch_size=[64, 4], |
| patch_stride=[64, 4], |
| embed_dim=1536, |
| depth=40, |
| num_heads=24, |
| mlp_ratio=4, |
| ) |
| encoder_kwargs.update( |
| (k, kwargs[k]) for k in set(kwargs).intersection(encoder_kwargs)) |
| encoder_kwargs = {**encoder_kwargs, **kwargs} |
| encoder = AudioTransformerMAE_Encoder(**encoder_kwargs) |
|
|
| decoder_kwargs = dict(embed_dim=768, |
| depth=8, |
| num_heads=24, |
| input_dim=encoder_kwargs['embed_dim'], |
| outputdim=encoder.patch_embed.patch_size[0] * |
| encoder.patch_embed.patch_size[1], |
| num_patches=encoder.patch_embed.num_patches) |
| decoder = AudioTransformerMAE_Decoder(**decoder_kwargs) |
| return AudioTransformerMAE(encoder, decoder) |
|
|