| import os |
| from pathlib import Path |
| from typing import Union |
|
|
| from transformers import BatchEncoding, PythonBackend |
| from transformers.tokenization_utils_base import TruncationStrategy |
| from transformers.utils.generic import PaddingStrategy, TensorType |
|
|
| try: |
| from miditok import PerTok, TokSequence |
| from symusic.types import Score |
| import symusic |
| except ImportError: |
| raise ImportError( |
| "The `miditok` library is required for processing MIDI files. " |
| "Please install it with `pip install miditok`." |
| ) |
|
|
|
|
| class Song2MIDIPerTokTokenizer(PythonBackend): |
| vocab_files_names = {"vocab_file": "vocab.json"} |
|
|
| def __init__( |
| self, |
| vocab_file: str | os.PathLike | Path, |
| unk_token: str = "UNK_None", |
| bos_token: str = "BOS_None", |
| eos_token: str = "EOS_None", |
| pad_token: str = "PAD_None", |
| **kwargs, |
| ): |
| self._tokenizer = PerTok(params=vocab_file) |
| |
| |
| |
| if self._tokenizer.use_position_toks and not getattr(self._tokenizer, "position_locations", None): |
| self._tokenizer.position_locations = self._tokenizer._create_position_tok_locations() |
|
|
| self._decoder = {value: key for key, value in self._tokenizer.vocab.items()} |
|
|
| super().__init__( |
| unk_token=unk_token, |
| bos_token=bos_token, |
| eos_token=eos_token, |
| pad_token=pad_token, |
| **kwargs, |
| ) |
|
|
| @property |
| def vocab_size(self): |
| return len(self._tokenizer) |
|
|
| def get_vocab(self): |
| return self._tokenizer.vocab |
|
|
| def _encode_plus( |
| self, |
| text: Union[Score, Path, bytes, list[Union[Score, Path, bytes]], list[int]], |
| text_pair: Union[Score, Path, list[Union[Score, Path]], list[int], None] = None, |
| add_special_tokens: bool = True, |
| padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, |
| truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, |
| max_length: int | None = None, |
| stride: int = 0, |
| pad_to_multiple_of: int | None = None, |
| padding_side: str | None = None, |
| return_tensors: str | TensorType | None = None, |
| return_token_type_ids: bool | None = None, |
| return_attention_mask: bool | None = None, |
| return_overflowing_tokens: bool = False, |
| return_special_tokens_mask: bool = False, |
| return_length: bool = False, |
| verbose: bool = True, |
| **kwargs, |
| ): |
| midi = text |
| midi_pair = text_pair |
|
|
| |
| is_batched = isinstance(midi, (list, tuple)) and ( |
| (not midi) or (midi and isinstance(midi[0], (str, Path, Score, bytes))) |
| ) |
|
|
| if is_batched: |
| if midi_pair is not None: |
| if not isinstance(midi_pair, (list, tuple)) or len(midi_pair) != len( |
| midi |
| ): |
| raise ValueError( |
| "If `midi` is a batch, `midi_pair` must be a batch of the same length." |
| ) |
| pairs = midi_pair if midi_pair is not None else [None] * len(midi) |
|
|
| batch_outputs = {} |
| for current_midi, current_pair in zip(midi, pairs): |
| current_output = self._encode_plus( |
| text=current_midi, |
| text_pair=current_pair, |
| add_special_tokens=add_special_tokens, |
| padding_strategy=PaddingStrategy.DO_NOT_PAD, |
| truncation_strategy=truncation_strategy, |
| max_length=max_length, |
| stride=stride, |
| pad_to_multiple_of=None, |
| padding_side=None, |
| return_tensors=None, |
| return_token_type_ids=return_token_type_ids, |
| return_attention_mask=False, |
| return_overflowing_tokens=return_overflowing_tokens, |
| return_special_tokens_mask=return_special_tokens_mask, |
| return_length=return_length, |
| verbose=verbose, |
| **kwargs, |
| ) |
| for key, value in current_output.items(): |
| batch_outputs.setdefault(key, []).append(value) |
|
|
| |
| |
| if return_tensors and return_overflowing_tokens: |
| batch_outputs.pop("overflowing_tokens", None) |
| batch_outputs.pop("num_truncated_tokens", None) |
|
|
| batch_outputs = self.pad( |
| batch_outputs, |
| padding=padding_strategy.value, |
| max_length=max_length, |
| pad_to_multiple_of=pad_to_multiple_of, |
| padding_side=padding_side, |
| return_attention_mask=return_attention_mask, |
| ) |
|
|
| return BatchEncoding(batch_outputs, tensor_type=return_tensors) |
|
|
| |
| def get_input_ids(midi_input): |
| if not midi_input: |
| return [] |
| if isinstance(midi_input, (str, Path, Score, bytes)): |
| if isinstance(midi_input, bytes): |
| midi_input = symusic.Score.from_midi(midi_input) |
| return self._tokenizer.encode(midi_input).ids |
| if isinstance(midi_input, (list, tuple)) and midi_input: |
| if isinstance(midi_input[0], int): |
| return midi_input |
|
|
|
|
| raise ValueError( |
| "Input must be a Score, a path to a MIDI file, or a list of token IDs." |
| ) |
|
|
| first_ids = get_input_ids(midi) |
| second_ids = get_input_ids(midi_pair) if midi_pair is not None else None |
|
|
| return self.prepare_for_model( |
| first_ids, |
| pair_ids=second_ids, |
| add_special_tokens=add_special_tokens, |
| padding=padding_strategy.value, |
| truncation=truncation_strategy.value, |
| max_length=max_length, |
| stride=stride, |
| pad_to_multiple_of=pad_to_multiple_of, |
| padding_side=padding_side, |
| prepend_batch_axis=True, |
| return_attention_mask=return_attention_mask, |
| return_token_type_ids=return_token_type_ids, |
| return_overflowing_tokens=return_overflowing_tokens, |
| return_special_tokens_mask=return_special_tokens_mask, |
| return_length=return_length, |
| verbose=verbose, |
| ) |
|
|
| def _decode( |
| self, |
| token_ids: int | list[int], |
| skip_special_tokens: bool = False, |
| clean_up_tokenization_spaces: bool | None = None, |
| **kwargs, |
| ) -> str: |
| if isinstance(token_ids, int): |
| token_ids = [token_ids] |
|
|
| tok_sequence = TokSequence(ids=token_ids, are_ids_encoded=True) |
| self._tokenizer.decode_token_ids(tok_sequence) |
|
|
| tokens = [self._decoder[token_id] for token_id in tok_sequence.ids] |
|
|
| if skip_special_tokens: |
| tokens = [ |
| token for token in tokens if token not in self._tokenizer.special_tokens |
| ] |
|
|
| return " ".join(tokens) |
| |
| def decode_score( |
| self, |
| token_ids: int | list[int], |
| skip_special_tokens: bool = False, |
| clean_up_tokenization_spaces: bool | None = None, |
| **kwargs, |
| ) -> Score: |
| if isinstance(token_ids, int): |
| token_ids = [token_ids] |
|
|
| tok_sequence = TokSequence(ids=token_ids, are_ids_encoded=True) |
| return self._tokenizer.decode(tok_sequence) |
| |
| def save_vocabulary( |
| self, save_directory: str, filename_prefix: str | None = None |
| ) -> tuple[str, ...]: |
| """Save the MidiTok tokenizer params to disk.""" |
| if not os.path.isdir(save_directory): |
| return () |
|
|
| prefix = f"{filename_prefix}-" if filename_prefix else "" |
| vocab_file = os.path.join(save_directory, prefix + "vocab.json") |
|
|
| |
| self._tokenizer.save(vocab_file) |
|
|
| return (vocab_file,) |
|
|