| |
| """Minimal Gomoku policy gradient example. |
| |
| Features: |
| 1. Configurable board size and win length, e.g. 5x5 connect-4 or 15x15 connect-5. |
| 2. Shared-policy self-play with REINFORCE. |
| 3. Fully convolutional policy, so the same code works for different board sizes. |
| 4. Optional random-agent evaluation and CLI human play. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import math |
| import random |
| from collections import deque |
| from dataclasses import dataclass, field |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
| from torch.distributions import Categorical |
|
|
|
|
| def choose_device(name: str) -> torch.device: |
| if name != "auto": |
| return torch.device(name) |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| return torch.device("mps") |
| return torch.device("cpu") |
|
|
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| class GomokuEnv: |
| def __init__(self, board_size: int, win_length: int): |
| if board_size <= 1: |
| raise ValueError("board_size must be > 1") |
| if not 1 < win_length <= board_size: |
| raise ValueError("win_length must satisfy 1 < win_length <= board_size") |
| self.board_size = board_size |
| self.win_length = win_length |
| self.reset() |
|
|
| def reset(self) -> np.ndarray: |
| self.board = np.zeros((self.board_size, self.board_size), dtype=np.int8) |
| self.current_player = 1 |
| self.done = False |
| self.winner = 0 |
| return self.board |
|
|
| def legal_mask(self) -> np.ndarray: |
| return self.board == 0 |
|
|
| def valid_moves(self) -> np.ndarray: |
| return np.flatnonzero(self.legal_mask().reshape(-1)) |
|
|
| def step(self, action: int) -> tuple[bool, int]: |
| if self.done: |
| raise RuntimeError("game is already finished") |
|
|
| row, col = divmod(int(action), self.board_size) |
| if self.board[row, col] != 0: |
| raise ValueError(f"illegal move at ({row}, {col})") |
|
|
| player = self.current_player |
| self.board[row, col] = player |
|
|
| if self._is_winning_move(row, col, player): |
| self.done = True |
| self.winner = player |
| elif not np.any(self.board == 0): |
| self.done = True |
| self.winner = 0 |
| else: |
| self.current_player = -player |
|
|
| return self.done, self.winner |
|
|
| def _is_winning_move(self, row: int, col: int, player: int) -> bool: |
| directions = ((1, 0), (0, 1), (1, 1), (1, -1)) |
| for dr, dc in directions: |
| count = 1 |
| count += self._count_one_side(row, col, dr, dc, player) |
| count += self._count_one_side(row, col, -dr, -dc, player) |
| if count >= self.win_length: |
| return True |
| return False |
|
|
| def _count_one_side(self, row: int, col: int, dr: int, dc: int, player: int) -> int: |
| total = 0 |
| r, c = row + dr, col + dc |
| while 0 <= r < self.board_size and 0 <= c < self.board_size: |
| if self.board[r, c] != player: |
| break |
| total += 1 |
| r += dr |
| c += dc |
| return total |
|
|
| def render(self) -> str: |
| symbols = {1: "X", -1: "O", 0: "."} |
| header = " " + " ".join(f"{i + 1:2d}" for i in range(self.board_size)) |
| rows = [header] |
| for row_idx in range(self.board_size): |
| row = " ".join(f"{symbols[int(v)]:>2}" for v in self.board[row_idx]) |
| rows.append(f"{row_idx + 1:2d} {row}") |
| return "\n".join(rows) |
|
|
|
|
| def encode_state(board: np.ndarray, current_player: int) -> torch.Tensor: |
| current = (board == current_player).astype(np.float32) |
| opponent = (board == -current_player).astype(np.float32) |
| legal = (board == 0).astype(np.float32) |
| stacked = np.stack([current, opponent, legal], axis=0) |
| return torch.from_numpy(stacked) |
|
|
|
|
| class PolicyValueNet(nn.Module): |
| def __init__(self, channels: int = 64): |
| super().__init__() |
| self.trunk = nn.Sequential( |
| nn.Conv2d(3, channels, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(channels, channels, kernel_size=3, padding=1), |
| nn.ReLU(), |
| nn.Conv2d(channels, channels, kernel_size=3, padding=1), |
| nn.ReLU(), |
| ) |
| self.policy_head = nn.Conv2d(channels, 1, kernel_size=1) |
| self.value_head = nn.Sequential( |
| nn.AdaptiveAvgPool2d(1), |
| nn.Flatten(), |
| nn.Linear(channels, channels), |
| nn.ReLU(), |
| nn.Linear(channels, 1), |
| nn.Tanh(), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| features = self.trunk(x) |
| policy_logits = self.policy_head(features).flatten(start_dim=1) |
| value = self.value_head(features).squeeze(-1) |
| return policy_logits, value |
|
|
|
|
| def masked_logits(logits: torch.Tensor, legal_mask: np.ndarray) -> torch.Tensor: |
| legal = torch.as_tensor(legal_mask.reshape(-1), device=logits.device, dtype=torch.bool) |
| return logits.masked_fill(~legal, -1e9) |
|
|
|
|
| def transform_board(board: np.ndarray, rotation_k: int, flip: bool) -> np.ndarray: |
| transformed = np.rot90(board, k=rotation_k) |
| if flip: |
| transformed = np.fliplr(transformed) |
| return np.ascontiguousarray(transformed) |
|
|
|
|
| def action_to_coords(action: int, board_size: int) -> tuple[int, int]: |
| return divmod(int(action), board_size) |
|
|
|
|
| def coords_to_action(row: int, col: int, board_size: int) -> int: |
| return row * board_size + col |
|
|
|
|
| def count_one_side( |
| board: np.ndarray, |
| row: int, |
| col: int, |
| dr: int, |
| dc: int, |
| player: int, |
| ) -> int: |
| board_size = board.shape[0] |
| total = 0 |
| r, c = row + dr, col + dc |
| while 0 <= r < board_size and 0 <= c < board_size: |
| if board[r, c] != player: |
| break |
| total += 1 |
| r += dr |
| c += dc |
| return total |
|
|
|
|
| def is_winning_move( |
| board: np.ndarray, |
| row: int, |
| col: int, |
| player: int, |
| win_length: int, |
| ) -> bool: |
| directions = ((1, 0), (0, 1), (1, 1), (1, -1)) |
| for dr, dc in directions: |
| count = 1 |
| count += count_one_side(board, row, col, dr, dc, player) |
| count += count_one_side(board, row, col, -dr, -dc, player) |
| if count >= win_length: |
| return True |
| return False |
|
|
|
|
| def apply_action_to_board( |
| board: np.ndarray, |
| current_player: int, |
| action: int, |
| win_length: int, |
| ) -> tuple[np.ndarray, int, bool, int]: |
| board_size = board.shape[0] |
| row, col = action_to_coords(action, board_size) |
| if board[row, col] != 0: |
| raise ValueError(f"illegal move at ({row}, {col})") |
|
|
| next_board = board.copy() |
| next_board[row, col] = current_player |
|
|
| if is_winning_move(next_board, row, col, current_player, win_length): |
| return next_board, -current_player, True, current_player |
| if not np.any(next_board == 0): |
| return next_board, -current_player, True, 0 |
| return next_board, -current_player, False, 0 |
|
|
|
|
| def forward_transform_coords( |
| row: int, |
| col: int, |
| board_size: int, |
| rotation_k: int, |
| flip: bool, |
| ) -> tuple[int, int]: |
| for _ in range(rotation_k % 4): |
| row, col = board_size - 1 - col, row |
| if flip: |
| col = board_size - 1 - col |
| return row, col |
|
|
|
|
| def inverse_transform_coords( |
| row: int, |
| col: int, |
| board_size: int, |
| rotation_k: int, |
| flip: bool, |
| ) -> tuple[int, int]: |
| if flip: |
| col = board_size - 1 - col |
| for _ in range(rotation_k % 4): |
| row, col = col, board_size - 1 - row |
| return row, col |
|
|
|
|
| def sample_action( |
| policy: PolicyValueNet, |
| board: np.ndarray, |
| current_player: int, |
| device: torch.device, |
| greedy: bool, |
| augment: bool, |
| ) -> tuple[int, torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: |
| board_size = board.shape[0] |
| rotation_k = random.randint(0, 3) if augment else 0 |
| flip = bool(random.getrandbits(1)) if augment else False |
| transformed_board = transform_board(board, rotation_k=rotation_k, flip=flip) |
|
|
| state = encode_state(transformed_board, current_player).unsqueeze(0).to(device) |
| logits, value = policy(state) |
| logits = masked_logits(logits.squeeze(0), transformed_board == 0) |
|
|
| if greedy: |
| action = torch.argmax(logits) |
| transformed_row, transformed_col = action_to_coords(int(action.item()), board_size) |
| row, col = inverse_transform_coords( |
| transformed_row, |
| transformed_col, |
| board_size, |
| rotation_k=rotation_k, |
| flip=flip, |
| ) |
| return coords_to_action(row, col, board_size), None, None, value.squeeze(0) |
|
|
| dist = Categorical(logits=logits) |
| action = dist.sample() |
| transformed_row, transformed_col = action_to_coords(int(action.item()), board_size) |
| row, col = inverse_transform_coords( |
| transformed_row, |
| transformed_col, |
| board_size, |
| rotation_k=rotation_k, |
| flip=flip, |
| ) |
| return ( |
| coords_to_action(row, col, board_size), |
| dist.log_prob(action), |
| dist.entropy(), |
| value.squeeze(0), |
| ) |
|
|
|
|
| def evaluate_policy_value( |
| policy: PolicyValueNet, |
| board: np.ndarray, |
| current_player: int, |
| device: torch.device, |
| ) -> tuple[np.ndarray, float]: |
| state = encode_state(board, current_player).unsqueeze(0).to(device) |
| with torch.no_grad(): |
| logits, value = policy(state) |
| logits = masked_logits(logits.squeeze(0), board == 0) |
| probs = torch.softmax(logits, dim=0).detach().cpu().numpy() |
| return probs, float(value.item()) |
|
|
|
|
| @dataclass |
| class MCTSNode: |
| board: np.ndarray |
| current_player: int |
| win_length: int |
| done: bool = False |
| winner: int = 0 |
| priors: dict[int, float] = field(default_factory=dict) |
| visit_counts: dict[int, int] = field(default_factory=dict) |
| value_sums: dict[int, float] = field(default_factory=dict) |
| children: dict[int, "MCTSNode"] = field(default_factory=dict) |
| expanded: bool = False |
|
|
| def expand(self, priors: np.ndarray) -> None: |
| legal_actions = np.flatnonzero(self.board.reshape(-1) == 0) |
| total_prob = float(np.sum(priors[legal_actions])) |
| if total_prob <= 0.0: |
| uniform = 1.0 / max(len(legal_actions), 1) |
| self.priors = {int(action): uniform for action in legal_actions} |
| else: |
| self.priors = { |
| int(action): float(priors[action] / total_prob) |
| for action in legal_actions |
| } |
| self.visit_counts = {action: 0 for action in self.priors} |
| self.value_sums = {action: 0.0 for action in self.priors} |
| self.expanded = True |
|
|
| def q_value(self, action: int) -> float: |
| visits = self.visit_counts[action] |
| if visits == 0: |
| return 0.0 |
| return self.value_sums[action] / visits |
|
|
| def select_action(self, c_puct: float) -> int: |
| total_visits = sum(self.visit_counts.values()) |
| sqrt_total = math.sqrt(total_visits + 1.0) |
| best_action = -1 |
| best_score = -float("inf") |
|
|
| for action, prior in self.priors.items(): |
| visits = self.visit_counts[action] |
| q = self.q_value(action) |
| u = c_puct * prior * sqrt_total / (1.0 + visits) |
| score = q + u |
| if score > best_score: |
| best_score = score |
| best_action = action |
|
|
| return best_action |
|
|
| def child_for_action(self, action: int) -> "MCTSNode": |
| child = self.children.get(action) |
| if child is not None: |
| return child |
|
|
| next_board, next_player, done, winner = apply_action_to_board( |
| board=self.board, |
| current_player=self.current_player, |
| action=action, |
| win_length=self.win_length, |
| ) |
| child = MCTSNode( |
| board=next_board, |
| current_player=next_player, |
| win_length=self.win_length, |
| done=done, |
| winner=winner, |
| ) |
| self.children[action] = child |
| return child |
|
|
|
|
| def terminal_value(winner: int, current_player: int) -> float: |
| if winner == 0: |
| return 0.0 |
| return 1.0 if winner == current_player else -1.0 |
|
|
|
|
| def choose_mcts_action( |
| policy: PolicyValueNet, |
| board: np.ndarray, |
| current_player: int, |
| win_length: int, |
| device: torch.device, |
| num_simulations: int, |
| c_puct: float, |
| ) -> tuple[int, np.ndarray]: |
| root = MCTSNode( |
| board=board.copy(), |
| current_player=current_player, |
| win_length=win_length, |
| ) |
|
|
| priors, _ = evaluate_policy_value(policy, root.board, root.current_player, device) |
| root.expand(priors) |
|
|
| for _ in range(num_simulations): |
| node = root |
| path: list[tuple[MCTSNode, int]] = [] |
|
|
| while node.expanded and not node.done: |
| action = node.select_action(c_puct) |
| path.append((node, action)) |
| node = node.child_for_action(action) |
|
|
| if node.done: |
| value = terminal_value(node.winner, node.current_player) |
| else: |
| priors, value = evaluate_policy_value(policy, node.board, node.current_player, device) |
| node.expand(priors) |
|
|
| for parent, action in reversed(path): |
| value = -value |
| parent.visit_counts[action] += 1 |
| parent.value_sums[action] += value |
|
|
| visits = np.zeros(board.size, dtype=np.float32) |
| for action, count in root.visit_counts.items(): |
| visits[action] = float(count) |
|
|
| if np.all(visits == 0): |
| best_action = int(np.argmax(priors)) |
| else: |
| best_action = int(np.argmax(visits)) |
|
|
| return best_action, visits.reshape(board.shape) |
|
|
|
|
| def choose_ai_action( |
| policy: PolicyValueNet, |
| board: np.ndarray, |
| current_player: int, |
| win_length: int, |
| device: torch.device, |
| agent: str, |
| mcts_sims: int, |
| c_puct: float, |
| ) -> tuple[int, np.ndarray | None]: |
| if agent == "mcts": |
| return choose_mcts_action( |
| policy=policy, |
| board=board, |
| current_player=current_player, |
| win_length=win_length, |
| device=device, |
| num_simulations=mcts_sims, |
| c_puct=c_puct, |
| ) |
|
|
| action, _, _, _ = sample_action( |
| policy=policy, |
| board=board, |
| current_player=current_player, |
| device=device, |
| greedy=True, |
| augment=False, |
| ) |
| return action, None |
|
|
|
|
| def self_play_episode( |
| policy: PolicyValueNet, |
| env: GomokuEnv, |
| device: torch.device, |
| gamma: float, |
| augment: bool, |
| ) -> tuple[list[torch.Tensor], list[float], list[torch.Tensor], list[torch.Tensor], int, int]: |
| env.reset() |
| log_probs: list[torch.Tensor] = [] |
| entropies: list[torch.Tensor] = [] |
| values: list[torch.Tensor] = [] |
| players: list[int] = [] |
|
|
| while not env.done: |
| player = env.current_player |
| action, log_prob, entropy, value = sample_action( |
| policy=policy, |
| board=env.board, |
| current_player=player, |
| device=device, |
| greedy=False, |
| augment=augment, |
| ) |
| log_probs.append(log_prob) |
| entropies.append(entropy) |
| values.append(value) |
| players.append(player) |
| env.step(action) |
|
|
| returns: list[float] = [] |
| total_moves = len(players) |
| for move_idx, player in enumerate(players): |
| outcome = 0.0 |
| if env.winner != 0: |
| outcome = 1.0 if player == env.winner else -1.0 |
| discounted = outcome * (gamma ** (total_moves - move_idx - 1)) |
| returns.append(discounted) |
|
|
| return log_probs, returns, entropies, values, env.winner, total_moves |
|
|
|
|
| def update_policy( |
| optimizer: torch.optim.Optimizer, |
| batch_log_probs: list[torch.Tensor], |
| batch_returns: list[float], |
| batch_entropies: list[torch.Tensor], |
| batch_values: list[torch.Tensor], |
| entropy_coef: float, |
| value_coef: float, |
| grad_clip: float, |
| device: torch.device, |
| ) -> float: |
| returns = torch.tensor(batch_returns, dtype=torch.float32, device=device) |
| log_probs = torch.stack(batch_log_probs) |
| entropies = torch.stack(batch_entropies) |
| values = torch.stack(batch_values) |
| advantages = returns - values.detach() |
| if advantages.numel() > 1: |
| advantages = (advantages - advantages.mean()) / (advantages.std(unbiased=False) + 1e-6) |
|
|
| policy_loss = -(log_probs * advantages).mean() |
| value_loss = torch.mean((values - returns) ** 2) |
| entropy_bonus = entropies.mean() |
| loss = policy_loss + value_coef * value_loss - entropy_coef * entropy_bonus |
|
|
| optimizer.zero_grad(set_to_none=True) |
| loss.backward() |
| nn.utils.clip_grad_norm_(optimizer.param_groups[0]["params"], grad_clip) |
| optimizer.step() |
| return float(loss.item()) |
|
|
|
|
| def save_checkpoint( |
| path: Path, |
| policy: PolicyValueNet, |
| args: argparse.Namespace, |
| ) -> None: |
| payload = { |
| "state_dict": policy.state_dict(), |
| "channels": args.channels, |
| "board_size": args.board_size, |
| "win_length": args.win_length, |
| } |
| torch.save(payload, path) |
|
|
|
|
| def load_checkpoint(path: Path, map_location: torch.device) -> dict: |
| checkpoint = torch.load(path, map_location=map_location) |
| if isinstance(checkpoint, dict) and "state_dict" in checkpoint: |
| return checkpoint |
| if isinstance(checkpoint, dict) and "policy_state_dict" in checkpoint: |
| raise RuntimeError( |
| f"{path} is an old fixed-board checkpoint from the previous implementation. " |
| "It is not compatible with the current fully-convolutional actor-critic model. " |
| "Please retrain with the current script." |
| ) |
| return { |
| "state_dict": checkpoint, |
| "channels": 64, |
| "board_size": None, |
| "win_length": None, |
| } |
|
|
|
|
| def load_policy(path: Path, channels: int, device: torch.device) -> PolicyValueNet: |
| checkpoint = load_checkpoint(path, map_location=device) |
| state_dict = checkpoint["state_dict"] |
| saved_channels = int(checkpoint.get("channels", channels)) |
|
|
| policy = PolicyValueNet(channels=saved_channels).to(device) |
| policy.load_state_dict(state_dict) |
| policy.eval() |
| return policy |
|
|
|
|
| def resolve_game_config( |
| checkpoint_path: Path, |
| arg_board_size: int | None, |
| arg_win_length: int | None, |
| arg_channels: int, |
| device: torch.device, |
| ) -> tuple[PolicyValueNet, int, int]: |
| checkpoint = load_checkpoint(checkpoint_path, map_location=device) |
| board_size = int(checkpoint.get("board_size") or arg_board_size or 15) |
| win_length = int(checkpoint.get("win_length") or arg_win_length or 5) |
| channels = int(checkpoint.get("channels") or arg_channels) |
|
|
| policy = PolicyValueNet(channels=channels).to(device) |
| policy.load_state_dict(checkpoint["state_dict"]) |
| policy.eval() |
| return policy, board_size, win_length |
|
|
|
|
| def play_vs_random_once( |
| policy: PolicyValueNet, |
| board_size: int, |
| win_length: int, |
| device: torch.device, |
| policy_player: int, |
| agent: str = "policy", |
| mcts_sims: int = 100, |
| c_puct: float = 1.5, |
| ) -> int: |
| env = GomokuEnv(board_size=board_size, win_length=win_length) |
| env.reset() |
|
|
| while not env.done: |
| if env.current_player == policy_player: |
| action, _ = choose_ai_action( |
| policy=policy, |
| board=env.board, |
| current_player=env.current_player, |
| win_length=win_length, |
| device=device, |
| agent=agent, |
| mcts_sims=mcts_sims, |
| c_puct=c_puct, |
| ) |
| else: |
| action = int(np.random.choice(env.valid_moves())) |
| env.step(action) |
|
|
| return env.winner |
|
|
|
|
| def evaluate_vs_random( |
| policy: PolicyValueNet, |
| board_size: int, |
| win_length: int, |
| device: torch.device, |
| games: int, |
| agent: str = "policy", |
| mcts_sims: int = 100, |
| c_puct: float = 1.5, |
| ) -> tuple[float, int, int, int]: |
| wins = 0 |
| draws = 0 |
| losses = 0 |
|
|
| for game_idx in range(games): |
| policy_player = 1 if game_idx < games // 2 else -1 |
| winner = play_vs_random_once( |
| policy=policy, |
| board_size=board_size, |
| win_length=win_length, |
| device=device, |
| policy_player=policy_player, |
| agent=agent, |
| mcts_sims=mcts_sims, |
| c_puct=c_puct, |
| ) |
| if winner == 0: |
| draws += 1 |
| elif winner == policy_player: |
| wins += 1 |
| else: |
| losses += 1 |
|
|
| return wins / max(games, 1), wins, draws, losses |
|
|
|
|
| def train(args: argparse.Namespace) -> None: |
| set_seed(args.seed) |
| device = choose_device(args.device) |
| env = GomokuEnv(board_size=args.board_size, win_length=args.win_length) |
| policy = PolicyValueNet(channels=args.channels).to(device) |
| if args.init_checkpoint is not None and args.init_checkpoint.exists(): |
| checkpoint = load_checkpoint(args.init_checkpoint, map_location=device) |
| policy.load_state_dict(checkpoint["state_dict"]) |
| optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr) |
|
|
| recent_winners: deque[int] = deque(maxlen=args.print_every) |
| recent_lengths: deque[int] = deque(maxlen=args.print_every) |
| batch_log_probs: list[torch.Tensor] = [] |
| batch_returns: list[float] = [] |
| batch_entropies: list[torch.Tensor] = [] |
| batch_values: list[torch.Tensor] = [] |
| last_loss = 0.0 |
|
|
| print(f"device={device} board={args.board_size} win={args.win_length}") |
|
|
| for episode in range(1, args.episodes + 1): |
| log_probs, returns, entropies, values, winner, moves = self_play_episode( |
| policy=policy, |
| env=env, |
| device=device, |
| gamma=args.gamma, |
| augment=args.symmetry_augment, |
| ) |
| batch_log_probs.extend(log_probs) |
| batch_returns.extend(returns) |
| batch_entropies.extend(entropies) |
| batch_values.extend(values) |
| recent_winners.append(winner) |
| recent_lengths.append(moves) |
|
|
| if episode % args.batch_size == 0 or episode == args.episodes: |
| policy.train() |
| last_loss = update_policy( |
| optimizer=optimizer, |
| batch_log_probs=batch_log_probs, |
| batch_returns=batch_returns, |
| batch_entropies=batch_entropies, |
| batch_values=batch_values, |
| entropy_coef=args.entropy_coef, |
| value_coef=args.value_coef, |
| grad_clip=args.grad_clip, |
| device=device, |
| ) |
| batch_log_probs.clear() |
| batch_returns.clear() |
| batch_entropies.clear() |
| batch_values.clear() |
|
|
| if episode % args.print_every == 0 or episode == args.episodes: |
| p1_wins = sum(1 for x in recent_winners if x == 1) |
| p2_wins = sum(1 for x in recent_winners if x == -1) |
| draws = sum(1 for x in recent_winners if x == 0) |
| avg_len = float(np.mean(recent_lengths)) if recent_lengths else 0.0 |
| message = ( |
| f"episode={episode:6d} loss={last_loss:8.4f} " |
| f"p1={p1_wins:4d} p2={p2_wins:4d} draw={draws:4d} avg_len={avg_len:6.2f}" |
| ) |
| if args.eval_every > 0 and episode % args.eval_every == 0: |
| policy.eval() |
| win_rate, wins, eval_draws, losses = evaluate_vs_random( |
| policy=policy, |
| board_size=args.board_size, |
| win_length=args.win_length, |
| device=device, |
| games=args.eval_games, |
| ) |
| message += ( |
| f" random_win_rate={win_rate:.3f}" |
| f" ({wins}/{eval_draws}/{losses})" |
| ) |
| print(message) |
|
|
| save_checkpoint(args.checkpoint, policy, args) |
| print(f"saved checkpoint to {args.checkpoint}") |
|
|
|
|
| def evaluate(args: argparse.Namespace) -> None: |
| device = choose_device(args.device) |
| policy, board_size, win_length = resolve_game_config( |
| checkpoint_path=args.checkpoint, |
| arg_board_size=args.board_size, |
| arg_win_length=args.win_length, |
| arg_channels=args.channels, |
| device=device, |
| ) |
| win_rate, wins, draws, losses = evaluate_vs_random( |
| policy=policy, |
| board_size=board_size, |
| win_length=win_length, |
| device=device, |
| games=args.games, |
| agent=args.agent, |
| mcts_sims=args.mcts_sims, |
| c_puct=args.c_puct, |
| ) |
| print(f"device={device}") |
| print(f"agent={args.agent} mcts_sims={args.mcts_sims}") |
| print(f"win_rate={win_rate:.3f} wins={wins} draws={draws} losses={losses}") |
|
|
|
|
| def ask_human_move(env: GomokuEnv) -> int: |
| while True: |
| text = input("your move (row col): ").strip() |
| parts = text.replace(",", " ").split() |
| if len(parts) != 2: |
| print("please enter: row col") |
| continue |
| try: |
| row, col = (int(parts[0]) - 1, int(parts[1]) - 1) |
| except ValueError: |
| print("row and col must be integers") |
| continue |
| if not (0 <= row < env.board_size and 0 <= col < env.board_size): |
| print("move out of range") |
| continue |
| if env.board[row, col] != 0: |
| print("that position is occupied") |
| continue |
| return row * env.board_size + col |
|
|
|
|
| def play(args: argparse.Namespace) -> None: |
| device = choose_device(args.device) |
| policy, board_size, win_length = resolve_game_config( |
| checkpoint_path=args.checkpoint, |
| arg_board_size=args.board_size, |
| arg_win_length=args.win_length, |
| arg_channels=args.channels, |
| device=device, |
| ) |
| env = GomokuEnv(board_size=board_size, win_length=win_length) |
| human_player = 1 if args.human_first else -1 |
|
|
| print(f"device={device}") |
| print( |
| f"human={'X' if human_player == 1 else 'O'} ai={'O' if human_player == 1 else 'X'} " |
| f"agent={args.agent} mcts_sims={args.mcts_sims}" |
| ) |
|
|
| while not env.done: |
| print() |
| print(env.render()) |
| print() |
|
|
| if env.current_player == human_player: |
| action = ask_human_move(env) |
| else: |
| action, _ = choose_ai_action( |
| policy=policy, |
| board=env.board, |
| current_player=env.current_player, |
| win_length=win_length, |
| device=device, |
| agent=args.agent, |
| mcts_sims=args.mcts_sims, |
| c_puct=args.c_puct, |
| ) |
| row, col = divmod(action, env.board_size) |
| print(f"ai move: {row + 1} {col + 1}") |
|
|
| env.step(action) |
|
|
| print() |
| print(env.render()) |
| if env.winner == 0: |
| print("draw") |
| elif env.winner == human_player: |
| print("you win") |
| else: |
| print("ai wins") |
|
|
|
|
| def gui(args: argparse.Namespace) -> None: |
| try: |
| import pygame |
| except ModuleNotFoundError as exc: |
| raise SystemExit( |
| "pygame is not installed. Install it with: " |
| "~/miniconda3/bin/conda run -n lerobot python -m pip install pygame" |
| ) from exc |
|
|
| device = choose_device(args.device) |
| policy, board_size, win_length = resolve_game_config( |
| checkpoint_path=args.checkpoint, |
| arg_board_size=args.board_size, |
| arg_win_length=args.win_length, |
| arg_channels=args.channels, |
| device=device, |
| ) |
| env = GomokuEnv(board_size=board_size, win_length=win_length) |
| human_player = 1 if args.human_first else -1 |
| last_search_visits: np.ndarray | None = None |
|
|
| pygame.init() |
| pygame.display.set_caption("Gomoku Policy Gradient") |
| font = pygame.font.SysFont("Arial", 24) |
| small_font = pygame.font.SysFont("Arial", 18) |
|
|
| cell_size = args.cell_size |
| padding = 40 |
| status_height = 80 |
| board_pixels = board_size * cell_size |
| screen = pygame.display.set_mode( |
| (board_pixels + padding * 2, board_pixels + padding * 2 + status_height) |
| ) |
| clock = pygame.time.Clock() |
|
|
| background = (236, 196, 122) |
| line_color = (80, 55, 20) |
| black_stone = (20, 20, 20) |
| white_stone = (245, 245, 245) |
| accent = (180, 40, 40) |
|
|
| def board_to_screen(row: int, col: int) -> tuple[int, int]: |
| x = padding + col * cell_size + cell_size // 2 |
| y = padding + row * cell_size + cell_size // 2 |
| return x, y |
|
|
| def mouse_to_action(pos: tuple[int, int]) -> int | None: |
| x, y = pos |
| left = padding |
| top = padding |
| if x < left or y < top: |
| return None |
| col = (x - left) // cell_size |
| row = (y - top) // cell_size |
| if not (0 <= row < env.board_size and 0 <= col < env.board_size): |
| return None |
| if env.board[row, col] != 0: |
| return None |
| return row * env.board_size + col |
|
|
| def restart() -> None: |
| nonlocal last_search_visits |
| env.reset() |
| last_search_visits = None |
| if env.current_player != human_player: |
| ai_step() |
|
|
| def ai_step() -> None: |
| nonlocal last_search_visits |
| if env.done or env.current_player == human_player: |
| return |
| action, visits = choose_ai_action( |
| policy=policy, |
| board=env.board, |
| current_player=env.current_player, |
| win_length=win_length, |
| device=device, |
| agent=args.agent, |
| mcts_sims=args.mcts_sims, |
| c_puct=args.c_puct, |
| ) |
| last_search_visits = visits |
| env.step(action) |
|
|
| def status_text() -> str: |
| if env.done: |
| if env.winner == 0: |
| return "Draw. Press R to restart." |
| if env.winner == human_player: |
| return "You win. Press R to restart." |
| return "AI wins. Press R to restart." |
| if env.current_player == human_player: |
| return "Your turn. Left click to place." |
| return "AI is thinking..." |
|
|
| if env.current_player != human_player: |
| ai_step() |
|
|
| running = True |
| while running: |
| for event in pygame.event.get(): |
| if event.type == pygame.QUIT: |
| running = False |
| elif event.type == pygame.KEYDOWN: |
| if event.key == pygame.K_ESCAPE: |
| running = False |
| elif event.key == pygame.K_r: |
| restart() |
| elif event.type == pygame.MOUSEBUTTONDOWN and event.button == 1: |
| if env.done or env.current_player != human_player: |
| continue |
| action = mouse_to_action(event.pos) |
| if action is None: |
| continue |
| env.step(action) |
| ai_step() |
|
|
| screen.fill(background) |
|
|
| for idx in range(board_size + 1): |
| x = padding + idx * cell_size |
| pygame.draw.line(screen, line_color, (x, padding), (x, padding + board_pixels), 2) |
| y = padding + idx * cell_size |
| pygame.draw.line(screen, line_color, (padding, y), (padding + board_pixels, y), 2) |
|
|
| for row in range(env.board_size): |
| for col in range(env.board_size): |
| stone = int(env.board[row, col]) |
| if stone == 0: |
| continue |
| x, y = board_to_screen(row, col) |
| color = black_stone if stone == 1 else white_stone |
| pygame.draw.circle(screen, color, (x, y), cell_size // 2 - 4) |
| pygame.draw.circle(screen, line_color, (x, y), cell_size // 2 - 4, 1) |
|
|
| for idx in range(board_size): |
| label = small_font.render(str(idx + 1), True, line_color) |
| screen.blit( |
| label, |
| (padding + idx * cell_size + cell_size // 2 - label.get_width() // 2, 8), |
| ) |
| screen.blit( |
| label, |
| (8, padding + idx * cell_size + cell_size // 2 - label.get_height() // 2), |
| ) |
|
|
| info = ( |
| f"{board_size}x{board_size} connect={win_length} " |
| f"device={device} human={'X' if human_player == 1 else 'O'} " |
| f"agent={args.agent}" |
| ) |
| info_surface = small_font.render(info, True, line_color) |
| status_surface = font.render(status_text(), True, accent) |
| screen.blit(info_surface, (padding, padding + board_pixels + 16)) |
| screen.blit(status_surface, (padding, padding + board_pixels + 42)) |
|
|
| if last_search_visits is not None and args.agent == "mcts": |
| peak = float(np.max(last_search_visits)) |
| if peak > 0: |
| stats = small_font.render( |
| f"mcts_sims={args.mcts_sims} peak_visits={int(peak)}", |
| True, |
| line_color, |
| ) |
| screen.blit(stats, (padding + 380, padding + board_pixels + 16)) |
|
|
| pygame.display.flip() |
| clock.tick(args.fps) |
|
|
| pygame.quit() |
|
|
|
|
| def build_parser() -> argparse.ArgumentParser: |
| parser = argparse.ArgumentParser(description="Minimal Gomoku policy gradient example") |
| subparsers = parser.add_subparsers(dest="mode", required=True) |
|
|
| def add_common_arguments(subparser: argparse.ArgumentParser, defaults_from_checkpoint: bool = False) -> None: |
| board_default = None if defaults_from_checkpoint else 15 |
| win_default = None if defaults_from_checkpoint else 5 |
| subparser.add_argument("--board-size", type=int, default=board_default) |
| subparser.add_argument("--win-length", type=int, default=win_default) |
| subparser.add_argument("--channels", type=int, default=64) |
| subparser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto") |
| subparser.add_argument("--checkpoint", type=Path, default=Path("gomoku_policy.pt")) |
|
|
| def add_inference_arguments(subparser: argparse.ArgumentParser, default_agent: str = "mcts") -> None: |
| subparser.add_argument("--agent", choices=["policy", "mcts"], default=default_agent) |
| subparser.add_argument("--mcts-sims", type=int, default=120) |
| subparser.add_argument("--c-puct", type=float, default=1.5) |
|
|
| train_parser = subparsers.add_parser("train", help="self-play training") |
| add_common_arguments(train_parser) |
| train_parser.add_argument("--episodes", type=int, default=5000) |
| train_parser.add_argument("--batch-size", type=int, default=32) |
| train_parser.add_argument("--lr", type=float, default=1e-3) |
| train_parser.add_argument("--gamma", type=float, default=0.99) |
| train_parser.add_argument("--entropy-coef", type=float, default=0.01) |
| train_parser.add_argument("--value-coef", type=float, default=0.5) |
| train_parser.add_argument("--grad-clip", type=float, default=1.0) |
| train_parser.add_argument("--print-every", type=int, default=100) |
| train_parser.add_argument("--eval-every", type=int, default=500) |
| train_parser.add_argument("--eval-games", type=int, default=40) |
| train_parser.add_argument("--seed", type=int, default=42) |
| train_parser.add_argument("--init-checkpoint", type=Path, default=None) |
| train_parser.add_argument( |
| "--no-symmetry-augment", |
| dest="symmetry_augment", |
| action="store_false", |
| help="disable random rotation/flip augmentation during training", |
| ) |
| train_parser.set_defaults(symmetry_augment=True) |
| train_parser.set_defaults(func=train) |
|
|
| eval_parser = subparsers.add_parser("eval", help="evaluate against random agent") |
| add_common_arguments(eval_parser) |
| eval_parser.add_argument("--games", type=int, default=100) |
| add_inference_arguments(eval_parser) |
| eval_parser.set_defaults(func=evaluate) |
|
|
| play_parser = subparsers.add_parser("play", help="play against the trained model") |
| add_common_arguments(play_parser, defaults_from_checkpoint=True) |
| play_parser.add_argument("--human-first", action="store_true", help="human plays X") |
| add_inference_arguments(play_parser) |
| play_parser.set_defaults(func=play) |
|
|
| gui_parser = subparsers.add_parser("gui", help="pygame GUI for testing against the model") |
| add_common_arguments(gui_parser, defaults_from_checkpoint=True) |
| gui_parser.add_argument("--human-first", action="store_true", help="human plays X") |
| gui_parser.add_argument("--cell-size", type=int, default=48) |
| gui_parser.add_argument("--fps", type=int, default=30) |
| add_inference_arguments(gui_parser) |
| gui_parser.set_defaults(func=gui) |
|
|
| return parser |
|
|
|
|
| def main() -> None: |
| parser = build_parser() |
| args = parser.parse_args() |
| args.func(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|