| """ |
| Wrap torch's flex attention and handle mess info or potentially refactor |
| """ |
| from functools import partial |
| import torch |
| import numpy as np |
| import torch.nn as nn |
| import torch.nn.functional as F |
| try: |
| from torch.nn.attention.flex_attention import flex_attention, create_block_mask |
| flex_attention_available = True |
| except ImportError: |
| print(f"[Warning] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}") |
| flex_attention_available = False |
|
|
| def _causal_mask(b, h, q_idx, kv_idx): |
| return q_idx >= kv_idx |
|
|
| def _length_to_offsets(lengths, device): |
| """Converts a list of lengths to a list of offsets. |
| |
| Args: |
| lengths: A list of lengths. |
| |
| """ |
| offsets = [0] |
| offsets.extend(lengths) |
| offsets = torch.tensor(offsets, device=device, dtype=torch.int32) |
| offsets = torch.cumsum(offsets, dim=-1) |
| return offsets |
|
|
| def _generate_var_mask_mod(offsets): |
| """Generates mask mods that apply to inputs to flex attention in the sequence stacked |
| format. |
| |
| Args: |
| offsets: This tensor should be of shape(num_documents + 1) |
| this should contain the cumulative counts of document tokens. |
| e.g. if you have 3 documents of length 2, 4, 3 then |
| offsets = [0, 2, 6, 9] |
| |
| Note: |
| What is the sequence stacked format? When assembling batches of inputs, we |
| take multiple sequences and stack them together to form 1 large sequence. We then |
| use masking to ensure that the attention scores are only applied to tokens within |
| the same document. |
| """ |
|
|
| def _offsets_to_doc_ids_tensor(offsets): |
| device = offsets.device |
| counts = offsets[1:] - offsets[:-1] |
| return torch.repeat_interleave( |
| torch.arange(len(counts), device=device, dtype=torch.int32), counts |
| ) |
|
|
| document_id = _offsets_to_doc_ids_tensor(offsets) |
|
|
| def var_mask_mod(b, h, q_idx, kv_idx): |
| same_doc = document_id[q_idx] == document_id[kv_idx] |
| causal_mask = _causal_mask(b, h, q_idx, kv_idx) |
| return same_doc | causal_mask |
|
|
| return var_mask_mod |
|
|
| def _generate_var_infer_mask_with_kv_cache(lengths): |
| kv_len = sum(lengths) |
| def var_mask_mod(b, h, q_idx, kv_idx): |
| return kv_idx < kv_len |
|
|
| return var_mask_mod |
|
|
| class FlexAttn(nn.Module): |
| def __init__( |
| self, block_scales:list, mask_type:str, B, H, L:int, auto_padding=False |
| ): |
| """ |
| :param block_scales: accept VAR's block sizes like [(1,1), (2,2), (3,3)] |
| :param mask_type: var/causal |
| :param B: batch size |
| :param H: heads num |
| :param L: sequence length |
| """ |
| super().__init__() |
| if not flex_attention_available: |
| raise NotImplementedError((f"[Error] flex attention need pytorch 2.5.0+ but your version is {torch.__version__}")) |
|
|
| self.support_mask_type = ["var", "causal", "var_infer_mask_with_kv_cache"] |
| self.auto_padding = auto_padding |
|
|
| self.flex_attention = torch.compile(flex_attention) |
|
|
| self.block_scales = block_scales |
| self.lengths = [ x * y * z for x,y,z in block_scales] |
|
|
| self.offsets = _length_to_offsets(self.lengths, device='cuda') |
|
|
| |
| if self.offsets[-1] < L: |
| self.offsets = torch.cat((self.offsets, torch.tensor([L], device='cuda')), dim=0) |
|
|
| if mask_type == "var": |
| self.mask_mod = _generate_var_mask_mod(self.offsets) |
| self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True) |
| elif mask_type == "causal": |
| self.mask_mod = _causal_mask |
| self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True) |
| elif mask_type == 'var_infer_mask_with_kv_cache': |
| self.mask_mod = _generate_var_infer_mask_with_kv_cache(self.lengths) |
| self.block_mask = create_block_mask(self.mask_mod, B = B, H = H, Q_LEN = L, KV_LEN = L, device = 'cuda', _compile = True) |
| else: |
| raise NotImplementedError(f"{mask_type} not supportted in FlexAttn, support type:{self.support_mask_type}") |
|
|
|
|
| def forward(self, q, k, v, scale = None): |
| |
| |
| |
| kernel_options = { |
| "BLOCK_M": 64, |
| "BLOCK_N": 64, |
| "BLOCK_M1": 32, |
| "BLOCK_N1": 64, |
| "BLOCK_M2": 64, |
| "BLOCK_N2": 32, |
| } |
| if self.auto_padding: |
| q_pad_len = (128 - q.shape[-2] % 128) % 128 |
| kv_pad_len = (128 - k.shape[-2] % 128) % 128 |
| q_pad = F.pad(q, (0, 0, 0, q_pad_len)) |
| k_pad = F.pad(k, (0, 0, 0, kv_pad_len)) |
| v_pad = F.pad(v, (0, 0, 0, kv_pad_len)) |
| |
| |
| oup = self.flex_attention(q_pad.to(v_pad.dtype), k_pad.to(v.dtype), v_pad, block_mask = self.block_mask, scale = scale, kernel_options = kernel_options) |
| if q_pad_len > 0: |
| oup = oup[:,:,:-q_pad_len] |
| else: |
| |
| |
| oup = self.flex_attention(q.to(v.dtype), k.to(v.dtype), v, block_mask = self.block_mask, scale = scale, kernel_options = kernel_options) |
| return oup |
|
|
| def extra_repr(self) -> str: |
| tail = '' |
| return f'block size:{self.block_scales} {tail}' |
|
|