lofi-bytes / model /music_transformer.py
amosyou's picture
feat: add lofi-bytes-api and gradio app
7116323
import torch
import torch.nn as nn
from torch.nn.modules.normalization import LayerNorm
import random
from utilities.constants import *
from utilities.device import get_device
from .positional_encoding import PositionalEncoding
from .rpr import TransformerEncoderRPR, TransformerEncoderLayerRPR
# MusicTransformer
class MusicTransformer(nn.Module):
"""
----------
Author: Damon Gwinn
----------
Music Transformer reproduction from https://arxiv.org/abs/1809.04281. Arguments allow for
tweaking the transformer architecture (https://arxiv.org/abs/1706.03762) and the rpr argument
toggles Relative Position Representations (RPR - https://arxiv.org/abs/1803.02155).
Supports training and generation using Pytorch's nn.Transformer class with dummy decoder to
make a decoder-only transformer architecture
For RPR support, there is modified Pytorch 1.2.0 code in rpr.py. Modified source will be
kept up to date with Pytorch revisions only as necessary.
----------
"""
def __init__(self, n_layers=6, num_heads=8, d_model=512, dim_feedforward=1024,
dropout=0.1, max_sequence=2048, rpr=False):
super(MusicTransformer, self).__init__()
self.dummy = DummyDecoder()
self.nlayers = n_layers
self.nhead = num_heads
self.d_model = d_model
self.d_ff = dim_feedforward
self.dropout = dropout
self.max_seq = max_sequence
self.rpr = rpr
# Input embedding
self.embedding = nn.Embedding(VOCAB_SIZE, self.d_model)
# Positional encoding
self.positional_encoding = PositionalEncoding(self.d_model, self.dropout, self.max_seq)
# Base transformer
if(not self.rpr):
# To make a decoder-only transformer we need to use masked encoder layers
# Dummy decoder to essentially just return the encoder output
self.transformer = nn.Transformer(
d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ,
dim_feedforward=self.d_ff, custom_decoder=self.dummy
)
# RPR Transformer
else:
encoder_norm = LayerNorm(self.d_model)
encoder_layer = TransformerEncoderLayerRPR(self.d_model, self.nhead, self.d_ff, self.dropout, er_len=self.max_seq)
encoder = TransformerEncoderRPR(encoder_layer, self.nlayers, encoder_norm)
self.transformer = nn.Transformer(
d_model=self.d_model, nhead=self.nhead, num_encoder_layers=self.nlayers,
num_decoder_layers=0, dropout=self.dropout, # activation=self.ff_activ,
dim_feedforward=self.d_ff, custom_decoder=self.dummy, custom_encoder=encoder
)
# Final output is a softmaxed linear layer
self.Wout = nn.Linear(self.d_model, VOCAB_SIZE)
self.softmax = nn.Softmax(dim=-1)
# forward
def forward(self, x, mask=True):
"""
----------
Author: Damon Gwinn
----------
Takes an input sequence and outputs predictions using a sequence to sequence method.
A prediction at one index is the "next" prediction given all information seen previously.
----------
"""
if(mask is True):
mask = self.transformer.generate_square_subsequent_mask(x.shape[1]).to(get_device())
else:
mask = None
x = self.embedding(x)
# Input shape is (max_seq, batch_size, d_model)
x = x.permute(1,0,2)
x = self.positional_encoding(x)
# Since there are no true decoder layers, the tgt is unused
# Pytorch wants src and tgt to have some equal dims however
x_out = self.transformer(src=x, tgt=x, src_mask=mask)
# Back to (batch_size, max_seq, d_model)
x_out = x_out.permute(1,0,2)
y = self.Wout(x_out)
# y = self.softmax(y)
del mask
# They are trained to predict the next note in sequence (we don't need the last one)
return y
# generate
def generate(self, primer=None, target_seq_length=1024, beam=0, beam_chance=1.0):
"""
----------
Author: Damon Gwinn
----------
Generates midi given a primer sample. Music can be generated using a probability distribution over
the softmax probabilities (recommended) or by using a beam search.
----------
"""
assert (not self.training), "Cannot generate while in training mode"
print("Generating sequence of max length:", target_seq_length)
gen_seq = torch.full((1,target_seq_length), TOKEN_PAD, dtype=TORCH_LABEL_TYPE, device=get_device())
num_primer = len(primer)
gen_seq[..., :num_primer] = primer.type(TORCH_LABEL_TYPE).to(get_device())
# print("primer:",primer)
# print(gen_seq)
cur_i = num_primer
while(cur_i < target_seq_length):
# gen_seq_batch = gen_seq.clone()
y = self.softmax(self.forward(gen_seq[..., :cur_i]))[..., :TOKEN_END]
token_probs = y[:, cur_i-1, :]
if(beam == 0):
beam_ran = 2.0
else:
beam_ran = random.uniform(0,1)
if(beam_ran <= beam_chance):
token_probs = token_probs.flatten()
top_res, top_i = torch.topk(token_probs, beam)
beam_rows = top_i // VOCAB_SIZE
beam_cols = top_i % VOCAB_SIZE
gen_seq = gen_seq[beam_rows, :]
gen_seq[..., cur_i] = beam_cols
else:
distrib = torch.distributions.categorical.Categorical(probs=token_probs)
next_token = distrib.sample()
# print("next token:",next_token)
gen_seq[:, cur_i] = next_token
# Let the transformer decide to end if it wants to
if(next_token == TOKEN_END):
print("Model called end of sequence at:", cur_i, "/", target_seq_length)
break
cur_i += 1
if(cur_i % 50 == 0):
print(cur_i, "/", target_seq_length)
return gen_seq[:, :cur_i]
# Used as a dummy to nn.Transformer
# DummyDecoder
class DummyDecoder(nn.Module):
"""
----------
Author: Damon Gwinn
----------
A dummy decoder that returns its input. Used to make the Pytorch transformer into a decoder-only
architecture (stacked encoders with dummy decoder fits the bill)
----------
"""
def __init__(self):
super(DummyDecoder, self).__init__()
def forward(self, tgt, memory, tgt_mask, memory_mask,tgt_key_padding_mask,memory_key_padding_mask, **kwargs):
"""
----------
Author: Damon Gwinn
----------
Returns the input (memory)
----------
"""
return memory