| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .attn_map import apm_map, apm_out |
| import math |
| from .encoding_simple import encode_fen_to_tensor, encode_moves_to_tensor |
| from .vocab import policy_index |
| from typing import Union, List, Optional |
| import bulletchess |
| import numpy as np |
|
|
|
|
| from transformers import PretrainedConfig, PreTrainedModel |
|
|
| class Gating(nn.Module): |
| def __init__(self, features_shape, additive=True, init_value=None): |
| super(Gating, self).__init__() |
| self.additive = additive |
| if init_value is None: |
| init_value = 0 if self.additive else 1 |
| |
| self.gate = nn.Parameter(torch.full(features_shape, float(init_value))) |
| if not self.additive: |
| self.gate.register_hook(lambda grad: torch.clamp(grad, min=0)) |
|
|
| def forward(self, x): |
| if self.additive: |
| return x + self.gate |
| else: |
| return x * self.gate |
|
|
| def ma_gating(x, in_features): |
| x = Gating(in_features, additive=False)(x) |
| x = Gating(in_features, additive=True)(x) |
| return x |
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, in_features, scale=True): |
| super(RMSNorm, self).__init__() |
| self.scale = scale |
| if self.scale: |
| self.gamma = nn.Parameter(torch.ones(in_features)) |
|
|
| def forward(self, x): |
| rms = torch.sqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-5) |
| x_normalized = x / rms |
| if self.scale: |
| return x_normalized * self.gamma |
| return x_normalized |
|
|
| class ApplyAttentionPolicyMap(nn.Module): |
| def __init__(self): |
| super(ApplyAttentionPolicyMap, self).__init__() |
| |
| |
| self.register_buffer('fc1', torch.from_numpy(apm_map).float()) |
| self.register_buffer('idx', torch.from_numpy(apm_out).long()) |
|
|
| def forward(self, logits, pp_logits): |
| logits = torch.cat([logits.reshape(-1, 64 * 64), |
| pp_logits.reshape(-1, 8 * 24)], |
| dim=1) |
| |
| batch_size = logits.size(0) |
| idx = self.idx.unsqueeze(0).expand(batch_size, -1) |
| |
| return torch.gather(logits, 1, idx) |
|
|
| class Mish(nn.Module): |
| def __init__(self): |
| super(Mish, self).__init__() |
|
|
| def forward(self, x): |
| return x * torch.tanh(F.softplus(x)) |
|
|
| class CustomMHA(nn.Module): |
| def __init__(self, emb_size, d_model, num_heads, dropout=0.0, use_bias_qkv=True, use_bias_out=True, |
| use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'): |
| super(CustomMHA, self).__init__() |
| assert d_model % num_heads == 0 |
| self.emb_size = emb_size |
| self.d_model = d_model |
| self.num_heads = num_heads |
| self.head_dim = d_model // num_heads |
| self.wq = nn.Linear(emb_size, d_model, bias=use_bias_qkv) |
| self.wk = nn.Linear(emb_size, d_model, bias=use_bias_qkv) |
| self.wv = nn.Linear(emb_size, d_model, bias=use_bias_qkv) |
| self.out_proj = nn.Linear(d_model, emb_size, bias=use_bias_out) |
| self.attn_dropout = nn.Dropout(dropout) |
| |
| self.smol_compress = None |
| self.smol_hidden1 = None |
| self.smol_hidden1_ln = None |
| self.smol_gen_from = None |
| self.smol_gen_from_ln = None |
| self.smol_weight_gen = None |
| if use_smolgen: |
| self.smol_compress = nn.Linear(emb_size, smol_hidden_channels, bias=False) |
| self.smol_hidden1 = nn.Linear(64 * smol_hidden_channels, smol_hidden_sz, bias=True) |
| self.smol_hidden1_ln = nn.LayerNorm(smol_hidden_sz, eps=1e-3) |
| self.smol_gen_from = nn.Linear(smol_hidden_sz, num_heads * smol_gen_sz, bias=True) |
| self.smol_gen_from_ln = nn.LayerNorm(num_heads * smol_gen_sz, eps=1e-3) |
| self.smol_weight_gen = nn.Linear(smol_gen_sz, 64 * 64, bias=False) |
| self.smol_activation = smol_activation |
|
|
| def _shape(self, x): |
| b, l, _ = x.shape |
| return x.view(b, l, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
| def forward(self, x, return_attn=False): |
| |
| q = self.wq(x) |
| k = self.wk(x) |
| v = self.wv(x) |
| q = self._shape(q) |
| k = self._shape(k) |
| v = self._shape(v) |
| scale = torch.sqrt(torch.tensor(self.head_dim, dtype=x.dtype, device=x.device)) |
| attn_logits = torch.matmul(q, k.transpose(-2, -1)) / scale |
| |
| smol_w = None |
| if self.smol_compress is not None: |
| b, l, _ = x.shape |
| compressed = self.smol_compress(x) |
| compressed = compressed.reshape(b, l * compressed.shape[-1]) |
| hidden_pre = self.smol_hidden1(compressed) |
| hidden = F.silu(hidden_pre) if self.smol_activation == 'swish' else F.silu(hidden_pre) |
| hidden_ln = self.smol_hidden1_ln(hidden) |
| gen_from_pre = self.smol_gen_from(hidden_ln) |
| gen_from_act = F.silu(gen_from_pre) if self.smol_activation == 'swish' else F.silu(gen_from_pre) |
| gen_from = self.smol_gen_from_ln(gen_from_act) |
| gen_from = gen_from.view(b, self.num_heads, -1) |
| smol_w = self.smol_weight_gen(gen_from) |
| smol_w = smol_w.view(b, self.num_heads, l, l) |
| attn_logits = attn_logits + smol_w |
| |
| attn_logits = attn_logits - attn_logits.max(dim=-1, keepdim=True)[0] |
| attn_weights = torch.exp(attn_logits) |
| attn_weights = attn_weights / attn_weights.sum(dim=-1, keepdim=True) |
| attn_weights = self.attn_dropout(attn_weights) |
| attn_output = torch.matmul(attn_weights, v) |
| attn_output = attn_output.transpose(1, 2).contiguous().view(x.size(0), x.size(1), self.d_model) |
| out = self.out_proj(attn_output) |
| if return_attn: |
| return out, attn_weights, smol_w, attn_logits |
| return out |
|
|
| class FFN(nn.Module): |
| def __init__(self, emb_size, dff, activation=Mish(), omit_other_biases=False): |
| super(FFN, self).__init__() |
| self.dense1 = nn.Linear(emb_size, dff, bias=not omit_other_biases) |
| self.activation = activation |
| self.dense2 = nn.Linear(dff, emb_size, bias=not omit_other_biases) |
|
|
| def forward(self, x): |
| x = self.dense1(x) |
| x = self.activation(x) |
| x = self.dense2(x) |
| return x |
|
|
| class EncoderLayer(nn.Module): |
| def __init__(self, emb_size, d_model, num_heads, dff, dropout_rate, encoder_layers, skip_first_ln=False, encoder_rms_norm=False, omit_qkv_biases=False, omit_other_biases=False, |
| use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'): |
| super(EncoderLayer, self).__init__() |
| self.mha = CustomMHA(emb_size, d_model, num_heads, dropout=dropout_rate, use_bias_qkv=not omit_qkv_biases, use_bias_out=not omit_other_biases, |
| use_smolgen=use_smolgen, smol_hidden_channels=smol_hidden_channels, smol_hidden_sz=smol_hidden_sz, smol_gen_sz=smol_gen_sz, smol_activation=smol_activation) |
| self.ffn = FFN(emb_size, dff, omit_other_biases=omit_other_biases) |
| |
| self.norm1 = RMSNorm(emb_size) if encoder_rms_norm else nn.LayerNorm(emb_size, eps=0.001) |
| self.norm2 = RMSNorm(emb_size) if encoder_rms_norm else nn.LayerNorm(emb_size, eps=0.001) |
| |
| self.dropout1 = nn.Dropout(dropout_rate) |
| self.dropout2 = nn.Dropout(dropout_rate) |
| |
| self.alpha = (2. * encoder_layers)**-0.25 |
| self.skip_first_ln = skip_first_ln |
|
|
| def forward(self, x): |
| attn_output = self.mha(x) |
| attn_output = self.dropout1(attn_output) |
| |
| out1 = x + attn_output * self.alpha |
| if not self.skip_first_ln: |
| out1 = self.norm1(out1) |
| ffn_output = self.ffn(out1) |
| ffn_output = self.dropout2(ffn_output) |
| |
| out2 = self.norm2(out1 + ffn_output * self.alpha) |
| return out2 |
|
|
| class PolicyHead(nn.Module): |
| def __init__(self, pol_embedding_size, policy_d_model, opponent=False): |
| super(PolicyHead, self).__init__() |
| self.opponent = opponent |
| self.wq = nn.Linear(pol_embedding_size, policy_d_model) |
| self.wk = nn.Linear(pol_embedding_size, policy_d_model) |
| self.ppo = nn.Linear(policy_d_model, 4, bias=False) |
| self.apply_map = ApplyAttentionPolicyMap() |
|
|
| def forward(self, x): |
| if self.opponent: |
| x = torch.flip(x, [1]) |
|
|
| queries = self.wq(x) |
| keys = self.wk(x) |
|
|
| matmul_qk = torch.matmul(queries, keys.transpose(-2, -1)) |
| |
| dk = torch.sqrt(torch.tensor(keys.shape[-1], dtype=keys.dtype, device=keys.device)) |
| promotion_keys = keys[:, -8:, :] |
| promotion_offsets = self.ppo(promotion_keys).transpose(-2,-1) * dk |
| promotion_offsets = promotion_offsets[:, :3, :] + promotion_offsets[:, 3:4, :] |
|
|
| n_promo_logits = matmul_qk[:, -16:-8, -8:] |
| q_promo_logits = (n_promo_logits + promotion_offsets[:, 0:1, :]).unsqueeze(3) |
| r_promo_logits = (n_promo_logits + promotion_offsets[:, 1:2, :]).unsqueeze(3) |
| b_promo_logits = (n_promo_logits + promotion_offsets[:, 2:3, :]).unsqueeze(3) |
| promotion_logits = torch.cat([q_promo_logits, r_promo_logits, b_promo_logits], axis=3).reshape(-1, 8, 24) |
|
|
| promotion_logits = promotion_logits / dk |
| policy_attn_logits = matmul_qk / dk |
|
|
| return self.apply_map(policy_attn_logits, promotion_logits) |
|
|
| class ValueHead(nn.Module): |
| def __init__(self, embedding_size, val_embedding_size, default_activation=Mish()): |
| super(ValueHead, self).__init__() |
| self.embedding = nn.Linear(embedding_size, val_embedding_size) |
| self.activation = default_activation |
| self.flatten = nn.Flatten() |
| self.dense1 = nn.Linear(val_embedding_size * 64, 128) |
| self.dense2 = nn.Linear(128, 3) |
|
|
| def forward(self, x): |
| x = self.embedding(x) |
| x = self.activation(x) |
| x = self.flatten(x) |
| x = self.dense1(x) |
| x = self.activation(x) |
| x = self.dense2(x) |
| return x |
|
|
| class BT4Config(PretrainedConfig): |
| """Configuration class for BT4 model.""" |
| model_type = "bt4" |
| |
| def __init__( |
| self, |
| embedding_size=1024, |
| embedding_dense_sz=512, |
| encoder_layers=15, |
| encoder_d_model=1024, |
| encoder_heads=32, |
| encoder_dff=1536, |
| dropout_rate=0.0, |
| pol_embedding_size=1024, |
| policy_d_model=1024, |
| val_embedding_size=128, |
| use_smolgen=True, |
| smol_hidden_channels=32, |
| smol_hidden_sz=256, |
| smol_gen_sz=256, |
| smol_activation="swish", |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.embedding_size = embedding_size |
| self.embedding_dense_sz = embedding_dense_sz |
| self.encoder_layers = encoder_layers |
| self.encoder_d_model = encoder_d_model |
| self.encoder_heads = encoder_heads |
| self.encoder_dff = encoder_dff |
| self.dropout_rate = dropout_rate |
| self.pol_embedding_size = pol_embedding_size |
| self.policy_d_model = policy_d_model |
| self.val_embedding_size = val_embedding_size |
| self.use_smolgen = use_smolgen |
| self.smol_hidden_channels = smol_hidden_channels |
| self.smol_hidden_sz = smol_hidden_sz |
| self.smol_gen_sz = smol_gen_sz |
| self.smol_activation = smol_activation |
|
|
| class BT4(PreTrainedModel): |
| config_class = BT4Config |
| |
| def __init__(self, config=None, embedding_size=1024, embedding_dense_sz=512, encoder_layers=15, encoder_d_model=1024, encoder_heads=32, encoder_dff=1536, dropout_rate=0.0, pol_embedding_size=1024, policy_d_model=1024, val_embedding_size=128, default_activation=Mish(), |
| use_smolgen=True, smol_hidden_channels=32, smol_hidden_sz=256, smol_gen_sz=256, smol_activation='swish'): |
| |
| if config is None: |
| config = BT4Config( |
| embedding_size=embedding_size, |
| embedding_dense_sz=embedding_dense_sz, |
| encoder_layers=encoder_layers, |
| encoder_d_model=encoder_d_model, |
| encoder_heads=encoder_heads, |
| encoder_dff=encoder_dff, |
| dropout_rate=dropout_rate, |
| pol_embedding_size=pol_embedding_size, |
| policy_d_model=policy_d_model, |
| val_embedding_size=val_embedding_size, |
| use_smolgen=use_smolgen, |
| smol_hidden_channels=smol_hidden_channels, |
| smol_hidden_sz=smol_hidden_sz, |
| smol_gen_sz=smol_gen_sz, |
| smol_activation=smol_activation, |
| ) |
| super(BT4, self).__init__(config) |
| |
| |
| embedding_size = config.embedding_size |
| embedding_dense_sz = config.embedding_dense_sz |
| encoder_layers = config.encoder_layers |
| encoder_d_model = config.encoder_d_model |
| encoder_heads = config.encoder_heads |
| encoder_dff = config.encoder_dff |
| dropout_rate = config.dropout_rate |
| pol_embedding_size = config.pol_embedding_size |
| policy_d_model = config.policy_d_model |
| val_embedding_size = config.val_embedding_size |
| use_smolgen = config.use_smolgen |
| smol_hidden_channels = config.smol_hidden_channels |
| smol_hidden_sz = config.smol_hidden_sz |
| smol_gen_sz = config.smol_gen_sz |
| smol_activation = config.smol_activation |
| self.embedding_dense_sz = embedding_dense_sz |
| |
| self.deepnorm_alpha = (2. * encoder_layers) ** -0.25 |
| |
| self.embedding_preprocess = nn.Linear(64*12, 64*self.embedding_dense_sz) |
| self.embedding = nn.Linear(112 + self.embedding_dense_sz, embedding_size) |
| nn.init.xavier_uniform_(self.embedding.weight) |
| nn.init.zeros_(self.embedding.bias) |
|
|
| self.embedding_ln = nn.LayerNorm(embedding_size, eps=0.001) |
| |
| self.gating_mult = Gating((64, embedding_size), additive=False) |
| self.gating_add = Gating((64, embedding_size), additive=True) |
|
|
| self.embedding_ffn = FFN(embedding_size, encoder_dff) |
| self.embedding_ffn_ln = nn.LayerNorm(embedding_size, eps=0.001) |
| |
| self.encoder_layers_list = nn.ModuleList([ |
| EncoderLayer(embedding_size, encoder_d_model, encoder_heads, encoder_dff, dropout_rate, encoder_layers, |
| use_smolgen=use_smolgen, smol_hidden_channels=smol_hidden_channels, smol_hidden_sz=smol_hidden_sz, smol_gen_sz=smol_gen_sz, smol_activation=smol_activation) |
| for _ in range(encoder_layers) |
| ]) |
| |
| self.policy_embedding = nn.Linear(embedding_size, pol_embedding_size) |
| self.policy_head = PolicyHead(pol_embedding_size, policy_d_model) |
| self.value_head_winner = ValueHead(embedding_size, val_embedding_size) |
| self.value_head_q = ValueHead(embedding_size, val_embedding_size) |
| self.activation = default_activation |
| |
| self.apply(self._init_weights) |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| """Load model from pretrained checkpoint (required by transformers).""" |
| from transformers import AutoConfig |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
| import os |
| |
| |
| config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) |
| |
| |
| model = cls(config=config) |
| |
| |
| is_hf_hub = "/" in pretrained_model_name_or_path and not os.path.isdir(pretrained_model_name_or_path) |
| |
| if is_hf_hub: |
| |
| safetensors_path = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, |
| filename="model.safetensors", |
| cache_dir=kwargs.get("cache_dir", None), |
| token=kwargs.get("token", None), |
| ) |
| state_dict = load_file(safetensors_path) |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| if missing_keys: |
| print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys") |
| if unexpected_keys: |
| print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys") |
| else: |
| |
| safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors") |
| if os.path.exists(safetensors_path): |
| state_dict = load_file(safetensors_path) |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| if missing_keys: |
| print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys") |
| if unexpected_keys: |
| print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys") |
| else: |
| |
| pt_path = os.path.join(pretrained_model_name_or_path, "model.pt") |
| checkpoint = torch.load(pt_path, map_location="cpu") |
| if isinstance(checkpoint, dict): |
| if "state_dict" in checkpoint: |
| state_dict = checkpoint["state_dict"] |
| elif "model" in checkpoint: |
| state_dict = checkpoint["model"] |
| else: |
| state_dict = checkpoint |
| else: |
| state_dict = checkpoint |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
| if missing_keys: |
| print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys") |
| if unexpected_keys: |
| print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys") |
| |
| return model |
| |
| @classmethod |
| def register_for_auto_class(cls, auto_class): |
| """Register this class for auto class loading (required by transformers).""" |
| |
| pass |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| |
| nn.init.xavier_normal_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| def forward(self, x): |
| |
| flow = x.permute(0, 2, 3, 1).reshape(-1, 64, 112) |
| |
| pos_info = flow[..., :12] |
| pos_info_flat = pos_info.reshape(-1, 64 * 12) |
| |
| pos_info_processed = self.embedding_preprocess(pos_info_flat) |
| pos_info = pos_info_processed.reshape(-1, 64, self.embedding_dense_sz) |
| |
| flow = torch.cat([flow, pos_info], dim=-1) |
| |
| flow = self.embedding(flow) |
| |
| flow = self.activation(flow) |
| |
| flow = self.embedding_ln(flow) |
|
|
| flow = self.gating_mult(flow) |
| flow = self.gating_add(flow) |
| |
| ffn_dense1_pre = self.embedding_ffn.dense1(flow) |
| ffn_dense1 = self.embedding_ffn.activation(ffn_dense1_pre) |
| ffn_output = self.embedding_ffn.dense2(ffn_dense1) |
| |
| residual = flow + ffn_output * self.deepnorm_alpha |
| flow = self.embedding_ffn_ln(residual) |
| |
| for i, layer in enumerate(self.encoder_layers_list): |
| flow = layer(flow) |
| |
| policy_tokens = self.policy_embedding(flow) |
| policy_tokens = self.activation(policy_tokens) |
| |
| policy_logits = self.policy_head(policy_tokens) |
| |
| value_winner = self.value_head_winner(flow) |
| value_q = self.value_head_q(flow) |
|
|
| return policy_logits, value_winner, value_q |
| |
| def get_move_from_history(self, fen_or_moves: Union[str, List[str]], T: float, device: str = None, **kwargs) -> str: |
| """ |
| Predict a move from a move history or FEN position. |
| |
| Args: |
| fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves |
| T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic) |
| device: Device to run the model on (if None, uses model's device) |
| return_probs: If True, returns a dictionary of move probabilities instead of a single move |
| |
| Returns: |
| UCI move string (e.g., 'e2e4') or dictionary of move probabilities if return_probs=True |
| """ |
| |
| if device is None: |
| device = next(self.parameters()).device |
| else: |
| device = torch.device(device) |
| |
| |
| if isinstance(fen_or_moves, str): |
| |
| fen = fen_or_moves |
| is_black_to_move = fen.split()[1] == 'b' |
| input_tensor_112, legal_moves_mask = encode_fen_to_tensor(fen) |
| castling_rights = fen.split()[2] if len(fen.split()) > 2 else "" |
| elif isinstance(fen_or_moves, list): |
| |
| move_history = fen_or_moves |
| input_tensor_112, legal_moves_mask = encode_moves_to_tensor(move_history) |
| |
| board = bulletchess.Board() |
| for mv in move_history: |
| move = bulletchess.Move.from_uci(mv) |
| board.apply(move) |
| is_black_to_move = (board.turn == bulletchess.BLACK) |
| fen_parts = board.fen().split() |
| castling_rights = fen_parts[2] if len(fen_parts) > 2 else "" |
| else: |
| raise ValueError("Input must be a FEN string or a list of UCI moves") |
| |
| input_tensor_112 = input_tensor_112.to(device, non_blocking=True) |
| |
| self.eval() |
| with torch.inference_mode(): |
| policy_logits,_,_ = self.forward(input_tensor_112) |
| |
| |
| logits0 = policy_logits[0] + torch.from_numpy(legal_moves_mask).to(policy_logits.device) |
| |
| |
| return_probs = kwargs.get('return_probs', False) |
| |
| if return_probs: |
| |
| scaled_logits = logits0 / T if T > 0 else logits0 |
| probs = F.softmax(scaled_logits, dim=0) |
| probs_dict = {} |
| for idx, move in enumerate(policy_index): |
| prob_val = probs[idx].item() |
| if prob_val > 1e-6: |
| probs_dict[move] = prob_val |
| return probs_dict |
| |
| if T == 0.0: |
| |
| best_move_idx = torch.argmax(logits0).item() |
| uci_move = policy_index[best_move_idx] |
| else: |
| |
| |
| scaled_logits = logits0 / T |
| |
| probs = F.softmax(scaled_logits, dim=0) |
| |
| move_idx = torch.multinomial(probs, 1).item() |
| uci_move = policy_index[move_idx] |
| |
| |
| |
| if is_black_to_move: |
| def mirror_rank(rank_char): |
| rank = int(rank_char) |
| return str(9 - rank) |
| |
| |
| if len(uci_move) >= 4: |
| from_file = uci_move[0] |
| from_rank = uci_move[1] |
| to_file = uci_move[2] |
| to_rank = uci_move[3] |
| promo = uci_move[4:] if len(uci_move) > 4 else "" |
| |
| uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo |
| |
| |
| |
| |
| if uci_move == "e1h1" and "K" in castling_rights: |
| uci_move = "e1g1" |
| elif uci_move == "e1a1" and "Q" in castling_rights: |
| uci_move = "e1c1" |
| |
| elif uci_move == "e8h8" and "k" in castling_rights: |
| uci_move = "e8g8" |
| elif uci_move == "e8a8" and "q" in castling_rights: |
| uci_move = "e8c8" |
| |
| return uci_move |
| |
| def get_best_move_value(self, fen_or_moves: Union[str, List[str]], T: float = 0.0, device: str = None) -> tuple: |
| """ |
| Get the best move and its value using value analysis. |
| |
| Args: |
| fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves |
| T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic) |
| device: Device to run the model on (if None, uses model's device) |
| |
| Returns: |
| Tuple of (best_move, value) where value is the position evaluation |
| """ |
| |
| if device is None: |
| device = next(self.parameters()).device |
| else: |
| device = torch.device(device) |
| |
| |
| if isinstance(fen_or_moves, str): |
| fen = fen_or_moves |
| is_black_to_move = fen.split()[1] == 'b' |
| input_tensor_112, legal_moves_mask = encode_fen_to_tensor(fen) |
| castling_rights = fen.split()[2] if len(fen.split()) > 2 else "" |
| elif isinstance(fen_or_moves, list): |
| move_history = fen_or_moves |
| input_tensor_112, legal_moves_mask = encode_moves_to_tensor(move_history) |
| board = bulletchess.Board() |
| for mv in move_history: |
| move = bulletchess.Move.from_uci(mv) |
| board.apply(move) |
| is_black_to_move = (board.turn == bulletchess.BLACK) |
| fen_parts = board.fen().split() |
| castling_rights = fen_parts[2] if len(fen_parts) > 2 else "" |
| else: |
| raise ValueError("Input must be a FEN string or a list of UCI moves") |
| |
| input_tensor_112 = input_tensor_112.to(device, non_blocking=True) |
| |
| self.eval() |
| with torch.inference_mode(): |
| policy_logits, _, value_q = self.forward(input_tensor_112) |
| |
| |
| logits0 = policy_logits[0] + torch.from_numpy(legal_moves_mask).to(policy_logits.device) |
| |
| |
| if T == 0.0: |
| best_move_idx = torch.argmax(logits0).item() |
| else: |
| scaled_logits = logits0 / T |
| probs = F.softmax(scaled_logits, dim=0) |
| move_idx = torch.multinomial(probs, 1).item() |
| best_move_idx = move_idx |
| |
| uci_move = policy_index[best_move_idx] |
| |
| |
| if is_black_to_move: |
| def mirror_rank(rank_char): |
| rank = int(rank_char) |
| return str(9 - rank) |
| |
| if len(uci_move) >= 4: |
| from_file = uci_move[0] |
| from_rank = uci_move[1] |
| to_file = uci_move[2] |
| to_rank = uci_move[3] |
| promo = uci_move[4:] if len(uci_move) > 4 else "" |
| uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo |
| |
| |
| if uci_move == "e1h1" and "K" in castling_rights: |
| uci_move = "e1g1" |
| elif uci_move == "e1a1" and "Q" in castling_rights: |
| uci_move = "e1c1" |
| elif uci_move == "e8h8" and "k" in castling_rights: |
| uci_move = "e8g8" |
| elif uci_move == "e8a8" and "q" in castling_rights: |
| uci_move = "e8c8" |
| |
| |
| value_probs = F.softmax(value_q[0], dim=0) |
| value = value_probs.cpu().numpy() |
| |
| return uci_move, value |
| |
| def get_position_value(self, fen_or_moves: Union[str, List[str]], device: str = None) -> np.ndarray: |
| """ |
| Get position evaluation using value_q. |
| |
| Args: |
| fen_or_moves: Either a FEN string representing the chess position, or a list of UCI moves |
| device: Device to run the model on (if None, uses model's device) |
| |
| Returns: |
| Array of [black_win, draw, white_win] probabilities |
| """ |
| |
| if device is None: |
| device = next(self.parameters()).device |
| else: |
| device = torch.device(device) |
| |
| |
| if isinstance(fen_or_moves, str): |
| input_tensor_112, _ = encode_fen_to_tensor(fen_or_moves) |
| elif isinstance(fen_or_moves, list): |
| input_tensor_112, _ = encode_moves_to_tensor(fen_or_moves) |
| else: |
| raise ValueError("Input must be a FEN string or a list of UCI moves") |
| |
| input_tensor_112 = input_tensor_112.to(device, non_blocking=True) |
| |
| self.eval() |
| with torch.inference_mode(): |
| _, _, value_q = self.forward(input_tensor_112) |
| |
| |
| value_probs = F.softmax(value_q[0], dim=0) |
| return value_probs.cpu().numpy() |
| |
| def batch_get_moves_from_fens(self, fens: List[str], T: float, device: str = None, use_fp16: bool = False) -> List[str]: |
| """ |
| Get moves for multiple FEN positions using batched inference. |
| |
| Args: |
| fens: List of FEN strings representing chess positions |
| T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic) |
| device: Device to run the model on (if None, uses model's device) |
| |
| Returns: |
| List of UCI move strings |
| """ |
| if not fens: |
| return [] |
| |
| |
| if device is None: |
| device = next(self.parameters()).device |
| else: |
| device = torch.device(device) |
| |
| batch_size = len(fens) |
| |
| |
| input_tensors = [] |
| legal_moves_masks = [] |
| is_black_to_move_list = [] |
| castling_rights_list = [] |
| |
| for fen in fens: |
| input_tensor, legal_mask = encode_fen_to_tensor(fen) |
| input_tensors.append(input_tensor.squeeze(0)) |
| legal_moves_masks.append(legal_mask) |
| is_black_to_move_list.append(fen.split()[1] == 'b') |
| castling_rights_list.append(fen.split()[2] if len(fen.split()) > 2 else "") |
| |
| |
| batch_tensor = torch.stack(input_tensors).to(device, non_blocking=True) |
| if use_fp16 and device.type == 'cuda': |
| batch_tensor = batch_tensor.half() |
| |
| |
| self.eval() |
| with torch.inference_mode(): |
| if use_fp16 and device.type == 'cuda': |
| with torch.autocast(device_type='cuda', dtype=torch.float16): |
| policy_logits,_,_ = self.forward(batch_tensor) |
| else: |
| policy_logits,_,_ = self.forward(batch_tensor) |
| |
| |
| moves = [] |
| for i in range(batch_size): |
| |
| logits = policy_logits[i] + torch.from_numpy(legal_moves_masks[i]).to(policy_logits.device, dtype=policy_logits.dtype) |
| |
| |
| if T == 0.0: |
| best_move_idx = torch.argmax(logits).item() |
| uci_move = policy_index[best_move_idx] |
| else: |
| scaled_logits = logits / T |
| probs = F.softmax(scaled_logits, dim=0) |
| move_idx = torch.multinomial(probs, 1).item() |
| uci_move = policy_index[move_idx] |
| |
| |
| if is_black_to_move_list[i]: |
| def mirror_rank(rank_char): |
| rank = int(rank_char) |
| return str(9 - rank) |
| |
| if len(uci_move) >= 4: |
| from_file = uci_move[0] |
| from_rank = uci_move[1] |
| to_file = uci_move[2] |
| to_rank = uci_move[3] |
| promo = uci_move[4:] if len(uci_move) > 4 else "" |
| |
| uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo |
| |
| |
| castling_rights = castling_rights_list[i] |
| if uci_move == "e1h1" and "K" in castling_rights: |
| uci_move = "e1g1" |
| elif uci_move == "e1a1" and "Q" in castling_rights: |
| uci_move = "e1c1" |
| elif uci_move == "e8h8" and "k" in castling_rights: |
| uci_move = "e8g8" |
| elif uci_move == "e8a8" and "q" in castling_rights: |
| uci_move = "e8c8" |
| |
| moves.append(uci_move) |
| |
| return moves |
| |
| def batch_get_moves_from_move_lists(self, move_lists: List[List[str]], T: float, device: str = None, use_fp16: bool = False, fens: Optional[List[str]] = None): |
| """ |
| Get moves for multiple move histories using batched inference. |
| |
| Args: |
| move_lists: List of move sequences, where each sequence is a list of UCI moves |
| T: Temperature for sampling (0.0 = deterministic/argmax, >0.0 = stochastic) |
| device: Device to run the model on (if None, uses model's device) |
| fens: Optional list of FEN strings that represent the board state prior to |
| applying the corresponding move list. When provided, each move history |
| is applied starting from the supplied FEN instead of the standard initial position. |
| |
| Returns: |
| List of UCI move strings |
| """ |
| if not move_lists: |
| return [] |
| |
| |
| if device is None: |
| device = next(self.parameters()).device |
| else: |
| device = torch.device(device) |
| |
| batch_size = len(move_lists) |
| |
| if fens is not None and len(fens) != len(move_lists): |
| raise ValueError("Length of fens must match length of move_lists when provided.") |
| |
| |
| input_tensors = [] |
| legal_moves_masks = [] |
| is_black_to_move_list = [] |
| castling_rights_list = [] |
| |
| for idx, move_history in enumerate(move_lists): |
| starting_fen = fens[idx] if fens is not None else None |
| input_tensor, legal_mask = encode_moves_to_tensor(move_history, starting_fen=starting_fen) |
| input_tensors.append(input_tensor.squeeze(0)) |
| legal_moves_masks.append(legal_mask) |
| |
| board = bulletchess.Board.from_fen(starting_fen) if starting_fen is not None else bulletchess.Board() |
| for mv in move_history: |
| move = bulletchess.Move.from_uci(mv) |
| board.apply(move) |
| is_black_to_move_list.append(board.turn == bulletchess.BLACK) |
| fen_parts = board.fen().split() |
| castling_rights_list.append(fen_parts[2] if len(fen_parts) > 2 else "") |
| |
| |
| batch_tensor = torch.stack(input_tensors).to(device, non_blocking=True) |
| if use_fp16 and device.type == 'cuda': |
| batch_tensor = batch_tensor.half() |
| |
| |
| self.eval() |
| with torch.inference_mode(): |
| if use_fp16 and device.type == 'cuda': |
| with torch.autocast(device_type='cuda', dtype=torch.float16): |
| policy_logits,_,_ = self.forward(batch_tensor) |
| else: |
| policy_logits,_,_ = self.forward(batch_tensor) |
| |
| |
| moves = [] |
| for i in range(batch_size): |
| |
| logits = policy_logits[i] + torch.from_numpy(legal_moves_masks[i]).to(policy_logits.device, dtype=policy_logits.dtype) |
| |
| |
| if T == 0.0: |
| best_move_idx = torch.argmax(logits).item() |
| uci_move = policy_index[best_move_idx] |
| else: |
| scaled_logits = logits / T |
| probs = F.softmax(scaled_logits, dim=0) |
| move_idx = torch.multinomial(probs, 1).item() |
| uci_move = policy_index[move_idx] |
| |
| |
| if is_black_to_move_list[i]: |
| def mirror_rank(rank_char): |
| rank = int(rank_char) |
| return str(9 - rank) |
| |
| if len(uci_move) >= 4: |
| from_file = uci_move[0] |
| from_rank = uci_move[1] |
| to_file = uci_move[2] |
| to_rank = uci_move[3] |
| promo = uci_move[4:] if len(uci_move) > 4 else "" |
| |
| uci_move = from_file + mirror_rank(from_rank) + to_file + mirror_rank(to_rank) + promo |
| |
| |
| castling_rights = castling_rights_list[i] |
| if uci_move == "e1h1" and "K" in castling_rights: |
| uci_move = "e1g1" |
| elif uci_move == "e1a1" and "Q" in castling_rights: |
| uci_move = "e1c1" |
| elif uci_move == "e8h8" and "k" in castling_rights: |
| uci_move = "e8g8" |
| elif uci_move == "e8a8" and "q" in castling_rights: |
| uci_move = "e8c8" |
| |
| moves.append(uci_move) |
| return moves |
|
|