Upload Gomoku training and MCTS code
Browse files- README.md +178 -0
- gomoku_mcts.py +975 -0
- gomoku_pg.py +1105 -0
- train_mcts_15x15_5.sh +73 -0
README.md
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal Gomoku Policy Gradient
|
| 2 |
+
|
| 3 |
+
这是一个学习向的最简五子棋策略梯度示例,核心特点:
|
| 4 |
+
|
| 5 |
+
- 一个文件:`gomoku_pg.py`
|
| 6 |
+
- 可配置棋盘大小:例如 `5x5`、`15x15`
|
| 7 |
+
- 可配置连珠数:例如 `4` 连珠、`5` 连珠
|
| 8 |
+
- 使用 `torch` 和精简版 `actor-critic` policy gradient
|
| 9 |
+
- 同一个策略同时扮演先手和后手,自博弈训练
|
| 10 |
+
|
| 11 |
+
## 核心思路
|
| 12 |
+
|
| 13 |
+
状态编码是 3 个平面:
|
| 14 |
+
|
| 15 |
+
1. 当前行动方自己的棋子
|
| 16 |
+
2. 对手的棋子
|
| 17 |
+
3. 合法落点
|
| 18 |
+
|
| 19 |
+
策略网络是一个很小的全卷积网络,输出每个格子的 logits。非法位置会被 mask 掉,然后对合法位置做采样。
|
| 20 |
+
|
| 21 |
+
训练时:
|
| 22 |
+
|
| 23 |
+
1. 用当前策略自博弈完整下一局
|
| 24 |
+
2. 每一步保存 `log_prob(action)`
|
| 25 |
+
3. 终局后给每一步一个回报
|
| 26 |
+
当前步所属玩家最终赢了就是 `+1`
|
| 27 |
+
输了就是 `-1`
|
| 28 |
+
平局就是 `0`
|
| 29 |
+
4. 策略头用 advantage 做 policy gradient,价值头预测回报,降低方差
|
| 30 |
+
5. 训练时随机旋转/翻转棋盘,提升样本效率
|
| 31 |
+
|
| 32 |
+
## 先做小棋盘验证
|
| 33 |
+
|
| 34 |
+
建议先验证:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py train \
|
| 38 |
+
--board-size 5 \
|
| 39 |
+
--win-length 4 \
|
| 40 |
+
--episodes 5000 \
|
| 41 |
+
--batch-size 32 \
|
| 42 |
+
--eval-every 300 \
|
| 43 |
+
--eval-games 40 \
|
| 44 |
+
--checkpoint gomoku_5x5_4.pt
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
评估:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py eval \
|
| 51 |
+
--board-size 5 \
|
| 52 |
+
--win-length 4 \
|
| 53 |
+
--checkpoint gomoku_5x5_4.pt \
|
| 54 |
+
--agent mcts \
|
| 55 |
+
--mcts-sims 120 \
|
| 56 |
+
--games 100
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
图形界面对弈验证:
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py gui \
|
| 63 |
+
--checkpoint gomoku_5x5_4.pt \
|
| 64 |
+
--agent mcts \
|
| 65 |
+
--mcts-sims 120 \
|
| 66 |
+
--human-first
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
操作:
|
| 70 |
+
|
| 71 |
+
- 鼠标左键落子
|
| 72 |
+
- `R` 重新开始
|
| 73 |
+
- `Esc` 退出
|
| 74 |
+
|
| 75 |
+
如果还没装 `pygame`:
|
| 76 |
+
|
| 77 |
+
```bash
|
| 78 |
+
~/miniconda3/bin/conda run -n lerobot python -m pip install pygame
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
人机对弈:
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py play \
|
| 85 |
+
--board-size 5 \
|
| 86 |
+
--win-length 4 \
|
| 87 |
+
--checkpoint gomoku_5x5_4.pt \
|
| 88 |
+
--agent mcts \
|
| 89 |
+
--mcts-sims 120 \
|
| 90 |
+
--human-first
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## 切换到标准五子棋
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py train \
|
| 97 |
+
--board-size 15 \
|
| 98 |
+
--win-length 5 \
|
| 99 |
+
--episodes 20000 \
|
| 100 |
+
--batch-size 32 \
|
| 101 |
+
--eval-every 1000 \
|
| 102 |
+
--eval-games 40 \
|
| 103 |
+
--checkpoint gomoku_15x15_5.pt
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
注意:代码可以直接切棋盘大小,但模型参数需要重新训练,不能指望 `5x5 + 4 连珠` 学到的策略直接适用于 `15x15 + 5 连珠`。
|
| 107 |
+
|
| 108 |
+
## 怎么验证算法
|
| 109 |
+
|
| 110 |
+
最直接的验证顺序:
|
| 111 |
+
|
| 112 |
+
1. 先训练 `5x5 + 4 连珠`
|
| 113 |
+
2. 用 `eval` 看对随机策略胜率是否明显高于 50%
|
| 114 |
+
3. 用 `gui` 人工对弈,观察它是否会优先补成四连、阻挡你的四连
|
| 115 |
+
4. 再切到 `15x15 + 5 连珠` 重新训练
|
| 116 |
+
|
| 117 |
+
如果你只是想验证实现有没有大错,先看小棋盘最有效,因为训练快,策略错误会更明显。
|
| 118 |
+
|
| 119 |
+
## 为什么你会很容易赢
|
| 120 |
+
|
| 121 |
+
如果你之前用的是最原始的终局奖励 `REINFORCE`,很容易出现这几个问题:
|
| 122 |
+
|
| 123 |
+
- 终局奖励太稀疏,前面大量落子几乎收不到有效学习信号
|
| 124 |
+
- 方差很大,训练出来的策略不稳定
|
| 125 |
+
- `15x15` 动作空间太大,从零自博弈非常慢
|
| 126 |
+
|
| 127 |
+
这版已经改成更稳的 `actor-critic`。即便如此,标准五子棋从零训练仍然不可能靠几百局就变强。
|
| 128 |
+
|
| 129 |
+
## 推理时 MCTS
|
| 130 |
+
|
| 131 |
+
现在 `eval`、`play`、`gui` 都支持:
|
| 132 |
+
|
| 133 |
+
- `--agent policy`:直接让策略网络落子
|
| 134 |
+
- `--agent mcts`:让策略网络和值网络先做 MCTS 搜索,再落子
|
| 135 |
+
|
| 136 |
+
建议人机测试默认用 `mcts`,通常会比直接落子强一截。
|
| 137 |
+
|
| 138 |
+
例如:
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py gui \
|
| 142 |
+
--checkpoint gomoku_15x15_5.pt \
|
| 143 |
+
--agent mcts \
|
| 144 |
+
--mcts-sims 120 \
|
| 145 |
+
--human-first
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
如果你觉得慢,可以先把 `--mcts-sims` 降到 `32` 或 `64`。
|
| 149 |
+
|
| 150 |
+
## 更现实的训练方式
|
| 151 |
+
|
| 152 |
+
建议这样做:
|
| 153 |
+
|
| 154 |
+
1. 先训 `5x5 + 4 连珠`
|
| 155 |
+
2. 再用小棋盘权重热启动更大的棋盘
|
| 156 |
+
3. 最后再训 `15x15 + 5 连珠`
|
| 157 |
+
|
| 158 |
+
例如:
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py train \
|
| 162 |
+
--board-size 7 \
|
| 163 |
+
--win-length 5 \
|
| 164 |
+
--episodes 5000 \
|
| 165 |
+
--init-checkpoint gomoku_5x5_4.pt \
|
| 166 |
+
--checkpoint gomoku_7x7_5.pt
|
| 167 |
+
```
|
| 168 |
+
|
| 169 |
+
再继续:
|
| 170 |
+
|
| 171 |
+
```bash
|
| 172 |
+
~/miniconda3/bin/conda run -n lerobot python gomoku_pg.py train \
|
| 173 |
+
--board-size 15 \
|
| 174 |
+
--win-length 5 \
|
| 175 |
+
--episodes 20000 \
|
| 176 |
+
--init-checkpoint gomoku_7x7_5.pt \
|
| 177 |
+
--checkpoint gomoku_15x15_5.pt
|
| 178 |
+
```
|
gomoku_mcts.py
ADDED
|
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Minimal Gomoku MCTS example.
|
| 3 |
+
|
| 4 |
+
This file is intentionally separate from gomoku_pg.py.
|
| 5 |
+
It uses the simpler AlphaZero-style recipe:
|
| 6 |
+
1. self-play with MCTS
|
| 7 |
+
2. policy/value targets from search
|
| 8 |
+
3. supervised update on policy + value heads
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import math
|
| 15 |
+
import random
|
| 16 |
+
from collections import deque
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def choose_device(name: str) -> torch.device:
|
| 26 |
+
if name != "auto":
|
| 27 |
+
return torch.device(name)
|
| 28 |
+
if torch.cuda.is_available():
|
| 29 |
+
return torch.device("cuda")
|
| 30 |
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 31 |
+
return torch.device("mps")
|
| 32 |
+
return torch.device("cpu")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def set_seed(seed: int) -> None:
|
| 36 |
+
random.seed(seed)
|
| 37 |
+
np.random.seed(seed)
|
| 38 |
+
torch.manual_seed(seed)
|
| 39 |
+
if torch.cuda.is_available():
|
| 40 |
+
torch.cuda.manual_seed_all(seed)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class GomokuEnv:
|
| 44 |
+
def __init__(self, board_size: int, win_length: int):
|
| 45 |
+
if board_size <= 1:
|
| 46 |
+
raise ValueError("board_size must be > 1")
|
| 47 |
+
if not 1 < win_length <= board_size:
|
| 48 |
+
raise ValueError("win_length must satisfy 1 < win_length <= board_size")
|
| 49 |
+
self.board_size = board_size
|
| 50 |
+
self.win_length = win_length
|
| 51 |
+
self.reset()
|
| 52 |
+
|
| 53 |
+
def reset(self) -> np.ndarray:
|
| 54 |
+
self.board = np.zeros((self.board_size, self.board_size), dtype=np.int8)
|
| 55 |
+
self.current_player = 1
|
| 56 |
+
self.done = False
|
| 57 |
+
self.winner = 0
|
| 58 |
+
return self.board
|
| 59 |
+
|
| 60 |
+
def valid_moves(self) -> np.ndarray:
|
| 61 |
+
return np.flatnonzero((self.board == 0).reshape(-1))
|
| 62 |
+
|
| 63 |
+
def step(self, action: int) -> tuple[bool, int]:
|
| 64 |
+
next_board, next_player, done, winner = apply_action_to_board(
|
| 65 |
+
board=self.board,
|
| 66 |
+
current_player=self.current_player,
|
| 67 |
+
action=action,
|
| 68 |
+
win_length=self.win_length,
|
| 69 |
+
)
|
| 70 |
+
self.board = next_board
|
| 71 |
+
self.current_player = next_player
|
| 72 |
+
self.done = done
|
| 73 |
+
self.winner = winner
|
| 74 |
+
return done, winner
|
| 75 |
+
|
| 76 |
+
def render(self) -> str:
|
| 77 |
+
symbols = {1: "X", -1: "O", 0: "."}
|
| 78 |
+
header = " " + " ".join(f"{i + 1:2d}" for i in range(self.board_size))
|
| 79 |
+
rows = [header]
|
| 80 |
+
for row_idx in range(self.board_size):
|
| 81 |
+
row = " ".join(f"{symbols[int(v)]:>2}" for v in self.board[row_idx])
|
| 82 |
+
rows.append(f"{row_idx + 1:2d} {row}")
|
| 83 |
+
return "\n".join(rows)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def action_to_coords(action: int, board_size: int) -> tuple[int, int]:
|
| 87 |
+
return divmod(int(action), board_size)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def coords_to_action(row: int, col: int, board_size: int) -> int:
|
| 91 |
+
return row * board_size + col
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def count_one_side(
|
| 95 |
+
board: np.ndarray,
|
| 96 |
+
row: int,
|
| 97 |
+
col: int,
|
| 98 |
+
dr: int,
|
| 99 |
+
dc: int,
|
| 100 |
+
player: int,
|
| 101 |
+
) -> int:
|
| 102 |
+
board_size = board.shape[0]
|
| 103 |
+
total = 0
|
| 104 |
+
r, c = row + dr, col + dc
|
| 105 |
+
while 0 <= r < board_size and 0 <= c < board_size:
|
| 106 |
+
if board[r, c] != player:
|
| 107 |
+
break
|
| 108 |
+
total += 1
|
| 109 |
+
r += dr
|
| 110 |
+
c += dc
|
| 111 |
+
return total
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def is_winning_move(
|
| 115 |
+
board: np.ndarray,
|
| 116 |
+
row: int,
|
| 117 |
+
col: int,
|
| 118 |
+
player: int,
|
| 119 |
+
win_length: int,
|
| 120 |
+
) -> bool:
|
| 121 |
+
directions = ((1, 0), (0, 1), (1, 1), (1, -1))
|
| 122 |
+
for dr, dc in directions:
|
| 123 |
+
count = 1
|
| 124 |
+
count += count_one_side(board, row, col, dr, dc, player)
|
| 125 |
+
count += count_one_side(board, row, col, -dr, -dc, player)
|
| 126 |
+
if count >= win_length:
|
| 127 |
+
return True
|
| 128 |
+
return False
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def apply_action_to_board(
|
| 132 |
+
board: np.ndarray,
|
| 133 |
+
current_player: int,
|
| 134 |
+
action: int,
|
| 135 |
+
win_length: int,
|
| 136 |
+
) -> tuple[np.ndarray, int, bool, int]:
|
| 137 |
+
board_size = board.shape[0]
|
| 138 |
+
row, col = action_to_coords(action, board_size)
|
| 139 |
+
if board[row, col] != 0:
|
| 140 |
+
raise ValueError(f"illegal move at ({row}, {col})")
|
| 141 |
+
|
| 142 |
+
next_board = board.copy()
|
| 143 |
+
next_board[row, col] = current_player
|
| 144 |
+
|
| 145 |
+
if is_winning_move(next_board, row, col, current_player, win_length):
|
| 146 |
+
return next_board, -current_player, True, current_player
|
| 147 |
+
if not np.any(next_board == 0):
|
| 148 |
+
return next_board, -current_player, True, 0
|
| 149 |
+
return next_board, -current_player, False, 0
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def encode_state(board: np.ndarray, current_player: int) -> torch.Tensor:
|
| 153 |
+
current = (board == current_player).astype(np.float32)
|
| 154 |
+
opponent = (board == -current_player).astype(np.float32)
|
| 155 |
+
legal = (board == 0).astype(np.float32)
|
| 156 |
+
return torch.from_numpy(np.stack([current, opponent, legal], axis=0))
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class PolicyValueNet(nn.Module):
|
| 160 |
+
def __init__(self, channels: int = 64):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.trunk = nn.Sequential(
|
| 163 |
+
nn.Conv2d(3, channels, kernel_size=3, padding=1),
|
| 164 |
+
nn.ReLU(),
|
| 165 |
+
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
| 166 |
+
nn.ReLU(),
|
| 167 |
+
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
| 168 |
+
nn.ReLU(),
|
| 169 |
+
)
|
| 170 |
+
self.policy_head = nn.Conv2d(channels, 1, kernel_size=1)
|
| 171 |
+
self.value_head = nn.Sequential(
|
| 172 |
+
nn.AdaptiveAvgPool2d(1),
|
| 173 |
+
nn.Flatten(),
|
| 174 |
+
nn.Linear(channels, channels),
|
| 175 |
+
nn.ReLU(),
|
| 176 |
+
nn.Linear(channels, 1),
|
| 177 |
+
nn.Tanh(),
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 181 |
+
features = self.trunk(x)
|
| 182 |
+
policy_logits = self.policy_head(features).flatten(start_dim=1)
|
| 183 |
+
value = self.value_head(features).squeeze(-1)
|
| 184 |
+
return policy_logits, value
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def masked_logits(logits: torch.Tensor, board: np.ndarray) -> torch.Tensor:
|
| 188 |
+
legal = torch.as_tensor((board == 0).reshape(-1), device=logits.device, dtype=torch.bool)
|
| 189 |
+
return logits.masked_fill(~legal, -1e9)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def evaluate_policy_value(
|
| 193 |
+
policy: PolicyValueNet,
|
| 194 |
+
board: np.ndarray,
|
| 195 |
+
current_player: int,
|
| 196 |
+
device: torch.device,
|
| 197 |
+
) -> tuple[np.ndarray, float]:
|
| 198 |
+
state = encode_state(board, current_player).unsqueeze(0).to(device)
|
| 199 |
+
with torch.no_grad():
|
| 200 |
+
logits, value = policy(state)
|
| 201 |
+
logits = masked_logits(logits.squeeze(0), board)
|
| 202 |
+
probs = torch.softmax(logits, dim=0).cpu().numpy()
|
| 203 |
+
return probs, float(value.item())
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
@dataclass
|
| 207 |
+
class MCTSNode:
|
| 208 |
+
board: np.ndarray
|
| 209 |
+
current_player: int
|
| 210 |
+
win_length: int
|
| 211 |
+
done: bool = False
|
| 212 |
+
winner: int = 0
|
| 213 |
+
priors: dict[int, float] = field(default_factory=dict)
|
| 214 |
+
visit_counts: dict[int, int] = field(default_factory=dict)
|
| 215 |
+
value_sums: dict[int, float] = field(default_factory=dict)
|
| 216 |
+
children: dict[int, "MCTSNode"] = field(default_factory=dict)
|
| 217 |
+
expanded: bool = False
|
| 218 |
+
|
| 219 |
+
def expand(
|
| 220 |
+
self,
|
| 221 |
+
priors: np.ndarray,
|
| 222 |
+
add_noise: bool = False,
|
| 223 |
+
dirichlet_alpha: float = 0.3,
|
| 224 |
+
noise_eps: float = 0.25,
|
| 225 |
+
) -> None:
|
| 226 |
+
legal_actions = np.flatnonzero((self.board == 0).reshape(-1))
|
| 227 |
+
legal_priors = priors[legal_actions]
|
| 228 |
+
total_prob = float(np.sum(legal_priors))
|
| 229 |
+
if total_prob <= 0.0:
|
| 230 |
+
legal_priors = np.full(len(legal_actions), 1.0 / max(len(legal_actions), 1), dtype=np.float32)
|
| 231 |
+
else:
|
| 232 |
+
legal_priors = legal_priors / total_prob
|
| 233 |
+
|
| 234 |
+
if add_noise and len(legal_actions) > 0:
|
| 235 |
+
noise = np.random.dirichlet([dirichlet_alpha] * len(legal_actions))
|
| 236 |
+
legal_priors = (1.0 - noise_eps) * legal_priors + noise_eps * noise
|
| 237 |
+
|
| 238 |
+
self.priors = {
|
| 239 |
+
int(action): float(prior)
|
| 240 |
+
for action, prior in zip(legal_actions, legal_priors, strict=False)
|
| 241 |
+
}
|
| 242 |
+
self.visit_counts = {action: 0 for action in self.priors}
|
| 243 |
+
self.value_sums = {action: 0.0 for action in self.priors}
|
| 244 |
+
self.expanded = True
|
| 245 |
+
|
| 246 |
+
def q_value(self, action: int) -> float:
|
| 247 |
+
visits = self.visit_counts[action]
|
| 248 |
+
if visits == 0:
|
| 249 |
+
return 0.0
|
| 250 |
+
return self.value_sums[action] / visits
|
| 251 |
+
|
| 252 |
+
def select_action(self, c_puct: float) -> int:
|
| 253 |
+
total_visits = sum(self.visit_counts.values())
|
| 254 |
+
sqrt_total = math.sqrt(total_visits + 1.0)
|
| 255 |
+
best_action = -1
|
| 256 |
+
best_score = -float("inf")
|
| 257 |
+
for action, prior in self.priors.items():
|
| 258 |
+
q = self.q_value(action)
|
| 259 |
+
u = c_puct * prior * sqrt_total / (1.0 + self.visit_counts[action])
|
| 260 |
+
score = q + u
|
| 261 |
+
if score > best_score:
|
| 262 |
+
best_score = score
|
| 263 |
+
best_action = action
|
| 264 |
+
return best_action
|
| 265 |
+
|
| 266 |
+
def child_for_action(self, action: int) -> "MCTSNode":
|
| 267 |
+
child = self.children.get(action)
|
| 268 |
+
if child is not None:
|
| 269 |
+
return child
|
| 270 |
+
next_board, next_player, done, winner = apply_action_to_board(
|
| 271 |
+
board=self.board,
|
| 272 |
+
current_player=self.current_player,
|
| 273 |
+
action=action,
|
| 274 |
+
win_length=self.win_length,
|
| 275 |
+
)
|
| 276 |
+
child = MCTSNode(
|
| 277 |
+
board=next_board,
|
| 278 |
+
current_player=next_player,
|
| 279 |
+
win_length=self.win_length,
|
| 280 |
+
done=done,
|
| 281 |
+
winner=winner,
|
| 282 |
+
)
|
| 283 |
+
self.children[action] = child
|
| 284 |
+
return child
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def terminal_value(winner: int, current_player: int) -> float:
|
| 288 |
+
if winner == 0:
|
| 289 |
+
return 0.0
|
| 290 |
+
return 1.0 if winner == current_player else -1.0
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def sample_from_visits(visits: np.ndarray, temperature: float) -> tuple[int, np.ndarray]:
|
| 294 |
+
flat = visits.reshape(-1).astype(np.float64)
|
| 295 |
+
if np.all(flat == 0):
|
| 296 |
+
flat = np.ones_like(flat)
|
| 297 |
+
|
| 298 |
+
if temperature <= 1e-6:
|
| 299 |
+
probs = np.zeros_like(flat, dtype=np.float64)
|
| 300 |
+
probs[int(np.argmax(flat))] = 1.0
|
| 301 |
+
else:
|
| 302 |
+
adjusted = np.power(flat, 1.0 / temperature)
|
| 303 |
+
probs = adjusted / np.sum(adjusted)
|
| 304 |
+
|
| 305 |
+
action = int(np.random.choice(len(probs), p=probs))
|
| 306 |
+
return action, probs.reshape(visits.shape).astype(np.float32)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def choose_mcts_action(
|
| 310 |
+
policy: PolicyValueNet,
|
| 311 |
+
board: np.ndarray,
|
| 312 |
+
current_player: int,
|
| 313 |
+
win_length: int,
|
| 314 |
+
device: torch.device,
|
| 315 |
+
num_simulations: int,
|
| 316 |
+
c_puct: float,
|
| 317 |
+
temperature: float,
|
| 318 |
+
add_root_noise: bool,
|
| 319 |
+
dirichlet_alpha: float,
|
| 320 |
+
noise_eps: float,
|
| 321 |
+
) -> tuple[int, np.ndarray]:
|
| 322 |
+
root = MCTSNode(board=board.copy(), current_player=current_player, win_length=win_length)
|
| 323 |
+
priors, _ = evaluate_policy_value(policy, root.board, root.current_player, device)
|
| 324 |
+
root.expand(
|
| 325 |
+
priors,
|
| 326 |
+
add_noise=add_root_noise,
|
| 327 |
+
dirichlet_alpha=dirichlet_alpha,
|
| 328 |
+
noise_eps=noise_eps,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
for _ in range(num_simulations):
|
| 332 |
+
node = root
|
| 333 |
+
path: list[tuple[MCTSNode, int]] = []
|
| 334 |
+
|
| 335 |
+
while node.expanded and not node.done:
|
| 336 |
+
action = node.select_action(c_puct)
|
| 337 |
+
path.append((node, action))
|
| 338 |
+
node = node.child_for_action(action)
|
| 339 |
+
|
| 340 |
+
if node.done:
|
| 341 |
+
value = terminal_value(node.winner, node.current_player)
|
| 342 |
+
else:
|
| 343 |
+
priors, value = evaluate_policy_value(policy, node.board, node.current_player, device)
|
| 344 |
+
node.expand(priors)
|
| 345 |
+
|
| 346 |
+
for parent, action in reversed(path):
|
| 347 |
+
value = -value
|
| 348 |
+
parent.visit_counts[action] += 1
|
| 349 |
+
parent.value_sums[action] += value
|
| 350 |
+
|
| 351 |
+
visits = np.zeros(board.shape, dtype=np.float32)
|
| 352 |
+
for action, count in root.visit_counts.items():
|
| 353 |
+
row, col = action_to_coords(action, board.shape[0])
|
| 354 |
+
visits[row, col] = float(count)
|
| 355 |
+
|
| 356 |
+
action, visit_probs = sample_from_visits(visits, temperature=temperature)
|
| 357 |
+
return action, visit_probs
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def choose_ai_action(
|
| 361 |
+
policy: PolicyValueNet,
|
| 362 |
+
board: np.ndarray,
|
| 363 |
+
current_player: int,
|
| 364 |
+
win_length: int,
|
| 365 |
+
device: torch.device,
|
| 366 |
+
agent: str,
|
| 367 |
+
mcts_sims: int,
|
| 368 |
+
c_puct: float,
|
| 369 |
+
) -> tuple[int, np.ndarray | None]:
|
| 370 |
+
if agent == "policy":
|
| 371 |
+
priors, _ = evaluate_policy_value(policy, board, current_player, device)
|
| 372 |
+
return int(np.argmax(priors)), None
|
| 373 |
+
return choose_mcts_action(
|
| 374 |
+
policy=policy,
|
| 375 |
+
board=board,
|
| 376 |
+
current_player=current_player,
|
| 377 |
+
win_length=win_length,
|
| 378 |
+
device=device,
|
| 379 |
+
num_simulations=mcts_sims,
|
| 380 |
+
c_puct=c_puct,
|
| 381 |
+
temperature=1e-6,
|
| 382 |
+
add_root_noise=False,
|
| 383 |
+
dirichlet_alpha=0.3,
|
| 384 |
+
noise_eps=0.25,
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def self_play_game(
|
| 389 |
+
policy: PolicyValueNet,
|
| 390 |
+
board_size: int,
|
| 391 |
+
win_length: int,
|
| 392 |
+
device: torch.device,
|
| 393 |
+
mcts_sims: int,
|
| 394 |
+
c_puct: float,
|
| 395 |
+
temperature: float,
|
| 396 |
+
temperature_drop_moves: int,
|
| 397 |
+
dirichlet_alpha: float,
|
| 398 |
+
noise_eps: float,
|
| 399 |
+
) -> tuple[list[tuple[torch.Tensor, np.ndarray, float]], int, int]:
|
| 400 |
+
env = GomokuEnv(board_size=board_size, win_length=win_length)
|
| 401 |
+
env.reset()
|
| 402 |
+
history: list[tuple[torch.Tensor, np.ndarray, int]] = []
|
| 403 |
+
|
| 404 |
+
move_idx = 0
|
| 405 |
+
while not env.done:
|
| 406 |
+
move_temp = temperature if move_idx < temperature_drop_moves else 1e-6
|
| 407 |
+
action, visit_probs = choose_mcts_action(
|
| 408 |
+
policy=policy,
|
| 409 |
+
board=env.board,
|
| 410 |
+
current_player=env.current_player,
|
| 411 |
+
win_length=win_length,
|
| 412 |
+
device=device,
|
| 413 |
+
num_simulations=mcts_sims,
|
| 414 |
+
c_puct=c_puct,
|
| 415 |
+
temperature=move_temp,
|
| 416 |
+
add_root_noise=True,
|
| 417 |
+
dirichlet_alpha=dirichlet_alpha,
|
| 418 |
+
noise_eps=noise_eps,
|
| 419 |
+
)
|
| 420 |
+
history.append((encode_state(env.board, env.current_player), visit_probs.reshape(-1), env.current_player))
|
| 421 |
+
env.step(action)
|
| 422 |
+
move_idx += 1
|
| 423 |
+
|
| 424 |
+
examples: list[tuple[torch.Tensor, np.ndarray, float]] = []
|
| 425 |
+
for state, visit_probs, player in history:
|
| 426 |
+
if env.winner == 0:
|
| 427 |
+
outcome = 0.0
|
| 428 |
+
else:
|
| 429 |
+
outcome = 1.0 if player == env.winner else -1.0
|
| 430 |
+
examples.append((state, visit_probs, outcome))
|
| 431 |
+
return examples, env.winner, move_idx
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
def train_batch(
|
| 435 |
+
policy: PolicyValueNet,
|
| 436 |
+
optimizer: torch.optim.Optimizer,
|
| 437 |
+
batch: list[tuple[torch.Tensor, np.ndarray, float]],
|
| 438 |
+
device: torch.device,
|
| 439 |
+
value_coef: float,
|
| 440 |
+
) -> tuple[float, float, float]:
|
| 441 |
+
states = torch.stack([item[0] for item in batch]).to(device)
|
| 442 |
+
target_policies = torch.tensor(
|
| 443 |
+
np.stack([item[1] for item in batch]),
|
| 444 |
+
dtype=torch.float32,
|
| 445 |
+
device=device,
|
| 446 |
+
)
|
| 447 |
+
target_values = torch.tensor([item[2] for item in batch], dtype=torch.float32, device=device)
|
| 448 |
+
|
| 449 |
+
logits, values = policy(states)
|
| 450 |
+
log_probs = torch.log_softmax(logits, dim=1)
|
| 451 |
+
policy_loss = -(target_policies * log_probs).sum(dim=1).mean()
|
| 452 |
+
value_loss = torch.mean((values - target_values) ** 2)
|
| 453 |
+
loss = policy_loss + value_coef * value_loss
|
| 454 |
+
|
| 455 |
+
optimizer.zero_grad(set_to_none=True)
|
| 456 |
+
loss.backward()
|
| 457 |
+
nn.utils.clip_grad_norm_(policy.parameters(), 1.0)
|
| 458 |
+
optimizer.step()
|
| 459 |
+
return float(loss.item()), float(policy_loss.item()), float(value_loss.item())
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
def save_checkpoint(path: Path, policy: PolicyValueNet, args: argparse.Namespace) -> None:
|
| 463 |
+
torch.save(
|
| 464 |
+
{
|
| 465 |
+
"state_dict": policy.state_dict(),
|
| 466 |
+
"channels": args.channels,
|
| 467 |
+
"board_size": args.board_size,
|
| 468 |
+
"win_length": args.win_length,
|
| 469 |
+
},
|
| 470 |
+
path,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
def last_checkpoint_path(base_path: Path) -> Path:
|
| 475 |
+
return base_path.with_name(f"{base_path.stem}_last{base_path.suffix}")
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def load_checkpoint(path: Path, map_location: torch.device) -> dict:
|
| 479 |
+
checkpoint = torch.load(path, map_location=map_location)
|
| 480 |
+
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
| 481 |
+
return checkpoint
|
| 482 |
+
raise RuntimeError(f"{path} is not a compatible gomoku_mcts checkpoint")
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def resolve_game_config(
|
| 486 |
+
checkpoint_path: Path,
|
| 487 |
+
arg_board_size: int | None,
|
| 488 |
+
arg_win_length: int | None,
|
| 489 |
+
arg_channels: int,
|
| 490 |
+
device: torch.device,
|
| 491 |
+
) -> tuple[PolicyValueNet, int, int]:
|
| 492 |
+
checkpoint = load_checkpoint(checkpoint_path, map_location=device)
|
| 493 |
+
board_size = int(checkpoint.get("board_size") or arg_board_size or 15)
|
| 494 |
+
win_length = int(checkpoint.get("win_length") or arg_win_length or 5)
|
| 495 |
+
channels = int(checkpoint.get("channels") or arg_channels)
|
| 496 |
+
|
| 497 |
+
policy = PolicyValueNet(channels=channels).to(device)
|
| 498 |
+
policy.load_state_dict(checkpoint["state_dict"])
|
| 499 |
+
policy.eval()
|
| 500 |
+
return policy, board_size, win_length
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def play_vs_random_once(
|
| 504 |
+
policy: PolicyValueNet,
|
| 505 |
+
board_size: int,
|
| 506 |
+
win_length: int,
|
| 507 |
+
device: torch.device,
|
| 508 |
+
policy_player: int,
|
| 509 |
+
agent: str,
|
| 510 |
+
mcts_sims: int,
|
| 511 |
+
c_puct: float,
|
| 512 |
+
) -> int:
|
| 513 |
+
env = GomokuEnv(board_size=board_size, win_length=win_length)
|
| 514 |
+
env.reset()
|
| 515 |
+
while not env.done:
|
| 516 |
+
if env.current_player == policy_player:
|
| 517 |
+
action, _ = choose_ai_action(
|
| 518 |
+
policy=policy,
|
| 519 |
+
board=env.board,
|
| 520 |
+
current_player=env.current_player,
|
| 521 |
+
win_length=win_length,
|
| 522 |
+
device=device,
|
| 523 |
+
agent=agent,
|
| 524 |
+
mcts_sims=mcts_sims,
|
| 525 |
+
c_puct=c_puct,
|
| 526 |
+
)
|
| 527 |
+
else:
|
| 528 |
+
action = int(np.random.choice(env.valid_moves()))
|
| 529 |
+
env.step(action)
|
| 530 |
+
return env.winner
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def evaluate_vs_random(
|
| 534 |
+
policy: PolicyValueNet,
|
| 535 |
+
board_size: int,
|
| 536 |
+
win_length: int,
|
| 537 |
+
device: torch.device,
|
| 538 |
+
games: int,
|
| 539 |
+
agent: str,
|
| 540 |
+
mcts_sims: int,
|
| 541 |
+
c_puct: float,
|
| 542 |
+
) -> tuple[float, int, int, int]:
|
| 543 |
+
wins = 0
|
| 544 |
+
draws = 0
|
| 545 |
+
losses = 0
|
| 546 |
+
for game_idx in range(games):
|
| 547 |
+
policy_player = 1 if game_idx < games // 2 else -1
|
| 548 |
+
winner = play_vs_random_once(
|
| 549 |
+
policy=policy,
|
| 550 |
+
board_size=board_size,
|
| 551 |
+
win_length=win_length,
|
| 552 |
+
device=device,
|
| 553 |
+
policy_player=policy_player,
|
| 554 |
+
agent=agent,
|
| 555 |
+
mcts_sims=mcts_sims,
|
| 556 |
+
c_puct=c_puct,
|
| 557 |
+
)
|
| 558 |
+
if winner == 0:
|
| 559 |
+
draws += 1
|
| 560 |
+
elif winner == policy_player:
|
| 561 |
+
wins += 1
|
| 562 |
+
else:
|
| 563 |
+
losses += 1
|
| 564 |
+
return wins / max(games, 1), wins, draws, losses
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def train(args: argparse.Namespace) -> None:
|
| 568 |
+
set_seed(args.seed)
|
| 569 |
+
device = choose_device(args.device)
|
| 570 |
+
policy = PolicyValueNet(channels=args.channels).to(device)
|
| 571 |
+
if args.init_checkpoint is not None and args.init_checkpoint.exists():
|
| 572 |
+
checkpoint = load_checkpoint(args.init_checkpoint, map_location=device)
|
| 573 |
+
policy.load_state_dict(checkpoint["state_dict"])
|
| 574 |
+
optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
| 575 |
+
replay_buffer: deque[tuple[torch.Tensor, np.ndarray, float]] = deque(maxlen=args.buffer_size)
|
| 576 |
+
|
| 577 |
+
print(f"device={device} board={args.board_size} win={args.win_length}")
|
| 578 |
+
|
| 579 |
+
for iteration in range(1, args.iterations + 1):
|
| 580 |
+
policy.eval()
|
| 581 |
+
winners: list[int] = []
|
| 582 |
+
lengths: list[int] = []
|
| 583 |
+
for _ in range(args.games_per_iter):
|
| 584 |
+
examples, winner, moves = self_play_game(
|
| 585 |
+
policy=policy,
|
| 586 |
+
board_size=args.board_size,
|
| 587 |
+
win_length=args.win_length,
|
| 588 |
+
device=device,
|
| 589 |
+
mcts_sims=args.mcts_sims,
|
| 590 |
+
c_puct=args.c_puct,
|
| 591 |
+
temperature=args.temperature,
|
| 592 |
+
temperature_drop_moves=args.temperature_drop_moves,
|
| 593 |
+
dirichlet_alpha=args.dirichlet_alpha,
|
| 594 |
+
noise_eps=args.noise_eps,
|
| 595 |
+
)
|
| 596 |
+
replay_buffer.extend(examples)
|
| 597 |
+
winners.append(winner)
|
| 598 |
+
lengths.append(moves)
|
| 599 |
+
|
| 600 |
+
losses: list[tuple[float, float, float]] = []
|
| 601 |
+
if len(replay_buffer) >= args.batch_size:
|
| 602 |
+
policy.train()
|
| 603 |
+
for _ in range(args.train_steps):
|
| 604 |
+
batch = random.sample(replay_buffer, args.batch_size)
|
| 605 |
+
losses.append(
|
| 606 |
+
train_batch(
|
| 607 |
+
policy=policy,
|
| 608 |
+
optimizer=optimizer,
|
| 609 |
+
batch=batch,
|
| 610 |
+
device=device,
|
| 611 |
+
value_coef=args.value_coef,
|
| 612 |
+
)
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
avg_loss = float(np.mean([x[0] for x in losses])) if losses else 0.0
|
| 616 |
+
avg_policy_loss = float(np.mean([x[1] for x in losses])) if losses else 0.0
|
| 617 |
+
avg_value_loss = float(np.mean([x[2] for x in losses])) if losses else 0.0
|
| 618 |
+
p1_wins = sum(1 for x in winners if x == 1)
|
| 619 |
+
p2_wins = sum(1 for x in winners if x == -1)
|
| 620 |
+
draws = sum(1 for x in winners if x == 0)
|
| 621 |
+
avg_len = float(np.mean(lengths)) if lengths else 0.0
|
| 622 |
+
|
| 623 |
+
message = (
|
| 624 |
+
f"iter={iteration:5d} loss={avg_loss:7.4f} policy={avg_policy_loss:7.4f} "
|
| 625 |
+
f"value={avg_value_loss:7.4f} p1={p1_wins:3d} p2={p2_wins:3d} draw={draws:3d} "
|
| 626 |
+
f"avg_len={avg_len:6.2f} buffer={len(replay_buffer):6d}"
|
| 627 |
+
)
|
| 628 |
+
if args.eval_every > 0 and iteration % args.eval_every == 0:
|
| 629 |
+
policy.eval()
|
| 630 |
+
win_rate, wins, eval_draws, eval_losses = evaluate_vs_random(
|
| 631 |
+
policy=policy,
|
| 632 |
+
board_size=args.board_size,
|
| 633 |
+
win_length=args.win_length,
|
| 634 |
+
device=device,
|
| 635 |
+
games=args.eval_games,
|
| 636 |
+
agent="mcts",
|
| 637 |
+
mcts_sims=args.eval_mcts_sims,
|
| 638 |
+
c_puct=args.c_puct,
|
| 639 |
+
)
|
| 640 |
+
message += f" random_win_rate={win_rate:.3f} ({wins}/{eval_draws}/{eval_losses})"
|
| 641 |
+
print(message)
|
| 642 |
+
|
| 643 |
+
if args.save_every > 0 and iteration % args.save_every == 0:
|
| 644 |
+
checkpoint_path = last_checkpoint_path(args.checkpoint)
|
| 645 |
+
save_checkpoint(checkpoint_path, policy, args)
|
| 646 |
+
print(f"saved checkpoint to {checkpoint_path}")
|
| 647 |
+
|
| 648 |
+
save_checkpoint(args.checkpoint, policy, args)
|
| 649 |
+
print(f"saved checkpoint to {args.checkpoint}")
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def evaluate(args: argparse.Namespace) -> None:
|
| 653 |
+
device = choose_device(args.device)
|
| 654 |
+
policy, board_size, win_length = resolve_game_config(
|
| 655 |
+
checkpoint_path=args.checkpoint,
|
| 656 |
+
arg_board_size=args.board_size,
|
| 657 |
+
arg_win_length=args.win_length,
|
| 658 |
+
arg_channels=args.channels,
|
| 659 |
+
device=device,
|
| 660 |
+
)
|
| 661 |
+
win_rate, wins, draws, losses = evaluate_vs_random(
|
| 662 |
+
policy=policy,
|
| 663 |
+
board_size=board_size,
|
| 664 |
+
win_length=win_length,
|
| 665 |
+
device=device,
|
| 666 |
+
games=args.games,
|
| 667 |
+
agent=args.agent,
|
| 668 |
+
mcts_sims=args.mcts_sims,
|
| 669 |
+
c_puct=args.c_puct,
|
| 670 |
+
)
|
| 671 |
+
print(f"device={device}")
|
| 672 |
+
print(f"agent={args.agent} mcts_sims={args.mcts_sims}")
|
| 673 |
+
print(f"win_rate={win_rate:.3f} wins={wins} draws={draws} losses={losses}")
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def ask_human_move(env: GomokuEnv) -> int:
|
| 677 |
+
while True:
|
| 678 |
+
text = input("your move (row col): ").strip()
|
| 679 |
+
parts = text.replace(",", " ").split()
|
| 680 |
+
if len(parts) != 2:
|
| 681 |
+
print("please enter: row col")
|
| 682 |
+
continue
|
| 683 |
+
try:
|
| 684 |
+
row, col = (int(parts[0]) - 1, int(parts[1]) - 1)
|
| 685 |
+
except ValueError:
|
| 686 |
+
print("row and col must be integers")
|
| 687 |
+
continue
|
| 688 |
+
if not (0 <= row < env.board_size and 0 <= col < env.board_size):
|
| 689 |
+
print("move out of range")
|
| 690 |
+
continue
|
| 691 |
+
if env.board[row, col] != 0:
|
| 692 |
+
print("that position is occupied")
|
| 693 |
+
continue
|
| 694 |
+
return coords_to_action(row, col, env.board_size)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
def play(args: argparse.Namespace) -> None:
|
| 698 |
+
device = choose_device(args.device)
|
| 699 |
+
policy, board_size, win_length = resolve_game_config(
|
| 700 |
+
checkpoint_path=args.checkpoint,
|
| 701 |
+
arg_board_size=args.board_size,
|
| 702 |
+
arg_win_length=args.win_length,
|
| 703 |
+
arg_channels=args.channels,
|
| 704 |
+
device=device,
|
| 705 |
+
)
|
| 706 |
+
env = GomokuEnv(board_size=board_size, win_length=win_length)
|
| 707 |
+
human_player = 1 if args.human_first else -1
|
| 708 |
+
|
| 709 |
+
print(f"device={device}")
|
| 710 |
+
print(
|
| 711 |
+
f"human={'X' if human_player == 1 else 'O'} ai={'O' if human_player == 1 else 'X'} "
|
| 712 |
+
f"agent={args.agent} mcts_sims={args.mcts_sims}"
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
while not env.done:
|
| 716 |
+
print()
|
| 717 |
+
print(env.render())
|
| 718 |
+
print()
|
| 719 |
+
if env.current_player == human_player:
|
| 720 |
+
action = ask_human_move(env)
|
| 721 |
+
else:
|
| 722 |
+
action, _ = choose_ai_action(
|
| 723 |
+
policy=policy,
|
| 724 |
+
board=env.board,
|
| 725 |
+
current_player=env.current_player,
|
| 726 |
+
win_length=win_length,
|
| 727 |
+
device=device,
|
| 728 |
+
agent=args.agent,
|
| 729 |
+
mcts_sims=args.mcts_sims,
|
| 730 |
+
c_puct=args.c_puct,
|
| 731 |
+
)
|
| 732 |
+
row, col = action_to_coords(action, env.board_size)
|
| 733 |
+
print(f"ai move: {row + 1} {col + 1}")
|
| 734 |
+
env.step(action)
|
| 735 |
+
|
| 736 |
+
print()
|
| 737 |
+
print(env.render())
|
| 738 |
+
if env.winner == 0:
|
| 739 |
+
print("draw")
|
| 740 |
+
elif env.winner == human_player:
|
| 741 |
+
print("you win")
|
| 742 |
+
else:
|
| 743 |
+
print("ai wins")
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def gui(args: argparse.Namespace) -> None:
|
| 747 |
+
try:
|
| 748 |
+
import pygame
|
| 749 |
+
except ModuleNotFoundError as exc:
|
| 750 |
+
raise SystemExit(
|
| 751 |
+
"pygame is not installed. Install it with: "
|
| 752 |
+
"~/miniconda3/bin/conda run -n lerobot python -m pip install pygame"
|
| 753 |
+
) from exc
|
| 754 |
+
|
| 755 |
+
device = choose_device(args.device)
|
| 756 |
+
policy, board_size, win_length = resolve_game_config(
|
| 757 |
+
checkpoint_path=args.checkpoint,
|
| 758 |
+
arg_board_size=args.board_size,
|
| 759 |
+
arg_win_length=args.win_length,
|
| 760 |
+
arg_channels=args.channels,
|
| 761 |
+
device=device,
|
| 762 |
+
)
|
| 763 |
+
env = GomokuEnv(board_size=board_size, win_length=win_length)
|
| 764 |
+
human_player = 1 if args.human_first else -1
|
| 765 |
+
last_search_visits: np.ndarray | None = None
|
| 766 |
+
|
| 767 |
+
pygame.init()
|
| 768 |
+
pygame.display.set_caption("Gomoku MCTS")
|
| 769 |
+
font = pygame.font.SysFont("Arial", 24)
|
| 770 |
+
small_font = pygame.font.SysFont("Arial", 18)
|
| 771 |
+
|
| 772 |
+
cell_size = args.cell_size
|
| 773 |
+
padding = 40
|
| 774 |
+
status_height = 80
|
| 775 |
+
board_pixels = board_size * cell_size
|
| 776 |
+
screen = pygame.display.set_mode(
|
| 777 |
+
(board_pixels + padding * 2, board_pixels + padding * 2 + status_height)
|
| 778 |
+
)
|
| 779 |
+
clock = pygame.time.Clock()
|
| 780 |
+
|
| 781 |
+
background = (236, 196, 122)
|
| 782 |
+
line_color = (80, 55, 20)
|
| 783 |
+
black_stone = (20, 20, 20)
|
| 784 |
+
white_stone = (245, 245, 245)
|
| 785 |
+
accent = (180, 40, 40)
|
| 786 |
+
|
| 787 |
+
def board_to_screen(row: int, col: int) -> tuple[int, int]:
|
| 788 |
+
x = padding + col * cell_size + cell_size // 2
|
| 789 |
+
y = padding + row * cell_size + cell_size // 2
|
| 790 |
+
return x, y
|
| 791 |
+
|
| 792 |
+
def mouse_to_action(pos: tuple[int, int]) -> int | None:
|
| 793 |
+
x, y = pos
|
| 794 |
+
if x < padding or y < padding:
|
| 795 |
+
return None
|
| 796 |
+
col = (x - padding) // cell_size
|
| 797 |
+
row = (y - padding) // cell_size
|
| 798 |
+
if not (0 <= row < env.board_size and 0 <= col < env.board_size):
|
| 799 |
+
return None
|
| 800 |
+
if env.board[row, col] != 0:
|
| 801 |
+
return None
|
| 802 |
+
return coords_to_action(row, col, env.board_size)
|
| 803 |
+
|
| 804 |
+
def ai_step() -> None:
|
| 805 |
+
nonlocal last_search_visits
|
| 806 |
+
if env.done or env.current_player == human_player:
|
| 807 |
+
return
|
| 808 |
+
action, visits = choose_ai_action(
|
| 809 |
+
policy=policy,
|
| 810 |
+
board=env.board,
|
| 811 |
+
current_player=env.current_player,
|
| 812 |
+
win_length=win_length,
|
| 813 |
+
device=device,
|
| 814 |
+
agent=args.agent,
|
| 815 |
+
mcts_sims=args.mcts_sims,
|
| 816 |
+
c_puct=args.c_puct,
|
| 817 |
+
)
|
| 818 |
+
last_search_visits = visits
|
| 819 |
+
env.step(action)
|
| 820 |
+
|
| 821 |
+
def restart() -> None:
|
| 822 |
+
nonlocal last_search_visits
|
| 823 |
+
env.reset()
|
| 824 |
+
last_search_visits = None
|
| 825 |
+
if env.current_player != human_player:
|
| 826 |
+
ai_step()
|
| 827 |
+
|
| 828 |
+
def status_text() -> str:
|
| 829 |
+
if env.done:
|
| 830 |
+
if env.winner == 0:
|
| 831 |
+
return "Draw. Press R to restart."
|
| 832 |
+
if env.winner == human_player:
|
| 833 |
+
return "You win. Press R to restart."
|
| 834 |
+
return "AI wins. Press R to restart."
|
| 835 |
+
if env.current_player == human_player:
|
| 836 |
+
return "Your turn. Left click to place."
|
| 837 |
+
return "AI is thinking..."
|
| 838 |
+
|
| 839 |
+
if env.current_player != human_player:
|
| 840 |
+
ai_step()
|
| 841 |
+
|
| 842 |
+
running = True
|
| 843 |
+
while running:
|
| 844 |
+
for event in pygame.event.get():
|
| 845 |
+
if event.type == pygame.QUIT:
|
| 846 |
+
running = False
|
| 847 |
+
elif event.type == pygame.KEYDOWN:
|
| 848 |
+
if event.key == pygame.K_ESCAPE:
|
| 849 |
+
running = False
|
| 850 |
+
elif event.key == pygame.K_r:
|
| 851 |
+
restart()
|
| 852 |
+
elif event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
|
| 853 |
+
if env.done or env.current_player != human_player:
|
| 854 |
+
continue
|
| 855 |
+
action = mouse_to_action(event.pos)
|
| 856 |
+
if action is None:
|
| 857 |
+
continue
|
| 858 |
+
env.step(action)
|
| 859 |
+
ai_step()
|
| 860 |
+
|
| 861 |
+
screen.fill(background)
|
| 862 |
+
for idx in range(board_size + 1):
|
| 863 |
+
x = padding + idx * cell_size
|
| 864 |
+
y = padding + idx * cell_size
|
| 865 |
+
pygame.draw.line(screen, line_color, (x, padding), (x, padding + board_pixels), 2)
|
| 866 |
+
pygame.draw.line(screen, line_color, (padding, y), (padding + board_pixels, y), 2)
|
| 867 |
+
|
| 868 |
+
for row in range(env.board_size):
|
| 869 |
+
for col in range(env.board_size):
|
| 870 |
+
stone = int(env.board[row, col])
|
| 871 |
+
if stone == 0:
|
| 872 |
+
continue
|
| 873 |
+
x, y = board_to_screen(row, col)
|
| 874 |
+
color = black_stone if stone == 1 else white_stone
|
| 875 |
+
pygame.draw.circle(screen, color, (x, y), cell_size // 2 - 4)
|
| 876 |
+
pygame.draw.circle(screen, line_color, (x, y), cell_size // 2 - 4, 1)
|
| 877 |
+
|
| 878 |
+
for idx in range(board_size):
|
| 879 |
+
label = small_font.render(str(idx + 1), True, line_color)
|
| 880 |
+
screen.blit(label, (padding + idx * cell_size + cell_size // 2 - label.get_width() // 2, 8))
|
| 881 |
+
screen.blit(label, (8, padding + idx * cell_size + cell_size // 2 - label.get_height() // 2))
|
| 882 |
+
|
| 883 |
+
info = (
|
| 884 |
+
f"{board_size}x{board_size} connect={win_length} device={device} "
|
| 885 |
+
f"agent={args.agent} sims={args.mcts_sims}"
|
| 886 |
+
)
|
| 887 |
+
screen.blit(small_font.render(info, True, line_color), (padding, padding + board_pixels + 16))
|
| 888 |
+
screen.blit(font.render(status_text(), True, accent), (padding, padding + board_pixels + 42))
|
| 889 |
+
|
| 890 |
+
if last_search_visits is not None and args.agent == "mcts":
|
| 891 |
+
peak = int(np.max(last_search_visits))
|
| 892 |
+
screen.blit(
|
| 893 |
+
small_font.render(f"peak_visits={peak}", True, line_color),
|
| 894 |
+
(padding + 420, padding + board_pixels + 16),
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
pygame.display.flip()
|
| 898 |
+
clock.tick(args.fps)
|
| 899 |
+
|
| 900 |
+
pygame.quit()
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 904 |
+
parser = argparse.ArgumentParser(description="Minimal Gomoku MCTS example")
|
| 905 |
+
subparsers = parser.add_subparsers(dest="mode", required=True)
|
| 906 |
+
|
| 907 |
+
def add_common_arguments(subparser: argparse.ArgumentParser, defaults_from_checkpoint: bool = False) -> None:
|
| 908 |
+
board_default = None if defaults_from_checkpoint else 15
|
| 909 |
+
win_default = None if defaults_from_checkpoint else 5
|
| 910 |
+
subparser.add_argument("--board-size", type=int, default=board_default)
|
| 911 |
+
subparser.add_argument("--win-length", type=int, default=win_default)
|
| 912 |
+
subparser.add_argument("--channels", type=int, default=64)
|
| 913 |
+
subparser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto")
|
| 914 |
+
subparser.add_argument("--checkpoint", type=Path, default=Path("gomoku_mcts.pt"))
|
| 915 |
+
|
| 916 |
+
def add_inference_arguments(subparser: argparse.ArgumentParser) -> None:
|
| 917 |
+
subparser.add_argument("--agent", choices=["policy", "mcts"], default="mcts")
|
| 918 |
+
subparser.add_argument("--mcts-sims", type=int, default=120)
|
| 919 |
+
subparser.add_argument("--c-puct", type=float, default=1.5)
|
| 920 |
+
|
| 921 |
+
train_parser = subparsers.add_parser("train", help="MCTS self-play training")
|
| 922 |
+
add_common_arguments(train_parser)
|
| 923 |
+
train_parser.add_argument("--iterations", type=int, default=200)
|
| 924 |
+
train_parser.add_argument("--games-per-iter", type=int, default=8)
|
| 925 |
+
train_parser.add_argument("--train-steps", type=int, default=32)
|
| 926 |
+
train_parser.add_argument("--batch-size", type=int, default=64)
|
| 927 |
+
train_parser.add_argument("--buffer-size", type=int, default=20000)
|
| 928 |
+
train_parser.add_argument("--lr", type=float, default=1e-3)
|
| 929 |
+
train_parser.add_argument("--weight-decay", type=float, default=1e-4)
|
| 930 |
+
train_parser.add_argument("--value-coef", type=float, default=1.0)
|
| 931 |
+
train_parser.add_argument("--mcts-sims", type=int, default=64)
|
| 932 |
+
train_parser.add_argument("--eval-mcts-sims", type=int, default=120)
|
| 933 |
+
train_parser.add_argument("--c-puct", type=float, default=1.5)
|
| 934 |
+
train_parser.add_argument("--temperature", type=float, default=1.0)
|
| 935 |
+
train_parser.add_argument("--temperature-drop-moves", type=int, default=8)
|
| 936 |
+
train_parser.add_argument("--dirichlet-alpha", type=float, default=0.3)
|
| 937 |
+
train_parser.add_argument("--noise-eps", type=float, default=0.25)
|
| 938 |
+
train_parser.add_argument("--eval-every", type=int, default=10)
|
| 939 |
+
train_parser.add_argument("--eval-games", type=int, default=20)
|
| 940 |
+
train_parser.add_argument("--save-every", type=int, default=10)
|
| 941 |
+
train_parser.add_argument("--seed", type=int, default=42)
|
| 942 |
+
train_parser.add_argument("--init-checkpoint", type=Path, default=None)
|
| 943 |
+
train_parser.set_defaults(func=train)
|
| 944 |
+
|
| 945 |
+
eval_parser = subparsers.add_parser("eval", help="evaluate against random agent")
|
| 946 |
+
add_common_arguments(eval_parser)
|
| 947 |
+
add_inference_arguments(eval_parser)
|
| 948 |
+
eval_parser.add_argument("--games", type=int, default=40)
|
| 949 |
+
eval_parser.set_defaults(func=evaluate)
|
| 950 |
+
|
| 951 |
+
play_parser = subparsers.add_parser("play", help="play against the model")
|
| 952 |
+
add_common_arguments(play_parser, defaults_from_checkpoint=True)
|
| 953 |
+
add_inference_arguments(play_parser)
|
| 954 |
+
play_parser.add_argument("--human-first", action="store_true")
|
| 955 |
+
play_parser.set_defaults(func=play)
|
| 956 |
+
|
| 957 |
+
gui_parser = subparsers.add_parser("gui", help="pygame GUI")
|
| 958 |
+
add_common_arguments(gui_parser, defaults_from_checkpoint=True)
|
| 959 |
+
add_inference_arguments(gui_parser)
|
| 960 |
+
gui_parser.add_argument("--human-first", action="store_true")
|
| 961 |
+
gui_parser.add_argument("--cell-size", type=int, default=48)
|
| 962 |
+
gui_parser.add_argument("--fps", type=int, default=30)
|
| 963 |
+
gui_parser.set_defaults(func=gui)
|
| 964 |
+
|
| 965 |
+
return parser
|
| 966 |
+
|
| 967 |
+
|
| 968 |
+
def main() -> None:
|
| 969 |
+
parser = build_parser()
|
| 970 |
+
args = parser.parse_args()
|
| 971 |
+
args.func(args)
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
if __name__ == "__main__":
|
| 975 |
+
main()
|
gomoku_pg.py
ADDED
|
@@ -0,0 +1,1105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Minimal Gomoku policy gradient example.
|
| 3 |
+
|
| 4 |
+
Features:
|
| 5 |
+
1. Configurable board size and win length, e.g. 5x5 connect-4 or 15x15 connect-5.
|
| 6 |
+
2. Shared-policy self-play with REINFORCE.
|
| 7 |
+
3. Fully convolutional policy, so the same code works for different board sizes.
|
| 8 |
+
4. Optional random-agent evaluation and CLI human play.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import math
|
| 15 |
+
import random
|
| 16 |
+
from collections import deque
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
from torch import nn
|
| 23 |
+
from torch.distributions import Categorical
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def choose_device(name: str) -> torch.device:
|
| 27 |
+
if name != "auto":
|
| 28 |
+
return torch.device(name)
|
| 29 |
+
if torch.cuda.is_available():
|
| 30 |
+
return torch.device("cuda")
|
| 31 |
+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 32 |
+
return torch.device("mps")
|
| 33 |
+
return torch.device("cpu")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def set_seed(seed: int) -> None:
|
| 37 |
+
random.seed(seed)
|
| 38 |
+
np.random.seed(seed)
|
| 39 |
+
torch.manual_seed(seed)
|
| 40 |
+
if torch.cuda.is_available():
|
| 41 |
+
torch.cuda.manual_seed_all(seed)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class GomokuEnv:
|
| 45 |
+
def __init__(self, board_size: int, win_length: int):
|
| 46 |
+
if board_size <= 1:
|
| 47 |
+
raise ValueError("board_size must be > 1")
|
| 48 |
+
if not 1 < win_length <= board_size:
|
| 49 |
+
raise ValueError("win_length must satisfy 1 < win_length <= board_size")
|
| 50 |
+
self.board_size = board_size
|
| 51 |
+
self.win_length = win_length
|
| 52 |
+
self.reset()
|
| 53 |
+
|
| 54 |
+
def reset(self) -> np.ndarray:
|
| 55 |
+
self.board = np.zeros((self.board_size, self.board_size), dtype=np.int8)
|
| 56 |
+
self.current_player = 1
|
| 57 |
+
self.done = False
|
| 58 |
+
self.winner = 0
|
| 59 |
+
return self.board
|
| 60 |
+
|
| 61 |
+
def legal_mask(self) -> np.ndarray:
|
| 62 |
+
return self.board == 0
|
| 63 |
+
|
| 64 |
+
def valid_moves(self) -> np.ndarray:
|
| 65 |
+
return np.flatnonzero(self.legal_mask().reshape(-1))
|
| 66 |
+
|
| 67 |
+
def step(self, action: int) -> tuple[bool, int]:
|
| 68 |
+
if self.done:
|
| 69 |
+
raise RuntimeError("game is already finished")
|
| 70 |
+
|
| 71 |
+
row, col = divmod(int(action), self.board_size)
|
| 72 |
+
if self.board[row, col] != 0:
|
| 73 |
+
raise ValueError(f"illegal move at ({row}, {col})")
|
| 74 |
+
|
| 75 |
+
player = self.current_player
|
| 76 |
+
self.board[row, col] = player
|
| 77 |
+
|
| 78 |
+
if self._is_winning_move(row, col, player):
|
| 79 |
+
self.done = True
|
| 80 |
+
self.winner = player
|
| 81 |
+
elif not np.any(self.board == 0):
|
| 82 |
+
self.done = True
|
| 83 |
+
self.winner = 0
|
| 84 |
+
else:
|
| 85 |
+
self.current_player = -player
|
| 86 |
+
|
| 87 |
+
return self.done, self.winner
|
| 88 |
+
|
| 89 |
+
def _is_winning_move(self, row: int, col: int, player: int) -> bool:
|
| 90 |
+
directions = ((1, 0), (0, 1), (1, 1), (1, -1))
|
| 91 |
+
for dr, dc in directions:
|
| 92 |
+
count = 1
|
| 93 |
+
count += self._count_one_side(row, col, dr, dc, player)
|
| 94 |
+
count += self._count_one_side(row, col, -dr, -dc, player)
|
| 95 |
+
if count >= self.win_length:
|
| 96 |
+
return True
|
| 97 |
+
return False
|
| 98 |
+
|
| 99 |
+
def _count_one_side(self, row: int, col: int, dr: int, dc: int, player: int) -> int:
|
| 100 |
+
total = 0
|
| 101 |
+
r, c = row + dr, col + dc
|
| 102 |
+
while 0 <= r < self.board_size and 0 <= c < self.board_size:
|
| 103 |
+
if self.board[r, c] != player:
|
| 104 |
+
break
|
| 105 |
+
total += 1
|
| 106 |
+
r += dr
|
| 107 |
+
c += dc
|
| 108 |
+
return total
|
| 109 |
+
|
| 110 |
+
def render(self) -> str:
|
| 111 |
+
symbols = {1: "X", -1: "O", 0: "."}
|
| 112 |
+
header = " " + " ".join(f"{i + 1:2d}" for i in range(self.board_size))
|
| 113 |
+
rows = [header]
|
| 114 |
+
for row_idx in range(self.board_size):
|
| 115 |
+
row = " ".join(f"{symbols[int(v)]:>2}" for v in self.board[row_idx])
|
| 116 |
+
rows.append(f"{row_idx + 1:2d} {row}")
|
| 117 |
+
return "\n".join(rows)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def encode_state(board: np.ndarray, current_player: int) -> torch.Tensor:
|
| 121 |
+
current = (board == current_player).astype(np.float32)
|
| 122 |
+
opponent = (board == -current_player).astype(np.float32)
|
| 123 |
+
legal = (board == 0).astype(np.float32)
|
| 124 |
+
stacked = np.stack([current, opponent, legal], axis=0)
|
| 125 |
+
return torch.from_numpy(stacked)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class PolicyValueNet(nn.Module):
|
| 129 |
+
def __init__(self, channels: int = 64):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.trunk = nn.Sequential(
|
| 132 |
+
nn.Conv2d(3, channels, kernel_size=3, padding=1),
|
| 133 |
+
nn.ReLU(),
|
| 134 |
+
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
| 135 |
+
nn.ReLU(),
|
| 136 |
+
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
|
| 137 |
+
nn.ReLU(),
|
| 138 |
+
)
|
| 139 |
+
self.policy_head = nn.Conv2d(channels, 1, kernel_size=1)
|
| 140 |
+
self.value_head = nn.Sequential(
|
| 141 |
+
nn.AdaptiveAvgPool2d(1),
|
| 142 |
+
nn.Flatten(),
|
| 143 |
+
nn.Linear(channels, channels),
|
| 144 |
+
nn.ReLU(),
|
| 145 |
+
nn.Linear(channels, 1),
|
| 146 |
+
nn.Tanh(),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 150 |
+
features = self.trunk(x)
|
| 151 |
+
policy_logits = self.policy_head(features).flatten(start_dim=1)
|
| 152 |
+
value = self.value_head(features).squeeze(-1)
|
| 153 |
+
return policy_logits, value
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def masked_logits(logits: torch.Tensor, legal_mask: np.ndarray) -> torch.Tensor:
|
| 157 |
+
legal = torch.as_tensor(legal_mask.reshape(-1), device=logits.device, dtype=torch.bool)
|
| 158 |
+
return logits.masked_fill(~legal, -1e9)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def transform_board(board: np.ndarray, rotation_k: int, flip: bool) -> np.ndarray:
|
| 162 |
+
transformed = np.rot90(board, k=rotation_k)
|
| 163 |
+
if flip:
|
| 164 |
+
transformed = np.fliplr(transformed)
|
| 165 |
+
return np.ascontiguousarray(transformed)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def action_to_coords(action: int, board_size: int) -> tuple[int, int]:
|
| 169 |
+
return divmod(int(action), board_size)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def coords_to_action(row: int, col: int, board_size: int) -> int:
|
| 173 |
+
return row * board_size + col
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def count_one_side(
|
| 177 |
+
board: np.ndarray,
|
| 178 |
+
row: int,
|
| 179 |
+
col: int,
|
| 180 |
+
dr: int,
|
| 181 |
+
dc: int,
|
| 182 |
+
player: int,
|
| 183 |
+
) -> int:
|
| 184 |
+
board_size = board.shape[0]
|
| 185 |
+
total = 0
|
| 186 |
+
r, c = row + dr, col + dc
|
| 187 |
+
while 0 <= r < board_size and 0 <= c < board_size:
|
| 188 |
+
if board[r, c] != player:
|
| 189 |
+
break
|
| 190 |
+
total += 1
|
| 191 |
+
r += dr
|
| 192 |
+
c += dc
|
| 193 |
+
return total
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def is_winning_move(
|
| 197 |
+
board: np.ndarray,
|
| 198 |
+
row: int,
|
| 199 |
+
col: int,
|
| 200 |
+
player: int,
|
| 201 |
+
win_length: int,
|
| 202 |
+
) -> bool:
|
| 203 |
+
directions = ((1, 0), (0, 1), (1, 1), (1, -1))
|
| 204 |
+
for dr, dc in directions:
|
| 205 |
+
count = 1
|
| 206 |
+
count += count_one_side(board, row, col, dr, dc, player)
|
| 207 |
+
count += count_one_side(board, row, col, -dr, -dc, player)
|
| 208 |
+
if count >= win_length:
|
| 209 |
+
return True
|
| 210 |
+
return False
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def apply_action_to_board(
|
| 214 |
+
board: np.ndarray,
|
| 215 |
+
current_player: int,
|
| 216 |
+
action: int,
|
| 217 |
+
win_length: int,
|
| 218 |
+
) -> tuple[np.ndarray, int, bool, int]:
|
| 219 |
+
board_size = board.shape[0]
|
| 220 |
+
row, col = action_to_coords(action, board_size)
|
| 221 |
+
if board[row, col] != 0:
|
| 222 |
+
raise ValueError(f"illegal move at ({row}, {col})")
|
| 223 |
+
|
| 224 |
+
next_board = board.copy()
|
| 225 |
+
next_board[row, col] = current_player
|
| 226 |
+
|
| 227 |
+
if is_winning_move(next_board, row, col, current_player, win_length):
|
| 228 |
+
return next_board, -current_player, True, current_player
|
| 229 |
+
if not np.any(next_board == 0):
|
| 230 |
+
return next_board, -current_player, True, 0
|
| 231 |
+
return next_board, -current_player, False, 0
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def forward_transform_coords(
|
| 235 |
+
row: int,
|
| 236 |
+
col: int,
|
| 237 |
+
board_size: int,
|
| 238 |
+
rotation_k: int,
|
| 239 |
+
flip: bool,
|
| 240 |
+
) -> tuple[int, int]:
|
| 241 |
+
for _ in range(rotation_k % 4):
|
| 242 |
+
row, col = board_size - 1 - col, row
|
| 243 |
+
if flip:
|
| 244 |
+
col = board_size - 1 - col
|
| 245 |
+
return row, col
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def inverse_transform_coords(
|
| 249 |
+
row: int,
|
| 250 |
+
col: int,
|
| 251 |
+
board_size: int,
|
| 252 |
+
rotation_k: int,
|
| 253 |
+
flip: bool,
|
| 254 |
+
) -> tuple[int, int]:
|
| 255 |
+
if flip:
|
| 256 |
+
col = board_size - 1 - col
|
| 257 |
+
for _ in range(rotation_k % 4):
|
| 258 |
+
row, col = col, board_size - 1 - row
|
| 259 |
+
return row, col
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def sample_action(
|
| 263 |
+
policy: PolicyValueNet,
|
| 264 |
+
board: np.ndarray,
|
| 265 |
+
current_player: int,
|
| 266 |
+
device: torch.device,
|
| 267 |
+
greedy: bool,
|
| 268 |
+
augment: bool,
|
| 269 |
+
) -> tuple[int, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
|
| 270 |
+
board_size = board.shape[0]
|
| 271 |
+
rotation_k = random.randint(0, 3) if augment else 0
|
| 272 |
+
flip = bool(random.getrandbits(1)) if augment else False
|
| 273 |
+
transformed_board = transform_board(board, rotation_k=rotation_k, flip=flip)
|
| 274 |
+
|
| 275 |
+
state = encode_state(transformed_board, current_player).unsqueeze(0).to(device)
|
| 276 |
+
logits, value = policy(state)
|
| 277 |
+
logits = masked_logits(logits.squeeze(0), transformed_board == 0)
|
| 278 |
+
|
| 279 |
+
if greedy:
|
| 280 |
+
action = torch.argmax(logits)
|
| 281 |
+
transformed_row, transformed_col = action_to_coords(int(action.item()), board_size)
|
| 282 |
+
row, col = inverse_transform_coords(
|
| 283 |
+
transformed_row,
|
| 284 |
+
transformed_col,
|
| 285 |
+
board_size,
|
| 286 |
+
rotation_k=rotation_k,
|
| 287 |
+
flip=flip,
|
| 288 |
+
)
|
| 289 |
+
return coords_to_action(row, col, board_size), None, None, value.squeeze(0)
|
| 290 |
+
|
| 291 |
+
dist = Categorical(logits=logits)
|
| 292 |
+
action = dist.sample()
|
| 293 |
+
transformed_row, transformed_col = action_to_coords(int(action.item()), board_size)
|
| 294 |
+
row, col = inverse_transform_coords(
|
| 295 |
+
transformed_row,
|
| 296 |
+
transformed_col,
|
| 297 |
+
board_size,
|
| 298 |
+
rotation_k=rotation_k,
|
| 299 |
+
flip=flip,
|
| 300 |
+
)
|
| 301 |
+
return (
|
| 302 |
+
coords_to_action(row, col, board_size),
|
| 303 |
+
dist.log_prob(action),
|
| 304 |
+
dist.entropy(),
|
| 305 |
+
value.squeeze(0),
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def evaluate_policy_value(
|
| 310 |
+
policy: PolicyValueNet,
|
| 311 |
+
board: np.ndarray,
|
| 312 |
+
current_player: int,
|
| 313 |
+
device: torch.device,
|
| 314 |
+
) -> tuple[np.ndarray, float]:
|
| 315 |
+
state = encode_state(board, current_player).unsqueeze(0).to(device)
|
| 316 |
+
with torch.no_grad():
|
| 317 |
+
logits, value = policy(state)
|
| 318 |
+
logits = masked_logits(logits.squeeze(0), board == 0)
|
| 319 |
+
probs = torch.softmax(logits, dim=0).detach().cpu().numpy()
|
| 320 |
+
return probs, float(value.item())
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@dataclass
|
| 324 |
+
class MCTSNode:
|
| 325 |
+
board: np.ndarray
|
| 326 |
+
current_player: int
|
| 327 |
+
win_length: int
|
| 328 |
+
done: bool = False
|
| 329 |
+
winner: int = 0
|
| 330 |
+
priors: dict[int, float] = field(default_factory=dict)
|
| 331 |
+
visit_counts: dict[int, int] = field(default_factory=dict)
|
| 332 |
+
value_sums: dict[int, float] = field(default_factory=dict)
|
| 333 |
+
children: dict[int, "MCTSNode"] = field(default_factory=dict)
|
| 334 |
+
expanded: bool = False
|
| 335 |
+
|
| 336 |
+
def expand(self, priors: np.ndarray) -> None:
|
| 337 |
+
legal_actions = np.flatnonzero(self.board.reshape(-1) == 0)
|
| 338 |
+
total_prob = float(np.sum(priors[legal_actions]))
|
| 339 |
+
if total_prob <= 0.0:
|
| 340 |
+
uniform = 1.0 / max(len(legal_actions), 1)
|
| 341 |
+
self.priors = {int(action): uniform for action in legal_actions}
|
| 342 |
+
else:
|
| 343 |
+
self.priors = {
|
| 344 |
+
int(action): float(priors[action] / total_prob)
|
| 345 |
+
for action in legal_actions
|
| 346 |
+
}
|
| 347 |
+
self.visit_counts = {action: 0 for action in self.priors}
|
| 348 |
+
self.value_sums = {action: 0.0 for action in self.priors}
|
| 349 |
+
self.expanded = True
|
| 350 |
+
|
| 351 |
+
def q_value(self, action: int) -> float:
|
| 352 |
+
visits = self.visit_counts[action]
|
| 353 |
+
if visits == 0:
|
| 354 |
+
return 0.0
|
| 355 |
+
return self.value_sums[action] / visits
|
| 356 |
+
|
| 357 |
+
def select_action(self, c_puct: float) -> int:
|
| 358 |
+
total_visits = sum(self.visit_counts.values())
|
| 359 |
+
sqrt_total = math.sqrt(total_visits + 1.0)
|
| 360 |
+
best_action = -1
|
| 361 |
+
best_score = -float("inf")
|
| 362 |
+
|
| 363 |
+
for action, prior in self.priors.items():
|
| 364 |
+
visits = self.visit_counts[action]
|
| 365 |
+
q = self.q_value(action)
|
| 366 |
+
u = c_puct * prior * sqrt_total / (1.0 + visits)
|
| 367 |
+
score = q + u
|
| 368 |
+
if score > best_score:
|
| 369 |
+
best_score = score
|
| 370 |
+
best_action = action
|
| 371 |
+
|
| 372 |
+
return best_action
|
| 373 |
+
|
| 374 |
+
def child_for_action(self, action: int) -> "MCTSNode":
|
| 375 |
+
child = self.children.get(action)
|
| 376 |
+
if child is not None:
|
| 377 |
+
return child
|
| 378 |
+
|
| 379 |
+
next_board, next_player, done, winner = apply_action_to_board(
|
| 380 |
+
board=self.board,
|
| 381 |
+
current_player=self.current_player,
|
| 382 |
+
action=action,
|
| 383 |
+
win_length=self.win_length,
|
| 384 |
+
)
|
| 385 |
+
child = MCTSNode(
|
| 386 |
+
board=next_board,
|
| 387 |
+
current_player=next_player,
|
| 388 |
+
win_length=self.win_length,
|
| 389 |
+
done=done,
|
| 390 |
+
winner=winner,
|
| 391 |
+
)
|
| 392 |
+
self.children[action] = child
|
| 393 |
+
return child
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def terminal_value(winner: int, current_player: int) -> float:
|
| 397 |
+
if winner == 0:
|
| 398 |
+
return 0.0
|
| 399 |
+
return 1.0 if winner == current_player else -1.0
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def choose_mcts_action(
|
| 403 |
+
policy: PolicyValueNet,
|
| 404 |
+
board: np.ndarray,
|
| 405 |
+
current_player: int,
|
| 406 |
+
win_length: int,
|
| 407 |
+
device: torch.device,
|
| 408 |
+
num_simulations: int,
|
| 409 |
+
c_puct: float,
|
| 410 |
+
) -> tuple[int, np.ndarray]:
|
| 411 |
+
root = MCTSNode(
|
| 412 |
+
board=board.copy(),
|
| 413 |
+
current_player=current_player,
|
| 414 |
+
win_length=win_length,
|
| 415 |
+
)
|
| 416 |
+
|
| 417 |
+
priors, _ = evaluate_policy_value(policy, root.board, root.current_player, device)
|
| 418 |
+
root.expand(priors)
|
| 419 |
+
|
| 420 |
+
for _ in range(num_simulations):
|
| 421 |
+
node = root
|
| 422 |
+
path: list[tuple[MCTSNode, int]] = []
|
| 423 |
+
|
| 424 |
+
while node.expanded and not node.done:
|
| 425 |
+
action = node.select_action(c_puct)
|
| 426 |
+
path.append((node, action))
|
| 427 |
+
node = node.child_for_action(action)
|
| 428 |
+
|
| 429 |
+
if node.done:
|
| 430 |
+
value = terminal_value(node.winner, node.current_player)
|
| 431 |
+
else:
|
| 432 |
+
priors, value = evaluate_policy_value(policy, node.board, node.current_player, device)
|
| 433 |
+
node.expand(priors)
|
| 434 |
+
|
| 435 |
+
for parent, action in reversed(path):
|
| 436 |
+
value = -value
|
| 437 |
+
parent.visit_counts[action] += 1
|
| 438 |
+
parent.value_sums[action] += value
|
| 439 |
+
|
| 440 |
+
visits = np.zeros(board.size, dtype=np.float32)
|
| 441 |
+
for action, count in root.visit_counts.items():
|
| 442 |
+
visits[action] = float(count)
|
| 443 |
+
|
| 444 |
+
if np.all(visits == 0):
|
| 445 |
+
best_action = int(np.argmax(priors))
|
| 446 |
+
else:
|
| 447 |
+
best_action = int(np.argmax(visits))
|
| 448 |
+
|
| 449 |
+
return best_action, visits.reshape(board.shape)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def choose_ai_action(
|
| 453 |
+
policy: PolicyValueNet,
|
| 454 |
+
board: np.ndarray,
|
| 455 |
+
current_player: int,
|
| 456 |
+
win_length: int,
|
| 457 |
+
device: torch.device,
|
| 458 |
+
agent: str,
|
| 459 |
+
mcts_sims: int,
|
| 460 |
+
c_puct: float,
|
| 461 |
+
) -> tuple[int, np.ndarray | None]:
|
| 462 |
+
if agent == "mcts":
|
| 463 |
+
return choose_mcts_action(
|
| 464 |
+
policy=policy,
|
| 465 |
+
board=board,
|
| 466 |
+
current_player=current_player,
|
| 467 |
+
win_length=win_length,
|
| 468 |
+
device=device,
|
| 469 |
+
num_simulations=mcts_sims,
|
| 470 |
+
c_puct=c_puct,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
action, _, _, _ = sample_action(
|
| 474 |
+
policy=policy,
|
| 475 |
+
board=board,
|
| 476 |
+
current_player=current_player,
|
| 477 |
+
device=device,
|
| 478 |
+
greedy=True,
|
| 479 |
+
augment=False,
|
| 480 |
+
)
|
| 481 |
+
return action, None
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
def self_play_episode(
|
| 485 |
+
policy: PolicyValueNet,
|
| 486 |
+
env: GomokuEnv,
|
| 487 |
+
device: torch.device,
|
| 488 |
+
gamma: float,
|
| 489 |
+
augment: bool,
|
| 490 |
+
) -> tuple[list[torch.Tensor], list[float], list[torch.Tensor], list[torch.Tensor], int, int]:
|
| 491 |
+
env.reset()
|
| 492 |
+
log_probs: list[torch.Tensor] = []
|
| 493 |
+
entropies: list[torch.Tensor] = []
|
| 494 |
+
values: list[torch.Tensor] = []
|
| 495 |
+
players: list[int] = []
|
| 496 |
+
|
| 497 |
+
while not env.done:
|
| 498 |
+
player = env.current_player
|
| 499 |
+
action, log_prob, entropy, value = sample_action(
|
| 500 |
+
policy=policy,
|
| 501 |
+
board=env.board,
|
| 502 |
+
current_player=player,
|
| 503 |
+
device=device,
|
| 504 |
+
greedy=False,
|
| 505 |
+
augment=augment,
|
| 506 |
+
)
|
| 507 |
+
log_probs.append(log_prob)
|
| 508 |
+
entropies.append(entropy)
|
| 509 |
+
values.append(value)
|
| 510 |
+
players.append(player)
|
| 511 |
+
env.step(action)
|
| 512 |
+
|
| 513 |
+
returns: list[float] = []
|
| 514 |
+
total_moves = len(players)
|
| 515 |
+
for move_idx, player in enumerate(players):
|
| 516 |
+
outcome = 0.0
|
| 517 |
+
if env.winner != 0:
|
| 518 |
+
outcome = 1.0 if player == env.winner else -1.0
|
| 519 |
+
discounted = outcome * (gamma ** (total_moves - move_idx - 1))
|
| 520 |
+
returns.append(discounted)
|
| 521 |
+
|
| 522 |
+
return log_probs, returns, entropies, values, env.winner, total_moves
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def update_policy(
|
| 526 |
+
optimizer: torch.optim.Optimizer,
|
| 527 |
+
batch_log_probs: list[torch.Tensor],
|
| 528 |
+
batch_returns: list[float],
|
| 529 |
+
batch_entropies: list[torch.Tensor],
|
| 530 |
+
batch_values: list[torch.Tensor],
|
| 531 |
+
entropy_coef: float,
|
| 532 |
+
value_coef: float,
|
| 533 |
+
grad_clip: float,
|
| 534 |
+
device: torch.device,
|
| 535 |
+
) -> float:
|
| 536 |
+
returns = torch.tensor(batch_returns, dtype=torch.float32, device=device)
|
| 537 |
+
log_probs = torch.stack(batch_log_probs)
|
| 538 |
+
entropies = torch.stack(batch_entropies)
|
| 539 |
+
values = torch.stack(batch_values)
|
| 540 |
+
advantages = returns - values.detach()
|
| 541 |
+
if advantages.numel() > 1:
|
| 542 |
+
advantages = (advantages - advantages.mean()) / (advantages.std(unbiased=False) + 1e-6)
|
| 543 |
+
|
| 544 |
+
policy_loss = -(log_probs * advantages).mean()
|
| 545 |
+
value_loss = torch.mean((values - returns) ** 2)
|
| 546 |
+
entropy_bonus = entropies.mean()
|
| 547 |
+
loss = policy_loss + value_coef * value_loss - entropy_coef * entropy_bonus
|
| 548 |
+
|
| 549 |
+
optimizer.zero_grad(set_to_none=True)
|
| 550 |
+
loss.backward()
|
| 551 |
+
nn.utils.clip_grad_norm_(optimizer.param_groups[0]["params"], grad_clip)
|
| 552 |
+
optimizer.step()
|
| 553 |
+
return float(loss.item())
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def save_checkpoint(
|
| 557 |
+
path: Path,
|
| 558 |
+
policy: PolicyValueNet,
|
| 559 |
+
args: argparse.Namespace,
|
| 560 |
+
) -> None:
|
| 561 |
+
payload = {
|
| 562 |
+
"state_dict": policy.state_dict(),
|
| 563 |
+
"channels": args.channels,
|
| 564 |
+
"board_size": args.board_size,
|
| 565 |
+
"win_length": args.win_length,
|
| 566 |
+
}
|
| 567 |
+
torch.save(payload, path)
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def load_checkpoint(path: Path, map_location: torch.device) -> dict:
|
| 571 |
+
checkpoint = torch.load(path, map_location=map_location)
|
| 572 |
+
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
|
| 573 |
+
return checkpoint
|
| 574 |
+
if isinstance(checkpoint, dict) and "policy_state_dict" in checkpoint:
|
| 575 |
+
raise RuntimeError(
|
| 576 |
+
f"{path} is an old fixed-board checkpoint from the previous implementation. "
|
| 577 |
+
"It is not compatible with the current fully-convolutional actor-critic model. "
|
| 578 |
+
"Please retrain with the current script."
|
| 579 |
+
)
|
| 580 |
+
return {
|
| 581 |
+
"state_dict": checkpoint,
|
| 582 |
+
"channels": 64,
|
| 583 |
+
"board_size": None,
|
| 584 |
+
"win_length": None,
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def load_policy(path: Path, channels: int, device: torch.device) -> PolicyValueNet:
|
| 589 |
+
checkpoint = load_checkpoint(path, map_location=device)
|
| 590 |
+
state_dict = checkpoint["state_dict"]
|
| 591 |
+
saved_channels = int(checkpoint.get("channels", channels))
|
| 592 |
+
|
| 593 |
+
policy = PolicyValueNet(channels=saved_channels).to(device)
|
| 594 |
+
policy.load_state_dict(state_dict)
|
| 595 |
+
policy.eval()
|
| 596 |
+
return policy
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def resolve_game_config(
|
| 600 |
+
checkpoint_path: Path,
|
| 601 |
+
arg_board_size: int | None,
|
| 602 |
+
arg_win_length: int | None,
|
| 603 |
+
arg_channels: int,
|
| 604 |
+
device: torch.device,
|
| 605 |
+
) -> tuple[PolicyValueNet, int, int]:
|
| 606 |
+
checkpoint = load_checkpoint(checkpoint_path, map_location=device)
|
| 607 |
+
board_size = int(checkpoint.get("board_size") or arg_board_size or 15)
|
| 608 |
+
win_length = int(checkpoint.get("win_length") or arg_win_length or 5)
|
| 609 |
+
channels = int(checkpoint.get("channels") or arg_channels)
|
| 610 |
+
|
| 611 |
+
policy = PolicyValueNet(channels=channels).to(device)
|
| 612 |
+
policy.load_state_dict(checkpoint["state_dict"])
|
| 613 |
+
policy.eval()
|
| 614 |
+
return policy, board_size, win_length
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def play_vs_random_once(
|
| 618 |
+
policy: PolicyValueNet,
|
| 619 |
+
board_size: int,
|
| 620 |
+
win_length: int,
|
| 621 |
+
device: torch.device,
|
| 622 |
+
policy_player: int,
|
| 623 |
+
agent: str = "policy",
|
| 624 |
+
mcts_sims: int = 100,
|
| 625 |
+
c_puct: float = 1.5,
|
| 626 |
+
) -> int:
|
| 627 |
+
env = GomokuEnv(board_size=board_size, win_length=win_length)
|
| 628 |
+
env.reset()
|
| 629 |
+
|
| 630 |
+
while not env.done:
|
| 631 |
+
if env.current_player == policy_player:
|
| 632 |
+
action, _ = choose_ai_action(
|
| 633 |
+
policy=policy,
|
| 634 |
+
board=env.board,
|
| 635 |
+
current_player=env.current_player,
|
| 636 |
+
win_length=win_length,
|
| 637 |
+
device=device,
|
| 638 |
+
agent=agent,
|
| 639 |
+
mcts_sims=mcts_sims,
|
| 640 |
+
c_puct=c_puct,
|
| 641 |
+
)
|
| 642 |
+
else:
|
| 643 |
+
action = int(np.random.choice(env.valid_moves()))
|
| 644 |
+
env.step(action)
|
| 645 |
+
|
| 646 |
+
return env.winner
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
def evaluate_vs_random(
|
| 650 |
+
policy: PolicyValueNet,
|
| 651 |
+
board_size: int,
|
| 652 |
+
win_length: int,
|
| 653 |
+
device: torch.device,
|
| 654 |
+
games: int,
|
| 655 |
+
agent: str = "policy",
|
| 656 |
+
mcts_sims: int = 100,
|
| 657 |
+
c_puct: float = 1.5,
|
| 658 |
+
) -> tuple[float, int, int, int]:
|
| 659 |
+
wins = 0
|
| 660 |
+
draws = 0
|
| 661 |
+
losses = 0
|
| 662 |
+
|
| 663 |
+
for game_idx in range(games):
|
| 664 |
+
policy_player = 1 if game_idx < games // 2 else -1
|
| 665 |
+
winner = play_vs_random_once(
|
| 666 |
+
policy=policy,
|
| 667 |
+
board_size=board_size,
|
| 668 |
+
win_length=win_length,
|
| 669 |
+
device=device,
|
| 670 |
+
policy_player=policy_player,
|
| 671 |
+
agent=agent,
|
| 672 |
+
mcts_sims=mcts_sims,
|
| 673 |
+
c_puct=c_puct,
|
| 674 |
+
)
|
| 675 |
+
if winner == 0:
|
| 676 |
+
draws += 1
|
| 677 |
+
elif winner == policy_player:
|
| 678 |
+
wins += 1
|
| 679 |
+
else:
|
| 680 |
+
losses += 1
|
| 681 |
+
|
| 682 |
+
return wins / max(games, 1), wins, draws, losses
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
def train(args: argparse.Namespace) -> None:
|
| 686 |
+
set_seed(args.seed)
|
| 687 |
+
device = choose_device(args.device)
|
| 688 |
+
env = GomokuEnv(board_size=args.board_size, win_length=args.win_length)
|
| 689 |
+
policy = PolicyValueNet(channels=args.channels).to(device)
|
| 690 |
+
if args.init_checkpoint is not None and args.init_checkpoint.exists():
|
| 691 |
+
checkpoint = load_checkpoint(args.init_checkpoint, map_location=device)
|
| 692 |
+
policy.load_state_dict(checkpoint["state_dict"])
|
| 693 |
+
optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr)
|
| 694 |
+
|
| 695 |
+
recent_winners: deque[int] = deque(maxlen=args.print_every)
|
| 696 |
+
recent_lengths: deque[int] = deque(maxlen=args.print_every)
|
| 697 |
+
batch_log_probs: list[torch.Tensor] = []
|
| 698 |
+
batch_returns: list[float] = []
|
| 699 |
+
batch_entropies: list[torch.Tensor] = []
|
| 700 |
+
batch_values: list[torch.Tensor] = []
|
| 701 |
+
last_loss = 0.0
|
| 702 |
+
|
| 703 |
+
print(f"device={device} board={args.board_size} win={args.win_length}")
|
| 704 |
+
|
| 705 |
+
for episode in range(1, args.episodes + 1):
|
| 706 |
+
log_probs, returns, entropies, values, winner, moves = self_play_episode(
|
| 707 |
+
policy=policy,
|
| 708 |
+
env=env,
|
| 709 |
+
device=device,
|
| 710 |
+
gamma=args.gamma,
|
| 711 |
+
augment=args.symmetry_augment,
|
| 712 |
+
)
|
| 713 |
+
batch_log_probs.extend(log_probs)
|
| 714 |
+
batch_returns.extend(returns)
|
| 715 |
+
batch_entropies.extend(entropies)
|
| 716 |
+
batch_values.extend(values)
|
| 717 |
+
recent_winners.append(winner)
|
| 718 |
+
recent_lengths.append(moves)
|
| 719 |
+
|
| 720 |
+
if episode % args.batch_size == 0 or episode == args.episodes:
|
| 721 |
+
policy.train()
|
| 722 |
+
last_loss = update_policy(
|
| 723 |
+
optimizer=optimizer,
|
| 724 |
+
batch_log_probs=batch_log_probs,
|
| 725 |
+
batch_returns=batch_returns,
|
| 726 |
+
batch_entropies=batch_entropies,
|
| 727 |
+
batch_values=batch_values,
|
| 728 |
+
entropy_coef=args.entropy_coef,
|
| 729 |
+
value_coef=args.value_coef,
|
| 730 |
+
grad_clip=args.grad_clip,
|
| 731 |
+
device=device,
|
| 732 |
+
)
|
| 733 |
+
batch_log_probs.clear()
|
| 734 |
+
batch_returns.clear()
|
| 735 |
+
batch_entropies.clear()
|
| 736 |
+
batch_values.clear()
|
| 737 |
+
|
| 738 |
+
if episode % args.print_every == 0 or episode == args.episodes:
|
| 739 |
+
p1_wins = sum(1 for x in recent_winners if x == 1)
|
| 740 |
+
p2_wins = sum(1 for x in recent_winners if x == -1)
|
| 741 |
+
draws = sum(1 for x in recent_winners if x == 0)
|
| 742 |
+
avg_len = float(np.mean(recent_lengths)) if recent_lengths else 0.0
|
| 743 |
+
message = (
|
| 744 |
+
f"episode={episode:6d} loss={last_loss:8.4f} "
|
| 745 |
+
f"p1={p1_wins:4d} p2={p2_wins:4d} draw={draws:4d} avg_len={avg_len:6.2f}"
|
| 746 |
+
)
|
| 747 |
+
if args.eval_every > 0 and episode % args.eval_every == 0:
|
| 748 |
+
policy.eval()
|
| 749 |
+
win_rate, wins, eval_draws, losses = evaluate_vs_random(
|
| 750 |
+
policy=policy,
|
| 751 |
+
board_size=args.board_size,
|
| 752 |
+
win_length=args.win_length,
|
| 753 |
+
device=device,
|
| 754 |
+
games=args.eval_games,
|
| 755 |
+
)
|
| 756 |
+
message += (
|
| 757 |
+
f" random_win_rate={win_rate:.3f}"
|
| 758 |
+
f" ({wins}/{eval_draws}/{losses})"
|
| 759 |
+
)
|
| 760 |
+
print(message)
|
| 761 |
+
|
| 762 |
+
save_checkpoint(args.checkpoint, policy, args)
|
| 763 |
+
print(f"saved checkpoint to {args.checkpoint}")
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
def evaluate(args: argparse.Namespace) -> None:
|
| 767 |
+
device = choose_device(args.device)
|
| 768 |
+
policy, board_size, win_length = resolve_game_config(
|
| 769 |
+
checkpoint_path=args.checkpoint,
|
| 770 |
+
arg_board_size=args.board_size,
|
| 771 |
+
arg_win_length=args.win_length,
|
| 772 |
+
arg_channels=args.channels,
|
| 773 |
+
device=device,
|
| 774 |
+
)
|
| 775 |
+
win_rate, wins, draws, losses = evaluate_vs_random(
|
| 776 |
+
policy=policy,
|
| 777 |
+
board_size=board_size,
|
| 778 |
+
win_length=win_length,
|
| 779 |
+
device=device,
|
| 780 |
+
games=args.games,
|
| 781 |
+
agent=args.agent,
|
| 782 |
+
mcts_sims=args.mcts_sims,
|
| 783 |
+
c_puct=args.c_puct,
|
| 784 |
+
)
|
| 785 |
+
print(f"device={device}")
|
| 786 |
+
print(f"agent={args.agent} mcts_sims={args.mcts_sims}")
|
| 787 |
+
print(f"win_rate={win_rate:.3f} wins={wins} draws={draws} losses={losses}")
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
def ask_human_move(env: GomokuEnv) -> int:
|
| 791 |
+
while True:
|
| 792 |
+
text = input("your move (row col): ").strip()
|
| 793 |
+
parts = text.replace(",", " ").split()
|
| 794 |
+
if len(parts) != 2:
|
| 795 |
+
print("please enter: row col")
|
| 796 |
+
continue
|
| 797 |
+
try:
|
| 798 |
+
row, col = (int(parts[0]) - 1, int(parts[1]) - 1)
|
| 799 |
+
except ValueError:
|
| 800 |
+
print("row and col must be integers")
|
| 801 |
+
continue
|
| 802 |
+
if not (0 <= row < env.board_size and 0 <= col < env.board_size):
|
| 803 |
+
print("move out of range")
|
| 804 |
+
continue
|
| 805 |
+
if env.board[row, col] != 0:
|
| 806 |
+
print("that position is occupied")
|
| 807 |
+
continue
|
| 808 |
+
return row * env.board_size + col
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
def play(args: argparse.Namespace) -> None:
|
| 812 |
+
device = choose_device(args.device)
|
| 813 |
+
policy, board_size, win_length = resolve_game_config(
|
| 814 |
+
checkpoint_path=args.checkpoint,
|
| 815 |
+
arg_board_size=args.board_size,
|
| 816 |
+
arg_win_length=args.win_length,
|
| 817 |
+
arg_channels=args.channels,
|
| 818 |
+
device=device,
|
| 819 |
+
)
|
| 820 |
+
env = GomokuEnv(board_size=board_size, win_length=win_length)
|
| 821 |
+
human_player = 1 if args.human_first else -1
|
| 822 |
+
|
| 823 |
+
print(f"device={device}")
|
| 824 |
+
print(
|
| 825 |
+
f"human={'X' if human_player == 1 else 'O'} ai={'O' if human_player == 1 else 'X'} "
|
| 826 |
+
f"agent={args.agent} mcts_sims={args.mcts_sims}"
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
while not env.done:
|
| 830 |
+
print()
|
| 831 |
+
print(env.render())
|
| 832 |
+
print()
|
| 833 |
+
|
| 834 |
+
if env.current_player == human_player:
|
| 835 |
+
action = ask_human_move(env)
|
| 836 |
+
else:
|
| 837 |
+
action, _ = choose_ai_action(
|
| 838 |
+
policy=policy,
|
| 839 |
+
board=env.board,
|
| 840 |
+
current_player=env.current_player,
|
| 841 |
+
win_length=win_length,
|
| 842 |
+
device=device,
|
| 843 |
+
agent=args.agent,
|
| 844 |
+
mcts_sims=args.mcts_sims,
|
| 845 |
+
c_puct=args.c_puct,
|
| 846 |
+
)
|
| 847 |
+
row, col = divmod(action, env.board_size)
|
| 848 |
+
print(f"ai move: {row + 1} {col + 1}")
|
| 849 |
+
|
| 850 |
+
env.step(action)
|
| 851 |
+
|
| 852 |
+
print()
|
| 853 |
+
print(env.render())
|
| 854 |
+
if env.winner == 0:
|
| 855 |
+
print("draw")
|
| 856 |
+
elif env.winner == human_player:
|
| 857 |
+
print("you win")
|
| 858 |
+
else:
|
| 859 |
+
print("ai wins")
|
| 860 |
+
|
| 861 |
+
|
| 862 |
+
def gui(args: argparse.Namespace) -> None:
|
| 863 |
+
try:
|
| 864 |
+
import pygame
|
| 865 |
+
except ModuleNotFoundError as exc:
|
| 866 |
+
raise SystemExit(
|
| 867 |
+
"pygame is not installed. Install it with: "
|
| 868 |
+
"~/miniconda3/bin/conda run -n lerobot python -m pip install pygame"
|
| 869 |
+
) from exc
|
| 870 |
+
|
| 871 |
+
device = choose_device(args.device)
|
| 872 |
+
policy, board_size, win_length = resolve_game_config(
|
| 873 |
+
checkpoint_path=args.checkpoint,
|
| 874 |
+
arg_board_size=args.board_size,
|
| 875 |
+
arg_win_length=args.win_length,
|
| 876 |
+
arg_channels=args.channels,
|
| 877 |
+
device=device,
|
| 878 |
+
)
|
| 879 |
+
env = GomokuEnv(board_size=board_size, win_length=win_length)
|
| 880 |
+
human_player = 1 if args.human_first else -1
|
| 881 |
+
last_search_visits: np.ndarray | None = None
|
| 882 |
+
|
| 883 |
+
pygame.init()
|
| 884 |
+
pygame.display.set_caption("Gomoku Policy Gradient")
|
| 885 |
+
font = pygame.font.SysFont("Arial", 24)
|
| 886 |
+
small_font = pygame.font.SysFont("Arial", 18)
|
| 887 |
+
|
| 888 |
+
cell_size = args.cell_size
|
| 889 |
+
padding = 40
|
| 890 |
+
status_height = 80
|
| 891 |
+
board_pixels = board_size * cell_size
|
| 892 |
+
screen = pygame.display.set_mode(
|
| 893 |
+
(board_pixels + padding * 2, board_pixels + padding * 2 + status_height)
|
| 894 |
+
)
|
| 895 |
+
clock = pygame.time.Clock()
|
| 896 |
+
|
| 897 |
+
background = (236, 196, 122)
|
| 898 |
+
line_color = (80, 55, 20)
|
| 899 |
+
black_stone = (20, 20, 20)
|
| 900 |
+
white_stone = (245, 245, 245)
|
| 901 |
+
accent = (180, 40, 40)
|
| 902 |
+
|
| 903 |
+
def board_to_screen(row: int, col: int) -> tuple[int, int]:
|
| 904 |
+
x = padding + col * cell_size + cell_size // 2
|
| 905 |
+
y = padding + row * cell_size + cell_size // 2
|
| 906 |
+
return x, y
|
| 907 |
+
|
| 908 |
+
def mouse_to_action(pos: tuple[int, int]) -> int | None:
|
| 909 |
+
x, y = pos
|
| 910 |
+
left = padding
|
| 911 |
+
top = padding
|
| 912 |
+
if x < left or y < top:
|
| 913 |
+
return None
|
| 914 |
+
col = (x - left) // cell_size
|
| 915 |
+
row = (y - top) // cell_size
|
| 916 |
+
if not (0 <= row < env.board_size and 0 <= col < env.board_size):
|
| 917 |
+
return None
|
| 918 |
+
if env.board[row, col] != 0:
|
| 919 |
+
return None
|
| 920 |
+
return row * env.board_size + col
|
| 921 |
+
|
| 922 |
+
def restart() -> None:
|
| 923 |
+
nonlocal last_search_visits
|
| 924 |
+
env.reset()
|
| 925 |
+
last_search_visits = None
|
| 926 |
+
if env.current_player != human_player:
|
| 927 |
+
ai_step()
|
| 928 |
+
|
| 929 |
+
def ai_step() -> None:
|
| 930 |
+
nonlocal last_search_visits
|
| 931 |
+
if env.done or env.current_player == human_player:
|
| 932 |
+
return
|
| 933 |
+
action, visits = choose_ai_action(
|
| 934 |
+
policy=policy,
|
| 935 |
+
board=env.board,
|
| 936 |
+
current_player=env.current_player,
|
| 937 |
+
win_length=win_length,
|
| 938 |
+
device=device,
|
| 939 |
+
agent=args.agent,
|
| 940 |
+
mcts_sims=args.mcts_sims,
|
| 941 |
+
c_puct=args.c_puct,
|
| 942 |
+
)
|
| 943 |
+
last_search_visits = visits
|
| 944 |
+
env.step(action)
|
| 945 |
+
|
| 946 |
+
def status_text() -> str:
|
| 947 |
+
if env.done:
|
| 948 |
+
if env.winner == 0:
|
| 949 |
+
return "Draw. Press R to restart."
|
| 950 |
+
if env.winner == human_player:
|
| 951 |
+
return "You win. Press R to restart."
|
| 952 |
+
return "AI wins. Press R to restart."
|
| 953 |
+
if env.current_player == human_player:
|
| 954 |
+
return "Your turn. Left click to place."
|
| 955 |
+
return "AI is thinking..."
|
| 956 |
+
|
| 957 |
+
if env.current_player != human_player:
|
| 958 |
+
ai_step()
|
| 959 |
+
|
| 960 |
+
running = True
|
| 961 |
+
while running:
|
| 962 |
+
for event in pygame.event.get():
|
| 963 |
+
if event.type == pygame.QUIT:
|
| 964 |
+
running = False
|
| 965 |
+
elif event.type == pygame.KEYDOWN:
|
| 966 |
+
if event.key == pygame.K_ESCAPE:
|
| 967 |
+
running = False
|
| 968 |
+
elif event.key == pygame.K_r:
|
| 969 |
+
restart()
|
| 970 |
+
elif event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
|
| 971 |
+
if env.done or env.current_player != human_player:
|
| 972 |
+
continue
|
| 973 |
+
action = mouse_to_action(event.pos)
|
| 974 |
+
if action is None:
|
| 975 |
+
continue
|
| 976 |
+
env.step(action)
|
| 977 |
+
ai_step()
|
| 978 |
+
|
| 979 |
+
screen.fill(background)
|
| 980 |
+
|
| 981 |
+
for idx in range(board_size + 1):
|
| 982 |
+
x = padding + idx * cell_size
|
| 983 |
+
pygame.draw.line(screen, line_color, (x, padding), (x, padding + board_pixels), 2)
|
| 984 |
+
y = padding + idx * cell_size
|
| 985 |
+
pygame.draw.line(screen, line_color, (padding, y), (padding + board_pixels, y), 2)
|
| 986 |
+
|
| 987 |
+
for row in range(env.board_size):
|
| 988 |
+
for col in range(env.board_size):
|
| 989 |
+
stone = int(env.board[row, col])
|
| 990 |
+
if stone == 0:
|
| 991 |
+
continue
|
| 992 |
+
x, y = board_to_screen(row, col)
|
| 993 |
+
color = black_stone if stone == 1 else white_stone
|
| 994 |
+
pygame.draw.circle(screen, color, (x, y), cell_size // 2 - 4)
|
| 995 |
+
pygame.draw.circle(screen, line_color, (x, y), cell_size // 2 - 4, 1)
|
| 996 |
+
|
| 997 |
+
for idx in range(board_size):
|
| 998 |
+
label = small_font.render(str(idx + 1), True, line_color)
|
| 999 |
+
screen.blit(
|
| 1000 |
+
label,
|
| 1001 |
+
(padding + idx * cell_size + cell_size // 2 - label.get_width() // 2, 8),
|
| 1002 |
+
)
|
| 1003 |
+
screen.blit(
|
| 1004 |
+
label,
|
| 1005 |
+
(8, padding + idx * cell_size + cell_size // 2 - label.get_height() // 2),
|
| 1006 |
+
)
|
| 1007 |
+
|
| 1008 |
+
info = (
|
| 1009 |
+
f"{board_size}x{board_size} connect={win_length} "
|
| 1010 |
+
f"device={device} human={'X' if human_player == 1 else 'O'} "
|
| 1011 |
+
f"agent={args.agent}"
|
| 1012 |
+
)
|
| 1013 |
+
info_surface = small_font.render(info, True, line_color)
|
| 1014 |
+
status_surface = font.render(status_text(), True, accent)
|
| 1015 |
+
screen.blit(info_surface, (padding, padding + board_pixels + 16))
|
| 1016 |
+
screen.blit(status_surface, (padding, padding + board_pixels + 42))
|
| 1017 |
+
|
| 1018 |
+
if last_search_visits is not None and args.agent == "mcts":
|
| 1019 |
+
peak = float(np.max(last_search_visits))
|
| 1020 |
+
if peak > 0:
|
| 1021 |
+
stats = small_font.render(
|
| 1022 |
+
f"mcts_sims={args.mcts_sims} peak_visits={int(peak)}",
|
| 1023 |
+
True,
|
| 1024 |
+
line_color,
|
| 1025 |
+
)
|
| 1026 |
+
screen.blit(stats, (padding + 380, padding + board_pixels + 16))
|
| 1027 |
+
|
| 1028 |
+
pygame.display.flip()
|
| 1029 |
+
clock.tick(args.fps)
|
| 1030 |
+
|
| 1031 |
+
pygame.quit()
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
def build_parser() -> argparse.ArgumentParser:
|
| 1035 |
+
parser = argparse.ArgumentParser(description="Minimal Gomoku policy gradient example")
|
| 1036 |
+
subparsers = parser.add_subparsers(dest="mode", required=True)
|
| 1037 |
+
|
| 1038 |
+
def add_common_arguments(subparser: argparse.ArgumentParser, defaults_from_checkpoint: bool = False) -> None:
|
| 1039 |
+
board_default = None if defaults_from_checkpoint else 15
|
| 1040 |
+
win_default = None if defaults_from_checkpoint else 5
|
| 1041 |
+
subparser.add_argument("--board-size", type=int, default=board_default)
|
| 1042 |
+
subparser.add_argument("--win-length", type=int, default=win_default)
|
| 1043 |
+
subparser.add_argument("--channels", type=int, default=64)
|
| 1044 |
+
subparser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto")
|
| 1045 |
+
subparser.add_argument("--checkpoint", type=Path, default=Path("gomoku_policy.pt"))
|
| 1046 |
+
|
| 1047 |
+
def add_inference_arguments(subparser: argparse.ArgumentParser, default_agent: str = "mcts") -> None:
|
| 1048 |
+
subparser.add_argument("--agent", choices=["policy", "mcts"], default=default_agent)
|
| 1049 |
+
subparser.add_argument("--mcts-sims", type=int, default=120)
|
| 1050 |
+
subparser.add_argument("--c-puct", type=float, default=1.5)
|
| 1051 |
+
|
| 1052 |
+
train_parser = subparsers.add_parser("train", help="self-play training")
|
| 1053 |
+
add_common_arguments(train_parser)
|
| 1054 |
+
train_parser.add_argument("--episodes", type=int, default=5000)
|
| 1055 |
+
train_parser.add_argument("--batch-size", type=int, default=32)
|
| 1056 |
+
train_parser.add_argument("--lr", type=float, default=1e-3)
|
| 1057 |
+
train_parser.add_argument("--gamma", type=float, default=0.99)
|
| 1058 |
+
train_parser.add_argument("--entropy-coef", type=float, default=0.01)
|
| 1059 |
+
train_parser.add_argument("--value-coef", type=float, default=0.5)
|
| 1060 |
+
train_parser.add_argument("--grad-clip", type=float, default=1.0)
|
| 1061 |
+
train_parser.add_argument("--print-every", type=int, default=100)
|
| 1062 |
+
train_parser.add_argument("--eval-every", type=int, default=500)
|
| 1063 |
+
train_parser.add_argument("--eval-games", type=int, default=40)
|
| 1064 |
+
train_parser.add_argument("--seed", type=int, default=42)
|
| 1065 |
+
train_parser.add_argument("--init-checkpoint", type=Path, default=None)
|
| 1066 |
+
train_parser.add_argument(
|
| 1067 |
+
"--no-symmetry-augment",
|
| 1068 |
+
dest="symmetry_augment",
|
| 1069 |
+
action="store_false",
|
| 1070 |
+
help="disable random rotation/flip augmentation during training",
|
| 1071 |
+
)
|
| 1072 |
+
train_parser.set_defaults(symmetry_augment=True)
|
| 1073 |
+
train_parser.set_defaults(func=train)
|
| 1074 |
+
|
| 1075 |
+
eval_parser = subparsers.add_parser("eval", help="evaluate against random agent")
|
| 1076 |
+
add_common_arguments(eval_parser)
|
| 1077 |
+
eval_parser.add_argument("--games", type=int, default=100)
|
| 1078 |
+
add_inference_arguments(eval_parser)
|
| 1079 |
+
eval_parser.set_defaults(func=evaluate)
|
| 1080 |
+
|
| 1081 |
+
play_parser = subparsers.add_parser("play", help="play against the trained model")
|
| 1082 |
+
add_common_arguments(play_parser, defaults_from_checkpoint=True)
|
| 1083 |
+
play_parser.add_argument("--human-first", action="store_true", help="human plays X")
|
| 1084 |
+
add_inference_arguments(play_parser)
|
| 1085 |
+
play_parser.set_defaults(func=play)
|
| 1086 |
+
|
| 1087 |
+
gui_parser = subparsers.add_parser("gui", help="pygame GUI for testing against the model")
|
| 1088 |
+
add_common_arguments(gui_parser, defaults_from_checkpoint=True)
|
| 1089 |
+
gui_parser.add_argument("--human-first", action="store_true", help="human plays X")
|
| 1090 |
+
gui_parser.add_argument("--cell-size", type=int, default=48)
|
| 1091 |
+
gui_parser.add_argument("--fps", type=int, default=30)
|
| 1092 |
+
add_inference_arguments(gui_parser)
|
| 1093 |
+
gui_parser.set_defaults(func=gui)
|
| 1094 |
+
|
| 1095 |
+
return parser
|
| 1096 |
+
|
| 1097 |
+
|
| 1098 |
+
def main() -> None:
|
| 1099 |
+
parser = build_parser()
|
| 1100 |
+
args = parser.parse_args()
|
| 1101 |
+
args.func(args)
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
if __name__ == "__main__":
|
| 1105 |
+
main()
|
train_mcts_15x15_5.sh
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR="$(cd "$(dirname "$0")" && pwd)"
|
| 5 |
+
CONDA_BIN="${CONDA_BIN:-$HOME/miniconda3/bin/conda}"
|
| 6 |
+
ENV_NAME="${ENV_NAME:-lerobot}"
|
| 7 |
+
|
| 8 |
+
INIT_CHECKPOINT="${INIT_CHECKPOINT:-$ROOT_DIR/gomoku_7x7_5.pt}"
|
| 9 |
+
OUTPUT_CHECKPOINT="${OUTPUT_CHECKPOINT:-$ROOT_DIR/gomoku_mcts_15x15_5.pt}"
|
| 10 |
+
|
| 11 |
+
BOARD_SIZE="${BOARD_SIZE:-15}"
|
| 12 |
+
WIN_LENGTH="${WIN_LENGTH:-5}"
|
| 13 |
+
CHANNELS="${CHANNELS:-64}"
|
| 14 |
+
|
| 15 |
+
ITERATIONS="${ITERATIONS:-3000}"
|
| 16 |
+
GAMES_PER_ITER="${GAMES_PER_ITER:-12}"
|
| 17 |
+
TRAIN_STEPS="${TRAIN_STEPS:-64}"
|
| 18 |
+
BATCH_SIZE="${BATCH_SIZE:-128}"
|
| 19 |
+
BUFFER_SIZE="${BUFFER_SIZE:-50000}"
|
| 20 |
+
|
| 21 |
+
MCTS_SIMS="${MCTS_SIMS:-64}"
|
| 22 |
+
EVAL_MCTS_SIMS="${EVAL_MCTS_SIMS:-160}"
|
| 23 |
+
EVAL_EVERY="${EVAL_EVERY:-10}"
|
| 24 |
+
EVAL_GAMES="${EVAL_GAMES:-20}"
|
| 25 |
+
SAVE_EVERY="${SAVE_EVERY:-10}"
|
| 26 |
+
|
| 27 |
+
LR="${LR:-5e-4}"
|
| 28 |
+
WEIGHT_DECAY="${WEIGHT_DECAY:-1e-4}"
|
| 29 |
+
VALUE_COEF="${VALUE_COEF:-1.0}"
|
| 30 |
+
CPUCT="${CPUCT:-1.5}"
|
| 31 |
+
TEMPERATURE="${TEMPERATURE:-1.0}"
|
| 32 |
+
TEMPERATURE_DROP_MOVES="${TEMPERATURE_DROP_MOVES:-10}"
|
| 33 |
+
DIRICHLET_ALPHA="${DIRICHLET_ALPHA:-0.3}"
|
| 34 |
+
NOISE_EPS="${NOISE_EPS:-0.25}"
|
| 35 |
+
SEED="${SEED:-42}"
|
| 36 |
+
DEVICE="${DEVICE:-auto}"
|
| 37 |
+
|
| 38 |
+
CMD=(
|
| 39 |
+
"$CONDA_BIN" run -n "$ENV_NAME" python "$ROOT_DIR/gomoku_mcts.py" train
|
| 40 |
+
--board-size "$BOARD_SIZE"
|
| 41 |
+
--win-length "$WIN_LENGTH"
|
| 42 |
+
--channels "$CHANNELS"
|
| 43 |
+
--iterations "$ITERATIONS"
|
| 44 |
+
--games-per-iter "$GAMES_PER_ITER"
|
| 45 |
+
--train-steps "$TRAIN_STEPS"
|
| 46 |
+
--batch-size "$BATCH_SIZE"
|
| 47 |
+
--buffer-size "$BUFFER_SIZE"
|
| 48 |
+
--mcts-sims "$MCTS_SIMS"
|
| 49 |
+
--eval-mcts-sims "$EVAL_MCTS_SIMS"
|
| 50 |
+
--eval-every "$EVAL_EVERY"
|
| 51 |
+
--eval-games "$EVAL_GAMES"
|
| 52 |
+
--save-every "$SAVE_EVERY"
|
| 53 |
+
--lr "$LR"
|
| 54 |
+
--weight-decay "$WEIGHT_DECAY"
|
| 55 |
+
--value-coef "$VALUE_COEF"
|
| 56 |
+
--c-puct "$CPUCT"
|
| 57 |
+
--temperature "$TEMPERATURE"
|
| 58 |
+
--temperature-drop-moves "$TEMPERATURE_DROP_MOVES"
|
| 59 |
+
--dirichlet-alpha "$DIRICHLET_ALPHA"
|
| 60 |
+
--noise-eps "$NOISE_EPS"
|
| 61 |
+
--seed "$SEED"
|
| 62 |
+
--device "$DEVICE"
|
| 63 |
+
--checkpoint "$OUTPUT_CHECKPOINT"
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if [[ -f "$INIT_CHECKPOINT" ]]; then
|
| 67 |
+
CMD+=(--init-checkpoint "$INIT_CHECKPOINT")
|
| 68 |
+
else
|
| 69 |
+
echo "init checkpoint not found, training from scratch: $INIT_CHECKPOINT"
|
| 70 |
+
fi
|
| 71 |
+
|
| 72 |
+
printf 'Running command:\n%s\n' "${CMD[*]}"
|
| 73 |
+
exec "${CMD[@]}"
|