| """Building-block nn.Modules for the Needle Simple Attention Network. |
| |
| ZCRMSNorm — zero-centred RMSNorm: (1+γ)*x / RMS(x) |
| RoPE — pre-computed cos/sin freqs + static apply() |
| MultiHeadAttention — GQA, optional RoPE, optional past-KV caching |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .config import TransformerConfig |
|
|
|
|
| |
| |
| |
|
|
| class ZCRMSNorm(nn.Module): |
| """Zero-centred RMSNorm. |
| |
| Formula: (1 + γ) * x / RMS(x) |
| where γ is a learnable scale initialized to zero. |
| |
| Matches Flax architecture.py ZCRMSNorm exactly. |
| """ |
|
|
| def __init__(self, d: int, epsilon: float = 1e-6): |
| super().__init__() |
| self.epsilon = epsilon |
| |
| self.scale = nn.Parameter(torch.zeros(d)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| orig_dtype = x.dtype |
| x_f32 = x.float() |
| rms = torch.sqrt(x_f32.pow(2).mean(dim=-1, keepdim=True) + self.epsilon) |
| return ((1.0 + self.scale) * x_f32 / rms).to(orig_dtype) |
|
|
|
|
| |
| |
| |
|
|
| class RoPE(nn.Module): |
| """Pre-computed rotary position embeddings. |
| |
| Buffers are NOT parameters (no gradient needed). |
| Exposes a static apply() helper for use inside MultiHeadAttention. |
| """ |
|
|
| def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10000.0): |
| super().__init__() |
| |
| half = head_dim // 2 |
| freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
| t = torch.arange(max_seq_len).float() |
| angles = torch.outer(t, freqs) |
| self.register_buffer("cos", torch.cos(angles), persistent=False) |
| self.register_buffer("sin", torch.sin(angles), persistent=False) |
|
|
| @staticmethod |
| def apply(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| """Apply RoPE to x of shape (B, num_heads, T, head_dim). |
| |
| Matches Flax apply_rope(): |
| x1 = x[..., :half] x2 = x[..., half:] |
| return cat([x1*cos - x2*sin, x2*cos + x1*sin], dim=-1) |
| """ |
| T = x.shape[2] |
| |
| cos_t = cos[:T].unsqueeze(0).unsqueeze(0) |
| sin_t = sin[:T].unsqueeze(0).unsqueeze(0) |
| half = x.shape[-1] // 2 |
| x1, x2 = x[..., :half], x[..., half:] |
| return torch.cat([x1 * cos_t - x2 * sin_t, |
| x2 * cos_t + x1 * sin_t], dim=-1) |
|
|
| def get_cos_sin(self, seq_len: int): |
| """Return (cos, sin) buffers sliced to seq_len.""" |
| return self.cos[:seq_len], self.sin[:seq_len] |
|
|
|
|
| |
| |
| |
|
|
| def make_causal_mask(seq: int, past_seq: int = 0, device=None) -> torch.Tensor: |
| """Lower-triangular bool mask of shape (1, 1, seq, seq+past_seq). |
| |
| Position i in the query can attend to positions 0..i+past_seq inclusive. |
| """ |
| total = seq + past_seq |
| |
| |
| row_idx = torch.arange(past_seq, total, device=device).unsqueeze(1) |
| col_idx = torch.arange(total, device=device).unsqueeze(0) |
| mask = row_idx >= col_idx |
| return mask.unsqueeze(0).unsqueeze(0) |
|
|
|
|
| def make_padding_mask(tokens: torch.Tensor, pad_token_id: int) -> torch.Tensor: |
| """Boolean padding mask: True where token != pad. Shape (B, 1, 1, T).""" |
| return (tokens != pad_token_id).unsqueeze(1).unsqueeze(2) |
|
|
|
|
| |
| |
| |
|
|
| class MultiHeadAttention(nn.Module): |
| """Grouped-query attention with optional RoPE and past-KV caching. |
| |
| Args: |
| config: TransformerConfig |
| is_cross_attn: if True, Q comes from one source and K/V from another. |
| is_causal: if True, applies causal mask (for decoder self-attention). |
| Also enables past_kv acceptance in forward(). |
| """ |
|
|
| def __init__(self, config: TransformerConfig, is_cross_attn: bool = False, |
| is_causal: bool = False): |
| super().__init__() |
| self.num_heads = config.num_heads |
| self.num_kv_heads = config.num_kv_heads |
| self.d_model = config.d_model |
| self.head_dim = config.d_model // config.num_heads |
| self.is_cross_attn = is_cross_attn |
| self.is_causal = is_causal |
| self.rope_keys_only = config.rope_keys_only |
| self.repeats = config.num_heads // config.num_kv_heads |
|
|
| kv_dim = config.num_kv_heads * self.head_dim |
|
|
| |
| self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False) |
| self.k_proj = nn.Linear(config.d_model, kv_dim, bias=False) |
| self.v_proj = nn.Linear(config.d_model, kv_dim, bias=False) |
| self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) |
|
|
| |
| |
| |
| self.q_norm = ZCRMSNorm(self.head_dim) |
| self.k_norm = ZCRMSNorm(self.head_dim) |
|
|
| self._scale = math.sqrt(self.head_dim) |
|
|
| def forward( |
| self, |
| q_input: torch.Tensor, |
| kv_input: torch.Tensor, |
| mask: torch.Tensor | None = None, |
| rope: tuple[torch.Tensor, torch.Tensor] | None = None, |
| past_kv: tuple[torch.Tensor, torch.Tensor] | None = None, |
| ): |
| """ |
| Args: |
| q_input: (B, T_q, d_model) |
| kv_input: (B, T_kv, d_model) |
| mask: (B, 1, T_q, T_kv) bool — True = attend |
| rope: (cos, sin) tensors of shape (T, head_dim//2) each |
| past_kv: (k_cache, v_cache), each (B, num_kv_heads, past_T, head_dim) |
| Only used when is_causal=True (decoder self-attn). |
| |
| Returns: |
| out: (B, T_q, d_model) |
| present_kv: (k, v) each (B, num_kv_heads, T_total, head_dim) |
| """ |
| B, T_q, _ = q_input.shape |
|
|
| q = self.q_proj(q_input) |
| k = self.k_proj(kv_input) |
| v = self.v_proj(kv_input) |
|
|
| |
| q = q.reshape(B, T_q, self.num_heads, self.head_dim).transpose(1, 2) |
| T_kv = k.shape[1] |
| k = k.reshape(B, T_kv, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = v.reshape(B, T_kv, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
|
|
| |
| |
| |
| if rope is not None: |
| cos, sin = rope |
| if not self.rope_keys_only: |
| q = RoPE.apply(q, cos, sin) |
| k = RoPE.apply(k, cos, sin) |
|
|
| |
| |
| if past_kv is not None: |
| k_past, v_past = past_kv |
| k = torch.cat([k_past, k], dim=2) |
| v = torch.cat([v_past, v], dim=2) |
|
|
| present_kv = (k, v) |
|
|
| |
| if self.repeats > 1: |
| k = k.repeat_interleave(self.repeats, dim=1) |
| v = v.repeat_interleave(self.repeats, dim=1) |
|
|
| |
| |
| |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self._scale |
|
|
| if mask is not None: |
| |
| |
| attn_weights = attn_weights.masked_fill(~mask, float("-inf")) |
|
|
| attn_weights = F.softmax(attn_weights.float(), dim=-1).to(q.dtype) |
|
|
| out = torch.matmul(attn_weights, v) |
| out = out.transpose(1, 2).reshape(B, T_q, self.d_model) |
| out = self.out_proj(out) |
|
|
| return out, present_kv |
|
|