| from typing import Union, Tuple |
| import numpy as np |
| from numpy.typing import NDArray |
| import torch |
| from torch import nn |
| from functools import partial |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| import librosa |
| import miniaudio |
|
|
| from mae import MaskedAutoencoderViT |
|
|
|
|
| def load_audio( |
| path: str, |
| sr: int = 32000, |
| duration: int = 20, |
| ) -> (np.ndarray, int): |
| g = miniaudio.stream_file(path, output_format=miniaudio.SampleFormat.FLOAT32, nchannels=1, |
| sample_rate=sr, frames_to_read=sr * duration) |
| signal = np.array(next(g)) |
| return signal |
|
|
|
|
| def mel_spectrogram( |
| signal: np.ndarray, |
| sr: int = 32000, |
| n_fft: int = 800, |
| hop_length: int = 320, |
| n_mels: int = 128, |
| ) -> np.ndarray: |
| mel_spec = librosa.feature.melspectrogram( |
| y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, |
| window='hann', pad_mode='constant' |
| ) |
| mel_spec = librosa.power_to_db(mel_spec) |
| return mel_spec.T |
|
|
|
|
| def display_image( |
| img: Union[NDArray, Image.Image], |
| figsize: Tuple[float, float] = (5, 5), |
| ) -> None: |
| plt.figure(figsize=figsize) |
| plt.imshow(img, origin='lower', aspect='auto') |
| plt.axis('off') |
| plt.colorbar() |
| plt.tight_layout() |
| plt.show() |
|
|
|
|
| def normalize(arr: np.ndarray, eps: float = 1e-8) -> np.ndarray: |
| return (arr - arr.mean()) / (arr.std() + eps) |
|
|
|
|
| if __name__ == '__main__': |
| mp3_file = "~/Downloads/songs/See You Again.mp3" |
| mel_spec = mel_spectrogram(load_audio(mp3_file, duration=21)) |
|
|
| |
| length = mel_spec.shape[0] |
| target_length = 2048 |
| mel_spec = mel_spec[:target_length] if length > target_length else np.pad( |
| mel_spec, ((0, target_length - length), (0, 0)), mode='constant', constant_values=mel_spec.min() |
| ) |
|
|
| |
| mel_spec = normalize(mel_spec) |
|
|
| display_image(mel_spec.T, figsize=(10, 4)) |
|
|
| |
| mae = MaskedAutoencoderViT( |
| img_size=(2048, 128), |
| patch_size=16, |
| in_chans=1, |
| embed_dim=768, |
| depth=12, |
| num_heads=12, |
| decoder_mode=1, |
| no_shift=False, |
| decoder_embed_dim=512, |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), |
| norm_pix_loss=False, |
| pos_trainable=False, |
| ) |
|
|
| |
| ckpt_path = 'music-mae-32kHz.pth' |
| mae.load_state_dict(torch.load(ckpt_path, map_location='cpu')) |
|
|
| device = 'cpu' |
| mae.to(device) |
|
|
| x = torch.from_numpy(mel_spec).unsqueeze(0).unsqueeze(0).to(device) |
| mse_loss, y, mask = mae(x, mask_ratio=0.7) |
|
|
| y[mask == 0.] = mae.patchify(x)[mask == 0.] |
| x_reconstructed = mae.unpatchify(y).squeeze(0).squeeze(0).detach().numpy() |
|
|
| print(f'mse_loss: {mse_loss.item()}') |
| display_image(x_reconstructed.T, figsize=(10, 4)) |
|
|