Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from game_logic import AaduPulliEnv, NeuralNetwork, AlphaZeroAgent | |
| import numpy as np | |
| import time | |
| REPO_ID = "AaduPulliAttam/apa-ray" | |
| MODEL_FILENAME = "alphazero_aadu_pulli_ray.weights.h5" | |
| HUMAN_PLAYER = 0 # 0 for Goat, 1 for Tiger | |
| AI_PLAYER = 1 | |
| ai_agent = None | |
| try: | |
| print(f"Downloading model from {REPO_ID}...") | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) | |
| print(f"Model downloaded to: {model_path}") | |
| env_for_init = AaduPulliEnv() | |
| action_space_size = env_for_init.action_space.n | |
| neural_net = NeuralNetwork(action_space_size) | |
| neural_net.model.load_weights(model_path) | |
| print("Model weights loaded successfully.") | |
| ai_agent = AlphaZeroAgent(env_for_init, neural_net, simulations_per_move=50) | |
| except Exception as e: | |
| print(f"Could not load model. AI will be disabled. Error: {e}") | |
| def get_pos_from_click(env, click_coords): | |
| """Maps pixel coordinates from a click to the closest board position.""" | |
| if not click_coords: return None | |
| x_click, y_click = click_coords | |
| img_width, img_height = 800, 800 | |
| top_margin_ratio = 0.12 | |
| bottom_margin_ratio = 0.12 | |
| left_margin_ratio = 0.05 | |
| right_margin_ratio = 0.05 | |
| plot_width_px = img_width * (1 - left_margin_ratio - right_margin_ratio) | |
| plot_height_px = img_height * (1 - top_margin_ratio - bottom_margin_ratio) | |
| if not (img_width * left_margin_ratio < x_click < img_width * (1 - right_margin_ratio) and \ | |
| img_height * top_margin_ratio < y_click < img_height * (1 - bottom_margin_ratio)): | |
| return None | |
| x_rel = x_click - (img_width * left_margin_ratio) | |
| y_rel = y_click - (img_height * top_margin_ratio) | |
| x_scaled = (x_rel / plot_width_px) * 24 | |
| y_scaled = (1 - (y_rel / plot_height_px)) * 21 | |
| min_dist = float('inf') | |
| closest_pos = None | |
| for pos, (x_board, y_board) in env.board_points.items(): | |
| dist = np.sqrt((x_scaled - x_board)**2 + (y_scaled - y_board)**2) | |
| if dist < min_dist: | |
| min_dist = dist | |
| closest_pos = pos | |
| if min_dist > 1.5: return None | |
| return closest_pos | |
| def check_game_over(info, human_player_side): | |
| """Checks for game over conditions and returns a message.""" | |
| winner = info.get('winner', -1) | |
| if winner != -1: | |
| if winner == human_player_side: return "Congratulations, you won!", True | |
| elif winner == 1 - human_player_side: return "The AI has won. Better luck next time!", True | |
| else: return "The game is a draw!", True | |
| return None, False | |
| def handle_click(env, selected_pos, game_over, human_player_side, evt: gr.SelectData): | |
| """ | |
| Handles all user clicks on the board. | |
| - For placement, it executes the move and triggers the AI. | |
| - For movement, it handles both piece selection and destination clicking. | |
| """ | |
| if game_over or env.player_turn != human_player_side: | |
| message = "Game is over." if game_over else "It's not your turn." | |
| return env, env.render(highlight_pos=selected_pos), message, selected_pos, game_over | |
| if not ai_agent: | |
| return env, env.render(), "AI is not available.", None, True | |
| clicked_pos = get_pos_from_click(env, evt.index) | |
| if clicked_pos is None: | |
| return env, env.render(highlight_pos=selected_pos), "Please click on a valid position.", selected_pos, game_over | |
| is_placement_phase = env.goats_placed_count < env.NUM_GOATS and human_player_side == 0 | |
| if is_placement_phase: | |
| action = clicked_pos - 1 | |
| is_valid, details = env.is_action_valid(action) | |
| if not is_valid: | |
| return env, env.render(), f"Invalid placement: {details.get('error')}", None, False | |
| else: | |
| if selected_pos is None: | |
| piece_type = 1 if human_player_side == 0 else 2 | |
| if env.board[clicked_pos - 1] == piece_type: | |
| return env, env.render(highlight_pos=clicked_pos), f"Selected piece at {clicked_pos}. Click where to move.", clicked_pos, False | |
| else: | |
| return env, env.render(), "You must select one of your own pieces.", None, False | |
| else: | |
| from_pos, to_pos = selected_pos, clicked_pos | |
| move_tuple = (from_pos, to_pos) | |
| if move_tuple not in env._move_action_lookup: | |
| return env, env.render(), "Invalid move path.", None, False | |
| action = env.placement_actions + env._move_action_lookup[move_tuple] | |
| is_valid, details = env.is_action_valid(action) | |
| if not is_valid: | |
| return env, env.render(), f"Invalid move: {details.get('error')}", None, False | |
| # Human's move | |
| state, _, _, info = env.step(action) | |
| message, game_over = check_game_over(info, human_player_side) | |
| if game_over: | |
| return env, env.render(), message, None, True | |
| # AI's turn | |
| time.sleep(0.5) | |
| state = env._get_current_observation() | |
| ai_action = ai_agent.get_action(state, training=False) | |
| state, _, _, info = env.step(ai_action) | |
| action_type = "Placement" if ai_action < env.placement_actions else "Move" | |
| details_txt = ai_action + 1 if action_type == "Placement" else env._move_action_map.get(ai_action - env.placement_actions, "N/A") | |
| message, game_over = check_game_over(info, human_player_side) | |
| if game_over: | |
| return env, env.render(), message, None, game_over | |
| final_message = f"AI played {action_type} {details_txt}. Your turn." | |
| return env, env.render(), final_message, None, game_over | |
| def start_game(player_choice): | |
| """Starts a new game and handles the initial AI move if necessary.""" | |
| env = AaduPulliEnv() | |
| human_player_side = 0 if player_choice == "Goat" else 1 | |
| message = "New game started. Your turn!" | |
| game_over = False | |
| if human_player_side == 1: | |
| if ai_agent: | |
| state = env._get_current_observation() | |
| ai_action = ai_agent.get_action(state, training=False) | |
| state, _, _, info = env.step(ai_action) | |
| print(f"state {state}") | |
| message = "New game started. AI placed a goat. Your turn to move a Tiger." | |
| else: | |
| message = "AI not available. Cannot start game as Tiger." | |
| game_over = True | |
| return { | |
| start_screen: gr.update(visible=False), | |
| game_screen: gr.update(visible=True), | |
| game_env_state: env, | |
| board_image: env.render(), | |
| status_message: message, | |
| selected_pos_state: None, | |
| game_over_state: game_over, | |
| human_player_state: human_player_side | |
| } | |
| def go_to_start_screen(): | |
| return { | |
| start_screen: gr.update(visible=True), | |
| game_screen: gr.update(visible=False), | |
| } | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Aadu Puli Aattam: Human vs AlphaZero AI") | |
| game_env_state = gr.State() | |
| selected_pos_state = gr.State() | |
| game_over_state = gr.State(False) | |
| human_player_state = gr.State() | |
| with gr.Column(visible=True) as start_screen: | |
| gr.Markdown("## Choose Your Side") | |
| player_choice_radio = gr.Radio(["Goat", "Tiger"], label="Play as", value="Goat") | |
| start_button = gr.Button("Start Game") | |
| with gr.Column(visible=False) as game_screen: | |
| gr.Markdown("You play as the Goats (Blue) or the Tigers (Red). Your goal is to trap the opponent or capture their pieces.") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| board_image = gr.Image(label="Game Board", interactive=True) | |
| with gr.Column(scale=1): | |
| status_message = gr.Markdown(label="Game Status") | |
| new_game_button = gr.Button("New Game") | |
| start_button.click( | |
| fn=start_game, | |
| inputs=[player_choice_radio], | |
| outputs=[start_screen, game_screen, game_env_state, board_image, status_message, selected_pos_state, game_over_state, human_player_state] | |
| ) | |
| board_image.select( | |
| fn=handle_click, | |
| inputs=[game_env_state, selected_pos_state, game_over_state, human_player_state], | |
| outputs=[game_env_state, board_image, status_message, selected_pos_state, game_over_state] | |
| ) | |
| new_game_button.click( | |
| fn=go_to_start_screen, | |
| outputs=[start_screen, game_screen] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |