gomoku-ai-code / gomoku_pg.py
lcccluck's picture
Upload Gomoku training and MCTS code
63cdefe verified
#!/usr/bin/env python3
"""Minimal Gomoku policy gradient example.
Features:
1. Configurable board size and win length, e.g. 5x5 connect-4 or 15x15 connect-5.
2. Shared-policy self-play with REINFORCE.
3. Fully convolutional policy, so the same code works for different board sizes.
4. Optional random-agent evaluation and CLI human play.
"""
from __future__ import annotations
import argparse
import math
import random
from collections import deque
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.distributions import Categorical
def choose_device(name: str) -> torch.device:
if name != "auto":
return torch.device(name)
if torch.cuda.is_available():
return torch.device("cuda")
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
class GomokuEnv:
def __init__(self, board_size: int, win_length: int):
if board_size <= 1:
raise ValueError("board_size must be > 1")
if not 1 < win_length <= board_size:
raise ValueError("win_length must satisfy 1 < win_length <= board_size")
self.board_size = board_size
self.win_length = win_length
self.reset()
def reset(self) -> np.ndarray:
self.board = np.zeros((self.board_size, self.board_size), dtype=np.int8)
self.current_player = 1
self.done = False
self.winner = 0
return self.board
def legal_mask(self) -> np.ndarray:
return self.board == 0
def valid_moves(self) -> np.ndarray:
return np.flatnonzero(self.legal_mask().reshape(-1))
def step(self, action: int) -> tuple[bool, int]:
if self.done:
raise RuntimeError("game is already finished")
row, col = divmod(int(action), self.board_size)
if self.board[row, col] != 0:
raise ValueError(f"illegal move at ({row}, {col})")
player = self.current_player
self.board[row, col] = player
if self._is_winning_move(row, col, player):
self.done = True
self.winner = player
elif not np.any(self.board == 0):
self.done = True
self.winner = 0
else:
self.current_player = -player
return self.done, self.winner
def _is_winning_move(self, row: int, col: int, player: int) -> bool:
directions = ((1, 0), (0, 1), (1, 1), (1, -1))
for dr, dc in directions:
count = 1
count += self._count_one_side(row, col, dr, dc, player)
count += self._count_one_side(row, col, -dr, -dc, player)
if count >= self.win_length:
return True
return False
def _count_one_side(self, row: int, col: int, dr: int, dc: int, player: int) -> int:
total = 0
r, c = row + dr, col + dc
while 0 <= r < self.board_size and 0 <= c < self.board_size:
if self.board[r, c] != player:
break
total += 1
r += dr
c += dc
return total
def render(self) -> str:
symbols = {1: "X", -1: "O", 0: "."}
header = " " + " ".join(f"{i + 1:2d}" for i in range(self.board_size))
rows = [header]
for row_idx in range(self.board_size):
row = " ".join(f"{symbols[int(v)]:>2}" for v in self.board[row_idx])
rows.append(f"{row_idx + 1:2d} {row}")
return "\n".join(rows)
def encode_state(board: np.ndarray, current_player: int) -> torch.Tensor:
current = (board == current_player).astype(np.float32)
opponent = (board == -current_player).astype(np.float32)
legal = (board == 0).astype(np.float32)
stacked = np.stack([current, opponent, legal], axis=0)
return torch.from_numpy(stacked)
class PolicyValueNet(nn.Module):
def __init__(self, channels: int = 64):
super().__init__()
self.trunk = nn.Sequential(
nn.Conv2d(3, channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(channels, channels, kernel_size=3, padding=1),
nn.ReLU(),
)
self.policy_head = nn.Conv2d(channels, 1, kernel_size=1)
self.value_head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(channels, channels),
nn.ReLU(),
nn.Linear(channels, 1),
nn.Tanh(),
)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
features = self.trunk(x)
policy_logits = self.policy_head(features).flatten(start_dim=1)
value = self.value_head(features).squeeze(-1)
return policy_logits, value
def masked_logits(logits: torch.Tensor, legal_mask: np.ndarray) -> torch.Tensor:
legal = torch.as_tensor(legal_mask.reshape(-1), device=logits.device, dtype=torch.bool)
return logits.masked_fill(~legal, -1e9)
def transform_board(board: np.ndarray, rotation_k: int, flip: bool) -> np.ndarray:
transformed = np.rot90(board, k=rotation_k)
if flip:
transformed = np.fliplr(transformed)
return np.ascontiguousarray(transformed)
def action_to_coords(action: int, board_size: int) -> tuple[int, int]:
return divmod(int(action), board_size)
def coords_to_action(row: int, col: int, board_size: int) -> int:
return row * board_size + col
def count_one_side(
board: np.ndarray,
row: int,
col: int,
dr: int,
dc: int,
player: int,
) -> int:
board_size = board.shape[0]
total = 0
r, c = row + dr, col + dc
while 0 <= r < board_size and 0 <= c < board_size:
if board[r, c] != player:
break
total += 1
r += dr
c += dc
return total
def is_winning_move(
board: np.ndarray,
row: int,
col: int,
player: int,
win_length: int,
) -> bool:
directions = ((1, 0), (0, 1), (1, 1), (1, -1))
for dr, dc in directions:
count = 1
count += count_one_side(board, row, col, dr, dc, player)
count += count_one_side(board, row, col, -dr, -dc, player)
if count >= win_length:
return True
return False
def apply_action_to_board(
board: np.ndarray,
current_player: int,
action: int,
win_length: int,
) -> tuple[np.ndarray, int, bool, int]:
board_size = board.shape[0]
row, col = action_to_coords(action, board_size)
if board[row, col] != 0:
raise ValueError(f"illegal move at ({row}, {col})")
next_board = board.copy()
next_board[row, col] = current_player
if is_winning_move(next_board, row, col, current_player, win_length):
return next_board, -current_player, True, current_player
if not np.any(next_board == 0):
return next_board, -current_player, True, 0
return next_board, -current_player, False, 0
def forward_transform_coords(
row: int,
col: int,
board_size: int,
rotation_k: int,
flip: bool,
) -> tuple[int, int]:
for _ in range(rotation_k % 4):
row, col = board_size - 1 - col, row
if flip:
col = board_size - 1 - col
return row, col
def inverse_transform_coords(
row: int,
col: int,
board_size: int,
rotation_k: int,
flip: bool,
) -> tuple[int, int]:
if flip:
col = board_size - 1 - col
for _ in range(rotation_k % 4):
row, col = col, board_size - 1 - row
return row, col
def sample_action(
policy: PolicyValueNet,
board: np.ndarray,
current_player: int,
device: torch.device,
greedy: bool,
augment: bool,
) -> tuple[int, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
board_size = board.shape[0]
rotation_k = random.randint(0, 3) if augment else 0
flip = bool(random.getrandbits(1)) if augment else False
transformed_board = transform_board(board, rotation_k=rotation_k, flip=flip)
state = encode_state(transformed_board, current_player).unsqueeze(0).to(device)
logits, value = policy(state)
logits = masked_logits(logits.squeeze(0), transformed_board == 0)
if greedy:
action = torch.argmax(logits)
transformed_row, transformed_col = action_to_coords(int(action.item()), board_size)
row, col = inverse_transform_coords(
transformed_row,
transformed_col,
board_size,
rotation_k=rotation_k,
flip=flip,
)
return coords_to_action(row, col, board_size), None, None, value.squeeze(0)
dist = Categorical(logits=logits)
action = dist.sample()
transformed_row, transformed_col = action_to_coords(int(action.item()), board_size)
row, col = inverse_transform_coords(
transformed_row,
transformed_col,
board_size,
rotation_k=rotation_k,
flip=flip,
)
return (
coords_to_action(row, col, board_size),
dist.log_prob(action),
dist.entropy(),
value.squeeze(0),
)
def evaluate_policy_value(
policy: PolicyValueNet,
board: np.ndarray,
current_player: int,
device: torch.device,
) -> tuple[np.ndarray, float]:
state = encode_state(board, current_player).unsqueeze(0).to(device)
with torch.no_grad():
logits, value = policy(state)
logits = masked_logits(logits.squeeze(0), board == 0)
probs = torch.softmax(logits, dim=0).detach().cpu().numpy()
return probs, float(value.item())
@dataclass
class MCTSNode:
board: np.ndarray
current_player: int
win_length: int
done: bool = False
winner: int = 0
priors: dict[int, float] = field(default_factory=dict)
visit_counts: dict[int, int] = field(default_factory=dict)
value_sums: dict[int, float] = field(default_factory=dict)
children: dict[int, "MCTSNode"] = field(default_factory=dict)
expanded: bool = False
def expand(self, priors: np.ndarray) -> None:
legal_actions = np.flatnonzero(self.board.reshape(-1) == 0)
total_prob = float(np.sum(priors[legal_actions]))
if total_prob <= 0.0:
uniform = 1.0 / max(len(legal_actions), 1)
self.priors = {int(action): uniform for action in legal_actions}
else:
self.priors = {
int(action): float(priors[action] / total_prob)
for action in legal_actions
}
self.visit_counts = {action: 0 for action in self.priors}
self.value_sums = {action: 0.0 for action in self.priors}
self.expanded = True
def q_value(self, action: int) -> float:
visits = self.visit_counts[action]
if visits == 0:
return 0.0
return self.value_sums[action] / visits
def select_action(self, c_puct: float) -> int:
total_visits = sum(self.visit_counts.values())
sqrt_total = math.sqrt(total_visits + 1.0)
best_action = -1
best_score = -float("inf")
for action, prior in self.priors.items():
visits = self.visit_counts[action]
q = self.q_value(action)
u = c_puct * prior * sqrt_total / (1.0 + visits)
score = q + u
if score > best_score:
best_score = score
best_action = action
return best_action
def child_for_action(self, action: int) -> "MCTSNode":
child = self.children.get(action)
if child is not None:
return child
next_board, next_player, done, winner = apply_action_to_board(
board=self.board,
current_player=self.current_player,
action=action,
win_length=self.win_length,
)
child = MCTSNode(
board=next_board,
current_player=next_player,
win_length=self.win_length,
done=done,
winner=winner,
)
self.children[action] = child
return child
def terminal_value(winner: int, current_player: int) -> float:
if winner == 0:
return 0.0
return 1.0 if winner == current_player else -1.0
def choose_mcts_action(
policy: PolicyValueNet,
board: np.ndarray,
current_player: int,
win_length: int,
device: torch.device,
num_simulations: int,
c_puct: float,
) -> tuple[int, np.ndarray]:
root = MCTSNode(
board=board.copy(),
current_player=current_player,
win_length=win_length,
)
priors, _ = evaluate_policy_value(policy, root.board, root.current_player, device)
root.expand(priors)
for _ in range(num_simulations):
node = root
path: list[tuple[MCTSNode, int]] = []
while node.expanded and not node.done:
action = node.select_action(c_puct)
path.append((node, action))
node = node.child_for_action(action)
if node.done:
value = terminal_value(node.winner, node.current_player)
else:
priors, value = evaluate_policy_value(policy, node.board, node.current_player, device)
node.expand(priors)
for parent, action in reversed(path):
value = -value
parent.visit_counts[action] += 1
parent.value_sums[action] += value
visits = np.zeros(board.size, dtype=np.float32)
for action, count in root.visit_counts.items():
visits[action] = float(count)
if np.all(visits == 0):
best_action = int(np.argmax(priors))
else:
best_action = int(np.argmax(visits))
return best_action, visits.reshape(board.shape)
def choose_ai_action(
policy: PolicyValueNet,
board: np.ndarray,
current_player: int,
win_length: int,
device: torch.device,
agent: str,
mcts_sims: int,
c_puct: float,
) -> tuple[int, np.ndarray | None]:
if agent == "mcts":
return choose_mcts_action(
policy=policy,
board=board,
current_player=current_player,
win_length=win_length,
device=device,
num_simulations=mcts_sims,
c_puct=c_puct,
)
action, _, _, _ = sample_action(
policy=policy,
board=board,
current_player=current_player,
device=device,
greedy=True,
augment=False,
)
return action, None
def self_play_episode(
policy: PolicyValueNet,
env: GomokuEnv,
device: torch.device,
gamma: float,
augment: bool,
) -> tuple[list[torch.Tensor], list[float], list[torch.Tensor], list[torch.Tensor], int, int]:
env.reset()
log_probs: list[torch.Tensor] = []
entropies: list[torch.Tensor] = []
values: list[torch.Tensor] = []
players: list[int] = []
while not env.done:
player = env.current_player
action, log_prob, entropy, value = sample_action(
policy=policy,
board=env.board,
current_player=player,
device=device,
greedy=False,
augment=augment,
)
log_probs.append(log_prob)
entropies.append(entropy)
values.append(value)
players.append(player)
env.step(action)
returns: list[float] = []
total_moves = len(players)
for move_idx, player in enumerate(players):
outcome = 0.0
if env.winner != 0:
outcome = 1.0 if player == env.winner else -1.0
discounted = outcome * (gamma ** (total_moves - move_idx - 1))
returns.append(discounted)
return log_probs, returns, entropies, values, env.winner, total_moves
def update_policy(
optimizer: torch.optim.Optimizer,
batch_log_probs: list[torch.Tensor],
batch_returns: list[float],
batch_entropies: list[torch.Tensor],
batch_values: list[torch.Tensor],
entropy_coef: float,
value_coef: float,
grad_clip: float,
device: torch.device,
) -> float:
returns = torch.tensor(batch_returns, dtype=torch.float32, device=device)
log_probs = torch.stack(batch_log_probs)
entropies = torch.stack(batch_entropies)
values = torch.stack(batch_values)
advantages = returns - values.detach()
if advantages.numel() > 1:
advantages = (advantages - advantages.mean()) / (advantages.std(unbiased=False) + 1e-6)
policy_loss = -(log_probs * advantages).mean()
value_loss = torch.mean((values - returns) ** 2)
entropy_bonus = entropies.mean()
loss = policy_loss + value_coef * value_loss - entropy_coef * entropy_bonus
optimizer.zero_grad(set_to_none=True)
loss.backward()
nn.utils.clip_grad_norm_(optimizer.param_groups[0]["params"], grad_clip)
optimizer.step()
return float(loss.item())
def save_checkpoint(
path: Path,
policy: PolicyValueNet,
args: argparse.Namespace,
) -> None:
payload = {
"state_dict": policy.state_dict(),
"channels": args.channels,
"board_size": args.board_size,
"win_length": args.win_length,
}
torch.save(payload, path)
def load_checkpoint(path: Path, map_location: torch.device) -> dict:
checkpoint = torch.load(path, map_location=map_location)
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
return checkpoint
if isinstance(checkpoint, dict) and "policy_state_dict" in checkpoint:
raise RuntimeError(
f"{path} is an old fixed-board checkpoint from the previous implementation. "
"It is not compatible with the current fully-convolutional actor-critic model. "
"Please retrain with the current script."
)
return {
"state_dict": checkpoint,
"channels": 64,
"board_size": None,
"win_length": None,
}
def load_policy(path: Path, channels: int, device: torch.device) -> PolicyValueNet:
checkpoint = load_checkpoint(path, map_location=device)
state_dict = checkpoint["state_dict"]
saved_channels = int(checkpoint.get("channels", channels))
policy = PolicyValueNet(channels=saved_channels).to(device)
policy.load_state_dict(state_dict)
policy.eval()
return policy
def resolve_game_config(
checkpoint_path: Path,
arg_board_size: int | None,
arg_win_length: int | None,
arg_channels: int,
device: torch.device,
) -> tuple[PolicyValueNet, int, int]:
checkpoint = load_checkpoint(checkpoint_path, map_location=device)
board_size = int(checkpoint.get("board_size") or arg_board_size or 15)
win_length = int(checkpoint.get("win_length") or arg_win_length or 5)
channels = int(checkpoint.get("channels") or arg_channels)
policy = PolicyValueNet(channels=channels).to(device)
policy.load_state_dict(checkpoint["state_dict"])
policy.eval()
return policy, board_size, win_length
def play_vs_random_once(
policy: PolicyValueNet,
board_size: int,
win_length: int,
device: torch.device,
policy_player: int,
agent: str = "policy",
mcts_sims: int = 100,
c_puct: float = 1.5,
) -> int:
env = GomokuEnv(board_size=board_size, win_length=win_length)
env.reset()
while not env.done:
if env.current_player == policy_player:
action, _ = choose_ai_action(
policy=policy,
board=env.board,
current_player=env.current_player,
win_length=win_length,
device=device,
agent=agent,
mcts_sims=mcts_sims,
c_puct=c_puct,
)
else:
action = int(np.random.choice(env.valid_moves()))
env.step(action)
return env.winner
def evaluate_vs_random(
policy: PolicyValueNet,
board_size: int,
win_length: int,
device: torch.device,
games: int,
agent: str = "policy",
mcts_sims: int = 100,
c_puct: float = 1.5,
) -> tuple[float, int, int, int]:
wins = 0
draws = 0
losses = 0
for game_idx in range(games):
policy_player = 1 if game_idx < games // 2 else -1
winner = play_vs_random_once(
policy=policy,
board_size=board_size,
win_length=win_length,
device=device,
policy_player=policy_player,
agent=agent,
mcts_sims=mcts_sims,
c_puct=c_puct,
)
if winner == 0:
draws += 1
elif winner == policy_player:
wins += 1
else:
losses += 1
return wins / max(games, 1), wins, draws, losses
def train(args: argparse.Namespace) -> None:
set_seed(args.seed)
device = choose_device(args.device)
env = GomokuEnv(board_size=args.board_size, win_length=args.win_length)
policy = PolicyValueNet(channels=args.channels).to(device)
if args.init_checkpoint is not None and args.init_checkpoint.exists():
checkpoint = load_checkpoint(args.init_checkpoint, map_location=device)
policy.load_state_dict(checkpoint["state_dict"])
optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr)
recent_winners: deque[int] = deque(maxlen=args.print_every)
recent_lengths: deque[int] = deque(maxlen=args.print_every)
batch_log_probs: list[torch.Tensor] = []
batch_returns: list[float] = []
batch_entropies: list[torch.Tensor] = []
batch_values: list[torch.Tensor] = []
last_loss = 0.0
print(f"device={device} board={args.board_size} win={args.win_length}")
for episode in range(1, args.episodes + 1):
log_probs, returns, entropies, values, winner, moves = self_play_episode(
policy=policy,
env=env,
device=device,
gamma=args.gamma,
augment=args.symmetry_augment,
)
batch_log_probs.extend(log_probs)
batch_returns.extend(returns)
batch_entropies.extend(entropies)
batch_values.extend(values)
recent_winners.append(winner)
recent_lengths.append(moves)
if episode % args.batch_size == 0 or episode == args.episodes:
policy.train()
last_loss = update_policy(
optimizer=optimizer,
batch_log_probs=batch_log_probs,
batch_returns=batch_returns,
batch_entropies=batch_entropies,
batch_values=batch_values,
entropy_coef=args.entropy_coef,
value_coef=args.value_coef,
grad_clip=args.grad_clip,
device=device,
)
batch_log_probs.clear()
batch_returns.clear()
batch_entropies.clear()
batch_values.clear()
if episode % args.print_every == 0 or episode == args.episodes:
p1_wins = sum(1 for x in recent_winners if x == 1)
p2_wins = sum(1 for x in recent_winners if x == -1)
draws = sum(1 for x in recent_winners if x == 0)
avg_len = float(np.mean(recent_lengths)) if recent_lengths else 0.0
message = (
f"episode={episode:6d} loss={last_loss:8.4f} "
f"p1={p1_wins:4d} p2={p2_wins:4d} draw={draws:4d} avg_len={avg_len:6.2f}"
)
if args.eval_every > 0 and episode % args.eval_every == 0:
policy.eval()
win_rate, wins, eval_draws, losses = evaluate_vs_random(
policy=policy,
board_size=args.board_size,
win_length=args.win_length,
device=device,
games=args.eval_games,
)
message += (
f" random_win_rate={win_rate:.3f}"
f" ({wins}/{eval_draws}/{losses})"
)
print(message)
save_checkpoint(args.checkpoint, policy, args)
print(f"saved checkpoint to {args.checkpoint}")
def evaluate(args: argparse.Namespace) -> None:
device = choose_device(args.device)
policy, board_size, win_length = resolve_game_config(
checkpoint_path=args.checkpoint,
arg_board_size=args.board_size,
arg_win_length=args.win_length,
arg_channels=args.channels,
device=device,
)
win_rate, wins, draws, losses = evaluate_vs_random(
policy=policy,
board_size=board_size,
win_length=win_length,
device=device,
games=args.games,
agent=args.agent,
mcts_sims=args.mcts_sims,
c_puct=args.c_puct,
)
print(f"device={device}")
print(f"agent={args.agent} mcts_sims={args.mcts_sims}")
print(f"win_rate={win_rate:.3f} wins={wins} draws={draws} losses={losses}")
def ask_human_move(env: GomokuEnv) -> int:
while True:
text = input("your move (row col): ").strip()
parts = text.replace(",", " ").split()
if len(parts) != 2:
print("please enter: row col")
continue
try:
row, col = (int(parts[0]) - 1, int(parts[1]) - 1)
except ValueError:
print("row and col must be integers")
continue
if not (0 <= row < env.board_size and 0 <= col < env.board_size):
print("move out of range")
continue
if env.board[row, col] != 0:
print("that position is occupied")
continue
return row * env.board_size + col
def play(args: argparse.Namespace) -> None:
device = choose_device(args.device)
policy, board_size, win_length = resolve_game_config(
checkpoint_path=args.checkpoint,
arg_board_size=args.board_size,
arg_win_length=args.win_length,
arg_channels=args.channels,
device=device,
)
env = GomokuEnv(board_size=board_size, win_length=win_length)
human_player = 1 if args.human_first else -1
print(f"device={device}")
print(
f"human={'X' if human_player == 1 else 'O'} ai={'O' if human_player == 1 else 'X'} "
f"agent={args.agent} mcts_sims={args.mcts_sims}"
)
while not env.done:
print()
print(env.render())
print()
if env.current_player == human_player:
action = ask_human_move(env)
else:
action, _ = choose_ai_action(
policy=policy,
board=env.board,
current_player=env.current_player,
win_length=win_length,
device=device,
agent=args.agent,
mcts_sims=args.mcts_sims,
c_puct=args.c_puct,
)
row, col = divmod(action, env.board_size)
print(f"ai move: {row + 1} {col + 1}")
env.step(action)
print()
print(env.render())
if env.winner == 0:
print("draw")
elif env.winner == human_player:
print("you win")
else:
print("ai wins")
def gui(args: argparse.Namespace) -> None:
try:
import pygame
except ModuleNotFoundError as exc:
raise SystemExit(
"pygame is not installed. Install it with: "
"~/miniconda3/bin/conda run -n lerobot python -m pip install pygame"
) from exc
device = choose_device(args.device)
policy, board_size, win_length = resolve_game_config(
checkpoint_path=args.checkpoint,
arg_board_size=args.board_size,
arg_win_length=args.win_length,
arg_channels=args.channels,
device=device,
)
env = GomokuEnv(board_size=board_size, win_length=win_length)
human_player = 1 if args.human_first else -1
last_search_visits: np.ndarray | None = None
pygame.init()
pygame.display.set_caption("Gomoku Policy Gradient")
font = pygame.font.SysFont("Arial", 24)
small_font = pygame.font.SysFont("Arial", 18)
cell_size = args.cell_size
padding = 40
status_height = 80
board_pixels = board_size * cell_size
screen = pygame.display.set_mode(
(board_pixels + padding * 2, board_pixels + padding * 2 + status_height)
)
clock = pygame.time.Clock()
background = (236, 196, 122)
line_color = (80, 55, 20)
black_stone = (20, 20, 20)
white_stone = (245, 245, 245)
accent = (180, 40, 40)
def board_to_screen(row: int, col: int) -> tuple[int, int]:
x = padding + col * cell_size + cell_size // 2
y = padding + row * cell_size + cell_size // 2
return x, y
def mouse_to_action(pos: tuple[int, int]) -> int | None:
x, y = pos
left = padding
top = padding
if x < left or y < top:
return None
col = (x - left) // cell_size
row = (y - top) // cell_size
if not (0 <= row < env.board_size and 0 <= col < env.board_size):
return None
if env.board[row, col] != 0:
return None
return row * env.board_size + col
def restart() -> None:
nonlocal last_search_visits
env.reset()
last_search_visits = None
if env.current_player != human_player:
ai_step()
def ai_step() -> None:
nonlocal last_search_visits
if env.done or env.current_player == human_player:
return
action, visits = choose_ai_action(
policy=policy,
board=env.board,
current_player=env.current_player,
win_length=win_length,
device=device,
agent=args.agent,
mcts_sims=args.mcts_sims,
c_puct=args.c_puct,
)
last_search_visits = visits
env.step(action)
def status_text() -> str:
if env.done:
if env.winner == 0:
return "Draw. Press R to restart."
if env.winner == human_player:
return "You win. Press R to restart."
return "AI wins. Press R to restart."
if env.current_player == human_player:
return "Your turn. Left click to place."
return "AI is thinking..."
if env.current_player != human_player:
ai_step()
running = True
while running:
for event in pygame.event.get():
if event.type == pygame.QUIT:
running = False
elif event.type == pygame.KEYDOWN:
if event.key == pygame.K_ESCAPE:
running = False
elif event.key == pygame.K_r:
restart()
elif event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
if env.done or env.current_player != human_player:
continue
action = mouse_to_action(event.pos)
if action is None:
continue
env.step(action)
ai_step()
screen.fill(background)
for idx in range(board_size + 1):
x = padding + idx * cell_size
pygame.draw.line(screen, line_color, (x, padding), (x, padding + board_pixels), 2)
y = padding + idx * cell_size
pygame.draw.line(screen, line_color, (padding, y), (padding + board_pixels, y), 2)
for row in range(env.board_size):
for col in range(env.board_size):
stone = int(env.board[row, col])
if stone == 0:
continue
x, y = board_to_screen(row, col)
color = black_stone if stone == 1 else white_stone
pygame.draw.circle(screen, color, (x, y), cell_size // 2 - 4)
pygame.draw.circle(screen, line_color, (x, y), cell_size // 2 - 4, 1)
for idx in range(board_size):
label = small_font.render(str(idx + 1), True, line_color)
screen.blit(
label,
(padding + idx * cell_size + cell_size // 2 - label.get_width() // 2, 8),
)
screen.blit(
label,
(8, padding + idx * cell_size + cell_size // 2 - label.get_height() // 2),
)
info = (
f"{board_size}x{board_size} connect={win_length} "
f"device={device} human={'X' if human_player == 1 else 'O'} "
f"agent={args.agent}"
)
info_surface = small_font.render(info, True, line_color)
status_surface = font.render(status_text(), True, accent)
screen.blit(info_surface, (padding, padding + board_pixels + 16))
screen.blit(status_surface, (padding, padding + board_pixels + 42))
if last_search_visits is not None and args.agent == "mcts":
peak = float(np.max(last_search_visits))
if peak > 0:
stats = small_font.render(
f"mcts_sims={args.mcts_sims} peak_visits={int(peak)}",
True,
line_color,
)
screen.blit(stats, (padding + 380, padding + board_pixels + 16))
pygame.display.flip()
clock.tick(args.fps)
pygame.quit()
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Minimal Gomoku policy gradient example")
subparsers = parser.add_subparsers(dest="mode", required=True)
def add_common_arguments(subparser: argparse.ArgumentParser, defaults_from_checkpoint: bool = False) -> None:
board_default = None if defaults_from_checkpoint else 15
win_default = None if defaults_from_checkpoint else 5
subparser.add_argument("--board-size", type=int, default=board_default)
subparser.add_argument("--win-length", type=int, default=win_default)
subparser.add_argument("--channels", type=int, default=64)
subparser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto")
subparser.add_argument("--checkpoint", type=Path, default=Path("gomoku_policy.pt"))
def add_inference_arguments(subparser: argparse.ArgumentParser, default_agent: str = "mcts") -> None:
subparser.add_argument("--agent", choices=["policy", "mcts"], default=default_agent)
subparser.add_argument("--mcts-sims", type=int, default=120)
subparser.add_argument("--c-puct", type=float, default=1.5)
train_parser = subparsers.add_parser("train", help="self-play training")
add_common_arguments(train_parser)
train_parser.add_argument("--episodes", type=int, default=5000)
train_parser.add_argument("--batch-size", type=int, default=32)
train_parser.add_argument("--lr", type=float, default=1e-3)
train_parser.add_argument("--gamma", type=float, default=0.99)
train_parser.add_argument("--entropy-coef", type=float, default=0.01)
train_parser.add_argument("--value-coef", type=float, default=0.5)
train_parser.add_argument("--grad-clip", type=float, default=1.0)
train_parser.add_argument("--print-every", type=int, default=100)
train_parser.add_argument("--eval-every", type=int, default=500)
train_parser.add_argument("--eval-games", type=int, default=40)
train_parser.add_argument("--seed", type=int, default=42)
train_parser.add_argument("--init-checkpoint", type=Path, default=None)
train_parser.add_argument(
"--no-symmetry-augment",
dest="symmetry_augment",
action="store_false",
help="disable random rotation/flip augmentation during training",
)
train_parser.set_defaults(symmetry_augment=True)
train_parser.set_defaults(func=train)
eval_parser = subparsers.add_parser("eval", help="evaluate against random agent")
add_common_arguments(eval_parser)
eval_parser.add_argument("--games", type=int, default=100)
add_inference_arguments(eval_parser)
eval_parser.set_defaults(func=evaluate)
play_parser = subparsers.add_parser("play", help="play against the trained model")
add_common_arguments(play_parser, defaults_from_checkpoint=True)
play_parser.add_argument("--human-first", action="store_true", help="human plays X")
add_inference_arguments(play_parser)
play_parser.set_defaults(func=play)
gui_parser = subparsers.add_parser("gui", help="pygame GUI for testing against the model")
add_common_arguments(gui_parser, defaults_from_checkpoint=True)
gui_parser.add_argument("--human-first", action="store_true", help="human plays X")
gui_parser.add_argument("--cell-size", type=int, default=48)
gui_parser.add_argument("--fps", type=int, default=30)
add_inference_arguments(gui_parser)
gui_parser.set_defaults(func=gui)
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
args.func(args)
if __name__ == "__main__":
main()