| """ |
| ChessGPT -- LLaMA-style decoder-only transformer for UCI move prediction. |
| |
| Architecture: RMSNorm, RoPE, SwiGLU, QK-Norm, no bias, scaled residual init. |
| HuggingFace-compatible implementation. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.checkpoint import checkpoint as torch_checkpoint |
| from transformers import PreTrainedModel |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| from .configuration_chessgpt import ChessGPTConfig |
|
|
|
|
| |
| |
| |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight |
|
|
|
|
| def precompute_rope_freqs( |
| head_dim: int, max_seq_len: int, theta: float = 10000.0 |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Return (freqs_cos, freqs_sin) as real-valued tensors.""" |
| freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| t = torch.arange(max_seq_len) |
| angles = torch.outer(t, freqs) |
| return angles.cos(), angles.sin() |
|
|
|
|
| def apply_rotary_emb( |
| xq: torch.Tensor, |
| xk: torch.Tensor, |
| freqs_cos: torch.Tensor, |
| freqs_sin: torch.Tensor, |
| ): |
| |
| T = xq.shape[2] |
| cos = freqs_cos[:T][None, None, :, :] |
| sin = freqs_sin[:T][None, None, :, :] |
|
|
| |
| xq_r = xq.float().reshape(*xq.shape[:-1], -1, 2) |
| xk_r = xk.float().reshape(*xk.shape[:-1], -1, 2) |
|
|
| xq_out = torch.stack([ |
| xq_r[..., 0] * cos - xq_r[..., 1] * sin, |
| xq_r[..., 0] * sin + xq_r[..., 1] * cos, |
| ], dim=-1).flatten(-2) |
|
|
| xk_out = torch.stack([ |
| xk_r[..., 0] * cos - xk_r[..., 1] * sin, |
| xk_r[..., 0] * sin + xk_r[..., 1] * cos, |
| ], dim=-1).flatten(-2) |
|
|
| return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, d_model: int, d_ff: int): |
| super().__init__() |
| self.w1 = nn.Linear(d_model, d_ff, bias=False) |
| self.w3 = nn.Linear(d_model, d_ff, bias=False) |
| self.w2 = nn.Linear(d_ff, d_model, bias=False) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
| class CausalSelfAttention(nn.Module): |
| """Causal self-attention with RoPE and QK-Norm, using PyTorch SDPA.""" |
|
|
| def __init__(self, config: ChessGPTConfig): |
| super().__init__() |
| assert config.d_model % config.n_heads == 0 |
| self.n_heads = config.n_heads |
| self.head_dim = config.d_model // config.n_heads |
| self.qkv = nn.Linear(config.d_model, 3 * config.d_model, bias=False) |
| self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) |
| self.proj = nn.Linear(config.d_model, config.d_model, bias=False) |
|
|
| def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor: |
| B, T, C = x.shape |
| qkv = self.qkv(x).view(B, T, 3, self.n_heads, self.head_dim) |
| q, k, v = qkv.unbind(dim=2) |
|
|
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
|
|
| q, k = apply_rotary_emb(q, k, freqs_cos, freqs_sin) |
|
|
| |
| y = F.scaled_dot_product_attention( |
| q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True |
| ) |
|
|
| y = y.transpose(1, 2).contiguous().view(B, T, C) |
| return self.proj(y) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, config: ChessGPTConfig): |
| super().__init__() |
| self.ln1 = RMSNorm(config.d_model, eps=config.rms_norm_eps) |
| self.attn = CausalSelfAttention(config) |
| self.ln2 = RMSNorm(config.d_model, eps=config.rms_norm_eps) |
| self.ffn = SwiGLU(config.d_model, config.d_ff) |
|
|
| def forward(self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor: |
| x = x + self.attn(self.ln1(x), freqs_cos, freqs_sin) |
| x = x + self.ffn(self.ln2(x)) |
| return x |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ChessGPTPreTrainedModel(PreTrainedModel): |
| config_class = ChessGPTConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["TransformerBlock"] |
|
|
| def _init_weights(self, module): |
| std = self.config.weight_init_std |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| nn.init.normal_(module.weight, mean=0.0, std=std) |
|
|
|
|
| class ChessGPTModel(ChessGPTPreTrainedModel): |
| """The bare ChessGPT transformer outputting raw hidden-states.""" |
|
|
| def __init__(self, config: ChessGPTConfig): |
| super().__init__(config) |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| head_dim = config.d_model // config.n_heads |
| freqs_cos, freqs_sin = precompute_rope_freqs( |
| head_dim, config.max_seq_len, config.rope_theta |
| ) |
| self.register_buffer("freqs_cos", freqs_cos, persistent=True) |
| self.register_buffer("freqs_sin", freqs_sin, persistent=True) |
|
|
| self.blocks = nn.ModuleList( |
| [TransformerBlock(config) for _ in range(config.n_layers)] |
| ) |
| self.ln_f = RMSNorm(config.d_model, eps=config.rms_norm_eps) |
|
|
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| B, T = input_ids.shape |
| if T > self.config.max_seq_len: |
| raise ValueError( |
| f"Sequence length {T} > max_seq_len {self.config.max_seq_len}" |
| ) |
|
|
| x = self.embed_tokens(input_ids) |
|
|
| for block in self.blocks: |
| if self.gradient_checkpointing and self.training: |
| x = torch_checkpoint(block, x, self.freqs_cos, self.freqs_sin, use_reentrant=False) |
| else: |
| x = block(x, self.freqs_cos, self.freqs_sin) |
|
|
| x = self.ln_f(x) |
| return x |
|
|
|
|
| class ChessGPTForCausalLM(ChessGPTPreTrainedModel): |
| _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} |
|
|
| def __init__(self, config: ChessGPTConfig): |
| super().__init__(config) |
| self.model = ChessGPTModel(config) |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def get_decoder(self): |
| return self.model |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| hidden_states = self.model(input_ids, attention_mask=attention_mask) |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ignore_index=self.config.pad_token_id, |
| ) |
|
|
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|