| from typing import Iterable |
|
|
| import torch |
| from transformers import AutoTokenizer, PreTrainedTokenizerFast |
|
|
|
|
| class MagicBERTTokenizer(PreTrainedTokenizerFast): |
| def card_token_ids(self) -> list[int]: |
| specials = set(self.all_special_tokens) |
| vocab = self.get_vocab() |
| return sorted(token_id for token, token_id in vocab.items() if token not in specials) |
|
|
| def is_card_token(self, token: str) -> bool: |
| return token not in set(self.all_special_tokens) |
|
|
| def is_card_id(self, token_id: int) -> bool: |
| token: str = self.convert_ids_to_tokens(token_id) |
| return bool(token) and self.is_card_token(token) |
|
|
| def convert_card_names_to_ids(self, names: Iterable[str]) -> torch.Tensor: |
| ids = [self.convert_tokens_to_ids(name) for name in names] |
| return torch.tensor(ids, dtype=torch.long) |
|
|
| def encode( |
| self, |
| text, |
| text_pair=None, |
| add_special_tokens=True, |
| padding="max_length", |
| truncation=False, |
| max_length=None, |
| stride=0, |
| padding_side=None, |
| return_tensors=None, |
| **kwargs, |
| ): |
| if isinstance(text, list) and text_pair is None: |
| if not text or isinstance(text[0], str): |
| return super().encode( |
| text, |
| text_pair=text_pair, |
| add_special_tokens=add_special_tokens, |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| stride=stride, |
| padding_side=padding_side, |
| return_tensors=return_tensors, |
| is_split_into_words=True, |
| **kwargs, |
| ) |
| if isinstance(text[0], list): |
| batch_encoding = super().__call__( |
| text=text, |
| add_special_tokens=add_special_tokens, |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| stride=stride, |
| padding_side=padding_side, |
| return_tensors=return_tensors, |
| is_split_into_words=True, |
| **kwargs, |
| ) |
| return batch_encoding["input_ids"] |
| return super().encode( |
| text, |
| text_pair=text_pair, |
| add_special_tokens=add_special_tokens, |
| padding=padding, |
| truncation=truncation, |
| max_length=max_length, |
| stride=stride, |
| padding_side=padding_side, |
| return_tensors=return_tensors, |
| **kwargs, |
| ) |
|
|
|
|
| MagicBERTTokenizer.register_for_auto_class(AutoTokenizer) |
|
|