Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch.nn.modules.normalization import LayerNorm | |
| import random | |
| from utilities.constants import * | |
| from utilities.device import get_device | |
| from .positional_encoding import PositionalEncoding | |
| from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR | |
| # MusicTransformer | |
| class MusicTransformer(nn.Module): | |
| """ | |
| ---------- | |
| Author: Damon Gwinn | |
| ---------- | |
| Music Transformer reproduction from https://arxiv.org/abs/1809.04281. Arguments allow for | |
| tweaking the transformer architecture (https://arxiv.org/abs/1706.03762) and the rpr argument | |
| toggles Relative Position Representations (RPR - https://arxiv.org/abs/1803.02155). | |
| Supports training and generation using Pytorch's nn.Transformer class with dummy decoder to | |
| make a decoder-only transformer architecture | |
| For RPR support, there is modified Pytorch 1.2.0 code in rpr.py. Modified source will be | |
| kept up to date with Pytorch revisions only as necessary. | |
| ---------- | |
| """ | |
| def __init__(self, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024, | |
| dropout=0.1, max_sequence=2048, rpr=False): | |
| super(MusicTransformer, self).__init__() | |
| self.dummy = DummyDecoder() | |
| self.nlayers = n_layers | |
| self.nhead = num_heads | |
| self.d_model = d_model | |
| self.d_ff = dim_feedforward | |
| self.dropout = dropout | |
| self.max_seq = max_sequence | |
| self.rpr = rpr | |
| # Input embedding | |
| self.embedding = nn.Embedding(VOCAB_SIZE, self.d_model) | |
| # Positional encoding | |
| self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq) | |
| # Base transformer | |
| if(not self.rpr): | |
| # To make a decoder-only transformer we need to use masked encoder layers | |
| # Dummy decoder to essentially just return the encoder output | |
| self.transformer = nn.Transformer( | |
| d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers, | |
| num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ, | |
| dim_feedforward=self.d_ff, custom_decoder=self.dummy | |
| ) | |
| # RPR Transformer | |
| else: | |
| encoder_norm = LayerNorm(self.d_model) | |
| encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, er_len=self.max_seq) | |
| encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm) | |
| self.transformer = nn.Transformer( | |
| d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers, | |
| num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ, | |
| dim_feedforward=self.d_ff, custom_decoder=self.dummy, custom_encoder=encoder | |
| ) | |
| # Final output is a softmaxed linear layer | |
| self.Wout = nn.Linear(self.d_model, VOCAB_SIZE) | |
| self.softmax = nn.Softmax(dim=-1) | |
| # forward | |
| def forward(self, x, mask=True): | |
| """ | |
| ---------- | |
| Author: Damon Gwinn | |
| ---------- | |
| Takes an input sequence and outputs predictions using a sequence to sequence method. | |
| A prediction at one index is the "next" prediction given all information seen previously. | |
| ---------- | |
| """ | |
| if(mask is True): | |
| mask = self.transformer.generate_square_subsequent_mask(x.shape[1]).to(get_device()) | |
| else: | |
| mask = None | |
| x = self.embedding(x) | |
| # Input shape is (max_seq, batch_size, d_model) | |
| x = x.permute(1,0,2) | |
| x = self.positional_encoding(x) | |
| # Since there are no true decoder layers, the tgt is unused | |
| # Pytorch wants src and tgt to have some equal dims however | |
| x_out = self.transformer(src=x, tgt=x, src_mask=mask) | |
| # Back to (batch_size, max_seq, d_model) | |
| x_out = x_out.permute(1,0,2) | |
| y = self.Wout(x_out) | |
| # y = self.softmax(y) | |
| del mask | |
| # They are trained to predict the next note in sequence (we don't need the last one) | |
| return y | |
| # generate | |
| def generate(self, primer=None, target_seq_length=1024, beam=0, beam_chance=1.0): | |
| """ | |
| ---------- | |
| Author: Damon Gwinn | |
| ---------- | |
| Generates midi given a primer sample. Music can be generated using a probability distribution over | |
| the softmax probabilities (recommended) or by using a beam search. | |
| ---------- | |
| """ | |
| assert (not self.training), "Cannot generate while in training mode" | |
| print("Generating sequence of max length:", target_seq_length) | |
| gen_seq = torch.full((1,target_seq_length), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device()) | |
| num_primer = len(primer) | |
| gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(get_device()) | |
| # print("primer:",primer) | |
| # print(gen_seq) | |
| cur_i = num_primer | |
| while(cur_i < target_seq_length): | |
| # gen_seq_batch = gen_seq.clone() | |
| y = self.softmax(self.forward(gen_seq[..., :cur_i]))[..., :TOKEN_END] | |
| token_probs = y[:, cur_i-1, :] | |
| if(beam == 0): | |
| beam_ran = 2.0 | |
| else: | |
| beam_ran = random.uniform(0,1) | |
| if(beam_ran <= beam_chance): | |
| token_probs = token_probs.flatten() | |
| top_res, top_i = torch.topk(token_probs, beam) | |
| beam_rows = top_i // VOCAB_SIZE | |
| beam_cols = top_i % VOCAB_SIZE | |
| gen_seq = gen_seq[beam_rows, :] | |
| gen_seq[..., cur_i] = beam_cols | |
| else: | |
| distrib = torch.distributions.categorical.Categorical(probs=token_probs) | |
| next_token = distrib.sample() | |
| # print("next token:",next_token) | |
| gen_seq[:, cur_i] = next_token | |
| # Let the transformer decide to end if it wants to | |
| if(next_token == TOKEN_END): | |
| print("Model called end of sequence at:", cur_i, "/", target_seq_length) | |
| break | |
| cur_i += 1 | |
| if(cur_i % 50 == 0): | |
| print(cur_i, "/", target_seq_length) | |
| return gen_seq[:, :cur_i] | |
| # Used as a dummy to nn.Transformer | |
| # DummyDecoder | |
| class DummyDecoder(nn.Module): | |
| """ | |
| ---------- | |
| Author: Damon Gwinn | |
| ---------- | |
| A dummy decoder that returns its input. Used to make the Pytorch transformer into a decoder-only | |
| architecture (stacked encoders with dummy decoder fits the bill) | |
| ---------- | |
| """ | |
| def __init__(self): | |
| super(DummyDecoder, self).__init__() | |
| def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask, **kwargs): | |
| """ | |
| ---------- | |
| Author: Damon Gwinn | |
| ---------- | |
| Returns the input (memory) | |
| ---------- | |
| """ | |
| return memory | |