yagizdevre commited on
Commit
35b00e5
·
1 Parent(s): 06d90fc

config is added

Browse files
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .configuration_ministu import MiniSTUConfig
2
+ from .modeling_ministu import MiniSTU
added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "<|endofprompt|>": 200018
3
+ }
config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "ministu",
3
+ "_name_or_path": "STU-426M",
4
+ "architectures": ["MiniSTU"],
5
+ "dim": 896,
6
+ "num_heads": 8,
7
+ "num_layers": 12,
8
+ "seq_len": 8192,
9
+ "weight_tying": true,
10
+ "window_size": 1024,
11
+ "vocab_size": 200064,
12
+ "mlp_scale": 12,
13
+ "bias": false,
14
+ "dropout": 0.0,
15
+ "num_eigh": 24,
16
+ "use_hankel_L": false,
17
+ "num_epochs": 1,
18
+ "global_bsz": 524288,
19
+ "bsz": 2,
20
+ "warmup_steps": 1907,
21
+ "eval_period": 50,
22
+ "save_period": 500,
23
+ "max_lr": 3.0e-4,
24
+ "min_lr": 3.0e-5,
25
+ "max_norm": 1.0,
26
+ "dilation": 2,
27
+ "fsdp": true,
28
+ "ddp": false,
29
+ "mixed_precision": true,
30
+ "torch_dtype": "bfloat16",
31
+ "cpu_offload": false,
32
+ "sharding_strategy": "full_shard",
33
+ "state_dict_type": "full",
34
+ "auto_wrap_policy": "partial",
35
+ "backward_prefetch": "backward_pre",
36
+ "forward_prefetch": false,
37
+ "sync_module_states": true,
38
+ "use_orig_params": true,
39
+ "device_id": null,
40
+ "precision": {
41
+ "param": "bfloat16",
42
+ "reduce": "bfloat16",
43
+ "buffer": "bfloat16"
44
+ },
45
+ "fsdp_modules": [
46
+ "STULayer",
47
+ "AttentionLayer"
48
+ ],
49
+ "use_activation_checkpointing": true,
50
+ "use_flash_fft": true,
51
+ "use_approx": true,
52
+ "use_attn": true,
53
+ "softcap": 50.0,
54
+ "theta": 10000.0,
55
+ "use_alibi": false,
56
+ "torch_compile": false
57
+ }
configuration_ministu.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PretrainedConfig, AutoConfig
3
+
4
+ class MiniSTUConfig(PretrainedConfig):
5
+ model_type = "ministu"
6
+
7
+ def __init__(
8
+ self,
9
+ bsz: int = 1,
10
+ dim: int = 896,
11
+ num_heads: int = 8,
12
+ num_layers: int = 12,
13
+ seq_len: int = 8192,
14
+ weight_tying: bool = False,
15
+ window_size: int = 1024,
16
+ vocab_size: int = 200064,
17
+ mlp_scale: int = 12,
18
+ bias: bool = False,
19
+ dropout: float = 0.0,
20
+ num_eigh: int = 24,
21
+ use_hankel_L: bool = False,
22
+ use_flash_fft: bool = True,
23
+ use_approx: bool = True,
24
+ use_attn: bool = True,
25
+ softcap: float = 50.0,
26
+ theta: float = 10_000.0,
27
+ use_alibi: bool = False,
28
+ dilation: int = 2,
29
+ torch_dtype: torch.dtype = torch.bfloat16,
30
+ device: torch.device = None,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.bsz = bsz
35
+ self.dim = dim
36
+ self.num_heads = num_heads
37
+ self.num_layers = num_layers
38
+ self.seq_len = seq_len
39
+ self.weight_tying = weight_tying
40
+ self.window_size = window_size
41
+ self.vocab_size = vocab_size
42
+ self.hidden_size = dim
43
+ self.mlp_scale = mlp_scale
44
+ self.intermediate_size = self.hidden_size * self.mlp_scale
45
+ self.bias = bias
46
+ self.dropout = dropout
47
+ self.num_eigh = num_eigh
48
+ self.use_hankel_L = use_hankel_L
49
+ self.use_flash_fft = use_flash_fft
50
+ self.use_approx = use_approx
51
+ self.use_attn = use_attn
52
+ self.softcap = softcap
53
+ self.theta = theta
54
+ self.use_alibi = use_alibi
55
+ self.torch_dtype = torch_dtype
56
+ self.device = device
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_ministu.py ADDED
@@ -0,0 +1,536 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers import PreTrainedModel, PretrainedConfig
8
+ from .configuration_ministu import MiniSTUConfig
9
+ try:
10
+ from flashfftconv import FlashFFTConv
11
+
12
+ flash_fft_available = True
13
+ except ImportError as e:
14
+ print(f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation.")
15
+ flash_fft_available = False
16
+
17
+ try:
18
+ from flash_attn import flash_attn_func
19
+ except ImportError as e:
20
+ print(
21
+ f"Unable to import Triton-based flash attention: {e}. No alternative currently available."
22
+ )
23
+
24
+
25
+ def precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float = 10000.0):
26
+ # For half the dimensions, build the scale factor:
27
+ freq_seq = torch.arange(0, head_dim, 2).float() / head_dim
28
+ freqs = 1.0 / (theta ** freq_seq)
29
+
30
+ # Outer product with positions
31
+ t = torch.arange(max_seq_len, dtype=torch.float32)
32
+ angles = torch.outer(t, freqs)
33
+
34
+ # Build a complex exponential e^{i * theta}
35
+ freqs_cis = torch.polar(
36
+ torch.ones_like(angles),
37
+ angles
38
+ )
39
+ return freqs_cis
40
+
41
+
42
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
43
+ """
44
+ x is [B, n_heads, seq_len, head_dim_as_complex],
45
+ so we want to broadcast freqs_cis from [max_seq_len, half_dim]
46
+ to [1, 1, seq_len, half_dim].
47
+ """
48
+ seq_len = x.shape[2]
49
+ freqs_cis = freqs_cis[:seq_len] # slice down to current seq_len
50
+ return freqs_cis.view(1, 1, seq_len, -1)
51
+
52
+
53
+ def apply_rotary_emb(
54
+ xq: torch.Tensor,
55
+ xk: torch.Tensor,
56
+ freqs_cis: torch.Tensor,
57
+ ) -> tuple[torch.Tensor, torch.Tensor]:
58
+ # Convert real -> complex by grouping last dim in pairs
59
+ # shape => [B, n_heads, seq_len, head_dim//2, 2] => complex => [B, n_heads, seq_len, head_dim//2]
60
+ xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
61
+ xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
62
+
63
+ # Broadcast the frequencies to match [B, n_heads, seq_len, head_dim//2]
64
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex)
65
+
66
+ # Multiply => apply rotation
67
+ xq_complex = xq_complex * freqs_cis
68
+ xk_complex = xk_complex * freqs_cis
69
+
70
+ # Convert back to real => shape [B, n_heads, seq_len, head_dim]
71
+ xq_out = torch.view_as_real(xq_complex).reshape(*xq.shape)
72
+ xk_out = torch.view_as_real(xk_complex).reshape(*xk.shape)
73
+ return xq_out.type_as(xq), xk_out.type_as(xk)
74
+
75
+
76
+ def _generate_slopes(self, n: int):
77
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
78
+ return [start * (start**i) for i in range(n)]
79
+
80
+ def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25):
81
+ # If n_heads is a power of 2, generate slopes directly
82
+ if math.log2(n_heads).is_integer():
83
+ slopes = self._generate_slopes(n_heads)
84
+ else:
85
+ # Get slopes for the nearest power of two
86
+ n = nearest_power_of_two(n_heads, round_up=False)
87
+ slopes_power_of_two = self._generate_slopes(n)
88
+
89
+ # Generate extra slopes
90
+ extra_slopes = self._generate_slopes(2 * n)
91
+ extra_slopes_trunc = extra_slopes[0::2][: n_heads - n]
92
+ slopes = slopes_power_of_two + extra_slopes_trunc
93
+ slopes = torch.tensor(slopes, device=self.device)
94
+ slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
95
+ return slopes
96
+
97
+
98
+ def get_hankel(seq_len: int, use_hankel_L: bool = False) -> torch.Tensor:
99
+ entries = torch.arange(1, seq_len + 1, dtype=torch.float64)
100
+ i_plus_j = entries[:, None] + entries[None, :]
101
+
102
+ if use_hankel_L:
103
+ sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0
104
+ denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0)
105
+ Z = sgn * (8.0 / denom)
106
+ elif not use_hankel_L:
107
+ Z = 2.0 / (i_plus_j**3 - i_plus_j)
108
+ else:
109
+ raise ValueError("use_hankel_L must be a boolean")
110
+
111
+ return Z
112
+
113
+
114
+ def get_spectral_filters(
115
+ seq_len: int,
116
+ K: int,
117
+ use_hankel_L: bool = False,
118
+ device: torch.device = None,
119
+ dtype: torch.dtype = torch.bfloat16,
120
+ ) -> torch.Tensor:
121
+ Z = get_hankel(seq_len, use_hankel_L).to(device)
122
+ sigma, phi = torch.linalg.eigh(Z)
123
+ sigma_k, phi_k = sigma[-K:], phi[:, -K:]
124
+ phi_k *= sigma_k ** 0.25
125
+ return phi_k.to(dtype=dtype)
126
+
127
+ def nearest_power_of_two(x: int, round_up: bool = False) -> int:
128
+ return (
129
+ 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x))
130
+ )
131
+
132
+
133
+ def convolve(u: torch.Tensor, v: torch.Tensor, n: int, use_approx: bool = True) -> tuple[torch.Tensor, torch.Tensor]:
134
+ bsz, seq_len, d_in = u.shape
135
+
136
+ sgn = torch.full((1, seq_len, 1), 1, device=u.device)
137
+ sgn[:, 1::2] *= -1
138
+ if use_approx:
139
+ _, d_out = v.shape
140
+ v = v.view(1, -1, d_out, 1).to(torch.float32).contiguous()
141
+ else:
142
+ _, K = v.shape
143
+ sgn = sgn.unsqueeze(-1)
144
+ v = v.view(1, -1, K, 1, 1).to(torch.float32).contiguous() # (bsz, seq_len, K, d_in, stack)
145
+ u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in)
146
+
147
+ v = torch.fft.rfft(v, n=n, dim=1)
148
+
149
+ U = torch.stack([u, u * sgn], dim=-1).to(torch.float32).contiguous()
150
+ U = torch.fft.rfft(U, n=n, dim=1)
151
+ U_conv = torch.fft.irfft(v * U, n=n, dim=1)[:, :seq_len]
152
+ U_plus, U_minus = torch.unbind(U_conv, dim=-1)
153
+ U_minus = U_minus * sgn
154
+
155
+ return U_plus, U_minus
156
+
157
+
158
+ def flash_convolve(
159
+ u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
160
+ ) -> tuple[torch.Tensor, torch.Tensor]:
161
+ """
162
+ Flash FFT convolution.
163
+
164
+ Args:
165
+ u (torch.Tensor): Input tensor of shape `(B, L, d_in)`, where:
166
+ - `B` is the batch size,
167
+ - `L` is the sequence length,
168
+ - `d_in` is the input dimension.
169
+ v (torch.Tensor): Filter tensor of shape `(K, d_in)`, where:
170
+ - `K` is the number of filters,
171
+ - `d_in` is the input dimension.
172
+ flash_fft (FlashFFTConv): An instance of the FlashFFTConv module, used to perform the convolution.
173
+ use_approx (bool, optional): If `True`, performs the tensordot approximation (default is `True`).
174
+
175
+ Returns:
176
+ tuple[torch.Tensor, torch.Tensor]: A tuple `(U_plus, U_minus)`:
177
+ - `U_plus`: Convolved output tensor with positive eigenvalues.
178
+ - Shape depends on `use_approx`:
179
+ - If `use_approx=True`: `(B, L, d_in)`
180
+ - If `use_approx=False`: `(B, L, K, d_in)`
181
+ - `U_minus`: Convolved output tensor with negative eigenvalues.
182
+ - Shape depends on `use_approx`:
183
+ - If `use_approx=True`: `(B, L, d_in)`
184
+ - If `use_approx=False`: `(B, L, K, d_in)`
185
+
186
+ Raises:
187
+ ValueError: If the input tensor shapes do not conform to the expected dimensions.
188
+
189
+ Example:
190
+ >>> u = torch.randn(4, 16, 32) # (B, L, d_in)
191
+ >>> v = torch.randn(8, 32) # (K, d_in)
192
+ >>> flash_fft = FlashFFTConv(n=16, dtype=torch.float32)
193
+ >>> U_plus, U_minus = flash_convolve(u, v, flash_fft, use_approx=True)
194
+ >>> print(U_plus.shape, U_minus.shape)
195
+ torch.Size([4, 16, 32]) torch.Size([4, 16, 32])
196
+ """
197
+ bsz, seq_len, d_in = u.shape
198
+ _, K = v.shape
199
+
200
+ padded_len = nearest_power_of_two(seq_len, round_up=True)
201
+ pad_len = padded_len - seq_len
202
+
203
+ sgn = torch.full((1, 1, padded_len), 1, device=u.device)
204
+ sgn[:, :, 1::2] = -1
205
+
206
+ if use_approx:
207
+ u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16).contiguous()
208
+ v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).contiguous()
209
+ u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len)
210
+ else:
211
+ u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16).repeat_interleave(K, dim=1).contiguous()
212
+ v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).repeat(d_in, 1).contiguous()
213
+ u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len)
214
+
215
+ U_conv = flash_fft(u_conv, v_padded)
216
+
217
+ # Trim the output back to the original sequence length
218
+ U_conv = U_conv[..., :seq_len]
219
+
220
+ u_plus, u_minus = torch.chunk(U_conv, 2, dim=0)
221
+
222
+ if use_approx:
223
+ u_minus = u_minus * sgn[:, :, :seq_len]
224
+ U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2)
225
+ else:
226
+ sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2)
227
+ U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous()
228
+ U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn
229
+
230
+ return U_plus, U_minus
231
+
232
+
233
+ class STU(nn.Module):
234
+ def __init__(self, config, filters) -> None:
235
+ super(STU, self).__init__()
236
+ self.config = config
237
+ self.stu_filters = filters
238
+ self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True)
239
+ self.K = config.num_eigh
240
+ self.d_in = config.dim
241
+ self.d_out = config.dim
242
+ self.use_hankel_L = config.use_hankel_L
243
+ self.use_approx = config.use_approx
244
+ self.flash_fft = (
245
+ FlashFFTConv(self.n, dtype=torch.bfloat16)
246
+ if config.use_flash_fft and flash_fft_available
247
+ else None
248
+ ) # TODO: Buggy with torch.compile, need to write a custom op wrapper
249
+ if self.use_approx:
250
+ self.M_inputs = nn.Parameter(
251
+ torch.empty(self.d_in, self.d_out, dtype=config.torch_dtype)
252
+ )
253
+ self.M_filters = nn.Parameter(
254
+ torch.empty(self.K, self.d_in, dtype=config.torch_dtype)
255
+ )
256
+ else:
257
+ self.M_phi_plus = nn.Parameter(
258
+ torch.empty(self.K, self.d_in, self.d_out, dtype=config.torch_dtype)
259
+ )
260
+ if not self.use_hankel_L:
261
+ self.M_phi_minus = nn.Parameter(
262
+ torch.empty(self.K, self.d_in, self.d_out, dtype=config.torch_dtype)
263
+ )
264
+
265
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
266
+ if self.use_approx:
267
+ # Contract inputs and filters over the K and d_in dimensions, then convolve
268
+ x_proj = x @ self.M_inputs
269
+ phi_proj = self.stu_filters @ self.M_filters
270
+ if self.flash_fft:
271
+ spectral_plus, spectral_minus = flash_convolve(
272
+ x_proj, phi_proj, self.flash_fft, self.use_approx
273
+ )
274
+ else:
275
+ spectral_plus, spectral_minus = convolve(
276
+ x_proj, phi_proj, self.n, self.use_approx
277
+ )
278
+ else:
279
+ # Convolve inputs and filters,
280
+ if self.flash_fft:
281
+ U_plus, U_minus = flash_convolve(
282
+ x, self.stu_filters, self.flash_fft, self.use_approx
283
+ )
284
+ else:
285
+ U_plus, U_minus = convolve(x, self.stu_filters, self.n, self.use_approx)
286
+ # Then, contract over the K and d_in dimensions
287
+ spectral_plus = torch.tensordot(
288
+ U_plus, self.M_phi_plus, dims=([2, 3], [0, 1])
289
+ )
290
+ if not self.use_hankel_L:
291
+ spectral_minus = torch.tensordot(
292
+ U_minus, self.M_phi_minus, dims=([2, 3], [0, 1])
293
+ )
294
+
295
+ return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus
296
+
297
+
298
+ class STULayer(nn.Module):
299
+ def __init__(self, config, stu_filters):
300
+ super(STULayer, self).__init__()
301
+ self.stu_norm = nn.RMSNorm(config.dim)
302
+ self.stu = STU(config, stu_filters)
303
+ self.mlp_norm = nn.RMSNorm(config.dim)
304
+ self.mlp = MLP(config)
305
+
306
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
307
+ x = x + self.stu(self.stu_norm(x))
308
+ x = x + self.mlp(self.mlp_norm(x))
309
+ return x
310
+
311
+
312
+ class Attention(nn.Module):
313
+ def __init__(self, config):
314
+ super(Attention, self).__init__()
315
+ self.dim, self.num_heads = config.dim, config.num_heads
316
+ assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
317
+ self.head_dim = config.dim // config.num_heads
318
+
319
+ self.c_attn = nn.Linear(self.dim, 3*self.dim, bias=config.bias)
320
+ self.c_proj = nn.Linear(config.dim, config.dim, bias=config.bias)
321
+ self.c_proj.SCALE_INIT = 1
322
+
323
+ self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None
324
+ self.window_size = config.window_size
325
+ self.softcap = config.softcap
326
+
327
+ self.dropout = config.dropout
328
+ self.resid_dropout = nn.Dropout(self.dropout)
329
+
330
+ def _generate_slopes(self, n: int):
331
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
332
+ return [start * (start**i) for i in range(n)]
333
+
334
+ def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25):
335
+ # If n_heads is a power of 2, generate slopes directly
336
+ if math.log2(num_heads).is_integer():
337
+ slopes = self._generate_slopes(num_heads)
338
+ else:
339
+ # Get slopes for the nearest power of two
340
+ n = nearest_power_of_two(num_heads, round_up=False)
341
+ slopes_power_of_two = self._generate_slopes(n)
342
+
343
+ # Generate extra slopes
344
+ extra_slopes = self._generate_slopes(2 * n)
345
+ extra_slopes_trunc = extra_slopes[0::2][: num_heads - n]
346
+ slopes = slopes_power_of_two + extra_slopes_trunc
347
+ slopes = torch.tensor(slopes, device=torch.device("cuda"))
348
+ slopes = slopes * interpolation_factor # https://arxiv.org/pdf/2310.13017
349
+ return slopes
350
+
351
+ def forward(
352
+ self,
353
+ x: torch.Tensor = None,
354
+ q: torch.Tensor = None,
355
+ k: torch.Tensor = None,
356
+ v: torch.Tensor = None,
357
+ freqs_cis: torch.Tensor = None,
358
+ ) -> torch.Tensor:
359
+ if x is not None:
360
+ q = k = v = x
361
+ if any(t is None for t in [q, k, v]):
362
+ raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.")
363
+
364
+ bsz, q_len, dim = q.shape
365
+ _, k_len, _ = k.shape
366
+ _, v_len, _ = v.shape
367
+
368
+ qkv = self.c_attn(x)
369
+ q, k, v = torch.chunk(qkv, 3, dim=2)
370
+
371
+ q = q.view(bsz, q_len, self.num_heads, self.head_dim)
372
+ k = k.view(bsz, k_len, self.num_heads, self.head_dim)
373
+ v = v.view(bsz, v_len, self.num_heads, self.head_dim)
374
+
375
+ if self.alibi_slopes is None: # Use either ALiBi or RoPE
376
+ q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)
377
+
378
+ y = flash_attn_func( # https://arxiv.org/pdf/2307.08691
379
+ q=q, k=k, v=v,
380
+ dropout_p=self.dropout if self.training else 0.0,
381
+ causal=True,
382
+ window_size=(self.window_size, 0), # Set to config.seq_len if full attention
383
+ alibi_slopes=self.alibi_slopes, # https://arxiv.org/pdf/2108.12409
384
+ softcap=self.softcap, # https://arxiv.org/pdf/2408.00118
385
+ )
386
+
387
+ y = y.contiguous().view(bsz, q_len, -1)
388
+ y = self.resid_dropout(self.c_proj(y))
389
+ return y
390
+
391
+
392
+ class AttentionLayer(nn.Module):
393
+ def __init__(self, config) -> None:
394
+ super(AttentionLayer, self).__init__()
395
+ self.attn_norm = nn.RMSNorm(config.dim)
396
+ self.attn = Attention(config=config)
397
+ self.mlp_norm = nn.RMSNorm(config.dim)
398
+ self.mlp = MLP(config)
399
+
400
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor:
401
+ x = x + self.attn(x=self.attn_norm(x), freqs_cis=freqs_cis)
402
+ x = x + self.mlp(self.mlp_norm(x))
403
+ return x
404
+
405
+ class MLP(nn.Module):
406
+ def __init__(self, config):
407
+ # https://arxiv.org/pdf/2002.05202
408
+ super().__init__()
409
+ self.hidden_size = config.dim
410
+ self.intermediate_size = config.dim * config.mlp_scale
411
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
412
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias)
413
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias)
414
+ self.dropout = nn.Dropout(config.dropout)
415
+
416
+ def forward(self, x):
417
+ gate = self.gate_proj(x)
418
+ gate = F.gelu(gate, approximate="tanh")
419
+ up = self.up_proj(x)
420
+ fuse = gate * up
421
+ outputs = self.down_proj(fuse)
422
+ outputs = self.dropout(outputs)
423
+ return outputs
424
+
425
+
426
+ class MiniSTU(PreTrainedModel):
427
+ config_class = MiniSTUConfig
428
+
429
+ def __init__(self, config, filters) -> None:
430
+ super(MiniSTU, self).__init__(config)
431
+ self.num_layers = config.num_layers
432
+ assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})"
433
+ self.head_dim = config.dim // config.num_heads
434
+
435
+ # From pytorch/pytorch#123411, we set persistent=True for torch.compile and PP compatibility
436
+ self.register_buffer(
437
+ "freqs_cis",
438
+ precompute_freqs_cis(
439
+ head_dim=self.head_dim,
440
+ max_seq_len=config.seq_len,
441
+ theta=config.theta,
442
+ ),
443
+ persistent=True,
444
+ )
445
+
446
+ self.use_approx = config.use_approx
447
+ self.use_hankel_L = config.use_hankel_L
448
+
449
+ self.tok_emb = nn.Embedding(config.vocab_size, config.dim, dtype=config.torch_dtype)
450
+ self.dropout = nn.Dropout(config.dropout)
451
+
452
+ self.layers = nn.ModuleList()
453
+ for layer_idx in range(config.num_layers):
454
+ # For more complex %-split arrangements, see https://arxiv.org/pdf/2406.07887
455
+ if layer_idx % 2 == 0:
456
+ self.layers.append(STULayer(config, filters))
457
+ else:
458
+ self.layers.append(AttentionLayer(config) if config.use_attn else STULayer(config, filters))
459
+
460
+ self.norm = nn.RMSNorm(config.dim)
461
+ self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=config.bias)
462
+
463
+ if config.weight_tying:
464
+ self.tok_emb.weight = self.lm_head.weight
465
+
466
+ self.std = config.dim**-0.5
467
+ self.apply(self._init_weights)
468
+ print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,))
469
+
470
+ def forward(
471
+ self,
472
+ input_ids: torch.Tensor,
473
+ labels: torch.Tensor = None,
474
+ **kwargs
475
+ ) -> CausalLMOutput:
476
+ # Compute embeddings
477
+ tok_emb = self.tok_emb(input_ids)
478
+ tok_emb = self.dropout(tok_emb)
479
+
480
+ for layer in self.layers:
481
+ if hasattr(layer, "attn"):
482
+ tok_emb = layer(tok_emb, freqs_cis=self.freqs_cis)
483
+ else:
484
+ tok_emb = layer(tok_emb)
485
+
486
+ # Normalize and project to vocabulary
487
+ tok_emb = self.norm(tok_emb)
488
+ logits = self.lm_head(tok_emb)
489
+
490
+ loss = None
491
+ if labels is not None:
492
+ # Shift so that tokens predict the next token
493
+ shift_logits = logits[..., :-1, :].contiguous()
494
+ shift_labels = labels[..., 1:].contiguous()
495
+ loss_fct = nn.CrossEntropyLoss()
496
+ loss = loss_fct(
497
+ shift_logits.view(-1, shift_logits.size(-1)),
498
+ shift_labels.view(-1)
499
+ )
500
+
501
+ return CausalLMOutput(
502
+ loss=loss,
503
+ logits=logits,
504
+ )
505
+
506
+ def _get_num_params(self):
507
+ n_params = sum(p.numel() for p in self.parameters())
508
+ if hasattr(self, "pos_emb") and self.pos_emb is not None:
509
+ n_params -= self.pos_emb.weight.numel()
510
+ return n_params
511
+
512
+ def _init_weights(self, module):
513
+ if isinstance(module, nn.Linear):
514
+ if hasattr(module, "SCALE_INIT"):
515
+ self.std *= (2 * self.num_layers) ** -0.5
516
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
517
+ if module.bias is not None:
518
+ torch.nn.init.zeros_(module.bias)
519
+ elif isinstance(module, nn.Embedding):
520
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.std)
521
+ elif isinstance(module, Attention):
522
+ torch.nn.init.xavier_normal_(module.c_attn.weight)
523
+ torch.nn.init.xavier_normal_(module.c_proj.weight)
524
+ if module.c_attn.bias is not None:
525
+ torch.nn.init.zeros_(module.c_attn.bias)
526
+ if module.c_proj.bias is not None:
527
+ torch.nn.init.zeros_(module.c_proj.bias)
528
+ elif isinstance(module, STU):
529
+ if self.use_approx:
530
+ torch.nn.init.xavier_normal_(module.M_inputs)
531
+ torch.nn.init.xavier_normal_(module.M_filters)
532
+ else:
533
+ torch.nn.init.xavier_normal_(module.M_phi_plus)
534
+ if not self.use_hankel_L:
535
+ torch.nn.init.xavier_normal_(module.M_phi_minus)
536
+
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "199999": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "200018": {
13
+ "content": "<|endofprompt|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ }
20
+ },
21
+ "bos_token": "<|endoftext|>",
22
+ "clean_up_tokenization_spaces": false,
23
+ "eos_token": "<|endoftext|>",
24
+ "model_max_length": 128000,
25
+ "tokenizer_class": "GPT2Tokenizer",
26
+ "unk_token": "<|endoftext|>"
27
+ }
vocab.json ADDED
The diff for this file is too large to render. See raw diff