| from typing import Dict, Iterator, List, Optional, Union
|
|
|
| from tokenizers import AddedToken, Tokenizer, decoders, trainers
|
| from tokenizers.models import WordPiece
|
| from tokenizers.normalizers import BertNormalizer
|
| from tokenizers.pre_tokenizers import BertPreTokenizer
|
| from tokenizers.processors import BertProcessing
|
|
|
| from .base_tokenizer import BaseTokenizer
|
|
|
|
|
| class BertWordPieceTokenizer(BaseTokenizer):
|
| """Bert WordPiece Tokenizer"""
|
|
|
| def __init__(
|
| self,
|
| vocab: Optional[Union[str, Dict[str, int]]] = None,
|
| unk_token: Union[str, AddedToken] = "[UNK]",
|
| sep_token: Union[str, AddedToken] = "[SEP]",
|
| cls_token: Union[str, AddedToken] = "[CLS]",
|
| pad_token: Union[str, AddedToken] = "[PAD]",
|
| mask_token: Union[str, AddedToken] = "[MASK]",
|
| clean_text: bool = True,
|
| handle_chinese_chars: bool = True,
|
| strip_accents: Optional[bool] = None,
|
| lowercase: bool = True,
|
| wordpieces_prefix: str = "##",
|
| ):
|
| if vocab is not None:
|
| tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(unk_token)))
|
| else:
|
| tokenizer = Tokenizer(WordPiece(unk_token=str(unk_token)))
|
|
|
|
|
| if tokenizer.token_to_id(str(unk_token)) is not None:
|
| tokenizer.add_special_tokens([str(unk_token)])
|
| if tokenizer.token_to_id(str(sep_token)) is not None:
|
| tokenizer.add_special_tokens([str(sep_token)])
|
| if tokenizer.token_to_id(str(cls_token)) is not None:
|
| tokenizer.add_special_tokens([str(cls_token)])
|
| if tokenizer.token_to_id(str(pad_token)) is not None:
|
| tokenizer.add_special_tokens([str(pad_token)])
|
| if tokenizer.token_to_id(str(mask_token)) is not None:
|
| tokenizer.add_special_tokens([str(mask_token)])
|
|
|
| tokenizer.normalizer = BertNormalizer(
|
| clean_text=clean_text,
|
| handle_chinese_chars=handle_chinese_chars,
|
| strip_accents=strip_accents,
|
| lowercase=lowercase,
|
| )
|
| tokenizer.pre_tokenizer = BertPreTokenizer()
|
|
|
| if vocab is not None:
|
| sep_token_id = tokenizer.token_to_id(str(sep_token))
|
| if sep_token_id is None:
|
| raise TypeError("sep_token not found in the vocabulary")
|
| cls_token_id = tokenizer.token_to_id(str(cls_token))
|
| if cls_token_id is None:
|
| raise TypeError("cls_token not found in the vocabulary")
|
|
|
| tokenizer.post_processor = BertProcessing((str(sep_token), sep_token_id), (str(cls_token), cls_token_id))
|
| tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
|
|
|
| parameters = {
|
| "model": "BertWordPiece",
|
| "unk_token": unk_token,
|
| "sep_token": sep_token,
|
| "cls_token": cls_token,
|
| "pad_token": pad_token,
|
| "mask_token": mask_token,
|
| "clean_text": clean_text,
|
| "handle_chinese_chars": handle_chinese_chars,
|
| "strip_accents": strip_accents,
|
| "lowercase": lowercase,
|
| "wordpieces_prefix": wordpieces_prefix,
|
| }
|
|
|
| super().__init__(tokenizer, parameters)
|
|
|
| @staticmethod
|
| def from_file(vocab: str, **kwargs):
|
| vocab = WordPiece.read_file(vocab)
|
| return BertWordPieceTokenizer(vocab, **kwargs)
|
|
|
| def train(
|
| self,
|
| files: Union[str, List[str]],
|
| vocab_size: int = 30000,
|
| min_frequency: int = 2,
|
| limit_alphabet: int = 1000,
|
| initial_alphabet: List[str] = [],
|
| special_tokens: List[Union[str, AddedToken]] = [
|
| "[PAD]",
|
| "[UNK]",
|
| "[CLS]",
|
| "[SEP]",
|
| "[MASK]",
|
| ],
|
| show_progress: bool = True,
|
| wordpieces_prefix: str = "##",
|
| ):
|
| """Train the model using the given files"""
|
|
|
| trainer = trainers.WordPieceTrainer(
|
| vocab_size=vocab_size,
|
| min_frequency=min_frequency,
|
| limit_alphabet=limit_alphabet,
|
| initial_alphabet=initial_alphabet,
|
| special_tokens=special_tokens,
|
| show_progress=show_progress,
|
| continuing_subword_prefix=wordpieces_prefix,
|
| )
|
| if isinstance(files, str):
|
| files = [files]
|
| self._tokenizer.train(files, trainer=trainer)
|
|
|
| def train_from_iterator(
|
| self,
|
| iterator: Union[Iterator[str], Iterator[Iterator[str]]],
|
| vocab_size: int = 30000,
|
| min_frequency: int = 2,
|
| limit_alphabet: int = 1000,
|
| initial_alphabet: List[str] = [],
|
| special_tokens: List[Union[str, AddedToken]] = [
|
| "[PAD]",
|
| "[UNK]",
|
| "[CLS]",
|
| "[SEP]",
|
| "[MASK]",
|
| ],
|
| show_progress: bool = True,
|
| wordpieces_prefix: str = "##",
|
| length: Optional[int] = None,
|
| ):
|
| """Train the model using the given iterator"""
|
|
|
| trainer = trainers.WordPieceTrainer(
|
| vocab_size=vocab_size,
|
| min_frequency=min_frequency,
|
| limit_alphabet=limit_alphabet,
|
| initial_alphabet=initial_alphabet,
|
| special_tokens=special_tokens,
|
| show_progress=show_progress,
|
| continuing_subword_prefix=wordpieces_prefix,
|
| )
|
| self._tokenizer.train_from_iterator(
|
| iterator,
|
| trainer=trainer,
|
| length=length,
|
| )
|
|
|