qubitpage commited on
Commit
dfc9678
·
verified ·
1 Parent(s): aa231ce

Upload hf_tokenizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hf_tokenizer.py +140 -0
hf_tokenizer.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace-compatible tokenizer wrapper for tiktoken cl100k_base.
3
+
4
+ Wraps tiktoken so it works with HF's generate(), lm-evaluation-harness,
5
+ and the Hub (tokenizer.json / tokenizer_config.json).
6
+
7
+ Usage:
8
+ from hf_tokenizer import SentinelBrainTokenizer
9
+ tok = SentinelBrainTokenizer()
10
+ ids = tok("Hello world", return_tensors="pt")
11
+ """
12
+
13
+ import json
14
+ import os
15
+ from typing import Optional, List, Dict, Union
16
+ import tiktoken
17
+ from transformers import PreTrainedTokenizer
18
+
19
+
20
+ class SentinelBrainTokenizer(PreTrainedTokenizer):
21
+ """HuggingFace PreTrainedTokenizer wrapping tiktoken cl100k_base."""
22
+
23
+ vocab_files_names = {"vocab_file": "tiktoken_vocab.json"}
24
+ model_input_names = ["input_ids", "attention_mask"]
25
+
26
+ def __init__(
27
+ self,
28
+ vocab_file: Optional[str] = None,
29
+ eos_token: str = "<|endoftext|>",
30
+ pad_token: str = "<|endoftext|>",
31
+ model_max_length: int = 1024,
32
+ **kwargs,
33
+ ):
34
+ self._enc = tiktoken.get_encoding("cl100k_base")
35
+ self._vocab_size = self._enc.n_vocab # 100277
36
+
37
+ # Build token-to-id mapping for special tokens
38
+ self._special_tokens = {
39
+ "<|endoftext|>": self._enc.eot_token, # 100257
40
+ }
41
+
42
+ super().__init__(
43
+ eos_token=eos_token,
44
+ pad_token=pad_token,
45
+ model_max_length=model_max_length,
46
+ **kwargs,
47
+ )
48
+
49
+ @property
50
+ def vocab_size(self) -> int:
51
+ return self._vocab_size
52
+
53
+ def get_vocab(self) -> Dict[str, int]:
54
+ """Return vocab dict. tiktoken doesn't expose full vocab easily,
55
+ so we return a partial mapping for special tokens."""
56
+ vocab = {}
57
+ # Add special tokens
58
+ for tok, idx in self._special_tokens.items():
59
+ vocab[tok] = idx
60
+ return vocab
61
+
62
+ def _tokenize(self, text: str, **kwargs) -> List[str]:
63
+ """Tokenize into string tokens (HF convention).
64
+ We return token IDs as strings since tiktoken uses bytes."""
65
+ token_ids = self._enc.encode(text, allowed_special={"<|endoftext|>"})
66
+ return [str(tid) for tid in token_ids]
67
+
68
+ def _convert_token_to_id(self, token: str) -> int:
69
+ """Convert string token → ID."""
70
+ if token in self._special_tokens:
71
+ return self._special_tokens[token]
72
+ try:
73
+ return int(token)
74
+ except ValueError:
75
+ return self._enc.eot_token # fallback
76
+
77
+ def _convert_id_to_token(self, index: int) -> str:
78
+ """Convert ID → string token."""
79
+ try:
80
+ return self._enc.decode([index])
81
+ except Exception:
82
+ return "<|unk|>"
83
+
84
+ def convert_tokens_to_string(self, tokens: List[str]) -> str:
85
+ """Convert token strings back to text."""
86
+ ids = []
87
+ for t in tokens:
88
+ try:
89
+ ids.append(int(t))
90
+ except ValueError:
91
+ if t in self._special_tokens:
92
+ ids.append(self._special_tokens[t])
93
+ try:
94
+ return self._enc.decode(ids)
95
+ except Exception:
96
+ return ""
97
+
98
+ def encode(self, text: Union[str, List[str]], add_special_tokens: bool = True,
99
+ **kwargs) -> Union[List[int], List[List[int]]]:
100
+ """Fast-path encode using tiktoken directly."""
101
+ if isinstance(text, str):
102
+ ids = self._enc.encode(text, allowed_special={"<|endoftext|>"})
103
+ return ids
104
+ return [self._enc.encode(t, allowed_special={"<|endoftext|>"}) for t in text]
105
+
106
+ def decode(self, token_ids: Union[List[int], int], skip_special_tokens: bool = False,
107
+ **kwargs) -> str:
108
+ """Fast-path decode using tiktoken directly."""
109
+ if isinstance(token_ids, int):
110
+ token_ids = [token_ids]
111
+ if skip_special_tokens:
112
+ token_ids = [t for t in token_ids if t != self._enc.eot_token]
113
+ try:
114
+ return self._enc.decode(token_ids)
115
+ except Exception:
116
+ return ""
117
+
118
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
119
+ """Save a minimal vocab file so from_pretrained works."""
120
+ if not os.path.isdir(save_directory):
121
+ os.makedirs(save_directory, exist_ok=True)
122
+ prefix = filename_prefix + "-" if filename_prefix else ""
123
+ vocab_file = os.path.join(save_directory, prefix + "tiktoken_vocab.json")
124
+ vocab_data = {
125
+ "encoding": "cl100k_base",
126
+ "vocab_size": self._vocab_size,
127
+ "eos_token_id": self._enc.eot_token,
128
+ "special_tokens": self._special_tokens,
129
+ }
130
+ with open(vocab_file, "w", encoding="utf-8") as f:
131
+ json.dump(vocab_data, f, indent=2)
132
+ return (vocab_file,)
133
+
134
+ @classmethod
135
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
136
+ """Load from directory. Falls back to creating fresh tokenizer."""
137
+ try:
138
+ return super().from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
139
+ except Exception:
140
+ return cls(**kwargs)