| """VectraYX-Nano transformer (decoder-only, ~42M params). |
| |
| Modern small-LLM stack: |
| RMSNorm (pre-norm) · SwiGLU FFN · RoPE · GQA (8q/2kv) |
| QK-Norm · no biases · tied embeddings · z-loss |
| """ |
|
|
| import json |
| import math |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| @dataclass |
| class ModelConfig: |
| vocab_size: int = 16384 |
| n_layers: int = 8 |
| n_heads: int = 8 |
| n_kv_heads: int = 2 |
| d_model: int = 512 |
| d_ffn: int = 2048 |
| max_seq_len: int = 1024 |
| rope_theta: float = 10000.0 |
| rms_eps: float = 1e-6 |
| init_std: float = 0.02 |
| dropout: float = 0.0 |
| tie_embeddings: bool = True |
| qk_norm: bool = True |
| z_loss_coef: float = 1e-4 |
|
|
| @classmethod |
| def from_json(cls, path): |
| cfg = json.loads(open(path).read())["model"] |
| return cls(**{k: cfg[k] for k in cfg if k in cls.__dataclass_fields__}) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(dim)) |
| self.eps = eps |
|
|
| def forward(self, x): |
| var = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(var + self.eps) |
| return x.to(self.weight.dtype) * self.weight |
|
|
|
|
| def precompute_rope(head_dim, max_seq_len, theta=10000.0, device=None): |
| inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim)) |
| t = torch.arange(max_seq_len, dtype=torch.float32) |
| freqs = torch.outer(t, inv_freq) |
| cos = freqs.cos() |
| sin = freqs.sin() |
| if device is not None: |
| cos = cos.to(device) |
| sin = sin.to(device) |
| return cos, sin |
|
|
|
|
| def apply_rope(x, cos, sin): |
| |
| T, D = x.shape[-2], x.shape[-1] |
| cos = cos[:T].view(1, 1, T, D // 2) |
| sin = sin[:T].view(1, 1, T, D // 2) |
| x1 = x[..., : D // 2] |
| x2 = x[..., D // 2:] |
| rx1 = x1 * cos - x2 * sin |
| rx2 = x1 * sin + x2 * cos |
| return torch.cat([rx1, rx2], dim=-1) |
|
|
|
|
| class GQAttention(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| assert cfg.d_model % cfg.n_heads == 0 |
| assert cfg.n_heads % cfg.n_kv_heads == 0 |
| self.n_heads = cfg.n_heads |
| self.n_kv_heads = cfg.n_kv_heads |
| self.head_dim = cfg.d_model // cfg.n_heads |
| self.repeat = self.n_heads // self.n_kv_heads |
|
|
| self.wq = nn.Linear(cfg.d_model, cfg.n_heads * self.head_dim, bias=False) |
| self.wk = nn.Linear(cfg.d_model, self.n_kv_heads * self.head_dim, bias=False) |
| self.wv = nn.Linear(cfg.d_model, self.n_kv_heads * self.head_dim, bias=False) |
| self.wo = nn.Linear(cfg.d_model, cfg.d_model, bias=False) |
|
|
| self.qk_norm = cfg.qk_norm |
| if self.qk_norm: |
| self.q_norm = RMSNorm(self.head_dim, eps=cfg.rms_eps) |
| self.k_norm = RMSNorm(self.head_dim, eps=cfg.rms_eps) |
|
|
| self.dropout = cfg.dropout |
|
|
| def forward(self, x, cos, sin): |
| B, T, _ = x.shape |
| q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) |
| k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| if self.qk_norm: |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
|
|
| q = apply_rope(q, cos, sin) |
| k = apply_rope(k, cos, sin) |
|
|
| if self.repeat > 1: |
| k = k.repeat_interleave(self.repeat, dim=1) |
| v = v.repeat_interleave(self.repeat, dim=1) |
|
|
| out = F.scaled_dot_product_attention( |
| q, k, v, |
| dropout_p=self.dropout if self.training else 0.0, |
| is_causal=True, |
| ) |
| out = out.transpose(1, 2).contiguous().view(B, T, -1) |
| return self.wo(out) |
|
|
|
|
| class SwiGLU(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.w_gate = nn.Linear(cfg.d_model, cfg.d_ffn, bias=False) |
| self.w_up = nn.Linear(cfg.d_model, cfg.d_ffn, bias=False) |
| self.w_down = nn.Linear(cfg.d_ffn, cfg.d_model, bias=False) |
|
|
| def forward(self, x): |
| return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) |
|
|
|
|
| class Block(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.attn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps) |
| self.attn = GQAttention(cfg) |
| self.ffn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps) |
| self.ffn = SwiGLU(cfg) |
|
|
| def forward(self, x, cos, sin): |
| x = x + self.attn(self.attn_norm(x), cos, sin) |
| x = x + self.ffn(self.ffn_norm(x)) |
| return x |
|
|
|
|
| class VectraYXNano(nn.Module): |
| def __init__(self, cfg: ModelConfig): |
| super().__init__() |
| self.cfg = cfg |
| self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) |
| self.layers = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)]) |
| self.final_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps) |
| self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False) |
|
|
| if cfg.tie_embeddings: |
| self.lm_head.weight = self.tok_emb.weight |
|
|
| head_dim = cfg.d_model // cfg.n_heads |
| cos, sin = precompute_rope(head_dim, cfg.max_seq_len, cfg.rope_theta) |
| self.register_buffer("rope_cos", cos, persistent=False) |
| self.register_buffer("rope_sin", sin, persistent=False) |
|
|
| self.apply(self._init_weights) |
| residual_std = cfg.init_std / math.sqrt(2 * cfg.n_layers) |
| for n, p in self.named_parameters(): |
| if n.endswith("wo.weight") or n.endswith("w_down.weight"): |
| nn.init.normal_(p, mean=0.0, std=residual_std) |
|
|
| def _init_weights(self, m): |
| std = self.cfg.init_std |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, mean=0.0, std=std) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, nn.Embedding): |
| nn.init.normal_(m.weight, mean=0.0, std=std) |
|
|
| def num_params(self, exclude_embedding=False): |
| n = sum(p.numel() for p in self.parameters()) |
| if exclude_embedding and self.cfg.tie_embeddings: |
| n -= self.tok_emb.weight.numel() |
| return n |
|
|
| def forward(self, idx, targets=None, loss_mask=None): |
| B, T = idx.shape |
| assert T <= self.cfg.max_seq_len, f"seq {T} > max {self.cfg.max_seq_len}" |
| x = self.tok_emb(idx) |
| cos = self.rope_cos |
| sin = self.rope_sin |
| for layer in self.layers: |
| x = layer(x, cos, sin) |
| x = self.final_norm(x) |
| logits = self.lm_head(x) |
|
|
| if targets is None: |
| return logits, None |
|
|
| |
| flat_logits = logits.view(-1, logits.size(-1)) |
| flat_tgt = targets.view(-1) |
| ce = F.cross_entropy(flat_logits, flat_tgt, reduction="none", ignore_index=-100) |
| if loss_mask is not None: |
| mask = loss_mask.view(-1).float() |
| denom = mask.sum().clamp_min(1.0) |
| ce_loss = (ce * mask).sum() / denom |
| else: |
| valid = (flat_tgt != -100).float() |
| denom = valid.sum().clamp_min(1.0) |
| ce_loss = (ce * valid).sum() / denom |
|
|
| if self.cfg.z_loss_coef > 0: |
| lse = torch.logsumexp(flat_logits.float(), dim=-1) |
| if loss_mask is not None: |
| z = ((lse ** 2) * loss_mask.view(-1).float()).sum() / denom |
| else: |
| z = ((lse ** 2) * (flat_tgt != -100).float()).sum() / denom |
| loss = ce_loss + self.cfg.z_loss_coef * z |
| else: |
| loss = ce_loss |
| return logits, loss |
|
|
| @torch.no_grad() |
| def generate(self, idx, max_new_tokens, temperature=0.7, top_k=40, top_p=0.9, |
| eos_id=None, repeat_penalty=1.0): |
| self.eval() |
| for _ in range(max_new_tokens): |
| cond = idx[:, -self.cfg.max_seq_len:] |
| logits, _ = self(cond) |
| logits = logits[:, -1, :].float() |
|
|
| if repeat_penalty != 1.0: |
| for token in set(idx[0].tolist()): |
| logits[0, token] = logits[0, token] / repeat_penalty if logits[0, token] > 0 else logits[0, token] * repeat_penalty |
|
|
| if temperature <= 0: |
| next_id = logits.argmax(-1, keepdim=True) |
| else: |
| logits = logits / temperature |
| if top_k: |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float("inf") |
| if top_p and top_p < 1.0: |
| sorted_logits, sorted_idx = torch.sort(logits, descending=True) |
| probs = F.softmax(sorted_logits, dim=-1) |
| cumprobs = probs.cumsum(-1) |
| drop = cumprobs > top_p |
| drop[..., 1:] = drop[..., :-1].clone() |
| drop[..., 0] = False |
| sorted_logits[drop] = -float("inf") |
| logits = torch.full_like(logits, -float("inf")).scatter(-1, sorted_idx, sorted_logits) |
| probs = F.softmax(logits, dim=-1) |
| next_id = torch.multinomial(probs, 1) |
| idx = torch.cat([idx, next_id], dim=-1) |
| if eos_id is not None and next_id.item() == eos_id: |
| break |
| return idx |
|
|