resonance / bpe_tokenizer.py
ataeff's picture
Add bpe_tokenizer.py
7bffa1f verified
"""
BPE tokenizer for resonance-200m.
Uses HuggingFace tokenizers (Rust backend) for fast training + encoding.
Saves merge rules in binary format compatible with C inference.
Replaces naive Python BPE (O(n²) per merge = days on 200MB).
Rust backend: minutes.
"""
import struct
import os
import json
import numpy as np
def _byte_to_unicode():
"""GPT-2 byte-to-unicode mapping (ByteLevel pre-tokenizer)."""
bs = (list(range(ord("!"), ord("~") + 1)) +
list(range(ord("¡"), ord("¬") + 1)) +
list(range(ord("®"), ord("ÿ") + 1)))
cs = bs[:]
n = 0
for b in range(256):
if b not in bs:
bs.append(b)
cs.append(256 + n)
n += 1
return {b: chr(c) for b, c in zip(bs, cs)}
class BPETokenizer:
"""BPE tokenizer. 256 byte tokens + learned merges.
Rust backend for speed. Binary format for C inference."""
def __init__(self, max_merges=15936):
self.max_merges = max_merges
self.merges = [] # (a, b, new_id) — C format
self.vocab_size = 256
self._hf_tok = None
self._remap_lut = None # numpy LUT: HF_id → our_id
def train(self, text_bytes, num_merges=None, report_every=2000):
"""Learn BPE merges using Rust backend. Minutes, not days."""
from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders
if num_merges is None:
num_merges = self.max_merges
num_merges = min(num_merges, self.max_merges)
target_vocab = 256 + num_merges
print(f" [BPE] Training {num_merges} merges on {len(text_bytes)} bytes (Rust backend)...")
tok = Tokenizer(models.BPE())
tok.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tok.decoder = decoders.ByteLevel()
trainer = trainers.BpeTrainer(
vocab_size=target_vocab,
min_frequency=2,
special_tokens=[],
initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
show_progress=True,
)
text = text_bytes.decode('utf-8', errors='replace')
lines = text.split('\n')
del text
tok.train_from_iterator(lines, trainer=trainer)
del lines
self._hf_tok = tok
# Extract merges in our (a, b, new_id) format for C inference
data = json.loads(tok.to_str())
hf_merges = data['model']['merges']
hf_vocab = data['model']['vocab']
b2u = _byte_to_unicode()
# str → our_id mapping for merge conversion
str_to_our = {}
for bv in range(256):
str_to_our[b2u[bv]] = bv
self.merges = []
for i, ms in enumerate(hf_merges):
if i >= num_merges:
break
# HF tokenizers >=0.20 returns lists ['a','b'], older returns "a b"
if isinstance(ms, list):
if len(ms) != 2:
continue
a_str, b_str = ms[0], ms[1]
else:
parts = ms.split(' ', 1)
if len(parts) != 2:
continue
a_str, b_str = parts[0], parts[1]
if a_str not in str_to_our or b_str not in str_to_our:
continue
a_id = str_to_our[a_str]
b_id = str_to_our[b_str]
new_id = 256 + len(self.merges)
self.merges.append((a_id, b_id, new_id))
str_to_our[a_str + b_str] = new_id
if (i + 1) % report_every == 0:
print(f" [BPE] {i + 1}/{len(hf_merges)} merges converted")
self.vocab_size = 256 + len(self.merges)
# Build HF→our remap LUT (numpy vectorized lookup)
hf_to_our = {}
for bv in range(256):
uc = b2u[bv]
if uc in hf_vocab:
hf_to_our[hf_vocab[uc]] = bv
for tok_str, our_id in str_to_our.items():
if tok_str in hf_vocab and our_id >= 256:
hf_to_our[hf_vocab[tok_str]] = our_id
max_hf = max(hf_to_our.keys()) + 1 if hf_to_our else 256
self._remap_lut = np.arange(max_hf, dtype=np.int32)
for hf_id, our_id in hf_to_our.items():
self._remap_lut[hf_id] = our_id
self._hf_to_our = hf_to_our
print(f" [BPE] Done: {len(self.merges)} merges, vocab={self.vocab_size}")
def encode(self, text):
"""Encode text to our token IDs. Fast (Rust + numpy remap)."""
if isinstance(text, bytes):
text = text.decode('utf-8', errors='replace')
if self._hf_tok is not None and self._remap_lut is not None:
hf_ids = np.array(self._hf_tok.encode(text).ids, dtype=np.int32)
return self._remap_lut[hf_ids].tolist()
# Slow fallback (binary-only load, no HF JSON)
if isinstance(text, str):
text = text.encode('utf-8', errors='replace')
ids = list(text)
for a, b, new_id in self.merges:
new_ids = []
i = 0
while i < len(ids):
if i < len(ids) - 1 and ids[i] == a and ids[i + 1] == b:
new_ids.append(new_id)
i += 2
else:
new_ids.append(ids[i])
i += 1
ids = new_ids
return ids
def decode(self, ids):
"""Decode token IDs to bytes."""
vocab = {}
for i in range(256):
vocab[i] = bytes([i])
for a, b, new_id in self.merges:
vocab[new_id] = vocab[a] + vocab[b]
out = b''
for tid in ids:
out += vocab.get(tid, b'?')
return out
def save(self, path):
"""Save binary merges (C) + HF JSON + ID map."""
with open(path, 'wb') as f:
f.write(struct.pack('<I', len(self.merges)))
for a, b, new_id in self.merges:
f.write(struct.pack('<III', a, b, new_id))
print(f" [BPE] Saved {len(self.merges)} merges to {path}")
base = os.path.splitext(path)[0]
if self._hf_tok:
jp = base + '_hf.json'
self._hf_tok.save(jp)
print(f" [BPE] Saved HF tokenizer to {jp}")
if self._hf_to_our:
mp = base + '_idmap.json'
with open(mp, 'w') as f:
json.dump({str(k): v for k, v in self._hf_to_our.items()}, f)
def load(self, path):
"""Load tokenizer from binary + optional HF JSON for fast encode."""
with open(path, 'rb') as f:
n = struct.unpack('<I', f.read(4))[0]
self.merges = []
for _ in range(n):
a, b, new_id = struct.unpack('<III', f.read(12))
self.merges.append((a, b, new_id))
self.vocab_size = 256 + len(self.merges)
print(f" [BPE] Loaded {len(self.merges)} merges from {path}, vocab={self.vocab_size}")
base = os.path.splitext(path)[0]
jp = base + '_hf.json'
mp = base + '_idmap.json'
if os.path.exists(jp) and os.path.exists(mp):
from tokenizers import Tokenizer
self._hf_tok = Tokenizer.from_file(jp)
with open(mp) as f:
raw = json.load(f)
hf_to_our = {int(k): v for k, v in raw.items()}
max_hf = max(hf_to_our.keys()) + 1
self._remap_lut = np.arange(max_hf, dtype=np.int32)
for hf_id, our_id in hf_to_our.items():
self._remap_lut[hf_id] = our_id
self._hf_to_our = hf_to_our
print(f" [BPE] Loaded HF tokenizer for fast encode")
def save_copies(self, base_path, n=3):
"""Save tokenizer in N copies. Lesson from Janus 285M disaster."""
paths = []
for i in range(n):
if i == 0:
p = base_path
else:
name, ext = os.path.splitext(base_path)
p = f"{name}_backup{i}{ext}"
self.save(p)
paths.append(p)
print(f" [BPE] Saved {n} copies: {paths}")
return paths