| import mlx as mx |
| import mlx.nn as mx_nn |
| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
| device = torch.device('mps') |
|
|
| CONFIG = { |
| "model_path": "chessy_model.pth", |
| "backup_model_path": "chessy_modelt-1.pth", |
| } |
|
|
| class NN1(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.embedding = nn.Embedding(13, 64) |
| self.attention = nn.MultiheadAttention(embed_dim=64, num_heads=16) |
| self.neu = 512 |
| self.neurons = nn.Sequential( |
| nn.Linear(4096, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, self.neu), |
| nn.ReLU(), |
| nn.Linear(self.neu, 64), |
| nn.ReLU(), |
| nn.Linear(64, 4) |
| ) |
|
|
| def forward(self, x): |
| x = self.embedding(x) |
| x = x.permute(1, 0, 2) |
| attn_output, _ = self.attention(x, x, x) |
| x = attn_output.permute(1, 0, 2).contiguous() |
| x = x.view(x.size(0), -1) |
| x = self.neurons(x) |
| return x |
|
|
| model = NN1().to(device) |
| try: |
| model.load_state_dict(torch.load(CONFIG['model_path'], map_location=device)) |
| print(f"Loaded model from {CONFIG['model_path']}") |
| except FileNotFoundError: |
| try: |
| model.load_state_dict(torch.load(CONFIG["backup_model_path"], map_location=device)) |
| print(f"Loaded backup model from {CONFIG['backup_model_path']}") |
| except FileNotFoundError: |
| print("No model file found, starting from scratch.") |
| weights = {k: v.detach().cpu().numpy() for k, v in model.state_dict().items()} |
| np.savez("chessy_model_mlx.npz", **weights) |