from typing import Iterable import torch from transformers import AutoTokenizer, PreTrainedTokenizerFast class MagicBERTTokenizer(PreTrainedTokenizerFast): def card_token_ids(self) -> list[int]: specials = set(self.all_special_tokens) vocab = self.get_vocab() return sorted(token_id for token, token_id in vocab.items() if token not in specials) def is_card_token(self, token: str) -> bool: return token not in set(self.all_special_tokens) def is_card_id(self, token_id: int) -> bool: token: str = self.convert_ids_to_tokens(token_id) # type: ignore return bool(token) and self.is_card_token(token) def convert_card_names_to_ids(self, names: Iterable[str]) -> torch.Tensor: ids = [self.convert_tokens_to_ids(name) for name in names] return torch.tensor(ids, dtype=torch.long) def encode( self, text, text_pair=None, add_special_tokens=True, padding="max_length", truncation=False, max_length=None, stride=0, padding_side=None, return_tensors=None, **kwargs, ): if isinstance(text, list) and text_pair is None: if not text or isinstance(text[0], str): return super().encode( text, text_pair=text_pair, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, padding_side=padding_side, return_tensors=return_tensors, is_split_into_words=True, **kwargs, ) if isinstance(text[0], list): batch_encoding = super().__call__( text=text, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, padding_side=padding_side, return_tensors=return_tensors, is_split_into_words=True, **kwargs, ) return batch_encoding["input_ids"] return super().encode( text, text_pair=text_pair, add_special_tokens=add_special_tokens, padding=padding, truncation=truncation, max_length=max_length, stride=stride, padding_side=padding_side, return_tensors=return_tensors, **kwargs, ) MagicBERTTokenizer.register_for_auto_class(AutoTokenizer)