""" Pure Transformer Layers (extracted from Samsung's TRM) License: Apache 2.0 Source: https://github.com/Sam-Saarinen/TinyRecursiveModels Attribution: Adapted from Samsung's Tiny Recursive Model (TRM) codebase """ import math from typing import Tuple import torch from torch import nn import torch.nn.functional as F def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0): """Truncated normal initialization from JAX/Flax""" with torch.no_grad(): if std == 0: tensor.zero_() else: sqrt2 = math.sqrt(2) a = math.erf(lower / sqrt2) b = math.erf(upper / sqrt2) z = (b - a) / 2 c = (2 * math.pi) ** -0.5 pdf_u = c * math.exp(-0.5 * lower ** 2) pdf_l = c * math.exp(-0.5 * lower ** 2) comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2) tensor.uniform_(a, b) tensor.erfinv_() tensor.mul_(sqrt2 * comp_std) tensor.clip_(lower * comp_std, upper * comp_std) return tensor def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float = 1e-5) -> torch.Tensor: """RMS Normalization - faster than LayerNorm""" input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.square().mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon) return hidden_states.to(input_dtype) def rotate_half(x: torch.Tensor): """Rotates half the hidden dims for RoPE""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): """Apply rotary positional embeddings""" orig_dtype = q.dtype q = q.to(cos.dtype) k = k.to(cos.dtype) q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2)) k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2)) return q_embed.to(orig_dtype), k_embed.to(orig_dtype) class CastedLinear(nn.Module): """Linear layer with automatic dtype casting for mixed precision""" def __init__(self, in_features: int, out_features: int, bias: bool = False): super().__init__() self.weight = nn.Parameter( trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5)) ) self.bias = None if bias: self.bias = nn.Parameter(torch.zeros((out_features, ))) def forward(self, input: torch.Tensor) -> torch.Tensor: return F.linear(input, self.weight.to(input.dtype), bias=self.bias.to(input.dtype) if self.bias is not None else None) class RotaryEmbedding(nn.Module): """Rotary Position Embedding (RoPE)""" def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)) t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device) freqs = torch.outer(t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer('cos_cached', emb.cos(), persistent=False) self.register_buffer('sin_cached', emb.sin(), persistent=False) def forward(self): return self.cos_cached, self.sin_cached class SwiGLU(nn.Module): """SwiGLU activation (Swish + GLU) - from Samsung TRM""" def __init__(self, hidden_size: int, expansion: float = 2.667): super().__init__() inter = round(expansion * hidden_size * 2 / 3) inter = ((inter + 255) // 256) * 256 # Round to multiple of 256 self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False) self.down_proj = CastedLinear(inter, hidden_size, bias=False) def forward(self, x): gate, up = self.gate_up_proj(x).chunk(2, dim=-1) return self.down_proj(F.silu(gate) * up) class TransformerAttention(nn.Module): """Multi-head attention with RoPE support""" def __init__(self, hidden_size: int, num_heads: int = 8, head_dim: int = 64): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = head_dim self.output_size = head_dim * num_heads self.qkv_proj = CastedLinear(hidden_size, num_heads * head_dim * 3, bias=False) self.o_proj = CastedLinear(self.output_size, hidden_size, bias=False) def forward(self, hidden_states: torch.Tensor, cos_sin=None) -> torch.Tensor: B, S, _ = hidden_states.shape # Project to Q, K, V qkv = self.qkv_proj(hidden_states) qkv = qkv.view(B, S, self.num_heads * 3, self.head_dim) query = qkv[:, :, :self.num_heads] key = qkv[:, :, self.num_heads:self.num_heads * 2] value = qkv[:, :, self.num_heads * 2:] # Apply RoPE if provided if cos_sin is not None: cos, sin = cos_sin query, key = apply_rotary_pos_emb(query, key, cos[:S], sin[:S]) # Attention (using PyTorch's optimized SDPA) query = query.transpose(1, 2) # B, H, S, D key = key.transpose(1, 2) value = value.transpose(1, 2) attn_output = F.scaled_dot_product_attention(query, key, value) attn_output = attn_output.transpose(1, 2).reshape(B, S, self.output_size) return self.o_proj(attn_output) class TransformerBlock(nn.Module): """Single transformer block with RMS norm and SwiGLU""" def __init__(self, hidden_size: int, num_heads: int = 8, expansion: float = 4.0, rms_eps: float = 1e-5): super().__init__() self.rms_eps = rms_eps self.attention = TransformerAttention(hidden_size, num_heads, hidden_size // num_heads) self.mlp = SwiGLU(hidden_size, expansion) def forward(self, x: torch.Tensor, cos_sin=None) -> torch.Tensor: # Attention with pre-norm h = rms_norm(x, self.rms_eps) h = self.attention(h, cos_sin) x = x + h # MLP with pre-norm h = rms_norm(x, self.rms_eps) h = self.mlp(h) x = x + h return x