swdo commited on
Commit
54a50ed
·
verified ·
1 Parent(s): 0ac263c

Chess Challenge submission by swdo

Browse files
Files changed (2) hide show
  1. config.json +2 -2
  2. tokenizer.py +84 -205
config.json CHANGED
@@ -3,8 +3,8 @@
3
  "ChessTRMForCausalLM"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "model.ChessConfig",
7
- "AutoModelForCausalLM": "model.ChessForCausalLM"
8
  },
9
  "bos_token_id": 1,
10
  "dropout": 0.1,
 
3
  "ChessTRMForCausalLM"
4
  ],
5
  "auto_map": {
6
+ "AutoConfig": "model.ChessTRMConfig",
7
+ "AutoModelForCausalLM": "model.ChessTRMForCausalLM"
8
  },
9
  "bos_token_id": 1,
10
  "dropout": 0.1,
tokenizer.py CHANGED
@@ -1,91 +1,63 @@
1
  """
2
- Custom Chess Tokenizer for the Chess Challenge.
3
 
4
- This tokenizer treats each move as a single token using the extended UCI notation
5
- from the Lichess dataset (e.g., WPe2e4, BNg8f6).
 
 
 
6
 
7
- The dataset format uses:
8
- - W/B prefix for White/Black
9
- - Piece letter: P=Pawn, N=Knight, B=Bishop, R=Rook, Q=Queen, K=King
10
- - Source and destination squares (e.g., e2e4)
11
- - Special suffixes: (x)=capture, (+)=check, (+*)=checkmate, (o)/(O)=castling
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  import json
17
  import os
18
- from pathlib import Path
19
  from typing import Dict, List, Optional
20
 
21
  from transformers import PreTrainedTokenizer
22
 
23
 
24
- class ChessTokenizer(PreTrainedTokenizer):
25
- """
26
- A custom tokenizer for chess moves using extended UCI notation.
27
-
28
- This tokenizer maps each possible chess move to a unique token ID.
29
- The vocabulary is built from the training dataset to ensure all moves
30
- encountered during training have a corresponding token.
31
-
32
- Example:
33
- >>> tokenizer = ChessTokenizer()
34
- >>> tokenizer.encode("WPe2e4 BPe7e5")
35
- [1, 42, 87, 2] # [BOS, e2e4, e7e5, EOS]
36
- """
37
-
38
  model_input_names = ["input_ids", "attention_mask"]
39
  vocab_files_names = {"vocab_file": "vocab.json"}
40
-
41
- # Special tokens
42
  PAD_TOKEN = "[PAD]"
43
  BOS_TOKEN = "[BOS]"
44
  EOS_TOKEN = "[EOS]"
45
  UNK_TOKEN = "[UNK]"
46
-
 
 
47
  def __init__(
48
  self,
49
  vocab_file: Optional[str] = None,
50
  vocab: Optional[Dict[str, int]] = None,
51
  **kwargs,
52
  ):
53
- """
54
- Initialize the chess tokenizer.
55
-
56
- Args:
57
- vocab_file: Path to a JSON file containing the vocabulary mapping.
58
- vocab: Dictionary mapping tokens to IDs (alternative to vocab_file).
59
- **kwargs: Additional arguments passed to PreTrainedTokenizer.
60
- """
61
- # Initialize special tokens
62
  self._pad_token = self.PAD_TOKEN
63
  self._bos_token = self.BOS_TOKEN
64
  self._eos_token = self.EOS_TOKEN
65
  self._unk_token = self.UNK_TOKEN
66
 
67
- # Remove any duplicate special-token entries passed through kwargs
68
- # to avoid "multiple values for keyword" errors when loading from disk.
69
  kwargs.pop("pad_token", None)
70
  kwargs.pop("bos_token", None)
71
  kwargs.pop("eos_token", None)
72
  kwargs.pop("unk_token", None)
73
-
74
- # Load or create vocabulary
75
  if vocab is not None:
76
  self._vocab = vocab
77
  elif vocab_file is not None and os.path.exists(vocab_file):
78
  with open(vocab_file, "r", encoding="utf-8") as f:
79
  self._vocab = json.load(f)
80
  else:
81
- # Create a minimal vocabulary with just special tokens
82
- # The full vocabulary should be built from the dataset
83
- self._vocab = self._create_default_vocab()
84
-
85
- # Create reverse mapping
86
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
87
-
88
- # Call parent init AFTER setting up vocab
89
  super().__init__(
90
  pad_token=self._pad_token,
91
  bos_token=self._bos_token,
@@ -93,186 +65,93 @@ class ChessTokenizer(PreTrainedTokenizer):
93
  unk_token=self._unk_token,
94
  **kwargs,
95
  )
96
-
97
- def _create_default_vocab(self) -> Dict[str, int]:
98
- """
99
- Create a minimal default vocabulary with just special tokens.
100
-
101
- For the full vocabulary, use `build_vocab_from_dataset()`.
102
- This minimal vocab is just a placeholder - you should build from data.
103
- """
104
- special_tokens = [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
105
- vocab = {token: idx for idx, token in enumerate(special_tokens)}
106
- return vocab
107
-
108
- @classmethod
109
- def build_vocab_from_iterator(
110
- cls,
111
- iterator,
112
- min_frequency: int = 1,
113
- ) -> "ChessTokenizer":
114
- """
115
- Build a tokenizer vocabulary from an iterator of game strings.
116
-
117
- Args:
118
- iterator: An iterator yielding game strings (space-separated moves).
119
- min_frequency: Minimum frequency for a token to be included.
120
-
121
- Returns:
122
- A ChessTokenizer with the built vocabulary.
123
- """
124
- from collections import Counter
125
-
126
- token_counts = Counter()
127
-
128
- for game in iterator:
129
- moves = game.strip().split()
130
- token_counts.update(moves)
131
-
132
- # Filter by frequency
133
- tokens = [
134
- token for token, count in token_counts.items()
135
- if count >= min_frequency
136
  ]
137
-
138
- # Sort for reproducibility
139
- tokens = sorted(tokens)
140
-
141
- # Build vocabulary
142
- special_tokens = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
143
- vocab = {token: idx for idx, token in enumerate(special_tokens + tokens)}
144
-
145
- return cls(vocab=vocab)
146
-
147
- @classmethod
148
- def build_vocab_from_dataset(
149
- cls,
150
- dataset_name: str = "dlouapre/lichess_2025-01_1M",
151
- split: str = "train",
152
- column: str = "text",
153
- min_frequency: int = 500,
154
- max_samples: Optional[int] = 100000,
155
- ) -> "ChessTokenizer":
156
- """
157
- Build a tokenizer vocabulary from a Hugging Face dataset.
158
-
159
- Args:
160
- dataset_name: Name of the dataset on Hugging Face Hub.
161
- split: Dataset split to use.
162
- column: Column containing the game strings.
163
- min_frequency: Minimum frequency for a token to be included (default: 500).
164
- max_samples: Maximum number of samples to process (default: 100k).
165
-
166
- Returns:
167
- A ChessTokenizer with the built vocabulary.
168
- """
169
- from datasets import load_dataset
170
-
171
- dataset = load_dataset(dataset_name, cache_dir=os.environ["HF_HOME"], split=split)
172
-
173
- if max_samples is not None:
174
- dataset = dataset.select(range(min(max_samples, len(dataset))))
175
-
176
- def game_iterator():
177
- for example in dataset:
178
- yield example[column]
179
-
180
- return cls.build_vocab_from_iterator(game_iterator(), min_frequency=min_frequency)
181
-
182
  @property
183
  def vocab_size(self) -> int:
184
- """Return the size of the vocabulary."""
185
  return len(self._vocab)
186
-
187
  def get_vocab(self) -> Dict[str, int]:
188
- """Return the vocabulary as a dictionary."""
189
  return dict(self._vocab)
190
-
191
  def _tokenize(self, text: str) -> List[str]:
192
- """
193
- Tokenize a string of moves into a list of tokens.
194
-
195
- Args:
196
- text: A string of space-separated moves.
197
-
198
- Returns:
199
- List of move tokens.
200
- """
201
- return text.strip().split()
202
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def _convert_token_to_id(self, token: str) -> int:
204
- """Convert a token to its ID."""
205
  return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
206
-
207
  def _convert_id_to_token(self, index: int) -> str:
208
- """Convert an ID to its token."""
209
  return self._ids_to_tokens.get(index, self.UNK_TOKEN)
210
-
211
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
212
- """Convert a list of tokens back to a string."""
213
- # Filter out special tokens for cleaner output
214
  special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
215
  return " ".join(t for t in tokens if t not in special)
216
-
217
- def save_vocabulary(
218
- self,
219
- save_directory: str,
220
- filename_prefix: Optional[str] = None,
221
- ) -> tuple:
222
- """
223
- Save the vocabulary to a JSON file.
224
-
225
- Args:
226
- save_directory: Directory to save the vocabulary.
227
- filename_prefix: Optional prefix for the filename.
228
-
229
- Returns:
230
- Tuple containing the path to the saved vocabulary file.
231
- """
232
  if not os.path.isdir(save_directory):
233
  os.makedirs(save_directory, exist_ok=True)
234
-
235
  vocab_file = os.path.join(
236
  save_directory,
237
  (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
238
  )
239
-
240
  with open(vocab_file, "w", encoding="utf-8") as f:
241
  json.dump(self._vocab, f, ensure_ascii=False, indent=2)
242
-
243
- return (vocab_file,)
244
 
 
245
 
246
- def count_vocab_from_dataset(
247
- dataset_name: str = "dlouapre/lichess_2025-01_1M",
248
- split: str = "train",
249
- column: str = "text",
250
- max_samples: Optional[int] = 10000,
251
- ) -> Dict[str, int]:
252
- """
253
- Count token frequencies in a dataset (useful for vocabulary analysis).
254
-
255
- Args:
256
- dataset_name: Name of the dataset on Hugging Face Hub.
257
- split: Dataset split to use.
258
- column: Column containing the game strings.
259
- max_samples: Maximum number of samples to process.
260
-
261
- Returns:
262
- Dictionary mapping tokens to their frequencies.
263
- """
264
- from collections import Counter
265
- from datasets import load_dataset
266
-
267
- dataset = load_dataset(dataset_name, cache_dir=os.environ["HF_HOME"], split=split)
268
-
269
- if max_samples is not None:
270
- dataset = dataset.select(range(min(max_samples, len(dataset))))
271
-
272
- token_counts = Counter()
273
-
274
- for example in dataset:
275
- moves = example[column].strip().split()
276
- token_counts.update(moves)
277
-
278
- return dict(token_counts)
 
1
  """
