song2midi-processor / tokenization_song2midi.py
B-K's picture
Upload tokenization_song2midi.py
56ce0c8 verified
import os
from pathlib import Path
from typing import Union
from transformers import BatchEncoding, PythonBackend
from transformers.tokenization_utils_base import TruncationStrategy
from transformers.utils.generic import PaddingStrategy, TensorType
try:
from miditok import PerTok, TokSequence
from symusic.types import Score
import symusic
except ImportError:
raise ImportError(
"The `miditok` library is required for processing MIDI files. "
"Please install it with `pip install miditok`."
)
class Song2MIDIPerTokTokenizer(PythonBackend):
vocab_files_names = {"vocab_file": "vocab.json"}
def __init__(
self,
vocab_file: str | os.PathLike | Path,
unk_token: str = "UNK_None",
bos_token: str = "BOS_None",
eos_token: str = "EOS_None",
pad_token: str = "PAD_None",
**kwargs,
):
self._tokenizer = PerTok(params=vocab_file)
# PerTok as of miditok version 3.0.6.post1 does not load position token locations from the vocab file.
# use_position_toks workaround
if self._tokenizer.use_position_toks and not getattr(self._tokenizer, "position_locations", None):
self._tokenizer.position_locations = self._tokenizer._create_position_tok_locations()
self._decoder = {value: key for key, value in self._tokenizer.vocab.items()}
super().__init__(
unk_token=unk_token,
bos_token=bos_token,
eos_token=eos_token,
pad_token=pad_token,
**kwargs,
)
@property
def vocab_size(self):
return len(self._tokenizer)
def get_vocab(self):
return self._tokenizer.vocab
def _encode_plus(
self,
text: Union[Score, Path, bytes, list[Union[Score, Path, bytes]], list[int]],
text_pair: Union[Score, Path, list[Union[Score, Path]], list[int], None] = None,
add_special_tokens: bool = True,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE,
max_length: int | None = None,
stride: int = 0,
pad_to_multiple_of: int | None = None,
padding_side: str | None = None,
return_tensors: str | TensorType | None = None,
return_token_type_ids: bool | None = None,
return_attention_mask: bool | None = None,
return_overflowing_tokens: bool = False,
return_special_tokens_mask: bool = False,
return_length: bool = False,
verbose: bool = True,
**kwargs,
): # ty: ignore[invalid-method-override]
midi = text
midi_pair = text_pair
# From https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_python.py (v5.3.0)
is_batched = isinstance(midi, (list, tuple)) and (
(not midi) or (midi and isinstance(midi[0], (str, Path, Score, bytes)))
)
if is_batched:
if midi_pair is not None:
if not isinstance(midi_pair, (list, tuple)) or len(midi_pair) != len(
midi
):
raise ValueError(
"If `midi` is a batch, `midi_pair` must be a batch of the same length."
)
pairs = midi_pair if midi_pair is not None else [None] * len(midi)
batch_outputs = {}
for current_midi, current_pair in zip(midi, pairs):
current_output = self._encode_plus(
text=current_midi,
text_pair=current_pair,
add_special_tokens=add_special_tokens,
padding_strategy=PaddingStrategy.DO_NOT_PAD, # we pad in batch afterward
truncation_strategy=truncation_strategy,
max_length=max_length,
stride=stride,
pad_to_multiple_of=None, # we pad in batch afterward
padding_side=None, # we pad in batch afterward
return_tensors=None, # we convert the whole batch to tensors at the end
return_token_type_ids=return_token_type_ids,
return_attention_mask=False, # we pad in batch afterward
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose,
**kwargs,
)
for key, value in current_output.items():
batch_outputs.setdefault(key, []).append(value)
# Remove overflow-related keys before tensor conversion if return_tensors is set
# Slow tokenizers don't support returning these as tensors
if return_tensors and return_overflowing_tokens:
batch_outputs.pop("overflowing_tokens", None)
batch_outputs.pop("num_truncated_tokens", None)
batch_outputs = self.pad(
batch_outputs,
padding=padding_strategy.value,
max_length=max_length,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
return_attention_mask=return_attention_mask,
)
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
# Single sequence handling
def get_input_ids(midi_input):
if not midi_input:
return []
if isinstance(midi_input, (str, Path, Score, bytes)):
if isinstance(midi_input, bytes):
midi_input = symusic.Score.from_midi(midi_input)
return self._tokenizer.encode(midi_input).ids
if isinstance(midi_input, (list, tuple)) and midi_input:
if isinstance(midi_input[0], int):
return midi_input
raise ValueError(
"Input must be a Score, a path to a MIDI file, or a list of token IDs."
)
first_ids = get_input_ids(midi)
second_ids = get_input_ids(midi_pair) if midi_pair is not None else None
return self.prepare_for_model(
first_ids,
pair_ids=second_ids,
add_special_tokens=add_special_tokens,
padding=padding_strategy.value,
truncation=truncation_strategy.value,
max_length=max_length,
stride=stride,
pad_to_multiple_of=pad_to_multiple_of,
padding_side=padding_side,
prepend_batch_axis=True,
return_attention_mask=return_attention_mask,
return_token_type_ids=return_token_type_ids,
return_overflowing_tokens=return_overflowing_tokens,
return_special_tokens_mask=return_special_tokens_mask,
return_length=return_length,
verbose=verbose,
)
def _decode(
self,
token_ids: int | list[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool | None = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
tok_sequence = TokSequence(ids=token_ids, are_ids_encoded=True)
self._tokenizer.decode_token_ids(tok_sequence)
tokens = [self._decoder[token_id] for token_id in tok_sequence.ids]
if skip_special_tokens:
tokens = [
token for token in tokens if token not in self._tokenizer.special_tokens
]
return " ".join(tokens)
def decode_score(
self,
token_ids: int | list[int],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool | None = None,
**kwargs,
) -> Score:
if isinstance(token_ids, int):
token_ids = [token_ids]
tok_sequence = TokSequence(ids=token_ids, are_ids_encoded=True)
return self._tokenizer.decode(tok_sequence)
def save_vocabulary(
self, save_directory: str, filename_prefix: str | None = None
) -> tuple[str, ...]:
"""Save the MidiTok tokenizer params to disk."""
if not os.path.isdir(save_directory):
return ()
prefix = f"{filename_prefix}-" if filename_prefix else ""
vocab_file = os.path.join(save_directory, prefix + "vocab.json")
# Use MidiTok's own serialization
self._tokenizer.save(vocab_file)
return (vocab_file,)