| |
| |
| |
| |
|
|
| """Tokenization classes for QWen.""" |
|
|
| import base64 |
| import logging |
| import os |
| import re |
| import itertools |
|
|
| import requests |
| import unicodedata |
| from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional |
|
|
| import tiktoken |
| import numpy as np |
| from PIL import Image |
| from PIL import ImageFont |
| from PIL import ImageDraw |
| from transformers import PreTrainedTokenizer, AddedToken |
| from transformers.utils import try_to_load_from_cache |
| from transformers.tokenization_utils_base import BatchEncoding,PaddingStrategy,TruncationStrategy,\ |
| TextInput,TextInputPair,PreTokenizedInput,PreTokenizedInputPair,TensorType, EncodedInput, EncodedInputPair |
|
|
| import matplotlib.colors as mcolors |
| from matplotlib.font_manager import FontProperties |
| from .audio import * |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"} |
|
|
| PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" |
| ENDOFTEXT = "<|endoftext|>" |
| IMSTART = "<|im_start|>" |
| IMEND = "<|im_end|>" |
| |
| |
| |
| EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) |
| SPECIAL_TOKENS = ( |
| ENDOFTEXT, |
| IMSTART, |
| IMEND, |
| ) + EXTRAS |
| IMG_TOKEN_SPAN = 256 |
| LANGUAGES = { |
| "en": "english", |
| "zh": "chinese", |
| "de": "german", |
| "es": "spanish", |
| "ko": "korean", |
| "fr": "french", |
| "ja": "japanese", |
| "it": "italian", |
| } |
|
|
|
|
| def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: |
| with open(tiktoken_bpe_file, "rb") as f: |
| contents = f.read() |
| return { |
| base64.b64decode(token): int(rank) |
| for token, rank in (line.split() for line in contents.splitlines() if line) |
| } |
|
|
| def _list_find( |
| input_list: List[Any], |
| candidates: Tuple[Any], |
| start: int = 0, |
| ): |
| for i in range(start, len(input_list)): |
| if input_list[i] in candidates: |
| return i |
| return -1 |
|
|
| def _replace_closed_tag( |
| input_tokens: List[Any], |
| start_tags: Union[Any, Tuple[Any]], |
| end_tags: Union[Any, Tuple[Any]], |
| inclusive_replace_func: Callable, |
| exclusive_replace_func: Callable = lambda x: x, |
| audio_info: Dict = None |
| ): |
| if isinstance(start_tags, (str, int)): |
| start_tags = (start_tags,) |
| if isinstance(end_tags, (str, int)): |
| end_tags = (end_tags,) |
| assert len(start_tags) == len(end_tags) |
|
|
| output_tokens = [] |
| end = 0 |
| audio_idx = 0 |
| while True: |
| start = _list_find(input_tokens, start_tags, end) |
| if start == -1: |
| break |
| output_tokens.extend(exclusive_replace_func(input_tokens[end : start])) |
| tag_idx = start_tags.index(input_tokens[start]) |
| end = _list_find(input_tokens, (end_tags[tag_idx],), start) |
| if end == -1: |
| raise ValueError("Unclosed image token") |
| output_tokens.extend(inclusive_replace_func(input_tokens[start : end + 1], audio_info, audio_idx)) |
| end += 1 |
| audio_idx += 1 |
| output_tokens.extend(exclusive_replace_func(input_tokens[end : ])) |
| return output_tokens |
|
|
| class QWenTokenizer(PreTrainedTokenizer): |
| """QWen tokenizer.""" |
|
|
| vocab_files_names = VOCAB_FILES_NAMES |
|
|
| def __init__( |
| self, |
| vocab_file, |
| errors="replace", |
| audio_start_tag='<audio>', |
| audio_end_tag='</audio>', |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| self.audio_start_tag = audio_start_tag |
| self.audio_end_tag = audio_end_tag |
| self.audio_pad_tag = "[[[AUDIO:modality]]]" |
| self.IMAGE_ST = ("<ref>", "</ref>", "<box>", "</box>", "<quad>", "</quad>") |
|
|
| self.AUDIO_ST = ( |
| '[[[AUDIO:modality]]]', |
| "<|startoftranscript|>", |
| "<|startofcaption|>", |
| |
| "<|translate|>", |
| "<|transcribe|>", |
| "<|caption|>", |
| "<|keyword|>", |
| |
| "<|unknown|>", |
| *[f"<|{lang}|>" for lang in LANGUAGES.keys()], |
| "<|zh_tw|>", |
| |
| "<|notimestamps|>", |
| "<|sil|>", |
| "<|timestamps|>", |
| *[f"<|{i * 0.01:.2f}|>" for i in range(3001)], |
| |
| "<|caption_audiocaps|>", |
| "<|caption_clotho|>", |
| "<|audioset_ontology|>", |
| "<|caption_plain|>", |
| "<|itn|>", |
| "<|wo_itn|>", |
| |
| "<|startofentityvalue|>", |
| "<|endofentityvalue|>", |
| "<|startofentitytype|>", |
| "<|endofentitytype|>", |
| "<|named_entity_recognition|>", |
| |
| "<|grounding|>", |
| "<|startofword|>", |
| "<|endofword|>", |
| "<|delim|>", |
| |
| "<|emotion_recognition|>", |
| |
| "<|music_description|>", |
| |
| "<|note_analysis|>", |
| "<|pitch|>", |
| *[f"<|midi_pitch_{i}|>" for i in range(128)], |
| "<|velocity|>", |
| *[f"<|midi_velocity_{i}|>" for i in range(128)], |
| "<|sonic|>", |
| "<|instrument|>", |
| |
| "<|speaker_meta|>", |
| |
| "<|song_meta|>", |
| |
| "<|question|>", |
| "<|answer|>", |
| "<|choice|>", |
| |
| "<|scene|>", |
| |
| "<|event|>", |
| |
| "<|vocal_classification|>", |
| |
| "<|speech_understanding|>", |
| "<|scenario|>", |
| "<|action|>", |
| "<|entities|>", |
| |
| "<|speech_edit|>", |
| |
| "<|speech_command|>", |
| audio_start_tag, |
| audio_end_tag |
| ) |
|
|
| self.errors = errors |
|
|
| self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) |
| self.special_tokens = { |
| token: index |
| for index, token in enumerate( |
| |
| SPECIAL_TOKENS + self.AUDIO_ST, start=len(self.mergeable_ranks) |
| |
| ) |
| } |
| self.audio_start_id = self.special_tokens[self.audio_start_tag] |
| self.audio_end_id = self.special_tokens[self.audio_end_tag] |
| self.audio_pad_id = self.special_tokens[self.audio_pad_tag] |
| print(f"audio_start_id: {self.audio_start_id}, " |
| f"audio_end_id: {self.audio_end_id}, " |
| f"audio_pad_id: {self.audio_pad_id}.") |
|
|
| enc = tiktoken.Encoding( |
| "Qwen", |
| pat_str=PAT_STR, |
| mergeable_ranks=self.mergeable_ranks, |
| special_tokens=self.special_tokens, |
| ) |
| assert ( |
| len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab |
| ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" |
|
|
| self.decoder = { |
| v: k for k, v in self.mergeable_ranks.items() |
| } |
| self.decoder.update({v: k for k, v in self.special_tokens.items()}) |
|
|
| self.tokenizer = enc |
|
|
| self.eod_id = self.tokenizer.eot_token |
| self.im_start_id = self.special_tokens[IMSTART] |
| self.im_end_id = self.special_tokens[IMEND] |
|
|
| def __getstate__(self): |
| |
| state = self.__dict__.copy() |
| del state['tokenizer'] |
| return state |
|
|
| def __setstate__(self, state): |
| |
| self.__dict__.update(state) |
| enc = tiktoken.Encoding( |
| "Qwen", |
| pat_str=PAT_STR, |
| mergeable_ranks=self.mergeable_ranks, |
| special_tokens=self.special_tokens, |
| ) |
| self.tokenizer = enc |
|
|
|
|
| def __len__(self) -> int: |
| return self.tokenizer.n_vocab |
|
|
| def get_vocab(self) -> Dict[bytes, int]: |
| return self.mergeable_ranks |
|
|
| def convert_tokens_to_ids( |
| self, tokens: Union[bytes, str, List[Union[bytes, str]]] |
| ) -> List[int]: |
| ids = [] |
| if isinstance(tokens, (str, bytes)): |
| if tokens in self.special_tokens: |
| return self.special_tokens[tokens] |
| else: |
| return self.mergeable_ranks.get(tokens) |
| for token in tokens: |
| if token in self.special_tokens: |
| ids.append(self.special_tokens[token]) |
| else: |
| ids.append(self.mergeable_ranks.get(token)) |
| return ids |
|
|
| def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: |
| if not special_tokens and new_tokens: |
| raise ValueError('Adding regular tokens is not supported') |
| for token in new_tokens: |
| surface_form = token.content if isinstance(token, AddedToken) else token |
| if surface_form not in SPECIAL_TOKENS + self.IMAGE_ST+ self.AUDIO_ST: |
| raise ValueError('Adding unknown special tokens is not supported') |
| return 0 |
|
|
| def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: |
| """ |
| Save only the vocabulary of the tokenizer (vocabulary). |
| |
| Returns: |
| `Tuple(str)`: Paths to the files saved. |
| """ |
| file_path = os.path.join(save_directory, "qwen.tiktoken") |
| with open(file_path, "w", encoding="utf8") as w: |
| for k, v in self.mergeable_ranks.items(): |
| line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n" |
| w.write(line) |
| return (file_path,) |
|
|
| def tokenize( |
| self, |
| text: str, |
| allowed_special: Union[Set, str] = "all", |
| disallowed_special: Union[Collection, str] = (), |
| audio_info: Dict = None, |
| **kwargs, |
| ) -> List[Union[bytes, str]]: |
| """ |
| Converts a string in a sequence of tokens. |
| |
| Args: |
| text (`str`): |
| The sequence to be encoded. |
| allowed_special (`Literal["all"]` or `set`): |
| The surface forms of the tokens to be encoded as special tokens in regular texts. |
| Default to "all". |
| disallowed_special (`Literal["all"]` or `Collection`): |
| The surface forms of the tokens that should not be in regular texts and trigger errors. |
| Default to an empty tuple. |
| |
| kwargs (additional keyword arguments, *optional*): |
| Will be passed to the underlying model specific encode method. |
| |
| Returns: |
| `List[bytes|str]`: The list of tokens. |
| """ |
| tokens = [] |
| text = unicodedata.normalize("NFC", text) |
|
|
| |
| for t in self.tokenizer.encode( |
| text, allowed_special=allowed_special, disallowed_special=disallowed_special |
| ): |
| tokens.append(self.decoder[t]) |
|
|
| def _encode_audiourl(audio_tokens, audio_info, audio_idx): |
| assert audio_tokens[0] == self.audio_start_tag and audio_tokens[-1] == self.audio_end_tag |
| audio_token_span = audio_info['audio_span_tokens'][audio_idx] |
| out_audio_tokens = [self.audio_start_tag] + [self.audio_pad_tag]*(audio_token_span-2) + [self.audio_end_tag] |
| return out_audio_tokens |
|
|
| return _replace_closed_tag(tokens, self.audio_start_tag, self.audio_end_tag, _encode_audiourl, audio_info=audio_info) |
|
|
| def _batch_encode_plus( |
| self, |
| batch_text_or_text_pairs: Union[ |
| List[TextInput], |
| List[TextInputPair], |
| List[PreTokenizedInput], |
| List[PreTokenizedInputPair], |
| List[EncodedInput], |
| List[EncodedInputPair], |
| ], |
| add_special_tokens: bool = True, |
| padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, |
| truncation_strategy: TruncationStrategy = TruncationStrategy.DO_NOT_TRUNCATE, |
| max_length: Optional[int] = None, |
| stride: int = 0, |
| is_split_into_words: bool = False, |
| pad_to_multiple_of: Optional[int] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| return_token_type_ids: Optional[bool] = None, |
| return_attention_mask: Optional[bool] = None, |
| return_overflowing_tokens: bool = False, |
| return_special_tokens_mask: bool = False, |
| return_offsets_mapping: bool = False, |
| return_length: bool = False, |
| verbose: bool = True, |
| **kwargs, |
| ) -> BatchEncoding: |
|
|
| def get_input_ids(text): |
| if isinstance(text, str): |
| tokens = self.tokenize(text, **kwargs) |
| return self.convert_tokens_to_ids(tokens) |
| elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str): |
| if is_split_into_words: |
| tokens = list( |
| itertools.chain(*(self.tokenize(t, is_split_into_words=True, **kwargs) for t in text)) |
| ) |
| return self.convert_tokens_to_ids(tokens) |
| else: |
| return self.convert_tokens_to_ids(text) |
| elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): |
| return text |
| else: |
| raise ValueError( |
| "Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers." |
| ) |
|
|
| if return_offsets_mapping: |
| raise NotImplementedError( |
| "return_offset_mapping is not available when using Python tokenizers. " |
| "To use this feature, change your tokenizer to one deriving from " |
| "transformers.PreTrainedTokenizerFast." |
| ) |
|
|
| input_ids = [] |
| audio_info = kwargs.pop("audio_info", None) |
| for pair_id in range(len(batch_text_or_text_pairs)): |
| kwargs['audio_info'] = audio_info[pair_id] |
| ids_or_pair_ids = batch_text_or_text_pairs[pair_id] |
| |
| if not isinstance(ids_or_pair_ids, (list, tuple)): |
| ids, pair_ids = ids_or_pair_ids, None |
| elif is_split_into_words and not isinstance(ids_or_pair_ids[0], (list, tuple)): |
| ids, pair_ids = ids_or_pair_ids, None |
| else: |
| ids, pair_ids = ids_or_pair_ids |
|
|
| first_ids = get_input_ids(ids) |
| second_ids = get_input_ids(pair_ids) if pair_ids is not None else None |
| input_ids.append((first_ids, second_ids)) |
|
|
| batch_outputs = self._batch_prepare_for_model( |
| input_ids, |
| add_special_tokens=add_special_tokens, |
| padding_strategy=padding_strategy, |
| truncation_strategy=truncation_strategy, |
| max_length=max_length, |
| stride=stride, |
| pad_to_multiple_of=pad_to_multiple_of, |
| return_attention_mask=return_attention_mask, |
| return_token_type_ids=return_token_type_ids, |
| return_overflowing_tokens=return_overflowing_tokens, |
| return_special_tokens_mask=return_special_tokens_mask, |
| return_length=return_length, |
| return_tensors=return_tensors, |
| verbose=verbose, |
| ) |
|
|
| return BatchEncoding(batch_outputs) |
|
|
| def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: |
| """ |
| Converts a sequence of tokens in a single string. |
| """ |
| text = "" |
| temp = b"" |
| for t in tokens: |
| if isinstance(t, str): |
| if temp: |
| text += temp.decode("utf-8", errors=self.errors) |
| temp = b"" |
| text += t |
| elif isinstance(t, bytes): |
| temp += t |
| else: |
| raise TypeError("token should only be of type types or str") |
| if temp: |
| text += temp.decode("utf-8", errors=self.errors) |
| return text |
|
|
| @property |
| def vocab_size(self): |
| return self.tokenizer.n_vocab |
|
|
| def _convert_id_to_token(self, index: int) -> Union[bytes, str]: |
| """Converts an id to a token, special tokens included""" |
| if index in self.decoder: |
| return self.decoder[index] |
| raise ValueError("unknown ids") |
|
|
| def _convert_token_to_id(self, token: Union[bytes, str]) -> int: |
| """Converts a token to an id using the vocab, special tokens included""" |
| if token in self.special_tokens: |
| return self.special_tokens[token] |
| if token in self.mergeable_ranks: |
| return self.mergeable_ranks[token] |
| raise ValueError("unknown token") |
|
|
| def _tokenize(self, text: str, **kwargs): |
| """ |
| Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based |
| vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). |
| |
| Do NOT take care of added tokens. |
| """ |
| raise NotImplementedError |
|
|
| def _decode( |
| self, |
| token_ids: Union[int, List[int]], |
| skip_special_tokens: bool = False, |
| errors: str = None, |
| **kwargs, |
| ) -> str: |
| if isinstance(token_ids, int): |
| token_ids = [token_ids] |
| audio_info = kwargs.pop("audio_info", None) |
|
|
|
|
| def _decode_audiourl(audio_token_ids, audio_info, audio_idx): |
| assert audio_token_ids[0] == self.audio_start_id and audio_token_ids[-1] == self.audio_end_id |
| audio_url = audio_info["audio_urls"][audio_idx] |
| return [self.audio_start_id] + self.tokenizer.encode(audio_url) + [self.audio_end_id] |
|
|
| token_ids = _replace_closed_tag(token_ids, self.audio_start_id, self.audio_end_id, _decode_audiourl, audio_info=audio_info) |
|
|
| if skip_special_tokens: |
| token_ids = [i for i in token_ids if i < self.eod_id] |
| return self.tokenizer.decode(token_ids, errors=errors or self.errors) |
|
|
| def to_list_format(self, text: str): |
| text = unicodedata.normalize("NFC", text) |
| token_ids = self.tokenizer.encode( |
| text, allowed_special=set(self.IMAGE_ST + self.AUDIO_ST + (ENDOFTEXT,))) |
|
|
| def _encode_audio_info(tokens): |
| if len(tokens) == 0: |
| return [] |
| if tokens[0] == self.audio_start_id and tokens[-1] == self.audio_end_id: |
| key = 'audio' |
| else: |
| _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x |
| return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}] |
| _tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x |
| val = b''.join(map(_tobytes, map(self.decoder.get, tokens[1:-1]))).decode('utf-8') |
| return [{key: val}] |
|
|
| return _replace_closed_tag( |
| token_ids, |
| (self.audio_start_id), |
| (self.audio_end_id), |
| _encode_audio_info, |
| _encode_audio_info, |
| ) |
|
|
| def from_list_format(self, list_format: List[Dict]): |
| text = '' |
| num_audios = 0 |
| for ele in list_format: |
| if 'audio' in ele: |
| num_audios += 1 |
| text += f'Audio {num_audios}:' |
| text += self.audio_start_tag + ele['audio'] + self.audio_end_tag |
| text += '\n' |
| elif 'text' in ele: |
| text += ele['text'] |
| elif 'box' in ele: |
| if 'ref' in ele: |
| text += self.ref_start_tag + ele['ref'] + self.ref_end_tag |
| for box in ele['box']: |
| text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag |
| else: |
| raise ValueError("Unsupport element: " + str(ele)) |
| return text |
|
|
| def extract_audio_urls(self, text): |
| pattern = rf"{self.audio_start_tag}(.*?){self.audio_end_tag}" |
| return re.findall(pattern, text) |
|
|
| def process_audio(self, text): |
| audio_urls = self.extract_audio_urls(text) |
| if len(audio_urls)> 0: |
| audios, audio_lens, audio_span_tokens = [], [], [] |
| for audio_path in audio_urls: |
| if audio_path.startswith("http://") or audio_path.startswith("https://"): |
| data = bytes(requests.get(audio_path, stream=True).content) |
| audio = load_bytesio_audio(data) |
| else: |
| audio = load_audio(audio_path) |
| L = (audio.shape[0] if audio.shape[0] <= 480000 else 480000) |
| mel_len = L // 160 |
| audio = pad_or_trim(audio.flatten()) |
| mel = log_mel_spectrogram(audio) |
| audio_len_after_cnn = get_T_after_cnn(mel_len) |
| audio_token_num = (audio_len_after_cnn - 2) // 2 + 1 |
| audio_len = [audio_len_after_cnn, audio_token_num] |
| audios.append(mel) |
| audio_lens.append(audio_len) |
| audio_span_tokens.append(audio_token_num+2) |
| input_audio_lengths = torch.IntTensor(audio_lens) |
| input_audios = torch.stack(audios, dim=0) |
| return {"input_audios": input_audios, |
| "input_audio_lengths": input_audio_lengths, |
| "audio_span_tokens": audio_span_tokens, |
| "audio_urls": audio_urls} |
| else: |
| return None |
|
|
|
|
|
|
|
|
|
|
|
|