| import os |
| os.system("mv NeoChess/san_moves.txt /usr/local/python/3.12.1/lib/python3.12/site-packages/torchrl/envs/custom/") |
| import torchrl |
| import torch |
| import chess |
| import chess.engine |
| import gymnasium |
| import numpy as np |
| import tensordict |
| from collections import defaultdict |
| from tensordict.nn import TensorDictModule |
| from tensordict.nn.distributions import NormalParamExtractor |
| from torch import nn |
| from torchrl.collectors import SyncDataCollector |
| from torchrl.data.replay_buffers import ReplayBuffer |
| from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement |
| from torchrl.data.replay_buffers.storages import LazyTensorStorage |
| import torch.nn.functional as F |
| from torch.distributions import Categorical |
| from torchrl.envs import ( |
| Compose, |
| DoubleToFloat, |
| ObservationNorm, |
| StepCounter, |
| TransformedEnv, |
| ) |
| from torchrl.envs.libs.gym import GymEnv |
| from torchrl.envs.utils import check_env_specs, ExplorationType, set_exploration_type |
| from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator, MaskedCategorical, ActorCriticWrapper |
| from torchrl.objectives import ClipPPOLoss |
| from torchrl.objectives.value import GAE |
| from tqdm import tqdm |
| from torchrl.envs.custom.chess import ChessEnv |
| from torchrl.envs.libs.gym import set_gym_backend, GymWrapper |
| from torchrl.envs import GymEnv |
| from tensordict import TensorDict |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| def board_to_tensor(board): |
| piece_encoding = { |
| 'P': 1, 'N': 2, 'B': 3, 'R': 4, 'Q': 5, 'K': 6, |
| 'p': 7, 'n': 8, 'b': 9, 'r': 10, 'q': 11, 'k': 12 |
| } |
|
|
| tensor = torch.zeros(64, dtype=torch.long) |
| for square in chess.SQUARES: |
| piece = board.piece_at(square) |
| if piece: |
| tensor[square] = piece_encoding[piece.symbol()] |
| else: |
| tensor[square] = 0 |
|
|
| return tensor.unsqueeze(0) |
|
|
| class Policy(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.embedding = nn.Embedding(13, 32) |
| self.attention = nn.MultiheadAttention(embed_dim=32, num_heads=16) |
| self.neu = 256 |
| self.neurons = nn.Sequential( |
| nn.Linear(64*32, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, 128), |
| nn.ReLU(), |
| nn.Linear(128, 29275), |
| ) |
|
|
| def forward(self, x): |
| x = chess.Board(x) |
| color = x.turn |
| x = board_to_tensor(x) |
| x = self.embedding(x) |
| x = x.permute(1, 0, 2) |
| attn_output, _ = self.attention(x, x, x) |
| x = attn_output.permute(1, 0, 2).contiguous() |
| x = x.view(x.size(0), -1) |
| x = self.neurons(x) * color |
| return x |
|
|
| class Value(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.embedding = nn.Embedding(13, 64) |
| self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=16) |
| self.neu = 512 |
| self.neurons = nn.Sequential( |
| nn.Linear(64*64, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, 64), |
| nn.ReLU(), |
| nn.Linear(64, 4) |
| ) |
|
|
| def forward(self, x): |
| x = chess.Board(x) |
| color = x.turn |
| x = board_to_tensor(x) |
| x = self.embedding(x) |
| x = x.permute(1, 0, 2) |
| attn_output, _ = self.attention(x, x, x) |
| x = attn_output.permute(1, 0, 2).contiguous() |
| x = x.view(x.size(0), -1) |
| x = self.neurons(x) |
| x = x[0][0]/10 |
| if color == chess.WHITE: |
| x = -x |
| return x |
|
|
| with set_gym_backend("gymnasium"): |
| env = ChessEnv( |
| stateful=True, |
| include_fen=True, |
| include_san=False, |
| ) |
|
|
| policy = Policy().to(device) |
| value = Value().to(device) |
| valweight = torch.load("NeoChess-Community/chessy_modelt-1.pth",map_location=device,weights_only=False) |
| value.load_state_dict(valweight) |
| polweight = torch.load("NeoChess-Community/chessy_policy.pth",map_location=device,weights_only=False) |
| policy.load_state_dict(polweight) |
|
|
| def sample_masked_action(logits, mask): |
| masked_logits = logits.clone() |
| masked_logits[~mask] = float('-inf') |
| probs = F.softmax(masked_logits, dim=-1) |
| dist = Categorical(probs=probs) |
| action = dist.sample() |
| log_prob = dist.log_prob(action) |
| return action, log_prob |
|
|
| class FENPolicyWrapper(nn.Module): |
| def __init__(self, policy_net): |
| super().__init__() |
| self.policy_net = policy_net |
|
|
| def forward(self, fens, action_mask=None) -> torch.tensor: |
| if isinstance(fens, (TensorDict, dict)): |
| fens = fens["fen"] |
|
|
| |
| if isinstance(fens, str): |
| fens = [fens] |
|
|
| |
| while isinstance(fens[0], list): |
| fens = fens[0] |
|
|
| |
| if action_mask is not None: |
| if isinstance(action_mask, torch.Tensor): |
| action_mask = action_mask.unsqueeze(0) if action_mask.ndim == 1 else action_mask |
| if not isinstance(action_mask, list): |
| action_mask = [action_mask[i] for i in range(len(fens))] |
|
|
| logits_list = [] |
|
|
| for i, fen in enumerate(fens): |
| logits = self.policy_net(fen) |
|
|
| |
| if action_mask is not None: |
| mask = action_mask[i].bool() |
| logits = logits.masked_fill(~mask, float("-inf")) |
|
|
| logits_list.append(logits) |
|
|
| return torch.stack(logits_list).squeeze(-2).squeeze(-2) |
|
|
| class FENValueWrapper(nn.Module): |
| def __init__(self, value_net): |
| super().__init__() |
| self.value_net = value_net |
|
|
| def forward(self, fens) -> torch.tensor: |
| if isinstance(fens, TensorDict) or isinstance(fens,dict): |
| fens = fens["fen"] |
| if isinstance(fens, str): |
| fens = [fens] |
| while isinstance(fens[0], list): |
| fens = fens[0] |
| state_value = [] |
| for fen in fens: |
| state_value += [self.value_net(fen)] |
| state_value = torch.stack(state_value) |
| |
| if state_value.ndim == 0: |
| state_value = state_value.unsqueeze(0) |
| return state_value |
|
|
| ACTION_DIM = 64 * 73 |
|
|
| from functools import partial |
| |
| policy_module = TensorDictModule( |
| FENPolicyWrapper(policy), |
| in_keys=["fen"], |
| out_keys=["logits"] |
| ) |
| value_module = TensorDictModule( |
| FENValueWrapper(value), |
| in_keys=["fen"], |
| out_keys=["state_value"] |
| ) |
|
|
| def masked_categorical_factory(logits, action_mask): |
| return MaskedCategorical(logits=logits, mask=action_mask) |
|
|
| actor = ProbabilisticActor( |
| module=policy_module, |
| in_keys=["logits", "action_mask"], |
| out_keys=["action"], |
| distribution_class=masked_categorical_factory, |
| return_log_prob=True, |
| ) |
| |
| obs = env.reset() |
| print(obs) |
| print(policy_module(obs)) |
| print(value_module(obs)) |
| print(actor(obs)) |
|
|
| rollout = env.rollout(3) |
|
|
| from torchrl.record.loggers import generate_exp_name, get_logger |
| def train_ppo_chess(chess_env, num_iterations=1, frames_per_batch=1000, |
| num_epochs=100, lr=3e-4, gamma=0.99, lmbda=0.95, |
| clip_epsilon=0.2, device="cpu"): |
| global actor_module, value_module, loss_module |
| """ |
| Main PPO training loop for Chess |
| |
| Args: |
| chess_env: Your ChessEnv instance |
| num_iterations: Number of training iterations |
| frames_per_batch: Number of environment steps per batch |
| num_epochs: Number of PPO epochs per iteration |
| lr: Learning rate |
| gamma: Discount factor |
| lmbda: GAE lambda parameter |
| clip_epsilon: PPO clipping parameter |
| device: Training device |
| """ |
|
|
| |
| env = chess_env |
| |
| actor_module = actor |
|
|
| collector = SyncDataCollector( |
| env, |
| actor_module, |
| frames_per_batch=frames_per_batch, |
| total_frames=-1, |
| device=device, |
| ) |
|
|
| |
| replay_buffer = ReplayBuffer( |
| storage=LazyTensorStorage(frames_per_batch), |
| sampler=SamplerWithoutReplacement(), |
| batch_size=256, |
| ) |
|
|
| |
| loss_module = ClipPPOLoss( |
| actor_network=actor_module, |
| critic_network=value_module, |
| clip_epsilon=clip_epsilon, |
| entropy_bonus=True, |
| entropy_coef=0.01, |
| critic_coef=1.0, |
| normalize_advantage=True, |
| ) |
|
|
| optim = torch.optim.Adam(loss_module.parameters(), lr=lr) |
|
|
| |
| logger = get_logger("tensorboard", logger_name="ppo_chess", experiment_name=generate_exp_name("PPO", "Chess")) |
|
|
| |
| collected_frames = 0 |
|
|
| for iteration in range(num_iterations): |
| print(f"\n=== Iteration {iteration + 1}/{num_iterations} ===") |
|
|
| |
| batch_data = [] |
| for i, batch in enumerate(collector): |
| batch_data.append(batch) |
| collected_frames += batch.numel() |
|
|
| |
| if len(batch_data) * collector.frames_per_batch >= frames_per_batch: |
| break |
|
|
| |
| if batch_data: |
| full_batch = torch.cat(batch_data, dim=0) |
|
|
| |
| with torch.no_grad(): |
| full_batch = loss_module.value_estimator(full_batch) |
|
|
| replay_buffer.extend(full_batch) |
|
|
| |
| total_loss = 0 |
| total_actor_loss = 0 |
| total_critic_loss = 0 |
| total_entropy_loss = 0 |
|
|
| for epoch in range(num_epochs): |
| epoch_loss = 0 |
| epoch_actor_loss = 0 |
| epoch_critic_loss = 0 |
| epoch_entropy_loss = 0 |
| num_batches = 0 |
|
|
| for batch in replay_buffer: |
| print(batch) |
| |
| if "state_value" in batch and batch["state_value"].dim() > 1: |
| batch["state_value"] = batch["state_value"].squeeze(-1) |
|
|
| batch["value_target"] = batch["value_target"].squeeze(1) |
| |
| loss_dict = loss_module(batch) |
| loss = loss_dict["loss_objective"] + loss_dict["loss_critic"] + loss_dict["loss_entropy"] |
|
|
| |
| optim.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_norm=0.5) |
| optim.step() |
|
|
| |
| epoch_loss += loss.item() |
| epoch_actor_loss += loss_dict["loss_objective"].item() |
| epoch_critic_loss += loss_dict["loss_critic"].item() |
| epoch_entropy_loss += loss_dict["loss_entropy"].item() |
| num_batches += 1 |
|
|
| |
| if num_batches > 0: |
| total_loss += epoch_loss / num_batches |
| total_actor_loss += epoch_actor_loss / num_batches |
| total_critic_loss += epoch_critic_loss / num_batches |
| total_entropy_loss += epoch_entropy_loss / num_batches |
|
|
| |
| avg_total_loss = total_loss / num_epochs |
| avg_actor_loss = total_actor_loss / num_epochs |
| avg_critic_loss = total_critic_loss / num_epochs |
| avg_entropy_loss = total_entropy_loss / num_epochs |
|
|
| |
| metrics = { |
| "train/total_loss": avg_total_loss, |
| "train/actor_loss": avg_actor_loss, |
| "train/critic_loss": avg_critic_loss, |
| "train/entropy_loss": avg_entropy_loss, |
| "train/collected_frames": collected_frames, |
| } |
|
|
| |
| if "reward" in batch.keys(): |
| avg_reward = batch["reward"].mean().item() |
| metrics["train/avg_reward"] = avg_reward |
| print(f"Average Reward: {avg_reward:.3f}") |
|
|
| for key, value in metrics.items(): |
| logger.log_scalar(key, value, step=iteration) |
|
|
| print(f"Total Loss: {avg_total_loss:.4f}") |
| print(f"Actor Loss: {avg_actor_loss:.4f}") |
| print(f"Critic Loss: {avg_critic_loss:.4f}") |
| print(f"Entropy Loss: {avg_entropy_loss:.4f}") |
| print(f"Collected Frames: {collected_frames}") |
|
|
| |
| replay_buffer.empty() |
|
|
| print("\nTraining completed!") |
|
|
| train_ppo_chess(env) |
| torch.save(value.state_dict(),"NeoChess-Community/chessy_model.pth") |
| torch.save(policy.state_dict(),"NeoChess-Community/chessy_policy.pth") |