| import math
|
| import numpy as np
|
| import torch
|
| from src.game import OthelloGame
|
| from src.bitboard import make_input_planes, bit_to_row_col, popcount
|
|
|
| class MCTSNode:
|
| def __init__(self, prior, to_play):
|
| self.prior = prior
|
| self.visit_count = 0
|
| self.value_sum = 0
|
| self.children = {}
|
| self.to_play = to_play
|
|
|
| def value(self):
|
| if self.visit_count == 0:
|
| return 0
|
| return self.value_sum / self.visit_count
|
|
|
| def expand(self, policy_logits, valid_moves, next_to_play):
|
| """
|
| Expands the node using the policy from the neural network.
|
| """
|
|
|
| policy = np.exp(policy_logits - np.max(policy_logits))
|
| policy /= np.sum(policy)
|
|
|
|
|
|
|
|
|
|
|
| valid_probs_sum = 0
|
| temp_children = {}
|
|
|
| for move_bit in valid_moves:
|
| if move_bit == 0:
|
| idx = 64
|
| else:
|
| r, c = bit_to_row_col(move_bit)
|
|
|
| if r == -1: idx = 64
|
| else: idx = r * 8 + c
|
|
|
| prob = policy[idx]
|
| valid_probs_sum += prob
|
| temp_children[move_bit] = prob
|
|
|
|
|
| if valid_probs_sum > 0:
|
| for move, prob in temp_children.items():
|
| self.children[move] = MCTSNode(prior=prob / valid_probs_sum, to_play=next_to_play)
|
| else:
|
|
|
| prob = 1.0 / len(valid_moves)
|
| for move in valid_moves:
|
| self.children[move] = MCTSNode(prior=prob, to_play=next_to_play)
|
|
|
| class MCTS:
|
| def __init__(self, model, cpuct=1.0, num_simulations=800):
|
| self.model = model
|
| self.cpuct = cpuct
|
| self.num_simulations = num_simulations
|
|
|
| def search(self, game: OthelloGame):
|
| """
|
| Executes MCTS simulations and returns the root node (containing mechanics for move selection).
|
| """
|
|
|
| valid_moves_bb = game.get_valid_moves(game.player_bb, game.opponent_bb)
|
| valid_moves_list = self._get_moves_list(valid_moves_bb)
|
|
|
|
|
|
|
|
|
|
|
| if valid_moves_bb == 0:
|
| if game.is_terminal():
|
| return None
|
| valid_moves_list = [0]
|
|
|
|
|
| root = MCTSNode(prior=0, to_play=game.turn)
|
|
|
|
|
| state_tensor = make_input_planes(game.player_bb, game.opponent_bb)
|
|
|
|
|
| device = next(self.model.parameters()).device
|
| state_tensor = state_tensor.to(device)
|
|
|
| self.model.eval()
|
| with torch.no_grad():
|
| policy_logits, _ = self.model(state_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| root.expand(policy_logits.cpu().numpy().flatten(), valid_moves_list, -game.turn)
|
|
|
|
|
| self._add_dirichlet_noise(root)
|
|
|
| for _ in range(self.num_simulations):
|
| node = root
|
| sim_game = self._clone_game(game)
|
| search_path = [node]
|
| last_value = 0
|
|
|
|
|
| while node.children:
|
| move_bit, node = self._select_child(node)
|
| search_path.append(node)
|
| sim_game.play_move(move_bit)
|
|
|
|
|
| if sim_game.is_terminal():
|
|
|
|
|
|
|
|
|
| p1_score = popcount(sim_game.player_bb) if sim_game.turn == 1 else popcount(sim_game.opponent_bb)
|
|
|
|
|
|
|
|
|
|
|
| if sim_game.turn == 1:
|
| black_score = popcount(sim_game.player_bb)
|
| white_score = popcount(sim_game.opponent_bb)
|
| else:
|
| white_score = popcount(sim_game.player_bb)
|
| black_score = popcount(sim_game.opponent_bb)
|
|
|
| diff = black_score - white_score
|
| if diff > 0: last_value = 1.0
|
| elif diff < 0: last_value = -1.0
|
| else: last_value = 0.0
|
|
|
| else:
|
|
|
| state_tensor = make_input_planes(sim_game.player_bb, sim_game.opponent_bb)
|
|
|
|
|
| device = next(self.model.parameters()).device
|
| state_tensor = state_tensor.to(device)
|
|
|
| with torch.no_grad():
|
| policy_logits, v = self.model(state_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
| val_for_current = v.item()
|
| if sim_game.turn == 1:
|
| last_value = val_for_current
|
| else:
|
| last_value = -val_for_current
|
|
|
|
|
|
|
| valid_bb = sim_game.get_valid_moves(sim_game.player_bb, sim_game.opponent_bb)
|
| valid_list = self._get_moves_list(valid_bb)
|
| if valid_bb == 0: valid_list = [0]
|
|
|
| node.expand(policy_logits.cpu().numpy().flatten(), valid_list, -sim_game.turn)
|
|
|
|
|
| self._backpropagate(search_path, last_value)
|
|
|
| return root
|
|
|
| def _select_child(self, node):
|
| best_score = -float('inf')
|
| best_action = None
|
| best_child = None
|
|
|
| for action, child in node.children.items():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| mean_val = child.value()
|
|
|
| if node.to_play == 1:
|
| q = mean_val
|
| else:
|
| q = -mean_val
|
|
|
|
|
|
|
|
|
|
|
|
|
| u = self.cpuct * child.prior * math.sqrt(node.visit_count) / (1 + child.visit_count)
|
|
|
| score = q + u
|
| if score > best_score:
|
| best_score = score
|
| best_action = action
|
| best_child = child
|
|
|
| return best_action, best_child
|
|
|
| def _backpropagate(self, search_path, value):
|
| """
|
| value: The evaluation of the lead node, from BLACK's perspective (1=Black wins, -1=White wins).
|
| """
|
| for node in search_path:
|
| node.value_sum += value
|
| node.visit_count += 1
|
|
|
|
|
|
|
| def _add_dirichlet_noise(self, node):
|
| eps = 0.25
|
| alpha = 0.3
|
| moves = list(node.children.keys())
|
| noise = np.random.dirichlet([alpha] * len(moves))
|
|
|
| for i, move in enumerate(moves):
|
| node.children[move].prior = (1 - eps) * node.children[move].prior + eps * noise[i]
|
|
|
| def _get_moves_list(self, moves_bb):
|
| moves = []
|
| if moves_bb == 0: return []
|
|
|
|
|
|
|
|
|
| temp = moves_bb
|
| while temp:
|
|
|
| lsb = temp & -temp
|
| moves.append(lsb)
|
| temp ^= lsb
|
| return moves
|
|
|
| def _clone_game(self, game):
|
| new_game = OthelloGame()
|
| new_game.player_bb = game.player_bb
|
| new_game.opponent_bb = game.opponent_bb
|
| new_game.turn = game.turn
|
| return new_game
|
|
|