#!/usr/bin/env python3 """Minimal Gomoku MCTS example. This file is intentionally separate from gomoku_pg.py. It uses the simpler AlphaZero-style recipe: 1. self-play with MCTS 2. policy/value targets from search 3. supervised update on policy + value heads """ from __future__ import annotations import argparse import math import random from collections import deque from dataclasses import dataclass, field from pathlib import Path import numpy as np import torch from torch import nn def choose_device(name: str) -> torch.device: if name != "auto": return torch.device(name) if torch.cuda.is_available(): return torch.device("cuda") if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) class GomokuEnv: def __init__(self, board_size: int, win_length: int): if board_size <= 1: raise ValueError("board_size must be > 1") if not 1 < win_length <= board_size: raise ValueError("win_length must satisfy 1 < win_length <= board_size") self.board_size = board_size self.win_length = win_length self.reset() def reset(self) -> np.ndarray: self.board = np.zeros((self.board_size, self.board_size), dtype=np.int8) self.current_player = 1 self.done = False self.winner = 0 return self.board def valid_moves(self) -> np.ndarray: return np.flatnonzero((self.board == 0).reshape(-1)) def step(self, action: int) -> tuple[bool, int]: next_board, next_player, done, winner = apply_action_to_board( board=self.board, current_player=self.current_player, action=action, win_length=self.win_length, ) self.board = next_board self.current_player = next_player self.done = done self.winner = winner return done, winner def render(self) -> str: symbols = {1: "X", -1: "O", 0: "."} header = " " + " ".join(f"{i + 1:2d}" for i in range(self.board_size)) rows = [header] for row_idx in range(self.board_size): row = " ".join(f"{symbols[int(v)]:>2}" for v in self.board[row_idx]) rows.append(f"{row_idx + 1:2d} {row}") return "\n".join(rows) def action_to_coords(action: int, board_size: int) -> tuple[int, int]: return divmod(int(action), board_size) def coords_to_action(row: int, col: int, board_size: int) -> int: return row * board_size + col def count_one_side( board: np.ndarray, row: int, col: int, dr: int, dc: int, player: int, ) -> int: board_size = board.shape[0] total = 0 r, c = row + dr, col + dc while 0 <= r < board_size and 0 <= c < board_size: if board[r, c] != player: break total += 1 r += dr c += dc return total def is_winning_move( board: np.ndarray, row: int, col: int, player: int, win_length: int, ) -> bool: directions = ((1, 0), (0, 1), (1, 1), (1, -1)) for dr, dc in directions: count = 1 count += count_one_side(board, row, col, dr, dc, player) count += count_one_side(board, row, col, -dr, -dc, player) if count >= win_length: return True return False def apply_action_to_board( board: np.ndarray, current_player: int, action: int, win_length: int, ) -> tuple[np.ndarray, int, bool, int]: board_size = board.shape[0] row, col = action_to_coords(action, board_size) if board[row, col] != 0: raise ValueError(f"illegal move at ({row}, {col})") next_board = board.copy() next_board[row, col] = current_player if is_winning_move(next_board, row, col, current_player, win_length): return next_board, -current_player, True, current_player if not np.any(next_board == 0): return next_board, -current_player, True, 0 return next_board, -current_player, False, 0 def encode_state(board: np.ndarray, current_player: int) -> torch.Tensor: current = (board == current_player).astype(np.float32) opponent = (board == -current_player).astype(np.float32) legal = (board == 0).astype(np.float32) return torch.from_numpy(np.stack([current, opponent, legal], axis=0)) class PolicyValueNet(nn.Module): def __init__(self, channels: int = 64): super().__init__() self.trunk = nn.Sequential( nn.Conv2d(3, channels, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(channels, channels, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(channels, channels, kernel_size=3, padding=1), nn.ReLU(), ) self.policy_head = nn.Conv2d(channels, 1, kernel_size=1) self.value_head = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(channels, channels), nn.ReLU(), nn.Linear(channels, 1), nn.Tanh(), ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: features = self.trunk(x) policy_logits = self.policy_head(features).flatten(start_dim=1) value = self.value_head(features).squeeze(-1) return policy_logits, value def masked_logits(logits: torch.Tensor, board: np.ndarray) -> torch.Tensor: legal = torch.as_tensor((board == 0).reshape(-1), device=logits.device, dtype=torch.bool) return logits.masked_fill(~legal, -1e9) def evaluate_policy_value( policy: PolicyValueNet, board: np.ndarray, current_player: int, device: torch.device, ) -> tuple[np.ndarray, float]: state = encode_state(board, current_player).unsqueeze(0).to(device) with torch.no_grad(): logits, value = policy(state) logits = masked_logits(logits.squeeze(0), board) probs = torch.softmax(logits, dim=0).cpu().numpy() return probs, float(value.item()) @dataclass class MCTSNode: board: np.ndarray current_player: int win_length: int done: bool = False winner: int = 0 priors: dict[int, float] = field(default_factory=dict) visit_counts: dict[int, int] = field(default_factory=dict) value_sums: dict[int, float] = field(default_factory=dict) children: dict[int, "MCTSNode"] = field(default_factory=dict) expanded: bool = False def expand( self, priors: np.ndarray, add_noise: bool = False, dirichlet_alpha: float = 0.3, noise_eps: float = 0.25, ) -> None: legal_actions = np.flatnonzero((self.board == 0).reshape(-1)) legal_priors = priors[legal_actions] total_prob = float(np.sum(legal_priors)) if total_prob <= 0.0: legal_priors = np.full(len(legal_actions), 1.0 / max(len(legal_actions), 1), dtype=np.float32) else: legal_priors = legal_priors / total_prob if add_noise and len(legal_actions) > 0: noise = np.random.dirichlet([dirichlet_alpha] * len(legal_actions)) legal_priors = (1.0 - noise_eps) * legal_priors + noise_eps * noise self.priors = { int(action): float(prior) for action, prior in zip(legal_actions, legal_priors, strict=False) } self.visit_counts = {action: 0 for action in self.priors} self.value_sums = {action: 0.0 for action in self.priors} self.expanded = True def q_value(self, action: int) -> float: visits = self.visit_counts[action] if visits == 0: return 0.0 return self.value_sums[action] / visits def select_action(self, c_puct: float) -> int: total_visits = sum(self.visit_counts.values()) sqrt_total = math.sqrt(total_visits + 1.0) best_action = -1 best_score = -float("inf") for action, prior in self.priors.items(): q = self.q_value(action) u = c_puct * prior * sqrt_total / (1.0 + self.visit_counts[action]) score = q + u if score > best_score: best_score = score best_action = action return best_action def child_for_action(self, action: int) -> "MCTSNode": child = self.children.get(action) if child is not None: return child next_board, next_player, done, winner = apply_action_to_board( board=self.board, current_player=self.current_player, action=action, win_length=self.win_length, ) child = MCTSNode( board=next_board, current_player=next_player, win_length=self.win_length, done=done, winner=winner, ) self.children[action] = child return child def terminal_value(winner: int, current_player: int) -> float: if winner == 0: return 0.0 return 1.0 if winner == current_player else -1.0 def sample_from_visits(visits: np.ndarray, temperature: float) -> tuple[int, np.ndarray]: flat = visits.reshape(-1).astype(np.float64) if np.all(flat == 0): flat = np.ones_like(flat) if temperature <= 1e-6: probs = np.zeros_like(flat, dtype=np.float64) probs[int(np.argmax(flat))] = 1.0 else: adjusted = np.power(flat, 1.0 / temperature) probs = adjusted / np.sum(adjusted) action = int(np.random.choice(len(probs), p=probs)) return action, probs.reshape(visits.shape).astype(np.float32) def choose_mcts_action( policy: PolicyValueNet, board: np.ndarray, current_player: int, win_length: int, device: torch.device, num_simulations: int, c_puct: float, temperature: float, add_root_noise: bool, dirichlet_alpha: float, noise_eps: float, ) -> tuple[int, np.ndarray]: root = MCTSNode(board=board.copy(), current_player=current_player, win_length=win_length) priors, _ = evaluate_policy_value(policy, root.board, root.current_player, device) root.expand( priors, add_noise=add_root_noise, dirichlet_alpha=dirichlet_alpha, noise_eps=noise_eps, ) for _ in range(num_simulations): node = root path: list[tuple[MCTSNode, int]] = [] while node.expanded and not node.done: action = node.select_action(c_puct) path.append((node, action)) node = node.child_for_action(action) if node.done: value = terminal_value(node.winner, node.current_player) else: priors, value = evaluate_policy_value(policy, node.board, node.current_player, device) node.expand(priors) for parent, action in reversed(path): value = -value parent.visit_counts[action] += 1 parent.value_sums[action] += value visits = np.zeros(board.shape, dtype=np.float32) for action, count in root.visit_counts.items(): row, col = action_to_coords(action, board.shape[0]) visits[row, col] = float(count) action, visit_probs = sample_from_visits(visits, temperature=temperature) return action, visit_probs def choose_ai_action( policy: PolicyValueNet, board: np.ndarray, current_player: int, win_length: int, device: torch.device, agent: str, mcts_sims: int, c_puct: float, ) -> tuple[int, np.ndarray | None]: if agent == "policy": priors, _ = evaluate_policy_value(policy, board, current_player, device) return int(np.argmax(priors)), None return choose_mcts_action( policy=policy, board=board, current_player=current_player, win_length=win_length, device=device, num_simulations=mcts_sims, c_puct=c_puct, temperature=1e-6, add_root_noise=False, dirichlet_alpha=0.3, noise_eps=0.25, ) def self_play_game( policy: PolicyValueNet, board_size: int, win_length: int, device: torch.device, mcts_sims: int, c_puct: float, temperature: float, temperature_drop_moves: int, dirichlet_alpha: float, noise_eps: float, ) -> tuple[list[tuple[torch.Tensor, np.ndarray, float]], int, int]: env = GomokuEnv(board_size=board_size, win_length=win_length) env.reset() history: list[tuple[torch.Tensor, np.ndarray, int]] = [] move_idx = 0 while not env.done: move_temp = temperature if move_idx < temperature_drop_moves else 1e-6 action, visit_probs = choose_mcts_action( policy=policy, board=env.board, current_player=env.current_player, win_length=win_length, device=device, num_simulations=mcts_sims, c_puct=c_puct, temperature=move_temp, add_root_noise=True, dirichlet_alpha=dirichlet_alpha, noise_eps=noise_eps, ) history.append((encode_state(env.board, env.current_player), visit_probs.reshape(-1), env.current_player)) env.step(action) move_idx += 1 examples: list[tuple[torch.Tensor, np.ndarray, float]] = [] for state, visit_probs, player in history: if env.winner == 0: outcome = 0.0 else: outcome = 1.0 if player == env.winner else -1.0 examples.append((state, visit_probs, outcome)) return examples, env.winner, move_idx def train_batch( policy: PolicyValueNet, optimizer: torch.optim.Optimizer, batch: list[tuple[torch.Tensor, np.ndarray, float]], device: torch.device, value_coef: float, ) -> tuple[float, float, float]: states = torch.stack([item[0] for item in batch]).to(device) target_policies = torch.tensor( np.stack([item[1] for item in batch]), dtype=torch.float32, device=device, ) target_values = torch.tensor([item[2] for item in batch], dtype=torch.float32, device=device) logits, values = policy(states) log_probs = torch.log_softmax(logits, dim=1) policy_loss = -(target_policies * log_probs).sum(dim=1).mean() value_loss = torch.mean((values - target_values) ** 2) loss = policy_loss + value_coef * value_loss optimizer.zero_grad(set_to_none=True) loss.backward() nn.utils.clip_grad_norm_(policy.parameters(), 1.0) optimizer.step() return float(loss.item()), float(policy_loss.item()), float(value_loss.item()) def save_checkpoint(path: Path, policy: PolicyValueNet, args: argparse.Namespace) -> None: torch.save( { "state_dict": policy.state_dict(), "channels": args.channels, "board_size": args.board_size, "win_length": args.win_length, }, path, ) def last_checkpoint_path(base_path: Path) -> Path: return base_path.with_name(f"{base_path.stem}_last{base_path.suffix}") def load_checkpoint(path: Path, map_location: torch.device) -> dict: checkpoint = torch.load(path, map_location=map_location) if isinstance(checkpoint, dict) and "state_dict" in checkpoint: return checkpoint raise RuntimeError(f"{path} is not a compatible gomoku_mcts checkpoint") def resolve_game_config( checkpoint_path: Path, arg_board_size: int | None, arg_win_length: int | None, arg_channels: int, device: torch.device, ) -> tuple[PolicyValueNet, int, int]: checkpoint = load_checkpoint(checkpoint_path, map_location=device) board_size = int(checkpoint.get("board_size") or arg_board_size or 15) win_length = int(checkpoint.get("win_length") or arg_win_length or 5) channels = int(checkpoint.get("channels") or arg_channels) policy = PolicyValueNet(channels=channels).to(device) policy.load_state_dict(checkpoint["state_dict"]) policy.eval() return policy, board_size, win_length def play_vs_random_once( policy: PolicyValueNet, board_size: int, win_length: int, device: torch.device, policy_player: int, agent: str, mcts_sims: int, c_puct: float, ) -> int: env = GomokuEnv(board_size=board_size, win_length=win_length) env.reset() while not env.done: if env.current_player == policy_player: action, _ = choose_ai_action( policy=policy, board=env.board, current_player=env.current_player, win_length=win_length, device=device, agent=agent, mcts_sims=mcts_sims, c_puct=c_puct, ) else: action = int(np.random.choice(env.valid_moves())) env.step(action) return env.winner def evaluate_vs_random( policy: PolicyValueNet, board_size: int, win_length: int, device: torch.device, games: int, agent: str, mcts_sims: int, c_puct: float, ) -> tuple[float, int, int, int]: wins = 0 draws = 0 losses = 0 for game_idx in range(games): policy_player = 1 if game_idx < games // 2 else -1 winner = play_vs_random_once( policy=policy, board_size=board_size, win_length=win_length, device=device, policy_player=policy_player, agent=agent, mcts_sims=mcts_sims, c_puct=c_puct, ) if winner == 0: draws += 1 elif winner == policy_player: wins += 1 else: losses += 1 return wins / max(games, 1), wins, draws, losses def train(args: argparse.Namespace) -> None: set_seed(args.seed) device = choose_device(args.device) policy = PolicyValueNet(channels=args.channels).to(device) if args.init_checkpoint is not None and args.init_checkpoint.exists(): checkpoint = load_checkpoint(args.init_checkpoint, map_location=device) policy.load_state_dict(checkpoint["state_dict"]) optimizer = torch.optim.Adam(policy.parameters(), lr=args.lr, weight_decay=args.weight_decay) replay_buffer: deque[tuple[torch.Tensor, np.ndarray, float]] = deque(maxlen=args.buffer_size) print(f"device={device} board={args.board_size} win={args.win_length}") for iteration in range(1, args.iterations + 1): policy.eval() winners: list[int] = [] lengths: list[int] = [] for _ in range(args.games_per_iter): examples, winner, moves = self_play_game( policy=policy, board_size=args.board_size, win_length=args.win_length, device=device, mcts_sims=args.mcts_sims, c_puct=args.c_puct, temperature=args.temperature, temperature_drop_moves=args.temperature_drop_moves, dirichlet_alpha=args.dirichlet_alpha, noise_eps=args.noise_eps, ) replay_buffer.extend(examples) winners.append(winner) lengths.append(moves) losses: list[tuple[float, float, float]] = [] if len(replay_buffer) >= args.batch_size: policy.train() for _ in range(args.train_steps): batch = random.sample(replay_buffer, args.batch_size) losses.append( train_batch( policy=policy, optimizer=optimizer, batch=batch, device=device, value_coef=args.value_coef, ) ) avg_loss = float(np.mean([x[0] for x in losses])) if losses else 0.0 avg_policy_loss = float(np.mean([x[1] for x in losses])) if losses else 0.0 avg_value_loss = float(np.mean([x[2] for x in losses])) if losses else 0.0 p1_wins = sum(1 for x in winners if x == 1) p2_wins = sum(1 for x in winners if x == -1) draws = sum(1 for x in winners if x == 0) avg_len = float(np.mean(lengths)) if lengths else 0.0 message = ( f"iter={iteration:5d} loss={avg_loss:7.4f} policy={avg_policy_loss:7.4f} " f"value={avg_value_loss:7.4f} p1={p1_wins:3d} p2={p2_wins:3d} draw={draws:3d} " f"avg_len={avg_len:6.2f} buffer={len(replay_buffer):6d}" ) if args.eval_every > 0 and iteration % args.eval_every == 0: policy.eval() win_rate, wins, eval_draws, eval_losses = evaluate_vs_random( policy=policy, board_size=args.board_size, win_length=args.win_length, device=device, games=args.eval_games, agent="mcts", mcts_sims=args.eval_mcts_sims, c_puct=args.c_puct, ) message += f" random_win_rate={win_rate:.3f} ({wins}/{eval_draws}/{eval_losses})" print(message) if args.save_every > 0 and iteration % args.save_every == 0: checkpoint_path = last_checkpoint_path(args.checkpoint) save_checkpoint(checkpoint_path, policy, args) print(f"saved checkpoint to {checkpoint_path}") save_checkpoint(args.checkpoint, policy, args) print(f"saved checkpoint to {args.checkpoint}") def evaluate(args: argparse.Namespace) -> None: device = choose_device(args.device) policy, board_size, win_length = resolve_game_config( checkpoint_path=args.checkpoint, arg_board_size=args.board_size, arg_win_length=args.win_length, arg_channels=args.channels, device=device, ) win_rate, wins, draws, losses = evaluate_vs_random( policy=policy, board_size=board_size, win_length=win_length, device=device, games=args.games, agent=args.agent, mcts_sims=args.mcts_sims, c_puct=args.c_puct, ) print(f"device={device}") print(f"agent={args.agent} mcts_sims={args.mcts_sims}") print(f"win_rate={win_rate:.3f} wins={wins} draws={draws} losses={losses}") def ask_human_move(env: GomokuEnv) -> int: while True: text = input("your move (row col): ").strip() parts = text.replace(",", " ").split() if len(parts) != 2: print("please enter: row col") continue try: row, col = (int(parts[0]) - 1, int(parts[1]) - 1) except ValueError: print("row and col must be integers") continue if not (0 <= row < env.board_size and 0 <= col < env.board_size): print("move out of range") continue if env.board[row, col] != 0: print("that position is occupied") continue return coords_to_action(row, col, env.board_size) def play(args: argparse.Namespace) -> None: device = choose_device(args.device) policy, board_size, win_length = resolve_game_config( checkpoint_path=args.checkpoint, arg_board_size=args.board_size, arg_win_length=args.win_length, arg_channels=args.channels, device=device, ) env = GomokuEnv(board_size=board_size, win_length=win_length) human_player = 1 if args.human_first else -1 print(f"device={device}") print( f"human={'X' if human_player == 1 else 'O'} ai={'O' if human_player == 1 else 'X'} " f"agent={args.agent} mcts_sims={args.mcts_sims}" ) while not env.done: print() print(env.render()) print() if env.current_player == human_player: action = ask_human_move(env) else: action, _ = choose_ai_action( policy=policy, board=env.board, current_player=env.current_player, win_length=win_length, device=device, agent=args.agent, mcts_sims=args.mcts_sims, c_puct=args.c_puct, ) row, col = action_to_coords(action, env.board_size) print(f"ai move: {row + 1} {col + 1}") env.step(action) print() print(env.render()) if env.winner == 0: print("draw") elif env.winner == human_player: print("you win") else: print("ai wins") def gui(args: argparse.Namespace) -> None: try: import pygame except ModuleNotFoundError as exc: raise SystemExit( "pygame is not installed. Install it with: " "~/miniconda3/bin/conda run -n lerobot python -m pip install pygame" ) from exc device = choose_device(args.device) policy, board_size, win_length = resolve_game_config( checkpoint_path=args.checkpoint, arg_board_size=args.board_size, arg_win_length=args.win_length, arg_channels=args.channels, device=device, ) env = GomokuEnv(board_size=board_size, win_length=win_length) human_player = 1 if args.human_first else -1 last_search_visits: np.ndarray | None = None pygame.init() pygame.display.set_caption("Gomoku MCTS") font = pygame.font.SysFont("Arial", 24) small_font = pygame.font.SysFont("Arial", 18) cell_size = args.cell_size padding = 40 status_height = 80 board_pixels = board_size * cell_size screen = pygame.display.set_mode( (board_pixels + padding * 2, board_pixels + padding * 2 + status_height) ) clock = pygame.time.Clock() background = (236, 196, 122) line_color = (80, 55, 20) black_stone = (20, 20, 20) white_stone = (245, 245, 245) accent = (180, 40, 40) def board_to_screen(row: int, col: int) -> tuple[int, int]: x = padding + col * cell_size + cell_size // 2 y = padding + row * cell_size + cell_size // 2 return x, y def mouse_to_action(pos: tuple[int, int]) -> int | None: x, y = pos if x < padding or y < padding: return None col = (x - padding) // cell_size row = (y - padding) // cell_size if not (0 <= row < env.board_size and 0 <= col < env.board_size): return None if env.board[row, col] != 0: return None return coords_to_action(row, col, env.board_size) def ai_step() -> None: nonlocal last_search_visits if env.done or env.current_player == human_player: return action, visits = choose_ai_action( policy=policy, board=env.board, current_player=env.current_player, win_length=win_length, device=device, agent=args.agent, mcts_sims=args.mcts_sims, c_puct=args.c_puct, ) last_search_visits = visits env.step(action) def restart() -> None: nonlocal last_search_visits env.reset() last_search_visits = None if env.current_player != human_player: ai_step() def status_text() -> str: if env.done: if env.winner == 0: return "Draw. Press R to restart." if env.winner == human_player: return "You win. Press R to restart." return "AI wins. Press R to restart." if env.current_player == human_player: return "Your turn. Left click to place." return "AI is thinking..." if env.current_player != human_player: ai_step() running = True while running: for event in pygame.event.get(): if event.type == pygame.QUIT: running = False elif event.type == pygame.KEYDOWN: if event.key == pygame.K_ESCAPE: running = False elif event.key == pygame.K_r: restart() elif event.type == pygame.MOUSEBUTTONDOWN and event.button == 1: if env.done or env.current_player != human_player: continue action = mouse_to_action(event.pos) if action is None: continue env.step(action) ai_step() screen.fill(background) for idx in range(board_size + 1): x = padding + idx * cell_size y = padding + idx * cell_size pygame.draw.line(screen, line_color, (x, padding), (x, padding + board_pixels), 2) pygame.draw.line(screen, line_color, (padding, y), (padding + board_pixels, y), 2) for row in range(env.board_size): for col in range(env.board_size): stone = int(env.board[row, col]) if stone == 0: continue x, y = board_to_screen(row, col) color = black_stone if stone == 1 else white_stone pygame.draw.circle(screen, color, (x, y), cell_size // 2 - 4) pygame.draw.circle(screen, line_color, (x, y), cell_size // 2 - 4, 1) for idx in range(board_size): label = small_font.render(str(idx + 1), True, line_color) screen.blit(label, (padding + idx * cell_size + cell_size // 2 - label.get_width() // 2, 8)) screen.blit(label, (8, padding + idx * cell_size + cell_size // 2 - label.get_height() // 2)) info = ( f"{board_size}x{board_size} connect={win_length} device={device} " f"agent={args.agent} sims={args.mcts_sims}" ) screen.blit(small_font.render(info, True, line_color), (padding, padding + board_pixels + 16)) screen.blit(font.render(status_text(), True, accent), (padding, padding + board_pixels + 42)) if last_search_visits is not None and args.agent == "mcts": peak = int(np.max(last_search_visits)) screen.blit( small_font.render(f"peak_visits={peak}", True, line_color), (padding + 420, padding + board_pixels + 16), ) pygame.display.flip() clock.tick(args.fps) pygame.quit() def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Minimal Gomoku MCTS example") subparsers = parser.add_subparsers(dest="mode", required=True) def add_common_arguments(subparser: argparse.ArgumentParser, defaults_from_checkpoint: bool = False) -> None: board_default = None if defaults_from_checkpoint else 15 win_default = None if defaults_from_checkpoint else 5 subparser.add_argument("--board-size", type=int, default=board_default) subparser.add_argument("--win-length", type=int, default=win_default) subparser.add_argument("--channels", type=int, default=64) subparser.add_argument("--device", choices=["auto", "cpu", "cuda", "mps"], default="auto") subparser.add_argument("--checkpoint", type=Path, default=Path("gomoku_mcts.pt")) def add_inference_arguments(subparser: argparse.ArgumentParser) -> None: subparser.add_argument("--agent", choices=["policy", "mcts"], default="mcts") subparser.add_argument("--mcts-sims", type=int, default=120) subparser.add_argument("--c-puct", type=float, default=1.5) train_parser = subparsers.add_parser("train", help="MCTS self-play training") add_common_arguments(train_parser) train_parser.add_argument("--iterations", type=int, default=200) train_parser.add_argument("--games-per-iter", type=int, default=8) train_parser.add_argument("--train-steps", type=int, default=32) train_parser.add_argument("--batch-size", type=int, default=64) train_parser.add_argument("--buffer-size", type=int, default=20000) train_parser.add_argument("--lr", type=float, default=1e-3) train_parser.add_argument("--weight-decay", type=float, default=1e-4) train_parser.add_argument("--value-coef", type=float, default=1.0) train_parser.add_argument("--mcts-sims", type=int, default=64) train_parser.add_argument("--eval-mcts-sims", type=int, default=120) train_parser.add_argument("--c-puct", type=float, default=1.5) train_parser.add_argument("--temperature", type=float, default=1.0) train_parser.add_argument("--temperature-drop-moves", type=int, default=8) train_parser.add_argument("--dirichlet-alpha", type=float, default=0.3) train_parser.add_argument("--noise-eps", type=float, default=0.25) train_parser.add_argument("--eval-every", type=int, default=10) train_parser.add_argument("--eval-games", type=int, default=20) train_parser.add_argument("--save-every", type=int, default=10) train_parser.add_argument("--seed", type=int, default=42) train_parser.add_argument("--init-checkpoint", type=Path, default=None) train_parser.set_defaults(func=train) eval_parser = subparsers.add_parser("eval", help="evaluate against random agent") add_common_arguments(eval_parser) add_inference_arguments(eval_parser) eval_parser.add_argument("--games", type=int, default=40) eval_parser.set_defaults(func=evaluate) play_parser = subparsers.add_parser("play", help="play against the model") add_common_arguments(play_parser, defaults_from_checkpoint=True) add_inference_arguments(play_parser) play_parser.add_argument("--human-first", action="store_true") play_parser.set_defaults(func=play) gui_parser = subparsers.add_parser("gui", help="pygame GUI") add_common_arguments(gui_parser, defaults_from_checkpoint=True) add_inference_arguments(gui_parser) gui_parser.add_argument("--human-first", action="store_true") gui_parser.add_argument("--cell-size", type=int, default=48) gui_parser.add_argument("--fps", type=int, default=30) gui_parser.set_defaults(func=gui) return parser def main() -> None: parser = build_parser() args = parser.parse_args() args.func(args) if __name__ == "__main__": main()