| from functools import partial |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import torch.utils.checkpoint |
| from timm.models.swin_transformer import SwinTransformerBlock |
| from timm.models.vision_transformer import Block |
| from timm.models.layers import to_2tuple |
|
|
|
|
| class PatchEmbed(nn.Module): |
| """ Image to Patch Embedding |
| """ |
|
|
| def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): |
| super().__init__() |
| img_size = to_2tuple(img_size) |
| patch_size = to_2tuple(patch_size) |
| num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) |
| self.patch_hw = (img_size[1] // patch_size[1], img_size[0] // patch_size[0]) |
| self.img_size = img_size |
| self.patch_size = patch_size |
| self.num_patches = num_patches |
|
|
| self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
|
| def forward(self, x): |
| B, C, H, W = x.shape |
| x = self.proj(x).flatten(2).transpose(1, 2) |
| return x |
|
|
|
|
| def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| """ |
| embed_dim: output dimension for each position |
| pos: a list of positions to be encoded: size (M,) |
| out: (M, D) |
| """ |
| assert embed_dim % 2 == 0 |
| omega = np.arange(embed_dim // 2, dtype=float) |
| omega /= embed_dim / 2. |
| omega = 1. / 10000 ** omega |
|
|
| pos = pos.reshape(-1) |
| out = np.einsum('m,d->md', pos, omega) |
|
|
| emb_sin = np.sin(out) |
| emb_cos = np.cos(out) |
|
|
| emb = np.concatenate([emb_sin, emb_cos], axis=1) |
| return emb |
|
|
|
|
| def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
| assert embed_dim % 2 == 0 |
| |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
| emb = np.concatenate([emb_h, emb_w], axis=1) |
| return emb |
|
|
|
|
| def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): |
| """ |
| grid_size: int of the grid height and width |
| return: |
| pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| """ |
| grid_h = np.arange(grid_size, dtype=np.float32) |
| grid_w = np.arange(grid_size, dtype=np.float32) |
| grid = np.meshgrid(grid_w, grid_h) |
| grid = np.stack(grid, axis=0) |
|
|
| grid = grid.reshape([2, 1, grid_size, grid_size]) |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| if cls_token: |
| pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
| return pos_embed |
|
|
|
|
| def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): |
| """ |
| grid_size: int of the grid height and width |
| return: |
| pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| """ |
| grid_h = np.arange(grid_size[0], dtype=np.float32) |
| grid_w = np.arange(grid_size[1], dtype=np.float32) |
| grid = np.meshgrid(grid_w, grid_h) |
| grid = np.stack(grid, axis=0) |
|
|
| grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| if cls_token: |
| pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
| return pos_embed |
|
|
|
|
| class SwinTransformerBlockWrapper(torch.nn.Module): |
| """ |
| Wrap SwinTransformerBlock to fit the input shape of [B, N, C] like TransformerBlock. |
| |
| The SwinTransformerBlock takes the input shape of [B, H, W, C], and TransformerBlock |
| takes the input shape of [B, N, C]. |
| """ |
|
|
| def __init__(self, block: SwinTransformerBlock): |
| super().__init__() |
| self.block = block |
| self.input_resolution = block.input_resolution |
|
|
| def forward(self, x): |
| """ |
| :param x: [B, N, C] |
| :return: [B, N, C] |
| """ |
| B, N, C = x.shape |
| x = x.reshape(B, *self.input_resolution, C) |
| x = self.block(x) |
| x = x.reshape(B, N, C) |
| return x |
|
|
|
|
| class MaskedAutoencoderViT(nn.Module): |
| """ Masked Autoencoder with VisionTransformer backbone |
| """ |
|
|
| def __init__( |
| self, |
| img_size=224, |
| patch_size=16, |
| in_chans=3, |
| embed_dim=1024, |
| depth=24, |
| num_heads=16, |
| decoder_mode=0, |
| no_shift=False, |
| decoder_embed_dim=512, |
| decoder_depth=8, |
| decoder_num_heads=16, |
| mlp_ratio=4., |
| norm_layer=nn.LayerNorm, |
| norm_pix_loss=False, |
| pos_trainable=False, |
| ): |
| super().__init__() |
|
|
| self.img_size = to_2tuple(img_size) |
|
|
| self.embed_dim = embed_dim |
| self.decoder_embed_dim = decoder_embed_dim |
| |
| self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim) |
| num_patches = self.patch_embed.num_patches |
|
|
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), |
| requires_grad=pos_trainable) |
|
|
| self.encoder_depth = depth |
| self.blocks = nn.ModuleList([ |
| Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]) |
| self.norm = norm_layer(embed_dim) |
|
|
| |
| |
| self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True) |
|
|
| self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) |
| self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), |
| requires_grad=pos_trainable) |
|
|
| self.no_shift = no_shift |
|
|
| self.decoder_mode = decoder_mode |
|
|
| window_size = (4, 4) |
| feat_size = (self.img_size[0] // patch_size, 8) |
|
|
| if self.decoder_mode == 1: |
| decoder_modules = [] |
| for index in range(16): |
| if self.no_shift: |
| shift_size = (0, 0) |
| else: |
| if (index % 2) == 0: |
| shift_size = (0, 0) |
| else: |
| shift_size = (2, 0) |
| |
| decoder_modules.append( |
| SwinTransformerBlockWrapper( |
| SwinTransformerBlock( |
| dim=decoder_embed_dim, |
| input_resolution=feat_size, |
| num_heads=16, |
| window_size=window_size, |
| shift_size=shift_size, |
| mlp_ratio=mlp_ratio, |
| proj_drop=0.0, |
| attn_drop=0.0, |
| drop_path=0.0, |
| norm_layer=norm_layer, |
| ) |
| ) |
| ) |
| self.decoder_blocks = nn.ModuleList(decoder_modules) |
| else: |
| |
| self.decoder_blocks = nn.ModuleList([ |
| Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) |
| for _ in range(decoder_depth)]) |
|
|
| self.decoder_norm = norm_layer(decoder_embed_dim) |
| self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size ** 2 * in_chans, bias=True) |
|
|
| self.norm_pix_loss = norm_pix_loss |
|
|
| self.patch_size = patch_size |
|
|
| self.initialize_weights() |
|
|
| def initialize_weights(self): |
| |
| pos_embed = get_2d_sincos_pos_embed_flexible(self.pos_embed.shape[-1], self.patch_embed.patch_hw, |
| cls_token=True) |
| self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
| decoder_pos_embed = get_2d_sincos_pos_embed_flexible(self.decoder_pos_embed.shape[-1], |
| self.patch_embed.patch_hw, cls_token=True) |
| self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) |
|
|
| |
| w = self.patch_embed.proj.weight.data |
| torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) |
|
|
| |
| torch.nn.init.normal_(self.cls_token, std=.02) |
| torch.nn.init.normal_(self.mask_token, std=.02) |
|
|
| |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| |
| torch.nn.init.xavier_uniform_(m.weight) |
| if isinstance(m, nn.Linear) and m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.LayerNorm): |
| nn.init.constant_(m.bias, 0) |
| nn.init.constant_(m.weight, 1.0) |
|
|
| def patchify(self, imgs): |
| """ |
| imgs: (N, 3, H, W) |
| x: (N, L, patch_size**2 *3) |
| L = (H/p)*(W/p) |
| """ |
| p = self.patch_embed.patch_size[0] |
|
|
| h = imgs.shape[2] // p |
| w = imgs.shape[3] // p |
| |
| x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) |
| x = torch.einsum('nchpwq->nhwpqc', x) |
| x = x.reshape(imgs.shape[0], h * w, p ** 2 * 1) |
|
|
| return x |
|
|
| def unpatchify(self, x): |
| """ |
| x: (N, L, patch_size**2 *3) |
| specs: (N, 1, H, W) |
| """ |
| p = self.patch_embed.patch_size[0] |
| h = self.img_size[0] // p |
| w = 128 // p |
| x = x.reshape(shape=(x.shape[0], h, w, p, p, 1)) |
| x = torch.einsum('nhwpqc->nchpwq', x) |
| specs = x.reshape(x.shape[0], 1, h * p, w * p) |
| return specs |
|
|
| def random_masking(self, x, mask_ratio): |
| """ |
| Perform per-sample random masking by per-sample shuffling. |
| Per-sample shuffling is done by argsort random noise. |
| x: [N, L, D], sequence |
| """ |
| N, L, D = x.shape |
| len_keep = int(L * (1 - mask_ratio)) |
|
|
| noise = torch.rand(N, L, device=x.device) |
|
|
| |
| ids_shuffle = torch.argsort(noise, dim=1) |
| ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
| |
| ids_keep = ids_shuffle[:, :len_keep] |
| x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
| |
| mask = torch.ones([N, L], device=x.device) |
| mask[:, :len_keep] = 0 |
| |
| mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
| return x_masked, mask, ids_restore |
|
|
| def forward_encoder(self, x, mask_ratio): |
| """ |
| :param x: [N, C, H, W] |
| :param mask_ratio: float. ratio of masked patches |
| :return: tuple. x: [N, L', D], mask: [N, L], ids_restore: [N, L], None |
| """ |
| |
| x = self.patch_embed(x) |
|
|
| B, L, D = x.shape |
|
|
| |
| x = x + self.pos_embed[:, 1:L + 1, :] |
|
|
| |
| x, mask, ids_restore = self.random_masking(x, mask_ratio) |
|
|
| |
| cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| |
| for blk in self.blocks: |
| x = blk(x) |
| x = self.norm(x) |
|
|
| return x, mask, ids_restore |
|
|
| def forward_encoder_no_mask( |
| self, |
| x, |
| header='mean' |
| ): |
| """ |
| :param x: [N, C, H, W] |
| :param header: str. 'cls' or 'mean' |
| :param key_padding_mask: [N, L], 0 is keep, 1 is remove |
| :return: contextual_emb: [N, L, D] |
| """ |
| |
| x = self.patch_embed(x) |
|
|
| B, L, D = x.shape |
|
|
| |
| x = x + self.pos_embed[:, 1:L + 1, :] |
|
|
| |
| cls_token = self.cls_token + self.pos_embed[:, :1, :] |
| cls_tokens = cls_token.expand(x.shape[0], -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| |
| for n, blk in enumerate(self.blocks): |
| x = blk(x) |
|
|
| x = self.norm(x) |
|
|
| if header == 'cls': |
| emb = x[:, 0, :] |
| elif header == 'mean': |
| emb = x[:, 1:, :].mean(dim=1) |
| else: |
| raise NotImplementedError |
|
|
| return emb |
|
|
| def forward_decoder(self, x, ids_restore): |
| """ |
| :param x: [N, L, D] |
| :param ids_restore: [N, L] |
| :return: pred: [N, L, p*p*3], None, None |
| """ |
| |
| x = self.decoder_embed(x) |
|
|
| |
| mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1) |
| x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) |
| x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) |
| x = torch.cat([x[:, :1, :], x_], dim=1) |
|
|
| B, L, D = x.shape |
|
|
| |
| x = x + self.decoder_pos_embed[:, :L, :] |
|
|
| if self.decoder_mode != 0: |
| B, L, D = x.shape |
| x = x[:, 1:, :] |
|
|
| if self.decoder_mode > 3: |
| x = self.decoder_blocks(x) |
| else: |
| |
| for blk in self.decoder_blocks: |
| x = blk(x) |
|
|
| x = self.decoder_norm(x) |
|
|
| |
| pred = self.decoder_pred(x) |
|
|
| |
| if self.decoder_mode == 0: |
| pred = pred[:, 1:, :] |
|
|
| return pred |
|
|
| def forward_loss(self, imgs, pred, mask, norm_pix_loss=False): |
| """ |
| imgs: [N, 3, H, W] |
| pred: [N, L, p*p*3] |
| mask: [N, L], 0 is keep, 1 is remove, |
| """ |
| target = self.patchify(imgs) |
| if norm_pix_loss: |
| mean = target.mean(dim=-1, keepdim=True) |
| var = target.var(dim=-1, keepdim=True) |
| target = (target - mean) / (var + 1.e-6) ** .5 |
|
|
| loss = (pred - target) ** 2 |
| loss = loss.mean(dim=-1) |
|
|
| loss = (loss * mask).sum() / mask.sum() |
| return loss |
|
|
| def forward(self, imgs, mask_ratio=0.8): |
| """ |
| |
| :param imgs: [N, C, H, W] |
| :param mask_ratio: float. ratio of masked patches |
| :return: tuple. loss_recon: float, pred: [N, L, p*p*3], mask: [N, L], None |
| """ |
| emb_enc, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) |
| pred = self.forward_decoder(emb_enc, ids_restore) |
| loss_recon = self.forward_loss(imgs, pred, mask, norm_pix_loss=self.norm_pix_loss) |
| return loss_recon, pred, mask |
|
|
|
|
| if __name__ == '__main__': |
| device = 'cpu' |
| |
|
|
| |
| audio_mae = MaskedAutoencoderViT( |
| img_size=(2048, 128), |
| patch_size=16, |
| in_chans=1, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| decoder_mode=1, |
| no_shift=False, |
| decoder_embed_dim=512, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| norm_pix_loss=False, |
| pos_trainable=False, |
| ) |
|
|
| |
| ckpt_path = 'music-mae-32kHz.pth' |
| audio_mae.load_state_dict(torch.load(ckpt_path, map_location='cpu')) |
| audio_mae.to(device) |
|
|
| |
| |
| x = torch.randn(4, 1, 2048, 128).to(device) |
|
|
| |
| emb = audio_mae.forward_encoder_no_mask(x, header='mean') |
|
|