| """Compact Chess BPE Tokenizer: Train, Upload, Load, Inference""" |
| import os, json, rustbpe, tiktoken |
| from datasets import load_dataset |
| from huggingface_hub import HfApi, create_repo, upload_folder, hf_hub_download |
|
|
| REPO_ID = "ItsMaxNorm/bpess" |
|
|
| def train(vocab_size=4096, split="train[0:10000]"): |
| """Train BPE tokenizer on chess moves.""" |
| ds = load_dataset('angeluriot/chess_games', split=split) |
| tok = rustbpe.Tokenizer() |
| tok.train_from_iterator((' '.join(g['moves_custom']) for g in ds if g['moves_custom']), vocab_size) |
| return tok |
|
|
| def save(tok, path="./tokenizer"): |
| """Save tokenizer files locally.""" |
| os.makedirs(path, exist_ok=True) |
| ranks = tok.get_mergeable_ranks() |
| json.dump({bytes(k).decode('utf-8', errors='replace'): v for k, v in ranks}, |
| open(f"{path}/vocab.json", 'w'), indent=2) |
| json.dump({"pattern": tok.get_pattern(), "vocab_size": tok.vocab_size}, |
| open(f"{path}/config.json", 'w')) |
| return path |
|
|
| def upload(tok, repo_id=REPO_ID, private=False): |
| """Upload tokenizer to HuggingFace Hub.""" |
| path = save(tok) |
| try: create_repo(repo_id, private=private) |
| except: pass |
| HfApi().upload_folder(folder_path=path, repo_id=repo_id) |
| print(f"Uploaded: https://huggingface.co/{repo_id}") |
|
|
| def load_tiktoken(repo_id=REPO_ID): |
| """Load tokenizer from HuggingFace as tiktoken Encoding.""" |
| config = json.load(open(hf_hub_download(repo_id, "config.json"))) |
| vocab = json.load(open(hf_hub_download(repo_id, "vocab.json"))) |
| return tiktoken.Encoding( |
| name="chess", pat_str=config["pattern"], |
| mergeable_ranks={k.encode('utf-8', errors='replace'): v for k, v in vocab.items()}, |
| special_tokens={} |
| ) |
|
|
| if __name__ == "__main__": |
| |
| tok = train(vocab_size=4096, split="train[0:10000]") |
| print(f"Trained: {tok.vocab_size} tokens") |
| upload(tok, REPO_ID) |
|
|
| |
| enc = load_tiktoken(REPO_ID) |
| test = "w.♘g1♘f3.. b.♟c7♟c5.. w.♙d2♙d4.." |
| ids = enc.encode(test) |
| print(f"Encoded: {ids[:10]}... ({len(ids)} tokens)") |
| print(f"Decoded: {enc.decode(ids)}") |
|
|