| """ Code by Nathan Fradet https://github.com/Natooz """ |
| """ Reorganised from his original Jupyter Notebook into a straight-forward code for quick execution on a supercomputing cluster """ |
|
|
| from copy import deepcopy |
| from pathlib import Path |
| from random import shuffle |
|
|
| from torch import Tensor, argmax |
| from torch.utils.data import DataLoader |
| from torch.cuda import is_available as cuda_available, is_bf16_supported |
| from torch.backends.mps import is_available as mps_available |
| from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig |
| from transformers.trainer_utils import set_seed |
| from evaluate import load as load_metric |
| from miditok import REMI, TokenizerConfig |
| from miditok.pytorch_data import DatasetTok, DataCollator |
| from tqdm import tqdm |
|
|
| |
| set_seed(777) |
|
|
| |
| PITCH_RANGE = (21, 109) |
| BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1} |
| NUM_VELOCITIES = 24 |
| SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"] |
| USE_CHORDS = False |
| USE_RESTS = False |
| USE_TEMPOS = True |
| USE_TIME_SIGNATURE = False |
| USE_PROGRAMS = False |
| NUM_TEMPOS = 32 |
| TEMPO_RANGE = (50, 200) |
| TOKENIZER_PARAMS = { |
| "pitch_range": PITCH_RANGE, |
| "beat_res": BEAT_RES, |
| "num_velocities": NUM_VELOCITIES, |
| "special_tokens": SPECIAL_TOKENS, |
| "use_chords": USE_CHORDS, |
| "use_rests": USE_RESTS, |
| "use_tempos": USE_TEMPOS, |
| "use_time_signatures": USE_TIME_SIGNATURE, |
| "use_programs": USE_PROGRAMS, |
| "num_tempos": NUM_TEMPOS, |
| "tempo_range": TEMPO_RANGE, |
| } |
| config = TokenizerConfig(**TOKENIZER_PARAMS) |
|
|
| |
| tokenizer = REMI(config) |
|
|
| |
| midi_paths = list(Path('Maestro').glob('**/*.mid')) + list(Path('Maestro').glob('**/*.midi')) |
|
|
| print(midi_paths[:5]) |
|
|
| tokenizer.learn_bpe( |
| vocab_size=1000, |
| files_paths=midi_paths, |
| start_from_empty_voc=False, |
| ) |
| tokenizer.save_params("tokenizer.json") |
|
|
| |
| total_num_files = len(midi_paths) |
| num_files_valid = round(total_num_files * 0.2) |
| num_files_test = round(total_num_files * 0.1) |
| shuffle(midi_paths) |
| midi_paths_valid = midi_paths[:num_files_valid] |
| midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test] |
| midi_paths_train = midi_paths[num_files_valid + num_files_test:] |
|
|
| |
| kwargs_dataset = {"min_seq_len": 256, "max_seq_len": 1024, "tokenizer": tokenizer} |
| dataset_train = DatasetTok(midi_paths_train, **kwargs_dataset) |
| dataset_valid = DatasetTok(midi_paths_valid, **kwargs_dataset) |
| dataset_test = DatasetTok(midi_paths_test, **kwargs_dataset) |
| collator = DataCollator( |
| tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"] |
| ) |
|
|
| model_config = MistralConfig( |
| vocab_size=len(tokenizer), |
| hidden_size=512, |
| intermediate_size=2048, |
| num_hidden_layers=8, |
| num_attention_heads=8, |
| num_key_value_heads=4, |
| sliding_window=256, |
| max_position_embeddings=8192, |
| pad_token_id=tokenizer['PAD_None'], |
| bos_token_id=tokenizer['BOS_None'], |
| eos_token_id=tokenizer['EOS_None'], |
| ) |
|
|
| |
| model = AutoModelForCausalLM.from_config(model_config) |
|
|
| metrics = {metric: load_metric(metric) for metric in ["accuracy"]} |
|
|
| def compute_metrics(eval_pred): |
| """ |
| Compute metrics for pretraining. |
| |
| Must use preprocess_logits function that converts logits to predictions (argmax or sampling). |
| |
| :param eval_pred: EvalPrediction containing predictions and labels |
| :return: metrics |
| """ |
| predictions, labels = eval_pred |
| not_pad_mask = labels != -100 |
| labels, predictions = labels[not_pad_mask], predictions[not_pad_mask] |
| return metrics["accuracy"].compute(predictions=predictions.flatten(), references=labels.flatten()) |
|
|
| def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor: |
| """ |
| Preprocess the logits before accumulating them during evaluation. |
| |
| This allows to significantly reduce the memory usage and make the training tractable. |
| """ |
| pred_ids = argmax(logits, dim=-1) |
| return pred_ids |
|
|
| |
| USE_CUDA = cuda_available() |
| if not cuda_available(): |
| FP16 = FP16_EVAL = BF16 = BF16_EVAL = False |
| elif is_bf16_supported(): |
| BF16 = BF16_EVAL = True |
| FP16 = FP16_EVAL = False |
| else: |
| BF16 = BF16_EVAL = False |
| FP16 = FP16_EVAL = True |
| USE_MPS = not USE_CUDA and mps_available() |
| training_config = TrainingArguments( |
| "runs", False, True, True, False, "steps", |
| per_device_train_batch_size=16, |
| per_device_eval_batch_size=48, |
| gradient_accumulation_steps=3, |
| eval_accumulation_steps=None, |
| eval_steps=100, |
| learning_rate=1e-4, |
| weight_decay=0.01, |
| max_grad_norm=3.0, |
| max_steps=1000, |
| lr_scheduler_type="cosine_with_restarts", |
| warmup_ratio=0.3, |
| log_level="debug", |
| logging_strategy="steps", |
| logging_steps=20, |
| save_strategy="steps", |
| save_steps=1000, |
| save_total_limit=5, |
| no_cuda=not USE_CUDA, |
| seed=444, |
| fp16=FP16, |
| fp16_full_eval=FP16_EVAL, |
| bf16=BF16, |
| bf16_full_eval=BF16_EVAL, |
| load_best_model_at_end=True, |
| label_smoothing_factor=0., |
| optim="adamw_torch", |
| report_to=["tensorboard"], |
| gradient_checkpointing=True, |
| ) |
|
|
| collator = DataCollator(tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"], copy_inputs_as_labels=True) |
| trainer = Trainer( |
| model=model, |
| args=training_config, |
| data_collator=collator, |
| train_dataset=dataset_train, |
| eval_dataset=dataset_valid, |
| compute_metrics=compute_metrics, |
| callbacks=None, |
| preprocess_logits_for_metrics=preprocess_logits, |
| ) |
|
|
| |
| train_result = trainer.train() |
| trainer.save_model() |
| trainer.log_metrics("train", train_result.metrics) |
| trainer.save_metrics("train", train_result.metrics) |
| trainer.save_state() |
| trainer.push_to_hub() |