| import streamlit as st |
| import numpy as np |
|
|
| import Arena |
|
|
| from MCTS import MCTS |
|
|
| from no_one.NoOneGame import NoOneGame |
|
|
|
|
| from no_one.pytorch.NNet import NNetWrapper as NNet |
| from utils import * |
|
|
| import time |
|
|
|
|
| square_content = { |
| -1: "❌", |
| +0: "·", |
| +1: "⭕" |
| } |
|
|
| players = [{"name": "❌"}, {"name": "⭕"}] |
| player_types = [":rainbow[AI]", "Human"] |
| ai_ranks = ["Super", "First Blood", "Random"] |
| ai_models = { |
| "Super": "super.pth.tar", |
| "First Blood": "firstBlood.pth.tar", |
| "Random": "random.pth.tar", |
| } |
|
|
| game = NoOneGame(4) |
|
|
| def check_clicked(i, j): |
| if st.session_state.clicked == (i, j): |
| return True |
| return False |
|
|
| class HumanPlayer: |
| def __init__(self, game): |
| self.game = game |
| def encode(self, src, target): |
| row, col = src[0] - target[0], src[1] - target[1] |
| if (row, col) == (0, -1): |
| d = 1 |
| elif (row, col) == (0, 1): |
| d = 3 |
| elif (row, col) == (-1, 0): |
| d = 0 |
| elif (row, col) == (1, 0): |
| d = 2 |
| else: |
| return None |
| return self.game.encodeAction(src, d) |
| def play(self, board, player, src, target): |
| action = self.encode(src, target) |
| if action is None: |
| return None |
| valids = self.game.getValidMoves(board, player) |
| if valids[action] != 1: |
| return None |
| return action |
|
|
|
|
| def ai_player(ai_model): |
| n1 = NNet(game) |
| n1.load_checkpoint('./models/', ai_models[ai_model]) |
| args1 = dotdict({'numMCTSSims': 50, 'cpuct':1.0}) |
| msts1 = MCTS(game, n1, args1) |
| return lambda x: np.argmax(msts1.getActionProb(x, temp=0)) |
|
|
|
|
| def human_step(i, j): |
| if st.session_state.clicked is None: |
| if st.session_state.board[i, j] != st.session_state.player: |
| st.toast('Invalid move!') |
| return |
| st.session_state.clicked = (i, j) |
| return |
| elif st.session_state.clicked == (i, j): |
| |
| st.session_state.clicked = None |
| return |
| elif st.session_state.board[i, j] == st.session_state.player: |
| st.session_state.clicked = (i, j) |
| return |
| else: |
| |
| player = st.session_state.players[st.session_state.player + 1] |
| action = player.play(st.session_state.board, st.session_state.player, st.session_state.clicked, (i, j)) |
| if action is None: |
| st.toast('Invalid move!') |
| return |
| move(action) |
| st.session_state.clicked = None |
|
|
|
|
| def ai_step(): |
| p = st.session_state.player |
| action = st.session_state.players[p+1](game.getCanonicalForm(st.session_state.board, st.session_state.player)) |
| |
| |
| |
| move(action) |
|
|
|
|
| def move(action): |
| p = st.session_state.player |
| st.session_state.board, st.session_state.player = game.getNextState( |
| st.session_state.board, p, action, |
| ) |
|
|
| res = game.getGameEnded(st.session_state.board, p) |
| if res != 0: |
| st.session_state.winner = res*p |
| st.session_state.win[res*p] += 1 |
| st.balloons() |
|
|
| def reinit(): |
| st.session_state.reinit = True |
|
|
| def init(post_init=False): |
| st.session_state.reinit = False |
| if not post_init: |
| st.session_state.win = {-1: 0, 1: 0} |
|
|
| st.session_state.board = game.getInitBoard() |
| st.session_state.player = -1 |
| st.session_state.players = [None for i in range(2)] |
| for i, setting in enumerate(st.session_state.player_settings): |
| if setting["pt"] == player_types[1]: |
| st.session_state.players[i] = HumanPlayer(game) |
| else: |
| st.session_state.players[i] = ai_player(setting["ai_model"]) |
| st.session_state.players.insert(1, None) |
|
|
| st.session_state.winner = None |
| st.session_state.clicked = None |
|
|
|
|
| def player_to_index(p): |
| if p == -1: |
| return 0 |
| elif p == 1: |
| return 1 |
|
|
|
|
| def main(): |
| if "player_settings" not in st.session_state: |
| st.session_state.player_settings = [ |
| {"pt": player_types[1], "ai_model": ai_ranks[0]}, |
| {"pt": player_types[0], "ai_model": ai_ranks[1]}, |
| ] |
| if "reinit" not in st.session_state: |
| st.session_state.reinit = False |
|
|
| fire, settings = st.columns([1, 2]) |
| fire.button('New Game', on_click=init, args=(True,)) |
| with settings.expander( |
| 'Settings', |
| expanded=False, |
| ): |
| st.warning('Any setting changing will restart the game immediately', icon="⚠️") |
| for i, p in enumerate(st.columns([0.5, 0.5])): |
| with p: |
| |
| st.session_state.player_settings[i]["pt"] = st.radio( |
| f"Who will play %s" % players[i]["name"], |
| player_types, |
| key=f"xp_type_{i}", |
| horizontal=True, |
| index=player_types.index(st.session_state.player_settings[i]["pt"]), |
| on_change=reinit, |
| ) |
| if st.session_state.player_settings[i]["pt"] == player_types[0]: |
| st.session_state.player_settings[i]["ai_model"] = st.radio( |
| "AI rank", |
| ai_ranks, |
| key=f"xp_rank_{i}", |
| index=ai_ranks.index(st.session_state.player_settings[i]["ai_model"]), |
| on_change=reinit, |
| ) |
| |
| st.divider() |
| st.checkbox("Show me how to play", key="show_how_to_play") |
|
|
|
|
| if "board" not in st.session_state or st.session_state.reinit: |
| init() |
|
|
| if st.session_state.show_how_to_play: |
| with st.sidebar: |
| st.title("How to play") |
| how_to = '''1. 游戏的目标是吃掉对方的棋子 |
| 开局双方各有四枚棋子,被吃剩一枚棋子即可判负 |
| 2. 棋子可以向上下左右移动到空白位置 |
| 3. 如何吃子 |
| 1. 移动后的棋子需要跟本方其他某个棋子,在水平方向,或者垂直方向上连住 |
| 2. 如果连住后的棋子两端有对方一个棋子,那么这个对方的棋子就被吃掉 |
| 3. 如果对方也有两颗棋子,则互相不吃 |
| 4. 如果落子在对方连起来的两枚棋子一端,也不会被吃 |
| 5. 如果选择 AI vs AI,需要手动点击比分下面的回合按钮,触发下一次落子 |
| ''' |
| st.markdown(how_to) |
| |
|
|
| xp, score, op = st.columns([2, 8, 2]) |
| for i, p in enumerate([xp, op]): |
| p.title(players[i]["name"], anchor=False) |
| caption = st.session_state.player_settings[i]["pt"] |
| if st.session_state.player_settings[i]["pt"] == player_types[0]: |
| caption = st.session_state.player_settings[i]["ai_model"] + " " + caption |
| p.caption(caption) |
|
|
| |
|
|
| if st.session_state.player_settings[player_to_index(st.session_state.player)]['pt'] != "Human" and st.session_state.winner is None: |
| ai_step() |
|
|
| st.divider() |
|
|
| for i, row in enumerate(st.session_state.board): |
| cols = st.columns([5, 1, 1, 1, 1, 5]) |
| for j, field in enumerate(row): |
| cols[j + 1].button( |
| square_content[field], |
| key=f"{i}-{j}", |
| type=f'{"primary" if check_clicked(i, j) else "secondary"}', |
| on_click=human_step, |
| args=(i, j), |
| ) |
|
|
| s = score.columns([2, 2, 2]) |
| s[1].title(f'{st.session_state.win[-1]} : {st.session_state.win[1]}', anchor=False) |
| s[1].button( |
| f'{"❌" if st.session_state.player == -1 else "⭕"}\'s turn' |
| if not st.session_state.winner |
| else f'🏁 Game finished' |
| ) |
|
|
|
|
|
|
| if __name__ == '__main__': |
| main() |