| """
|
| Pure Transformer Layers (extracted from Samsung's TRM)
|
|
|
| License: Apache 2.0
|
| Source: https://github.com/Sam-Saarinen/TinyRecursiveModels
|
| Attribution: Adapted from Samsung's Tiny Recursive Model (TRM) codebase
|
| """
|
| import math
|
| from typing import Tuple
|
| import torch
|
| from torch import nn
|
| import torch.nn.functional as F
|
|
|
|
|
| def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
|
| """Truncated normal initialization from JAX/Flax"""
|
| with torch.no_grad():
|
| if std == 0:
|
| tensor.zero_()
|
| else:
|
| sqrt2 = math.sqrt(2)
|
| a = math.erf(lower / sqrt2)
|
| b = math.erf(upper / sqrt2)
|
| z = (b - a) / 2
|
|
|
| c = (2 * math.pi) ** -0.5
|
| pdf_u = c * math.exp(-0.5 * lower ** 2)
|
| pdf_l = c * math.exp(-0.5 * lower ** 2)
|
| comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
|
|
|
| tensor.uniform_(a, b)
|
| tensor.erfinv_()
|
| tensor.mul_(sqrt2 * comp_std)
|
| tensor.clip_(lower * comp_std, upper * comp_std)
|
| return tensor
|
|
|
|
|
| def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float = 1e-5) -> torch.Tensor:
|
| """RMS Normalization - faster than LayerNorm"""
|
| input_dtype = hidden_states.dtype
|
| hidden_states = hidden_states.to(torch.float32)
|
| variance = hidden_states.square().mean(-1, keepdim=True)
|
| hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
|
| return hidden_states.to(input_dtype)
|
|
|
|
|
| def rotate_half(x: torch.Tensor):
|
| """Rotates half the hidden dims for RoPE"""
|
| x1 = x[..., : x.shape[-1] // 2]
|
| x2 = x[..., x.shape[-1] // 2 :]
|
| return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
| def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
|
| """Apply rotary positional embeddings"""
|
| orig_dtype = q.dtype
|
| q = q.to(cos.dtype)
|
| k = k.to(cos.dtype)
|
|
|
| q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
|
| k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
|
|
|
| return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
|
|
|
|
|
| class CastedLinear(nn.Module):
|
| """Linear layer with automatic dtype casting for mixed precision"""
|
| def __init__(self, in_features: int, out_features: int, bias: bool = False):
|
| super().__init__()
|
| self.weight = nn.Parameter(
|
| trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
|
| )
|
| self.bias = None
|
| if bias:
|
| self.bias = nn.Parameter(torch.zeros((out_features, )))
|
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| return F.linear(input, self.weight.to(input.dtype),
|
| bias=self.bias.to(input.dtype) if self.bias is not None else None)
|
|
|
|
|
| class RotaryEmbedding(nn.Module):
|
| """Rotary Position Embedding (RoPE)"""
|
| def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| super().__init__()
|
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
|
| t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
|
| freqs = torch.outer(t, inv_freq)
|
| emb = torch.cat((freqs, freqs), dim=-1)
|
| self.register_buffer('cos_cached', emb.cos(), persistent=False)
|
| self.register_buffer('sin_cached', emb.sin(), persistent=False)
|
|
|
| def forward(self):
|
| return self.cos_cached, self.sin_cached
|
|
|
|
|
| class SwiGLU(nn.Module):
|
| """SwiGLU activation (Swish + GLU) - from Samsung TRM"""
|
| def __init__(self, hidden_size: int, expansion: float = 2.667):
|
| super().__init__()
|
| inter = round(expansion * hidden_size * 2 / 3)
|
| inter = ((inter + 255) // 256) * 256
|
|
|
| self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
|
| self.down_proj = CastedLinear(inter, hidden_size, bias=False)
|
|
|
| def forward(self, x):
|
| gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
| return self.down_proj(F.silu(gate) * up)
|
|
|
|
|
| class TransformerAttention(nn.Module):
|
| """Multi-head attention with RoPE support"""
|
| def __init__(self, hidden_size: int, num_heads: int = 8, head_dim: int = 64):
|
| super().__init__()
|
| self.hidden_size = hidden_size
|
| self.num_heads = num_heads
|
| self.head_dim = head_dim
|
| self.output_size = head_dim * num_heads
|
|
|
| self.qkv_proj = CastedLinear(hidden_size, num_heads * head_dim * 3, bias=False)
|
| self.o_proj = CastedLinear(self.output_size, hidden_size, bias=False)
|
|
|
| def forward(self, hidden_states: torch.Tensor, cos_sin=None) -> torch.Tensor:
|
| B, S, _ = hidden_states.shape
|
|
|
|
|
| qkv = self.qkv_proj(hidden_states)
|
| qkv = qkv.view(B, S, self.num_heads * 3, self.head_dim)
|
|
|
| query = qkv[:, :, :self.num_heads]
|
| key = qkv[:, :, self.num_heads:self.num_heads * 2]
|
| value = qkv[:, :, self.num_heads * 2:]
|
|
|
|
|
| if cos_sin is not None:
|
| cos, sin = cos_sin
|
| query, key = apply_rotary_pos_emb(query, key, cos[:S], sin[:S])
|
|
|
|
|
| query = query.transpose(1, 2)
|
| key = key.transpose(1, 2)
|
| value = value.transpose(1, 2)
|
|
|
| attn_output = F.scaled_dot_product_attention(query, key, value)
|
| attn_output = attn_output.transpose(1, 2).reshape(B, S, self.output_size)
|
|
|
| return self.o_proj(attn_output)
|
|
|
|
|
| class TransformerBlock(nn.Module):
|
| """Single transformer block with RMS norm and SwiGLU"""
|
| def __init__(self, hidden_size: int, num_heads: int = 8, expansion: float = 4.0, rms_eps: float = 1e-5):
|
| super().__init__()
|
| self.rms_eps = rms_eps
|
|
|
| self.attention = TransformerAttention(hidden_size, num_heads, hidden_size // num_heads)
|
| self.mlp = SwiGLU(hidden_size, expansion)
|
|
|
| def forward(self, x: torch.Tensor, cos_sin=None) -> torch.Tensor:
|
|
|
| h = rms_norm(x, self.rms_eps)
|
| h = self.attention(h, cos_sin)
|
| x = x + h
|
|
|
|
|
| h = rms_norm(x, self.rms_eps)
|
| h = self.mlp(h)
|
| x = x + h
|
|
|
| return x
|
|
|