2
+ Decomposed Chess Tokenizer.
3
 
4
+ This tokenizer decomposes each move into 3-4 tokens:
5
+ - color+piece token (e.g., "WP", "BN")
6
+ - from-square token with suffix "_f" (e.g., "e2_f")
7
+ - to-square token with suffix "_t" (e.g., "e4_t")
8
+ - optional promotion token (one of "q", "r", "b", "n")
9
 
10
+ This avoids UNKs for rare moves and makes legality learning easier because the model
11
+ always emits explicit squares.
 
 
 
12
  """
13
 
14
  from __future__ import annotations
15
 
16
  import json
17
  import os
18
+ import re
19
  from typing import Dict, List, Optional
20
 
21
  from transformers import PreTrainedTokenizer
22
 
23
 
24
+ class ChessDecomposedTokenizer(PreTrainedTokenizer):
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  model_input_names = ["input_ids", "attention_mask"]
26
  vocab_files_names = {"vocab_file": "vocab.json"}
27
+
 
28
  PAD_TOKEN = "[PAD]"
29
  BOS_TOKEN = "[BOS]"
30
  EOS_TOKEN = "[EOS]"
31
  UNK_TOKEN = "[UNK]"
32
+
33
+ _MOVE_RE = re.compile(r"^[WB][PNBRQK][a-h][1-8][a-h][1-8].*$")
34
+
35
  def __init__(
36
  self,
37
  vocab_file: Optional[str] = None,
38
  vocab: Optional[Dict[str, int]] = None,
39
  **kwargs,
40
  ):
 
 
 
 
 
 
 
 
 
41
  self._pad_token = self.PAD_TOKEN
42
  self._bos_token = self.BOS_TOKEN
43
  self._eos_token = self.EOS_TOKEN
44
  self._unk_token = self.UNK_TOKEN
45
 
 
 
46
  kwargs.pop("pad_token", None)
47
  kwargs.pop("bos_token", None)
48
  kwargs.pop("eos_token", None)
49
  kwargs.pop("unk_token", None)
50
+
 
51
  if vocab is not None:
52
  self._vocab = vocab
53
  elif vocab_file is not None and os.path.exists(vocab_file):
54
  with open(vocab_file, "r", encoding="utf-8") as f:
55
  self._vocab = json.load(f)
56
  else:
57
+ self._vocab = self._create_full_vocab()
58
+
 
 
 
59
  self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
60
+
 
61
  super().__init__(
62
  pad_token=self._pad_token,
63
  bos_token=self._bos_token,
 
65
  unk_token=self._unk_token,
66
  **kwargs,
67
  )
68
+
69
+ @staticmethod
70
+ def _create_full_vocab() -> Dict[str, int]:
71
+ special_tokens = [
72
+ ChessDecomposedTokenizer.PAD_TOKEN,
73
+ ChessDecomposedTokenizer.BOS_TOKEN,
74
+ ChessDecomposedTokenizer.EOS_TOKEN,
75
+ ChessDecomposedTokenizer.UNK_TOKEN,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  ]
77
+
78
+ pieces = ["P", "N", "B", "R", "Q", "K"]
79
+ colors = ["W", "B"]
80
+ piece_tokens = [f"{c}{p}" for c in colors for p in pieces]
81
+
82
+ files = "abcdefgh"
83
+ ranks = "12345678"
84
+ squares = [f"{f}{r}" for f in files for r in ranks]
85
+ from_tokens = [f"{sq}_f" for sq in squares]
86
+ to_tokens = [f"{sq}_t" for sq in squares]
87
+
88
+ promo_tokens = ["q", "r", "b", "n"]
89
+
90
+ tokens = special_tokens + piece_tokens + from_tokens + to_tokens + promo_tokens
91
+ return {tok: idx for idx, tok in enumerate(tokens)}
92
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  @property
94
  def vocab_size(self) -> int:
 
95
  return len(self._vocab)
96
+
97
  def get_vocab(self) -> Dict[str, int]:
 
98
  return dict(self._vocab)
99
+
100
  def _tokenize(self, text: str) -> List[str]:
101
+ raw = text.strip()
102
+ if not raw:
103
+ return []
104
+
105
+ parts = raw.split()
106
+ out: List[str] = []
107
+
108
+ for part in parts:
109
+ if part in {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}:
110
+ out.append(part)
111
+ continue
112
+
113
+ if not self._MOVE_RE.match(part):
114
+ out.append(self.UNK_TOKEN)
115
+ continue
116
+
117
+ color = part[0]
118
+ piece = part[1]
119
+ from_sq = part[2:4]
120
+ to_sq = part[4:6]
121
+ out.append(f"{color}{piece}")
122
+ out.append(f"{from_sq}_f")
123
+ out.append(f"{to_sq}_t")
124
+
125
+ if "=" in part:
126
+ promo_idx = part.find("=")
127
+ if promo_idx != -1 and promo_idx + 1 < len(part):
128
+ promo = part[promo_idx + 1].lower()
129
+ if promo in {"q", "r", "b", "n"}:
130
+ out.append(promo)
131
+
132
+ return out
133
+
134
  def _convert_token_to_id(self, token: str) -> int:
 
135
  return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN, 0))
136
+
137
  def _convert_id_to_token(self, index: int) -> str:
 
138
  return self._ids_to_tokens.get(index, self.UNK_TOKEN)
139
+
140
  def convert_tokens_to_string(self, tokens: List[str]) -> str:
 
 
141
  special = {self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN}
142
  return " ".join(t for t in tokens if t not in special)
143
+
144
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if not os.path.isdir(save_directory):
146
  os.makedirs(save_directory, exist_ok=True)
147
+
148
  vocab_file = os.path.join(
149
  save_directory,
150
  (filename_prefix + "-" if filename_prefix else "") + "vocab.json",
151
  )
152
+
153
  with open(vocab_file, "w", encoding="utf-8") as f:
154
  json.dump(self._vocab, f, ensure_ascii=False, indent=2)
 
 
155
 
156
+ return (vocab_file,)
157