| from copy import deepcopy |
| from pathlib import Path |
| from random import shuffle |
|
|
| from torch import Tensor, argmax |
| from torch.utils.data import DataLoader |
| from torch.cuda import is_available as cuda_available, is_bf16_supported |
| from torch.backends.mps import is_available as mps_available |
| from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoTokenizer, MistralForCausalLM |
| from transformers.trainer_utils import set_seed |
| from evaluate import load as load_metric |
| from miditok import REMI, TokenizerConfig |
| from miditok.pytorch_data import DatasetTok, DataCollator |
| from tqdm import tqdm |
|
|
| |
| PITCH_RANGE = (21, 109) |
| BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1} |
| NUM_VELOCITIES = 24 |
| SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"] |
| USE_CHORDS = False |
| USE_RESTS = False |
| USE_TEMPOS = True |
| USE_TIME_SIGNATURE = False |
| USE_PROGRAMS = False |
| NUM_TEMPOS = 32 |
| TEMPO_RANGE = (50, 200) |
| TOKENIZER_PARAMS = { |
| "pitch_range": PITCH_RANGE, |
| "beat_res": BEAT_RES, |
| "num_velocities": NUM_VELOCITIES, |
| "special_tokens": SPECIAL_TOKENS, |
| "use_chords": USE_CHORDS, |
| "use_rests": USE_RESTS, |
| "use_tempos": USE_TEMPOS, |
| "use_time_signatures": USE_TIME_SIGNATURE, |
| "use_programs": USE_PROGRAMS, |
| "num_tempos": NUM_TEMPOS, |
| "tempo_range": TEMPO_RANGE, |
| } |
| config = TokenizerConfig(**TOKENIZER_PARAMS) |
|
|
| |
| set_seed(777) |
|
|
| |
| tokenizer = REMI.from_pretrained("sunsetsobserver/MIDI") |
|
|
| midi_paths = list(Path('input').glob('**/*.mid')) + list(Path('input').glob('**/*.midi')) |
|
|
| """ list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) """ |
|
|
| |
| kwargs_dataset = {"min_seq_len": 10, "max_seq_len": 1024, "tokenizer": tokenizer} |
| dataset_test = DatasetTok(midi_paths, **kwargs_dataset) |
| collator = DataCollator( |
| tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"] |
| ) |
|
|
| |
| model = MistralForCausalLM.from_pretrained("./runs") |
|
|
| collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True) |
|
|
| (gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True) |
| generation_config = GenerationConfig( |
| max_new_tokens=512, |
| num_beams=1, |
| do_sample=True, |
| temperature=0.9, |
| top_k=15, |
| top_p=0.95, |
| epsilon_cutoff=3e-4, |
| eta_cutoff=1e-3, |
| ) |
|
|
| |
| |
| collator.pad_on_left = True |
| collator.eos_token = None |
| dataloader_test = DataLoader(dataset_test, batch_size=1, collate_fn=collator) |
| model.eval() |
| count = 0 |
| for batch in tqdm(dataloader_test, desc='Testing model / Generating results'): |
| res = model.generate( |
| inputs=batch["input_ids"].to(model.device), |
| attention_mask=batch["attention_mask"].to(model.device), |
| generation_config=generation_config) |
|
|
| |
| for prompt, continuation in zip(batch["input_ids"], res): |
| generated = continuation[len(prompt):] |
| midi = tokenizer.tokens_to_midi([deepcopy(generated.tolist())]) |
| tokens = [generated, prompt, continuation] |
| tokens = [seq.tolist() for seq in tokens] |
| for tok_seq in tokens[1:]: |
| _midi = tokenizer.tokens_to_midi([deepcopy(tok_seq)]) |
| midi.tracks.append(_midi.tracks[0]) |
| midi.tracks[0].name = f'Continuation of original sample ({len(generated)} tokens)' |
| midi.tracks[1].name = f'Original sample ({len(prompt)} tokens)' |
| midi.tracks[2].name = f'Original sample and continuation' |
| midi.dump_midi(gen_results_path / f'{count}.mid') |
| tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json') |
|
|
| count += 1 |