Jonttup commited on
Commit
14ed7e4
·
verified ·
1 Parent(s): 27fdf85

Upload models/transformer_layers.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/transformer_layers.py +171 -0
models/transformer_layers.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pure Transformer Layers (extracted from Samsung's TRM)
3
+
4
+ License: Apache 2.0
5
+ Source: https://github.com/Sam-Saarinen/TinyRecursiveModels
6
+ Attribution: Adapted from Samsung's Tiny Recursive Model (TRM) codebase
7
+ """
8
+ import math
9
+ from typing import Tuple
10
+ import torch
11
+ from torch import nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ def trunc_normal_init_(tensor: torch.Tensor, std: float = 1.0, lower: float = -2.0, upper: float = 2.0):
16
+ """Truncated normal initialization from JAX/Flax"""
17
+ with torch.no_grad():
18
+ if std == 0:
19
+ tensor.zero_()
20
+ else:
21
+ sqrt2 = math.sqrt(2)
22
+ a = math.erf(lower / sqrt2)
23
+ b = math.erf(upper / sqrt2)
24
+ z = (b - a) / 2
25
+
26
+ c = (2 * math.pi) ** -0.5
27
+ pdf_u = c * math.exp(-0.5 * lower ** 2)
28
+ pdf_l = c * math.exp(-0.5 * lower ** 2)
29
+ comp_std = std / math.sqrt(1 - (upper * pdf_u - lower * pdf_l) / z - ((pdf_u - pdf_l) / z) ** 2)
30
+
31
+ tensor.uniform_(a, b)
32
+ tensor.erfinv_()
33
+ tensor.mul_(sqrt2 * comp_std)
34
+ tensor.clip_(lower * comp_std, upper * comp_std)
35
+ return tensor
36
+
37
+
38
+ def rms_norm(hidden_states: torch.Tensor, variance_epsilon: float = 1e-5) -> torch.Tensor:
39
+ """RMS Normalization - faster than LayerNorm"""
40
+ input_dtype = hidden_states.dtype
41
+ hidden_states = hidden_states.to(torch.float32)
42
+ variance = hidden_states.square().mean(-1, keepdim=True)
43
+ hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
44
+ return hidden_states.to(input_dtype)
45
+
46
+
47
+ def rotate_half(x: torch.Tensor):
48
+ """Rotates half the hidden dims for RoPE"""
49
+ x1 = x[..., : x.shape[-1] // 2]
50
+ x2 = x[..., x.shape[-1] // 2 :]
51
+ return torch.cat((-x2, x1), dim=-1)
52
+
53
+
54
+ def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
55
+ """Apply rotary positional embeddings"""
56
+ orig_dtype = q.dtype
57
+ q = q.to(cos.dtype)
58
+ k = k.to(cos.dtype)
59
+
60
+ q_embed = (q * cos.unsqueeze(-2)) + (rotate_half(q) * sin.unsqueeze(-2))
61
+ k_embed = (k * cos.unsqueeze(-2)) + (rotate_half(k) * sin.unsqueeze(-2))
62
+
63
+ return q_embed.to(orig_dtype), k_embed.to(orig_dtype)
64
+
65
+
66
+ class CastedLinear(nn.Module):
67
+ """Linear layer with automatic dtype casting for mixed precision"""
68
+ def __init__(self, in_features: int, out_features: int, bias: bool = False):
69
+ super().__init__()
70
+ self.weight = nn.Parameter(
71
+ trunc_normal_init_(torch.empty((out_features, in_features)), std=1.0 / (in_features ** 0.5))
72
+ )
73
+ self.bias = None
74
+ if bias:
75
+ self.bias = nn.Parameter(torch.zeros((out_features, )))
76
+
77
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
78
+ return F.linear(input, self.weight.to(input.dtype),
79
+ bias=self.bias.to(input.dtype) if self.bias is not None else None)
80
+
81
+
82
+ class RotaryEmbedding(nn.Module):
83
+ """Rotary Position Embedding (RoPE)"""
84
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
85
+ super().__init__()
86
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
87
+ t = torch.arange(max_position_embeddings, dtype=torch.float32, device=device)
88
+ freqs = torch.outer(t, inv_freq)
89
+ emb = torch.cat((freqs, freqs), dim=-1)
90
+ self.register_buffer('cos_cached', emb.cos(), persistent=False)
91
+ self.register_buffer('sin_cached', emb.sin(), persistent=False)
92
+
93
+ def forward(self):
94
+ return self.cos_cached, self.sin_cached
95
+
96
+
97
+ class SwiGLU(nn.Module):
98
+ """SwiGLU activation (Swish + GLU) - from Samsung TRM"""
99
+ def __init__(self, hidden_size: int, expansion: float = 2.667):
100
+ super().__init__()
101
+ inter = round(expansion * hidden_size * 2 / 3)
102
+ inter = ((inter + 255) // 256) * 256 # Round to multiple of 256
103
+
104
+ self.gate_up_proj = CastedLinear(hidden_size, inter * 2, bias=False)
105
+ self.down_proj = CastedLinear(inter, hidden_size, bias=False)
106
+
107
+ def forward(self, x):
108
+ gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
109
+ return self.down_proj(F.silu(gate) * up)
110
+
111
+
112
+ class TransformerAttention(nn.Module):
113
+ """Multi-head attention with RoPE support"""
114
+ def __init__(self, hidden_size: int, num_heads: int = 8, head_dim: int = 64):
115
+ super().__init__()
116
+ self.hidden_size = hidden_size
117
+ self.num_heads = num_heads
118
+ self.head_dim = head_dim
119
+ self.output_size = head_dim * num_heads
120
+
121
+ self.qkv_proj = CastedLinear(hidden_size, num_heads * head_dim * 3, bias=False)
122
+ self.o_proj = CastedLinear(self.output_size, hidden_size, bias=False)
123
+
124
+ def forward(self, hidden_states: torch.Tensor, cos_sin=None) -> torch.Tensor:
125
+ B, S, _ = hidden_states.shape
126
+
127
+ # Project to Q, K, V
128
+ qkv = self.qkv_proj(hidden_states)
129
+ qkv = qkv.view(B, S, self.num_heads * 3, self.head_dim)
130
+
131
+ query = qkv[:, :, :self.num_heads]
132
+ key = qkv[:, :, self.num_heads:self.num_heads * 2]
133
+ value = qkv[:, :, self.num_heads * 2:]
134
+
135
+ # Apply RoPE if provided
136
+ if cos_sin is not None:
137
+ cos, sin = cos_sin
138
+ query, key = apply_rotary_pos_emb(query, key, cos[:S], sin[:S])
139
+
140
+ # Attention (using PyTorch's optimized SDPA)
141
+ query = query.transpose(1, 2) # B, H, S, D
142
+ key = key.transpose(1, 2)
143
+ value = value.transpose(1, 2)
144
+
145
+ attn_output = F.scaled_dot_product_attention(query, key, value)
146
+ attn_output = attn_output.transpose(1, 2).reshape(B, S, self.output_size)
147
+
148
+ return self.o_proj(attn_output)
149
+
150
+
151
+ class TransformerBlock(nn.Module):
152
+ """Single transformer block with RMS norm and SwiGLU"""
153
+ def __init__(self, hidden_size: int, num_heads: int = 8, expansion: float = 4.0, rms_eps: float = 1e-5):
154
+ super().__init__()
155
+ self.rms_eps = rms_eps
156
+
157
+ self.attention = TransformerAttention(hidden_size, num_heads, hidden_size // num_heads)
158
+ self.mlp = SwiGLU(hidden_size, expansion)
159
+
160
+ def forward(self, x: torch.Tensor, cos_sin=None) -> torch.Tensor:
161
+ # Attention with pre-norm
162
+ h = rms_norm(x, self.rms_eps)
163
+ h = self.attention(h, cos_sin)
164
+ x = x + h
165
+
166
+ # MLP with pre-norm
167
+ h = rms_norm(x, self.rms_eps)
168
+ h = self.mlp(h)
169
+ x = x + h
170
+
171
+ return x