palette-edit-classifier / models /transformer_layers.py
Jonttup's picture
Upload models/transformer_layers.py with huggingface_hub
14ed7e4 verified
"""
Pure Transformer Layers (extracted from Samsung's TRM)
License: Apache 2.0
Source: https://github.com/Sam-Saarinen/TinyRecursiveModels
Attribution: Adapted from Samsung's Tiny Recursive Model (TRM) codebase
"""
import math
from typing import Tuple
import torch
from torch import nn
import torch.nn.functional as F
def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
"""Truncated normal initialization from JAX/Flax"""
with torch.no_grad():
if std == 0:
tensor.zero_()
else:
sqrt2 = math.sqrt(2)
a = math.erf(lower / sqrt2)
b = math.erf(upper / sqrt2)
z = (b - a) / 2
c = (2 * math.pi) ** -0.5
pdf_u = c * math.exp(-0.5 * lower ** 2)
pdf_l = c * math.exp(-0.5 * lower ** 2)
comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
tensor.uniform_(a, b)
tensor.erfinv_()
tensor.mul_(sqrt2 * comp_std)
tensor.clip_(lower * comp_std, upper * comp_std)
return tensor
def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float = 1e-5) -> torch.Tensor:
"""RMS Normalization - faster than LayerNorm"""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
return hidden_states.to(input_dtype)
def rotate_half(x: torch.Tensor):
"""Rotates half the hidden dims for RoPE"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
"""Apply rotary positional embeddings"""
orig_dtype = q.dtype
q = q.to(cos.dtype)
k = k.to(cos.dtype)
q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
class CastedLinear(nn.Module):
"""Linear layer with automatic dtype casting for mixed precision"""
def __init__(self, in_features: int, out_features: int, bias: bool = False):
super().__init__()
self.weight = nn.Parameter(
trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
)
self.bias = None
if bias:
self.bias = nn.Parameter(torch.zeros((out_features, )))
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.linear(input, self.weight.to(input.dtype),
bias=self.bias.to(input.dtype) if self.bias is not None else None)
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer('cos_cached', emb.cos(), persistent=False)
self.register_buffer('sin_cached', emb.sin(), persistent=False)
def forward(self):
return self.cos_cached, self.sin_cached
class SwiGLU(nn.Module):
"""SwiGLU activation (Swish + GLU) - from Samsung TRM"""
def __init__(self, hidden_size: int, expansion: float = 2.667):
super().__init__()
inter = round(expansion * hidden_size * 2 / 3)
inter = ((inter + 255) // 256) * 256 # Round to multiple of 256
self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
self.down_proj = CastedLinear(inter, hidden_size, bias=False)
def forward(self, x):
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
return self.down_proj(F.silu(gate) * up)
class TransformerAttention(nn.Module):
"""Multi-head attention with RoPE support"""
def __init__(self, hidden_size: int, num_heads: int = 8, head_dim: int = 64):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.output_size = head_dim * num_heads
self.qkv_proj = CastedLinear(hidden_size, num_heads * head_dim * 3, bias=False)
self.o_proj = CastedLinear(self.output_size, hidden_size, bias=False)
def forward(self, hidden_states: torch.Tensor, cos_sin=None) -> torch.Tensor:
B, S, _ = hidden_states.shape
# Project to Q, K, V
qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(B, S, self.num_heads * 3, self.head_dim)
query = qkv[:, :, :self.num_heads]
key = qkv[:, :, self.num_heads:self.num_heads * 2]
value = qkv[:, :, self.num_heads * 2:]
# Apply RoPE if provided
if cos_sin is not None:
cos, sin = cos_sin
query, key = apply_rotary_pos_emb(query, key, cos[:S], sin[:S])
# Attention (using PyTorch's optimized SDPA)
query = query.transpose(1, 2) # B, H, S, D
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_output = F.scaled_dot_product_attention(query, key, value)
attn_output = attn_output.transpose(1, 2).reshape(B, S, self.output_size)
return self.o_proj(attn_output)
class TransformerBlock(nn.Module):
"""Single transformer block with RMS norm and SwiGLU"""
def __init__(self, hidden_size: int, num_heads: int = 8, expansion: float = 4.0, rms_eps: float = 1e-5):
super().__init__()
self.rms_eps = rms_eps
self.attention = TransformerAttention(hidden_size, num_heads, hidden_size // num_heads)
self.mlp = SwiGLU(hidden_size, expansion)
def forward(self, x: torch.Tensor, cos_sin=None) -> torch.Tensor:
# Attention with pre-norm
h = rms_norm(x, self.rms_eps)
h = self.attention(h, cos_sin)
x = x + h
# MLP with pre-norm
h = rms_norm(x, self.rms_eps)
h = self.mlp(h)
x = x + h
return x