s23deepak commited on
Commit
2cf29cb
·
verified ·
1 Parent(s): 4a38cd3

Upload 2 files

Browse files
Files changed (2) hide show
  1. game_logic.py +227 -0
  2. requirements.txt +6 -0
game_logic.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /game_logic.py
2
+ import gymnasium as gym
3
+ from gym import spaces
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ from tensorflow.keras.models import Model
7
+ from tensorflow.keras.layers import Input, Dense, Flatten, Conv2D
8
+ from tensorflow.keras.optimizers import Adam
9
+ import collections
10
+ import matplotlib
11
+ matplotlib.use('Agg')
12
+ import matplotlib.pyplot as plt
13
+
14
+ # AaduPulliEnv
15
+ class AaduPulliEnv(gym.Env):
16
+ metadata = {'render.modes': ['human', 'rgb_array']}
17
+ def __init__(self):
18
+ super(AaduPulliEnv, self).__init__()
19
+ self.NUM_GOATS = 15
20
+ self.NUM_TIGERS = 3
21
+ self.TIGER_WIN_THRESHOLD = 10
22
+ self.BOARD_POSITIONS = 23
23
+ self.MAX_TURNS = 200
24
+ self.adj = self._get_adjacency()
25
+ self.jump_adj = self._get_jump_adjacency()
26
+ self.placement_actions = self.BOARD_POSITIONS
27
+ self._move_action_map, self._move_action_lookup = self._create_move_maps()
28
+ self.move_actions_count = len(self._move_action_map)
29
+ total_actions = self.placement_actions + self.move_actions_count
30
+ self.action_space = spaces.Discrete(total_actions)
31
+ self.observation_space = spaces.Dict({
32
+ "board": spaces.Box(low=0, high=2, shape=(self.BOARD_POSITIONS,), dtype=np.int32),
33
+ "player_turn": spaces.Discrete(2),
34
+ "goats_to_place": spaces.Box(low=0, high=self.NUM_GOATS, shape=(1,), dtype=np.int32),
35
+ "goats_captured": spaces.Box(low=0, high=self.TIGER_WIN_THRESHOLD, shape=(1,), dtype=np.int32),
36
+ })
37
+ self.board_points = self._get_board_coordinates()
38
+ self.reset()
39
+
40
+ def _get_adjacency(self):
41
+ return {
42
+ 1: [3, 4, 5, 6], 2: [3, 8], 3: [1, 4, 9, 2], 4: [1, 5, 10, 3], 5: [1, 6, 11, 4], 6: [1, 7, 12, 5], 7: [6, 13],
43
+ 8: [2, 9, 14], 9: [3, 10, 15, 8], 10: [4, 11, 16, 9], 11: [5, 12, 17, 10], 12: [6, 13, 18, 11], 13: [7, 14, 12],
44
+ 14: [8, 15], 15: [9, 16, 20, 14], 16: [10, 17, 21, 15], 17: [11, 18, 22, 16], 18: [12, 19, 23, 17], 19: [13, 18],
45
+ 20: [15, 21], 21: [16, 20, 22], 22: [17, 21, 23], 23: [18, 22]
46
+ }
47
+
48
+ def _get_jump_adjacency(self):
49
+ return {
50
+ 1: [9, 10, 11, 12], 2: [4, 14], 3: [5, 15], 4: [2, 6, 16], 5: [3, 7, 17], 6: [4, 18], 7: [5, 19],
51
+ 8: [10], 9: [1, 11, 20], 10: [1, 8, 12, 21], 11: [1, 9, 13, 22], 12: [1, 10, 23], 13: [11],
52
+ 14: [2, 16], 15: [3, 17], 16: [4, 14, 18], 17: [5, 15, 19], 18: [6, 16], 19: [7, 17],
53
+ 20: [9, 22], 21: [10, 23], 22: [11, 20], 23: [12, 21]
54
+ }
55
+
56
+ def _create_move_maps(self):
57
+ action_map, action_lookup, index = {}, {}, 0
58
+ for start_pos in range(1, self.BOARD_POSITIONS + 1):
59
+ for end_pos in self.adj.get(start_pos, []):
60
+ move = (start_pos, end_pos); action_map[index] = move; action_lookup[move] = index; index += 1
61
+ for end_pos in self.jump_adj.get(start_pos, []):
62
+ move = (start_pos, end_pos)
63
+ if move not in action_lookup:
64
+ action_map[index] = move; action_lookup[move] = index; index += 1
65
+ return action_map, action_lookup
66
+
67
+ def is_action_valid(self, action):
68
+ if not (0 <= action < self.action_space.n): return False, {'error': 'Action out of bounds.'}
69
+ if action < self.placement_actions:
70
+ to_idx = action
71
+ if self.player_turn != 0 or self.goats_placed_count >= self.NUM_GOATS: return False, {'error': 'Cannot place piece now.'}
72
+ if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'}
73
+ return True, {'type': 'place', 'to_idx': to_idx}
74
+ move_idx = action - self.placement_actions; from_pos, to_pos = self._move_action_map[move_idx]; from_idx, to_idx = from_pos - 1, to_pos - 1
75
+ if self.board[to_idx] != 0: return False, {'error': 'Destination square is not empty.'}
76
+ if self.player_turn == 0:
77
+ if self.goats_placed_count < self.NUM_GOATS: return False, {'error': 'Goat is still in placement phase.'}
78
+ if self.board[from_idx] != 1: return False, {'error': 'Player must move a goat.'}
79
+ if to_pos not in self.adj.get(from_pos, []): return False, {'error': 'Goat can only move to adjacent squares.'}
80
+ return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx}
81
+ else:
82
+ if self.board[from_idx] != 2: return False, {'error': 'Player must move a tiger.'}
83
+ if to_pos in self.adj.get(from_pos, []): return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': False}
84
+ if to_pos in self.jump_adj.get(from_pos, []):
85
+ from_neighbors = set(self.adj.get(from_pos, [])); to_neighbors = set(self.adj.get(to_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors)
86
+ if mid_pos_set:
87
+ mid_pos = mid_pos_set.pop()
88
+ if self.board[mid_pos - 1] == 1: return True, {'type': 'move', 'from_idx': from_idx, 'to_idx': to_idx, 'is_jump': True, 'mid_idx': mid_pos - 1}
89
+ return False, {'error': 'Invalid tiger move.'}
90
+
91
+ def _are_tigers_blocked(self):
92
+ for t_idx in np.where(self.board == 2)[0]:
93
+ t_pos = t_idx + 1
94
+ for dest_pos in self.adj.get(t_pos, []):
95
+ if self.board[dest_pos - 1] == 0: return False
96
+ for dest_pos in self.jump_adj.get(t_pos, []):
97
+ if self.board[dest_pos - 1] == 0:
98
+ from_neighbors = set(self.adj.get(t_pos, [])); to_neighbors = set(self.adj.get(dest_pos, [])); mid_pos_set = from_neighbors.intersection(to_neighbors)
99
+ if mid_pos_set and self.board[mid_pos_set.pop() - 1] == 1: return False
100
+ return True
101
+
102
+ def _get_current_observation(self):
103
+ return {"board":self.board.copy(),"player_turn":self.player_turn,"goats_to_place":np.array([self.NUM_GOATS-self.goats_placed_count],dtype=np.int32),"goats_captured":np.array([self.goats_captured_count],dtype=np.int32)}
104
+
105
+ def reset(self):
106
+ self.board=np.zeros(self.BOARD_POSITIONS,dtype=np.int32); self.board[0]=2; self.board[3]=2; self.board[4]=2;self.player_turn=0; self.goats_placed_count=0; self.goats_captured_count=0; self.turn_count=0
107
+ return self._get_current_observation()
108
+
109
+ def step(self, action):
110
+ is_valid, details = self.is_action_valid(action);reward, done, info = 0, False, details
111
+ if not is_valid: reward = -1
112
+ else:
113
+ if details['type']=='place': self.board[details['to_idx']]=1; self.goats_placed_count+=1
114
+ elif details['type']=='move':
115
+ p=self.board[details['from_idx']]; self.board[details['from_idx']]=0; self.board[details['to_idx']]=p
116
+ if p==2 and details.get('is_jump'): self.board[details['mid_idx']]=0; self.goats_captured_count+=1; reward=5
117
+ g_win=self._are_tigers_blocked(); t_win=self.goats_captured_count>=self.TIGER_WIN_THRESHOLD
118
+ if g_win: done=True; reward=100; info['winner']=0
119
+ elif t_win: done=True; reward=-100; info['winner']=1
120
+ self.player_turn=1-self.player_turn; self.turn_count+=1
121
+ if not done and self.turn_count>=self.MAX_TURNS: done=True; info['winner']=-1
122
+ return self._get_current_observation(), reward, done, info
123
+
124
+ def render(self, mode='rgb_array'):
125
+ fig,ax=plt.subplots(figsize=(8,8)); ax.clear();
126
+ ax.plot([1,23],[4,4],'k'); ax.plot([1,23],[8,8],'k'); ax.plot([1,23],[12,12],'k'); ax.plot([1,23],[16,16],'k'); ax.plot([1,1],[4,16],'k'); ax.plot([23,23],[4,16],'k'); ax.plot([1,12,23],[4,20,4],'k'); ax.plot([7,12,17],[4,20,4],'k')
127
+ for i in range(self.BOARD_POSITIONS):
128
+ p,x,y=i+1,self.board_points[i+1][0],self.board_points[i+1][1]
129
+ if self.board[i]==1: ax.plot(x,y,'o',ms=20,mfc='royalblue',mec='k',zorder=2); ax.text(x,y,'G',color='w',ha='center',va='center',fontsize=12,fontweight='bold')
130
+ elif self.board[i]==2: ax.plot(x,y,'o',ms=25,mfc='orangered',mec='k',zorder=2); ax.text(x,y,'T',color='w',ha='center',va='center',fontsize=12,fontweight='bold')
131
+ else: ax.plot(x,y,'o',ms=20,mfc='lightgray',mec='k',zorder=1); ax.text(x,y,str(p),color='k',ha='center',va='center',fontsize=8)
132
+ ax.set_xlim(0,24); ax.set_ylim(0,21); ax.set_aspect('equal'); ax.axis('off')
133
+ turn_txt="Goat's Turn" if self.player_turn==0 else "Tiger's Turn"
134
+ title=f"Aadu Puli Aattam\n{turn_txt}\nGoats to Place: {self.NUM_GOATS-self.goats_placed_count} | Goats Captured: {self.goats_captured_count}"
135
+ ax.set_title(title); plt.tight_layout();
136
+ fig.canvas.draw()
137
+ img_buf = fig.canvas.buffer_rgba()
138
+ img = np.frombuffer(img_buf, dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (4,))
139
+ img = img[:, :, :3]
140
+ plt.close(fig)
141
+ return img
142
+
143
+ def _get_board_coordinates(self):
144
+ return {1:(12,20), 2:(1,16),3:(9.2,16),4:(10.7,16),5:(13.3,16),6:(14.8,16),7:(23,16), 8:(1,12),9:(6.5,12),10:(9.5,12),11:(14.5,12),12:(17.5,12),13:(23,12), 14:(1,8),15:(3.8,8),16:(8.3,8),17:(15.7,8),18:(20.3,8),19:(23,8), 20:(1,4),21:(7,4),22:(17,4),23:(23,4)}
145
+
146
+ def copy(self):
147
+ new_env = AaduPulliEnv(); new_env.board = self.board.copy(); new_env.player_turn = self.player_turn; new_env.goats_placed_count = self.goats_placed_count; new_env.goats_captured_count = self.goats_captured_count; new_env.turn_count = self.turn_count
148
+ return new_env
149
+
150
+ # NeuralNetwork
151
+ class NeuralNetwork:
152
+ def __init__(self, action_space_size, learning_rate=0.001):
153
+ self.state_shape = (23, 23, 4)
154
+ self.action_space_size = action_space_size
155
+ self.learning_rate = learning_rate
156
+ self.model = self._build_model()
157
+
158
+ def _build_model(self):
159
+ input_layer = Input(shape=self.state_shape, name='matrix_input')
160
+ x = Conv2D(filters=64, kernel_size=(3, 3), padding='same', activation='relu')(input_layer)
161
+ x = Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu')(x)
162
+ x = Flatten()(x)
163
+ x = Dense(256, activation='relu')(x)
164
+ policy_output = Dense(self.action_space_size, activation='softmax', name='policy_output')(x)
165
+ value_output = Dense(1, activation='tanh', name='value_output')(x)
166
+ model = Model(inputs=input_layer, outputs=[policy_output, value_output])
167
+ model.compile(optimizer=Adam(self.learning_rate), loss={'policy_output': 'categorical_crossentropy', 'value_output': 'mean_squared_error'})
168
+ return model
169
+
170
+ def predict(self, matrix_state):
171
+ return self.model.predict(np.expand_dims(matrix_state, axis=0), verbose=0)
172
+
173
+ # AlphaZeroAgent
174
+ class AlphaZeroAgent:
175
+ def _zero_array_factory(self): return np.zeros(self.env.action_space.n)
176
+ def __init__(self, env, network, simulations_per_move=50, max_depth=25, c_puct=1.0):
177
+ self.env = env; self.network = network; self.simulations_per_move = simulations_per_move; self.max_depth = max_depth; self.c_puct = c_puct;
178
+ self.Q = collections.defaultdict(self._zero_array_factory); self.N_sa = collections.defaultdict(self._zero_array_factory); self.N_s = collections.defaultdict(int); self.P = {}
179
+
180
+ def _get_matrix_state(self, state):
181
+ board = state['board']; num_nodes = len(board); matrix = np.zeros((num_nodes, num_nodes, 4), dtype=np.float32)
182
+ np.fill_diagonal(matrix[:, :, 0], board == 1); np.fill_diagonal(matrix[:, :, 1], board == 2)
183
+ adj_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
184
+ for start, end_list in self.env.adj.items():
185
+ for end in end_list: adj_matrix[start - 1, end - 1] = 1
186
+ matrix[:, :, 2] = adj_matrix; matrix[:, :, 3] = state['player_turn']
187
+ return matrix
188
+
189
+ def search(self, state, depth):
190
+ if depth >= self.max_depth:
191
+ _, value = self.network.predict(self._get_matrix_state(state)); return -value[0][0]
192
+ state_key = self._get_state_key(state)
193
+ if state_key not in self.P:
194
+ policy, value = self.network.predict(self._get_matrix_state(state)); self.P[state_key] = policy[0]; return -value[0][0]
195
+ node_env = self.env.copy(); node_env.board = state['board'].copy(); node_env.player_turn = state['player_turn']; node_env.goats_placed_count = self.env.NUM_GOATS - state['goats_to_place'][0]; node_env.goats_captured_count = state['goats_captured'][0]
196
+ best_ucb = -np.inf; best_action = -1
197
+ valid_actions = [a for a in range(node_env.action_space.n) if node_env.is_action_valid(a)[0]]
198
+ for action in valid_actions:
199
+ q_value = self.Q[state_key][action]; ucb = q_value + self.c_puct * self.P[state_key][action] * np.sqrt(self.N_s[state_key]) / (1 + self.N_sa[state_key][action]);
200
+ if ucb > best_ucb: best_ucb = ucb; best_action = action
201
+ if best_action == -1: return 0
202
+ action = best_action
203
+ next_state, _, done, info = node_env.step(action)
204
+ if done:
205
+ winner = info.get('winner', -1); value = 0
206
+ if winner != -1: value = 1 if winner == state['player_turn'] else -1
207
+ else: value = self.search(next_state, depth + 1)
208
+ self.Q[state_key][action] = (self.N_sa[state_key][action] * self.Q[state_key][action] + value) / (self.N_sa[state_key][action] + 1); self.N_sa[state_key][action] += 1; self.N_s[state_key] += 1
209
+ return -value
210
+
211
+ def get_action(self, state, training=False):
212
+ state_key = self._get_state_key(state)
213
+ for _ in range(self.simulations_per_move): self.search(state, 0)
214
+ visit_counts = self.N_sa[state_key]
215
+ if np.sum(visit_counts) == 0:
216
+ valid_actions = [a for a in range(self.env.action_space.n) if self.env.is_action_valid(a)[0]]
217
+ return np.random.choice(valid_actions) if valid_actions else 0
218
+ if training:
219
+ tau = 1.0; action_probs = visit_counts**(1/tau);
220
+ if np.sum(action_probs) > 0: action_probs /= np.sum(action_probs)
221
+ else:
222
+ valid_actions=[a for a in range(self.env.action_space.n) if self.env.is_action_valid(a)[0]]; action_probs=np.zeros(self.env.action_space.n)
223
+ if valid_actions: action_probs[valid_actions] = 1 / len(valid_actions)
224
+ action = np.random.choice(self.env.action_space.n, p=action_probs)
225
+ else: action = np.argmax(visit_counts)
226
+ return action
227
+ def _get_state_key(self, state): return (state['board'].tobytes(), state['player_turn'])
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ tensorflow
3
+ numpy
4
+ gym
5
+ matplotlib
6
+ huggingface-hub