| import json |
| import math |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel, PretrainedConfig |
| from .configuration_ministu import MiniSTUConfig |
| from transformers.modeling_outputs import CausalLMOutput |
| try: |
| from flashfftconv import FlashFFTConv |
|
|
| flash_fft_available = True |
| except ImportError as e: |
| print(f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation.") |
| flash_fft_available = False |
|
|
| try: |
| from flash_attn import flash_attn_func |
| except ImportError as e: |
| print( |
| f"Unable to import Triton-based flash attention: {e}. No alternative currently available." |
| ) |
|
|
|
|
| def precompute_freqs_cis(head_dim: int, max_seq_len: int, theta: float = 10000.0): |
| |
| freq_seq = torch.arange(0, head_dim, 2).float() / head_dim |
| freqs = 1.0 / (theta ** freq_seq) |
|
|
| |
| t = torch.arange(max_seq_len, dtype=torch.float32) |
| angles = torch.outer(t, freqs) |
| |
| |
| freqs_cis = torch.polar( |
| torch.ones_like(angles), |
| angles |
| ) |
| return freqs_cis |
|
|
|
|
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): |
| """ |
| x is [B, n_heads, seq_len, head_dim_as_complex], |
| so we want to broadcast freqs_cis from [max_seq_len, half_dim] |
| to [1, 1, seq_len, half_dim]. |
| """ |
| seq_len = x.shape[2] |
| freqs_cis = freqs_cis[:seq_len] |
| return freqs_cis.view(1, 1, seq_len, -1) |
|
|
|
|
| def apply_rotary_emb( |
| xq: torch.Tensor, |
| xk: torch.Tensor, |
| freqs_cis: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| |
| |
| xq_complex = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| xk_complex = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
|
|
| |
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex) |
|
|
| |
| xq_complex = xq_complex * freqs_cis |
| xk_complex = xk_complex * freqs_cis |
|
|
| |
| xq_out = torch.view_as_real(xq_complex).reshape(*xq.shape) |
| xk_out = torch.view_as_real(xk_complex).reshape(*xk.shape) |
| return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
| def _generate_slopes(self, n: int): |
| start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
| return [start * (start**i) for i in range(n)] |
|
|
| def _get_alibi_slopes(self, n_heads: int, interpolation_factor: float = 0.25): |
| |
| if math.log2(n_heads).is_integer(): |
| slopes = self._generate_slopes(n_heads) |
| else: |
| |
| n = nearest_power_of_two(n_heads, round_up=False) |
| slopes_power_of_two = self._generate_slopes(n) |
|
|
| |
| extra_slopes = self._generate_slopes(2 * n) |
| extra_slopes_trunc = extra_slopes[0::2][: n_heads - n] |
| slopes = slopes_power_of_two + extra_slopes_trunc |
| slopes = torch.tensor(slopes, device=self.device) |
| slopes = slopes * interpolation_factor |
| return slopes |
|
|
|
|
| def get_hankel(seq_len: int, use_hankel_L: bool = False) -> torch.Tensor: |
| entries = torch.arange(1, seq_len + 1, dtype=torch.float64) |
| i_plus_j = entries[:, None] + entries[None, :] |
|
|
| if use_hankel_L: |
| sgn = (-1.0) ** (i_plus_j - 2.0) + 1.0 |
| denom = (i_plus_j + 3.0) * (i_plus_j - 1.0) * (i_plus_j + 1.0) |
| Z = sgn * (8.0 / denom) |
| elif not use_hankel_L: |
| Z = 2.0 / (i_plus_j**3 - i_plus_j) |
| else: |
| raise ValueError("use_hankel_L must be a boolean") |
|
|
| return Z |
|
|
|
|
| def get_spectral_filters( |
| seq_len: int, |
| K: int, |
| use_hankel_L: bool = False, |
| device: torch.device = None, |
| dtype: torch.dtype = torch.bfloat16, |
| ) -> torch.Tensor: |
| Z = get_hankel(seq_len, use_hankel_L).to(device) |
| sigma, phi = torch.linalg.eigh(Z) |
| sigma_k, phi_k = sigma[-K:], phi[:, -K:] |
| phi_k *= sigma_k ** 0.25 |
| return phi_k.to(dtype=dtype) |
|
|
| def nearest_power_of_two(x: int, round_up: bool = False) -> int: |
| return ( |
| 1 << math.floor(math.log2(x)) if not round_up else 1 << math.ceil(math.log2(x)) |
| ) |
|
|
|
|
| def convolve(u: torch.Tensor, v: torch.Tensor, n: int, use_approx: bool = True) -> tuple[torch.Tensor, torch.Tensor]: |
| bsz, seq_len, d_in = u.shape |
|
|
| sgn = torch.full((1, seq_len, 1), 1, device=u.device) |
| sgn[:, 1::2] *= -1 |
| if use_approx: |
| _, d_out = v.shape |
| v = v.view(1, -1, d_out, 1).to(torch.float32).contiguous() |
| else: |
| _, K = v.shape |
| sgn = sgn.unsqueeze(-1) |
| v = v.view(1, -1, K, 1, 1).to(torch.float32).contiguous() |
| u = u.view(bsz, -1, 1, d_in).expand(bsz, -1, K, d_in) |
|
|
| v = torch.fft.rfft(v, n=n, dim=1) |
|
|
| U = torch.stack([u, u * sgn], dim=-1).to(torch.float32).contiguous() |
| U = torch.fft.rfft(U, n=n, dim=1) |
| U_conv = torch.fft.irfft(v * U, n=n, dim=1)[:, :seq_len] |
| U_plus, U_minus = torch.unbind(U_conv, dim=-1) |
| U_minus = U_minus * sgn |
|
|
| return U_plus, U_minus |
|
|
|
|
| def flash_convolve( |
| u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Flash FFT convolution. |
| |
| Args: |
| u (torch.Tensor): Input tensor of shape `(B, L, d_in)`, where: |
| - `B` is the batch size, |
| - `L` is the sequence length, |
| - `d_in` is the input dimension. |
| v (torch.Tensor): Filter tensor of shape `(K, d_in)`, where: |
| - `K` is the number of filters, |
| - `d_in` is the input dimension. |
| flash_fft (FlashFFTConv): An instance of the FlashFFTConv module, used to perform the convolution. |
| use_approx (bool, optional): If `True`, performs the tensordot approximation (default is `True`). |
| |
| Returns: |
| tuple[torch.Tensor, torch.Tensor]: A tuple `(U_plus, U_minus)`: |
| - `U_plus`: Convolved output tensor with positive eigenvalues. |
| - Shape depends on `use_approx`: |
| - If `use_approx=True`: `(B, L, d_in)` |
| - If `use_approx=False`: `(B, L, K, d_in)` |
| - `U_minus`: Convolved output tensor with negative eigenvalues. |
| - Shape depends on `use_approx`: |
| - If `use_approx=True`: `(B, L, d_in)` |
| - If `use_approx=False`: `(B, L, K, d_in)` |
| |
| Raises: |
| ValueError: If the input tensor shapes do not conform to the expected dimensions. |
| |
| Example: |
| >>> u = torch.randn(4, 16, 32) # (B, L, d_in) |
| >>> v = torch.randn(8, 32) # (K, d_in) |
| >>> flash_fft = FlashFFTConv(n=16, dtype=torch.float32) |
| >>> U_plus, U_minus = flash_convolve(u, v, flash_fft, use_approx=True) |
| >>> print(U_plus.shape, U_minus.shape) |
| torch.Size([4, 16, 32]) torch.Size([4, 16, 32]) |
| """ |
| bsz, seq_len, d_in = u.shape |
| _, K = v.shape |
|
|
| padded_len = nearest_power_of_two(seq_len, round_up=True) |
| pad_len = padded_len - seq_len |
|
|
| sgn = torch.full((1, 1, padded_len), 1, device=u.device) |
| sgn[:, :, 1::2] = -1 |
|
|
| if use_approx: |
| u_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16).contiguous() |
| v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).contiguous() |
| u_conv = torch.stack([u_padded, u_padded * sgn], dim=0).reshape(2 * bsz, d_in, padded_len) |
| else: |
| u_k_padded = F.pad(u.transpose(1, 2), (0, pad_len)).to(torch.bfloat16).repeat_interleave(K, dim=1).contiguous() |
| v_padded = F.pad(v.transpose(0, 1), (0, pad_len)).to(torch.float32).repeat(d_in, 1).contiguous() |
| u_conv = torch.stack([u_k_padded, u_k_padded * sgn], dim=0).reshape(2 * bsz, K * d_in, padded_len) |
|
|
| U_conv = flash_fft(u_conv, v_padded) |
|
|
| |
| U_conv = U_conv[..., :seq_len] |
|
|
| u_plus, u_minus = torch.chunk(U_conv, 2, dim=0) |
|
|
| if use_approx: |
| u_minus = u_minus * sgn[:, :, :seq_len] |
| U_plus, U_minus = u_plus.transpose(1, 2), u_minus.transpose(1, 2) |
| else: |
| sgn = sgn[:, :, :seq_len].unsqueeze(-1).transpose(1, 2) |
| U_plus = u_plus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() |
| U_minus = u_minus.view(bsz, d_in, K, seq_len).permute(0, 3, 2, 1).contiguous() * sgn |
|
|
| return U_plus, U_minus |
|
|
|
|
| class STU(nn.Module): |
| def __init__(self, config, filters) -> None: |
| super(STU, self).__init__() |
| self.config = config |
| self.stu_filters = filters |
| self.n = nearest_power_of_two(config.seq_len * 2 - 1, round_up=True) |
| self.K = config.num_eigh |
| self.d_in = config.dim |
| self.d_out = config.dim |
| self.use_hankel_L = config.use_hankel_L |
| self.use_approx = config.use_approx |
| self.flash_fft = ( |
| FlashFFTConv(self.n, dtype=torch.bfloat16) |
| if config.use_flash_fft and flash_fft_available |
| else None |
| ) |
| if self.use_approx: |
| self.M_inputs = nn.Parameter( |
| torch.empty(self.d_in, self.d_out, dtype=config.torch_dtype) |
| ) |
| self.M_filters = nn.Parameter( |
| torch.empty(self.K, self.d_in, dtype=config.torch_dtype) |
| ) |
| else: |
| self.M_phi_plus = nn.Parameter( |
| torch.empty(self.K, self.d_in, self.d_out, dtype=config.torch_dtype) |
| ) |
| if not self.use_hankel_L: |
| self.M_phi_minus = nn.Parameter( |
| torch.empty(self.K, self.d_in, self.d_out, dtype=config.torch_dtype) |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.use_approx: |
| |
| x_proj = x @ self.M_inputs |
| phi_proj = self.stu_filters @ self.M_filters |
| if self.flash_fft: |
| spectral_plus, spectral_minus = flash_convolve( |
| x_proj, phi_proj, self.flash_fft, self.use_approx |
| ) |
| else: |
| spectral_plus, spectral_minus = convolve( |
| x_proj, phi_proj, self.n, self.use_approx |
| ) |
| else: |
| |
| if self.flash_fft: |
| U_plus, U_minus = flash_convolve( |
| x, self.stu_filters, self.flash_fft, self.use_approx |
| ) |
| else: |
| U_plus, U_minus = convolve(x, self.stu_filters, self.n, self.use_approx) |
| |
| spectral_plus = torch.tensordot( |
| U_plus, self.M_phi_plus, dims=([2, 3], [0, 1]) |
| ) |
| if not self.use_hankel_L: |
| spectral_minus = torch.tensordot( |
| U_minus, self.M_phi_minus, dims=([2, 3], [0, 1]) |
| ) |
|
|
| return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus |
|
|
|
|
| class STULayer(nn.Module): |
| def __init__(self, config, stu_filters): |
| super(STULayer, self).__init__() |
| self.stu_norm = nn.RMSNorm(config.dim) |
| self.stu = STU(config, stu_filters) |
| self.mlp_norm = nn.RMSNorm(config.dim) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = x + self.stu(self.stu_norm(x)) |
| x = x + self.mlp(self.mlp_norm(x)) |
| return x |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, config): |
| super(Attention, self).__init__() |
| self.dim, self.num_heads = config.dim, config.num_heads |
| assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})" |
| self.head_dim = config.dim // config.num_heads |
|
|
| self.c_attn = nn.Linear(self.dim, 3*self.dim, bias=config.bias) |
| self.c_proj = nn.Linear(config.dim, config.dim, bias=config.bias) |
| self.c_proj.SCALE_INIT = 1 |
|
|
| self.alibi_slopes = self._get_alibi_slopes(self.num_heads) if config.use_alibi else None |
| self.window_size = config.window_size |
| self.softcap = config.softcap |
|
|
| self.dropout = config.dropout |
| self.resid_dropout = nn.Dropout(self.dropout) |
|
|
| def _generate_slopes(self, n: int): |
| start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
| return [start * (start**i) for i in range(n)] |
|
|
| def _get_alibi_slopes(self, num_heads: int, interpolation_factor: float = 0.25): |
| |
| if math.log2(num_heads).is_integer(): |
| slopes = self._generate_slopes(num_heads) |
| else: |
| |
| n = nearest_power_of_two(num_heads, round_up=False) |
| slopes_power_of_two = self._generate_slopes(n) |
|
|
| |
| extra_slopes = self._generate_slopes(2 * n) |
| extra_slopes_trunc = extra_slopes[0::2][: num_heads - n] |
| slopes = slopes_power_of_two + extra_slopes_trunc |
| slopes = torch.tensor(slopes, device=torch.device("cuda")) |
| slopes = slopes * interpolation_factor |
| return slopes |
|
|
| def forward( |
| self, |
| x: torch.Tensor = None, |
| q: torch.Tensor = None, |
| k: torch.Tensor = None, |
| v: torch.Tensor = None, |
| freqs_cis: torch.Tensor = None, |
| ) -> torch.Tensor: |
| if x is not None: |
| q = k = v = x |
| if any(t is None for t in [q, k, v]): |
| raise ValueError("Must provide either x for self-attention or q/k/v for cross-attention.") |
|
|
| bsz, q_len, dim = q.shape |
| _, k_len, _ = k.shape |
| _, v_len, _ = v.shape |
|
|
| qkv = self.c_attn(x) |
| q, k, v = torch.chunk(qkv, 3, dim=2) |
|
|
| q = q.view(bsz, q_len, self.num_heads, self.head_dim) |
| k = k.view(bsz, k_len, self.num_heads, self.head_dim) |
| v = v.view(bsz, v_len, self.num_heads, self.head_dim) |
|
|
| if self.alibi_slopes is None: |
| q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis) |
|
|
| y = flash_attn_func( |
| q=q, k=k, v=v, |
| dropout_p=self.dropout if self.training else 0.0, |
| causal=True, |
| window_size=(self.window_size, 0), |
| alibi_slopes=self.alibi_slopes, |
| softcap=self.softcap, |
| ) |
|
|
| y = y.contiguous().view(bsz, q_len, -1) |
| y = self.resid_dropout(self.c_proj(y)) |
| return y |
|
|
|
|
| class AttentionLayer(nn.Module): |
| def __init__(self, config) -> None: |
| super(AttentionLayer, self).__init__() |
| self.attn_norm = nn.RMSNorm(config.dim) |
| self.attn = Attention(config=config) |
| self.mlp_norm = nn.RMSNorm(config.dim) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor=None) -> torch.Tensor: |
| x = x + self.attn(x=self.attn_norm(x), freqs_cis=freqs_cis) |
| x = x + self.mlp(self.mlp_norm(x)) |
| return x |
|
|
| class MLP(nn.Module): |
| def __init__(self, config): |
| |
| super().__init__() |
| self.hidden_size = config.dim |
| self.intermediate_size = config.dim * config.mlp_scale |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias) |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.bias) |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.bias) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| def forward(self, x): |
| gate = self.gate_proj(x) |
| gate = F.gelu(gate, approximate="tanh") |
| up = self.up_proj(x) |
| fuse = gate * up |
| outputs = self.down_proj(fuse) |
| outputs = self.dropout(outputs) |
| return outputs |
|
|
|
|
| class MiniSTU(PreTrainedModel): |
| config_class = MiniSTUConfig |
|
|
| def __init__(self, config) -> None: |
| super(MiniSTU, self).__init__(config) |
| filters = get_spectral_filters( |
| seq_len=config.seq_len, |
| K=config.num_eigh, |
| use_hankel_L=config.use_hankel_L, |
| device=config.device, |
| dtype=config.torch_dtype, |
| ) |
|
|
| self.num_layers = config.num_layers |
| assert config.dim % config.num_heads == 0, f"dim ({self.dim}) must be divisible num_heads ({self.num_heads})" |
| self.head_dim = config.dim // config.num_heads |
|
|
| |
| self.register_buffer( |
| "freqs_cis", |
| precompute_freqs_cis( |
| head_dim=self.head_dim, |
| max_seq_len=config.seq_len, |
| theta=config.theta, |
| ), |
| persistent=True, |
| ) |
|
|
| self.use_approx = config.use_approx |
| self.use_hankel_L = config.use_hankel_L |
|
|
| self.tok_emb = nn.Embedding(config.vocab_size, config.dim, dtype=config.torch_dtype) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| self.layers = nn.ModuleList() |
| for layer_idx in range(config.num_layers): |
| |
| if layer_idx % 2 == 0: |
| self.layers.append(STULayer(config, filters)) |
| else: |
| self.layers.append(AttentionLayer(config) if config.use_attn else STULayer(config, filters)) |
|
|
| self.norm = nn.RMSNorm(config.dim) |
| self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=config.bias) |
|
|
| if config.weight_tying: |
| self.tok_emb.weight = self.lm_head.weight |
|
|
| self.std = config.dim**-0.5 |
| self.apply(self._init_weights) |
| print("Model Parameter Count: %.2fM\n" % (self._get_num_params() / 1e6,)) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| labels: torch.Tensor = None, |
| **kwargs |
| ) -> CausalLMOutput: |
| |
| tok_emb = self.tok_emb(input_ids) |
| tok_emb = self.dropout(tok_emb) |
|
|
| for layer in self.layers: |
| if hasattr(layer, "attn"): |
| tok_emb = layer(tok_emb, freqs_cis=self.freqs_cis) |
| else: |
| tok_emb = layer(tok_emb) |
|
|
| |
| tok_emb = self.norm(tok_emb) |
| logits = self.lm_head(tok_emb) |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1) |
| ) |
|
|
| return CausalLMOutput( |
| loss=loss, |
| logits=logits, |
| ) |
|
|
| def _get_num_params(self): |
| n_params = sum(p.numel() for p in self.parameters()) |
| if hasattr(self, "pos_emb") and self.pos_emb is not None: |
| n_params -= self.pos_emb.weight.numel() |
| return n_params |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| if hasattr(module, "SCALE_INIT"): |
| self.std *= (2 * self.num_layers) ** -0.5 |
| torch.nn.init.normal_(module.weight, mean=0.0, std=self.std) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=self.std) |
| elif isinstance(module, Attention): |
| torch.nn.init.xavier_normal_(module.c_attn.weight) |
| torch.nn.init.xavier_normal_(module.c_proj.weight) |
| if module.c_attn.bias is not None: |
| torch.nn.init.zeros_(module.c_attn.bias) |
| if module.c_proj.bias is not None: |
| torch.nn.init.zeros_(module.c_proj.bias) |
| elif isinstance(module, STU): |
| if self.use_approx: |
| torch.nn.init.xavier_normal_(module.M_inputs) |
| torch.nn.init.xavier_normal_(module.M_filters) |
| else: |
| torch.nn.init.xavier_normal_(module.M_phi_plus) |
| if not self.use_hankel_L: |
| torch.nn.init.xavier_normal_(module.M_phi_minus) |
|
|
|
|