| """ |
| BPE tokenizer for resonance-200m. |
| Uses HuggingFace tokenizers (Rust backend) for fast training + encoding. |
| Saves merge rules in binary format compatible with C inference. |
| |
| Replaces naive Python BPE (O(n²) per merge = days on 200MB). |
| Rust backend: minutes. |
| """ |
|
|
| import struct |
| import os |
| import json |
| import numpy as np |
|
|
|
|
| def _byte_to_unicode(): |
| """GPT-2 byte-to-unicode mapping (ByteLevel pre-tokenizer).""" |
| bs = (list(range(ord("!"), ord("~") + 1)) + |
| list(range(ord("¡"), ord("¬") + 1)) + |
| list(range(ord("®"), ord("ÿ") + 1))) |
| cs = bs[:] |
| n = 0 |
| for b in range(256): |
| if b not in bs: |
| bs.append(b) |
| cs.append(256 + n) |
| n += 1 |
| return {b: chr(c) for b, c in zip(bs, cs)} |
|
|
|
|
| class BPETokenizer: |
| """BPE tokenizer. 256 byte tokens + learned merges. |
| Rust backend for speed. Binary format for C inference.""" |
|
|
| def __init__(self, max_merges=15936): |
| self.max_merges = max_merges |
| self.merges = [] |
| self.vocab_size = 256 |
| self._hf_tok = None |
| self._remap_lut = None |
|
|
| def train(self, text_bytes, num_merges=None, report_every=2000): |
| """Learn BPE merges using Rust backend. Minutes, not days.""" |
| from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders |
|
|
| if num_merges is None: |
| num_merges = self.max_merges |
| num_merges = min(num_merges, self.max_merges) |
| target_vocab = 256 + num_merges |
|
|
| print(f" [BPE] Training {num_merges} merges on {len(text_bytes)} bytes (Rust backend)...") |
|
|
| tok = Tokenizer(models.BPE()) |
| tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) |
| tok.decoder = decoders.ByteLevel() |
|
|
| trainer = trainers.BpeTrainer( |
| vocab_size=target_vocab, |
| min_frequency=2, |
| special_tokens=[], |
| initial_alphabet=pre_tokenizers.ByteLevel.alphabet(), |
| show_progress=True, |
| ) |
|
|
| text = text_bytes.decode('utf-8', errors='replace') |
| lines = text.split('\n') |
| del text |
|
|
| tok.train_from_iterator(lines, trainer=trainer) |
| del lines |
|
|
| self._hf_tok = tok |
|
|
| |
| data = json.loads(tok.to_str()) |
| hf_merges = data['model']['merges'] |
| hf_vocab = data['model']['vocab'] |
| b2u = _byte_to_unicode() |
|
|
| |
| str_to_our = {} |
| for bv in range(256): |
| str_to_our[b2u[bv]] = bv |
|
|
| self.merges = [] |
| for i, ms in enumerate(hf_merges): |
| if i >= num_merges: |
| break |
| |
| if isinstance(ms, list): |
| if len(ms) != 2: |
| continue |
| a_str, b_str = ms[0], ms[1] |
| else: |
| parts = ms.split(' ', 1) |
| if len(parts) != 2: |
| continue |
| a_str, b_str = parts[0], parts[1] |
| if a_str not in str_to_our or b_str not in str_to_our: |
| continue |
| a_id = str_to_our[a_str] |
| b_id = str_to_our[b_str] |
| new_id = 256 + len(self.merges) |
| self.merges.append((a_id, b_id, new_id)) |
| str_to_our[a_str + b_str] = new_id |
| if (i + 1) % report_every == 0: |
| print(f" [BPE] {i + 1}/{len(hf_merges)} merges converted") |
|
|
| self.vocab_size = 256 + len(self.merges) |
|
|
| |
| hf_to_our = {} |
| for bv in range(256): |
| uc = b2u[bv] |
| if uc in hf_vocab: |
| hf_to_our[hf_vocab[uc]] = bv |
| for tok_str, our_id in str_to_our.items(): |
| if tok_str in hf_vocab and our_id >= 256: |
| hf_to_our[hf_vocab[tok_str]] = our_id |
|
|
| max_hf = max(hf_to_our.keys()) + 1 if hf_to_our else 256 |
| self._remap_lut = np.arange(max_hf, dtype=np.int32) |
| for hf_id, our_id in hf_to_our.items(): |
| self._remap_lut[hf_id] = our_id |
| self._hf_to_our = hf_to_our |
|
|
| print(f" [BPE] Done: {len(self.merges)} merges, vocab={self.vocab_size}") |
|
|
| def encode(self, text): |
| """Encode text to our token IDs. Fast (Rust + numpy remap).""" |
| if isinstance(text, bytes): |
| text = text.decode('utf-8', errors='replace') |
|
|
| if self._hf_tok is not None and self._remap_lut is not None: |
| hf_ids = np.array(self._hf_tok.encode(text).ids, dtype=np.int32) |
| return self._remap_lut[hf_ids].tolist() |
|
|
| |
| if isinstance(text, str): |
| text = text.encode('utf-8', errors='replace') |
| ids = list(text) |
| for a, b, new_id in self.merges: |
| new_ids = [] |
| i = 0 |
| while i < len(ids): |
| if i < len(ids) - 1 and ids[i] == a and ids[i + 1] == b: |
| new_ids.append(new_id) |
| i += 2 |
| else: |
| new_ids.append(ids[i]) |
| i += 1 |
| ids = new_ids |
| return ids |
|
|
| def decode(self, ids): |
| """Decode token IDs to bytes.""" |
| vocab = {} |
| for i in range(256): |
| vocab[i] = bytes([i]) |
| for a, b, new_id in self.merges: |
| vocab[new_id] = vocab[a] + vocab[b] |
| out = b'' |
| for tid in ids: |
| out += vocab.get(tid, b'?') |
| return out |
|
|
| def save(self, path): |
| """Save binary merges (C) + HF JSON + ID map.""" |
| with open(path, 'wb') as f: |
| f.write(struct.pack('<I', len(self.merges))) |
| for a, b, new_id in self.merges: |
| f.write(struct.pack('<III', a, b, new_id)) |
| print(f" [BPE] Saved {len(self.merges)} merges to {path}") |
|
|
| base = os.path.splitext(path)[0] |
| if self._hf_tok: |
| jp = base + '_hf.json' |
| self._hf_tok.save(jp) |
| print(f" [BPE] Saved HF tokenizer to {jp}") |
|
|
| if self._hf_to_our: |
| mp = base + '_idmap.json' |
| with open(mp, 'w') as f: |
| json.dump({str(k): v for k, v in self._hf_to_our.items()}, f) |
|
|
| def load(self, path): |
| """Load tokenizer from binary + optional HF JSON for fast encode.""" |
| with open(path, 'rb') as f: |
| n = struct.unpack('<I', f.read(4))[0] |
| self.merges = [] |
| for _ in range(n): |
| a, b, new_id = struct.unpack('<III', f.read(12)) |
| self.merges.append((a, b, new_id)) |
| self.vocab_size = 256 + len(self.merges) |
| print(f" [BPE] Loaded {len(self.merges)} merges from {path}, vocab={self.vocab_size}") |
|
|
| base = os.path.splitext(path)[0] |
| jp = base + '_hf.json' |
| mp = base + '_idmap.json' |
| if os.path.exists(jp) and os.path.exists(mp): |
| from tokenizers import Tokenizer |
| self._hf_tok = Tokenizer.from_file(jp) |
| with open(mp) as f: |
| raw = json.load(f) |
| hf_to_our = {int(k): v for k, v in raw.items()} |
| max_hf = max(hf_to_our.keys()) + 1 |
| self._remap_lut = np.arange(max_hf, dtype=np.int32) |
| for hf_id, our_id in hf_to_our.items(): |
| self._remap_lut[hf_id] = our_id |
| self._hf_to_our = hf_to_our |
| print(f" [BPE] Loaded HF tokenizer for fast encode") |
|
|
| def save_copies(self, base_path, n=3): |
| """Save tokenizer in N copies. Lesson from Janus 285M disaster.""" |
| paths = [] |
| for i in range(n): |
| if i == 0: |
| p = base_path |
| else: |
| name, ext = os.path.splitext(base_path) |
| p = f"{name}_backup{i}{ext}" |
| self.save(p) |
| paths.append(p) |
| print(f" [BPE] Saved {n} copies: {paths}") |
| return paths |
|
|