Spaces:
Sleeping
Sleeping
File size: 14,442 Bytes
387a65a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | import gymnasium as gym
from gymnasium import spaces
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D
from tensorflow.keras.optimizers import Adam
import collections
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# AaduPulliEnv
class AaduPulliEnv(gym.Env):
metadata = {'render.modes': ['human', 'rgb_array']}
def __init__(self):
super(AaduPulliEnv, self).__init__()
self.NUM_GOATS = 15
self.NUM_TIGERS = 3
self.TIGER_WIN_THRESHOLD = 10
self.BOARD_POSITIONS = 23
self.MAX_TURNS = 200
self.adj = self._get_adjacency()
self.jump_adj = self._get_jump_adjacency()
self.placement_actions = self.BOARD_POSITIONS
self._move_action_map, self._move_action_lookup = self._create_move_maps()
self.move_actions_count = len(self._move_action_map)
total_actions = self.placement_actions + self.move_actions_count
self.action_space = spaces.Discrete(total_actions)
self.observation_space = spaces.Dict({
"board": spaces.Box(low=0, high=2, shape=(self.BOARD_POSITIONS,), dtype=np.int32),
"player_turn": spaces.Discrete(2),
"goats_to_place": spaces.Box(low=0, high=self.NUM_GOATS, shape=(1,), dtype=np.int32),
"goats_captured": spaces.Box(low=0, high=self.TIGER_WIN_THRESHOLD, shape=(1,), dtype=np.int32),
})
self.board_points = self._get_board_coordinates()
self.reset()
def _get_adjacency(self):
return {
1: [3, 4, 5, 6], 2: [3, 8], 3: [1, 4, 9, 2], 4: [1, 5, 10, 3], 5: [1, 6, 11, 4], 6: [1, 7, 12, 5], 7: [6, 13],
8: [2, 9, 14], 9: [3, 10, 15, 8], 10: [4, 11, 16, 9], 11: [5, 12, 17, 10], 12: [6, 13, 18, 11], 13: [7, 14, 12],
14: [8, 15], 15: [9, 16, 20, 14], 16: [10, 17, 21, 15], 17: [11, 18, 22, 16], 18: [12, 19, 23, 17], 19: [13, 18],
20: [15, 21], 21: [16, 20, 22], 22: [17, 21, 23], 23: [18, 22]
}
def _get_jump_adjacency(self):
return {
1: [9, 10, 11, 12], 2: [4, 14], 3: [5, 15], 4: [2, 6, 16], 5: [3, 7, 17], 6: [4, 18], 7: [5, 19],
8: [10], 9: [1, 11, 20], 10: [1, 8, 12, 21], 11: [1, 9, 13, 22], 12: [1, 10, 23], 13: [11],
14: [2, 16], 15: [3, 17], 16: [4, 14, 18], 17: [5, 15, 19], 18: [6, 16], 19: [7, 17],
20: [9, 22], 21: [10, 23], 22: [11, 20], 23: [12, 21]
}
def _create_move_maps(self):
action_map, action_lookup, index = {}, {}, 0
for start_pos in range(1, self.BOARD_POSITIONS + 1):
for end_pos in self.adj.get(start_pos, []):
move = (start_pos, end_pos); action_map[index] = move; action_lookup[move] = index; index += 1
for end_pos in self.jump_adj.get(start_pos, []):
move = (start_pos, end_pos)
if move not in action_lookup:
action_map[index] = move; action_lookup[move] = index; index += 1
return action_map, action_lookup
def is_action_valid(self, action):
if not (0 <= action < self.action_space.n): return False, {'error': 'Action out of bounds.'}
if action < self.placement_actions:
to_idx = action
if self.player_turn != 0 or self.goats_placed_count >= self.NUM_GOATS: return False, {'error': 'Cannot place piece now.'}
if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'}
return True, {'type': 'place', 'to_idx': to_idx}
move_idx = action - self.placement_actions; from_pos, to_pos = self._move_action_map[move_idx]; from_idx, to_idx = from_pos - 1, to_pos - 1
if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'}
if self.player_turn == 0:
if self.goats_placed_count < self.NUM_GOATS: return False, {'error': 'Goat is still in placement phase.'}
if self.board[from_idx] != 1: return False, {'error': 'Player must move a goat.'}
if to_pos not in self.adj.get(from_pos, []): return False, {'error': 'Goat can only move to adjacent squares.'}
return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx}
else:
if self.board[from_idx] != 2: return False, {'error': 'Player must move a tiger.'}
if to_pos in self.adj.get(from_pos, []): return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': False}
if to_pos in self.jump_adj.get(from_pos, []):
from_neighbors = set(self.adj.get(from_pos, [])); to_neighbors = set(self.adj.get(to_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors)
if mid_pos_set:
mid_pos = mid_pos_set.pop()
if self.board[mid_pos - 1] == 1: return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': True, 'mid_idx': mid_pos - 1}
return False, {'error': 'Invalid tiger move.'}
def _are_tigers_blocked(self):
for t_idx in np.where(self.board == 2)[0]:
t_pos = t_idx + 1
for dest_pos in self.adj.get(t_pos, []):
if self.board[dest_pos - 1] == 0: return False
for dest_pos in self.jump_adj.get(t_pos, []):
if self.board[dest_pos - 1] == 0:
from_neighbors = set(self.adj.get(t_pos, [])); to_neighbors = set(self.adj.get(dest_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors)
if mid_pos_set and self.board[mid_pos_set.pop() - 1] == 1: return False
return True
def _get_current_observation(self):
return {"board":self.board.copy(),"player_turn":self.player_turn,"goats_to_place":np.array([self.NUM_GOATS-self.goats_placed_count],dtype=np.int32),"goats_captured":np.array([self.goats_captured_count],dtype=np.int32)}
def reset(self):
self.board=np.zeros(self.BOARD_POSITIONS,dtype=np.int32); self.board[0]=2; self.board[3]=2; self.board[4]=2;self.player_turn=0; self.goats_placed_count=0; self.goats_captured_count=0; self.turn_count=0
return self._get_current_observation()
def step(self, action):
is_valid, details = self.is_action_valid(action);reward, done, info = 0, False, details
if not is_valid: reward = -1
else:
if details['type']=='place': self.board[details['to_idx']]=1; self.goats_placed_count+=1
elif details['type']=='move':
p=self.board[details['from_idx']]; self.board[details['from_idx']]=0; self.board[details['to_idx']]=p
if p==2 and details.get('is_jump'): self.board[details['mid_idx']]=0; self.goats_captured_count+=1; reward=5
g_win=self._are_tigers_blocked(); t_win=self.goats_captured_count>=self.TIGER_WIN_THRESHOLD
if g_win: done=True; reward=100; info['winner']=0
elif t_win: done=True; reward=-100; info['winner']=1
self.player_turn=1-self.player_turn; self.turn_count+=1
if not done and self.turn_count>=self.MAX_TURNS: done=True; info['winner']=-1
return self._get_current_observation(), reward, done, info
def render(self, mode='rgb_array'):
fig,ax=plt.subplots(figsize=(8,8)); ax.clear();
ax.plot([1,23],[4,4],'k'); ax.plot([1,23],[8,8],'k'); ax.plot([1,23],[12,12],'k'); ax.plot([1,23],[16,16],'k'); ax.plot([1,1],[4,16],'k'); ax.plot([23,23],[4,16],'k'); ax.plot([1,12,23],[4,20,4],'k'); ax.plot([7,12,17],[4,20,4],'k')
for i in range(self.BOARD_POSITIONS):
p,x,y=i+1,self.board_points[i+1][0],self.board_points[i+1][1]
if self.board[i]==1: ax.plot(x,y,'o',ms=20,mfc='royalblue',mec='k',zorder=2); ax.text(x,y,'G',color='w',ha='center',va='center',fontsize=12,fontweight='bold')
elif self.board[i]==2: ax.plot(x,y,'o',ms=25,mfc='orangered',mec='k',zorder=2); ax.text(x,y,'T',color='w',ha='center',va='center',fontsize=12,fontweight='bold')
else: ax.plot(x,y,'o',ms=20,mfc='lightgray',mec='k',zorder=1); ax.text(x,y,str(p),color='k',ha='center',va='center',fontsize=8)
ax.set_xlim(0,24); ax.set_ylim(0,21); ax.set_aspect('equal'); ax.axis('off')
turn_txt="Goat's Turn" if self.player_turn==0 else "Tiger's Turn"
title=f"Aadu Puli Aattam\n{turn_txt}\nGoats to Place: {self.NUM_GOATS-self.goats_placed_count} | Goats Captured: {self.goats_captured_count}"
ax.set_title(title); plt.tight_layout();
fig.canvas.draw()
img_buf = fig.canvas.buffer_rgba()
img = np.frombuffer(img_buf, dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (4,))
img = img[:, :, :3]
plt.close(fig)
return img
def _get_board_coordinates(self):
return {1:(12,20), 2:(1,16),3:(9.2,16),4:(10.7,16),5:(13.3,16),6:(14.8,16),7:(23,16), 8:(1,12),9:(6.5,12),10:(9.5,12),11:(14.5,12),12:(17.5,12),13:(23,12), 14:(1,8),15:(3.8,8),16:(8.3,8),17:(15.7,8),18:(20.3,8),19:(23,8), 20:(1,4),21:(7,4),22:(17,4),23:(23,4)}
def copy(self):
new_env = AaduPulliEnv(); new_env.board = self.board.copy(); new_env.player_turn = self.player_turn; new_env.goats_placed_count = self.goats_placed_count; new_env.goats_captured_count = self.goats_captured_count; new_env.turn_count = self.turn_count
return new_env
# NeuralNetwork
class NeuralNetwork:
def __init__(self, action_space_size, learning_rate=0.001):
self.state_shape = (23, 23, 4)
self.action_space_size = action_space_size
self.learning_rate = learning_rate
self.model = self._build_model()
def _build_model(self):
input_layer = Input(shape=self.state_shape, name='matrix_input')
x = Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(input_layer)
x = Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu')(x)
x = Flatten()(x)
x = Dense(256, activation='relu')(x)
policy_output = Dense(self.action_space_size, activation='softmax', name='policy_output')(x)
value_output = Dense(1, activation='tanh', name='value_output')(x)
model = Model(inputs=input_layer, outputs=[policy_output, value_output])
model.compile(optimizer=Adam(self.learning_rate), loss={'policy_output': 'categorical_crossentropy', 'value_output': 'mean_squared_error'})
return model
def predict(self, matrix_state):
return self.model.predict(np.expand_dims(matrix_state, axis=0), verbose=0)
# AlphaZeroAgent
class AlphaZeroAgent:
def _zero_array_factory(self): return np.zeros(self.env.action_space.n)
def __init__(self, env, network, simulations_per_move=50, max_depth=25, c_puct=1.0):
self.env = env; self.network = network; self.simulations_per_move = simulations_per_move; self.max_depth = max_depth; self.c_puct = c_puct;
self.Q = collections.defaultdict(self._zero_array_factory); self.N_sa = collections.defaultdict(self._zero_array_factory); self.N_s = collections.defaultdict(int); self.P = {}
def _get_matrix_state(self, state):
board = state['board']; num_nodes = len(board); matrix = np.zeros((num_nodes, num_nodes, 4), dtype=np.float32)
np.fill_diagonal(matrix[:, :, 0], board == 1); np.fill_diagonal(matrix[:, :, 1], board == 2)
adj_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
for start, end_list in self.env.adj.items():
for end in end_list: adj_matrix[start - 1, end - 1] = 1
matrix[:, :, 2] = adj_matrix; matrix[:, :, 3] = state['player_turn']
return matrix
def search(self, state, depth):
if depth >= self.max_depth:
_, value = self.network.predict(self._get_matrix_state(state)); return -value[0][0]
state_key = self._get_state_key(state)
if state_key not in self.P:
policy, value = self.network.predict(self._get_matrix_state(state)); self.P[state_key] = policy[0]; return -value[0][0]
node_env = self.env.copy(); node_env.board = state['board'].copy(); node_env.player_turn = state['player_turn']; node_env.goats_placed_count = self.env.NUM_GOATS - state['goats_to_place'][0]; node_env.goats_captured_count = state['goats_captured'][0]
best_ucb = -np.inf; best_action = -1
valid_actions = [a for a in range(node_env.action_space.n) if node_env.is_action_valid(a)[0]]
for action in valid_actions:
q_value = self.Q[state_key][action]; ucb = q_value + self.c_puct * self.P[state_key][action] * np.sqrt(self.N_s[state_key]) / (1 + self.N_sa[state_key][action]);
if ucb > best_ucb: best_ucb = ucb; best_action = action
if best_action == -1: return 0
action = best_action
next_state, _, done, info = node_env.step(action)
if done:
winner = info.get('winner', -1); value = 0
if winner != -1: value = 1 if winner == state['player_turn'] else -1
else: value = self.search(next_state, depth + 1)
self.Q[state_key][action] = (self.N_sa[state_key][action] * self.Q[state_key][action] + value) / (self.N_sa[state_key][action] + 1); self.N_sa[state_key][action] += 1; self.N_s[state_key] += 1
return -value
def get_action(self, state, training=False):
state_key = self._get_state_key(state)
for _ in range(self.simulations_per_move): self.search(state, 0)
visit_counts = self.N_sa[state_key]
if np.sum(visit_counts) == 0:
valid_actions = [a for a in range(self.env.action_space.n) if self.env.is_action_valid(a)[0]]
return np.random.choice(valid_actions) if valid_actions else 0
if training:
tau = 1.0; action_probs = visit_counts**(1/tau);
if np.sum(action_probs) > 0: action_probs /= np.sum(action_probs)
else:
valid_actions=[a for a in range(self.env.action_space.n) if self.env.is_action_valid(a)[0]]; action_probs=np.zeros(self.env.action_space.n)
if valid_actions: action_probs[valid_actions] = 1 / len(valid_actions)
action = np.random.choice(self.env.action_space.n, p=action_probs)
else: action = np.argmax(visit_counts)
return action
def _get_state_key(self, state): return (state['board'].tobytes(), state['player_turn'])
|