| import math |
| from typing import Optional, Tuple, TypeVar |
| import torch.nn as nn |
| import torch |
| import triton |
|
|
| from functools import lru_cache |
|
|
|
|
| from .triton_flash_blocksparse_attn import get_local_strided_sparse_attention_op, _get_sparse_attn_mask, blocksparse_flash_attn_padded_fwd, blocksparse_flash_attn_varlen_fwd |
|
|
|
|
| Layout = Tuple[torch.LongTensor, torch.LongTensor] |
|
|
|
|
| def create_sparse_attn_mask( |
| n_heads: int, |
| max_seq_len: int, |
| max_seq_len_k: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| BLOCK: int, |
| local_blocks: int, |
| vert_stride: int, |
| homo_head: bool, |
| return_dense: bool |
| ) -> Tuple[Layout, torch.Tensor, Optional[torch.Tensor]]: |
| layout, block_sparse_pattern, _ = _get_sparse_attn_mask( |
| n_heads=n_heads, |
| q_len=max_seq_len, |
| N_CTX=max_seq_len_k, |
| dtype=dtype, |
| device=device, |
| BLOCK=BLOCK, |
| local_blocks=local_blocks, |
| vert_stride=vert_stride, |
| homo_head=homo_head, |
| return_dense=return_dense |
| ) |
| return layout, block_sparse_pattern |
|
|
|
|
| class BlockSparseAttentionLayer(nn.Module): |
| def __init__( |
| self, |
| n_heads: int, |
| max_seq_len: int, |
| sparse_block_size: int, |
| local_blocks: int, |
| vert_stride: int, |
| kernel_block_size: Optional[int] = None, |
| homo_head: bool = False, |
| active_head_range: Optional[Tuple[int]] = None |
| ) -> None: |
| super().__init__() |
|
|
| self.n_heads = n_heads |
| self.max_seq_len = max_seq_len |
| self.sparse_block_size = sparse_block_size |
| self.kernel_block_size = kernel_block_size or sparse_block_size |
| self.local_blocks = local_blocks |
| self.vert_stride = vert_stride |
| self.homo_head = homo_head |
| self.active_head_range = active_head_range |
|
|
| |
| self._sparse_block_mask = None |
| self._sparse_layout = None |
| self._dtype = None |
| self._device = None |
|
|
| |
| |
| |
| |
| |
| def prune_blocksparse_layout_to_heads(self, h_start: int, h_end: int) -> None: |
| self._sparse_block_mask = self._sparse_block_mask[h_start: h_end] |
| self._sparse_layout[0] = self._sparse_layout[0][h_start: h_end] |
| self._sparse_layout[1] = self._sparse_layout[1][h_start: h_end] |
| |
| def _initialize_internals( |
| self, |
| dtype: torch.dtype, |
| device: torch.device |
| ) -> None: |
| self._dtype, self._device = dtype, device |
| self._sparse_layout, self._sparse_block_mask = create_sparse_attn_mask( |
| n_heads=self.n_heads, |
| max_seq_len=self.max_seq_len, |
| max_seq_len_k=self.max_seq_len, |
| dtype=dtype, |
| device=device, |
| BLOCK=self.sparse_block_size, |
| local_blocks=self.local_blocks, |
| vert_stride=self.vert_stride, |
| homo_head=self.homo_head, |
| return_dense=False, |
| ) |
| if (not self.homo_head) and (self.active_head_range is not None): |
| assert len(self.active_head_range) == 2, "\"active_head_range\" should be a tuple of start/end index of the heads." |
| h_start, h_end = self.active_head_range |
| self.prune_blocksparse_layout_to_heads(h_start=h_start, h_end=h_end) |
|
|
| assert self.sparse_block_size % self.kernel_block_size == 0, f"The sparse block size must be a multiple of {self.kernel_block_size}. Found {self.sparse_block_size}." |
| assert self.kernel_block_size >=16 and math.log2(self.kernel_block_size) % 1 == 0, f"block_size must be power of 2 and at least 16, but {self.kernel_block_size} is given" |
| if self.sparse_block_size // self.kernel_block_size > 1: |
| _mul = self.sparse_block_size // self.kernel_block_size |
| |
| self._sparse_block_mask = torch.kron(self._sparse_block_mask, self._sparse_block_mask.new_ones(_mul, _mul)) |
| num_sparse_blocks = self._sparse_block_mask.size(-1) |
| block_causal_mask = torch.arange(0, num_sparse_blocks)[:, None] >= torch.arange(0, num_sparse_blocks)[None] |
| self._sparse_block_mask *= block_causal_mask.type_as(self._sparse_block_mask) |
|
|
|
|
| def forward( |
| self, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| sm_scale: float, |
| *, |
| |
| left_paddings: Optional[torch.LongTensor] = None, |
| seqlens: Optional[torch.LongTensor] = None, |
| |
| cu_seqlens_k: Optional[torch.LongTensor] = None, |
| cu_seqlens_q: Optional[torch.LongTensor] = None, |
| ) -> torch.Tensor: |
|
|
| if left_paddings is None and seqlens is None and cu_seqlens_k is None and cu_seqlens_q is None: |
| blocksparse_op = get_local_strided_sparse_attention_op( |
| n_heads=self.n_heads, |
| max_seq_len=self.max_seq_len, |
| sparse_block_size=self.sparse_block_size, |
| kernel_block_size=self.kernel_block_size, |
| local_blocks=self.local_blocks, |
| vert_stride=self.vert_stride, |
| homo_head=self.homo_head, |
| device=q.device, |
| inference=not self.training |
| ) |
| return blocksparse_op(q, k, v, sm_scale) |
|
|
| assert not torch.is_grad_enabled(), "Variable Length Inference / Batched inference is not supported during training. Please run it in a torch.no_grad() context" |
| |
| if self._sparse_block_mask is None or (self._dtype != q.dtype) or (self._device != q.device): |
| self._initialize_internals(dtype=q.dtype, device=q.device) |
| |
| if k.dim() == 3: |
| assert cu_seqlens_k is not None |
| return blocksparse_flash_attn_varlen_fwd( |
| q=q, |
| k=k, |
| v=v, |
| cu_seqlens_k=cu_seqlens_k, |
| cu_seqlens_q=cu_seqlens_q, |
| sm_scale=sm_scale, |
| sparse_layout=self._sparse_layout, |
| block_size=self.kernel_block_size, |
| max_seqlen=self.max_seq_len, |
| ) |
| if k.dim() == 4: |
| assert not (left_paddings is None and seqlens is None), "Either left_paddings or seqlens must be provided for batched inference." |
| return blocksparse_flash_attn_padded_fwd( |
| q=q, |
| k=k, |
| v=v, |
| sm_scale=sm_scale, |
| sparse_layout=self._sparse_layout, |
| left_paddings=left_paddings, |
| seqlens=seqlens, |
| block_size=self.kernel_block_size, |
| max_seqlen=self.max_seq_len, |
| ) |
| raise ValueError('q/k/v must be either 3 dim for variable-length input or 4 dim for fixed-length.') |
|
|