lcccluck commited on
Commit
63cdefe
·
verified ·
1 Parent(s): 00ab189

Upload Gomoku training and MCTS code

Browse files
Files changed (4) hide show
  1. README.md +178 -0
  2. gomoku_mcts.py +975 -0
  3. gomoku_pg.py +1105 -0
  4. train_mcts_15x15_5.sh +73 -0
README.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Minimal Gomoku Policy Gradient
2
+
3
+ 这是一个学习向的最简五子棋策略梯度示例,核心特点:
4
+
5
+ - 一个文件:`gomoku_pg.py`
6
+ - 可配置棋盘大小:例如 `5x5`、`15x15`
7
+ - 可配置连珠数:例如 `4` 连珠、`5` 连珠
8
+ - 使用 `torch` 和精简版 `actor-critic` policy gradient
9
+ - 同一个策略同时扮演先手和后手,自博弈训练
10
+
11
+ ## 核心思路
12
+
13
+ 状态编码是 3 个平面:
14
+
15
+ 1. 当前行动方自己的棋子
16
+ 2. 对手的棋子
17
+ 3. 合法落点
18
+
19
+ 策略网络是一个很小的全卷积网络,输出每个格子的 logits。非法位置会被 mask 掉,然后对合法位置做采样。
20
+
21
+ 训练时:
22
+
23
+ 1. 用当前策略自博弈完整下一局
24
+ 2. 每一步保存 `log_prob(action)`
25
+ 3. 终局后给每一步一个回报
26
+ 当前步所属玩家最终赢了就是 `+1`
27
+ 输了就是 `-1`
28
+ 平局就是 `0`
29
+ 4. 策略头用 advantage 做 policy gradient,价值头预测回报,降低方差
30
+ 5. 训练时随机旋转/翻转棋盘,提升样本效率
31
+
32
+ ## 先做小棋盘验证
33
+
34
+ 建议先验证:
35
+
36
+ ```bash
37
+ ~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py train \
38
+ --board-size 5 \
39
+ --win-length 4 \
40
+ --episodes 5000 \
41
+ --batch-size 32 \
42
+ --eval-every 300 \
43
+ --eval-games 40 \
44
+ --checkpoint gomoku_5x5_4.pt
45
+ ```
46
+
47
+ 评估:
48
+
49
+ ```bash
50
+ ~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py eval \
51
+ --board-size 5 \
52
+ --win-length 4 \
53
+ --checkpoint gomoku_5x5_4.pt \
54
+ --agent mcts \
55
+ --mcts-sims 120 \
56
+ --games 100
57
+ ```
58
+
59
+ 图形界面对弈验证:
60
+
61
+ ```bash
62
+ ~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py gui \
63
+ --checkpoint gomoku_5x5_4.pt \
64
+ --agent mcts \
65
+ --mcts-sims 120 \
66
+ --human-first
67
+ ```
68
+
69
+ 操作:
70
+
71
+ - 鼠标左键落子
72
+ - `R` 重新开始
73
+ - `Esc` 退出
74
+
75
+ 如果还没装 `pygame`:
76
+
77
+ ```bash
78
+ ~/miniconda3/bin/conda run -n lerobot python -m pip install pygame
79
+ ```
80
+
81
+ 人机对弈:
82
+
83
+ ```bash
84
+ ~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py play \
85
+ --board-size 5 \
86
+ --win-length 4 \
87
+ --checkpoint gomoku_5x5_4.pt \
88
+ --agent mcts \
89
+ --mcts-sims 120 \
90
+ --human-first
91
+ ```
92
+
93
+ ## 切换到标准五子棋
94
+
95
+ ```bash
96
+ ~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py train \
97
+ --board-size 15 \
98
+ --win-length 5 \
99
+ --episodes 20000 \
100
+ --batch-size 32 \
101
+ --eval-every 1000 \
102
+ --eval-games 40 \
103
+ --checkpoint gomoku_15x15_5.pt
104
+ ```
105
+
106
+ 注意:代码可以直接切棋盘大小,但模型参数需要重新训练,不能指望 `5x5 + 4 连珠` 学到的策略直接适用于 `15x15 + 5 连珠`。
107
+
108
+ ## 怎么验证算法
109
+
110
+ 最直接的验证顺序:
111
+
112
+ 1. 先训练 `5x5 + 4 连珠`
113
+ 2. 用 `eval` 看对随机策略胜率是否明显高于 50%
114
+ 3. 用 `gui` 人工对弈,观察它是否会优先补成四连、阻挡你的四连
115
+ 4. 再切到 `15x15 + 5 连珠` 重新训练
116
+
117
+ 如果你只是想验证实现有没有大错,先看小棋盘最有效,因为训练快,策略错误会更明显。
118
+
119
+ ## 为什么你会很容易赢
120
+
121
+ 如果你之前用的是最原始的终局奖励 `REINFORCE`,很容易出现这几个问题:
122
+
123
+ - 终局奖励太稀疏,前面大量落子几乎收不到有效学习信号
124
+ - 方差很大,训练出来的策略不稳定
125
+ - `15x15` 动作空间太大,从零自博弈非常慢
126
+
127
+ 这版已经改成更稳的 `actor-critic`。即便如此,标准五子棋从零训练仍然不可能靠几百局就变强。
128
+
129
+ ## 推理时 MCTS
130
+
131
+ 现在 `eval`、`play`、`gui` 都支持:
132
+
133
+ - `--agent policy`:直接让策略网络落子
134
+ - `--agent mcts`:让策略网络和值网络先做 MCTS 搜索,再落子
135
+
136
+ 建议人机测试默认用 `mcts`,通常会比直接落子强一截。
137
+
138
+ 例如:
139
+
140
+ ```bash
141
+ ~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py gui \
142
+ --checkpoint gomoku_15x15_5.pt \
143
+ --agent mcts \
144
+ --mcts-sims 120 \
145
+ --human-first
146
+ ```
147
+
148
+ 如果你觉得慢,可以先把 `--mcts-sims` 降到 `32` 或 `64`。
149
+
150
+ ## 更现实的训练方式
151
+
152
+ 建议这样做:
153
+
154
+ 1. 先训 `5x5 + 4 连珠`
155
+ 2. 再用小棋盘权重热启动更大的棋盘
156
+ 3. 最后再训 `15x15 + 5 连珠`
157
+
158
+ 例如:
159
+
160
+ ```bash
161
+ ~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py train \
162
+ --board-size 7 \
163
+ --win-length 5 \
164
+ --episodes 5000 \
165
+ --init-checkpoint gomoku_5x5_4.pt \
166
+ --checkpoint gomoku_7x7_5.pt
167
+ ```
168
+
169
+ 再继续:
170
+
171
+ ```bash
172
+ ~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py train \
173
+ --board-size 15 \
174
+ --win-length 5 \
175
+ --episodes 20000 \
176
+ --init-checkpoint gomoku_7x7_5.pt \
177
+ --checkpoint gomoku_15x15_5.pt
178
+ ```
gomoku_mcts.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Minimal Gomoku MCTS example.
3
+
4
+ This file is intentionally separate from gomoku_pg.py.
5
+ It uses the simpler AlphaZero-style recipe:
6
+ 1. self-play with MCTS
7
+ 2. policy/value targets from search
8
+ 3. supervised update on policy + value heads
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import math
15
+ import random
16
+ from collections import deque
17
+ from dataclasses import dataclass, field
18
+ from pathlib import Path
19
+
20
+ import numpy as np
21
+ import torch
22
+ from torch import nn
23
+
24
+
25
+ def choose_device(name: str) -> torch.device:
26
+ if name != "auto":
27
+ return torch.device(name)
28
+ if torch.cuda.is_available():
29
+ return torch.device("cuda")
30
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
31
+ return torch.device("mps")
32
+ return torch.device("cpu")
33
+
34
+
35
+ def set_seed(seed: int) -> None:
36
+ random.seed(seed)
37
+ np.random.seed(seed)
38
+ torch.manual_seed(seed)
39
+ if torch.cuda.is_available():
40
+ torch.cuda.manual_seed_all(seed)
41
+
42
+
43
+ class GomokuEnv:
44
+ def __init__(self, board_size: int, win_length: int):
45
+ if board_size <= 1:
46
+ raise ValueError("board_size must be > 1")
47
+ if not 1 < win_length <= board_size:
48
+ raise ValueError("win_length must satisfy 1 < win_length <= board_size")
49
+ self.board_size = board_size
50
+ self.win_length = win_length
51
+ self.reset()
52
+
53
+ def reset(self) -> np.ndarray:
54
+ self.board = np.zeros((self.board_size, self.board_size), dtype=np.int8)
55
+ self.current_player = 1
56
+ self.done = False
57
+ self.winner = 0
58
+ return self.board
59
+
60
+ def valid_moves(self) -> np.ndarray:
61
+ return np.flatnonzero((self.board == 0).reshape(-1))
62
+
63
+ def step(self, action: int) -> tuple[bool, int]:
64
+ next_board, next_player, done, winner = apply_action_to_board(
65
+ board=self.board,
66
+ current_player=self.current_player,
67
+ action=action,
68
+ win_length=self.win_length,
69
+ )
70
+ self.board = next_board
71
+ self.current_player = next_player
72
+ self.done = done
73
+ self.winner = winner
74
+ return done, winner
75
+
76
+ def render(self) -> str:
77
+ symbols = {1: "X", -1: "O", 0: "."}
78
+ header = " " + " ".join(f"{i + 1:2d}" for i in range(self.board_size))
79
+ rows = [header]
80
+ for row_idx in range(self.board_size):
81
+ row = " ".join(f"{symbols[int(v)]:>2}" for v in self.board[row_idx])
82
+ rows.append(f"{row_idx + 1:2d} {row}")
83
+ return "\n".join(rows)
84
+
85
+
86
+ def action_to_coords(action: int, board_size: int) -> tuple[int, int]:
87
+ return divmod(int(action), board_size)
88
+
89
+
90
+ def coords_to_action(row: int, col: int, board_size: int) -> int:
91
+ return row * board_size + col
92
+
93
+
94
+ def count_one_side(
95
+ board: np.ndarray,
96
+ row: int,
97
+ col: int,
98
+ dr: int,
99
+ dc: int,
100
+ player: int,
101
+ ) -> int:
102
+ board_size = board.shape[0]
103
+ total = 0
104
+ r, c = row + dr, col + dc
105
+ while 0 <= r < board_size and 0 <= c < board_size:
106
+ if board[r, c] != player:
107
+ break
108
+ total += 1
109
+ r += dr
110
+ c += dc
111
+ return total
112
+
113
+
114
+ def is_winning_move(
115
+ board: np.ndarray,
116
+ row: int,
117
+ col: int,
118
+ player: int,
119
+ win_length: int,
120
+ ) -> bool:
121
+ directions = ((1, 0), (0, 1), (1, 1), (1, -1))
122
+ for dr, dc in directions:
123
+ count = 1
124
+ count += count_one_side(board, row, col, dr, dc, player)
125
+ count += count_one_side(board, row, col, -dr, -dc, player)
126
+ if count >= win_length:
127
+ return True
128
+ return False
129
+
130
+
131
+ def apply_action_to_board(
132
+ board: np.ndarray,
133
+ current_player: int,
134
+ action: int,
135
+ win_length: int,
136
+ ) -> tuple[np.ndarray, int, bool, int]:
137
+ board_size = board.shape[0]
138
+ row, col = action_to_coords(action, board_size)
139
+ if board[row, col] != 0:
140
+ raise ValueError(f"illegal move at ({row}, {col})")
141
+
142
+ next_board = board.copy()
143
+ next_board[row, col] = current_player
144
+
145
+ if is_winning_move(next_board, row, col, current_player, win_length):
146
+ return next_board, -current_player, True, current_player
147
+ if not np.any(next_board == 0):
148
+ return next_board, -current_player, True, 0
149
+ return next_board, -current_player, False, 0
150
+
151
+
152
+ def encode_state(board: np.ndarray, current_player: int) -> torch.Tensor:
153
+ current = (board == current_player).astype(np.float32)
154
+ opponent = (board == -current_player).astype(np.float32)
155
+ legal = (board == 0).astype(np.float32)
156
+ return torch.from_numpy(np.stack([current, opponent, legal], axis=0))
157
+
158
+
159
+ class PolicyValueNet(nn.Module):
160
+ def __init__(self, channels: int = 64):
161
+ super().__init__()
162
+ self.trunk = nn.Sequential(
163
+ nn.Conv2d(3, channels, kernel_size=3, padding=1),
164
+ nn.ReLU(),
165
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
166
+ nn.ReLU(),
167
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
168
+ nn.ReLU(),
169
+ )
170
+ self.policy_head = nn.Conv2d(channels, 1, kernel_size=1)
171
+ self.value_head = nn.Sequential(
172
+ nn.AdaptiveAvgPool2d(1),
173
+ nn.Flatten(),
174
+ nn.Linear(channels, channels),
175
+ nn.ReLU(),
176
+ nn.Linear(channels, 1),
177
+ nn.Tanh(),
178
+ )
179
+
180
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
181
+ features = self.trunk(x)
182
+ policy_logits = self.policy_head(features).flatten(start_dim=1)
183
+ value = self.value_head(features).squeeze(-1)
184
+ return policy_logits, value
185
+
186
+
187
+ def masked_logits(logits: torch.Tensor, board: np.ndarray) -> torch.Tensor:
188
+ legal = torch.as_tensor((board == 0).reshape(-1), device=logits.device, dtype=torch.bool)
189
+ return logits.masked_fill(~legal, -1e9)
190
+
191
+
192
+ def evaluate_policy_value(
193
+ policy: PolicyValueNet,
194
+ board: np.ndarray,
195
+ current_player: int,
196
+ device: torch.device,
197
+ ) -> tuple[np.ndarray, float]:
198
+ state = encode_state(board, current_player).unsqueeze(0).to(device)
199
+ with torch.no_grad():
200
+ logits, value = policy(state)
201
+ logits = masked_logits(logits.squeeze(0), board)
202
+ probs = torch.softmax(logits, dim=0).cpu().numpy()
203
+ return probs, float(value.item())
204
+
205
+
206
+ @dataclass
207
+ class MCTSNode:
208
+ board: np.ndarray
209
+ current_player: int
210
+ win_length: int
211
+ done: bool = False
212
+ winner: int = 0
213
+ priors: dict[int, float] = field(default_factory=dict)
214
+ visit_counts: dict[int, int] = field(default_factory=dict)
215
+ value_sums: dict[int, float] = field(default_factory=dict)
216
+ children: dict[int, "MCTSNode"] = field(default_factory=dict)
217
+ expanded: bool = False
218
+
219
+ def expand(
220
+ self,
221
+ priors: np.ndarray,
222
+ add_noise: bool = False,
223
+ dirichlet_alpha: float = 0.3,
224
+ noise_eps: float = 0.25,
225
+ ) -> None:
226
+ legal_actions = np.flatnonzero((self.board == 0).reshape(-1))
227
+ legal_priors = priors[legal_actions]
228
+ total_prob = float(np.sum(legal_priors))
229
+ if total_prob <= 0.0:
230
+ legal_priors = np.full(len(legal_actions), 1.0 / max(len(legal_actions), 1), dtype=np.float32)
231
+ else:
232
+ legal_priors = legal_priors / total_prob
233
+
234
+ if add_noise and len(legal_actions) > 0:
235
+ noise = np.random.dirichlet([dirichlet_alpha] * len(legal_actions))
236
+ legal_priors = (1.0 - noise_eps) * legal_priors + noise_eps * noise
237
+
238
+ self.priors = {
239
+ int(action): float(prior)
240
+ for action, prior in zip(legal_actions, legal_priors, strict=False)
241
+ }
242
+ self.visit_counts = {action: 0 for action in self.priors}
243
+ self.value_sums = {action: 0.0 for action in self.priors}
244
+ self.expanded = True
245
+
246
+ def q_value(self, action: int) -> float:
247
+ visits = self.visit_counts[action]
248
+ if visits == 0:
249
+ return 0.0
250
+ return self.value_sums[action] / visits
251
+
252
+ def select_action(self, c_puct: float) -> int:
253
+ total_visits = sum(self.visit_counts.values())
254
+ sqrt_total = math.sqrt(total_visits + 1.0)
255
+ best_action = -1
256
+ best_score = -float("inf")
257
+ for action, prior in self.priors.items():
258
+ q = self.q_value(action)
259
+ u = c_puct * prior * sqrt_total / (1.0 + self.visit_counts[action])
260
+ score = q + u
261
+ if score > best_score:
262
+ best_score = score
263
+ best_action = action
264
+ return best_action
265
+
266
+ def child_for_action(self, action: int) -> "MCTSNode":
267
+ child = self.children.get(action)
268
+ if child is not None:
269
+ return child
270
+ next_board, next_player, done, winner = apply_action_to_board(
271
+ board=self.board,
272
+ current_player=self.current_player,
273
+ action=action,
274
+ win_length=self.win_length,
275
+ )
276
+ child = MCTSNode(
277
+ board=next_board,
278
+ current_player=next_player,
279
+ win_length=self.win_length,
280
+ done=done,
281
+ winner=winner,
282
+ )
283
+ self.children[action] = child
284
+ return child
285
+
286
+
287
+ def terminal_value(winner: int, current_player: int) -> float:
288
+ if winner == 0:
289
+ return 0.0
290
+ return 1.0 if winner == current_player else -1.0
291
+
292
+
293
+ def sample_from_visits(visits: np.ndarray, temperature: float) -> tuple[int, np.ndarray]:
294
+ flat = visits.reshape(-1).astype(np.float64)
295
+ if np.all(flat == 0):
296
+ flat = np.ones_like(flat)
297
+
298
+ if temperature <= 1e-6:
299
+ probs = np.zeros_like(flat, dtype=np.float64)
300
+ probs[int(np.argmax(flat))] = 1.0
301
+ else:
302
+ adjusted = np.power(flat, 1.0 / temperature)
303
+ probs = adjusted / np.sum(adjusted)
304
+
305
+ action = int(np.random.choice(len(probs), p=probs))
306
+ return action, probs.reshape(visits.shape).astype(np.float32)
307
+
308
+
309
+ def choose_mcts_action(
310
+ policy: PolicyValueNet,
311
+ board: np.ndarray,
312
+ current_player: int,
313
+ win_length: int,
314
+ device: torch.device,
315
+ num_simulations: int,
316
+ c_puct: float,
317
+ temperature: float,
318
+ add_root_noise: bool,
319
+ dirichlet_alpha: float,
320
+ noise_eps: float,
321
+ ) -> tuple[int, np.ndarray]:
322
+ root = MCTSNode(board=board.copy(), current_player=current_player, win_length=win_length)
323
+ priors, _ = evaluate_policy_value(policy, root.board, root.current_player, device)
324
+ root.expand(
325
+ priors,
326
+ add_noise=add_root_noise,
327
+ dirichlet_alpha=dirichlet_alpha,
328
+ noise_eps=noise_eps,
329
+ )
330
+
331
+ for _ in range(num_simulations):
332
+ node = root
333
+ path: list[tuple[MCTSNode, int]] = []
334
+
335
+ while node.expanded and not node.done:
336
+ action = node.select_action(c_puct)
337
+ path.append((node, action))
338
+ node = node.child_for_action(action)
339
+
340
+ if node.done:
341
+ value = terminal_value(node.winner, node.current_player)
342
+ else:
343
+ priors, value = evaluate_policy_value(policy, node.board, node.current_player, device)
344
+ node.expand(priors)
345
+
346
+ for parent, action in reversed(path):
347
+ value = -value
348
+ parent.visit_counts[action] += 1
349
+ parent.value_sums[action] += value
350
+
351
+ visits = np.zeros(board.shape, dtype=np.float32)
352
+ for action, count in root.visit_counts.items():
353
+ row, col = action_to_coords(action, board.shape[0])
354
+ visits[row, col] = float(count)
355
+
356
+ action, visit_probs = sample_from_visits(visits, temperature=temperature)
357
+ return action, visit_probs
358
+
359
+
360
+ def choose_ai_action(
361
+ policy: PolicyValueNet,
362
+ board: np.ndarray,
363
+ current_player: int,
364
+ win_length: int,
365
+ device: torch.device,
366
+ agent: str,
367
+ mcts_sims: int,
368
+ c_puct: float,
369
+ ) -> tuple[int, np.ndarray | None]:
370
+ if agent == "policy":
371
+ priors, _ = evaluate_policy_value(policy, board, current_player, device)
372
+ return int(np.argmax(priors)), None
373
+ return choose_mcts_action(
374
+ policy=policy,
375
+ board=board,
376
+ current_player=current_player,
377
+ win_length=win_length,
378
+ device=device,
379
+ num_simulations=mcts_sims,
380
+ c_puct=c_puct,
381
+ temperature=1e-6,
382
+ add_root_noise=False,
383
+ dirichlet_alpha=0.3,
384
+ noise_eps=0.25,
385
+ )
386
+
387
+
388
+ def self_play_game(
389
+ policy: PolicyValueNet,
390
+ board_size: int,
391
+ win_length: int,
392
+ device: torch.device,
393
+ mcts_sims: int,
394
+ c_puct: float,
395
+ temperature: float,
396
+ temperature_drop_moves: int,
397
+ dirichlet_alpha: float,
398
+ noise_eps: float,
399
+ ) -> tuple[list[tuple[torch.Tensor, np.ndarray, float]], int, int]:
400
+ env = GomokuEnv(board_size=board_size, win_length=win_length)
401
+ env.reset()
402
+ history: list[tuple[torch.Tensor, np.ndarray, int]] = []
403
+
404
+ move_idx = 0
405
+ while not env.done:
406
+ move_temp = temperature if move_idx < temperature_drop_moves else 1e-6
407
+ action, visit_probs = choose_mcts_action(
408
+ policy=policy,
409
+ board=env.board,
410
+ current_player=env.current_player,
411
+ win_length=win_length,
412
+ device=device,
413
+ num_simulations=mcts_sims,
414
+ c_puct=c_puct,
415
+ temperature=move_temp,
416
+ add_root_noise=True,
417
+ dirichlet_alpha=dirichlet_alpha,
418
+ noise_eps=noise_eps,
419
+ )
420
+ history.append((encode_state(env.board, env.current_player), visit_probs.reshape(-1), env.current_player))
421
+ env.step(action)
422
+ move_idx += 1
423
+
424
+ examples: list[tuple[torch.Tensor, np.ndarray, float]] = []
425
+ for state, visit_probs, player in history:
426
+ if env.winner == 0:
427
+ outcome = 0.0
428
+ else:
429
+ outcome = 1.0 if player == env.winner else -1.0
430
+ examples.append((state, visit_probs, outcome))
431
+ return examples, env.winner, move_idx
432
+
433
+
434
+ def train_batch(
435
+ policy: PolicyValueNet,
436
+ optimizer: torch.optim.Optimizer,
437
+ batch: list[tuple[torch.Tensor, np.ndarray, float]],
438
+ device: torch.device,
439
+ value_coef: float,
440
+ ) -> tuple[float, float, float]:
441
+ states = torch.stack([item[0] for item in batch]).to(device)
442
+ target_policies = torch.tensor(
443
+ np.stack([item[1] for item in batch]),
444
+ dtype=torch.float32,
445
+ device=device,
446
+ )
447
+ target_values = torch.tensor([item[2] for item in batch], dtype=torch.float32, device=device)
448
+
449
+ logits, values = policy(states)
450
+ log_probs = torch.log_softmax(logits, dim=1)
451
+ policy_loss = -(target_policies * log_probs).sum(dim=1).mean()
452
+ value_loss = torch.mean((values - target_values) ** 2)
453
+ loss = policy_loss + value_coef * value_loss
454
+
455
+ optimizer.zero_grad(set_to_none=True)
456
+ loss.backward()
457
+ nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
458
+ optimizer.step()
459
+ return float(loss.item()), float(policy_loss.item()), float(value_loss.item())
460
+
461
+
462
+ def save_checkpoint(path: Path, policy: PolicyValueNet, args: argparse.Namespace) -> None:
463
+ torch.save(
464
+ {
465
+ "state_dict": policy.state_dict(),
466
+ "channels": args.channels,
467
+ "board_size": args.board_size,
468
+ "win_length": args.win_length,
469
+ },
470
+ path,
471
+ )
472
+
473
+
474
+ def last_checkpoint_path(base_path: Path) -> Path:
475
+ return base_path.with_name(f"{base_path.stem}_last{base_path.suffix}")
476
+
477
+
478
+ def load_checkpoint(path: Path, map_location: torch.device) -> dict:
479
+ checkpoint = torch.load(path, map_location=map_location)
480
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
481
+ return checkpoint
482
+ raise RuntimeError(f"{path} is not a compatible gomoku_mcts checkpoint")
483
+
484
+
485
+ def resolve_game_config(
486
+ checkpoint_path: Path,
487
+ arg_board_size: int | None,
488
+ arg_win_length: int | None,
489
+ arg_channels: int,
490
+ device: torch.device,
491
+ ) -> tuple[PolicyValueNet, int, int]:
492
+ checkpoint = load_checkpoint(checkpoint_path, map_location=device)
493
+ board_size = int(checkpoint.get("board_size") or arg_board_size or 15)
494
+ win_length = int(checkpoint.get("win_length") or arg_win_length or 5)
495
+ channels = int(checkpoint.get("channels") or arg_channels)
496
+
497
+ policy = PolicyValueNet(channels=channels).to(device)
498
+ policy.load_state_dict(checkpoint["state_dict"])
499
+ policy.eval()
500
+ return policy, board_size, win_length
501
+
502
+
503
+ def play_vs_random_once(
504
+ policy: PolicyValueNet,
505
+ board_size: int,
506
+ win_length: int,
507
+ device: torch.device,
508
+ policy_player: int,
509
+ agent: str,
510
+ mcts_sims: int,
511
+ c_puct: float,
512
+ ) -> int:
513
+ env = GomokuEnv(board_size=board_size, win_length=win_length)
514
+ env.reset()
515
+ while not env.done:
516
+ if env.current_player == policy_player:
517
+ action, _ = choose_ai_action(
518
+ policy=policy,
519
+ board=env.board,
520
+ current_player=env.current_player,
521
+ win_length=win_length,
522
+ device=device,
523
+ agent=agent,
524
+ mcts_sims=mcts_sims,
525
+ c_puct=c_puct,
526
+ )
527
+ else:
528
+ action = int(np.random.choice(env.valid_moves()))
529
+ env.step(action)
530
+ return env.winner
531
+
532
+
533
+ def evaluate_vs_random(
534
+ policy: PolicyValueNet,
535
+ board_size: int,
536
+ win_length: int,
537
+ device: torch.device,
538
+ games: int,
539
+ agent: str,
540
+ mcts_sims: int,
541
+ c_puct: float,
542
+ ) -> tuple[float, int, int, int]:
543
+ wins = 0
544
+ draws = 0
545
+ losses = 0
546
+ for game_idx in range(games):
547
+ policy_player = 1 if game_idx < games // 2 else -1
548
+ winner = play_vs_random_once(
549
+ policy=policy,
550
+ board_size=board_size,
551
+ win_length=win_length,
552
+ device=device,
553
+ policy_player=policy_player,
554
+ agent=agent,
555
+ mcts_sims=mcts_sims,
556
+ c_puct=c_puct,
557
+ )
558
+ if winner == 0:
559
+ draws += 1
560
+ elif winner == policy_player:
561
+ wins += 1
562
+ else:
563
+ losses += 1
564
+ return wins / max(games, 1), wins, draws, losses
565
+
566
+
567
+ def train(args: argparse.Namespace) -> None:
568
+ set_seed(args.seed)
569
+ device = choose_device(args.device)
570
+ policy = PolicyValueNet(channels=args.channels).to(device)
571
+ if args.init_checkpoint is not None and args.init_checkpoint.exists():
572
+ checkpoint = load_checkpoint(args.init_checkpoint, map_location=device)
573
+ policy.load_state_dict(checkpoint["state_dict"])
574
+ optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr, weight_decay=args.weight_decay)
575
+ replay_buffer: deque[tuple[torch.Tensor, np.ndarray, float]] = deque(maxlen=args.buffer_size)
576
+
577
+ print(f"device={device} board={args.board_size} win={args.win_length}")
578
+
579
+ for iteration in range(1, args.iterations + 1):
580
+ policy.eval()
581
+ winners: list[int] = []
582
+ lengths: list[int] = []
583
+ for _ in range(args.games_per_iter):
584
+ examples, winner, moves = self_play_game(
585
+ policy=policy,
586
+ board_size=args.board_size,
587
+ win_length=args.win_length,
588
+ device=device,
589
+ mcts_sims=args.mcts_sims,
590
+ c_puct=args.c_puct,
591
+ temperature=args.temperature,
592
+ temperature_drop_moves=args.temperature_drop_moves,
593
+ dirichlet_alpha=args.dirichlet_alpha,
594
+ noise_eps=args.noise_eps,
595
+ )
596
+ replay_buffer.extend(examples)
597
+ winners.append(winner)
598
+ lengths.append(moves)
599
+
600
+ losses: list[tuple[float, float, float]] = []
601
+ if len(replay_buffer) >= args.batch_size:
602
+ policy.train()
603
+ for _ in range(args.train_steps):
604
+ batch = random.sample(replay_buffer, args.batch_size)
605
+ losses.append(
606
+ train_batch(
607
+ policy=policy,
608
+ optimizer=optimizer,
609
+ batch=batch,
610
+ device=device,
611
+ value_coef=args.value_coef,
612
+ )
613
+ )
614
+
615
+ avg_loss = float(np.mean([x[0] for x in losses])) if losses else 0.0
616
+ avg_policy_loss = float(np.mean([x[1] for x in losses])) if losses else 0.0
617
+ avg_value_loss = float(np.mean([x[2] for x in losses])) if losses else 0.0
618
+ p1_wins = sum(1 for x in winners if x == 1)
619
+ p2_wins = sum(1 for x in winners if x == -1)
620
+ draws = sum(1 for x in winners if x == 0)
621
+ avg_len = float(np.mean(lengths)) if lengths else 0.0
622
+
623
+ message = (
624
+ f"iter={iteration:5d} loss={avg_loss:7.4f} policy={avg_policy_loss:7.4f} "
625
+ f"value={avg_value_loss:7.4f} p1={p1_wins:3d} p2={p2_wins:3d} draw={draws:3d} "
626
+ f"avg_len={avg_len:6.2f} buffer={len(replay_buffer):6d}"
627
+ )
628
+ if args.eval_every > 0 and iteration % args.eval_every == 0:
629
+ policy.eval()
630
+ win_rate, wins, eval_draws, eval_losses = evaluate_vs_random(
631
+ policy=policy,
632
+ board_size=args.board_size,
633
+ win_length=args.win_length,
634
+ device=device,
635
+ games=args.eval_games,
636
+ agent="mcts",
637
+ mcts_sims=args.eval_mcts_sims,
638
+ c_puct=args.c_puct,
639
+ )
640
+ message += f" random_win_rate={win_rate:.3f} ({wins}/{eval_draws}/{eval_losses})"
641
+ print(message)
642
+
643
+ if args.save_every > 0 and iteration % args.save_every == 0:
644
+ checkpoint_path = last_checkpoint_path(args.checkpoint)
645
+ save_checkpoint(checkpoint_path, policy, args)
646
+ print(f"saved checkpoint to {checkpoint_path}")
647
+
648
+ save_checkpoint(args.checkpoint, policy, args)
649
+ print(f"saved checkpoint to {args.checkpoint}")
650
+
651
+
652
+ def evaluate(args: argparse.Namespace) -> None:
653
+ device = choose_device(args.device)
654
+ policy, board_size, win_length = resolve_game_config(
655
+ checkpoint_path=args.checkpoint,
656
+ arg_board_size=args.board_size,
657
+ arg_win_length=args.win_length,
658
+ arg_channels=args.channels,
659
+ device=device,
660
+ )
661
+ win_rate, wins, draws, losses = evaluate_vs_random(
662
+ policy=policy,
663
+ board_size=board_size,
664
+ win_length=win_length,
665
+ device=device,
666
+ games=args.games,
667
+ agent=args.agent,
668
+ mcts_sims=args.mcts_sims,
669
+ c_puct=args.c_puct,
670
+ )
671
+ print(f"device={device}")
672
+ print(f"agent={args.agent} mcts_sims={args.mcts_sims}")
673
+ print(f"win_rate={win_rate:.3f} wins={wins} draws={draws} losses={losses}")
674
+
675
+
676
+ def ask_human_move(env: GomokuEnv) -> int:
677
+ while True:
678
+ text = input("your move (row col): ").strip()
679
+ parts = text.replace(",", " ").split()
680
+ if len(parts) != 2:
681
+ print("please enter: row col")
682
+ continue
683
+ try:
684
+ row, col = (int(parts[0]) - 1, int(parts[1]) - 1)
685
+ except ValueError:
686
+ print("row and col must be integers")
687
+ continue
688
+ if not (0 <= row < env.board_size and 0 <= col < env.board_size):
689
+ print("move out of range")
690
+ continue
691
+ if env.board[row, col] != 0:
692
+ print("that position is occupied")
693
+ continue
694
+ return coords_to_action(row, col, env.board_size)
695
+
696
+
697
+ def play(args: argparse.Namespace) -> None:
698
+ device = choose_device(args.device)
699
+ policy, board_size, win_length = resolve_game_config(
700
+ checkpoint_path=args.checkpoint,
701
+ arg_board_size=args.board_size,
702
+ arg_win_length=args.win_length,
703
+ arg_channels=args.channels,
704
+ device=device,
705
+ )
706
+ env = GomokuEnv(board_size=board_size, win_length=win_length)
707
+ human_player = 1 if args.human_first else -1
708
+
709
+ print(f"device={device}")
710
+ print(
711
+ f"human={'X' if human_player == 1 else 'O'} ai={'O' if human_player == 1 else 'X'} "
712
+ f"agent={args.agent} mcts_sims={args.mcts_sims}"
713
+ )
714
+
715
+ while not env.done:
716
+ print()
717
+ print(env.render())
718
+ print()
719
+ if env.current_player == human_player:
720
+ action = ask_human_move(env)
721
+ else:
722
+ action, _ = choose_ai_action(
723
+ policy=policy,
724
+ board=env.board,
725
+ current_player=env.current_player,
726
+ win_length=win_length,
727
+ device=device,
728
+ agent=args.agent,
729
+ mcts_sims=args.mcts_sims,
730
+ c_puct=args.c_puct,
731
+ )
732
+ row, col = action_to_coords(action, env.board_size)
733
+ print(f"ai move: {row + 1} {col + 1}")
734
+ env.step(action)
735
+
736
+ print()
737
+ print(env.render())
738
+ if env.winner == 0:
739
+ print("draw")
740
+ elif env.winner == human_player:
741
+ print("you win")
742
+ else:
743
+ print("ai wins")
744
+
745
+
746
+ def gui(args: argparse.Namespace) -> None:
747
+ try:
748
+ import pygame
749
+ except ModuleNotFoundError as exc:
750
+ raise SystemExit(
751
+ "pygame is not installed. Install it with: "
752
+ "~/miniconda3/bin/conda run -n lerobot python -m pip install pygame"
753
+ ) from exc
754
+
755
+ device = choose_device(args.device)
756
+ policy, board_size, win_length = resolve_game_config(
757
+ checkpoint_path=args.checkpoint,
758
+ arg_board_size=args.board_size,
759
+ arg_win_length=args.win_length,
760
+ arg_channels=args.channels,
761
+ device=device,
762
+ )
763
+ env = GomokuEnv(board_size=board_size, win_length=win_length)
764
+ human_player = 1 if args.human_first else -1
765
+ last_search_visits: np.ndarray | None = None
766
+
767
+ pygame.init()
768
+ pygame.display.set_caption("Gomoku MCTS")
769
+ font = pygame.font.SysFont("Arial", 24)
770
+ small_font = pygame.font.SysFont("Arial", 18)
771
+
772
+ cell_size = args.cell_size
773
+ padding = 40
774
+ status_height = 80
775
+ board_pixels = board_size * cell_size
776
+ screen = pygame.display.set_mode(
777
+ (board_pixels + padding * 2, board_pixels + padding * 2 + status_height)
778
+ )
779
+ clock = pygame.time.Clock()
780
+
781
+ background = (236, 196, 122)
782
+ line_color = (80, 55, 20)
783
+ black_stone = (20, 20, 20)
784
+ white_stone = (245, 245, 245)
785
+ accent = (180, 40, 40)
786
+
787
+ def board_to_screen(row: int, col: int) -> tuple[int, int]:
788
+ x = padding + col * cell_size + cell_size // 2
789
+ y = padding + row * cell_size + cell_size // 2
790
+ return x, y
791
+
792
+ def mouse_to_action(pos: tuple[int, int]) -> int | None:
793
+ x, y = pos
794
+ if x < padding or y < padding:
795
+ return None
796
+ col = (x - padding) // cell_size
797
+ row = (y - padding) // cell_size
798
+ if not (0 <= row < env.board_size and 0 <= col < env.board_size):
799
+ return None
800
+ if env.board[row, col] != 0:
801
+ return None
802
+ return coords_to_action(row, col, env.board_size)
803
+
804
+ def ai_step() -> None:
805
+ nonlocal last_search_visits
806
+ if env.done or env.current_player == human_player:
807
+ return
808
+ action, visits = choose_ai_action(
809
+ policy=policy,
810
+ board=env.board,
811
+ current_player=env.current_player,
812
+ win_length=win_length,
813
+ device=device,
814
+ agent=args.agent,
815
+ mcts_sims=args.mcts_sims,
816
+ c_puct=args.c_puct,
817
+ )
818
+ last_search_visits = visits
819
+ env.step(action)
820
+
821
+ def restart() -> None:
822
+ nonlocal last_search_visits
823
+ env.reset()
824
+ last_search_visits = None
825
+ if env.current_player != human_player:
826
+ ai_step()
827
+
828
+ def status_text() -> str:
829
+ if env.done:
830
+ if env.winner == 0:
831
+ return "Draw. Press R to restart."
832
+ if env.winner == human_player:
833
+ return "You win. Press R to restart."
834
+ return "AI wins. Press R to restart."
835
+ if env.current_player == human_player:
836
+ return "Your turn. Left click to place."
837
+ return "AI is thinking..."
838
+
839
+ if env.current_player != human_player:
840
+ ai_step()
841
+
842
+ running = True
843
+ while running:
844
+ for event in pygame.event.get():
845
+ if event.type == pygame.QUIT:
846
+ running = False
847
+ elif event.type == pygame.KEYDOWN:
848
+ if event.key == pygame.K_ESCAPE:
849
+ running = False
850
+ elif event.key == pygame.K_r:
851
+ restart()
852
+ elif event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
853
+ if env.done or env.current_player != human_player:
854
+ continue
855
+ action = mouse_to_action(event.pos)
856
+ if action is None:
857
+ continue
858
+ env.step(action)
859
+ ai_step()
860
+
861
+ screen.fill(background)
862
+ for idx in range(board_size + 1):
863
+ x = padding + idx * cell_size
864
+ y = padding + idx * cell_size
865
+ pygame.draw.line(screen, line_color, (x, padding), (x, padding + board_pixels), 2)
866
+ pygame.draw.line(screen, line_color, (padding, y), (padding + board_pixels, y), 2)
867
+
868
+ for row in range(env.board_size):
869
+ for col in range(env.board_size):
870
+ stone = int(env.board[row, col])
871
+ if stone == 0:
872
+ continue
873
+ x, y = board_to_screen(row, col)
874
+ color = black_stone if stone == 1 else white_stone
875
+ pygame.draw.circle(screen, color, (x, y), cell_size // 2 - 4)
876
+ pygame.draw.circle(screen, line_color, (x, y), cell_size // 2 - 4, 1)
877
+
878
+ for idx in range(board_size):
879
+ label = small_font.render(str(idx + 1), True, line_color)
880
+ screen.blit(label, (padding + idx * cell_size + cell_size // 2 - label.get_width() // 2, 8))
881
+ screen.blit(label, (8, padding + idx * cell_size + cell_size // 2 - label.get_height() // 2))
882
+
883
+ info = (
884
+ f"{board_size}x{board_size} connect={win_length} device={device} "
885
+ f"agent={args.agent} sims={args.mcts_sims}"
886
+ )
887
+ screen.blit(small_font.render(info, True, line_color), (padding, padding + board_pixels + 16))
888
+ screen.blit(font.render(status_text(), True, accent), (padding, padding + board_pixels + 42))
889
+
890
+ if last_search_visits is not None and args.agent == "mcts":
891
+ peak = int(np.max(last_search_visits))
892
+ screen.blit(
893
+ small_font.render(f"peak_visits={peak}", True, line_color),
894
+ (padding + 420, padding + board_pixels + 16),
895
+ )
896
+
897
+ pygame.display.flip()
898
+ clock.tick(args.fps)
899
+
900
+ pygame.quit()
901
+
902
+
903
+ def build_parser() -> argparse.ArgumentParser:
904
+ parser = argparse.ArgumentParser(description="Minimal Gomoku MCTS example")
905
+ subparsers = parser.add_subparsers(dest="mode", required=True)
906
+
907
+ def add_common_arguments(subparser: argparse.ArgumentParser, defaults_from_checkpoint: bool = False) -> None:
908
+ board_default = None if defaults_from_checkpoint else 15
909
+ win_default = None if defaults_from_checkpoint else 5
910
+ subparser.add_argument("--board-size", type=int, default=board_default)
911
+ subparser.add_argument("--win-length", type=int, default=win_default)
912
+ subparser.add_argument("--channels", type=int, default=64)
913
+ subparser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto")
914
+ subparser.add_argument("--checkpoint", type=Path, default=Path("gomoku_mcts.pt"))
915
+
916
+ def add_inference_arguments(subparser: argparse.ArgumentParser) -> None:
917
+ subparser.add_argument("--agent", choices=["policy", "mcts"], default="mcts")
918
+ subparser.add_argument("--mcts-sims", type=int, default=120)
919
+ subparser.add_argument("--c-puct", type=float, default=1.5)
920
+
921
+ train_parser = subparsers.add_parser("train", help="MCTS self-play training")
922
+ add_common_arguments(train_parser)
923
+ train_parser.add_argument("--iterations", type=int, default=200)
924
+ train_parser.add_argument("--games-per-iter", type=int, default=8)
925
+ train_parser.add_argument("--train-steps", type=int, default=32)
926
+ train_parser.add_argument("--batch-size", type=int, default=64)
927
+ train_parser.add_argument("--buffer-size", type=int, default=20000)
928
+ train_parser.add_argument("--lr", type=float, default=1e-3)
929
+ train_parser.add_argument("--weight-decay", type=float, default=1e-4)
930
+ train_parser.add_argument("--value-coef", type=float, default=1.0)
931
+ train_parser.add_argument("--mcts-sims", type=int, default=64)
932
+ train_parser.add_argument("--eval-mcts-sims", type=int, default=120)
933
+ train_parser.add_argument("--c-puct", type=float, default=1.5)
934
+ train_parser.add_argument("--temperature", type=float, default=1.0)
935
+ train_parser.add_argument("--temperature-drop-moves", type=int, default=8)
936
+ train_parser.add_argument("--dirichlet-alpha", type=float, default=0.3)
937
+ train_parser.add_argument("--noise-eps", type=float, default=0.25)
938
+ train_parser.add_argument("--eval-every", type=int, default=10)
939
+ train_parser.add_argument("--eval-games", type=int, default=20)
940
+ train_parser.add_argument("--save-every", type=int, default=10)
941
+ train_parser.add_argument("--seed", type=int, default=42)
942
+ train_parser.add_argument("--init-checkpoint", type=Path, default=None)
943
+ train_parser.set_defaults(func=train)
944
+
945
+ eval_parser = subparsers.add_parser("eval", help="evaluate against random agent")
946
+ add_common_arguments(eval_parser)
947
+ add_inference_arguments(eval_parser)
948
+ eval_parser.add_argument("--games", type=int, default=40)
949
+ eval_parser.set_defaults(func=evaluate)
950
+
951
+ play_parser = subparsers.add_parser("play", help="play against the model")
952
+ add_common_arguments(play_parser, defaults_from_checkpoint=True)
953
+ add_inference_arguments(play_parser)
954
+ play_parser.add_argument("--human-first", action="store_true")
955
+ play_parser.set_defaults(func=play)
956
+
957
+ gui_parser = subparsers.add_parser("gui", help="pygame GUI")
958
+ add_common_arguments(gui_parser, defaults_from_checkpoint=True)
959
+ add_inference_arguments(gui_parser)
960
+ gui_parser.add_argument("--human-first", action="store_true")
961
+ gui_parser.add_argument("--cell-size", type=int, default=48)
962
+ gui_parser.add_argument("--fps", type=int, default=30)
963
+ gui_parser.set_defaults(func=gui)
964
+
965
+ return parser
966
+
967
+
968
+ def main() -> None:
969
+ parser = build_parser()
970
+ args = parser.parse_args()
971
+ args.func(args)
972
+
973
+
974
+ if __name__ == "__main__":
975
+ main()
gomoku_pg.py ADDED
@@ -0,0 +1,1105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Minimal Gomoku policy gradient example.
3
+
4
+ Features:
5
+ 1. Configurable board size and win length, e.g. 5x5 connect-4 or 15x15 connect-5.
6
+ 2. Shared-policy self-play with REINFORCE.
7
+ 3. Fully convolutional policy, so the same code works for different board sizes.
8
+ 4. Optional random-agent evaluation and CLI human play.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import math
15
+ import random
16
+ from collections import deque
17
+ from dataclasses import dataclass, field
18
+ from pathlib import Path
19
+
20
+ import numpy as np
21
+ import torch
22
+ from torch import nn
23
+ from torch.distributions import Categorical
24
+
25
+
26
+ def choose_device(name: str) -> torch.device:
27
+ if name != "auto":
28
+ return torch.device(name)
29
+ if torch.cuda.is_available():
30
+ return torch.device("cuda")
31
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
32
+ return torch.device("mps")
33
+ return torch.device("cpu")
34
+
35
+
36
+ def set_seed(seed: int) -> None:
37
+ random.seed(seed)
38
+ np.random.seed(seed)
39
+ torch.manual_seed(seed)
40
+ if torch.cuda.is_available():
41
+ torch.cuda.manual_seed_all(seed)
42
+
43
+
44
+ class GomokuEnv:
45
+ def __init__(self, board_size: int, win_length: int):
46
+ if board_size <= 1:
47
+ raise ValueError("board_size must be > 1")
48
+ if not 1 < win_length <= board_size:
49
+ raise ValueError("win_length must satisfy 1 < win_length <= board_size")
50
+ self.board_size = board_size
51
+ self.win_length = win_length
52
+ self.reset()
53
+
54
+ def reset(self) -> np.ndarray:
55
+ self.board = np.zeros((self.board_size, self.board_size), dtype=np.int8)
56
+ self.current_player = 1
57
+ self.done = False
58
+ self.winner = 0
59
+ return self.board
60
+
61
+ def legal_mask(self) -> np.ndarray:
62
+ return self.board == 0
63
+
64
+ def valid_moves(self) -> np.ndarray:
65
+ return np.flatnonzero(self.legal_mask().reshape(-1))
66
+
67
+ def step(self, action: int) -> tuple[bool, int]:
68
+ if self.done:
69
+ raise RuntimeError("game is already finished")
70
+
71
+ row, col = divmod(int(action), self.board_size)
72
+ if self.board[row, col] != 0:
73
+ raise ValueError(f"illegal move at ({row}, {col})")
74
+
75
+ player = self.current_player
76
+ self.board[row, col] = player
77
+
78
+ if self._is_winning_move(row, col, player):
79
+ self.done = True
80
+ self.winner = player
81
+ elif not np.any(self.board == 0):
82
+ self.done = True
83
+ self.winner = 0
84
+ else:
85
+ self.current_player = -player
86
+
87
+ return self.done, self.winner
88
+
89
+ def _is_winning_move(self, row: int, col: int, player: int) -> bool:
90
+ directions = ((1, 0), (0, 1), (1, 1), (1, -1))
91
+ for dr, dc in directions:
92
+ count = 1
93
+ count += self._count_one_side(row, col, dr, dc, player)
94
+ count += self._count_one_side(row, col, -dr, -dc, player)
95
+ if count >= self.win_length:
96
+ return True
97
+ return False
98
+
99
+ def _count_one_side(self, row: int, col: int, dr: int, dc: int, player: int) -> int:
100
+ total = 0
101
+ r, c = row + dr, col + dc
102
+ while 0 <= r < self.board_size and 0 <= c < self.board_size:
103
+ if self.board[r, c] != player:
104
+ break
105
+ total += 1
106
+ r += dr
107
+ c += dc
108
+ return total
109
+
110
+ def render(self) -> str:
111
+ symbols = {1: "X", -1: "O", 0: "."}
112
+ header = " " + " ".join(f"{i + 1:2d}" for i in range(self.board_size))
113
+ rows = [header]
114
+ for row_idx in range(self.board_size):
115
+ row = " ".join(f"{symbols[int(v)]:>2}" for v in self.board[row_idx])
116
+ rows.append(f"{row_idx + 1:2d} {row}")
117
+ return "\n".join(rows)
118
+
119
+
120
+ def encode_state(board: np.ndarray, current_player: int) -> torch.Tensor:
121
+ current = (board == current_player).astype(np.float32)
122
+ opponent = (board == -current_player).astype(np.float32)
123
+ legal = (board == 0).astype(np.float32)
124
+ stacked = np.stack([current, opponent, legal], axis=0)
125
+ return torch.from_numpy(stacked)
126
+
127
+
128
+ class PolicyValueNet(nn.Module):
129
+ def __init__(self, channels: int = 64):
130
+ super().__init__()
131
+ self.trunk = nn.Sequential(
132
+ nn.Conv2d(3, channels, kernel_size=3, padding=1),
133
+ nn.ReLU(),
134
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
135
+ nn.ReLU(),
136
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
137
+ nn.ReLU(),
138
+ )
139
+ self.policy_head = nn.Conv2d(channels, 1, kernel_size=1)
140
+ self.value_head = nn.Sequential(
141
+ nn.AdaptiveAvgPool2d(1),
142
+ nn.Flatten(),
143
+ nn.Linear(channels, channels),
144
+ nn.ReLU(),
145
+ nn.Linear(channels, 1),
146
+ nn.Tanh(),
147
+ )
148
+
149
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
150
+ features = self.trunk(x)
151
+ policy_logits = self.policy_head(features).flatten(start_dim=1)
152
+ value = self.value_head(features).squeeze(-1)
153
+ return policy_logits, value
154
+
155
+
156
+ def masked_logits(logits: torch.Tensor, legal_mask: np.ndarray) -> torch.Tensor:
157
+ legal = torch.as_tensor(legal_mask.reshape(-1), device=logits.device, dtype=torch.bool)
158
+ return logits.masked_fill(~legal, -1e9)
159
+
160
+
161
+ def transform_board(board: np.ndarray, rotation_k: int, flip: bool) -> np.ndarray:
162
+ transformed = np.rot90(board, k=rotation_k)
163
+ if flip:
164
+ transformed = np.fliplr(transformed)
165
+ return np.ascontiguousarray(transformed)
166
+
167
+
168
+ def action_to_coords(action: int, board_size: int) -> tuple[int, int]:
169
+ return divmod(int(action), board_size)
170
+
171
+
172
+ def coords_to_action(row: int, col: int, board_size: int) -> int:
173
+ return row * board_size + col
174
+
175
+
176
+ def count_one_side(
177
+ board: np.ndarray,
178
+ row: int,
179
+ col: int,
180
+ dr: int,
181
+ dc: int,
182
+ player: int,
183
+ ) -> int:
184
+ board_size = board.shape[0]
185
+ total = 0
186
+ r, c = row + dr, col + dc
187
+ while 0 <= r < board_size and 0 <= c < board_size:
188
+ if board[r, c] != player:
189
+ break
190
+ total += 1
191
+ r += dr
192
+ c += dc
193
+ return total
194
+
195
+
196
+ def is_winning_move(
197
+ board: np.ndarray,
198
+ row: int,
199
+ col: int,
200
+ player: int,
201
+ win_length: int,
202
+ ) -> bool:
203
+ directions = ((1, 0), (0, 1), (1, 1), (1, -1))
204
+ for dr, dc in directions:
205
+ count = 1
206
+ count += count_one_side(board, row, col, dr, dc, player)
207
+ count += count_one_side(board, row, col, -dr, -dc, player)
208
+ if count >= win_length:
209
+ return True
210
+ return False
211
+
212
+
213
+ def apply_action_to_board(
214
+ board: np.ndarray,
215
+ current_player: int,
216
+ action: int,
217
+ win_length: int,
218
+ ) -> tuple[np.ndarray, int, bool, int]:
219
+ board_size = board.shape[0]
220
+ row, col = action_to_coords(action, board_size)
221
+ if board[row, col] != 0:
222
+ raise ValueError(f"illegal move at ({row}, {col})")
223
+
224
+ next_board = board.copy()
225
+ next_board[row, col] = current_player
226
+
227
+ if is_winning_move(next_board, row, col, current_player, win_length):
228
+ return next_board, -current_player, True, current_player
229
+ if not np.any(next_board == 0):
230
+ return next_board, -current_player, True, 0
231
+ return next_board, -current_player, False, 0
232
+
233
+
234
+ def forward_transform_coords(
235
+ row: int,
236
+ col: int,
237
+ board_size: int,
238
+ rotation_k: int,
239
+ flip: bool,
240
+ ) -> tuple[int, int]:
241
+ for _ in range(rotation_k % 4):
242
+ row, col = board_size - 1 - col, row
243
+ if flip:
244
+ col = board_size - 1 - col
245
+ return row, col
246
+
247
+
248
+ def inverse_transform_coords(
249
+ row: int,
250
+ col: int,
251
+ board_size: int,
252
+ rotation_k: int,
253
+ flip: bool,
254
+ ) -> tuple[int, int]:
255
+ if flip:
256
+ col = board_size - 1 - col
257
+ for _ in range(rotation_k % 4):
258
+ row, col = col, board_size - 1 - row
259
+ return row, col
260
+
261
+
262
+ def sample_action(
263
+ policy: PolicyValueNet,
264
+ board: np.ndarray,
265
+ current_player: int,
266
+ device: torch.device,
267
+ greedy: bool,
268
+ augment: bool,
269
+ ) -> tuple[int, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
270
+ board_size = board.shape[0]
271
+ rotation_k = random.randint(0, 3) if augment else 0
272
+ flip = bool(random.getrandbits(1)) if augment else False
273
+ transformed_board = transform_board(board, rotation_k=rotation_k, flip=flip)
274
+
275
+ state = encode_state(transformed_board, current_player).unsqueeze(0).to(device)
276
+ logits, value = policy(state)
277
+ logits = masked_logits(logits.squeeze(0), transformed_board == 0)
278
+
279
+ if greedy:
280
+ action = torch.argmax(logits)
281
+ transformed_row, transformed_col = action_to_coords(int(action.item()), board_size)
282
+ row, col = inverse_transform_coords(
283
+ transformed_row,
284
+ transformed_col,
285
+ board_size,
286
+ rotation_k=rotation_k,
287
+ flip=flip,
288
+ )
289
+ return coords_to_action(row, col, board_size), None, None, value.squeeze(0)
290
+
291
+ dist = Categorical(logits=logits)
292
+ action = dist.sample()
293
+ transformed_row, transformed_col = action_to_coords(int(action.item()), board_size)
294
+ row, col = inverse_transform_coords(
295
+ transformed_row,
296
+ transformed_col,
297
+ board_size,
298
+ rotation_k=rotation_k,
299
+ flip=flip,
300
+ )
301
+ return (
302
+ coords_to_action(row, col, board_size),
303
+ dist.log_prob(action),
304
+ dist.entropy(),
305
+ value.squeeze(0),
306
+ )
307
+
308
+
309
+ def evaluate_policy_value(
310
+ policy: PolicyValueNet,
311
+ board: np.ndarray,
312
+ current_player: int,
313
+ device: torch.device,
314
+ ) -> tuple[np.ndarray, float]:
315
+ state = encode_state(board, current_player).unsqueeze(0).to(device)
316
+ with torch.no_grad():
317
+ logits, value = policy(state)
318
+ logits = masked_logits(logits.squeeze(0), board == 0)
319
+ probs = torch.softmax(logits, dim=0).detach().cpu().numpy()
320
+ return probs, float(value.item())
321
+
322
+
323
+ @dataclass
324
+ class MCTSNode:
325
+ board: np.ndarray
326
+ current_player: int
327
+ win_length: int
328
+ done: bool = False
329
+ winner: int = 0
330
+ priors: dict[int, float] = field(default_factory=dict)
331
+ visit_counts: dict[int, int] = field(default_factory=dict)
332
+ value_sums: dict[int, float] = field(default_factory=dict)
333
+ children: dict[int, "MCTSNode"] = field(default_factory=dict)
334
+ expanded: bool = False
335
+
336
+ def expand(self, priors: np.ndarray) -> None:
337
+ legal_actions = np.flatnonzero(self.board.reshape(-1) == 0)
338
+ total_prob = float(np.sum(priors[legal_actions]))
339
+ if total_prob <= 0.0:
340
+ uniform = 1.0 / max(len(legal_actions), 1)
341
+ self.priors = {int(action): uniform for action in legal_actions}
342
+ else:
343
+ self.priors = {
344
+ int(action): float(priors[action] / total_prob)
345
+ for action in legal_actions
346
+ }
347
+ self.visit_counts = {action: 0 for action in self.priors}
348
+ self.value_sums = {action: 0.0 for action in self.priors}
349
+ self.expanded = True
350
+
351
+ def q_value(self, action: int) -> float:
352
+ visits = self.visit_counts[action]
353
+ if visits == 0:
354
+ return 0.0
355
+ return self.value_sums[action] / visits
356
+
357
+ def select_action(self, c_puct: float) -> int:
358
+ total_visits = sum(self.visit_counts.values())
359
+ sqrt_total = math.sqrt(total_visits + 1.0)
360
+ best_action = -1
361
+ best_score = -float("inf")
362
+
363
+ for action, prior in self.priors.items():
364
+ visits = self.visit_counts[action]
365
+ q = self.q_value(action)
366
+ u = c_puct * prior * sqrt_total / (1.0 + visits)
367
+ score = q + u
368
+ if score > best_score:
369
+ best_score = score
370
+ best_action = action
371
+
372
+ return best_action
373
+
374
+ def child_for_action(self, action: int) -> "MCTSNode":
375
+ child = self.children.get(action)
376
+ if child is not None:
377
+ return child
378
+
379
+ next_board, next_player, done, winner = apply_action_to_board(
380
+ board=self.board,
381
+ current_player=self.current_player,
382
+ action=action,
383
+ win_length=self.win_length,
384
+ )
385
+ child = MCTSNode(
386
+ board=next_board,
387
+ current_player=next_player,
388
+ win_length=self.win_length,
389
+ done=done,
390
+ winner=winner,
391
+ )
392
+ self.children[action] = child
393
+ return child
394
+
395
+
396
+ def terminal_value(winner: int, current_player: int) -> float:
397
+ if winner == 0:
398
+ return 0.0
399
+ return 1.0 if winner == current_player else -1.0
400
+
401
+
402
+ def choose_mcts_action(
403
+ policy: PolicyValueNet,
404
+ board: np.ndarray,
405
+ current_player: int,
406
+ win_length: int,
407
+ device: torch.device,
408
+ num_simulations: int,
409
+ c_puct: float,
410
+ ) -> tuple[int, np.ndarray]:
411
+ root = MCTSNode(
412
+ board=board.copy(),
413
+ current_player=current_player,
414
+ win_length=win_length,
415
+ )
416
+
417
+ priors, _ = evaluate_policy_value(policy, root.board, root.current_player, device)
418
+ root.expand(priors)
419
+
420
+ for _ in range(num_simulations):
421
+ node = root
422
+ path: list[tuple[MCTSNode, int]] = []
423
+
424
+ while node.expanded and not node.done:
425
+ action = node.select_action(c_puct)
426
+ path.append((node, action))
427
+ node = node.child_for_action(action)
428
+
429
+ if node.done:
430
+ value = terminal_value(node.winner, node.current_player)
431
+ else:
432
+ priors, value = evaluate_policy_value(policy, node.board, node.current_player, device)
433
+ node.expand(priors)
434
+
435
+ for parent, action in reversed(path):
436
+ value = -value
437
+ parent.visit_counts[action] += 1
438
+ parent.value_sums[action] += value
439
+
440
+ visits = np.zeros(board.size, dtype=np.float32)
441
+ for action, count in root.visit_counts.items():
442
+ visits[action] = float(count)
443
+
444
+ if np.all(visits == 0):
445
+ best_action = int(np.argmax(priors))
446
+ else:
447
+ best_action = int(np.argmax(visits))
448
+
449
+ return best_action, visits.reshape(board.shape)
450
+
451
+
452
+ def choose_ai_action(
453
+ policy: PolicyValueNet,
454
+ board: np.ndarray,
455
+ current_player: int,
456
+ win_length: int,
457
+ device: torch.device,
458
+ agent: str,
459
+ mcts_sims: int,
460
+ c_puct: float,
461
+ ) -> tuple[int, np.ndarray | None]:
462
+ if agent == "mcts":
463
+ return choose_mcts_action(
464
+ policy=policy,
465
+ board=board,
466
+ current_player=current_player,
467
+ win_length=win_length,
468
+ device=device,
469
+ num_simulations=mcts_sims,
470
+ c_puct=c_puct,
471
+ )
472
+
473
+ action, _, _, _ = sample_action(
474
+ policy=policy,
475
+ board=board,
476
+ current_player=current_player,
477
+ device=device,
478
+ greedy=True,
479
+ augment=False,
480
+ )
481
+ return action, None
482
+
483
+
484
+ def self_play_episode(
485
+ policy: PolicyValueNet,
486
+ env: GomokuEnv,
487
+ device: torch.device,
488
+ gamma: float,
489
+ augment: bool,
490
+ ) -> tuple[list[torch.Tensor], list[float], list[torch.Tensor], list[torch.Tensor], int, int]:
491
+ env.reset()
492
+ log_probs: list[torch.Tensor] = []
493
+ entropies: list[torch.Tensor] = []
494
+ values: list[torch.Tensor] = []
495
+ players: list[int] = []
496
+
497
+ while not env.done:
498
+ player = env.current_player
499
+ action, log_prob, entropy, value = sample_action(
500
+ policy=policy,
501
+ board=env.board,
502
+ current_player=player,
503
+ device=device,
504
+ greedy=False,
505
+ augment=augment,
506
+ )
507
+ log_probs.append(log_prob)
508
+ entropies.append(entropy)
509
+ values.append(value)
510
+ players.append(player)
511
+ env.step(action)
512
+
513
+ returns: list[float] = []
514
+ total_moves = len(players)
515
+ for move_idx, player in enumerate(players):
516
+ outcome = 0.0
517
+ if env.winner != 0:
518
+ outcome = 1.0 if player == env.winner else -1.0
519
+ discounted = outcome * (gamma ** (total_moves - move_idx - 1))
520
+ returns.append(discounted)
521
+
522
+ return log_probs, returns, entropies, values, env.winner, total_moves
523
+
524
+
525
+ def update_policy(
526
+ optimizer: torch.optim.Optimizer,
527
+ batch_log_probs: list[torch.Tensor],
528
+ batch_returns: list[float],
529
+ batch_entropies: list[torch.Tensor],
530
+ batch_values: list[torch.Tensor],
531
+ entropy_coef: float,
532
+ value_coef: float,
533
+ grad_clip: float,
534
+ device: torch.device,
535
+ ) -> float:
536
+ returns = torch.tensor(batch_returns, dtype=torch.float32, device=device)
537
+ log_probs = torch.stack(batch_log_probs)
538
+ entropies = torch.stack(batch_entropies)
539
+ values = torch.stack(batch_values)
540
+ advantages = returns - values.detach()
541
+ if advantages.numel() > 1:
542
+ advantages = (advantages - advantages.mean()) / (advantages.std(unbiased=False) + 1e-6)
543
+
544
+ policy_loss = -(log_probs * advantages).mean()
545
+ value_loss = torch.mean((values - returns) ** 2)
546
+ entropy_bonus = entropies.mean()
547
+ loss = policy_loss + value_coef * value_loss - entropy_coef * entropy_bonus
548
+
549
+ optimizer.zero_grad(set_to_none=True)
550
+ loss.backward()
551
+ nn.utils.clip_grad_norm_(optimizer.param_groups[0]["params"], grad_clip)
552
+ optimizer.step()
553
+ return float(loss.item())
554
+
555
+
556
+ def save_checkpoint(
557
+ path: Path,
558
+ policy: PolicyValueNet,
559
+ args: argparse.Namespace,
560
+ ) -> None:
561
+ payload = {
562
+ "state_dict": policy.state_dict(),
563
+ "channels": args.channels,
564
+ "board_size": args.board_size,
565
+ "win_length": args.win_length,
566
+ }
567
+ torch.save(payload, path)
568
+
569
+
570
+ def load_checkpoint(path: Path, map_location: torch.device) -> dict:
571
+ checkpoint = torch.load(path, map_location=map_location)
572
+ if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
573
+ return checkpoint
574
+ if isinstance(checkpoint, dict) and "policy_state_dict" in checkpoint:
575
+ raise RuntimeError(
576
+ f"{path} is an old fixed-board checkpoint from the previous implementation. "
577
+ "It is not compatible with the current fully-convolutional actor-critic model. "
578
+ "Please retrain with the current script."
579
+ )
580
+ return {
581
+ "state_dict": checkpoint,
582
+ "channels": 64,
583
+ "board_size": None,
584
+ "win_length": None,
585
+ }
586
+
587
+
588
+ def load_policy(path: Path, channels: int, device: torch.device) -> PolicyValueNet:
589
+ checkpoint = load_checkpoint(path, map_location=device)
590
+ state_dict = checkpoint["state_dict"]
591
+ saved_channels = int(checkpoint.get("channels", channels))
592
+
593
+ policy = PolicyValueNet(channels=saved_channels).to(device)
594
+ policy.load_state_dict(state_dict)
595
+ policy.eval()
596
+ return policy
597
+
598
+
599
+ def resolve_game_config(
600
+ checkpoint_path: Path,
601
+ arg_board_size: int | None,
602
+ arg_win_length: int | None,
603
+ arg_channels: int,
604
+ device: torch.device,
605
+ ) -> tuple[PolicyValueNet, int, int]:
606
+ checkpoint = load_checkpoint(checkpoint_path, map_location=device)
607
+ board_size = int(checkpoint.get("board_size") or arg_board_size or 15)
608
+ win_length = int(checkpoint.get("win_length") or arg_win_length or 5)
609
+ channels = int(checkpoint.get("channels") or arg_channels)
610
+
611
+ policy = PolicyValueNet(channels=channels).to(device)
612
+ policy.load_state_dict(checkpoint["state_dict"])
613
+ policy.eval()
614
+ return policy, board_size, win_length
615
+
616
+
617
+ def play_vs_random_once(
618
+ policy: PolicyValueNet,
619
+ board_size: int,
620
+ win_length: int,
621
+ device: torch.device,
622
+ policy_player: int,
623
+ agent: str = "policy",
624
+ mcts_sims: int = 100,
625
+ c_puct: float = 1.5,
626
+ ) -> int:
627
+ env = GomokuEnv(board_size=board_size, win_length=win_length)
628
+ env.reset()
629
+
630
+ while not env.done:
631
+ if env.current_player == policy_player:
632
+ action, _ = choose_ai_action(
633
+ policy=policy,
634
+ board=env.board,
635
+ current_player=env.current_player,
636
+ win_length=win_length,
637
+ device=device,
638
+ agent=agent,
639
+ mcts_sims=mcts_sims,
640
+ c_puct=c_puct,
641
+ )
642
+ else:
643
+ action = int(np.random.choice(env.valid_moves()))
644
+ env.step(action)
645
+
646
+ return env.winner
647
+
648
+
649
+ def evaluate_vs_random(
650
+ policy: PolicyValueNet,
651
+ board_size: int,
652
+ win_length: int,
653
+ device: torch.device,
654
+ games: int,
655
+ agent: str = "policy",
656
+ mcts_sims: int = 100,
657
+ c_puct: float = 1.5,
658
+ ) -> tuple[float, int, int, int]:
659
+ wins = 0
660
+ draws = 0
661
+ losses = 0
662
+
663
+ for game_idx in range(games):
664
+ policy_player = 1 if game_idx < games // 2 else -1
665
+ winner = play_vs_random_once(
666
+ policy=policy,
667
+ board_size=board_size,
668
+ win_length=win_length,
669
+ device=device,
670
+ policy_player=policy_player,
671
+ agent=agent,
672
+ mcts_sims=mcts_sims,
673
+ c_puct=c_puct,
674
+ )
675
+ if winner == 0:
676
+ draws += 1
677
+ elif winner == policy_player:
678
+ wins += 1
679
+ else:
680
+ losses += 1
681
+
682
+ return wins / max(games, 1), wins, draws, losses
683
+
684
+
685
+ def train(args: argparse.Namespace) -> None:
686
+ set_seed(args.seed)
687
+ device = choose_device(args.device)
688
+ env = GomokuEnv(board_size=args.board_size, win_length=args.win_length)
689
+ policy = PolicyValueNet(channels=args.channels).to(device)
690
+ if args.init_checkpoint is not None and args.init_checkpoint.exists():
691
+ checkpoint = load_checkpoint(args.init_checkpoint, map_location=device)
692
+ policy.load_state_dict(checkpoint["state_dict"])
693
+ optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr)
694
+
695
+ recent_winners: deque[int] = deque(maxlen=args.print_every)
696
+ recent_lengths: deque[int] = deque(maxlen=args.print_every)
697
+ batch_log_probs: list[torch.Tensor] = []
698
+ batch_returns: list[float] = []
699
+ batch_entropies: list[torch.Tensor] = []
700
+ batch_values: list[torch.Tensor] = []
701
+ last_loss = 0.0
702
+
703
+ print(f"device={device} board={args.board_size} win={args.win_length}")
704
+
705
+ for episode in range(1, args.episodes + 1):
706
+ log_probs, returns, entropies, values, winner, moves = self_play_episode(
707
+ policy=policy,
708
+ env=env,
709
+ device=device,
710
+ gamma=args.gamma,
711
+ augment=args.symmetry_augment,
712
+ )
713
+ batch_log_probs.extend(log_probs)
714
+ batch_returns.extend(returns)
715
+ batch_entropies.extend(entropies)
716
+ batch_values.extend(values)
717
+ recent_winners.append(winner)
718
+ recent_lengths.append(moves)
719
+
720
+ if episode % args.batch_size == 0 or episode == args.episodes:
721
+ policy.train()
722
+ last_loss = update_policy(
723
+ optimizer=optimizer,
724
+ batch_log_probs=batch_log_probs,
725
+ batch_returns=batch_returns,
726
+ batch_entropies=batch_entropies,
727
+ batch_values=batch_values,
728
+ entropy_coef=args.entropy_coef,
729
+ value_coef=args.value_coef,
730
+ grad_clip=args.grad_clip,
731
+ device=device,
732
+ )
733
+ batch_log_probs.clear()
734
+ batch_returns.clear()
735
+ batch_entropies.clear()
736
+ batch_values.clear()
737
+
738
+ if episode % args.print_every == 0 or episode == args.episodes:
739
+ p1_wins = sum(1 for x in recent_winners if x == 1)
740
+ p2_wins = sum(1 for x in recent_winners if x == -1)
741
+ draws = sum(1 for x in recent_winners if x == 0)
742
+ avg_len = float(np.mean(recent_lengths)) if recent_lengths else 0.0
743
+ message = (
744
+ f"episode={episode:6d} loss={last_loss:8.4f} "
745
+ f"p1={p1_wins:4d} p2={p2_wins:4d} draw={draws:4d} avg_len={avg_len:6.2f}"
746
+ )
747
+ if args.eval_every > 0 and episode % args.eval_every == 0:
748
+ policy.eval()
749
+ win_rate, wins, eval_draws, losses = evaluate_vs_random(
750
+ policy=policy,
751
+ board_size=args.board_size,
752
+ win_length=args.win_length,
753
+ device=device,
754
+ games=args.eval_games,
755
+ )
756
+ message += (
757
+ f" random_win_rate={win_rate:.3f}"
758
+ f" ({wins}/{eval_draws}/{losses})"
759
+ )
760
+ print(message)
761
+
762
+ save_checkpoint(args.checkpoint, policy, args)
763
+ print(f"saved checkpoint to {args.checkpoint}")
764
+
765
+
766
+ def evaluate(args: argparse.Namespace) -> None:
767
+ device = choose_device(args.device)
768
+ policy, board_size, win_length = resolve_game_config(
769
+ checkpoint_path=args.checkpoint,
770
+ arg_board_size=args.board_size,
771
+ arg_win_length=args.win_length,
772
+ arg_channels=args.channels,
773
+ device=device,
774
+ )
775
+ win_rate, wins, draws, losses = evaluate_vs_random(
776
+ policy=policy,
777
+ board_size=board_size,
778
+ win_length=win_length,
779
+ device=device,
780
+ games=args.games,
781
+ agent=args.agent,
782
+ mcts_sims=args.mcts_sims,
783
+ c_puct=args.c_puct,
784
+ )
785
+ print(f"device={device}")
786
+ print(f"agent={args.agent} mcts_sims={args.mcts_sims}")
787
+ print(f"win_rate={win_rate:.3f} wins={wins} draws={draws} losses={losses}")
788
+
789
+
790
+ def ask_human_move(env: GomokuEnv) -> int:
791
+ while True:
792
+ text = input("your move (row col): ").strip()
793
+ parts = text.replace(",", " ").split()
794
+ if len(parts) != 2:
795
+ print("please enter: row col")
796
+ continue
797
+ try:
798
+ row, col = (int(parts[0]) - 1, int(parts[1]) - 1)
799
+ except ValueError:
800
+ print("row and col must be integers")
801
+ continue
802
+ if not (0 <= row < env.board_size and 0 <= col < env.board_size):
803
+ print("move out of range")
804
+ continue
805
+ if env.board[row, col] != 0:
806
+ print("that position is occupied")
807
+ continue
808
+ return row * env.board_size + col
809
+
810
+
811
+ def play(args: argparse.Namespace) -> None:
812
+ device = choose_device(args.device)
813
+ policy, board_size, win_length = resolve_game_config(
814
+ checkpoint_path=args.checkpoint,
815
+ arg_board_size=args.board_size,
816
+ arg_win_length=args.win_length,
817
+ arg_channels=args.channels,
818
+ device=device,
819
+ )
820
+ env = GomokuEnv(board_size=board_size, win_length=win_length)
821
+ human_player = 1 if args.human_first else -1
822
+
823
+ print(f"device={device}")
824
+ print(
825
+ f"human={'X' if human_player == 1 else 'O'} ai={'O' if human_player == 1 else 'X'} "
826
+ f"agent={args.agent} mcts_sims={args.mcts_sims}"
827
+ )
828
+
829
+ while not env.done:
830
+ print()
831
+ print(env.render())
832
+ print()
833
+
834
+ if env.current_player == human_player:
835
+ action = ask_human_move(env)
836
+ else:
837
+ action, _ = choose_ai_action(
838
+ policy=policy,
839
+ board=env.board,
840
+ current_player=env.current_player,
841
+ win_length=win_length,
842
+ device=device,
843
+ agent=args.agent,
844
+ mcts_sims=args.mcts_sims,
845
+ c_puct=args.c_puct,
846
+ )
847
+ row, col = divmod(action, env.board_size)
848
+ print(f"ai move: {row + 1} {col + 1}")
849
+
850
+ env.step(action)
851
+
852
+ print()
853
+ print(env.render())
854
+ if env.winner == 0:
855
+ print("draw")
856
+ elif env.winner == human_player:
857
+ print("you win")
858
+ else:
859
+ print("ai wins")
860
+
861
+
862
+ def gui(args: argparse.Namespace) -> None:
863
+ try:
864
+ import pygame
865
+ except ModuleNotFoundError as exc:
866
+ raise SystemExit(
867
+ "pygame is not installed. Install it with: "
868
+ "~/miniconda3/bin/conda run -n lerobot python -m pip install pygame"
869
+ ) from exc
870
+
871
+ device = choose_device(args.device)
872
+ policy, board_size, win_length = resolve_game_config(
873
+ checkpoint_path=args.checkpoint,
874
+ arg_board_size=args.board_size,
875
+ arg_win_length=args.win_length,
876
+ arg_channels=args.channels,
877
+ device=device,
878
+ )
879
+ env = GomokuEnv(board_size=board_size, win_length=win_length)
880
+ human_player = 1 if args.human_first else -1
881
+ last_search_visits: np.ndarray | None = None
882
+
883
+ pygame.init()
884
+ pygame.display.set_caption("Gomoku Policy Gradient")
885
+ font = pygame.font.SysFont("Arial", 24)
886
+ small_font = pygame.font.SysFont("Arial", 18)
887
+
888
+ cell_size = args.cell_size
889
+ padding = 40
890
+ status_height = 80
891
+ board_pixels = board_size * cell_size
892
+ screen = pygame.display.set_mode(
893
+ (board_pixels + padding * 2, board_pixels + padding * 2 + status_height)
894
+ )
895
+ clock = pygame.time.Clock()
896
+
897
+ background = (236, 196, 122)
898
+ line_color = (80, 55, 20)
899
+ black_stone = (20, 20, 20)
900
+ white_stone = (245, 245, 245)
901
+ accent = (180, 40, 40)
902
+
903
+ def board_to_screen(row: int, col: int) -> tuple[int, int]:
904
+ x = padding + col * cell_size + cell_size // 2
905
+ y = padding + row * cell_size + cell_size // 2
906
+ return x, y
907
+
908
+ def mouse_to_action(pos: tuple[int, int]) -> int | None:
909
+ x, y = pos
910
+ left = padding
911
+ top = padding
912
+ if x < left or y < top:
913
+ return None
914
+ col = (x - left) // cell_size
915
+ row = (y - top) // cell_size
916
+ if not (0 <= row < env.board_size and 0 <= col < env.board_size):
917
+ return None
918
+ if env.board[row, col] != 0:
919
+ return None
920
+ return row * env.board_size + col
921
+
922
+ def restart() -> None:
923
+ nonlocal last_search_visits
924
+ env.reset()
925
+ last_search_visits = None
926
+ if env.current_player != human_player:
927
+ ai_step()
928
+
929
+ def ai_step() -> None:
930
+ nonlocal last_search_visits
931
+ if env.done or env.current_player == human_player:
932
+ return
933
+ action, visits = choose_ai_action(
934
+ policy=policy,
935
+ board=env.board,
936
+ current_player=env.current_player,
937
+ win_length=win_length,
938
+ device=device,
939
+ agent=args.agent,
940
+ mcts_sims=args.mcts_sims,
941
+ c_puct=args.c_puct,
942
+ )
943
+ last_search_visits = visits
944
+ env.step(action)
945
+
946
+ def status_text() -> str:
947
+ if env.done:
948
+ if env.winner == 0:
949
+ return "Draw. Press R to restart."
950
+ if env.winner == human_player:
951
+ return "You win. Press R to restart."
952
+ return "AI wins. Press R to restart."
953
+ if env.current_player == human_player:
954
+ return "Your turn. Left click to place."
955
+ return "AI is thinking..."
956
+
957
+ if env.current_player != human_player:
958
+ ai_step()
959
+
960
+ running = True
961
+ while running:
962
+ for event in pygame.event.get():
963
+ if event.type == pygame.QUIT:
964
+ running = False
965
+ elif event.type == pygame.KEYDOWN:
966
+ if event.key == pygame.K_ESCAPE:
967
+ running = False
968
+ elif event.key == pygame.K_r:
969
+ restart()
970
+ elif event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
971
+ if env.done or env.current_player != human_player:
972
+ continue
973
+ action = mouse_to_action(event.pos)
974
+ if action is None:
975
+ continue
976
+ env.step(action)
977
+ ai_step()
978
+
979
+ screen.fill(background)
980
+
981
+ for idx in range(board_size + 1):
982
+ x = padding + idx * cell_size
983
+ pygame.draw.line(screen, line_color, (x, padding), (x, padding + board_pixels), 2)
984
+ y = padding + idx * cell_size
985
+ pygame.draw.line(screen, line_color, (padding, y), (padding + board_pixels, y), 2)
986
+
987
+ for row in range(env.board_size):
988
+ for col in range(env.board_size):
989
+ stone = int(env.board[row, col])
990
+ if stone == 0:
991
+ continue
992
+ x, y = board_to_screen(row, col)
993
+ color = black_stone if stone == 1 else white_stone
994
+ pygame.draw.circle(screen, color, (x, y), cell_size // 2 - 4)
995
+ pygame.draw.circle(screen, line_color, (x, y), cell_size // 2 - 4, 1)
996
+
997
+ for idx in range(board_size):
998
+ label = small_font.render(str(idx + 1), True, line_color)
999
+ screen.blit(
1000
+ label,
1001
+ (padding + idx * cell_size + cell_size // 2 - label.get_width() // 2, 8),
1002
+ )
1003
+ screen.blit(
1004
+ label,
1005
+ (8, padding + idx * cell_size + cell_size // 2 - label.get_height() // 2),
1006
+ )
1007
+
1008
+ info = (
1009
+ f"{board_size}x{board_size} connect={win_length} "
1010
+ f"device={device} human={'X' if human_player == 1 else 'O'} "
1011
+ f"agent={args.agent}"
1012
+ )
1013
+ info_surface = small_font.render(info, True, line_color)
1014
+ status_surface = font.render(status_text(), True, accent)
1015
+ screen.blit(info_surface, (padding, padding + board_pixels + 16))
1016
+ screen.blit(status_surface, (padding, padding + board_pixels + 42))
1017
+
1018
+ if last_search_visits is not None and args.agent == "mcts":
1019
+ peak = float(np.max(last_search_visits))
1020
+ if peak > 0:
1021
+ stats = small_font.render(
1022
+ f"mcts_sims={args.mcts_sims} peak_visits={int(peak)}",
1023
+ True,
1024
+ line_color,
1025
+ )
1026
+ screen.blit(stats, (padding + 380, padding + board_pixels + 16))
1027
+
1028
+ pygame.display.flip()
1029
+ clock.tick(args.fps)
1030
+
1031
+ pygame.quit()
1032
+
1033
+
1034
+ def build_parser() -> argparse.ArgumentParser:
1035
+ parser = argparse.ArgumentParser(description="Minimal Gomoku policy gradient example")
1036
+ subparsers = parser.add_subparsers(dest="mode", required=True)
1037
+
1038
+ def add_common_arguments(subparser: argparse.ArgumentParser, defaults_from_checkpoint: bool = False) -> None:
1039
+ board_default = None if defaults_from_checkpoint else 15
1040
+ win_default = None if defaults_from_checkpoint else 5
1041
+ subparser.add_argument("--board-size", type=int, default=board_default)
1042
+ subparser.add_argument("--win-length", type=int, default=win_default)
1043
+ subparser.add_argument("--channels", type=int, default=64)
1044
+ subparser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto")
1045
+ subparser.add_argument("--checkpoint", type=Path, default=Path("gomoku_policy.pt"))
1046
+
1047
+ def add_inference_arguments(subparser: argparse.ArgumentParser, default_agent: str = "mcts") -> None:
1048
+ subparser.add_argument("--agent", choices=["policy", "mcts"], default=default_agent)
1049
+ subparser.add_argument("--mcts-sims", type=int, default=120)
1050
+ subparser.add_argument("--c-puct", type=float, default=1.5)
1051
+
1052
+ train_parser = subparsers.add_parser("train", help="self-play training")
1053
+ add_common_arguments(train_parser)
1054
+ train_parser.add_argument("--episodes", type=int, default=5000)
1055
+ train_parser.add_argument("--batch-size", type=int, default=32)
1056
+ train_parser.add_argument("--lr", type=float, default=1e-3)
1057
+ train_parser.add_argument("--gamma", type=float, default=0.99)
1058
+ train_parser.add_argument("--entropy-coef", type=float, default=0.01)
1059
+ train_parser.add_argument("--value-coef", type=float, default=0.5)
1060
+ train_parser.add_argument("--grad-clip", type=float, default=1.0)
1061
+ train_parser.add_argument("--print-every", type=int, default=100)
1062
+ train_parser.add_argument("--eval-every", type=int, default=500)
1063
+ train_parser.add_argument("--eval-games", type=int, default=40)
1064
+ train_parser.add_argument("--seed", type=int, default=42)
1065
+ train_parser.add_argument("--init-checkpoint", type=Path, default=None)
1066
+ train_parser.add_argument(
1067
+ "--no-symmetry-augment",
1068
+ dest="symmetry_augment",
1069
+ action="store_false",
1070
+ help="disable random rotation/flip augmentation during training",
1071
+ )
1072
+ train_parser.set_defaults(symmetry_augment=True)
1073
+ train_parser.set_defaults(func=train)
1074
+
1075
+ eval_parser = subparsers.add_parser("eval", help="evaluate against random agent")
1076
+ add_common_arguments(eval_parser)
1077
+ eval_parser.add_argument("--games", type=int, default=100)
1078
+ add_inference_arguments(eval_parser)
1079
+ eval_parser.set_defaults(func=evaluate)
1080
+
1081
+ play_parser = subparsers.add_parser("play", help="play against the trained model")
1082
+ add_common_arguments(play_parser, defaults_from_checkpoint=True)
1083
+ play_parser.add_argument("--human-first", action="store_true", help="human plays X")
1084
+ add_inference_arguments(play_parser)
1085
+ play_parser.set_defaults(func=play)
1086
+
1087
+ gui_parser = subparsers.add_parser("gui", help="pygame GUI for testing against the model")
1088
+ add_common_arguments(gui_parser, defaults_from_checkpoint=True)
1089
+ gui_parser.add_argument("--human-first", action="store_true", help="human plays X")
1090
+ gui_parser.add_argument("--cell-size", type=int, default=48)
1091
+ gui_parser.add_argument("--fps", type=int, default=30)
1092
+ add_inference_arguments(gui_parser)
1093
+ gui_parser.set_defaults(func=gui)
1094
+
1095
+ return parser
1096
+
1097
+
1098
+ def main() -> None:
1099
+ parser = build_parser()
1100
+ args = parser.parse_args()
1101
+ args.func(args)
1102
+
1103
+
1104
+ if __name__ == "__main__":
1105
+ main()
train_mcts_15x15_5.sh ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR="$(cd "$(dirname "$0")" && pwd)"
5
+ CONDA_BIN="${CONDA_BIN:-$HOME/miniconda3/bin/conda}"
6
+ ENV_NAME="${ENV_NAME:-lerobot}"
7
+
8
+ INIT_CHECKPOINT="${INIT_CHECKPOINT:-$ROOT_DIR/gomoku_7x7_5.pt}"
9
+ OUTPUT_CHECKPOINT="${OUTPUT_CHECKPOINT:-$ROOT_DIR/gomoku_mcts_15x15_5.pt}"
10
+
11
+ BOARD_SIZE="${BOARD_SIZE:-15}"
12
+ WIN_LENGTH="${WIN_LENGTH:-5}"
13
+ CHANNELS="${CHANNELS:-64}"
14
+
15
+ ITERATIONS="${ITERATIONS:-3000}"
16
+ GAMES_PER_ITER="${GAMES_PER_ITER:-12}"
17
+ TRAIN_STEPS="${TRAIN_STEPS:-64}"
18
+ BATCH_SIZE="${BATCH_SIZE:-128}"
19
+ BUFFER_SIZE="${BUFFER_SIZE:-50000}"
20
+
21
+ MCTS_SIMS="${MCTS_SIMS:-64}"
22
+ EVAL_MCTS_SIMS="${EVAL_MCTS_SIMS:-160}"
23
+ EVAL_EVERY="${EVAL_EVERY:-10}"
24
+ EVAL_GAMES="${EVAL_GAMES:-20}"
25
+ SAVE_EVERY="${SAVE_EVERY:-10}"
26
+
27
+ LR="${LR:-5e-4}"
28
+ WEIGHT_DECAY="${WEIGHT_DECAY:-1e-4}"
29
+ VALUE_COEF="${VALUE_COEF:-1.0}"
30
+ CPUCT="${CPUCT:-1.5}"
31
+ TEMPERATURE="${TEMPERATURE:-1.0}"
32
+ TEMPERATURE_DROP_MOVES="${TEMPERATURE_DROP_MOVES:-10}"
33
+ DIRICHLET_ALPHA="${DIRICHLET_ALPHA:-0.3}"
34
+ NOISE_EPS="${NOISE_EPS:-0.25}"
35
+ SEED="${SEED:-42}"
36
+ DEVICE="${DEVICE:-auto}"
37
+
38
+ CMD=(
39
+ "$CONDA_BIN" run -n "$ENV_NAME" python "$ROOT_DIR/gomoku_mcts.py" train
40
+ --board-size "$BOARD_SIZE"
41
+ --win-length "$WIN_LENGTH"
42
+ --channels "$CHANNELS"
43
+ --iterations "$ITERATIONS"
44
+ --games-per-iter "$GAMES_PER_ITER"
45
+ --train-steps "$TRAIN_STEPS"
46
+ --batch-size "$BATCH_SIZE"
47
+ --buffer-size "$BUFFER_SIZE"
48
+ --mcts-sims "$MCTS_SIMS"
49
+ --eval-mcts-sims "$EVAL_MCTS_SIMS"
50
+ --eval-every "$EVAL_EVERY"
51
+ --eval-games "$EVAL_GAMES"
52
+ --save-every "$SAVE_EVERY"
53
+ --lr "$LR"
54
+ --weight-decay "$WEIGHT_DECAY"
55
+ --value-coef "$VALUE_COEF"
56
+ --c-puct "$CPUCT"
57
+ --temperature "$TEMPERATURE"
58
+ --temperature-drop-moves "$TEMPERATURE_DROP_MOVES"
59
+ --dirichlet-alpha "$DIRICHLET_ALPHA"
60
+ --noise-eps "$NOISE_EPS"
61
+ --seed "$SEED"
62
+ --device "$DEVICE"
63
+ --checkpoint "$OUTPUT_CHECKPOINT"
64
+ )
65
+
66
+ if [[ -f "$INIT_CHECKPOINT" ]]; then
67
+ CMD+=(--init-checkpoint "$INIT_CHECKPOINT")
68
+ else
69
+ echo "init checkpoint not found, training from scratch: $INIT_CHECKPOINT"
70
+ fi
71
+
72
+ printf 'Running command:\n%s\n' "${CMD[*]}"
73
+ exec "${CMD[@]}"