| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import random |
|
|
| import torch |
| import torchaudio |
| from einops import rearrange |
| from torch import einsum, nn |
| from torch.nn.common_types import _size_2_t |
| from transformers import PreTrainedModel |
|
|
| from .configuration_musicfm import MusicFMConfig, MusicFMInferenceConfig |
|
|
|
|
| class MusicFM25Hz(PreTrainedModel): |
| config_class = MusicFMConfig |
|
|
| def __init__(self, config: MusicFMConfig) -> None: |
| super().__init__(config) |
|
|
| |
| self.num_codebooks = config.num_codebooks |
| self.codebook_dim = config.codebook_dim |
| self.codebook_size = config.codebook_size |
| self.features = config.features |
| self.hop_length = config.hop_length |
| self.n_mels = config.n_mels |
| self.conv_dim = config.conv_dim |
| self.encoder_dim = config.encoder_dim |
| self.encoder_depth = config.encoder_depth |
| self.mask_hop = config.mask_hop |
| self.mask_prob = config.mask_prob |
| self.is_flash = config.is_flash |
| self.stat = config.stat |
|
|
| |
| self.preprocessor_melspec_2048 = MelSTFT( |
| n_fft=2048, hop_length=self.hop_length, is_db=True |
| ) |
|
|
| |
| seed = 142 |
| for feature in self.features: |
| for i in range(self.num_codebooks): |
| setattr( |
| self, |
| f"quantizer_{feature}_{i}", |
| RandomProjectionQuantizer( |
| self.n_mels * 4, |
| self.codebook_dim, |
| self.codebook_size, |
| seed=seed + i, |
| ), |
| ) |
|
|
| |
| self.conv = Conv2dSubsampling( |
| 1, self.conv_dim, self.encoder_dim, strides=[2, 2], n_bands=self.n_mels |
| ) |
|
|
| |
| if config.is_flash: |
| from .flash_conformer import ( |
| Wav2Vec2ConformerConfig, |
| Wav2Vec2ConformerEncoder, |
| ) |
| else: |
| from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
| Wav2Vec2ConformerConfig, |
| Wav2Vec2ConformerEncoder, |
| ) |
|
|
| conformer_config = Wav2Vec2ConformerConfig.from_pretrained( |
| "facebook/wav2vec2-conformer-rope-large-960h-ft" |
| ) |
| conformer_config.num_hidden_layers = self.encoder_depth |
| conformer_config.hidden_size = self.encoder_dim |
| self.conformer = Wav2Vec2ConformerEncoder(conformer_config) |
|
|
| |
| self.linear = nn.Linear(self.encoder_dim, self.codebook_size) |
|
|
| |
| self.loss = nn.CrossEntropyLoss() |
|
|
| |
| random.seed(seed) |
| self.cls_token = nn.Parameter(torch.randn(self.encoder_dim)) |
|
|
| def masking(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.LongTensor]: |
| """random masking of 400ms with given probability""" |
| mx = x.clone() |
| b, t = mx.shape |
| len_masking_raw = int(24000 * self.mask_hop) |
| len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop) |
|
|
| |
| start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob |
| time_domain_masked_indices = torch.nonzero( |
| start_indices.repeat_interleave(len_masking_raw, dim=1) |
| ) |
| token_domain_masked_indices = torch.nonzero( |
| start_indices.repeat_interleave(len_masking_token, dim=1) |
| ) |
|
|
| |
| masking_noise = ( |
| torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1 |
| ) |
| mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device) |
|
|
| return mx, token_domain_masked_indices |
|
|
| @torch.no_grad() |
| def preprocessing( |
| self, x: torch.Tensor, features: dict[str, torch.Tensor] |
| ) -> dict[str, torch.Tensor]: |
| """extract classic audio features""" |
| |
| if x.dtype == torch.float16: |
| precision = 16 |
| else: |
| precision = 32 |
|
|
| out = {} |
| for key in features: |
| layer = getattr(self, "preprocessor_%s" % key) |
| out[key] = layer.float()(x.float())[..., :-1] |
| if precision == 16: |
| out[key] = out[key].half() |
| return out |
|
|
| def encoder(self, x: torch.Tensor) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
| """2-layer conv + w2v-conformer""" |
| x = self.conv(x) |
| out = self.conformer(x, output_hidden_states=True) |
| hidden_emb = out["hidden_states"] |
| last_emb = out["last_hidden_state"] |
| logits = self.linear(last_emb) |
| logits = { |
| key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size] |
| for i, key in enumerate(self.features) |
| } |
| return logits, hidden_emb |
|
|
| @torch.no_grad() |
| def normalize(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| """normalize the input audio to have zero mean unit variance""" |
| for key in x.keys(): |
| x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] |
| return x |
|
|
| @torch.no_grad() |
| def rearrange(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| """rearrange the batch to flatten every 4 steps""" |
| for key in x.keys(): |
| if key == "chromagram": |
| x[key] = rearrange(x[key], "b f t -> b t f") |
| else: |
| x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4) |
|
|
| return x |
|
|
| @torch.no_grad() |
| def tokenize(self, x: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: |
| out = {} |
| for key in x.keys(): |
| layer = getattr(self, "quantizer_%s" % key) |
| out[key] = layer(x[key]) |
| return out |
|
|
| def get_targets(self, x: torch.Tensor) -> dict[str, torch.Tensor]: |
| x = self.preprocessing(x, features=self.features) |
| x = self.normalize(x) |
| x = self.rearrange(x) |
| target_tokens = self.tokenize(x) |
|
|
| return target_tokens |
|
|
| def get_predictions( |
| self, x: torch.Tensor |
| ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: |
| |
| x = self.preprocessing(x, features=["melspec_2048"]) |
| x = self.normalize(x) |
|
|
| |
| logits, hidden_emb = self.encoder(x["melspec_2048"]) |
|
|
| return logits, hidden_emb |
|
|
| def get_latent(self, x: torch.Tensor, layer_ix: int = 12) -> torch.Tensor: |
| _, hidden_states = self.get_predictions(x) |
| emb = hidden_states[layer_ix] |
| return emb |
|
|
| def get_loss( |
| self, |
| logits: dict[str, torch.Tensor], |
| target_tokens: dict[str, torch.Tensor], |
| masked_indices: torch.LongTensor, |
| ) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: |
| losses = {} |
| accuracies = {} |
| for key in logits.keys(): |
| masked_logits = logits[key][tuple(masked_indices.t())] |
| masked_tokens = target_tokens[key][tuple(masked_indices.t())] |
| losses[key] = self.loss(masked_logits, masked_tokens) |
| accuracies[key] = ( |
| torch.sum(masked_logits.argmax(-1) == masked_tokens) |
| / masked_tokens.numel() |
| ) |
| return losses, accuracies |
|
|
| def forward( |
| self, x: torch.Tensor |
| ) -> tuple[ |
| dict[str, torch.Tensor], |
| torch.Tensor, |
| dict[str, torch.Tensor], |
| dict[str, torch.Tensor], |
| ]: |
| |
| target_tokens = self.get_targets(x) |
|
|
| |
| x, masked_indices = self.masking(x) |
|
|
| |
| logits, hidden_emb = self.get_predictions(x) |
|
|
| |
| losses, accuracies = self.get_loss(logits, target_tokens, masked_indices) |
|
|
| return logits, hidden_emb, losses, accuracies |
|
|
|
|
| class MusicFM25HzInference(MusicFM25Hz): |
| config_class = MusicFMInferenceConfig |
|
|
| def __init__(self, config: MusicFMInferenceConfig) -> None: |
| super().__init__(config) |
|
|
| self.layer_index = config.layer_index |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| layer_index = self.layer_index |
| |
| _, hidden_emb = self.get_predictions(x) |
|
|
| outputs = hidden_emb[layer_index] |
|
|
| return outputs |
|
|
|
|
| class MelSTFT(nn.Module): |
| def __init__( |
| self, |
| sample_rate: int = 24000, |
| n_fft: int = 2048, |
| hop_length: int = 240, |
| n_mels: int = 128, |
| is_db: bool = False, |
| ) -> None: |
| super().__init__() |
|
|
| |
| self.mel_stft = torchaudio.transforms.MelSpectrogram( |
| sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels |
| ) |
|
|
| |
| self.is_db = is_db |
| if is_db: |
| self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() |
|
|
| def forward(self, waveform: torch.Tensor) -> torch.Tensor: |
| if self.is_db: |
| return self.amplitude_to_db(self.mel_stft(waveform)) |
| else: |
| return self.mel_stft(waveform) |
|
|
|
|
| class RandomProjectionQuantizer(nn.Module): |
| """ |
| Random projection and codebook lookup module |
| |
| Some code is borrowed from: |
| https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py |
| But I did normalization using pre-computed global mean & variance instead of using layer norm. |
| """ |
|
|
| def __init__( |
| self, |
| input_dim: int, |
| codebook_dim: int, |
| codebook_size: int, |
| seed: int = 142, |
| ) -> None: |
| super().__init__() |
|
|
| |
| torch.manual_seed(seed) |
|
|
| |
| random_projection = torch.empty(input_dim, codebook_dim) |
| nn.init.xavier_normal_(random_projection) |
| self.register_buffer("random_projection", random_projection) |
|
|
| |
| codebook = torch.empty(codebook_size, codebook_dim) |
| nn.init.normal_(codebook) |
| self.register_buffer("codebook", codebook) |
|
|
| def codebook_lookup(self, x: torch.Tensor) -> torch.Tensor: |
| |
| b = x.shape[0] |
| x = rearrange(x, "b n e -> (b n) e") |
|
|
| |
| normalized_x = nn.functional.normalize(x, dim=1, p=2) |
| normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2) |
|
|
| |
| distances = torch.cdist(normalized_codebook, normalized_x) |
|
|
| |
| nearest_indices = torch.argmin(distances, dim=0) |
|
|
| |
| xq = rearrange(nearest_indices, "(b n) -> b n", b=b) |
|
|
| return xq |
|
|
| @torch.no_grad() |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| self.eval() |
|
|
| |
| x = einsum("b n d, d e -> b n e", x, self.random_projection) |
|
|
| |
| xq = self.codebook_lookup(x) |
|
|
| return xq |
|
|
|
|
| class Res2dModule(nn.Module): |
| def __init__(self, idim: int, odim: int, stride: _size_2_t = (2, 2)) -> None: |
| super().__init__() |
| self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) |
| self.bn1 = nn.BatchNorm2d(odim) |
| self.conv2 = nn.Conv2d(odim, odim, 3, padding=1) |
| self.bn2 = nn.BatchNorm2d(odim) |
| self.relu = nn.ReLU() |
|
|
| |
| self.diff = False |
| if (idim != odim) or (stride[0] > 1): |
| self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) |
| self.bn3 = nn.BatchNorm2d(odim) |
| self.diff = True |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))) |
| if self.diff: |
| x = self.bn3(self.conv3(x)) |
| out = x + out |
| out = self.relu(out) |
| return out |
|
|
|
|
| class Conv2dSubsampling(nn.Module): |
| """Convolutional 2D subsampling (to 1/4 length). |
| |
| Args: |
| idim (int): Input dimension. |
| hdim (int): Hidden dimension. |
| odim (int): Output dimension. |
| strides (list): Sizes of strides. |
| n_bands (int): Number of frequency bands. |
| |
| """ |
|
|
| def __init__( |
| self, |
| idim: int, |
| hdim: int, |
| odim: int, |
| strides: list[int] = [2, 2], |
| n_bands: int = 64, |
| ) -> None: |
| """Construct an Conv2dSubsampling object.""" |
| super().__init__() |
|
|
| self.conv = nn.Sequential( |
| Res2dModule(idim, hdim, (2, strides[0])), |
| Res2dModule(hdim, hdim, (2, strides[1])), |
| ) |
| self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Subsample x. |
| |
| Args: |
| x (torch.Tensor): Input tensor (#batch, idim, time). |
| |
| Returns: |
| torch.Tensor: Subsampled tensor (#batch, time', odim), |
| where time' = time // 4. |
| """ |
|
|
| if x.dim() == 3: |
| x = x.unsqueeze(1) |
|
|
| x = self.conv(x) |
| x = rearrange(x, "b c f t -> b t (c f)") |
| x = self.linear(x) |
|
|
| return x |
|
|