import os from pathlib import Path from typing import Union from transformers import BatchEncoding, PythonBackend from transformers.tokenization_utils_base import TruncationStrategy from transformers.utils.generic import PaddingStrategy, TensorType try: from miditok import PerTok, TokSequence from symusic.types import Score import symusic except ImportError: raise ImportError( "The `miditok` library is required for processing MIDI files. " "Please install it with `pip install miditok`." ) class Song2MIDIPerTokTokenizer(PythonBackend): vocab_files_names = {"vocab_file": "vocab.json"} def __init__( self, vocab_file: str | os.PathLike | Path, unk_token: str = "UNK_None", bos_token: str = "BOS_None", eos_token: str = "EOS_None", pad_token: str = "PAD_None", **kwargs, ): self._tokenizer = PerTok(params=vocab_file) # PerTok as of miditok version 3.0.6.post1 does not load position token locations from the vocab file. # use_position_toks workaround if self._tokenizer.use_position_toks and not getattr(self._tokenizer, "position_locations", None): self._tokenizer.position_locations = self._tokenizer._create_position_tok_locations() self._decoder = {value: key for key, value in self._tokenizer.vocab.items()} super().__init__( unk_token=unk_token, bos_token=bos_token, eos_token=eos_token, pad_token=pad_token, **kwargs, ) @property def vocab_size(self): return len(self._tokenizer) def get_vocab(self): return self._tokenizer.vocab def _encode_plus( self, text: Union[Score, Path, bytes, list[Union[Score, Path, bytes]], list[int]], text_pair: Union[Score, Path, list[Union[Score, Path]], list[int], None] = None, add_special_tokens: bool = True, padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, max_length: int | None = None, stride: int = 0, pad_to_multiple_of: int | None = None, padding_side: str | None = None, return_tensors: str | TensorType | None = None, return_token_type_ids: bool | None = None, return_attention_mask: bool | None = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_length: bool = False, verbose: bool = True, **kwargs, ): # ty: ignore[invalid-method-override] midi = text midi_pair = text_pair # From https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_python.py (v5.3.0) is_batched = isinstance(midi, (list, tuple)) and ( (not midi) or (midi and isinstance(midi[0], (str, Path, Score, bytes))) ) if is_batched: if midi_pair is not None: if not isinstance(midi_pair, (list, tuple)) or len(midi_pair) != len( midi ): raise ValueError( "If `midi` is a batch, `midi_pair` must be a batch of the same length." ) pairs = midi_pair if midi_pair is not None else [None] * len(midi) batch_outputs = {} for current_midi, current_pair in zip(midi, pairs): current_output = self._encode_plus( text=current_midi, text_pair=current_pair, add_special_tokens=add_special_tokens, padding_strategy=PaddingStrategy.DO_NOT_PAD, # we pad in batch afterward truncation_strategy=truncation_strategy, max_length=max_length, stride=stride, pad_to_multiple_of=None, # we pad in batch afterward padding_side=None, # we pad in batch afterward return_tensors=None, # we convert the whole batch to tensors at the end return_token_type_ids=return_token_type_ids, return_attention_mask=False, # we pad in batch afterward return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_length=return_length, verbose=verbose, **kwargs, ) for key, value in current_output.items(): batch_outputs.setdefault(key, []).append(value) # Remove overflow-related keys before tensor conversion if return_tensors is set # Slow tokenizers don't support returning these as tensors if return_tensors and return_overflowing_tokens: batch_outputs.pop("overflowing_tokens", None) batch_outputs.pop("num_truncated_tokens", None) batch_outputs = self.pad( batch_outputs, padding=padding_strategy.value, max_length=max_length, pad_to_multiple_of=pad_to_multiple_of, padding_side=padding_side, return_attention_mask=return_attention_mask, ) return BatchEncoding(batch_outputs, tensor_type=return_tensors) # Single sequence handling def get_input_ids(midi_input): if not midi_input: return [] if isinstance(midi_input, (str, Path, Score, bytes)): if isinstance(midi_input, bytes): midi_input = symusic.Score.from_midi(midi_input) return self._tokenizer.encode(midi_input).ids if isinstance(midi_input, (list, tuple)) and midi_input: if isinstance(midi_input[0], int): return midi_input raise ValueError( "Input must be a Score, a path to a MIDI file, or a list of token IDs." ) first_ids = get_input_ids(midi) second_ids = get_input_ids(midi_pair) if midi_pair is not None else None return self.prepare_for_model( first_ids, pair_ids=second_ids, add_special_tokens=add_special_tokens, padding=padding_strategy.value, truncation=truncation_strategy.value, max_length=max_length, stride=stride, pad_to_multiple_of=pad_to_multiple_of, padding_side=padding_side, prepend_batch_axis=True, return_attention_mask=return_attention_mask, return_token_type_ids=return_token_type_ids, return_overflowing_tokens=return_overflowing_tokens, return_special_tokens_mask=return_special_tokens_mask, return_length=return_length, verbose=verbose, ) def _decode( self, token_ids: int | list[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool | None = None, **kwargs, ) -> str: if isinstance(token_ids, int): token_ids = [token_ids] tok_sequence = TokSequence(ids=token_ids, are_ids_encoded=True) self._tokenizer.decode_token_ids(tok_sequence) tokens = [self._decoder[token_id] for token_id in tok_sequence.ids] if skip_special_tokens: tokens = [ token for token in tokens if token not in self._tokenizer.special_tokens ] return " ".join(tokens) def decode_score( self, token_ids: int | list[int], skip_special_tokens: bool = False, clean_up_tokenization_spaces: bool | None = None, **kwargs, ) -> Score: if isinstance(token_ids, int): token_ids = [token_ids] tok_sequence = TokSequence(ids=token_ids, are_ids_encoded=True) return self._tokenizer.decode(tok_sequence) def save_vocabulary( self, save_directory: str, filename_prefix: str | None = None ) -> tuple[str, ...]: """Save the MidiTok tokenizer params to disk.""" if not os.path.isdir(save_directory): return () prefix = f"{filename_prefix}-" if filename_prefix else "" vocab_file = os.path.join(save_directory, prefix + "vocab.json") # Use MidiTok's own serialization self._tokenizer.save(vocab_file) return (vocab_file,)