| import argparse |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
|
|
| from game import UltimateTicTacToe |
| from mcts import MCTS |
| from model import UltimateTicTacToeModel |
| from trainer import Trainer |
|
|
|
|
| DEFAULT_ARGS = { |
| "num_simulations": 100, |
| "numIters": 50, |
| "numEps": 20, |
| "epochs": 5, |
| "batch_size": 64, |
| "lr": 5e-4, |
| "weight_decay": 1e-4, |
| "replay_buffer_size": 50000, |
| "value_loss_weight": 1.0, |
| "grad_clip_norm": 5.0, |
| "checkpoint_path": "latest.pth", |
| "temperature_threshold": 10, |
| "root_dirichlet_alpha": 0.3, |
| "root_exploration_fraction": 0.25, |
| "arena_compare_games": 6, |
| "arena_accept_threshold": 0.55, |
| "arena_compare_simulations": 8, |
| } |
|
|
|
|
| def get_device(device_arg): |
| if device_arg: |
| return device_arg |
| return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def build_model(game, device): |
| return UltimateTicTacToeModel( |
| game.get_board_size(), |
| game.get_action_size(), |
| device, |
| ) |
|
|
|
|
| def load_checkpoint(model, checkpoint_path, device, optimizer=None, required=True): |
| checkpoint = Path(checkpoint_path) |
| if not checkpoint.exists(): |
| if required: |
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint}") |
| return False |
|
|
| state = torch.load(checkpoint, map_location=device) |
| model.load_state_dict(state["state_dict"]) |
| if optimizer is not None and "optimizer_state_dict" in state: |
| optimizer.load_state_dict(state["optimizer_state_dict"]) |
| model.eval() |
| return True |
|
|
|
|
| def canonical_state(game, state, player): |
| board_data, active_board = state |
| return (game.get_canonical_board_data(board_data, player), active_board) |
|
|
|
|
| def apply_moves(game, moves): |
| state = game.get_init_board() |
| player = 1 |
| for action in moves: |
| next_state = game.get_next_state(state, player, action, verify_move=True) |
| if next_state is False: |
| raise ValueError(f"Illegal move in sequence: {action}") |
| state, player = next_state |
| return state, player |
|
|
|
|
| def format_board(board_data): |
| symbols = {1: "X", -1: "O", 0: "."} |
| rows = [] |
| for row in range(9): |
| cells = [symbols[int(board_data[row * 9 + col])] for col in range(9)] |
| groups = [" ".join(cells[idx:idx + 3]) for idx in (0, 3, 6)] |
| rows.append(" | ".join(groups)) |
| if row in (2, 5): |
| rows.append("-" * 23) |
| return "\n".join(rows) |
|
|
|
|
| def top_policy_moves(policy, limit): |
| ranked = np.argsort(policy)[::-1][:limit] |
| return [(int(action), float(policy[action])) for action in ranked] |
|
|
|
|
| def parse_moves(text): |
| if not text: |
| return [] |
| return [int(part.strip()) for part in text.split(",") if part.strip()] |
|
|
|
|
| def parse_action(text): |
| raw = text.strip().replace(",", " ").split() |
| if len(raw) == 1: |
| action = int(raw[0]) |
| elif len(raw) == 2: |
| row, col = (int(value) for value in raw) |
| if not (0 <= row < 9 and 0 <= col < 9): |
| raise ValueError("Row and column must be in [0, 8].") |
| action = row * 9 + col |
| else: |
| raise ValueError("Enter either a flat move index or 'row col'.") |
| if not (0 <= action < 81): |
| raise ValueError("Move index must be in [0, 80].") |
| return action |
|
|
|
|
| def scalar_value(value): |
| return float(np.asarray(value).reshape(-1)[0]) |
|
|
|
|
| def train_command(args): |
| device = get_device(args.device) |
| game = UltimateTicTacToe() |
| model = build_model(game, device) |
|
|
| train_args = dict(DEFAULT_ARGS) |
| train_args.update( |
| { |
| "num_simulations": args.num_simulations, |
| "numIters": args.num_iters, |
| "numEps": args.num_eps, |
| "epochs": args.epochs, |
| "batch_size": args.batch_size, |
| "lr": args.lr, |
| "weight_decay": args.weight_decay, |
| "replay_buffer_size": args.replay_buffer_size, |
| "value_loss_weight": args.value_loss_weight, |
| "grad_clip_norm": args.grad_clip_norm, |
| "checkpoint_path": args.checkpoint, |
| "temperature_threshold": args.temperature_threshold, |
| "root_dirichlet_alpha": args.root_dirichlet_alpha, |
| "root_exploration_fraction": args.root_exploration_fraction, |
| "arena_compare_games": args.arena_compare_games, |
| "arena_accept_threshold": args.arena_accept_threshold, |
| "arena_compare_simulations": args.arena_compare_simulations, |
| } |
| ) |
|
|
| trainer = Trainer(game, model, train_args) |
| if args.resume: |
| load_checkpoint(model, args.checkpoint, device, optimizer=trainer.optimizer) |
| trainer.learn() |
|
|
|
|
| def eval_command(args): |
| device = get_device(args.device) |
| game = UltimateTicTacToe() |
| model = build_model(game, device) |
| load_checkpoint(model, args.checkpoint, device) |
|
|
| moves = parse_moves(args.moves) |
| state, player = apply_moves(game, moves) |
| current_state = canonical_state(game, state, player) |
| encoded = game.encode_state(current_state) |
| policy, value = model.predict(encoded) |
| legal_mask = np.array(game.get_valid_moves(state), dtype=np.float32) |
| policy = policy * legal_mask |
| if policy.sum() > 0: |
| policy = policy / policy.sum() |
|
|
| print("Board:") |
| print(format_board(state[0])) |
| print() |
| print(f"Side to move: {'X' if player == 1 else 'O'}") |
| print(f"Active small board: {state[1]}") |
| print(f"Model value: {scalar_value(value):.4f}") |
| print("Top policy moves:") |
| for action, prob in top_policy_moves(policy, args.top_k): |
| print(f" {action:2d} -> {prob:.4f}") |
|
|
| if args.with_mcts: |
| mcts_args = dict(DEFAULT_ARGS) |
| mcts_args.update( |
| { |
| "num_simulations": args.num_simulations, |
| "root_dirichlet_alpha": None, |
| "root_exploration_fraction": None, |
| } |
| ) |
| root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1) |
| action = root.select_action(temperature=0) |
| print(f"MCTS best move: {action}") |
|
|
|
|
| def ai_action(game, model, state, player, num_simulations): |
| current_state = canonical_state(game, state, player) |
| mcts_args = dict(DEFAULT_ARGS) |
| mcts_args.update( |
| { |
| "num_simulations": num_simulations, |
| "root_dirichlet_alpha": None, |
| "root_exploration_fraction": None, |
| } |
| ) |
| root = MCTS(game, model, mcts_args).run(model, current_state, to_play=1) |
| return root.select_action(temperature=0) |
|
|
|
|
| def random_action(game, state): |
| legal_actions = [index for index, allowed in enumerate(game.get_valid_moves(state)) if allowed] |
| if not legal_actions: |
| raise ValueError("No legal actions available.") |
| return int(np.random.choice(legal_actions)) |
|
|
|
|
| def load_player_model(game, checkpoint, device): |
| model = build_model(game, device) |
| load_checkpoint(model, checkpoint, device) |
| return model |
|
|
|
|
| def choose_action(game, player_kind, model, state, player, num_simulations): |
| if player_kind == "random": |
| return random_action(game, state) |
| return ai_action(game, model, state, player, num_simulations) |
|
|
|
|
| def play_match(game, x_kind, x_model, o_kind, o_model, num_simulations): |
| state = game.get_init_board() |
| player = 1 |
|
|
| while True: |
| reward = game.get_reward_for_player(state, player) |
| if reward is not None: |
| if reward == 0: |
| return 0 |
| return player if reward == 1 else -player |
|
|
| if player == 1: |
| action = choose_action(game, x_kind, x_model, state, player, num_simulations) |
| else: |
| action = choose_action(game, o_kind, o_model, state, player, num_simulations) |
| state, player = game.get_next_state(state, player, action) |
|
|
|
|
| def arena_command(args): |
| device = get_device(args.device) |
| game = UltimateTicTacToe() |
|
|
| x_model = None |
| o_model = None |
| if args.x_player == "checkpoint": |
| x_model = load_player_model(game, args.x_checkpoint, device) |
| if args.o_player == "checkpoint": |
| o_model = load_player_model(game, args.o_checkpoint, device) |
|
|
| results = {1: 0, -1: 0, 0: 0} |
| for _ in range(args.games): |
| winner = play_match( |
| game, |
| args.x_player, |
| x_model, |
| args.o_player, |
| o_model, |
| args.num_simulations, |
| ) |
| results[winner] += 1 |
|
|
| print(f"Games: {args.games}") |
| print(f"X ({args.x_player}) wins: {results[1]}") |
| print(f"O ({args.o_player}) wins: {results[-1]}") |
| print(f"Draws: {results[0]}") |
|
|
|
|
| def play_command(args): |
| device = get_device(args.device) |
| game = UltimateTicTacToe() |
| model = build_model(game, device) |
| load_checkpoint(model, args.checkpoint, device) |
|
|
| state = game.get_init_board() |
| player = 1 |
| human_player = args.human_player |
|
|
| while True: |
| print() |
| print(format_board(state[0])) |
| print(f"Turn: {'X' if player == 1 else 'O'}") |
| print(f"Active small board: {state[1]}") |
|
|
| reward = game.get_reward_for_player(state, player) |
| if reward is not None: |
| if reward == 0: |
| print("Result: draw") |
| else: |
| winner = player if reward == 1 else -player |
| print(f"Winner: {'X' if winner == 1 else 'O'}") |
| return |
|
|
| valid_moves = game.get_valid_moves(state) |
| legal_actions = [index for index, allowed in enumerate(valid_moves) if allowed] |
| print(f"Legal moves: {legal_actions}") |
|
|
| if player == human_player: |
| while True: |
| try: |
| action = parse_action(input("Your move (index or 'row col'): ")) |
| next_state = game.get_next_state(state, player, action, verify_move=True) |
| if next_state is False: |
| raise ValueError(f"Illegal move: {action}") |
| state, player = next_state |
| break |
| except ValueError as exc: |
| print(exc) |
| else: |
| action = ai_action(game, model, state, player, args.num_simulations) |
| print(f"AI move: {action}") |
| state, player = game.get_next_state(state, player, action) |
|
|
|
|
| def build_parser(): |
| parser = argparse.ArgumentParser(description="Ultimate Tic-Tac-Toe Runner") |
| subparsers = parser.add_subparsers(dest="command", required=True) |
|
|
| train_parser = subparsers.add_parser("train", help="Train the model with self-play") |
| train_parser.add_argument("--device") |
| train_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) |
| train_parser.add_argument("--resume", action="store_true") |
| train_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) |
| train_parser.add_argument("--num-iters", type=int, default=DEFAULT_ARGS["numIters"]) |
| train_parser.add_argument("--num-eps", type=int, default=DEFAULT_ARGS["numEps"]) |
| train_parser.add_argument("--epochs", type=int, default=DEFAULT_ARGS["epochs"]) |
| train_parser.add_argument("--batch-size", type=int, default=DEFAULT_ARGS["batch_size"]) |
| train_parser.add_argument("--lr", type=float, default=DEFAULT_ARGS["lr"]) |
| train_parser.add_argument("--weight-decay", type=float, default=DEFAULT_ARGS["weight_decay"]) |
| train_parser.add_argument("--replay-buffer-size", type=int, default=DEFAULT_ARGS["replay_buffer_size"]) |
| train_parser.add_argument("--value-loss-weight", type=float, default=DEFAULT_ARGS["value_loss_weight"]) |
| train_parser.add_argument("--grad-clip-norm", type=float, default=DEFAULT_ARGS["grad_clip_norm"]) |
| train_parser.add_argument("--temperature-threshold", type=int, default=DEFAULT_ARGS["temperature_threshold"]) |
| train_parser.add_argument("--root-dirichlet-alpha", type=float, default=DEFAULT_ARGS["root_dirichlet_alpha"]) |
| train_parser.add_argument("--root-exploration-fraction", type=float, default=DEFAULT_ARGS["root_exploration_fraction"]) |
| train_parser.add_argument("--arena-compare-games", type=int, default=DEFAULT_ARGS["arena_compare_games"]) |
| train_parser.add_argument("--arena-accept-threshold", type=float, default=DEFAULT_ARGS["arena_accept_threshold"]) |
| train_parser.add_argument("--arena-compare-simulations", type=int, default=DEFAULT_ARGS["arena_compare_simulations"]) |
| train_parser.set_defaults(func=train_command) |
|
|
| eval_parser = subparsers.add_parser("eval", help="Inspect a checkpoint on a position") |
| eval_parser.add_argument("--device") |
| eval_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) |
| eval_parser.add_argument("--moves", default="", help="Comma-separated move sequence") |
| eval_parser.add_argument("--top-k", type=int, default=10) |
| eval_parser.add_argument("--with-mcts", action="store_true") |
| eval_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) |
| eval_parser.set_defaults(func=eval_command) |
|
|
| play_parser = subparsers.add_parser("play", help="Play against the checkpoint") |
| play_parser.add_argument("--device") |
| play_parser.add_argument("--checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) |
| play_parser.add_argument("--human-player", type=int, choices=[1, -1], default=1) |
| play_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) |
| play_parser.set_defaults(func=play_command) |
|
|
| arena_parser = subparsers.add_parser("arena", help="Run repeated matches between agents") |
| arena_parser.add_argument("--device") |
| arena_parser.add_argument("--games", type=int, default=20) |
| arena_parser.add_argument("--num-simulations", type=int, default=DEFAULT_ARGS["num_simulations"]) |
| arena_parser.add_argument("--x-player", choices=["checkpoint", "random"], default="checkpoint") |
| arena_parser.add_argument("--o-player", choices=["checkpoint", "random"], default="random") |
| arena_parser.add_argument("--x-checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) |
| arena_parser.add_argument("--o-checkpoint", default=DEFAULT_ARGS["checkpoint_path"]) |
| arena_parser.set_defaults(func=arena_command) |
|
|
| return parser |
|
|
|
|
| def main(): |
| parser = build_parser() |
| args = parser.parse_args() |
| args.func(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|