ataeff commited on
Commit
7bffa1f
·
verified ·
1 Parent(s): 413b3cd

Add bpe_tokenizer.py

Browse files
Files changed (1) hide show
  1. bpe_tokenizer.py +225 -0
bpe_tokenizer.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BPE tokenizer for resonance-200m.
3
+ Uses HuggingFace tokenizers (Rust backend) for fast training + encoding.
4
+ Saves merge rules in binary format compatible with C inference.
5
+
6
+ Replaces naive Python BPE (O(n²) per merge = days on 200MB).
7
+ Rust backend: minutes.
8
+ """
9
+
10
+ import struct
11
+ import os
12
+ import json
13
+ import numpy as np
14
+
15
+
16
+ def _byte_to_unicode():
17
+ """GPT-2 byte-to-unicode mapping (ByteLevel pre-tokenizer)."""
18
+ bs = (list(range(ord("!"), ord("~") + 1)) +
19
+ list(range(ord("¡"), ord("¬") + 1)) +
20
+ list(range(ord("®"), ord("ÿ") + 1)))
21
+ cs = bs[:]
22
+ n = 0
23
+ for b in range(256):
24
+ if b not in bs:
25
+ bs.append(b)
26
+ cs.append(256 + n)
27
+ n += 1
28
+ return {b: chr(c) for b, c in zip(bs, cs)}
29
+
30
+
31
+ class BPETokenizer:
32
+ """BPE tokenizer. 256 byte tokens + learned merges.
33
+ Rust backend for speed. Binary format for C inference."""
34
+
35
+ def __init__(self, max_merges=15936):
36
+ self.max_merges = max_merges
37
+ self.merges = [] # (a, b, new_id) — C format
38
+ self.vocab_size = 256
39
+ self._hf_tok = None
40
+ self._remap_lut = None # numpy LUT: HF_id → our_id
41
+
42
+ def train(self, text_bytes, num_merges=None, report_every=2000):
43
+ """Learn BPE merges using Rust backend. Minutes, not days."""
44
+ from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders
45
+
46
+ if num_merges is None:
47
+ num_merges = self.max_merges
48
+ num_merges = min(num_merges, self.max_merges)
49
+ target_vocab = 256 + num_merges
50
+
51
+ print(f" [BPE] Training {num_merges} merges on {len(text_bytes)} bytes (Rust backend)...")
52
+
53
+ tok = Tokenizer(models.BPE())
54
+ tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
55
+ tok.decoder = decoders.ByteLevel()
56
+
57
+ trainer = trainers.BpeTrainer(
58
+ vocab_size=target_vocab,
59
+ min_frequency=2,
60
+ special_tokens=[],
61
+ initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
62
+ show_progress=True,
63
+ )
64
+
65
+ text = text_bytes.decode('utf-8', errors='replace')
66
+ lines = text.split('\n')
67
+ del text
68
+
69
+ tok.train_from_iterator(lines, trainer=trainer)
70
+ del lines
71
+
72
+ self._hf_tok = tok
73
+
74
+ # Extract merges in our (a, b, new_id) format for C inference
75
+ data = json.loads(tok.to_str())
76
+ hf_merges = data['model']['merges']
77
+ hf_vocab = data['model']['vocab']
78
+ b2u = _byte_to_unicode()
79
+
80
+ # str → our_id mapping for merge conversion
81
+ str_to_our = {}
82
+ for bv in range(256):
83
+ str_to_our[b2u[bv]] = bv
84
+
85
+ self.merges = []
86
+ for i, ms in enumerate(hf_merges):
87
+ if i >= num_merges:
88
+ break
89
+ # HF tokenizers >=0.20 returns lists ['a','b'], older returns "a b"
90
+ if isinstance(ms, list):
91
+ if len(ms) != 2:
92
+ continue
93
+ a_str, b_str = ms[0], ms[1]
94
+ else:
95
+ parts = ms.split(' ', 1)
96
+ if len(parts) != 2:
97
+ continue
98
+ a_str, b_str = parts[0], parts[1]
99
+ if a_str not in str_to_our or b_str not in str_to_our:
100
+ continue
101
+ a_id = str_to_our[a_str]
102
+ b_id = str_to_our[b_str]
103
+ new_id = 256 + len(self.merges)
104
+ self.merges.append((a_id, b_id, new_id))
105
+ str_to_our[a_str + b_str] = new_id
106
+ if (i + 1) % report_every == 0:
107
+ print(f" [BPE] {i + 1}/{len(hf_merges)} merges converted")
108
+
109
+ self.vocab_size = 256 + len(self.merges)
110
+
111
+ # Build HF→our remap LUT (numpy vectorized lookup)
112
+ hf_to_our = {}
113
+ for bv in range(256):
114
+ uc = b2u[bv]
115
+ if uc in hf_vocab:
116
+ hf_to_our[hf_vocab[uc]] = bv
117
+ for tok_str, our_id in str_to_our.items():
118
+ if tok_str in hf_vocab and our_id >= 256:
119
+ hf_to_our[hf_vocab[tok_str]] = our_id
120
+
121
+ max_hf = max(hf_to_our.keys()) + 1 if hf_to_our else 256
122
+ self._remap_lut = np.arange(max_hf, dtype=np.int32)
123
+ for hf_id, our_id in hf_to_our.items():
124
+ self._remap_lut[hf_id] = our_id
125
+ self._hf_to_our = hf_to_our
126
+
127
+ print(f" [BPE] Done: {len(self.merges)} merges, vocab={self.vocab_size}")
128
+
129
+ def encode(self, text):
130
+ """Encode text to our token IDs. Fast (Rust + numpy remap)."""
131
+ if isinstance(text, bytes):
132
+ text = text.decode('utf-8', errors='replace')
133
+
134
+ if self._hf_tok is not None and self._remap_lut is not None:
135
+ hf_ids = np.array(self._hf_tok.encode(text).ids, dtype=np.int32)
136
+ return self._remap_lut[hf_ids].tolist()
137
+
138
+ # Slow fallback (binary-only load, no HF JSON)
139
+ if isinstance(text, str):
140
+ text = text.encode('utf-8', errors='replace')
141
+ ids = list(text)
142
+ for a, b, new_id in self.merges:
143
+ new_ids = []
144
+ i = 0
145
+ while i < len(ids):
146
+ if i < len(ids) - 1 and ids[i] == a and ids[i + 1] == b:
147
+ new_ids.append(new_id)
148
+ i += 2
149
+ else:
150
+ new_ids.append(ids[i])
151
+ i += 1
152
+ ids = new_ids
153
+ return ids
154
+
155
+ def decode(self, ids):
156
+ """Decode token IDs to bytes."""
157
+ vocab = {}
158
+ for i in range(256):
159
+ vocab[i] = bytes([i])
160
+ for a, b, new_id in self.merges:
161
+ vocab[new_id] = vocab[a] + vocab[b]
162
+ out = b''
163
+ for tid in ids:
164
+ out += vocab.get(tid, b'?')
165
+ return out
166
+
167
+ def save(self, path):
168
+ """Save binary merges (C) + HF JSON + ID map."""
169
+ with open(path, 'wb') as f:
170
+ f.write(struct.pack('<I', len(self.merges)))
171
+ for a, b, new_id in self.merges:
172
+ f.write(struct.pack('<III', a, b, new_id))
173
+ print(f" [BPE] Saved {len(self.merges)} merges to {path}")
174
+
175
+ base = os.path.splitext(path)[0]
176
+ if self._hf_tok:
177
+ jp = base + '_hf.json'
178
+ self._hf_tok.save(jp)
179
+ print(f" [BPE] Saved HF tokenizer to {jp}")
180
+
181
+ if self._hf_to_our:
182
+ mp = base + '_idmap.json'
183
+ with open(mp, 'w') as f:
184
+ json.dump({str(k): v for k, v in self._hf_to_our.items()}, f)
185
+
186
+ def load(self, path):
187
+ """Load tokenizer from binary + optional HF JSON for fast encode."""
188
+ with open(path, 'rb') as f:
189
+ n = struct.unpack('<I', f.read(4))[0]
190
+ self.merges = []
191
+ for _ in range(n):
192
+ a, b, new_id = struct.unpack('<III', f.read(12))
193
+ self.merges.append((a, b, new_id))
194
+ self.vocab_size = 256 + len(self.merges)
195
+ print(f" [BPE] Loaded {len(self.merges)} merges from {path}, vocab={self.vocab_size}")
196
+
197
+ base = os.path.splitext(path)[0]
198
+ jp = base + '_hf.json'
199
+ mp = base + '_idmap.json'
200
+ if os.path.exists(jp) and os.path.exists(mp):
201
+ from tokenizers import Tokenizer
202
+ self._hf_tok = Tokenizer.from_file(jp)
203
+ with open(mp) as f:
204
+ raw = json.load(f)
205
+ hf_to_our = {int(k): v for k, v in raw.items()}
206
+ max_hf = max(hf_to_our.keys()) + 1
207
+ self._remap_lut = np.arange(max_hf, dtype=np.int32)
208
+ for hf_id, our_id in hf_to_our.items():
209
+ self._remap_lut[hf_id] = our_id
210
+ self._hf_to_our = hf_to_our
211
+ print(f" [BPE] Loaded HF tokenizer for fast encode")
212
+
213
+ def save_copies(self, base_path, n=3):
214
+ """Save tokenizer in N copies. Lesson from Janus 285M disaster."""
215
+ paths = []
216
+ for i in range(n):
217
+ if i == 0:
218
+ p = base_path
219
+ else:
220
+ name, ext = os.path.splitext(base_path)
221
+ p = f"{name}_backup{i}{ext}"
222
+ self.save(p)
223
+ paths.append(p)
224
+ print(f" [BPE] Saved {n} copies: {paths}")
225
+ return paths