vectrayx-paper-code / training /transformer.py
jsantillana's picture
Upload folder using huggingface_hub
6848cb6 verified
"""VectraYX-Nano transformer (decoder-only, ~42M params).
Modern small-LLM stack:
RMSNorm (pre-norm) · SwiGLU FFN · RoPE · GQA (8q/2kv)
QK-Norm · no biases · tied embeddings · z-loss
"""
import json
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class ModelConfig:
vocab_size: int = 16384
n_layers: int = 8
n_heads: int = 8
n_kv_heads: int = 2
d_model: int = 512
d_ffn: int = 2048
max_seq_len: int = 1024
rope_theta: float = 10000.0
rms_eps: float = 1e-6
init_std: float = 0.02
dropout: float = 0.0
tie_embeddings: bool = True
qk_norm: bool = True
z_loss_coef: float = 1e-4
@classmethod
def from_json(cls, path):
cfg = json.loads(open(path).read())["model"]
return cls(**{k: cfg[k] for k in cfg if k in cls.__dataclass_fields__})
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return x.to(self.weight.dtype) * self.weight
def precompute_rope(head_dim, max_seq_len, theta=10000.0, device=None):
inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
if device is not None:
cos = cos.to(device)
sin = sin.to(device)
return cos, sin
def apply_rope(x, cos, sin):
# x: (B, H, T, D) with D even. cos/sin: (T, D/2)
T, D = x.shape[-2], x.shape[-1]
cos = cos[:T].view(1, 1, T, D // 2)
sin = sin[:T].view(1, 1, T, D // 2)
x1 = x[..., : D // 2]
x2 = x[..., D // 2:]
rx1 = x1 * cos - x2 * sin
rx2 = x1 * sin + x2 * cos
return torch.cat([rx1, rx2], dim=-1)
class GQAttention(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
assert cfg.d_model % cfg.n_heads == 0
assert cfg.n_heads % cfg.n_kv_heads == 0
self.n_heads = cfg.n_heads
self.n_kv_heads = cfg.n_kv_heads
self.head_dim = cfg.d_model // cfg.n_heads
self.repeat = self.n_heads // self.n_kv_heads
self.wq = nn.Linear(cfg.d_model, cfg.n_heads * self.head_dim, bias=False)
self.wk = nn.Linear(cfg.d_model, self.n_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(cfg.d_model, self.n_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(cfg.d_model, cfg.d_model, bias=False)
self.qk_norm = cfg.qk_norm
if self.qk_norm:
self.q_norm = RMSNorm(self.head_dim, eps=cfg.rms_eps)
self.k_norm = RMSNorm(self.head_dim, eps=cfg.rms_eps)
self.dropout = cfg.dropout
def forward(self, x, cos, sin):
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.wk(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.wv(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
if self.qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
if self.repeat > 1:
k = k.repeat_interleave(self.repeat, dim=1)
v = v.repeat_interleave(self.repeat, dim=1)
out = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
is_causal=True,
)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.wo(out)
class SwiGLU(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.w_gate = nn.Linear(cfg.d_model, cfg.d_ffn, bias=False)
self.w_up = nn.Linear(cfg.d_model, cfg.d_ffn, bias=False)
self.w_down = nn.Linear(cfg.d_ffn, cfg.d_model, bias=False)
def forward(self, x):
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
class Block(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.attn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
self.attn = GQAttention(cfg)
self.ffn_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
self.ffn = SwiGLU(cfg)
def forward(self, x, cos, sin):
x = x + self.attn(self.attn_norm(x), cos, sin)
x = x + self.ffn(self.ffn_norm(x))
return x
class VectraYXNano(nn.Module):
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
self.layers = nn.ModuleList([Block(cfg) for _ in range(cfg.n_layers)])
self.final_norm = RMSNorm(cfg.d_model, eps=cfg.rms_eps)
self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
if cfg.tie_embeddings:
self.lm_head.weight = self.tok_emb.weight
head_dim = cfg.d_model // cfg.n_heads
cos, sin = precompute_rope(head_dim, cfg.max_seq_len, cfg.rope_theta)
self.register_buffer("rope_cos", cos, persistent=False)
self.register_buffer("rope_sin", sin, persistent=False)
self.apply(self._init_weights)
residual_std = cfg.init_std / math.sqrt(2 * cfg.n_layers)
for n, p in self.named_parameters():
if n.endswith("wo.weight") or n.endswith("w_down.weight"):
nn.init.normal_(p, mean=0.0, std=residual_std)
def _init_weights(self, m):
std = self.cfg.init_std
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=std)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=std)
def num_params(self, exclude_embedding=False):
n = sum(p.numel() for p in self.parameters())
if exclude_embedding and self.cfg.tie_embeddings:
n -= self.tok_emb.weight.numel()
return n
def forward(self, idx, targets=None, loss_mask=None):
B, T = idx.shape
assert T <= self.cfg.max_seq_len, f"seq {T} > max {self.cfg.max_seq_len}"
x = self.tok_emb(idx)
cos = self.rope_cos
sin = self.rope_sin
for layer in self.layers:
x = layer(x, cos, sin)
x = self.final_norm(x)
logits = self.lm_head(x)
if targets is None:
return logits, None
# cross-entropy + z-loss for stability
flat_logits = logits.view(-1, logits.size(-1))
flat_tgt = targets.view(-1)
ce = F.cross_entropy(flat_logits, flat_tgt, reduction="none", ignore_index=-100)
if loss_mask is not None:
mask = loss_mask.view(-1).float()
denom = mask.sum().clamp_min(1.0)
ce_loss = (ce * mask).sum() / denom
else:
valid = (flat_tgt != -100).float()
denom = valid.sum().clamp_min(1.0)
ce_loss = (ce * valid).sum() / denom
if self.cfg.z_loss_coef > 0:
lse = torch.logsumexp(flat_logits.float(), dim=-1)
if loss_mask is not None:
z = ((lse ** 2) * loss_mask.view(-1).float()).sum() / denom
else:
z = ((lse ** 2) * (flat_tgt != -100).float()).sum() / denom
loss = ce_loss + self.cfg.z_loss_coef * z
else:
loss = ce_loss
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=0.7, top_k=40, top_p=0.9,
eos_id=None, repeat_penalty=1.0):
self.eval()
for _ in range(max_new_tokens):
cond = idx[:, -self.cfg.max_seq_len:]
logits, _ = self(cond)
logits = logits[:, -1, :].float()
if repeat_penalty != 1.0:
for token in set(idx[0].tolist()):
logits[0, token] = logits[0, token] / repeat_penalty if logits[0, token] > 0 else logits[0, token] * repeat_penalty
if temperature <= 0:
next_id = logits.argmax(-1, keepdim=True)
else:
logits = logits / temperature
if top_k:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("inf")
if top_p and top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
probs = F.softmax(sorted_logits, dim=-1)
cumprobs = probs.cumsum(-1)
drop = cumprobs > top_p
drop[..., 1:] = drop[..., :-1].clone()
drop[..., 0] = False
sorted_logits[drop] = -float("inf")
logits = torch.full_like(logits, -float("inf")).scatter(-1, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, 1)
idx = torch.cat([idx, next_id], dim=-1)
if eos_id is not None and next_id.item() == eos_id:
break
return idx