aadu-pulli-atam / game_logic.py
s23deepak's picture
Update game_logic.py
615925a verified
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D
from tensorflow.keras.optimizers import Adam
import collections
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# AaduPulliEnv
class AaduPulliEnv(gym.Env):
metadata = {'render.modes': ['human', 'rgb_array']}
def __init__(self):
super(AaduPulliEnv, self).__init__()
self.NUM_GOATS = 15
self.NUM_TIGERS = 3
self.TIGER_WIN_THRESHOLD = 10
self.BOARD_POSITIONS = 23
self.MAX_TURNS = 200
self.adj = self._get_adjacency()
self.jump_adj = self._get_jump_adjacency()
self.placement_actions = self.BOARD_POSITIONS
self._move_action_map, self._move_action_lookup = self._create_move_maps()
self.move_actions_count = len(self._move_action_map)
total_actions = self.placement_actions + self.move_actions_count
self.action_space = spaces.Discrete(total_actions)
self.observation_space = spaces.Dict({
"board": spaces.Box(low=0, high=2, shape=(self.BOARD_POSITIONS,), dtype=np.int32),
"player_turn": spaces.Discrete(2),
"goats_to_place": spaces.Box(low=0, high=self.NUM_GOATS, shape=(1,), dtype=np.int32),
"goats_captured": spaces.Box(low=0, high=self.TIGER_WIN_THRESHOLD, shape=(1,), dtype=np.int32),
})
self.board_points = self._get_board_coordinates()
self.reset()
def _get_adjacency(self):
return {
1: [3, 4, 5, 6], 2: [3, 8], 3: [1, 4, 9, 2], 4: [1, 5, 10, 3], 5: [1, 6, 11, 4], 6: [1, 7, 12, 5], 7: [6, 13],
8: [2, 9, 14], 9: [3, 10, 15, 8], 10: [4, 11, 16, 9], 11: [5, 12, 17, 10], 12: [6, 13, 18, 11], 13: [7, 14, 12],
14: [8, 15], 15: [9, 16, 20, 14], 16: [10, 17, 21, 15], 17: [11, 18, 22, 16], 18: [12, 19, 23, 17], 19: [13, 18],
20: [15, 21], 21: [16, 20, 22], 22: [17, 21, 23], 23: [18, 22]
}
def _get_jump_adjacency(self):
return {
1: [9, 10, 11, 12], 2: [4, 14], 3: [5, 15], 4: [2, 6, 16], 5: [3, 7, 17], 6: [4, 18], 7: [5, 19],
8: [10], 9: [1, 11, 20], 10: [1, 8, 12, 21], 11: [1, 9, 13, 22], 12: [1, 10, 23], 13: [11],
14: [2, 16], 15: [3, 17], 16: [4, 14, 18], 17: [5, 15, 19], 18: [6, 16], 19: [7, 17],
20: [9, 22], 21: [10, 23], 22: [11, 20], 23: [12, 21]
}
def _create_move_maps(self):
action_map, action_lookup, index = {}, {}, 0
for start_pos in range(1, self.BOARD_POSITIONS + 1):
for end_pos in self.adj.get(start_pos, []):
move = (start_pos, end_pos); action_map[index] = move; action_lookup[move] = index; index += 1
for end_pos in self.jump_adj.get(start_pos, []):
move = (start_pos, end_pos)
if move not in action_lookup:
action_map[index] = move; action_lookup[move] = index; index += 1
return action_map, action_lookup
def is_action_valid(self, action):
if not (0 <= action < self.action_space.n): return False, {'error': 'Action out of bounds.'}
if action < self.placement_actions:
to_idx = action
if self.player_turn != 0 or self.goats_placed_count >= self.NUM_GOATS: return False, {'error': 'Cannot place piece now.'}
if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'}
return True, {'type': 'place', 'to_idx': to_idx}
move_idx = action - self.placement_actions; from_pos, to_pos = self._move_action_map[move_idx]; from_idx, to_idx = from_pos - 1, to_pos - 1
if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'}
if self.player_turn == 0:
if self.goats_placed_count < self.NUM_GOATS: return False, {'error': 'Goat is still in placement phase.'}
if self.board[from_idx] != 1: return False, {'error': 'Player must move a goat.'}
if to_pos not in self.adj.get(from_pos, []): return False, {'error': 'Goat can only move to adjacent squares.'}
return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx}
else:
if self.board[from_idx] != 2: return False, {'error': 'Player must move a tiger.'}
if to_pos in self.adj.get(from_pos, []): return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': False}
if to_pos in self.jump_adj.get(from_pos, []):
from_neighbors = set(self.adj.get(from_pos, [])); to_neighbors = set(self.adj.get(to_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors)
if mid_pos_set:
mid_pos = mid_pos_set.pop()
if self.board[mid_pos - 1] == 1: return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': True, 'mid_idx': mid_pos - 1}
return False, {'error': 'Invalid tiger move.'}
def _are_tigers_blocked(self):
for t_idx in np.where(self.board == 2)[0]:
t_pos = t_idx + 1
for dest_pos in self.adj.get(t_pos, []):
if self.board[dest_pos - 1] == 0: return False
for dest_pos in self.jump_adj.get(t_pos, []):
if self.board[dest_pos - 1] == 0:
from_neighbors = set(self.adj.get(t_pos, [])); to_neighbors = set(self.adj.get(dest_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors)
if mid_pos_set and self.board[mid_pos_set.pop() - 1] == 1: return False
return True
def _get_current_observation(self):
return {"board":self.board.copy(),"player_turn":self.player_turn,"goats_to_place":np.array([self.NUM_GOATS-self.goats_placed_count],dtype=np.int32),"goats_captured":np.array([self.goats_captured_count],dtype=np.int32)}
def reset(self):
self.board=np.zeros(self.BOARD_POSITIONS,dtype=np.int32); self.board[0]=2; self.board[3]=2; self.board[4]=2;self.player_turn=0; self.goats_placed_count=0; self.goats_captured_count=0; self.turn_count=0
return self._get_current_observation()
def step(self, action):
is_valid, details = self.is_action_valid(action);reward, done, info = 0, False, details
if not is_valid: reward = -1
else:
if details['type']=='place': self.board[details['to_idx']]=1; self.goats_placed_count+=1
elif details['type']=='move':
p=self.board[details['from_idx']]; self.board[details['from_idx']]=0; self.board[details['to_idx']]=p
if p==2 and details.get('is_jump'): self.board[details['mid_idx']]=0; self.goats_captured_count+=1; reward=5
g_win=self._are_tigers_blocked(); t_win=self.goats_captured_count>=self.TIGER_WIN_THRESHOLD
if g_win: done=True; reward=100; info['winner']=0
elif t_win: done=True; reward=-100; info['winner']=1
self.player_turn=1-self.player_turn; self.turn_count+=1
if not done and self.turn_count>=self.MAX_TURNS: done=True; info['winner']=-1
return self._get_current_observation(), reward, done, info
def render(self, mode='rgb_array'):
fig,ax=plt.subplots(figsize=(8,8)); ax.clear();
ax.plot([1,23],[4,4],'k'); ax.plot([1,23],[8,8],'k'); ax.plot([1,23],[12,12],'k'); ax.plot([1,23],[16,16],'k'); ax.plot([1,1],[4,16],'k'); ax.plot([23,23],[4,16],'k'); ax.plot([1,12,23],[4,20,4],'k'); ax.plot([7,12,17],[4,20,4],'k')
for i in range(self.BOARD_POSITIONS):
p,x,y=i+1,self.board_points[i+1][0],self.board_points[i+1][1]
if self.board[i]==1: ax.plot(x,y,'o',ms=20,mfc='royalblue',mec='k',zorder=2); ax.text(x,y,'G',color='w',ha='center',va='center',fontsize=12,fontweight='bold')
elif self.board[i]==2: ax.plot(x,y,'o',ms=25,mfc='orangered',mec='k',zorder=2); ax.text(x,y,'T',color='w',ha='center',va='center',fontsize=12,fontweight='bold')
else: ax.plot(x,y,'o',ms=20,mfc='lightgray',mec='k',zorder=1); ax.text(x,y,str(p),color='k',ha='center',va='center',fontsize=8)
ax.set_xlim(0,24); ax.set_ylim(0,21); ax.set_aspect('equal'); ax.axis('off')
turn_txt="Goat's Turn" if self.player_turn==0 else "Tiger's Turn"
title=f"Aadu Puli Aattam\n{turn_txt}\nGoats to Place: {self.NUM_GOATS-self.goats_placed_count} | Goats Captured: {self.goats_captured_count}"
ax.set_title(title); plt.tight_layout();
fig.canvas.draw()
img_buf = fig.canvas.buffer_rgba()
img = np.frombuffer(img_buf, dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (4,))
img = img[:, :, :3]
plt.close(fig)
return img
def _get_board_coordinates(self):
return {1:(12,20), 2:(1,16),3:(9.2,16),4:(10.7,16),5:(13.3,16),6:(14.8,16),7:(23,16), 8:(1,12),9:(6.5,12),10:(9.5,12),11:(14.5,12),12:(17.5,12),13:(23,12), 14:(1,8),15:(3.8,8),16:(8.3,8),17:(15.7,8),18:(20.3,8),19:(23,8), 20:(1,4),21:(7,4),22:(17,4),23:(23,4)}
def copy(self):
new_env = AaduPulliEnv(); new_env.board = self.board.copy(); new_env.player_turn = self.player_turn; new_env.goats_placed_count = self.goats_placed_count; new_env.goats_captured_count = self.goats_captured_count; new_env.turn_count = self.turn_count
return new_env
# NeuralNetwork
class NeuralNetwork:
def __init__(self, action_space_size, learning_rate=0.001):
self.state_shape = (23, 23, 4)
self.action_space_size = action_space_size
self.learning_rate = learning_rate
self.model = self._build_model()
def _build_model(self):
input_layer = Input(shape=self.state_shape, name='matrix_input')
x = Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(input_layer)
x = Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu')(x)
x = Flatten()(x)
x = Dense(256, activation='relu')(x)
policy_output = Dense(self.action_space_size, activation='softmax', name='policy_output')(x)
value_output = Dense(1, activation='tanh', name='value_output')(x)
model = Model(inputs=input_layer, outputs=[policy_output, value_output])
model.compile(optimizer=Adam(self.learning_rate), loss={'policy_output': 'categorical_crossentropy', 'value_output': 'mean_squared_error'})
return model
def predict(self, matrix_state):
return self.model.predict(np.expand_dims(matrix_state, axis=0), verbose=0)
# AlphaZeroAgent
class AlphaZeroAgent:
def _zero_array_factory(self): return np.zeros(self.env.action_space.n)
def __init__(self, env, network, simulations_per_move=50, max_depth=25, c_puct=1.0):
self.env = env; self.network = network; self.simulations_per_move = simulations_per_move; self.max_depth = max_depth; self.c_puct = c_puct;
self.Q = collections.defaultdict(self._zero_array_factory); self.N_sa = collections.defaultdict(self._zero_array_factory); self.N_s = collections.defaultdict(int); self.P = {}
def _get_matrix_state(self, state):
board = state['board']; num_nodes = len(board); matrix = np.zeros((num_nodes, num_nodes, 4), dtype=np.float32)
np.fill_diagonal(matrix[:, :, 0], board == 1); np.fill_diagonal(matrix[:, :, 1], board == 2)
adj_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
for start, end_list in self.env.adj.items():
for end in end_list: adj_matrix[start - 1, end - 1] = 1
matrix[:, :, 2] = adj_matrix; matrix[:, :, 3] = state['player_turn']
return matrix
def search(self, state, depth):
if depth >= self.max_depth:
_, value = self.network.predict(self._get_matrix_state(state)); return -value[0][0]
state_key = self._get_state_key(state)
if state_key not in self.P:
policy, value = self.network.predict(self._get_matrix_state(state)); self.P[state_key] = policy[0]; return -value[0][0]
node_env = self.env.copy(); node_env.board = state['board'].copy(); node_env.player_turn = state['player_turn']; node_env.goats_placed_count = self.env.NUM_GOATS - state['goats_to_place'][0]; node_env.goats_captured_count = state['goats_captured'][0]
best_ucb = -np.inf; best_action = -1
valid_actions = [a for a in range(node_env.action_space.n) if node_env.is_action_valid(a)[0]]
for action in valid_actions:
q_value = self.Q[state_key][action]; ucb = q_value + self.c_puct * self.P[state_key][action] * np.sqrt(self.N_s[state_key]) / (1 + self.N_sa[state_key][action]);
if ucb > best_ucb: best_ucb = ucb; best_action = action
if best_action == -1: return 0
action = best_action
next_state, _, done, info = node_env.step(action)
if done:
winner = info.get('winner', -1); value = 0
if winner != -1: value = 1 if winner == state['player_turn'] else -1
else: value = self.search(next_state, depth + 1)
self.Q[state_key][action] = (self.N_sa[state_key][action] * self.Q[state_key][action] + value) / (self.N_sa[state_key][action] + 1); self.N_sa[state_key][action] += 1; self.N_s[state_key] += 1
return -value
def get_action(self, state, training=False):
state_key = self._get_state_key(state)
for _ in range(self.simulations_per_move): self.search(state, 0)
visit_counts = self.N_sa[state_key]
if np.sum(visit_counts) == 0:
valid_actions = [a for a in range(self.env.action_space.n) if self.env.is_action_valid(a)[0]]
return np.random.choice(valid_actions) if valid_actions else 0
if training:
tau = 1.0; action_probs = visit_counts**(1/tau);
if np.sum(action_probs) > 0: action_probs /= np.sum(action_probs)
else:
valid_actions=[a for a in range(self.env.action_space.n) if self.env.is_action_valid(a)[0]]; action_probs=np.zeros(self.env.action_space.n)
if valid_actions: action_probs[valid_actions] = 1 / len(valid_actions)
action = np.random.choice(self.env.action_space.n, p=action_probs)
else: action = np.argmax(visit_counts)
return action
def _get_state_key(self, state): return (state['board'].tobytes(), state['player_turn'])