| import logging |
| import math |
|
|
| import numpy as np |
|
|
|
|
| |
|
|
| EPS = 1e-8 |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class MCTS(): |
| """ |
| This class handles the MCTS tree. |
| """ |
|
|
| def __init__(self, game, nnet, args): |
| self.game = game |
| self.nnet = nnet |
| self.args = args |
| self.Qsa = {} |
| self.Nsa = {} |
| self.Ns = {} |
| self.Ps = {} |
|
|
| self.Es = {} |
| self.Vs = {} |
|
|
| def getActionProb(self, canonicalBoard, temp=1): |
| """ |
| This function performs numMCTSSims simulations of MCTS starting from |
| canonicalBoard. |
| |
| Returns: |
| probs: a policy vector where the probability of the ith action is |
| proportional to Nsa[(s,a)]**(1./temp) |
| """ |
| for i in range(self.args.numMCTSSims): |
| |
| self.game.reset_steps() |
| self.search(canonicalBoard) |
|
|
| s = self.game.stringRepresentation(canonicalBoard) |
| counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())] |
|
|
| if temp == 0: |
| bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten() |
| bestA = np.random.choice(bestAs) |
| probs = [0] * len(counts) |
| probs[bestA] = 1 |
| return probs |
|
|
| counts = [x ** (1. / temp) for x in counts] |
| counts_sum = float(sum(counts)) |
| if counts_sum == 0: |
| print(len(counts)) |
| probs = [x / counts_sum for x in counts] |
| return probs |
| |
| def search_iterate(self, canonicalBoard): |
| stack = [(0, (canonicalBoard,))] |
| results = [] |
|
|
| while stack: |
| st, sv = stack.pop() |
| if st == 0: |
| result, ns = self.search_iterate_st0(sv[0]) |
| if result is not None: |
| results.append(result) |
| if ns is not None: |
| stack.append((1, (ns[1], ns[2]))) |
| stack.append((0, (ns[0],))) |
| elif st == 1: |
| v = results.pop() |
| v = self.search_iterate_update(v, sv[0], sv[1]) |
| results.append(v) |
| else: |
| raise ValueError("Invalid state") |
| return results.pop() |
| |
| def search_iterate_st0(self, canonicalBoard): |
| s = self.game.stringRepresentation(canonicalBoard) |
|
|
| if s not in self.Es: |
| self.Es[s] = self.game.getGameEnded(canonicalBoard, 1) |
| if self.Es[s] != 0: |
| result = -self.Es[s] |
| return result, None |
| if s not in self.Ps: |
| |
| self.Ps[s], v = self.nnet.predict(canonicalBoard) |
| valids = self.game.getValidMoves(canonicalBoard, 1) |
| self.Ps[s] = self.Ps[s] * valids |
| sum_Ps_s = np.sum(self.Ps[s]) |
| if sum_Ps_s > 0: |
| self.Ps[s] /= sum_Ps_s |
| else: |
| self.Ps[s] = self.Ps[s] + valids |
| self.Ps[s] /= np.sum(self.Ps[s]) |
|
|
| self.Vs[s] = valids |
| self.Ns[s] = 0 |
|
|
| return -v, None |
|
|
| valids = self.Vs[s] |
| cur_best = -float('inf') |
| best_act = -1 |
|
|
| for a in range(self.game.getActionSize()): |
| if valids[a]: |
| if (s, a) in self.Qsa: |
| u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / (1 + self.Nsa[(s, a)]) |
| else: |
| u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS) |
|
|
| if u > cur_best: |
| cur_best = u |
| best_act = a |
|
|
| next_s, next_player = self.game.getNextState(canonicalBoard, 1, best_act) |
| next_s = self.game.getCanonicalForm(next_s, next_player) |
|
|
| return None, (next_s, s, best_act) |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| def search_iterate_update(self, v, s, a): |
| if (s, a) in self.Qsa: |
| self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1) |
| self.Nsa[(s, a)] += 1 |
|
|
| else: |
| self.Qsa[(s, a)] = v |
| self.Nsa[(s, a)] = 1 |
|
|
| self.Ns[s] += 1 |
| return -v |
|
|
| def search(self, canonicalBoard, depth=0): |
| """ |
| This function performs one iteration of MCTS. It is recursively called |
| till a leaf node is found. The action chosen at each node is one that |
| has the maximum upper confidence bound as in the paper. |
| |
| Once a leaf node is found, the neural network is called to return an |
| initial policy P and a value v for the state. This value is propagated |
| up the search path. In case the leaf node is a terminal state, the |
| outcome is propagated up the search path. The values of Ns, Nsa, Qsa are |
| updated. |
| |
| NOTE: the return values are the negative of the value of the current |
| state. This is done since v is in [-1,1] and if v is the value of a |
| state for the current player, then its value is -v for the other player. |
| |
| Returns: |
| v: the negative of the value of the current canonicalBoard |
| """ |
|
|
| s = self.game.stringRepresentation(canonicalBoard) |
|
|
| if s not in self.Es: |
| self.Es[s] = self.game.getGameEnded(canonicalBoard, 1) |
| if self.Es[s] != 0: |
| |
| return -self.Es[s] |
|
|
| if s not in self.Ps: |
| |
| self.Ps[s], v = self.nnet.predict(canonicalBoard) |
| valids = self.game.getValidMoves(canonicalBoard, 1) |
| self.Ps[s] = self.Ps[s] * valids |
| sum_Ps_s = np.sum(self.Ps[s]) |
| if sum_Ps_s > 0: |
| self.Ps[s] /= sum_Ps_s |
| else: |
| |
|
|
| |
| |
| log.error("All valid moves were masked, doing a workaround.") |
| self.Ps[s] = self.Ps[s] + valids |
| self.Ps[s] /= np.sum(self.Ps[s]) |
|
|
| self.Vs[s] = valids |
| self.Ns[s] = 0 |
| return -v |
|
|
| valids = self.Vs[s] |
| cur_best = -float('inf') |
| best_act = -1 |
|
|
| |
| for a in range(self.game.getActionSize()): |
| if valids[a]: |
| if (s, a) in self.Qsa: |
| u = self.Qsa[(s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s]) / ( |
| 1 + self.Nsa[(s, a)]) |
| else: |
| u = self.args.cpuct * self.Ps[s][a] * math.sqrt(self.Ns[s] + EPS) |
|
|
| if u > cur_best: |
| cur_best = u |
| best_act = a |
|
|
| a = best_act |
|
|
| if depth > 100: |
| candidates = self.game.getValidMoves(canonicalBoard, 1) |
| a = np.random.choice([i for i in range(len(candidates)) if candidates[i] == 1]) |
| |
| |
| |
| depth = 80 |
|
|
|
|
| next_s, next_player = self.game.getNextState(canonicalBoard, 1, a) |
| next_s = self.game.getCanonicalForm(next_s, next_player) |
| |
| |
|
|
| v = self.search(next_s, depth=depth + 1) |
|
|
| if (s, a) in self.Qsa: |
| self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1) |
| self.Nsa[(s, a)] += 1 |
|
|
| else: |
| self.Qsa[(s, a)] = v |
| self.Nsa[(s, a)] = 1 |
|
|
| self.Ns[s] += 1 |
| return -v |
|
|