| from typing import TYPE_CHECKING, List, Optional, Tuple |
|
|
| from transformers.tokenization_utils import PreTrainedTokenizer, BatchEncoding |
| from transformers.utils import logging, TensorType, to_py_obj |
|
|
| try: |
| from ariautils.midi import MidiDict |
| from ariautils.tokenizer import AbsTokenizer |
| from ariautils.tokenizer._base import Token |
| except ImportError: |
| raise ImportError( |
| "ariautils is not installed. Please try `pip install git+https://github.com/EleutherAI/aria-utils.git`." |
| ) |
|
|
| if TYPE_CHECKING: |
| pass |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class AriaTokenizer(PreTrainedTokenizer): |
| """ |
| Aria Tokenizer is NOT a BPE tokenizer. A midi file will be converted to a MidiDict (note: in fact, a MidiDict is not a single dict. It is more about a list of "notes") which represents a sequence of notes, stops, etc. And then, aria tokenizer is simply a dictionary that maps MidiDict to discrete indices according to a hard-coded rule. |
| |
| For a FIM finetuned model, we also follow a simple FIM format to guide a piece of music to a (possibly very different) suffix according to the prompts: |
| <GUIDANCE-START> ... <GUIDANCE-END> <S> <PROMPT-START> ... <PROMPT-END> |
| This way, we expect a continuation that connects PROMPT and GUIDANCE. |
| """ |
|
|
| vocab_files_names = {} |
| model_input_names = ["input_ids", "attention_mask"] |
|
|
| def __init__( |
| self, |
| add_eos_token=True, |
| add_dim_token=False, |
| clean_up_tokenization_spaces=False, |
| use_default_system_prompt=False, |
| **kwargs, |
| ): |
| self._tokenizer = AbsTokenizer() |
|
|
| self.add_eos_token = add_eos_token |
| self.add_dim_token = add_dim_token |
| self.use_default_system_prompt = use_default_system_prompt |
|
|
| bos_token = self._tokenizer.bos_tok |
| eos_token = self._tokenizer.eos_tok |
| pad_token = self._tokenizer.pad_tok |
| unk_token = self._tokenizer.unk_tok |
|
|
| super().__init__( |
| bos_token=bos_token, |
| eos_token=eos_token, |
| unk_token=unk_token, |
| pad_token=pad_token, |
| use_default_system_prompt=use_default_system_prompt, |
| **kwargs, |
| ) |
|
|
| def __getstate__(self): |
| return {} |
|
|
| def __setstate__(self, d): |
| raise NotImplementedError() |
|
|
| @property |
| def vocab_size(self): |
| """Returns vocab size""" |
| return self._tokenizer.vocab_size |
|
|
| def get_vocab(self): |
| return self._tokenizer.tok_to_id |
|
|
| def tokenize( |
| self, |
| midi_dict: MidiDict, |
| add_dim_token: Optional[bool] = None, |
| add_eos_token: Optional[bool] = None, |
| **kwargs, |
| ) -> List[Token]: |
| return self._tokenizer.tokenize( |
| midi_dict=midi_dict, |
| add_dim_tok=( |
| add_dim_token |
| if add_dim_token is not None |
| else self.add_dim_token |
| ), |
| add_eos_tok=( |
| add_eos_token |
| if add_eos_token is not None |
| else self.add_eos_token |
| ), |
| ) |
|
|
| def _tokenize( |
| self, |
| midi_dict: MidiDict, |
| add_dim_token: Optional[bool] = None, |
| add_eos_token: Optional[bool] = None, |
| **kwargs, |
| ) -> List[Token]: |
| return self._tokenizer.tokenize( |
| midi_dict=midi_dict, |
| add_dim_tok=add_dim_token, |
| add_eos_tok=add_eos_token, |
| ) |
|
|
| def __call__( |
| self, |
| midi_dicts: MidiDict | list[MidiDict], |
| padding: bool = False, |
| max_length: int | None = None, |
| pad_to_multiple_of: int | None = None, |
| return_tensors: str | TensorType | None = None, |
| return_attention_mask: bool | None = None, |
| **kwargs, |
| ) -> BatchEncoding: |
| """It is impossible to rely on the parent method because the inputs are MidiDict(s) instead of strings. I do not like the idea of going hacky so that two entirely different types of inputs can marry. So here I reimplement __call__ with limited support of certain useful arguments. I do not expect any conflict with other "string-in-ids-out" tokenizers. If you have to mix up the API of string-based tokenizers and our midi-based tokenizer, there must be a problem with your design.""" |
| if isinstance(midi_dicts, MidiDict): |
| midi_dicts = [midi_dicts] |
|
|
| all_tokens: list[list[int]] = [] |
| all_attn_masks: list[list[int]] = [] |
| max_len_encoded = 0 |
| for md in midi_dicts: |
| tokens = self._tokenizer.encode(self._tokenizer.tokenize(md)) |
| if max_length is not None: |
| tokens = tokens[:max_length] |
| max_len_encoded = max(max_len_encoded, len(tokens)) |
| all_tokens.append(tokens) |
| all_attn_masks.append([True] * len(tokens)) |
|
|
| if pad_to_multiple_of is not None: |
| max_len_encoded = ( |
| (max_len_encoded + pad_to_multiple_of) // pad_to_multiple_of |
| ) * pad_to_multiple_of |
| if padding: |
| for tokens, attn_mask in zip(all_tokens, all_attn_masks): |
| tokens.extend( |
| [self._tokenizer.pad_id] * (max_len_encoded - len(tokens)) |
| ) |
| attn_mask.extend([False] * (max_len_encoded - len(tokens))) |
|
|
| return BatchEncoding( |
| { |
| "input_ids": all_tokens, |
| "attention_masks": all_attn_masks, |
| }, |
| tensor_type=return_tensors, |
| ) |
|
|
| def decode(self, token_ids: List[int], **kwargs) -> MidiDict: |
| token_ids = to_py_obj(token_ids) |
|
|
| return self._tokenizer.detokenize(self._tokenizer.decode(token_ids)) |
|
|
| def batch_decode( |
| self, token_ids_list: List[List[Token]], **kwargs |
| ) -> List[MidiDict]: |
| results = [] |
| for token_ids in token_ids_list: |
| results.append(self.decode(token_ids)) |
| return results |
|
|
| def encode_from_file(self, filename: str, **kwargs) -> BatchEncoding: |
| midi_dict = MidiDict.from_midi(filename) |
| return self(midi_dict, **kwargs) |
|
|
| def encode_from_files( |
| self, filenames: list[str], **kwargs |
| ) -> BatchEncoding: |
| midi_dicts = [MidiDict.from_midi(file) for file in filenames] |
| return self(midi_dicts, **kwargs) |
|
|
| def _convert_token_to_id(self, token: Token): |
| """Converts a token (tuple or str) into an id.""" |
| return self._tokenizer.tok_to_id.get( |
| token, self._tokenizer.tok_to_id[self.unk_token] |
| ) |
|
|
| def _convert_id_to_token(self, index: int): |
| """Converts an index (integer) in a token (tuple or str).""" |
| return self._tokenizer.id_to_tok.get(index, self.unk_token) |
|
|
| def convert_tokens_to_string(self, tokens: List[Token]) -> MidiDict: |
| """Converts a sequence of tokens into a single MidiDict.""" |
| return self._tokenizer.detokenize(tokens) |
|
|
| def save_vocabulary( |
| self, save_directory, filename_prefix: Optional[str] = None |
| ) -> Tuple[str]: |
| raise NotImplementedError() |
|
|