| from collections import deque
|
| import random
|
| import torch
|
| import torch
|
| from engine import GameState
|
| from move_finder import find_best_move_shallow
|
| from infer_nnue import gs_to_nnue_features
|
| from nnue_model import NNUE
|
| from tqdm import tqdm
|
| from infer_nnue import NNUEInfer
|
| NNUE_FEATURES = 32
|
| def pad_features(feats):
|
| if len(feats) < NNUE_FEATURES:
|
| return feats + [0] * (NNUE_FEATURES - len(feats))
|
| return feats[:NNUE_FEATURES]
|
|
|
| import pickle
|
|
|
| def load_pgn_dataset(path):
|
| trajectories = []
|
| current_traj = []
|
|
|
| with open(path, "rb") as f:
|
| while True:
|
| try:
|
| chunk = pickle.load(f)
|
| for item in chunk:
|
| current_traj.append(item)
|
|
|
|
|
| if len(current_traj) > 1 and \
|
| current_traj[-1]["stm"] != current_traj[-2]["stm"]:
|
| trajectories.append(current_traj)
|
| current_traj = []
|
|
|
| except EOFError:
|
| break
|
|
|
| if current_traj:
|
| trajectories.append(current_traj)
|
|
|
| return trajectories
|
|
|
|
|
| @torch.no_grad()
|
| @torch.no_grad()
|
| def td_targets_from_traj(model, traj, gamma=0.99):
|
| if len(traj) == 1:
|
| return [0.0]
|
|
|
| feats = [pad_features(x["features"]) for x in traj]
|
| stm = [x["stm"] for x in traj]
|
|
|
| feats = torch.tensor(feats, dtype=torch.long, device="cuda")
|
| stm = torch.tensor(stm, dtype=torch.long, device="cuda")
|
|
|
| values = model(feats, stm).view(-1)
|
|
|
| targets = torch.empty_like(values)
|
|
|
|
|
| targets[:-1] = gamma * (-values[1:])
|
| targets[-1] = values[-1].detach()
|
|
|
|
|
| targets = torch.clamp(targets, -1.0, 1.0)
|
|
|
| return targets.cpu().tolist()
|
|
|
|
|
|
|
| from collections import deque
|
| import random
|
|
|
| class ReplayBuffer:
|
| def __init__(self, capacity=300_000):
|
| self.buf = deque(maxlen=capacity)
|
|
|
| def add(self, f, stm, t):
|
| self.buf.append((f, stm, t))
|
|
|
| def sample(self, n):
|
| return random.sample(self.buf, n)
|
|
|
| def __len__(self):
|
| return len(self.buf)
|
|
|
|
|
| def train_from_replay(model, optimizer, replay, batch_size):
|
| if len(replay) < batch_size:
|
| return
|
|
|
| batch = replay.sample(batch_size)
|
| feats, stm, targets = zip(*batch)
|
|
|
| feats = torch.tensor(feats, dtype=torch.long, device="cuda")
|
| stm = torch.tensor(stm, dtype=torch.long, device="cuda")
|
| targ = torch.tensor(targets, dtype=torch.float, device="cuda")
|
|
|
| preds = model(feats, stm).view(-1)
|
|
|
| loss = torch.nn.functional.smooth_l1_loss(preds, targ)
|
|
|
| optimizer.zero_grad(set_to_none=True)
|
| loss.backward()
|
| optimizer.step()
|
|
|
|
|
| from tqdm import tqdm
|
| device = "cuda"
|
| model = NNUE().to(device)
|
| model.load_state_dict(torch.load("nnue_model.pt", weights_only=True))
|
| optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
|
|
|
| replay = ReplayBuffer()
|
| trajectories = load_pgn_dataset("nnue_dataset.pkl")
|
|
|
| for epoch in range(3):
|
| print(f"Epoch {epoch}")
|
|
|
| for traj in tqdm(trajectories):
|
| if len(traj) < 2:
|
| continue
|
|
|
| targets = td_targets_from_traj(model, traj)
|
|
|
| for x, t in zip(traj, targets):
|
| replay.add(
|
| pad_features(x["features"]),
|
| x["stm"],
|
| t
|
| )
|
|
|
| for _ in range(3):
|
| train_from_replay(model, optimizer, replay, batch_size=512)
|
|
|
| torch.save(model.state_dict(), "nnue_model_td.pt")
|
|
|