| """ |
| Decomposed Chess Tokenizer (v2) for the Chess Challenge. |
| |
| This tokenizer factorizes each move into a small set of reusable tokens: |
| - One token for (color + piece): e.g. "WP", "BN" |
| - One token for the from-square with role suffix: e.g. "e2_f" |
| - One token for the to-square with role suffix: e.g. "e4_t" |
| - Optional promotion token: "q", "r", "b", "n" |
| |
| It is compatible with the teacher evaluator's supported formats: |
| - Standard: "WPe2e4", "BNg8f6", with optional annotations "(x)", "(+)", "(o)/(O)", "(Q)" |
| - Decomposed: "WP e2_f e4_t" |
| - UCI: "e2e4", "e7e8q" |
| - UCI spaced: "e2 e4" |
| |
| The tokenizer parses those inputs and emits the decomposed tokens above. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import re |
| from pathlib import Path |
| from typing import Dict, List, Optional |
|
|
| from transformers import PreTrainedTokenizer |
|
|
|
|
| class ChessTokenizer(PreTrainedTokenizer): |
| model_input_names = ["input_ids", "attention_mask"] |
| vocab_files_names = {"vocab_file": "vocab.json"} |
|
|
| PAD_TOKEN = "[PAD]" |
| BOS_TOKEN = "[BOS]" |
| EOS_TOKEN = "[EOS]" |
| UNK_TOKEN = "[UNK]" |
|
|
| _COLOR_PIECE_RE = re.compile(r"^[WB][PNBRQK]$") |
| _SQUARE_RE = re.compile(r"[a-h][1-8]") |
| _SQUARE_ROLE_RE = re.compile(r"^([a-h][1-8])_([ft])$", re.IGNORECASE) |
| _PLAIN_SQUARE_RE = re.compile(r"^[a-h][1-8]$", re.IGNORECASE) |
|
|
| def __init__( |
| self, |
| vocab_file: Optional[str] = None, |
| vocab: Optional[Dict[str, int]] = None, |
| **kwargs, |
| ): |
| self._pad_token = self.PAD_TOKEN |
| self._bos_token = self.BOS_TOKEN |
| self._eos_token = self.EOS_TOKEN |
| self._unk_token = self.UNK_TOKEN |
|
|
| |
| kwargs.pop("pad_token", None) |
| kwargs.pop("bos_token", None) |
| kwargs.pop("eos_token", None) |
| kwargs.pop("unk_token", None) |
|
|
| if vocab is not None: |
| self._vocab = vocab |
| elif vocab_file is not None and os.path.exists(vocab_file): |
| with open(vocab_file, "r", encoding="utf-8") as f: |
| self._vocab = json.load(f) |
| else: |
| self._vocab = self._create_default_vocab() |
|
|
| self._ids_to_tokens = {v: k for k, v in self._vocab.items()} |
|
|
| super().__init__( |
| pad_token=self._pad_token, |
| bos_token=self._bos_token, |
| eos_token=self._eos_token, |
| unk_token=self._unk_token, |
| **kwargs, |
| ) |
|
|
| @classmethod |
| def build_vocab_from_dataset( |
| cls, |
| *_, |
| **__, |
| ) -> "ChessTokenizer2": |
| """ |
| Kept for API compatibility with `train.py`. |
| |
| The v2 tokenizer uses a fixed vocabulary (colors/pieces/squares/promotions), |
| so dataset statistics are not required. |
| """ |
| return cls() |
|
|
| def _create_default_vocab(self) -> Dict[str, int]: |
| special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN] |
|
|
| color_pieces = [ |
| f"{color}{piece}" |
| for color in ("W", "B") |
| for piece in ("P", "N", "B", "R", "Q", "K") |
| ] |
|
|
| squares = [f"{file}{rank}" for rank in range(1, 9) for file in "abcdefgh"] |
| square_from = [f"{sq}_f" for sq in squares] |
| square_to = [f"{sq}_t" for sq in squares] |
|
|
| promotions = ["q", "r", "b", "n"] |
|
|
| |
| all_tokens = special_tokens + color_pieces + square_from + square_to + promotions |
| return {tok: idx for idx, tok in enumerate(all_tokens)} |
|
|
| @property |
| def vocab_size(self) -> int: |
| return len(self._vocab) |
|
|
| def get_vocab(self) -> Dict[str, int]: |
| return dict(self._vocab) |
|
|
| def _tokenize(self, text: str) -> List[str]: |
| parts = text.strip().split() |
| if not parts: |
| return [] |
|
|
| out: List[str] = [] |
| next_role = "f" |
|
|
| for part in parts: |
| if part in {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}: |
| out.append(part) |
| next_role = "f" |
| continue |
|
|
| |
| if self._COLOR_PIECE_RE.match(part.upper()): |
| out.append(part.upper()) |
| next_role = "f" |
| continue |
|
|
| |
| m_role = self._SQUARE_ROLE_RE.match(part) |
| if m_role: |
| sq = m_role.group(1).lower() |
| role = m_role.group(2).lower() |
| out.append(f"{sq}_{role}") |
| next_role = "t" if role == "f" else "f" |
| continue |
|
|
| |
| if self._PLAIN_SQUARE_RE.match(part): |
| sq = part.lower() |
| out.append(f"{sq}_{next_role}") |
| next_role = "t" if next_role == "f" else "f" |
| continue |
|
|
| |
| promo = self._extract_promotion(part) |
| if promo and self._looks_like_promo_only(part): |
| out.append(promo) |
| continue |
|
|
| |
| move_tokens = self._tokenize_move_chunk(part) |
| if move_tokens: |
| out.extend(move_tokens) |
| next_role = "f" |
| continue |
|
|
| |
| if re.fullmatch(r"[\(\)\+\*xoO=]+", part): |
| continue |
|
|
| out.append(self.UNK_TOKEN) |
|
|
| return out |
|
|
| def _looks_like_promo_only(self, part: str) -> bool: |
| part_stripped = part.strip() |
| if re.fullmatch(r"[qrbnQRBN]", part_stripped): |
| return True |
| if re.fullmatch(r"=[qrbnQRBN]", part_stripped): |
| return True |
| if re.fullmatch(r"\([qrbnQRBN]\)", part_stripped): |
| return True |
| return False |
|
|
| def _extract_promotion(self, text: str) -> Optional[str]: |
| text_lower = text.lower() |
| m = re.search(r"\(([qrbn])\)", text_lower) |
| if m: |
| return m.group(1) |
| m = re.search(r"=([qrbn])", text_lower) |
| if m: |
| return m.group(1) |
| return None |
|
|
| def _tokenize_move_chunk(self, chunk: str) -> List[str]: |
| chunk_stripped = chunk.strip() |
| if not chunk_stripped: |
| return [] |
|
|
| chunk_lower = chunk_stripped.lower() |
| squares = re.findall(self._SQUARE_RE, chunk_lower) |
| if len(squares) < 2: |
| return [] |
|
|
| from_sq, to_sq = squares[0], squares[1] |
|
|
| color_piece = None |
| if len(chunk_stripped) >= 2 and self._COLOR_PIECE_RE.match(chunk_stripped[:2].upper()): |
| color_piece = chunk_stripped[:2].upper() |
|
|
| tokens: List[str] = [] |
| if color_piece: |
| tokens.append(color_piece) |
|
|
| tokens.append(f"{from_sq}_f") |
| tokens.append(f"{to_sq}_t") |
|
|
| |
| after_to = chunk_lower.find(to_sq) |
| if after_to != -1: |
| remaining = chunk_lower[after_to + 2 : after_to + 6] |
| m = re.search(r"[=]?([qrbn])", remaining) |
| if m: |
| tokens.append(m.group(1)) |
|
|
| |
| promo = self._extract_promotion(chunk_stripped) |
| if promo and promo not in tokens: |
| tokens.append(promo) |
|
|
| return tokens |
|
|
| def _convert_token_to_id(self, token: str) -> int: |
| return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0)) |
|
|
| def _convert_id_to_token(self, index: int) -> str: |
| return self._ids_to_tokens.get(index, self.UNK_TOKEN) |
|
|
| def convert_tokens_to_string(self, tokens: List[str]) -> str: |
| special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN} |
| return " ".join(t for t in tokens if t not in special) |
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple: |
| if not os.path.isdir(save_directory): |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| vocab_file = os.path.join( |
| save_directory, |
| (filename_prefix + "-" if filename_prefix else "") + "vocab.json", |
| ) |
|
|
| with open(vocab_file, "w", encoding="utf-8") as f: |
| json.dump(self._vocab, f, ensure_ascii=False, indent=2) |
|
|
| return (vocab_file,) |