Spaces:
Sleeping
Sleeping
| import gymnasium as gym | |
| from gymnasium import spaces | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Model | |
| from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D | |
| from tensorflow.keras.optimizers import Adam | |
| import collections | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| # AaduPulliEnv | |
| class AaduPulliEnv(gym.Env): | |
| metadata = {'render.modes': ['human', 'rgb_array']} | |
| def __init__(self): | |
| super(AaduPulliEnv, self).__init__() | |
| self.NUM_GOATS = 15 | |
| self.NUM_TIGERS = 3 | |
| self.TIGER_WIN_THRESHOLD = 10 | |
| self.BOARD_POSITIONS = 23 | |
| self.MAX_TURNS = 200 | |
| self.adj = self._get_adjacency() | |
| self.jump_adj = self._get_jump_adjacency() | |
| self.placement_actions = self.BOARD_POSITIONS | |
| self._move_action_map, self._move_action_lookup = self._create_move_maps() | |
| self.move_actions_count = len(self._move_action_map) | |
| total_actions = self.placement_actions + self.move_actions_count | |
| self.action_space = spaces.Discrete(total_actions) | |
| self.observation_space = spaces.Dict({ | |
| "board": spaces.Box(low=0, high=2, shape=(self.BOARD_POSITIONS,), dtype=np.int32), | |
| "player_turn": spaces.Discrete(2), | |
| "goats_to_place": spaces.Box(low=0, high=self.NUM_GOATS, shape=(1,), dtype=np.int32), | |
| "goats_captured": spaces.Box(low=0, high=self.TIGER_WIN_THRESHOLD, shape=(1,), dtype=np.int32), | |
| }) | |
| self.board_points = self._get_board_coordinates() | |
| self.reset() | |
| def _get_adjacency(self): | |
| return { | |
| 1: [3, 4, 5, 6], 2: [3, 8], 3: [1, 4, 9, 2], 4: [1, 5, 10, 3], 5: [1, 6, 11, 4], 6: [1, 7, 12, 5], 7: [6, 13], | |
| 8: [2, 9, 14], 9: [3, 10, 15, 8], 10: [4, 11, 16, 9], 11: [5, 12, 17, 10], 12: [6, 13, 18, 11], 13: [7, 14, 12], | |
| 14: [8, 15], 15: [9, 16, 20, 14], 16: [10, 17, 21, 15], 17: [11, 18, 22, 16], 18: [12, 19, 23, 17], 19: [13, 18], | |
| 20: [15, 21], 21: [16, 20, 22], 22: [17, 21, 23], 23: [18, 22] | |
| } | |
| def _get_jump_adjacency(self): | |
| return { | |
| 1: [9, 10, 11, 12], 2: [4, 14], 3: [5, 15], 4: [2, 6, 16], 5: [3, 7, 17], 6: [4, 18], 7: [5, 19], | |
| 8: [10], 9: [1, 11, 20], 10: [1, 8, 12, 21], 11: [1, 9, 13, 22], 12: [1, 10, 23], 13: [11], | |
| 14: [2, 16], 15: [3, 17], 16: [4, 14, 18], 17: [5, 15, 19], 18: [6, 16], 19: [7, 17], | |
| 20: [9, 22], 21: [10, 23], 22: [11, 20], 23: [12, 21] | |
| } | |
| def _create_move_maps(self): | |
| action_map, action_lookup, index = {}, {}, 0 | |
| for start_pos in range(1, self.BOARD_POSITIONS + 1): | |
| for end_pos in self.adj.get(start_pos, []): | |
| move = (start_pos, end_pos); action_map[index] = move; action_lookup[move] = index; index += 1 | |
| for end_pos in self.jump_adj.get(start_pos, []): | |
| move = (start_pos, end_pos) | |
| if move not in action_lookup: | |
| action_map[index] = move; action_lookup[move] = index; index += 1 | |
| return action_map, action_lookup | |
| def is_action_valid(self, action): | |
| if not (0 <= action < self.action_space.n): return False, {'error': 'Action out of bounds.'} | |
| if action < self.placement_actions: | |
| to_idx = action | |
| if self.player_turn != 0 or self.goats_placed_count >= self.NUM_GOATS: return False, {'error': 'Cannot place piece now.'} | |
| if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'} | |
| return True, {'type': 'place', 'to_idx': to_idx} | |
| move_idx = action - self.placement_actions; from_pos, to_pos = self._move_action_map[move_idx]; from_idx, to_idx = from_pos - 1, to_pos - 1 | |
| if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'} | |
| if self.player_turn == 0: | |
| if self.goats_placed_count < self.NUM_GOATS: return False, {'error': 'Goat is still in placement phase.'} | |
| if self.board[from_idx] != 1: return False, {'error': 'Player must move a goat.'} | |
| if to_pos not in self.adj.get(from_pos, []): return False, {'error': 'Goat can only move to adjacent squares.'} | |
| return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx} | |
| else: | |
| if self.board[from_idx] != 2: return False, {'error': 'Player must move a tiger.'} | |
| if to_pos in self.adj.get(from_pos, []): return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': False} | |
| if to_pos in self.jump_adj.get(from_pos, []): | |
| from_neighbors = set(self.adj.get(from_pos, [])); to_neighbors = set(self.adj.get(to_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors) | |
| if mid_pos_set: | |
| mid_pos = mid_pos_set.pop() | |
| if self.board[mid_pos - 1] == 1: return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': True, 'mid_idx': mid_pos - 1} | |
| return False, {'error': 'Invalid tiger move.'} | |
| def _are_tigers_blocked(self): | |
| for t_idx in np.where(self.board == 2)[0]: | |
| t_pos = t_idx + 1 | |
| for dest_pos in self.adj.get(t_pos, []): | |
| if self.board[dest_pos - 1] == 0: return False | |
| for dest_pos in self.jump_adj.get(t_pos, []): | |
| if self.board[dest_pos - 1] == 0: | |
| from_neighbors = set(self.adj.get(t_pos, [])); to_neighbors = set(self.adj.get(dest_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors) | |
| if mid_pos_set and self.board[mid_pos_set.pop() - 1] == 1: return False | |
| return True | |
| def _get_current_observation(self): | |
| return {"board":self.board.copy(),"player_turn":self.player_turn,"goats_to_place":np.array([self.NUM_GOATS-self.goats_placed_count],dtype=np.int32),"goats_captured":np.array([self.goats_captured_count],dtype=np.int32)} | |
| def reset(self): | |
| self.board=np.zeros(self.BOARD_POSITIONS,dtype=np.int32); self.board[0]=2; self.board[3]=2; self.board[4]=2;self.player_turn=0; self.goats_placed_count=0; self.goats_captured_count=0; self.turn_count=0 | |
| return self._get_current_observation() | |
| def step(self, action): | |
| is_valid, details = self.is_action_valid(action);reward, done, info = 0, False, details | |
| if not is_valid: reward = -1 | |
| else: | |
| if details['type']=='place': self.board[details['to_idx']]=1; self.goats_placed_count+=1 | |
| elif details['type']=='move': | |
| p=self.board[details['from_idx']]; self.board[details['from_idx']]=0; self.board[details['to_idx']]=p | |
| if p==2 and details.get('is_jump'): self.board[details['mid_idx']]=0; self.goats_captured_count+=1; reward=5 | |
| g_win=self._are_tigers_blocked(); t_win=self.goats_captured_count>=self.TIGER_WIN_THRESHOLD | |
| if g_win: done=True; reward=100; info['winner']=0 | |
| elif t_win: done=True; reward=-100; info['winner']=1 | |
| self.player_turn=1-self.player_turn; self.turn_count+=1 | |
| if not done and self.turn_count>=self.MAX_TURNS: done=True; info['winner']=-1 | |
| return self._get_current_observation(), reward, done, info | |
| def render(self, mode='rgb_array'): | |
| fig,ax=plt.subplots(figsize=(8,8)); ax.clear(); | |
| ax.plot([1,23],[4,4],'k'); ax.plot([1,23],[8,8],'k'); ax.plot([1,23],[12,12],'k'); ax.plot([1,23],[16,16],'k'); ax.plot([1,1],[4,16],'k'); ax.plot([23,23],[4,16],'k'); ax.plot([1,12,23],[4,20,4],'k'); ax.plot([7,12,17],[4,20,4],'k') | |
| for i in range(self.BOARD_POSITIONS): | |
| p,x,y=i+1,self.board_points[i+1][0],self.board_points[i+1][1] | |
| if self.board[i]==1: ax.plot(x,y,'o',ms=20,mfc='royalblue',mec='k',zorder=2); ax.text(x,y,'G',color='w',ha='center',va='center',fontsize=12,fontweight='bold') | |
| elif self.board[i]==2: ax.plot(x,y,'o',ms=25,mfc='orangered',mec='k',zorder=2); ax.text(x,y,'T',color='w',ha='center',va='center',fontsize=12,fontweight='bold') | |
| else: ax.plot(x,y,'o',ms=20,mfc='lightgray',mec='k',zorder=1); ax.text(x,y,str(p),color='k',ha='center',va='center',fontsize=8) | |
| ax.set_xlim(0,24); ax.set_ylim(0,21); ax.set_aspect('equal'); ax.axis('off') | |
| turn_txt="Goat's Turn" if self.player_turn==0 else "Tiger's Turn" | |
| title=f"Aadu Puli Aattam\n{turn_txt}\nGoats to Place: {self.NUM_GOATS-self.goats_placed_count} | Goats Captured: {self.goats_captured_count}" | |
| ax.set_title(title); plt.tight_layout(); | |
| fig.canvas.draw() | |
| img_buf = fig.canvas.buffer_rgba() | |
| img = np.frombuffer(img_buf, dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (4,)) | |
| img = img[:, :, :3] | |
| plt.close(fig) | |
| return img | |
| def _get_board_coordinates(self): | |
| return {1:(12,20), 2:(1,16),3:(9.2,16),4:(10.7,16),5:(13.3,16),6:(14.8,16),7:(23,16), 8:(1,12),9:(6.5,12),10:(9.5,12),11:(14.5,12),12:(17.5,12),13:(23,12), 14:(1,8),15:(3.8,8),16:(8.3,8),17:(15.7,8),18:(20.3,8),19:(23,8), 20:(1,4),21:(7,4),22:(17,4),23:(23,4)} | |
| def copy(self): | |
| new_env = AaduPulliEnv(); new_env.board = self.board.copy(); new_env.player_turn = self.player_turn; new_env.goats_placed_count = self.goats_placed_count; new_env.goats_captured_count = self.goats_captured_count; new_env.turn_count = self.turn_count | |
| return new_env | |
| # NeuralNetwork | |
| class NeuralNetwork: | |
| def __init__(self, action_space_size, learning_rate=0.001): | |
| self.state_shape = (23, 23, 4) | |
| self.action_space_size = action_space_size | |
| self.learning_rate = learning_rate | |
| self.model = self._build_model() | |
| def _build_model(self): | |
| input_layer = Input(shape=self.state_shape, name='matrix_input') | |
| x = Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(input_layer) | |
| x = Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu')(x) | |
| x = Flatten()(x) | |
| x = Dense(256, activation='relu')(x) | |
| policy_output = Dense(self.action_space_size, activation='softmax', name='policy_output')(x) | |
| value_output = Dense(1, activation='tanh', name='value_output')(x) | |
| model = Model(inputs=input_layer, outputs=[policy_output, value_output]) | |
| model.compile(optimizer=Adam(self.learning_rate), loss={'policy_output': 'categorical_crossentropy', 'value_output': 'mean_squared_error'}) | |
| return model | |
| def predict(self, matrix_state): | |
| return self.model.predict(np.expand_dims(matrix_state, axis=0), verbose=0) | |
| # AlphaZeroAgent | |
| class AlphaZeroAgent: | |
| def _zero_array_factory(self): return np.zeros(self.env.action_space.n) | |
| def __init__(self, env, network, simulations_per_move=50, max_depth=25, c_puct=1.0): | |
| self.env = env; self.network = network; self.simulations_per_move = simulations_per_move; self.max_depth = max_depth; self.c_puct = c_puct; | |
| self.Q = collections.defaultdict(self._zero_array_factory); self.N_sa = collections.defaultdict(self._zero_array_factory); self.N_s = collections.defaultdict(int); self.P = {} | |
| def _get_matrix_state(self, state): | |
| board = state['board']; num_nodes = len(board); matrix = np.zeros((num_nodes, num_nodes, 4), dtype=np.float32) | |
| np.fill_diagonal(matrix[:, :, 0], board == 1); np.fill_diagonal(matrix[:, :, 1], board == 2) | |
| adj_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32) | |
| for start, end_list in self.env.adj.items(): | |
| for end in end_list: adj_matrix[start - 1, end - 1] = 1 | |
| matrix[:, :, 2] = adj_matrix; matrix[:, :, 3] = state['player_turn'] | |
| return matrix | |
| def search(self, state, depth): | |
| if depth >= self.max_depth: | |
| _, value = self.network.predict(self._get_matrix_state(state)); return -value[0][0] | |
| state_key = self._get_state_key(state) | |
| if state_key not in self.P: | |
| policy, value = self.network.predict(self._get_matrix_state(state)); self.P[state_key] = policy[0]; return -value[0][0] | |
| node_env = self.env.copy(); node_env.board = state['board'].copy(); node_env.player_turn = state['player_turn']; node_env.goats_placed_count = self.env.NUM_GOATS - state['goats_to_place'][0]; node_env.goats_captured_count = state['goats_captured'][0] | |
| best_ucb = -np.inf; best_action = -1 | |
| valid_actions = [a for a in range(node_env.action_space.n) if node_env.is_action_valid(a)[0]] | |
| for action in valid_actions: | |
| q_value = self.Q[state_key][action]; ucb = q_value + self.c_puct * self.P[state_key][action] * np.sqrt(self.N_s[state_key]) / (1 + self.N_sa[state_key][action]); | |
| if ucb > best_ucb: best_ucb = ucb; best_action = action | |
| if best_action == -1: return 0 | |
| action = best_action | |
| next_state, _, done, info = node_env.step(action) | |
| if done: | |
| winner = info.get('winner', -1); value = 0 | |
| if winner != -1: value = 1 if winner == state['player_turn'] else -1 | |
| else: value = self.search(next_state, depth + 1) | |
| self.Q[state_key][action] = (self.N_sa[state_key][action] * self.Q[state_key][action] + value) / (self.N_sa[state_key][action] + 1); self.N_sa[state_key][action] += 1; self.N_s[state_key] += 1 | |
| return -value | |
| def get_action(self, state, training=False): | |
| state_key = self._get_state_key(state) | |
| for _ in range(self.simulations_per_move): self.search(state, 0) | |
| visit_counts = self.N_sa[state_key] | |
| if np.sum(visit_counts) == 0: | |
| valid_actions = [a for a in range(self.env.action_space.n) if self.env.is_action_valid(a)[0]] | |
| return np.random.choice(valid_actions) if valid_actions else 0 | |
| if training: | |
| tau = 1.0; action_probs = visit_counts**(1/tau); | |
| if np.sum(action_probs) > 0: action_probs /= np.sum(action_probs) | |
| else: | |
| valid_actions=[a for a in range(self.env.action_space.n) if self.env.is_action_valid(a)[0]]; action_probs=np.zeros(self.env.action_space.n) | |
| if valid_actions: action_probs[valid_actions] = 1 / len(valid_actions) | |
| action = np.random.choice(self.env.action_space.n, p=action_probs) | |
| else: action = np.argmax(visit_counts) | |
| return action | |
| def _get_state_key(self, state): return (state['board'].tobytes(), state['player_turn']) | |