| import torch |
| import torch.nn.functional as F |
| from torch import nn, einsum |
|
|
| from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking |
|
|
| from audiolm_pytorch import AudioLM |
| from audiolm_pytorch.utils import AudioConditionerBase |
|
|
| from x_clip.tokenizer import tokenizer |
| from vector_quantize_pytorch import ResidualVQ |
|
|
| from einops import rearrange, repeat, reduce, pack, unpack |
|
|
| from beartype.typing import List, Optional, Tuple |
| from beartype import beartype |
|
|
| |
|
|
| def exists(val): |
| return val is not None |
|
|
| def default(val, d): |
| return val if exists(val) else d |
|
|
| def round_down_nearest_multiple(n, divisor): |
| return n // divisor * divisor |
|
|
| |
|
|
| def log(t, eps = 1e-20): |
| return torch.log(t.clamp(min = eps)) |
|
|
| def l2norm(t): |
| return F.normalize(t, p = 2, dim = -1) |
|
|
| |
| |
|
|
| def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): |
| _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype |
|
|
| y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') |
| assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' |
|
|
| omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) |
| omega = 1. / (temperature ** omega) |
|
|
| y = y.flatten()[:, None] * omega[None, :] |
| x = x.flatten()[:, None] * omega[None, :] |
|
|
| pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) |
| pe = pe.type(dtype) |
|
|
| return rearrange(pe, '(h w) d -> h w d', h = h, w = w) |
|
|
| |
|
|
| class LayerNorm(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.gamma = nn.Parameter(torch.ones(dim)) |
| self.register_buffer('beta', torch.zeros(dim)) |
|
|
| def forward(self, x): |
| return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) |
|
|
| |
|
|
| class GEGLU(nn.Module): |
| def forward(self, x): |
| x, gate = x.chunk(2, dim = -1) |
| return F.gelu(gate) * x |
|
|
| def FeedForward(dim, mult = 4, dropout = 0.): |
| dim_hidden = int(dim * mult * 2 / 3) |
|
|
| return nn.Sequential( |
| LayerNorm(dim), |
| nn.Linear(dim, dim_hidden * 2, bias = False), |
| GEGLU(), |
| nn.Dropout(dropout), |
| nn.Linear(dim_hidden, dim, bias = False) |
| ) |
|
|
| |
|
|
| class Attention(nn.Module): |
| def __init__( |
| self, |
| dim, |
| causal = False, |
| dim_head = 64, |
| heads = 8, |
| dropout = 0. |
| ): |
| super().__init__() |
| self.heads = heads |
| self.scale = dim_head ** -0.5 |
| self.causal = causal |
| inner_dim = dim_head * heads |
|
|
| self.norm = LayerNorm(dim) |
|
|
| self.attn_dropout = nn.Dropout(dropout) |
|
|
| self.to_q = nn.Linear(dim, inner_dim, bias = False) |
| self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) |
|
|
| self.to_out = nn.Sequential( |
| nn.Linear(inner_dim, dim, bias = False), |
| nn.Dropout(dropout) |
| ) |
|
|
| def forward( |
| self, |
| x, |
| mask = None |
| ): |
| b, n, _, device = *x.shape, x.device |
|
|
| |
|
|
| x = self.norm(x) |
|
|
| |
|
|
| q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1) |
|
|
| |
|
|
| q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v)) |
|
|
| q = q * self.scale |
|
|
| |
|
|
| sim = einsum('b h i d, b h j d -> b h i j', q, k) |
|
|
| if exists(mask): |
| mask = rearrange(mask, 'b j -> b 1 1 j') |
| sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) |
|
|
| if self.causal: |
| i, j = sim.shape[-2:] |
| causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1) |
| sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) |
|
|
| |
|
|
| attn = sim.softmax(dim = -1) |
| attn = self.attn_dropout(attn) |
|
|
| |
|
|
| out = einsum('b h i j, b h j d -> b h i d', attn, v) |
|
|
| |
|
|
| out = rearrange(out, 'b h n d -> b n (h d)') |
| return self.to_out(out) |
|
|
| |
|
|
| class Transformer(nn.Module): |
| def __init__( |
| self, |
| dim, |
| depth, |
| dim_head = 64, |
| heads = 8, |
| attn_dropout = 0., |
| ff_mult = 4, |
| ff_dropout = 0. |
| ): |
| super().__init__() |
| self.layers = nn.ModuleList([]) |
| for _ in range(depth): |
| self.layers.append(nn.ModuleList([ |
| Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout), |
| FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout), |
| ])) |
|
|
| def forward(self, x, mask = None): |
|
|
| for attn, ff in self.layers: |
| x = attn(x, mask = mask) + x |
| x = ff(x) + x |
|
|
| return x |
|
|
| |
|
|
| def pair(t): |
| return (t, t) if not isinstance(t, tuple) else t |
|
|
| class AudioSpectrogramTransformer(nn.Module): |
| def __init__( |
| self, |
| dim, |
| depth, |
| patch_size = 16, |
| dim_head = 64, |
| heads = 8, |
| attn_dropout = 0., |
| ff_mult = 4, |
| ff_dropout = 0., |
| spec_n_fft = 128, |
| spec_power = 2, |
| spec_win_length = 24, |
| spec_hop_length = None, |
| spec_pad = 0, |
| spec_center = True, |
| spec_pad_mode = 'reflect', |
| spec_aug_stretch_factor = 0.8, |
| spec_aug_freq_mask = 80, |
| spec_aug_time_mask = 80 |
| |
| ): |
| super().__init__() |
| self.dim = dim |
|
|
| self.patch_size = pair(patch_size) |
| self.to_patch_tokens = nn.Conv2d(self.patch_size[0] * self.patch_size[1], dim, 1) |
|
|
| self.spec = Spectrogram( |
| n_fft = spec_n_fft, |
| power = spec_power, |
| win_length = spec_win_length, |
| hop_length = spec_hop_length, |
| pad = spec_pad, |
| center = spec_center, |
| pad_mode = spec_pad_mode |
| ) |
|
|
| |
|
|
| self.aug = torch.nn.Sequential( |
| TimeStretch(spec_aug_stretch_factor, fixed_rate=True), |
| FrequencyMasking(freq_mask_param = spec_aug_freq_mask), |
| TimeMasking(time_mask_param = spec_aug_time_mask), |
| ) |
|
|
| self.transformer = Transformer( |
| dim = dim, |
| depth = depth, |
| dim_head = dim_head, |
| heads = heads, |
| attn_dropout = attn_dropout, |
| ff_mult = ff_mult, |
| ff_dropout = ff_dropout |
| ) |
|
|
| self.norm = LayerNorm(dim) |
|
|
| def forward(self, x): |
| x = self.spec(x) |
|
|
| if self.training: |
| x = self.aug(x) |
|
|
| |
|
|
| height, width = x.shape[-2:] |
| patch_height, patch_width = self.patch_size |
|
|
| rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width))) |
|
|
| if (height, width) != (rounded_height, rounded_width): |
| print(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer') |
|
|
| x = x[..., :rounded_height, :rounded_width] |
|
|
| |
|
|
| x = rearrange(x, 'b (h p1) (w p2) -> b (p1 p2) h w', p1 = patch_height, p2 = patch_width) |
| x = self.to_patch_tokens(x) |
|
|
| |
|
|
| x = rearrange(x, 'b c h w -> b h w c') |
| x = x + posemb_sincos_2d(x) |
|
|
| |
|
|
| x = rearrange(x, 'b ... c -> b (...) c') |
|
|
| x = self.transformer(x) |
|
|
| |
|
|
| x = reduce(x, 'b n d -> b d', 'mean') |
|
|
| return self.norm(x) |
|
|
| |
|
|
| @beartype |
| class TextTransformer(nn.Module): |
| def __init__( |
| self, |
| dim, |
| depth, |
| num_tokens = tokenizer.vocab_size, |
| max_seq_len = 256, |
| dim_head = 64, |
| heads = 8, |
| attn_dropout = 0., |
| ff_dropout = 0., |
| ff_mult = 4, |
| pad_id = 0 |
| ): |
| super().__init__() |
| self.dim = dim |
|
|
| self.token_emb = nn.Embedding(num_tokens, dim) |
| self.pos_emb = nn.Embedding(max_seq_len, dim) |
|
|
| self.cls_token = nn.Parameter(torch.randn(dim)) |
|
|
| self.transformer = Transformer( |
| dim = dim, |
| depth = depth, |
| dim_head = dim_head, |
| heads = heads, |
| attn_dropout = attn_dropout, |
| ff_dropout = ff_dropout, |
| ff_mult = ff_mult |
| ) |
|
|
| self.pad_id = pad_id |
| self.norm = LayerNorm(dim) |
|
|
| def forward( |
| self, |
| x = None, |
| raw_texts: Optional[List[str]] = None, |
| mask = None |
| ): |
| assert exists(x) ^ exists(raw_texts) |
|
|
| if exists(raw_texts): |
| x = tokenizer.tokenize(raw_texts) |
|
|
| if not exists(mask): |
| mask = x != self.pad_id |
|
|
| b, n, device = *x.shape, x.device |
|
|
| |
|
|
| x = self.token_emb(x) |
| x = x + self.pos_emb(torch.arange(n, device = device)) |
|
|
| |
|
|
| cls_tokens = repeat(self.cls_token, 'd -> b d', b = b) |
| x, ps = pack([cls_tokens, x], 'b * d') |
|
|
| |
|
|
| mask = F.pad(mask, (1, 0), value = True) |
|
|
| |
|
|
| x = self.transformer(x, mask = mask) |
|
|
| |
|
|
| cls_tokens, _ = unpack(x, ps, 'b * d') |
|
|
| return self.norm(cls_tokens) |
|
|
| |
|
|
| @beartype |
| class MuLaN(nn.Module): |
| def __init__( |
| self, |
| audio_transformer: AudioSpectrogramTransformer, |
| text_transformer: TextTransformer, |
| dim_latent = 128, |
| decoupled_contrastive_learning = True, |
| ): |
| super().__init__() |
| self.dim_latent = dim_latent |
|
|
| self.audio = audio_transformer |
| self.text = text_transformer |
|
|
| self.temperature = nn.Parameter(torch.tensor(1.)) |
|
|
| self.text_to_latents = nn.Linear(self.text.dim, dim_latent) |
| self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent) |
|
|
| self.decoupled_contrastive_learning = decoupled_contrastive_learning |
|
|
| def get_audio_latents( |
| self, |
| wavs |
| ): |
| audio_embeds = self.audio(wavs) |
| audio_latents = self.audio_to_latents(audio_embeds) |
| return l2norm(audio_latents) |
|
|
| def get_text_latents( |
| self, |
| texts = None, |
| raw_texts: Optional[List[str]] = None |
| ): |
| text_embeds = self.text(texts) |
| text_latents = self.text_to_latents(text_embeds) |
| return l2norm(text_latents) |
|
|
| def forward( |
| self, |
| wavs, |
| texts = None, |
| raw_texts: Optional[List[str]] = None, |
| return_similarities = False |
| ): |
| batch, device = wavs.shape[0], wavs.device |
|
|
| audio_latents = self.get_audio_latents(wavs) |
| text_latents = self.get_text_latents(texts, raw_texts = raw_texts) |
|
|
| cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents) |
|
|
| assert cosine_sim.shape[0] == cosine_sim.shape[1], 'batch sizes for audio and text are not equal' |
|
|
| if return_similarities: |
| return cosine_sim |
|
|
| cosine_sim = cosine_sim * self.temperature.exp() |
|
|
| cosine_sim_exp = cosine_sim.exp() |
|
|
| numerator = cosine_sim_exp.diag() |
|
|
| if self.decoupled_contrastive_learning: |
| eye = torch.eye(batch, device = device) |
| cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.) |
|
|
| denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum') |
|
|
| contrastive_loss = -log(numerator / denominator) |
| return contrastive_loss.mean() |
|
|
| |
|
|
| @beartype |
| class MuLaNEmbedQuantizer(AudioConditionerBase): |
| def __init__( |
| self, |
| mulan: MuLaN, |
| conditioning_dims: Tuple[int, ...], |
| rq_num_quantizers = 8, |
| rq_ema_decay = 0.9, |
| codebook_size = 1024, |
| namespaces: Tuple[str, ...] = ('semantic', 'coarse', 'fine'), |
| |
| ): |
| super().__init__() |
| self.mulan = mulan |
|
|
| assert len(namespaces) > 0 |
| self.namespaces = namespaces |
| self.conditioning_dims = conditioning_dims |
|
|
| assert len(conditioning_dims) == len(namespaces), 'number of conditioning dimensions must be equal to number of namespaces' |
|
|
| dim = mulan.dim_latent |
|
|
| self.rq = ResidualVQ( |
| dim = dim, |
| num_quantizers = rq_num_quantizers, |
| codebook_size = codebook_size, |
| decay = rq_ema_decay, |
| commitment_weight = 0, |
| kmeans_init = True, |
| threshold_ema_dead_code = 2, |
| quantize_dropout = False |
| ) |
|
|
| self.dim = dim |
| self.num_codebooks = rq_num_quantizers |
|
|
| self.cond_embeddings = nn.ParameterDict({}) |
|
|
| for namespace, conditioning_dim in zip(namespaces, conditioning_dims): |
| cond_embeddings = nn.Parameter(torch.randn(rq_num_quantizers, codebook_size, conditioning_dim)) |
| nn.init.normal_(cond_embeddings, std = 0.02) |
|
|
| self.cond_embeddings[namespace] = cond_embeddings |
|
|
| self.set_default_namespace(namespaces[0]) |
|
|
| def parameters(self): |
| return self.cond_embeddings.parameters() |
|
|
| def set_default_namespace(self, namespace): |
| self._default_namespace = namespace |
|
|
| def forward( |
| self, |
| wavs = None, |
| texts = None, |
| namespace = None |
| ): |
| assert exists(wavs) ^ exists(texts) |
|
|
| namespace = default(namespace, self._default_namespace) |
| assert namespace in self.namespaces, f'namespace {namespace} not found' |
| cond_embeddings = self.cond_embeddings[namespace] |
|
|
| with torch.no_grad(): |
| self.mulan.eval() |
|
|
| |
|
|
| if exists(wavs): |
| latents = self.mulan.get_audio_latents(wavs) |
| elif exists(texts): |
| latents = self.mulan.get_text_latents(texts) |
|
|
| _, indices, _ = self.rq(latents) |
|
|
| batch, num_codebooks, dim = indices.shape[0], self.num_codebooks, cond_embeddings.shape[-1] |
|
|
| cond_embeddings = repeat(cond_embeddings, 'q c d -> b q c d', b = batch) |
| indices = repeat(indices, 'b q -> b q 1 d', q = num_codebooks, d = dim) |
|
|
| cond_embeddings = cond_embeddings.gather(2, indices) |
| return rearrange(cond_embeddings, 'b q 1 d -> b q d') |
|
|
| @beartype |
| class MusicLM(nn.Module): |
| def __init__( |
| self, |
| audio_lm: AudioLM, |
| mulan_embed_quantizer: MuLaNEmbedQuantizer |
| ): |
| super().__init__() |
| assert not exists(audio_lm.audio_conditioner), 'mulan must not have been passed into AudioLM. it will be managed externally now, embedding the text into the joint embedding space for text-to-audio synthesis' |
|
|
| self.mulan_embed_quantizer = mulan_embed_quantizer |
| self.audio_lm = audio_lm |
|
|
| @torch.no_grad() |
| def forward( |
| self, |
| raw_texts: List[str], |
| **audio_lm_kwargs |
| ): |
| self.eval() |
|
|
| texts = tokenizer.tokenize(raw_texts) |
|
|
| text_embeds = self.mulan_embed_quantizer(texts = texts) |
|
|
| return self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs) |