magicBERT / tokenizer.py
nishtahir's picture
Upload folder using huggingface_hub
7cf5414 verified
from typing import Iterable
import torch
from transformers import AutoTokenizer, PreTrainedTokenizerFast
class MagicBERTTokenizer(PreTrainedTokenizerFast):
def card_token_ids(self) -> list[int]:
specials = set(self.all_special_tokens)
vocab = self.get_vocab()
return sorted(token_id for token, token_id in vocab.items() if token not in specials)
def is_card_token(self, token: str) -> bool:
return token not in set(self.all_special_tokens)
def is_card_id(self, token_id: int) -> bool:
token: str = self.convert_ids_to_tokens(token_id) # type: ignore
return bool(token) and self.is_card_token(token)
def convert_card_names_to_ids(self, names: Iterable[str]) -> torch.Tensor:
ids = [self.convert_tokens_to_ids(name) for name in names]
return torch.tensor(ids, dtype=torch.long)
def encode(
self,
text,
text_pair=None,
add_special_tokens=True,
padding="max_length",
truncation=False,
max_length=None,
stride=0,
padding_side=None,
return_tensors=None,
**kwargs,
):
if isinstance(text, list) and text_pair is None:
if not text or isinstance(text[0], str):
return super().encode(
text,
text_pair=text_pair,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
padding_side=padding_side,
return_tensors=return_tensors,
is_split_into_words=True,
**kwargs,
)
if isinstance(text[0], list):
batch_encoding = super().__call__(
text=text,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
padding_side=padding_side,
return_tensors=return_tensors,
is_split_into_words=True,
**kwargs,
)
return batch_encoding["input_ids"]
return super().encode(
text,
text_pair=text_pair,
add_special_tokens=add_special_tokens,
padding=padding,
truncation=truncation,
max_length=max_length,
stride=stride,
padding_side=padding_side,
return_tensors=return_tensors,
**kwargs,
)
MagicBERTTokenizer.register_for_auto_class(AutoTokenizer)