| import numpy as np |
| import torch |
| import bulletchess |
| from typing import List, Tuple, Optional |
| from .vocab import policy_index |
|
|
| |
| policy_to_idx = {u: i for i, u in enumerate(policy_index)} |
|
|
|
|
| def _board_to_12_piece_planes(board: bulletchess.Board) -> np.ndarray: |
| piece_types = [bulletchess.PAWN, bulletchess.KNIGHT, bulletchess.BISHOP, bulletchess.ROOK, bulletchess.QUEEN, bulletchess.KING] |
| piece_colors = [bulletchess.WHITE, bulletchess.BLACK] |
|
|
| planes = [] |
| for color in piece_colors: |
| for piece_type in piece_types: |
| mask = np.zeros((8, 8), dtype=np.float32) |
| |
| bitboard = board[color, piece_type] |
| for square in bitboard: |
| |
| square_idx = square.index() |
| rank = square_idx // 8 |
| file = square_idx % 8 |
| mask[rank][file] = 1.0 |
| planes.append(mask) |
| |
| return np.transpose(np.array(planes, dtype=np.float32), (1, 2, 0)) |
|
|
|
|
| def _castling_planes(board: bulletchess.Board) -> np.ndarray: |
| |
| |
| wq = 1.0 if bulletchess.WHITE_QUEENSIDE in board.castling_rights else 0.0 |
| wk = 1.0 if bulletchess.WHITE_KINGSIDE in board.castling_rights else 0.0 |
| bq = 1.0 if bulletchess.BLACK_QUEENSIDE in board.castling_rights else 0.0 |
| bk = 1.0 if bulletchess.BLACK_KINGSIDE in board.castling_rights else 0.0 |
| planes = [ |
| np.full((8, 8), wq, dtype=np.float32), |
| np.full((8, 8), wk, dtype=np.float32), |
| np.full((8, 8), bq, dtype=np.float32), |
| np.full((8, 8), bk, dtype=np.float32), |
| ] |
| return np.stack(planes, axis=0) |
|
|
|
|
| def _mirror_board(board: bulletchess.Board) -> bulletchess.Board: |
| """ |
| Fast mirror implementation for bulletchess.Board. |
| Mirrors the board (flips ranks 1<->8, 2<->7, etc.) and flips colors. |
| """ |
| |
| mirrored = bulletchess.Board.empty() |
| |
| |
| for square in bulletchess.SQUARES: |
| piece = board[square] |
| if piece is not None: |
| |
| square_idx = square.index() |
| rank = square_idx // 8 |
| file = square_idx % 8 |
| mirrored_rank = 7 - rank |
| mirrored_idx = mirrored_rank * 8 + file |
| mirrored_square = bulletchess.SQUARES[mirrored_idx] |
| |
| |
| mirrored_color = piece.color.opposite |
| mirrored[mirrored_square] = bulletchess.Piece(mirrored_color, piece.piece_type) |
| |
| |
| |
| new_castling_types = [] |
| if bulletchess.WHITE_KINGSIDE in board.castling_rights: |
| new_castling_types.append(bulletchess.BLACK_KINGSIDE) |
| if bulletchess.WHITE_QUEENSIDE in board.castling_rights: |
| new_castling_types.append(bulletchess.BLACK_QUEENSIDE) |
| if bulletchess.BLACK_KINGSIDE in board.castling_rights: |
| new_castling_types.append(bulletchess.WHITE_KINGSIDE) |
| if bulletchess.BLACK_QUEENSIDE in board.castling_rights: |
| new_castling_types.append(bulletchess.WHITE_QUEENSIDE) |
| |
| |
| if new_castling_types: |
| mirrored.castling_rights = bulletchess.CastlingRights(new_castling_types) |
| else: |
| mirrored.castling_rights = bulletchess.NO_CASTLING |
| |
| |
| mirrored.turn = board.turn.opposite |
| |
| |
| if board.en_passant_square is not None: |
| ep_idx = board.en_passant_square.index() |
| ep_rank = ep_idx // 8 |
| ep_file = ep_idx % 8 |
| mirrored_ep_rank = 7 - ep_rank |
| mirrored_ep_idx = mirrored_ep_rank * 8 + ep_file |
| mirrored.en_passant_square = bulletchess.SQUARES[mirrored_ep_idx] |
| |
| |
| mirrored.halfmove_clock = board.halfmove_clock |
| mirrored.fullmove_number = board.fullmove_number |
| |
| return mirrored |
|
|
|
|
| def _build_snapshots(board: bulletchess.Board) -> List[bulletchess.Board]: |
| |
| temp = board.copy() |
| snaps: List[bulletchess.Board] = [temp.copy()] |
| for _ in range(7): |
| |
| try: |
| temp.undo() |
| snaps.append(temp.copy()) |
| except (IndexError, AttributeError): |
| |
| snaps.append(None) |
| return snaps |
|
|
|
|
| def encode_moves_to_tensor(uci_moves: List[str], starting_fen: Optional[str] = None) -> Tuple[torch.Tensor, np.ndarray]: |
| board = bulletchess.Board.from_fen(starting_fen) if starting_fen is not None else bulletchess.Board() |
| for mv in uci_moves: |
| move = bulletchess.Move.from_uci(mv) |
| board.apply(move) |
|
|
| |
| snapshots = _build_snapshots(board) |
|
|
| |
| mirror = (board.turn == bulletchess.BLACK) |
| if mirror: |
| snapshots = [_mirror_board(s) if s is not None else None for s in snapshots] |
|
|
| |
| |
| channels: List[np.ndarray] = [] |
| for i in range(8): |
| if snapshots[i] is not None: |
| planes12 = _board_to_12_piece_planes(snapshots[i]) |
| channels.append(planes12) |
| else: |
| channels.append(np.zeros((8, 8, 12), dtype=np.float32)) |
| |
| channels.append(np.zeros((8, 8, 1), dtype=np.float32)) |
|
|
| |
| current_for_flags = snapshots[0] |
| assert current_for_flags is not None |
| castling = _castling_planes(current_for_flags) |
| is_black_to_move = 1.0 if (board.turn == bulletchess.BLACK) else 0.0 |
| specials = [ |
| castling[0:1, :, :], |
| castling[1:2, :, :], |
| castling[2:3, :, :], |
| castling[3:4, :, :], |
| np.full((1, 8, 8), is_black_to_move, dtype=np.float32), |
| np.zeros((1, 8, 8), dtype=np.float32), |
| np.zeros((1, 8, 8), dtype=np.float32), |
| np.ones((1, 8, 8), dtype=np.float32), |
| ] |
|
|
| |
| stacked = np.concatenate(channels, axis=2) |
| specials_hwk = np.transpose(np.concatenate(specials, axis=0), (1, 2, 0)) |
| final_hwk = np.concatenate([stacked, specials_hwk], axis=2) |
|
|
| |
| final_tensor = torch.from_numpy(final_hwk).permute(2, 0, 1).unsqueeze(0).float() |
|
|
| |
| board_for_mask = _mirror_board(board) if (board.turn == bulletchess.BLACK) else board.copy() |
| lm = np.ones(1858, dtype=np.float32) * (-1000) |
| |
| |
| legal_moves_uci = set() |
| for possible in board_for_mask.legal_moves(): |
| u = possible.uci() |
| if u[-1] != 'n': |
| legal_moves_uci.add(u) |
| else: |
| legal_moves_uci.add(u[:-1]) |
| |
| |
| for u in legal_moves_uci: |
| idx = policy_to_idx.get(u) |
| if idx is not None: |
| lm[idx] = 0 |
| |
| |
| |
| |
| if "e1g1" in legal_moves_uci: |
| castling_move = "e1h1" |
| idx = policy_to_idx.get(castling_move) |
| if idx is not None: |
| lm[idx] = 0 |
| |
| |
| if "e1c1" in legal_moves_uci: |
| castling_move = "e1a1" |
| idx = policy_to_idx.get(castling_move) |
| if idx is not None: |
| lm[idx] = 0 |
| |
| |
| if "e8g8" in legal_moves_uci: |
| castling_move = "e8h8" |
| idx = policy_to_idx.get(castling_move) |
| if idx is not None: |
| lm[idx] = 0 |
| |
| |
| if "e8c8" in legal_moves_uci: |
| castling_move = "e8a8" |
| idx = policy_to_idx.get(castling_move) |
| if idx is not None: |
| lm[idx] = 0 |
|
|
| return final_tensor, lm |
|
|
|
|
| def encode_fen_to_tensor(fen: str) -> Tuple[torch.Tensor, np.ndarray]: |
| board = bulletchess.Board.from_fen(fen) |
|
|
| |
| snapshots = [board.copy()] + [None] * 7 |
|
|
| |
| if board.turn == bulletchess.BLACK: |
| snapshots = [_mirror_board(s) if s is not None else None for s in snapshots] |
|
|
| |
| channels: List[np.ndarray] = [] |
| for i in range(8): |
| if snapshots[i] is not None: |
| planes12 = _board_to_12_piece_planes(snapshots[i]) |
| channels.append(planes12) |
| else: |
| channels.append(np.zeros((8, 8, 12), dtype=np.float32)) |
| channels.append(np.zeros((8, 8, 1), dtype=np.float32)) |
|
|
| current_for_flags = snapshots[0] |
| assert current_for_flags is not None |
| castling = _castling_planes(current_for_flags) |
| is_black_to_move = 1.0 if (board.turn == bulletchess.BLACK) else 0.0 |
| specials = [ |
| castling[0:1, :, :], |
| castling[1:2, :, :], |
| castling[2:3, :, :], |
| castling[3:4, :, :], |
| np.full((1, 8, 8), is_black_to_move, dtype=np.float32), |
| np.zeros((1, 8, 8), dtype=np.float32), |
| np.zeros((1, 8, 8), dtype=np.float32), |
| np.ones((1, 8, 8), dtype=np.float32), |
| ] |
|
|
| stacked = np.concatenate(channels, axis=2) |
| specials_hwk = np.transpose(np.concatenate(specials, axis=0), (1, 2, 0)) |
| final_hwk = np.concatenate([stacked, specials_hwk], axis=2) |
|
|
| final_tensor = torch.from_numpy(final_hwk).permute(2, 0, 1).unsqueeze(0).float() |
|
|
| |
| board_for_mask = _mirror_board(board) if (board.turn == bulletchess.BLACK) else board.copy() |
| lm = np.ones(1858, dtype=np.float32) * (-1000) |
| |
| |
| legal_moves_uci = set() |
| for possible in board_for_mask.legal_moves(): |
| u = possible.uci() |
| if u[-1] != 'n': |
| legal_moves_uci.add(u) |
| else: |
| legal_moves_uci.add(u[:-1]) |
| |
| |
| for u in legal_moves_uci: |
| idx = policy_to_idx.get(u) |
| if idx is not None: |
| lm[idx] = 0 |
| |
| |
| |
| |
| if "e1g1" in legal_moves_uci: |
| castling_move = "e1h1" |
| idx = policy_to_idx.get(castling_move) |
| if idx is not None: |
| lm[idx] = 0 |
| |
| |
| if "e1c1" in legal_moves_uci: |
| castling_move = "e1a1" |
| idx = policy_to_idx.get(castling_move) |
| if idx is not None: |
| lm[idx] = 0 |
| |
| |
| if "e8g8" in legal_moves_uci: |
| castling_move = "e8h8" |
| idx = policy_to_idx.get(castling_move) |
| if idx is not None: |
| lm[idx] = 0 |
| |
| |
| if "e8c8" in legal_moves_uci: |
| castling_move = "e8a8" |
| idx = policy_to_idx.get(castling_move) |
| if idx is not None: |
| lm[idx] = 0 |
|
|
| return final_tensor, lm |
|
|
|
|