| from dataclasses import dataclass |
| import torch |
| from tqdm.auto import trange |
| import typing as tp |
| from einops import rearrange |
| from torch import nn |
|
|
| from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config |
| from .factory import create_pretransform_from_config |
| from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone |
| from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform |
| from .utils import multinomial, sample_top_k, sample_top_p |
|
|
| from .codebook_patterns import ( |
| CodebooksPatternProvider, |
| DelayedPatternProvider, |
| MusicLMPattern, |
| ParallelPatternProvider, |
| UnrolledPatternProvider |
| ) |
|
|
| |
| |
|
|
| @dataclass |
| class LMOutput: |
| |
| |
| logits: torch.Tensor |
| mask: torch.Tensor |
|
|
| |
| |
| class AudioLanguageModel(nn.Module): |
| def __init__( |
| self, |
| pattern_provider: CodebooksPatternProvider, |
| backbone: AudioLMBackbone, |
| num_quantizers: int, |
| codebook_size: int |
| ): |
| super().__init__() |
|
|
| self.pattern_provider = pattern_provider |
| self.backbone = backbone |
| self.num_quantizers = num_quantizers |
| self.codebook_size = codebook_size |
|
|
| self.masked_token_id = codebook_size |
|
|
| |
| |
| self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)]) |
|
|
| |
| self.quantizer_heads = nn.ModuleList([ |
| nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers) |
| ]) |
|
|
| def forward(self, |
| sequence: torch.Tensor, |
| prepend_cond=None, |
| prepend_cond_mask=None, |
| cross_attn_cond=None, |
| **kwargs |
| ): |
| |
| |
| batch, num_quantizers, seq_len = sequence.shape |
|
|
| assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model" |
|
|
| backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) |
|
|
| dtype = next(self.parameters()).dtype |
|
|
| if cross_attn_cond is not None: |
| cross_attn_cond = cross_attn_cond.to(dtype) |
|
|
| if prepend_cond is not None: |
| prepend_cond = prepend_cond.to(dtype) |
|
|
| if prepend_cond_mask is not None: |
| prepend_cond_mask = prepend_cond_mask.to(dtype) |
| |
| backbone_input = backbone_input.to(dtype) |
|
|
| output = self.backbone( |
| backbone_input, |
| cross_attn_cond=cross_attn_cond, |
| prepend_cond=prepend_cond, |
| prepend_cond_mask=prepend_cond_mask, |
| **kwargs |
| ) |
|
|
| |
| logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) |
|
|
| return logits |
| |
| def compute_logits( |
| self, |
| codes, |
| **kwargs): |
| """ |
| Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning |
| Handles translation between input sequence and pattern-shifted sequence |
| Only used during training |
| """ |
| |
| batch, _, seq_len = codes.shape |
|
|
| pattern = self.pattern_provider.get_pattern(seq_len) |
|
|
| |
| shifted_codes, _, _ = pattern.build_pattern_sequence( |
| codes, |
| self.masked_token_id, |
| keep_only_valid_steps=True |
| ) |
|
|
| |
| logits = self(shifted_codes, **kwargs) |
|
|
| |
| logits = rearrange(logits, "b n s c -> b c n s") |
|
|
| |
| logits, _, logits_mask = pattern.revert_pattern_logits( |
| logits, float('nan'), keep_only_valid_steps=True |
| ) |
|
|
| logits = rearrange(logits, "b c n t -> b n t c") |
|
|
| logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) |
|
|
| return LMOutput(logits=logits, mask=logits_mask) |
|
|
| |
| |
| class AudioLanguageModelWrapper(nn.Module): |
| def __init__( |
| self, |
| pretransform: Pretransform, |
| lm: AudioLanguageModel, |
| sample_rate: int, |
| min_input_length: int, |
| conditioner: MultiConditioner = None, |
| cross_attn_cond_ids: tp.List[str] = [], |
| prepend_cond_ids: tp.List[str] = [], |
| global_cond_ids: tp.List[str] = [] |
| ): |
| super().__init__() |
| |
| assert pretransform.is_discrete, "Pretransform must be discrete" |
| self.pretransform = pretransform |
|
|
| self.pretransform.requires_grad_(False) |
| self.pretransform.eval() |
|
|
| if isinstance(self.pretransform, AutoencoderPretransform): |
| self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers |
| self.codebook_size = self.pretransform.model.bottleneck.codebook_size |
| elif isinstance(self.pretransform, PretrainedDACPretransform): |
| self.num_quantizers = self.pretransform.model.num_quantizers |
| self.codebook_size = self.pretransform.model.codebook_size |
| elif isinstance(self.pretransform, AudiocraftCompressionPretransform): |
| self.num_quantizers = self.pretransform.num_quantizers |
| self.codebook_size = self.pretransform.codebook_size |
| else: |
| raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}") |
|
|
| self.conditioner = conditioner |
|
|
| self.lm = lm |
|
|
| self.sample_rate = sample_rate |
| self.min_input_length = min_input_length |
|
|
| self.cross_attn_cond_ids = cross_attn_cond_ids |
| self.prepend_cond_ids = prepend_cond_ids |
| self.global_cond_ids = global_cond_ids |
| |
| def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): |
| cross_attention_input = None |
| prepend_cond = None |
| prepend_cond_mask = None |
| global_cond = None |
|
|
| if len(self.cross_attn_cond_ids) > 0: |
| |
| |
| cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1) |
|
|
| if len(self.prepend_cond_ids) > 0: |
| |
| |
| prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) |
| prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) |
|
|
| if len(self.global_cond_ids) > 0: |
| |
| |
| global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) |
| if len(global_cond.shape) == 3: |
| global_cond = global_cond.squeeze(1) |
|
|
| if negative: |
| return { |
| "negative_cross_attn_cond": cross_attention_input, |
| "negative_prepend_cond": prepend_cond, |
| "negative_prepend_cond_mask": prepend_cond_mask, |
| "negative_global_cond": global_cond |
| } |
| else: |
| return { |
| "cross_attn_cond": cross_attention_input, |
| "prepend_cond": prepend_cond, |
| "prepend_cond_mask": prepend_cond_mask, |
| "global_cond": global_cond |
| } |
| |
| def compute_logits( |
| self, |
| codes, |
| condition_tensors=None, |
| cfg_dropout_prob=0.0, |
| **kwargs |
| ): |
| """ |
| Compute logits for a batch of codes, and translates from conditioning inputs to model inputs |
| Handles CFG dropout |
| """ |
|
|
| if condition_tensors is None: |
| condition_tensors = {} |
|
|
| conditioning_inputs = self.get_conditioning_inputs(condition_tensors) |
|
|
| cross_attn_cond = conditioning_inputs["cross_attn_cond"] |
| prepend_cond = conditioning_inputs["prepend_cond"] |
| prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] |
| global_cond = conditioning_inputs["global_cond"] |
|
|
| if cfg_dropout_prob > 0.0: |
| if cross_attn_cond is not None: |
| null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) |
| dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) |
| cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) |
| |
| if prepend_cond is not None: |
| null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) |
| dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) |
| prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) |
|
|
| if global_cond is not None: |
| null_embed = torch.zeros_like(global_cond, device=global_cond.device) |
| dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) |
| global_cond = torch.where(dropout_mask, null_embed, global_cond) |
|
|
| return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) |
|
|
| def _sample_next_token( |
| self, |
| sequence, |
| conditioning_tensors=None, |
| cross_attn_use_cfg=True, |
| prepend_use_cfg=True, |
| global_use_cfg=True, |
| cfg_scale=1.0, |
| top_k=250, |
| top_p=0.0, |
| temp=1.0, |
| **kwargs |
| ): |
| """ |
| Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs |
| Handles CFG inference |
| """ |
|
|
| if conditioning_tensors is None: |
| conditioning_tensors = {} |
|
|
| conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors) |
|
|
| cross_attn_cond = conditioning_inputs["cross_attn_cond"] |
| prepend_cond = conditioning_inputs["prepend_cond"] |
| prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] |
| global_cond = conditioning_inputs["global_cond"] |
|
|
| if cfg_scale != 1.0: |
| |
| |
| sequence = torch.cat([sequence, sequence], dim=0) |
|
|
| if cross_attn_cond is not None and cross_attn_use_cfg: |
| null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) |
|
|
| cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) |
| |
| if prepend_cond is not None and prepend_use_cfg: |
| null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) |
|
|
| prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) |
|
|
| if prepend_cond_mask is not None: |
| prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) |
|
|
| if global_cond is not None and global_use_cfg: |
| null_embed = torch.zeros_like(global_cond, device=global_cond.device) |
|
|
| global_cond = torch.cat([global_cond, null_embed], dim=0) |
|
|
| logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) |
|
|
| if cfg_scale != 1.0: |
| cond_logits, uncond_logits = logits.chunk(2, dim=0) |
|
|
| logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale |
|
|
| logits = rearrange(logits, "b n s c -> b n c s") |
| |
| |
| logits = logits[:, :, :, -1] |
|
|
| |
|
|
| if temp > 0: |
| probs = torch.softmax(logits / temp, dim=-1) |
|
|
| if top_p > 0.0: |
| next_token = sample_top_p(probs, p=top_p) |
| elif top_k > 0: |
| next_token = sample_top_k(probs, k=top_k) |
| else: |
| next_token = multinomial(probs, num_samples=1) |
|
|
| else: |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| return next_token |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| max_gen_len: int = 256, |
| batch_size: tp.Optional[int] = None, |
| init_data: tp.Optional[torch.Tensor] = None, |
| conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, |
| conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None, |
| callback: tp.Optional[tp.Callable[[int, int], None]] = None, |
| use_cache: bool = True, |
| cfg_scale: float = 1.0, |
| **kwargs |
| ): |
| device = next(self.parameters()).device |
|
|
| if conditioning_tensors is None and conditioning is not None: |
| |
| conditioning_tensors = self.conditioner(conditioning, device) |
|
|
| |
| possible_batch_sizes = [] |
|
|
| if batch_size is not None: |
| possible_batch_sizes.append(batch_size) |
| elif init_data is not None: |
| possible_batch_sizes.append(init_data.shape[0]) |
| elif conditioning_tensors is not None: |
| |
| possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) |
| else: |
| possible_batch_sizes.append(1) |
|
|
| assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" |
|
|
| batch_size = possible_batch_sizes[0] |
| |
| if init_data is None: |
| |
| assert batch_size > 0 |
| init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) |
|
|
| batch_size, num_quantizers, seq_len = init_data.shape |
|
|
| start_offset = seq_len |
| assert start_offset < max_gen_len, "init data longer than max gen length" |
|
|
| pattern = self.lm.pattern_provider.get_pattern(max_gen_len) |
|
|
| unknown_token = -1 |
|
|
| |
| gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) |
| gen_codes[:, :, :start_offset] = init_data |
|
|
| gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) |
|
|
| start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) |
| assert start_offset_sequence is not None |
|
|
| |
| prev_offset = 0 |
| gen_sequence_len = gen_sequence.shape[-1] |
|
|
| |
| if use_cache and self.lm.backbone.use_generation_cache: |
| self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2) |
|
|
| for offset in trange(start_offset_sequence, gen_sequence_len): |
|
|
| |
| curr_sequence = gen_sequence[..., prev_offset:offset] |
|
|
| next_token = self._sample_next_token( |
| curr_sequence, |
| conditioning_tensors=conditioning_tensors, |
| use_cache=use_cache, |
| cfg_scale=cfg_scale, |
| **kwargs |
| ) |
|
|
| valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) |
| next_token[~valid_mask] = self.lm.masked_token_id |
|
|
| |
| gen_sequence[..., offset:offset+1] = torch.where( |
| gen_sequence[..., offset:offset+1] == unknown_token, |
| next_token, |
| gen_sequence[..., offset:offset+1] |
| ) |
|
|
| if use_cache and self.lm.backbone.use_generation_cache: |
| |
| prev_offset = offset |
|
|
| self.lm.backbone.update_generation_cache(offset) |
|
|
| if callback is not None: |
| |
| |
| callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) |
|
|
| assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" |
|
|
| out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) |
|
|
| |
| assert (out_codes[..., :max_gen_len] != unknown_token).all() |
| assert (out_mask[..., :max_gen_len] == 1).all() |
|
|
| |
|
|
| return out_codes |
| |
|
|
| def generate_audio( |
| self, |
| **kwargs |
| ): |
| """ |
| Generate audio from a batch of codes |
| """ |
|
|
| codes = self.generate(**kwargs) |
|
|
| audio = self.pretransform.decode_tokens(codes) |
|
|
| return audio |
|
|
|
|
| def create_audio_lm_from_config(config): |
| model_config = config.get('model', None) |
| assert model_config is not None, 'model config must be specified in config' |
|
|
| sample_rate = config.get('sample_rate', None) |
| assert sample_rate is not None, "Must specify sample_rate in config" |
| |
| lm_config = model_config.get('lm', None) |
| assert lm_config is not None, 'lm config must be specified in model config' |
|
|
| codebook_pattern = lm_config.get("codebook_pattern", "delay") |
|
|
| pattern_providers = { |
| 'parallel': ParallelPatternProvider, |
| 'delay': DelayedPatternProvider, |
| 'unroll': UnrolledPatternProvider, |
| 'musiclm': MusicLMPattern, |
| } |
|
|
| pretransform_config = model_config.get("pretransform", None) |
| |
| pretransform = create_pretransform_from_config(pretransform_config, sample_rate) |
|
|
| assert pretransform.is_discrete, "Pretransform must be discrete" |
|
|
| min_input_length = pretransform.downsampling_ratio |
|
|
| pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers) |
|
|
| conditioning_config = model_config.get('conditioning', None) |
|
|
| conditioner = None |
| if conditioning_config is not None: |
| conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) |
|
|
| cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) |
| prepend_cond_ids = lm_config.get('prepend_cond_ids', []) |
| global_cond_ids = lm_config.get('global_cond_ids', []) |
|
|
| lm_type = lm_config.get("type", None) |
| lm_model_config = lm_config.get("config", None) |
|
|
| assert lm_type is not None, "Must specify lm type in lm config" |
| assert lm_model_config is not None, "Must specify lm model config in lm config" |
|
|
| if lm_type == "x-transformers": |
| backbone = XTransformersAudioLMBackbone(**lm_model_config) |
| elif lm_type == "continuous_transformer": |
| backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config) |
| else: |
| raise NotImplementedError(f"Unrecognized lm type {lm_type}") |
|
|
| lm = AudioLanguageModel( |
| pattern_provider=pattern_provider, |
| backbone=backbone, |
| num_quantizers=pretransform.num_quantizers, |
| codebook_size=pretransform.codebook_size |
| ) |
|
|
| model = AudioLanguageModelWrapper( |
| pretransform=pretransform, |
| lm=lm, |
| conditioner=conditioner, |
| sample_rate=sample_rate, |
| min_input_length=min_input_length, |
| cross_attn_cond_ids=cross_attn_cond_ids, |
| prepend_cond_ids=prepend_cond_ids, |
| global_cond_ids=global_cond_ids |
| ) |
|
|
| return model |