diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c3ad15cdd70cd708e0e9f2892a249b53bffc7169 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.so filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..fba9b80ad86faed74e2c283873f856fa03acac1d --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.bak +__pycache__ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d8ad54f9a038c1db9c45e36f6ffff63507457b09 --- /dev/null +++ b/README.md @@ -0,0 +1,24 @@ +--- +tags: +- kernels +- flash-mla +- deepseek +- kernel-builder +--- + +![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/flash-mla) + +## flash-mla + +This repo builds Deepseeks [FlashMLA](https://github.com/deepseek-ai/FlashMLA) kernel via the HF [kernel-builder](https://github.com/huggingface/kernel-builder) + +### Dev +```bash +nix develop -L +pytest -vv tests/ +``` + +### Build +```bash +nix build .#bundle -L +``` diff --git a/benchmarks/benchmark.py b/benchmarks/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..f03557a2cd35f110f7d53f2e52a0047ac9a9a1ba --- /dev/null +++ b/benchmarks/benchmark.py @@ -0,0 +1,322 @@ +import math +import torch + +from kernels.benchmark import Benchmark + + +def _cdiv(a, b): + return (a + b - 1) // b + + +def _extract_output(result): + if isinstance(result, tuple): + return result[0] + return result + + +def _reference_mla_decode(q, blocked_k, block_table, cache_seqlens, head_dim_v, causal=False): + b, s_q, h_q, d = q.size() + block_size = blocked_k.size(1) + h_kv = blocked_k.size(2) + + out = torch.empty(b, s_q, h_q, head_dim_v, dtype=torch.float32, device=q.device) + + for i in range(b): + cur_len = int(cache_seqlens[i].item()) + num_blocks = _cdiv(cur_len, block_size) + cur_blocks = block_table[i][:num_blocks] + kv = blocked_k[cur_blocks].reshape(-1, h_kv, d)[:cur_len] + + query = q[i].transpose(0, 1).float() # [h_q, s_q, d] + key_val = kv.transpose(0, 1).float() # [h_kv, s_k, d] + + if h_kv != h_q: + key_val = key_val.repeat_interleave(h_q // h_kv, dim=0) + + attn = query @ key_val.transpose(-2, -1) / math.sqrt(d) + + s_k = key_val.size(1) + if causal and s_q > 1: + mask = torch.ones(s_q, s_k, dtype=torch.bool, device=q.device).tril( + diagonal=s_k - s_q + ) + attn.masked_fill_(~mask, float("-inf")) + + attn = torch.softmax(attn, dim=-1) + output = attn @ key_val[..., :head_dim_v] + out[i] = output.transpose(0, 1) + + return out.to(q.dtype) + + +def _varlen_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, causal=False): + batch_size = cu_seqlens_q.shape[0] - 1 + total_tokens_q = q.shape[0] + num_heads = q.shape[1] + head_dim_v = v.shape[2] + scale = q.shape[-1] ** (-0.5) + + out = torch.zeros( + (total_tokens_q, num_heads, head_dim_v), device=q.device, dtype=q.dtype + ) + + for b in range(batch_size): + start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1] + start_k, end_k = cu_seqlens_k[b], cu_seqlens_k[b + 1] + + q_b = q[start_q:end_q].transpose(0, 1).float() # [H, seq_q, D_qk] + k_b = k[start_k:end_k].transpose(0, 1).float() # [H, seq_k, D_qk] + v_b = v[start_k:end_k].transpose(0, 1).float() # [H, seq_k, D_v] + + attn = q_b @ k_b.transpose(-2, -1) * scale + + if causal: + seq_q, seq_k = q_b.size(1), k_b.size(1) + mask = torch.ones(seq_q, seq_k, dtype=torch.bool, device=q.device).tril( + diagonal=seq_k - seq_q + ) + attn.masked_fill_(~mask, float("-inf")) + + attn = torch.softmax(attn, dim=-1) + result = attn @ v_b # [H, seq_q, D_v] + out[start_q:end_q] = result.transpose(0, 1).to(q.dtype) + + return out + + +# MLA decode constants (DeepSeek V3 architecture) +_HEAD_DIM = 576 # Q/K head dimension +_HEAD_DIM_V = 512 # V head dimension +_NUM_HEADS_K = 1 # MLA uses single KV head +_PAGE_BLOCK_SIZE = 64 # Page block size + + +def _setup_mla_decode(bench, batch_size, seq_k, num_heads_q): + max_num_blocks = _cdiv(seq_k, _PAGE_BLOCK_SIZE) + total_blocks = batch_size * max_num_blocks + + bench.q = ( + torch.randn( + batch_size, 1, num_heads_q, _HEAD_DIM, device="cuda", dtype=torch.bfloat16 + ) + / 10 + ) + bench.blocked_k = ( + torch.randn( + total_blocks, + _PAGE_BLOCK_SIZE, + _NUM_HEADS_K, + _HEAD_DIM, + device="cuda", + dtype=torch.bfloat16, + ) + / 10 + ) + bench.block_table = torch.arange( + total_blocks, device="cuda", dtype=torch.int32 + ).view(batch_size, max_num_blocks) + bench.cache_seqlens = torch.full( + (batch_size,), seq_k, device="cuda", dtype=torch.int32 + ) + bench.tile_scheduler_metadata, _ = bench.kernel.get_mla_metadata() + bench.out = torch.empty( + batch_size, 1, num_heads_q, _HEAD_DIM_V, device="cuda", dtype=torch.bfloat16 + ) + + +def _run_mla_decode(bench, causal=False): + out, lse = bench.kernel.flash_mla_with_kvcache( + q=bench.q, + k_cache=bench.blocked_k, + block_table=bench.block_table, + cache_seqlens=bench.cache_seqlens, + head_dim_v=_HEAD_DIM_V, + tile_scheduler_metadata=bench.tile_scheduler_metadata, + causal=causal, + ) + bench.out = out + + +def _verify_mla_decode(bench, causal=False): + return _reference_mla_decode( + bench.q, + bench.blocked_k, + bench.block_table, + bench.cache_seqlens, + _HEAD_DIM_V, + causal=causal, + ) + + +class FlashMLABenchmark(Benchmark): + seed: int = 42 + + # Workload: small (B=2, S_k=256, H_q=64) + def setup_small(self): + _setup_mla_decode(self, batch_size=2, seq_k=256, num_heads_q=64) + + def benchmark_small(self): + _run_mla_decode(self, causal=False) + + def verify_small(self) -> torch.Tensor: + return _verify_mla_decode(self, causal=False) + + # Workload: medium (B=4, S_k=1024, H_q=64) + def setup_medium(self): + _setup_mla_decode(self, batch_size=4, seq_k=1024, num_heads_q=64) + + def benchmark_medium(self): + _run_mla_decode(self, causal=False) + + def verify_medium(self) -> torch.Tensor: + return _verify_mla_decode(self, causal=False) + + # Workload: large (B=8, S_k=4096, H_q=128) + def setup_large(self): + _setup_mla_decode(self, batch_size=8, seq_k=4096, num_heads_q=128) + + def benchmark_large(self): + _run_mla_decode(self, causal=False) + + def verify_large(self) -> torch.Tensor: + return _verify_mla_decode(self, causal=False) + + +class FlashMLACausalBenchmark(Benchmark): + seed: int = 42 + + # Workload: small (B=2, S_k=256, H_q=64) + def setup_small(self): + _setup_mla_decode(self, batch_size=2, seq_k=256, num_heads_q=64) + + def benchmark_small(self): + _run_mla_decode(self, causal=True) + + def verify_small(self) -> torch.Tensor: + return _verify_mla_decode(self, causal=True) + + # Workload: medium (B=4, S_k=1024, H_q=64) + def setup_medium(self): + _setup_mla_decode(self, batch_size=4, seq_k=1024, num_heads_q=64) + + def benchmark_medium(self): + _run_mla_decode(self, causal=True) + + def verify_medium(self) -> torch.Tensor: + return _verify_mla_decode(self, causal=True) + + # Workload: large (B=8, S_k=4096, H_q=128) + def setup_large(self): + _setup_mla_decode(self, batch_size=8, seq_k=4096, num_heads_q=128) + + def benchmark_large(self): + _run_mla_decode(self, causal=True) + + def verify_large(self) -> torch.Tensor: + return _verify_mla_decode(self, causal=True) + + +# class FlashMLAVarlenBenchmark(Benchmark): +# seed: int = 42 + +# # Workload: small (3 sequences, max_seqlen=64) +# def setup_small(self): +# H, D = 8, 64 +# seqlens = [32, 48, 64] +# total = sum(seqlens) +# self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16) +# self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16) +# self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16) +# self.cu_seqlens = torch.tensor( +# [0] + list(torch.cumsum(torch.tensor(seqlens), 0)), +# device="cuda", +# dtype=torch.int32, +# ) +# self.max_seqlen = max(seqlens) +# self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16) + +# def benchmark_small(self): +# self.out = _extract_output( +# self.kernel.flash_attn_varlen_func( +# self.q, +# self.k, +# self.v, +# self.cu_seqlens, +# self.cu_seqlens, +# self.max_seqlen, +# self.max_seqlen, +# ) +# ) + +# def verify_small(self) -> torch.Tensor: +# return _varlen_reference_attention( +# self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False +# ) + +# # Workload: medium (5 sequences, max_seqlen=256) +# def setup_medium(self): +# H, D = 16, 64 +# seqlens = [128, 192, 256, 200, 150] +# total = sum(seqlens) +# self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16) +# self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16) +# self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16) +# self.cu_seqlens = torch.tensor( +# [0] + list(torch.cumsum(torch.tensor(seqlens), 0)), +# device="cuda", +# dtype=torch.int32, +# ) +# self.max_seqlen = max(seqlens) +# self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16) + +# def benchmark_medium(self): +# self.out = _extract_output( +# self.kernel.flash_attn_varlen_func( +# self.q, +# self.k, +# self.v, +# self.cu_seqlens, +# self.cu_seqlens, +# self.max_seqlen, +# self.max_seqlen, +# ) +# ) + +# def verify_medium(self) -> torch.Tensor: +# return _varlen_reference_attention( +# self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False +# ) + +# # Workload: large (8 sequences, max_seqlen=512) +# def setup_large(self): +# H, D = 32, 128 +# seqlens = [256, 384, 512, 448, 320, 480, 400, 512] +# total = sum(seqlens) +# self.q = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16) +# self.k = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16) +# self.v = torch.randn(total, H, D, device="cuda", dtype=torch.bfloat16) +# self.cu_seqlens = torch.tensor( +# [0] + list(torch.cumsum(torch.tensor(seqlens), 0)), +# device="cuda", +# dtype=torch.int32, +# ) +# self.max_seqlen = max(seqlens) +# self.out = torch.empty(total, H, D, device="cuda", dtype=torch.bfloat16) + +# def benchmark_large(self): +# self.out = _extract_output( +# self.kernel.flash_attn_varlen_func( +# self.q, +# self.k, +# self.v, +# self.cu_seqlens, +# self.cu_seqlens, +# self.max_seqlen, +# self.max_seqlen, +# ) +# ) + +# def verify_large(self) -> torch.Tensor: +# return _varlen_reference_attention( +# self.q, self.k, self.v, self.cu_seqlens, self.cu_seqlens, causal=False +# ) diff --git a/build.toml b/build.toml new file mode 100644 index 0000000000000000000000000000000000000000..66307e488ea10c8b514089489091f67a144c7d25 --- /dev/null +++ b/build.toml @@ -0,0 +1,27 @@ +[general] +name = "flash_mla" + +[torch] +src = ["torch-ext/torch_binding.cpp", "torch-ext/torch_binding.h"] + + +[kernel.activation] +cuda-capabilities = [ + # "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", + + # Only available on H100 and H200 + "9.0", # (Hopper) +] +src = [ + "flash_mla/flash_mla_api.cu", + "flash_mla/flash_fwd_mla_bf16_sm90.cu", + "flash_mla/flash_fwd_mla_fp16_sm90.cu", + "flash_mla/flash_fwd_mla_kernel.h", + "flash_mla/flash_fwd_mla_metadata.cu", + "flash_mla/flash_mla.h", + "flash_mla/named_barrier.h", + "flash_mla/softmax.h", + "flash_mla/static_switch.h", + "flash_mla/utils.h", +] +depends = ["torch", "cutlass_3_6"] diff --git a/build/torch210-cxx11-cu128-aarch64-linux/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so b/build/torch210-cxx11-cu128-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..00fa14eb9a9ead741de9208cad2db3d99ee0b774 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50f679700cb2de6e4ed31faec0195d268b829f42fa5aa68b9e095cb6674e42c3 +size 3671112 diff --git a/build/torch210-cxx11-cu128-aarch64-linux/_ops.py b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5978e1b450acfa19dfe796dceea8fe35f5736 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_89d7fc1 +ops = torch.ops._flash_mla_cuda_89d7fc1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_89d7fc1::{op_name}" diff --git a/build/torch210-cxx11-cu128-aarch64-linux/flash_mla/__init__.py b/build/torch210-cxx11-cu128-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-aarch64-linux/flash_mla_interface.py b/build/torch210-cxx11-cu128-aarch64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch210-cxx11-cu128-aarch64-linux/metadata.json b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4899badb63d45293425e2164944268b6058af95d --- /dev/null +++ b/build/torch210-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,11 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "9.0a" + ] + } +} diff --git a/build/torch210-cxx11-cu128-x86_64-linux/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so b/build/torch210-cxx11-cu128-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..0a790d515681501232de9508b9bf6381611f9b29 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a5a848e71e943d30f41eea9b201cb51b7ebdf5e962652e836d46d4fc3f4769f6 +size 3534072 diff --git a/build/torch210-cxx11-cu128-x86_64-linux/_ops.py b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5978e1b450acfa19dfe796dceea8fe35f5736 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_89d7fc1 +ops = torch.ops._flash_mla_cuda_89d7fc1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_89d7fc1::{op_name}" diff --git a/build/torch210-cxx11-cu128-x86_64-linux/flash_mla/__init__.py b/build/torch210-cxx11-cu128-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/flash_mla_interface.py b/build/torch210-cxx11-cu128-x86_64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch210-cxx11-cu128-x86_64-linux/metadata.json b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4899badb63d45293425e2164944268b6058af95d --- /dev/null +++ b/build/torch210-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,11 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "9.0a" + ] + } +} diff --git a/build/torch210-cxx11-cu130-aarch64-linux/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so b/build/torch210-cxx11-cu130-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..dcecabed01c07ddde3e0ceaeeb9d7de4af2f31fb --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cddae191a26874b54292a433f7d1b136f99dfc1549347f8a040b723ac5da98f +size 9445408 diff --git a/build/torch210-cxx11-cu130-aarch64-linux/_ops.py b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5978e1b450acfa19dfe796dceea8fe35f5736 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_89d7fc1 +ops = torch.ops._flash_mla_cuda_89d7fc1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_89d7fc1::{op_name}" diff --git a/build/torch210-cxx11-cu130-aarch64-linux/flash_mla/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-aarch64-linux/flash_mla_interface.py b/build/torch210-cxx11-cu130-aarch64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch210-cxx11-cu130-aarch64-linux/metadata.json b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8190d75efa8fd6449ddcd73de2072f17086e0842 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0f", + "9.0a" + ] + } +} diff --git a/build/torch210-cxx11-cu130-x86_64-linux/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so b/build/torch210-cxx11-cu130-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..4571d1a3fc239cbb32a051fd7af800444d8e019f --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ebb93cd1323c0ade7e8f102af41d8c4b310722ceea363da2ceadbd831dbf723e +size 9395216 diff --git a/build/torch210-cxx11-cu130-x86_64-linux/_ops.py b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5978e1b450acfa19dfe796dceea8fe35f5736 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_89d7fc1 +ops = torch.ops._flash_mla_cuda_89d7fc1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_89d7fc1::{op_name}" diff --git a/build/torch210-cxx11-cu130-x86_64-linux/flash_mla/__init__.py b/build/torch210-cxx11-cu130-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/flash_mla_interface.py b/build/torch210-cxx11-cu130-x86_64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch210-cxx11-cu130-x86_64-linux/metadata.json b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8190d75efa8fd6449ddcd73de2072f17086e0842 --- /dev/null +++ b/build/torch210-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0f", + "9.0a" + ] + } +} diff --git a/build/torch211-cxx11-cu128-aarch64-linux/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so b/build/torch211-cxx11-cu128-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..77f8e71ce401715c6b75f5b4d2098e5ab3e6136f --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e4dbe8aa77fe391658f60bddcfaa7b9ffd1ab34a62d870a8fc2db40752faf88 +size 3667336 diff --git a/build/torch211-cxx11-cu128-aarch64-linux/_ops.py b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5978e1b450acfa19dfe796dceea8fe35f5736 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_89d7fc1 +ops = torch.ops._flash_mla_cuda_89d7fc1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_89d7fc1::{op_name}" diff --git a/build/torch211-cxx11-cu128-aarch64-linux/flash_mla/__init__.py b/build/torch211-cxx11-cu128-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-aarch64-linux/flash_mla_interface.py b/build/torch211-cxx11-cu128-aarch64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch211-cxx11-cu128-aarch64-linux/metadata.json b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4899badb63d45293425e2164944268b6058af95d --- /dev/null +++ b/build/torch211-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,11 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "9.0a" + ] + } +} diff --git a/build/torch211-cxx11-cu128-x86_64-linux/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so b/build/torch211-cxx11-cu128-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..2ed92062db7dc5c98ec2fa0b8258e4fdbefc2353 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a1b76510b5134347620c4874e7f3e272dcbe3360e522fd7ffaede1d7f850cf6 +size 3522920 diff --git a/build/torch211-cxx11-cu128-x86_64-linux/_ops.py b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5978e1b450acfa19dfe796dceea8fe35f5736 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_89d7fc1 +ops = torch.ops._flash_mla_cuda_89d7fc1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_89d7fc1::{op_name}" diff --git a/build/torch211-cxx11-cu128-x86_64-linux/flash_mla/__init__.py b/build/torch211-cxx11-cu128-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/flash_mla_interface.py b/build/torch211-cxx11-cu128-x86_64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch211-cxx11-cu128-x86_64-linux/metadata.json b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4899badb63d45293425e2164944268b6058af95d --- /dev/null +++ b/build/torch211-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,11 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "9.0a" + ] + } +} diff --git a/build/torch211-cxx11-cu130-aarch64-linux/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so b/build/torch211-cxx11-cu130-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..f1fb3727ad09cd93f620991815595b1822d9f875 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:24e204ec4ae541edbec284824951032598375f8b63c7e06d622543b42f6f339f +size 9441328 diff --git a/build/torch211-cxx11-cu130-aarch64-linux/_ops.py b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5978e1b450acfa19dfe796dceea8fe35f5736 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_89d7fc1 +ops = torch.ops._flash_mla_cuda_89d7fc1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_89d7fc1::{op_name}" diff --git a/build/torch211-cxx11-cu130-aarch64-linux/flash_mla/__init__.py b/build/torch211-cxx11-cu130-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-aarch64-linux/flash_mla_interface.py b/build/torch211-cxx11-cu130-aarch64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch211-cxx11-cu130-aarch64-linux/metadata.json b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8190d75efa8fd6449ddcd73de2072f17086e0842 --- /dev/null +++ b/build/torch211-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0f", + "9.0a" + ] + } +} diff --git a/build/torch211-cxx11-cu130-x86_64-linux/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so b/build/torch211-cxx11-cu130-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..37e43160775e17348e148cacb526a2d4cd2d72a0 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e585a4092d2aaab0ba89ddf1e266034c0e574fd9388e5a6d2cb21948e80cce0 +size 9383872 diff --git a/build/torch211-cxx11-cu130-x86_64-linux/_ops.py b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5978e1b450acfa19dfe796dceea8fe35f5736 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_89d7fc1 +ops = torch.ops._flash_mla_cuda_89d7fc1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_89d7fc1::{op_name}" diff --git a/build/torch211-cxx11-cu130-x86_64-linux/flash_mla/__init__.py b/build/torch211-cxx11-cu130-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/flash_mla_interface.py b/build/torch211-cxx11-cu130-x86_64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch211-cxx11-cu130-x86_64-linux/metadata.json b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8190d75efa8fd6449ddcd73de2072f17086e0842 --- /dev/null +++ b/build/torch211-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0f", + "9.0a" + ] + } +} diff --git a/build/torch25-cxx11-cu118-x86_64-linux/flash_mla/__init__.py b/build/torch25-cxx11-cu118-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch25-cxx11-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch25-cxx11-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..495e029409a66d955828ec98a963d997e7f55803 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a9a2b99b276b5aa714b27d1f54cc5da2d451e65a9ed385c583daf528f2c030a9 +size 2564144 diff --git a/build/torch25-cxx11-cu118-x86_64-linux/flash_mla/_ops.py b/build/torch25-cxx11-cu118-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch25-cxx11-cu118-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu121-x86_64-linux/flash_mla/__init__.py b/build/torch25-cxx11-cu121-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch25-cxx11-cu121-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch25-cxx11-cu121-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..4576745bcbec270ebb59949a1aac7146e5754c46 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:77e241f633fa5b103f379ba6ac58d2cc068e0c3fc4d4f20ac1e1c679fc19614f +size 2595176 diff --git a/build/torch25-cxx11-cu121-x86_64-linux/flash_mla/_ops.py b/build/torch25-cxx11-cu121-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch25-cxx11-cu121-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx11-cu124-x86_64-linux/flash_mla/__init__.py b/build/torch25-cxx11-cu124-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch25-cxx11-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch25-cxx11-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..60646291e7927b00f5a921ae85e16102a115fb52 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:31aa895a57efbd29aeff693b65b02842926bf1788d6f98022c32470a60265f9e +size 2580248 diff --git a/build/torch25-cxx11-cu124-x86_64-linux/flash_mla/_ops.py b/build/torch25-cxx11-cu124-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch25-cxx11-cu124-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu118-x86_64-linux/flash_mla/__init__.py b/build/torch25-cxx98-cu118-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch25-cxx98-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch25-cxx98-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..c22b643f0c606f54dced21fa2116796a02f23198 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7454c10a3b29128e035bdb3fa18d5fc3706f7970542a0bcb55d9714f0999d42f +size 2556792 diff --git a/build/torch25-cxx98-cu118-x86_64-linux/flash_mla/_ops.py b/build/torch25-cxx98-cu118-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch25-cxx98-cu118-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu121-x86_64-linux/flash_mla/__init__.py b/build/torch25-cxx98-cu121-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch25-cxx98-cu121-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch25-cxx98-cu121-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..5fb00b76fa56c3200f744e14a11afa5a3090dd7b --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8cb9402f3091420227cbccf1ec4938a444765e26f5d34c356c76bf7c85630d0 +size 2587896 diff --git a/build/torch25-cxx98-cu121-x86_64-linux/flash_mla/_ops.py b/build/torch25-cxx98-cu121-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch25-cxx98-cu121-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch25-cxx98-cu124-x86_64-linux/flash_mla/__init__.py b/build/torch25-cxx98-cu124-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch25-cxx98-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch25-cxx98-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..85a192d3f985a99cc8c984fabe3614ac2f44d9b2 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bb4be09cbde1979c1aa17e3bc93c1538f129b438d305bee0fe96f3c08efeee04 +size 2572968 diff --git a/build/torch25-cxx98-cu124-x86_64-linux/flash_mla/_ops.py b/build/torch25-cxx98-cu124-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch25-cxx98-cu124-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_mla/__init__.py b/build/torch26-cxx11-cu118-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch26-cxx11-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..513d74291a3c84f2b48b590a9fd6cec6c72f8f9e --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:35c67c788220d8988e47cd4ad976495450b71cd682bd8ab08af3db066d625126 +size 2564496 diff --git a/build/torch26-cxx11-cu118-x86_64-linux/flash_mla/_ops.py b/build/torch26-cxx11-cu118-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch26-cxx11-cu118-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_mla/__init__.py b/build/torch26-cxx11-cu124-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch26-cxx11-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..c5820e98edcaaf6652865037166c27a77cb8cdca --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:069fb3e3a051c91e73390245c7463218829b8decf0f60bd6fc9a0ba8127b5bd2 +size 2580592 diff --git a/build/torch26-cxx11-cu124-x86_64-linux/flash_mla/_ops.py b/build/torch26-cxx11-cu124-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch26-cxx11-cu124-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-aarch64-linux/flash_mla/__init__.py b/build/torch26-cxx11-cu126-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch26-cxx11-cu126-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch26-cxx11-cu126-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so b/build/torch26-cxx11-cu126-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..b3622a1486b5cc9b43381b072ead09c68986493b --- /dev/null +++ b/build/torch26-cxx11-cu126-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e1e97fef62f5ebbe6b19b0d5fbe700fcdf6b9acd7a54cba6f0b1d23665188fa9 +size 2643848 diff --git a/build/torch26-cxx11-cu126-aarch64-linux/flash_mla/_ops.py b/build/torch26-cxx11-cu126-aarch64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca4becc90e11facbc2ad156a8ef8bb23aeebed0 --- /dev/null +++ b/build/torch26-cxx11-cu126-aarch64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_341ab77 +ops = torch.ops._flash_mla_341ab77 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_341ab77::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_mla/__init__.py b/build/torch26-cxx11-cu126-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch26-cxx11-cu126-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..48e2389e6245bf0a3a220f3dddf52de715b00564 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:50fad86fa7bc15096c2a1feadf8091b20e188e32b8c0633423ec26e4e8e8e7ce +size 2560552 diff --git a/build/torch26-cxx11-cu126-x86_64-linux/flash_mla/_ops.py b/build/torch26-cxx11-cu126-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch26-cxx11-cu126-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_mla/__init__.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch26-cxx98-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..8f8b0677194aee3da40b16b1d4ccef0bcb1c6a75 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae937ddfbc3e6097b2fdd9197f2ddb5b9f66c65146a4de30ccab59dab6e18dd4 +size 2557136 diff --git a/build/torch26-cxx98-cu118-x86_64-linux/flash_mla/_ops.py b/build/torch26-cxx98-cu118-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch26-cxx98-cu118-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_mla/__init__.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch26-cxx98-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..e086d51f613f85e3cdeb42100e4350be2a181b28 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:503910324475f8bd9dab47687339005f58e5b623bf0c9e4234fabf099c08da33 +size 2573312 diff --git a/build/torch26-cxx98-cu124-x86_64-linux/flash_mla/_ops.py b/build/torch26-cxx98-cu124-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch26-cxx98-cu124-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-aarch64-linux/flash_mla/__init__.py b/build/torch26-cxx98-cu126-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch26-cxx98-cu126-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch26-cxx98-cu126-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so b/build/torch26-cxx98-cu126-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..fe8ba1fea4456f4132b5245ca04ffae76c2d43f4 --- /dev/null +++ b/build/torch26-cxx98-cu126-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6f15b3b0bd0bee56760bd6500175ca5a1fd17f2742ef9496c28ea3720d038c66 +size 2640208 diff --git a/build/torch26-cxx98-cu126-aarch64-linux/flash_mla/_ops.py b/build/torch26-cxx98-cu126-aarch64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca4becc90e11facbc2ad156a8ef8bb23aeebed0 --- /dev/null +++ b/build/torch26-cxx98-cu126-aarch64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_341ab77 +ops = torch.ops._flash_mla_341ab77 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_341ab77::{op_name}" \ No newline at end of file diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_mla/__init__.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch26-cxx98-cu126-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..2339738fe7aa81421ecdc9e619b68c1b5a2db07c --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2c41fa4058ee2bb5d3d90458a7f92f0ef1c10e8bc854329cf7c208025bb244b2 +size 2553280 diff --git a/build/torch26-cxx98-cu126-x86_64-linux/flash_mla/_ops.py b/build/torch26-cxx98-cu126-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch26-cxx98-cu126-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu118-x86_64-linux/flash_mla/__init__.py b/build/torch27-cxx11-cu118-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch27-cxx11-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..518b48e7aaa58c369e3c721c64c4f2e5c7a88035 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:59c4034880f4482b06e447a2c4810aaf8009b7d4c86a4fd71356f169df986535 +size 2564632 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/flash_mla/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-aarch64-linux/flash_mla/__init__.py b/build/torch27-cxx11-cu126-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch27-cxx11-cu126-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch27-cxx11-cu126-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so b/build/torch27-cxx11-cu126-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..5770b43c7c5f6d6b5beff9fff23624279651b0aa --- /dev/null +++ b/build/torch27-cxx11-cu126-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb925b062d31034672a45d925a3767d953e97a3c6c483467e6b81833d42b5a27 +size 2644048 diff --git a/build/torch27-cxx11-cu126-aarch64-linux/flash_mla/_ops.py b/build/torch27-cxx11-cu126-aarch64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca4becc90e11facbc2ad156a8ef8bb23aeebed0 --- /dev/null +++ b/build/torch27-cxx11-cu126-aarch64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_341ab77 +ops = torch.ops._flash_mla_341ab77 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_341ab77::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/flash_mla/__init__.py b/build/torch27-cxx11-cu126-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch27-cxx11-cu126-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..0b5c36b5a1b53830df46e7d3081382621e484bdc --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5db69ef4975e2eee001e6a9b7466c1fe40bc2228ed64eb8c24caf3e0fb6ed0b2 +size 2560584 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/flash_mla/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-aarch64-linux/flash_mla/__init__.py b/build/torch27-cxx11-cu128-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch27-cxx11-cu128-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so b/build/torch27-cxx11-cu128-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..12e46dd52eaa72ac67e34c8f60333c9a9d111c80 --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/flash_mla/_flash_mla_341ab77.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7776c629263bc0b32b82b8a094ead0749d6c393b6ca25c9ffa812bd8fbdb3002 +size 2709472 diff --git a/build/torch27-cxx11-cu128-aarch64-linux/flash_mla/_ops.py b/build/torch27-cxx11-cu128-aarch64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6ca4becc90e11facbc2ad156a8ef8bb23aeebed0 --- /dev/null +++ b/build/torch27-cxx11-cu128-aarch64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_341ab77 +ops = torch.ops._flash_mla_341ab77 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_341ab77::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/flash_mla/__init__.py b/build/torch27-cxx11-cu128-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/build/torch27-cxx11-cu128-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7769a2d82af13e218c14d783100f0c9e36090cbc --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/flash_mla/_flash_mla_d4f4195.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3fc7eb9341c975d0e313d837977ca3ed13556e6fe63926e0bf117f62499ea052 +size 2615448 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/flash_mla/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/flash_mla/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6e5aa4c56e24b75711edaa5d90f25828a6eb2484 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/flash_mla/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_d4f4195 +ops = torch.ops._flash_mla_d4f4195 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_d4f4195::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-aarch64-linux/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_flash_mla_cuda_09f70ef.abi3.so b/build/torch29-cxx11-cu128-aarch64-linux/_flash_mla_cuda_09f70ef.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..8d043d5d2fc760bfae5abe1b6cbd8b9887e0aecd --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_flash_mla_cuda_09f70ef.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f3f3b9f82de962911a1b8467a3130b34ea10e0b8e7db32432bb42ac24862e52e +size 3667784 diff --git a/build/torch29-cxx11-cu128-aarch64-linux/_ops.py b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7ed02f7680582f28bdb0d1e552de1dc177f7c5 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_09f70ef +ops = torch.ops._flash_mla_cuda_09f70ef + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_09f70ef::{op_name}" diff --git a/build/torch29-cxx11-cu128-aarch64-linux/flash_mla/__init__.py b/build/torch29-cxx11-cu128-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-aarch64-linux/flash_mla_interface.py b/build/torch29-cxx11-cu128-aarch64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch29-cxx11-cu128-aarch64-linux/metadata.json b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4899badb63d45293425e2164944268b6058af95d --- /dev/null +++ b/build/torch29-cxx11-cu128-aarch64-linux/metadata.json @@ -0,0 +1,11 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu128-x86_64-linux/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_flash_mla_cuda_09f70ef.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/_flash_mla_cuda_09f70ef.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..17cee4c7563a933314a32d9c5e074a58f2443baf --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_flash_mla_cuda_09f70ef.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:22efc944dccf34d020a87fe713033bfae026d91eda6ecc86a4d491abe38edc51 +size 3523096 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7ed02f7680582f28bdb0d1e552de1dc177f7c5 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_09f70ef +ops = torch.ops._flash_mla_cuda_09f70ef + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_09f70ef::{op_name}" diff --git a/build/torch29-cxx11-cu128-x86_64-linux/flash_mla/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/flash_mla_interface.py b/build/torch29-cxx11-cu128-x86_64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/metadata.json b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4899badb63d45293425e2164944268b6058af95d --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/metadata.json @@ -0,0 +1,11 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu129-aarch64-linux/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so b/build/torch29-cxx11-cu129-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..36cdedde23ec83979438ec79aaf0bb2de356a47d --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_flash_mla_cuda_89d7fc1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:18ec10cef6a11ac67f4fe62df9bd02e015a9355c493423c479d22acc80428e99 +size 9363528 diff --git a/build/torch29-cxx11-cu129-aarch64-linux/_ops.py b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5978e1b450acfa19dfe796dceea8fe35f5736 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_89d7fc1 +ops = torch.ops._flash_mla_cuda_89d7fc1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_89d7fc1::{op_name}" diff --git a/build/torch29-cxx11-cu129-aarch64-linux/flash_mla/__init__.py b/build/torch29-cxx11-cu129-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/flash_mla_interface.py b/build/torch29-cxx11-cu129-aarch64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch29-cxx11-cu129-aarch64-linux/metadata.json b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8190d75efa8fd6449ddcd73de2072f17086e0842 --- /dev/null +++ b/build/torch29-cxx11-cu129-aarch64-linux/metadata.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0f", + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu129-x86_64-linux/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so b/build/torch29-cxx11-cu129-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..ff77569f64b238e7132b594f1e5dd6417700ffaf --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_flash_mla_cuda_89d7fc1.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbee04ad485babe4202c1c634e2914fddb814f7225dc8b3ee556260d6b03557e +size 9283880 diff --git a/build/torch29-cxx11-cu129-x86_64-linux/_ops.py b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4dd5978e1b450acfa19dfe796dceea8fe35f5736 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_89d7fc1 +ops = torch.ops._flash_mla_cuda_89d7fc1 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_89d7fc1::{op_name}" diff --git a/build/torch29-cxx11-cu129-x86_64-linux/flash_mla/__init__.py b/build/torch29-cxx11-cu129-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b2672c1cd85b74c1b3ded0fc0b2100e1aeac23 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/flash_mla_interface.py b/build/torch29-cxx11-cu129-x86_64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch29-cxx11-cu129-x86_64-linux/metadata.json b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8190d75efa8fd6449ddcd73de2072f17086e0842 --- /dev/null +++ b/build/torch29-cxx11-cu129-x86_64-linux/metadata.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0f", + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu130-aarch64-linux/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_flash_mla_cuda_09f70ef.abi3.so b/build/torch29-cxx11-cu130-aarch64-linux/_flash_mla_cuda_09f70ef.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..69ac8d6d7b2a2b86ac945f814f49c1c032296812 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_flash_mla_cuda_09f70ef.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:97dd39bb9a543eef84a778853629313312dfb6af07ae5a01d9d5e7c10dc7df16 +size 9441960 diff --git a/build/torch29-cxx11-cu130-aarch64-linux/_ops.py b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7ed02f7680582f28bdb0d1e552de1dc177f7c5 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_09f70ef +ops = torch.ops._flash_mla_cuda_09f70ef + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_09f70ef::{op_name}" diff --git a/build/torch29-cxx11-cu130-aarch64-linux/flash_mla/__init__.py b/build/torch29-cxx11-cu130-aarch64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/flash_mla_interface.py b/build/torch29-cxx11-cu130-aarch64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch29-cxx11-cu130-aarch64-linux/metadata.json b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8190d75efa8fd6449ddcd73de2072f17086e0842 --- /dev/null +++ b/build/torch29-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0f", + "9.0a" + ] + } +} diff --git a/build/torch29-cxx11-cu130-x86_64-linux/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..db300fe9b95176a20b27b3641d89be657d0c4319 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/__init__.py @@ -0,0 +1,163 @@ +from typing import Optional, Tuple +import torch + +from .flash_mla_interface import FlashMLASchedMeta +from . import flash_mla_interface as _impl + + +def get_mla_metadata(*args, **kwargs) -> Tuple[FlashMLASchedMeta, None]: + return _impl.get_mla_metadata(*args, **kwargs) + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_mla_with_kvcache( + q=q, + k_cache=k_cache, + block_table=block_table, + cache_seqlens=cache_seqlens, + head_dim_v=head_dim_v, + tile_scheduler_metadata=tile_scheduler_metadata, + num_splits=num_splits, + softmax_scale=softmax_scale, + causal=causal, + is_fp8_kvcache=is_fp8_kvcache, + indices=indices, + attn_sink=attn_sink, + extra_k_cache=extra_k_cache, + extra_indices_in_kvcache=extra_indices_in_kvcache, + topk_length=topk_length, + extra_topk_length=extra_topk_length, + ) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return _impl.flash_mla_sparse_fwd( + q=q, + kv=kv, + indices=indices, + sm_scale=sm_scale, + d_v=d_v, + attn_sink=attn_sink, + topk_length=topk_length, + ) + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_qkvpacked_func( + qkv=qkv, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + return _impl.flash_attn_varlen_kvpacked_func( + q=q, + kv=kv, + cu_seqlens_qo=cu_seqlens_qo, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_qo=max_seqlen_qo, + max_seqlen_kv=max_seqlen_kv, + head_dim_qk=head_dim_qk, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + is_varlen=is_varlen, + ) + + +__all__ = [ + "__version__", + "FlashMLASchedMeta", + "get_mla_metadata", + "flash_mla_with_kvcache", + "flash_attn_varlen_func", + "flash_attn_varlen_qkvpacked_func", + "flash_attn_varlen_kvpacked_func", + "flash_mla_sparse_fwd", +] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_flash_mla_cuda_09f70ef.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/_flash_mla_cuda_09f70ef.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..c5f00086319d5168c2913aad7489d8d83204b054 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_flash_mla_cuda_09f70ef.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:243747242a38da9e2034590bfd52c3f883683b4faec3ccfd6ec76a0f97addf43 +size 9380384 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ea7ed02f7680582f28bdb0d1e552de1dc177f7c5 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _flash_mla_cuda_09f70ef +ops = torch.ops._flash_mla_cuda_09f70ef + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_flash_mla_cuda_09f70ef::{op_name}" diff --git a/build/torch29-cxx11-cu130-x86_64-linux/flash_mla/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/flash_mla/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/flash_mla_interface.py b/build/torch29-cxx11-cu130-x86_64-linux/flash_mla_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..a84e448ffe741bb6d3dafaf7888ed8cc94984467 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/flash_mla_interface.py @@ -0,0 +1,435 @@ +from typing import Optional, Tuple +import dataclasses + +import torch + +from ._ops import ops as flash_mla_cuda + +@dataclasses.dataclass +class FlashMLASchedMeta: + """ + A class that stores the tile scheduler metadata of FlashMLA + """ + + @dataclasses.dataclass + class Config: + b: int + s_q: int + h_q: int + page_block_size: int + h_k: int + + causal: bool + is_fp8_kvcache: bool + topk: Optional[int] + + extra_page_block_size: Optional[int] + extra_topk: Optional[int] + + have_initialized: bool = False + + config: Optional[Config] = None + + tile_scheduler_metadata: Optional[torch.Tensor] = None # (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. + num_splits: Optional[torch.Tensor] = None # (1), dtype torch.int32. + + +def get_mla_metadata( + *args, + **kwargs +) -> Tuple[FlashMLASchedMeta, None]: + """ + Returns an empty instance of FlashMLASchedMeta. The actual scheduling metadata will be generated during the first invocation of flash_mla_with_kvcache. + + Arguments: + This function does not need any arguments, but we keep *args and **kwargs to be compatible with the old interface. + + Return: + A tuple. Due to historical reasons, we return a tuple of (FlashMLASchedMeta, None) now. Only the first element is useful. + """ + return FlashMLASchedMeta(), None + + +def flash_mla_with_kvcache( + q: torch.Tensor, + k_cache: torch.Tensor, + block_table: Optional[torch.Tensor], + cache_seqlens: Optional[torch.Tensor], + head_dim_v: int, + tile_scheduler_metadata: FlashMLASchedMeta, + num_splits: None = None, + softmax_scale: Optional[float] = None, + causal: bool = False, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, + attn_sink: Optional[torch.Tensor] = None, + extra_k_cache: Optional[torch.Tensor] = None, + extra_indices_in_kvcache: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, + extra_topk_length: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Arguments: + q: (batch_size, seq_len_q, num_heads_q, head_dim). + k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). + Different modes (including fp8/bf16, and sparsity) has different KV cache layouts. See comments below for details. + The KV cache must be contiguously valid for sparse attention on sm100. Here "contiguously valid" means that every byte, from the very beginning of the KV cache, till the last byte in the KV cache, is valid memory address to visit (i.e. won't IMA). In other words, the KV cache could be a slice of a larger array, but cannot be a list of disjoint memory blocks. + block_table: (batch_size, max_num_blocks_per_seq), torch.int32. Can be None when sparse attention is used. + cache_seqlens: (batch_size), torch.int32. Can be None when sparse attention is used. + head_dim_v: Head_dim of v. Must be 512 + sched_meta: FlashMLASchedMeta, return by get_mla_metadata. You may reuse the same sched_meta across different invocations, but only when the tensor shapes and the values of cache_seqlens, topk_length, and extra_topk_length remain the same. + num_splits_placeholder: must be "None" (to be compatible with the old interface). + softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim_k). + causal: bool. Whether to apply causal attention mask. Only valid for dense attention + is_fp8_kvcache: bool. + indices: (batch_size, seq_len_q, topk). KV indices when sparse attention is enabled. + Pay attention that indices_in_kvcache[i][j][k] = (the index of the page block where token t resides) * block_size + (the offset of token t among the page block), + where t is the k-th token of the j-th q-sequence in the i-th batch. + attn_sink: Optional[torch.Tensor], (num_heads_q, ), torch.float32. If presented, the final output will be scaled by exp(lse) / (exp(lse) + exp(attn_sink)). Have no affect on the returned softmax_lse. +inf will cause the result to become 0. + extra_k_cache and extra_indices_in_kvcache: If provided, will attend to these extra tokens in addition to those in k_cache and indices_in_kvcache. Their format requirements are the same as k_cache and indices_in_kvcache respectively. + topk_length/extra_topk_length: (batch_size, ), torch.int32. If provided, only the leftmost topk_length indices will be processed. Useful when the actual topk for different queries are different so that we can save some computation, compared to masking. + + For DeepSeek V3, DeepSeek V3.1, and DeepSeek V3.2: + head_dim should be 576 while head_dim_v should be 512. + In FP8+sparse mode, each token's KV cache is 656 Bytes, structured as: + - The shape of the tensor `k_cache` is (num_blocks, page_block_size, num_heads_k, head_dim), and num_heads_k must be 1. + - First 512 bytes: The "quantized NoPE" part, containing 512 float8_e4m3 values. + - Next 16 bytes: Scale factors, containing 4 float32 values. The first float32 is the scale for the first 128 float8_e4m3 values, the second for the next 128, and so on. + - Last 128 bytes: The "RoPE" part, containing 64 bfloat16 values. This part is not quantized for accuracy. + + Return: + out: (batch_size, seq_len_q, num_heads_q, head_dim_v). + softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. + """ + sched_meta = tile_scheduler_metadata + indices_in_kvcache = indices + assert isinstance(sched_meta, FlashMLASchedMeta), "tile_scheduler_metadata must be of type FlashMLASchedMeta" + assert num_splits is None, "num_splits must be None" + + topk = indices_in_kvcache.shape[-1] if indices_in_kvcache is not None else None + extra_k_page_block_size = extra_k_cache.shape[1] if extra_k_cache is not None else None + extra_topk = extra_indices_in_kvcache.shape[-1] if extra_indices_in_kvcache is not None else None + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + if not sched_meta.have_initialized: + # Sanity check. We only perform sanity check during the first invocation to save CPU time. + if indices_in_kvcache is not None: + assert not causal, "causal must be False when indices_in_kvcache is not None (i.e. sparse attention is enabled)" + + # Initialize the tile scheduler metadata during the first invocation. + sched_meta.have_initialized = True + sched_meta.config = FlashMLASchedMeta.Config( + q.shape[0], + q.shape[1], + q.shape[2], + k_cache.shape[1], + k_cache.shape[2], + + causal, + is_fp8_kvcache, + topk, + + extra_k_page_block_size, + extra_topk, + ) + else: + # Check whether the input arguments are consistent with sched_meta + helper_msg = " Your input arguments are inconsistent with sched_meta. Please make sure the input arguments are consistent across different invocations of flash_mla_with_kvcache on the same sched_meta." + assert sched_meta.config is not None + assert sched_meta.config.b == q.shape[0], "sched_meta.config.b must be equal to batch_size." + helper_msg + assert sched_meta.config.s_q == q.shape[1], "sched_meta.config.s_q must be equal to seq_len_q." + helper_msg + assert sched_meta.config.h_q == q.shape[2], "sched_meta.config.h_q must be equal to num_heads_q." + helper_msg + assert sched_meta.config.page_block_size == k_cache.shape[1], "sched_meta.config.page_block_size must be equal to page_block_size." + helper_msg + assert sched_meta.config.h_k == k_cache.shape[2], "sched_meta.config.h_k must be equal to num_heads_k." + helper_msg + assert sched_meta.config.causal == causal, "sched_meta.config.causal must be equal to causal." + helper_msg + assert sched_meta.config.is_fp8_kvcache == is_fp8_kvcache, "sched_meta.config.is_fp8_kvcache must be equal to is_fp8_kvcache." + helper_msg + assert sched_meta.config.topk == topk, "sched_meta.config.topk must be equal to the last dim of indices_in_kvcache." + helper_msg + assert sched_meta.config.extra_page_block_size == extra_k_page_block_size, "sched_meta.config.extra_page_block_size must be equal to the page_block_size of extra_k_cache." + helper_msg + assert sched_meta.config.extra_topk == extra_topk, "sched_meta.config.extra_topk must be equal to the last dim of extra_indices_in_kvcache." + helper_msg + + if topk is not None: + # Sparse attention + assert not causal, "causal must be False when sparse attention is enabled" + assert is_fp8_kvcache, "is_fp8_kvcache must be True when sparse attention is enabled" + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.sparse_decode_fwd( + q, k_cache, indices_in_kvcache, topk_length, attn_sink, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits, + extra_k_cache, extra_indices_in_kvcache, extra_topk_length, + head_dim_v, softmax_scale + ) + else: + # Dense attention + assert indices_in_kvcache is None and attn_sink is None and extra_k_cache is None and extra_indices_in_kvcache is None and topk_length is None and extra_topk_length is None, "indices_in_kvcache, attn_sink, extra_k_cache, extra_indices_in_kvcache, topk_length and extra_topk_length must be None when dense attention is used." + assert block_table is not None and cache_seqlens is not None, "block_table and cache_seqlens must be provided when dense attention is used." + out, lse, new_tile_scheduler_metadata, new_num_splits = flash_mla_cuda.dense_decode_fwd( + q, k_cache, head_dim_v, + cache_seqlens, block_table, + softmax_scale, causal, + sched_meta.tile_scheduler_metadata, sched_meta.num_splits + ) + sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata + sched_meta.num_splits = new_num_splits + return (out, lse) + + +def flash_mla_sparse_fwd( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, + attn_sink: Optional[torch.Tensor] = None, + topk_length: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + attn_sink: optional, [h_q], float32. + If attn_sink is provided, when computing output, output will be additionally multiplied by exp(lse) / (exp(lse) + exp(attn_sink)). + +-inf in attn_sink will be handled normally (i.e., -inf has no effect, +inf will make corresponding output all zeros). + This argument has no effect on lse and max_logits. + topk_length: optional, [s_q], int32. If provided, the i-th q token will only attend to k tokens specified by indices[i, :, :topk_length[i]], ignoring later k/v tokens (even if provided in indices). + In extremely rare cases (topk_length provided, there is a valid topk index between topk_length[i] ~ s_kv, and that topk index points to a k token containing NaN), operator output will contain NaN, so please avoid this situation. + + Returns: + (output, max_logits, lse) + Please refer to tests/ref.py for the precise definitions of these parameters. + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, log-sum-exp of attention scores + """ + results = flash_mla_cuda.sparse_prefill_fwd( + q, kv, indices, sm_scale, d_v, attn_sink, topk_length + ) + return results + + +def _flash_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if out is None: + out = torch.empty(qo_total_len, num_qo_heads, head_dim_vo, device=q.device, dtype=q.dtype) + if lse is None: + # Make lse contiguous on seqlen dim + lse = torch.empty(num_qo_heads, qo_total_len, device=q.device, dtype=torch.float32).T + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_fwd( + workspace_buffer, + q, + k, + v, + cu_seqlens_qo, + cu_seqlens_kv, + out, + lse, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return out, lse + + +def _flash_attn_varlen_backward( + do: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + lse: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dq: Optional[torch.Tensor] = None, + dk: Optional[torch.Tensor] = None, + dv: Optional[torch.Tensor] = None, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + qo_total_len, num_qo_heads, head_dim_qk = q.shape + kv_total_len, num_kv_heads, head_dim_vo = v.shape + + # TODO: fix bwd GQA + if num_qo_heads != num_kv_heads: + raise ValueError(f"SM100 bwd doesn't support GQA now. num_qo_heads: {num_qo_heads}, num_kv_heads: {num_kv_heads}.") + + mask_mode_code = 1 if causal else 0 + if softmax_scale is None: + softmax_scale = head_dim_qk ** (-0.5) + + if dq is None: + dq = torch.empty(qo_total_len, num_qo_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dk is None: + dk = torch.empty(kv_total_len, num_kv_heads, head_dim_qk, device=q.device, dtype=q.dtype) + if dv is None: + dv = torch.empty(kv_total_len, num_kv_heads, head_dim_vo, device=q.device, dtype=q.dtype) + + max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 + bs = cu_seqlens_qo.shape[0] - 1 + workspace_bytes = 0 + workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc + workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse + if num_qo_heads != num_kv_heads: + workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc + workspace_buffer = torch.empty(workspace_bytes, dtype=torch.uint8, device=q.device) + flash_mla_cuda.dense_prefill_bwd( + workspace_buffer, + do, + q, + k, + v, + out, + lse, + cu_seqlens_qo, + cu_seqlens_kv, + dq, + dk, + dv, + mask_mode_code, + softmax_scale, + max_seqlen_qo, + max_seqlen_kv, + is_varlen, + ) + + return dq, dk, dv + + +class FlashAttnVarlenFunc(torch.autograd.Function): + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + causal: bool = False, + softmax_scale: Optional[float] = None, + is_varlen: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + out, lse = _flash_attn_varlen_forward( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal=causal, softmax_scale=softmax_scale, + is_varlen=is_varlen, + ) + ctx.save_for_backward(q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv) + ctx.max_seqlen_qo = max_seqlen_qo + ctx.max_seqlen_kv = max_seqlen_kv + ctx.causal = causal + ctx.softmax_scale = softmax_scale + ctx.is_varlen = is_varlen + return out, lse + + def backward( + ctx, + do: torch.Tensor, + dlse: torch.Tensor, + ): + del dlse # LSE doesn't support backward currently + q, k, v, out, lse, cu_seqlens_qo, cu_seqlens_kv = ctx.saved_tensors + dq, dk, dv = _flash_attn_varlen_backward( + do, q, k, v, out, lse, + cu_seqlens_qo, cu_seqlens_kv, ctx.max_seqlen_qo, ctx.max_seqlen_kv, + causal=ctx.causal, softmax_scale=ctx.softmax_scale, + is_varlen=ctx.is_varlen, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attn_varlen_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, k, v, + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + qkv[:, :, :head_dim_qk], qkv[:, :, head_dim_qk:head_dim_qk * 2], qkv[:, :, head_dim_qk * 2:], + cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, + causal, softmax_scale, is_varlen, + ) + + +def flash_attn_varlen_kvpacked_func( + q: torch.Tensor, + kv: torch.Tensor, + cu_seqlens_qo: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + max_seqlen_qo: int, + max_seqlen_kv: int, + head_dim_qk: int, + dropout_p: float = 0.0, + softmax_scale: Optional[float] = None, + causal: bool = False, + deterministic: bool = False, + is_varlen: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + assert dropout_p == 0.0 + assert not deterministic + return FlashAttnVarlenFunc.apply( + q, kv[:, :, :head_dim_qk], kv[:, :, head_dim_qk:], + cu_seqlens_qo, cu_seqlens_kv, max_seqlen_qo, max_seqlen_kv, + causal, softmax_scale, is_varlen, + ) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/metadata.json b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..8190d75efa8fd6449ddcd73de2072f17086e0842 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/metadata.json @@ -0,0 +1,12 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "10.0f", + "9.0a" + ] + } +} diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000000000000000000000000000000000000..514cfea459f8dc1ea70ccd9abc4c8f8d7571e250 --- /dev/null +++ b/flake.lock @@ -0,0 +1,117 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1733328505, + "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs", + "rocm-nix": "rocm-nix" + }, + "locked": { + "lastModified": 1744736115, + "narHash": "sha256-9PPp6XHoMx9jZjwCP7XvAlc52+TmmVuCbUqwh3snuI8=", + "owner": "huggingface", + "repo": "kernel-builder", + "rev": "319af881b27c3645dfc33128f99092c7c1176281", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernel-builder", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1743559129, + "narHash": "sha256-7gpAWsENV3tY2HmeHYQ2MoQxGpys+jQWnkS/BHAMXVk=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "adae22bea8bcc0aa2fd6e8732044660fb7755f5e", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "rocm-nix": { + "inputs": { + "nixpkgs": [ + "kernel-builder", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1743085847, + "narHash": "sha256-uWG29p+nhZmGRV1LffWwRGjwtPIXeu1F0YTQbXgB+GU=", + "owner": "huggingface", + "repo": "rocm-nix", + "rev": "245cdc9bfb4bfafa818711c5f5e0b889afe1ba39", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "rocm-nix", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000000000000000000000000000000000000..220c3165749291b4037be983721b8135904e5623 --- /dev/null +++ b/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for FlashMLA kernel"; + + inputs = { + kernel-builder.url = "github:huggingface/kernel-builder"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genFlakeOutputs { + path = ./.; + rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate; + }; +} diff --git a/flash_mla/flash_fwd_mla_bf16_sm90.cu b/flash_mla/flash_fwd_mla_bf16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..35691f28628d4f66f974d9457a08fef8ea4a8c8f --- /dev/null +++ b/flash_mla/flash_fwd_mla_bf16_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/flash_mla/flash_fwd_mla_fp16_sm90.cu b/flash_mla/flash_fwd_mla_fp16_sm90.cu new file mode 100644 index 0000000000000000000000000000000000000000..abdaf7b5ae35e2d94c4ea9951fd7703b761b0fe0 --- /dev/null +++ b/flash_mla/flash_fwd_mla_fp16_sm90.cu @@ -0,0 +1,3 @@ +#include "flash_fwd_mla_kernel.h" + +template void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); diff --git a/flash_mla/flash_fwd_mla_kernel.h b/flash_mla/flash_fwd_mla_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d96acd88d556b2bdcde31bf580a95350334cfa74 --- /dev/null +++ b/flash_mla/flash_fwd_mla_kernel.h @@ -0,0 +1,603 @@ +#pragma once + +#include +#include +#include +#include + +using namespace cute; + +#include "named_barrier.h" +#include "utils.h" +#include "softmax.h" +#include "static_switch.h" +#include "flash_mla.h" + + +template +constexpr auto getSmemLayoutK() { + constexpr int headSizeBytes = sizeof(PrecType) * DIM; + constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2; + + if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) { + return GMMA::Layout_K_SW128_Atom{}; + } else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) { + return GMMA::Layout_K_SW64_Atom{}; + } else { + return GMMA::Layout_K_SW32_Atom{}; + } +} + +template +struct Flash_fwd_kernel_traits_mla { + using Element = elem_type; + using ElementAccum = float; + using index_t = int64_t; + + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + static constexpr int kNWarpsS = 4; + static constexpr int kNThreadsS = kNWarpsS * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim; + static_assert(kHeadDimV % 32 == 0); + static_assert(kHeadDimV <= kHeadDim); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = decltype(make_tiled_mma( + cute::GMMA::ss_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::K>(), + Layout, _1, _1>>{})); + + static constexpr int AtomLayoutNO = kNThreads / kNThreadsS; + using TiledMmaO = decltype(make_tiled_mma( + cute::GMMA::rs_op_selector, Int, Int>, + GMMA::Major::K, GMMA::Major::MN>(), + Layout, Int, _1>>{})); + + using SmemLayoutQ = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutK = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + + using SmemLayoutV = decltype(tile_to_shape( + getSmemLayoutK(), + Shape, Int>{})); + using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape, Int>{}, GenRowMajor{}))); + + using SmemLayoutP = Layout, Int, _1, Int>>; + using SmemLayoutRow = Layout>, Stride<_1, _2>>; + + using SmemLayoutAtomO = decltype(composition( + Swizzle{}, + Layout, Int>, Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + using SmemCopyAtomOaccum = Copy_Atom, ElementAccum>; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL; + static constexpr int kNThreadsLoad = kNThreads - kNThreadsS; + static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + + using GmemLayoutAtom = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopy = decltype(make_tiled_copy( + Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + + using GmemLayoutAtomO = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyO = decltype(make_tiled_copy( + Copy_Atom, Element>{}, + GmemLayoutAtomO{}, + Layout>{})); // Val layout, 8 vals per store + + static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum); + static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum; + using GmemLayoutAtomOaccum = Layout< + Shape, Int>, + Stride, _1>>; + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout>{})); // Val layout, 4 vals per store +}; + +namespace flash { + +using namespace cute; + +template +struct SharedStorageMLA { + union { + struct { + cute::array_aligned> smem_q; + cute::array_aligned * 2> smem_k; // Double buffer + cute::array_aligned> smem_p; + cute::array_aligned> smem_scale; + }; + struct { + cute::array_aligned> smem_max; + cute::array_aligned> smem_sum; + cute::array_aligned> smem_o; + }; + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void store(const Flash_fwd_mla_params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, + SharedStorage &shared_storage, AccO tOrO, Softmax softmax) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + + // Epilogue + + const int split_offset = __ldg(params.num_splits_ptr + bidb); + + Tensor lse = softmax.template normalize_softmax_lse(tOrO, params.scale_softmax); + + using ElementO = std::conditional_t; + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type(tOrO); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + __syncthreads(); + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape, Int>{}, + make_stride(Split ? kHeadDimV : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)), + Shape>{}, Stride<_1>{}); + + using GmemTiledCopyO = std::conditional_t; + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + if (tidx >= kNThreadsS) { return; } + + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1) + Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor(make_shape(size<2>(tOgOaccum))); + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM + ); +} + +template +__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params ¶ms, + const int bidb, const int bidh, const int m_block, + const int n_split_idx, const int seqlen_k, + const int n_block_min, const int n_block_max, const bool NoSplit, + SharedStorage &shared_storage) { + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kHeadDimV = Kernel_traits::kHeadDimV; + constexpr int kNThreads = Kernel_traits::kNThreads; + constexpr int kNThreadsS = Kernel_traits::kNThreadsS; + static_assert(kNThreads == 256 and kNThreadsS == 128); + using Element = typename Kernel_traits::Element; + using index_t = typename Kernel_traits::index_t; + + const int tidx = threadIdx.x; + int n_block = n_block_max - 1; + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{}); + + Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{}); + Tensor tPsP = sP(_, tidx % kNThreadsS, _, _); + Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS); + Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS); + Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{}); + Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS); + + typename Kernel_traits::TiledMmaO tiled_mma_o; + auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx); + Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N) + Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + clear(tOrO); + + flash::Softmax<2 * size<1>(tOrO)> softmax; + + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if (warp_group_idx == 0) { + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; +#pragma unroll 1 + for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) { + __syncthreads(); + + Tensor tSrS = partition_fragment_C(tiled_mma, Shape, Int>{}); // ((MMA=4, X), MMA_M, MMA_N=1) + flash::gemm(tiled_mma, tSrQ, tSrK, tSrS); + + const bool is_masking_step = masking_step > 0; + const bool is_first_masking_step = masking_step == n_masking_steps; + + if (is_masking_step) { + Tensor cS = make_identity_tensor(Shape, Int>{}); + Tensor tScS = thr_mma.partition_C(cS); +#pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY; + } else { + // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups + int row = int(get<0>(tScS(i))); + int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups; + if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY; + } + } + } + + // We have key_padding_mask so we'll need to Check_inf + Tensor scale_o = is_first_masking_step + ? softmax.template softmax(tSrS, params.scale_softmax_log2) + : is_masking_step ? + softmax.template softmax(tSrS, params.scale_softmax_log2) + : softmax.template softmax(tSrS, params.scale_softmax_log2); + + Tensor rP = flash::convert_type(tSrS); + cute::copy(rP, tPsP); + cute::copy(scale_o, tScale_osScale_o); + + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SReady)); + + flash::rescale_o(tOrO, scale_o); + + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tSrK.data() = tSrK.data() + sK_offset / 8; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + cute::copy(softmax.row_max, tRow_maxsRow_max); + cute::copy(softmax.row_sum, tRow_sumsRow_sum); + cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + } else { + const int *block_table = params.block_table + bidb * params.block_table_batch_stride; + int cur_block_table = __ldg(&block_table[n_block]); + + const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + Shape, Int>{}, + make_stride(params.q_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q; + auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS); + Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ); + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tQpQ = make_tensor(make_shape(size<2>(tQsQ))); + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, + params.seqlen_q - m_block * kBlockM); + + const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride; + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + Shape, Int>{}, + make_stride(params.k_row_stride, _1{})); + typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K; + auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS); + Tensor tKgK = gmem_thr_copy_K.partition_S(gK); + Tensor tKsK = gmem_thr_copy_K.partition_D(sK); + Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + + if (n_block % 2 == 1) { + // Double buffer for sK + constexpr int sK_offset = size(sK); + tKsK.data() = tKsK.data() + sK_offset; + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + // We need to clear the sK smem tiles because K is V. + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK, + seqlen_k - n_block * kBlockN); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + + if (n_block - 1 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 1]); + } + +#pragma unroll 1 + for (; n_block >= n_block_min; --n_block) { + flash::cp_async_wait<0>(); + __syncthreads(); + + if (n_block - 1 >= n_block_min) { + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tKsK.data() = tKsK.data() + sK_offset; + + const index_t offset_k = cur_block_table * params.k_batch_stride; + tKgK.data() = tKgK.data() + offset_k; + flash::copy(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK); + tKgK.data() = tKgK.data() + -offset_k; + cute::cp_async_fence(); + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SReady)); + + if (n_block - 2 >= n_block_min) { + cur_block_table = __ldg(&block_table[n_block - 2]); + } + + typename Kernel_traits::TiledMma tiled_mma; + auto tSrS_layout = partition_fragment_C(tiled_mma, Shape, Int>{}).layout(); + Tensor rP = make_tensor(tSrS_layout); + Tensor scale_o = make_tensor(Shape<_2>{}); + cute::copy(tScale_osScale_o, scale_o); + cute::copy(tPsP, rP); + + flash::rescale_o(tOrO, scale_o); + + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs(rP.layout())); + flash::gemm(tiled_mma_o, tOrP, tOrVt, tOrO); + + // Double buffer for sK + const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK); + tOrVt.data() = tOrVt.data() + sK_offset / 8; + } + + cutlass::arch::NamedBarrier::sync(kNThreads, static_cast(NamedBarriers::SoftmaxReady)); + cute::copy(tRow_maxsRow_max, softmax.row_max); + cute::copy(tRow_sumsRow_sum, softmax.row_sum); + } + + if (NoSplit) + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); + else + store(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax); +} + +template +__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1) +flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + constexpr int kBlockN = Kernel_traits::kBlockN; + const int m_block = blockIdx.x; + const int bidh = blockIdx.y; + const int partition_idx = blockIdx.z; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize; + int4 tile_scheduler_metadata = __ldg(reinterpret_cast(tile_scheduler_metadata_ptr)); + int begin_idx = tile_scheduler_metadata.x; + int begin_seqlen = tile_scheduler_metadata.y; + int end_idx = tile_scheduler_metadata.z; + int end_seqlen = tile_scheduler_metadata.w; + if (begin_idx >= params.b) return; + int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4); + +#pragma unroll 1 + for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) { + const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0; + const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id); + const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0; + const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN); + const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN); + if (batch_id > begin_idx) { + __syncthreads(); // Barrier between two tiles. + } + flash::compute_attn_1rowblock_splitkv_mla(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(256, 1, 1) +flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) { + constexpr int kNThreads = 128; + + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + const int hs = params.h * params.seqlen_q; + const int batch_idx = bidx / hs; + const int hs_idx = bidx % hs; + + const int split_offset = __ldg(params.num_splits_ptr + batch_idx); + const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset; + FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits); + if (actual_num_splits == 1) return; + + __shared__ ElementAccum sLseScale[kMaxSplits]; + + const index_t row_offset_lseaccum = split_offset * hs + hs_idx; + const index_t row_offset_lse = bidx; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lseaccum), + Shape>{}, make_stride(hs)); + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + Shape<_1>{}, Stride<_1>{}); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + float local_lse[kNLsePerThread]; + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY; + } + + float max_lse = -INFINITY; + for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]); + for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset)); + max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf + + float sum_lse = 0; + for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse); + for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset); + + float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse; + if (tidx == 0) gLSE(0) = global_lse; + + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + tidx; + if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + __syncthreads(); + + static_assert(kHeadDimV % kNThreads == 0); + constexpr int Elements = kHeadDimV / kNThreads; + const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.oaccum_ptr) + row_offset_oaccum), + Shape>{}, Stride<_1>{}); + using GmemTiledCopyOaccum = decltype(make_tiled_copy( + Copy_Atom, ElementAccum>{}, + Layout>>{}, + Layout>>{})); + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrOaccum = make_tensor(shape(tOgOaccum)); + Tensor tOrO = make_tensor(shape(tOgOaccum)); + clear(tOrO); + + for (int split = 0; split < actual_num_splits; ++split) { + cute::copy(tOgOaccum, tOrOaccum); + ElementAccum lse_scale = sLseScale[split]; + for (int i = 0; i < size(tOrO); ++i) { + tOrO(i) += lse_scale * tOrOaccum(i); + } + tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV; + } + + Tensor rO = flash::convert_type(tOrO); + const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q; + const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride; + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape(rO))::value>>{}, Stride<_1>{}); + cute::copy(rO, gO); +} + +} // namespace flash + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN); + const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + auto kernel = &flash::flash_fwd_splitkv_mla_kernel; + constexpr size_t smem_size = sizeof(SharedStorage); + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); + + dim3 grid_combine(params.b * params.h * params.seqlen_q); + MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] { + auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel< + typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>; + combine_kernel<<>>(params); + }); + CHECK_CUDA_KERNEL_LAUNCH(); +} + +template +void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream) { + static_assert(Headdim == 576); + FLASH_ASSERT(params.d_v == 512); + FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV + using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>; + run_flash_splitkv_fwd_mla>(params, stream); +} diff --git a/flash_mla/flash_fwd_mla_metadata.cu b/flash_mla/flash_fwd_mla_metadata.cu new file mode 100644 index 0000000000000000000000000000000000000000..82f5b5ace6c80b08a47e4c93218edb79686c47fd --- /dev/null +++ b/flash_mla/flash_fwd_mla_metadata.cu @@ -0,0 +1,77 @@ +#include "flash_fwd_mla_kernel.h" + +static constexpr int MaxBatchSize = 4096; + +__global__ void __launch_bounds__(256, 1, 1) +get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) { + int *seqlens_k_ptr = params.seqlens_k_ptr; + int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr; + int *num_splits_ptr = params.num_splits_ptr; + int batch_size = params.batch_size; + int block_size_n = params.block_size_n; + int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks; + int num_sm_parts = params.num_sm_parts; + + __shared__ int num_blocks_shared[MaxBatchSize]; + __shared__ int num_splits_shared[MaxBatchSize]; + + int total_num_blocks = 0; + for (int i = threadIdx.x; i < batch_size; i += 32) { + int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n); + total_num_blocks += num_blocks + fixed_overhead_num_blocks; + num_blocks_shared[i] = num_blocks; + } + for (int offset = 16; offset >= 1; offset /= 2) { + total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset); + } + __syncwarp(); + + if (threadIdx.x == 0) { + int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks; + + int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0; + num_splits_shared[0] = 0; + for (int i = 0; i < num_sm_parts; ++i) { + int tile_scheduler_metadata0[4], tile_scheduler_metadata1; + tile_scheduler_metadata0[0] = now_idx; + tile_scheduler_metadata0[1] = now_block * block_size_n; + tile_scheduler_metadata1 = now_n_split_idx; + int remain_payload = payload; + while (now_idx < batch_size) { + int num_blocks = num_blocks_shared[now_idx]; + int now_remain_blocks = num_blocks - now_block; + if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) { + cum_num_splits += now_n_split_idx + 1; + num_splits_shared[now_idx + 1] = cum_num_splits; + remain_payload -= now_remain_blocks + fixed_overhead_num_blocks; + ++now_idx; + now_block = 0; + now_n_split_idx = 0; + } else { + if (remain_payload - fixed_overhead_num_blocks > 0) { + now_block += remain_payload - fixed_overhead_num_blocks; + ++now_n_split_idx; + remain_payload = 0; + } + break; + } + } + tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1; + tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1]; + *reinterpret_cast(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast(tile_scheduler_metadata0); + tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1; + } + FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0); + } + __syncwarp(); + + for (int i = threadIdx.x; i <= batch_size; i += 32) { + num_splits_ptr[i] = num_splits_shared[i]; + } +} + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream) { + FLASH_ASSERT(params.batch_size < MaxBatchSize); + get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params); + CHECK_CUDA_KERNEL_LAUNCH(); +} \ No newline at end of file diff --git a/flash_mla/flash_mla.h b/flash_mla/flash_mla.h new file mode 100644 index 0000000000000000000000000000000000000000..2994cb783f7d57277709e937de61adc648bd2fe9 --- /dev/null +++ b/flash_mla/flash_mla.h @@ -0,0 +1,63 @@ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_mla_params { + using index_t = int64_t; + + int b, seqlen_q, d, d_v; + int h, h_h_k_ratio, ngroups; + bool is_causal; + float scale_softmax, scale_softmax_log2; + int *__restrict__ cu_seqlens_k; + + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + void *__restrict__ o_ptr; + void *__restrict__ softmax_lse_ptr; + + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t o_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t o_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + index_t o_head_stride; + + int *__restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + int *__restrict__ tile_scheduler_metadata_ptr; + int num_sm_parts; + int *__restrict__ num_splits_ptr; + + void *__restrict__ softmax_lseaccum_ptr; + void *__restrict__ oaccum_ptr; +}; + +static constexpr int TileSchedulerMetaDataSize = 8; +// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _] + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, cudaStream_t stream); + +struct Mla_metadata_params { + int *__restrict__ seqlens_k_ptr; + int *__restrict__ tile_scheduler_metadata_ptr; + int *__restrict__ num_splits_ptr; + int batch_size; + int block_size_n; + int fixed_overhead_num_blocks; + int num_sm_parts; +}; + +void get_mla_metadata_func(Mla_metadata_params ¶ms, cudaStream_t stream); diff --git a/flash_mla/flash_mla_api.cu b/flash_mla/flash_mla_api.cu new file mode 100644 index 0000000000000000000000000000000000000000..848c8639ab425c71d77496537ca2e1e4d1b423ca --- /dev/null +++ b/flash_mla/flash_mla_api.cu @@ -0,0 +1,208 @@ +#include +#include +#include +#include + +#include "flash_mla.h" +#include "static_switch.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +std::vector +get_mla_metadata( + at::Tensor &seqlens_k, + const int64_t num_heads_per_head_k, + const int64_t num_heads_k +) { + // This should match the logic in the MLA kernel. + static constexpr int block_size_m = 64; + static constexpr int block_size_n = 64; + static constexpr int fixed_overhead_num_blocks = 5; + + CHECK_DEVICE(seqlens_k); + TORCH_CHECK(seqlens_k.is_contiguous()); + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); + + int batch_size = seqlens_k.size(0); + int *seqlens_k_ptr = seqlens_k.data_ptr(); + auto options = seqlens_k.options(); + + auto dprops = at::cuda::getCurrentDeviceProperties(); + int sm_count = dprops->multiProcessorCount; + int num_sm_parts = sm_count / num_heads_k / cutlass::ceil_div(num_heads_per_head_k, block_size_m); + + auto tile_scheduler_metadata = torch::empty({num_sm_parts, TileSchedulerMetaDataSize}, options); + auto num_splits = torch::empty({batch_size + 1}, options); + int *tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + int *num_splits_ptr = num_splits.data_ptr(); + + at::cuda::CUDAGuard device_guard{(char)seqlens_k.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + Mla_metadata_params params = {}; + params.seqlens_k_ptr = seqlens_k_ptr; + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata_ptr; + params.num_splits_ptr = num_splits_ptr; + params.batch_size = batch_size; + params.block_size_n = block_size_n; + params.fixed_overhead_num_blocks = fixed_overhead_num_blocks; + params.num_sm_parts = num_sm_parts; + get_mla_metadata_func(params, stream); + + return {tile_scheduler_metadata, num_splits}; +} + +// note doubles and longs are used in place of floats and ints +// https://github.com/pytorch/pytorch/blob/338ed67a1e7aa98dd849f297533c5a71bea4b661/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h#L211 +std::vector +mha_fwd_kvcache_mla( + at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &kcache, // num_blocks x page_block_size x num_heads_k x head_size + const c10::optional &vcache_, // num_blocks x page_block_size x num_heads_k x head_size_v + const int64_t head_size_v, + const at::Tensor &seqlens_k, // batch_size + const at::Tensor &block_table, // batch_size x max_num_blocks_per_seq + const double softmax_scale, + bool is_causal, + const at::Tensor &tile_scheduler_metadata, // num_sm_parts x TileSchedulerMetaDataSize + const at::Tensor &num_splits // batch_size + 1 +) { + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90); + + at::Tensor vcache = vcache_.has_value() ? vcache_.value() : kcache; + auto q_dtype = q.dtype(); + TORCH_CHECK(kcache.dtype() == q_dtype, "query and key must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(kcache); CHECK_DEVICE(vcache); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + CHECK_DEVICE(block_table); + TORCH_CHECK(block_table.dtype() == torch::kInt32, "block_table must have dtype torch.int32"); + TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension"); + + const auto sizes = q.sizes(); + const int batch_size = sizes[0]; + const int seqlen_q_ori = sizes[1]; + const int num_heads_ori = sizes[2]; + const int head_size = sizes[3]; + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size_v % 32 == 0, "head_size_v should be a multiple of 32"); + + const int max_num_blocks_per_seq = block_table.size(1); + const int num_blocks = kcache.size(0); + const int page_block_size = kcache.size(1); + const int num_heads_k = kcache.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(num_heads_ori % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + if (seqlen_q_ori == 1) { is_causal = false; } + + const int ngroups = num_heads_ori / num_heads_k; + const int seqlen_q = seqlen_q_ori * ngroups; + const int num_heads = num_heads_k; + q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3) + .reshape({batch_size, seqlen_q, num_heads, head_size}); + + int head_size_k = head_size; + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_k); + + // TODO: fix for optional + // if (vcache_.has_value()) { CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); } + CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_v); + + CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq); + + + TORCH_CHECK(seqlens_k.dtype() == torch::kInt32, "seqlens_k must have dtype int32"); + CHECK_DEVICE(seqlens_k); + CHECK_CONTIGUOUS(seqlens_k); + CHECK_SHAPE(seqlens_k, batch_size); + + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + at::Tensor out = torch::empty({batch_size, seqlen_q, num_heads, head_size_v}, opts); + at::Tensor softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + + Flash_fwd_mla_params params = {}; + // Set the sizes. + params.b = batch_size; + params.seqlen_q = seqlen_q; + params.cu_seqlens_k = seqlens_k.data_ptr(); + params.h = num_heads; + params.h_h_k_ratio = num_heads / num_heads_k; + params.ngroups = ngroups; + params.is_causal = is_causal; + params.d = head_size; + params.d_v = head_size_v; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = float(softmax_scale * M_LOG2E); + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = kcache.data_ptr(); + params.v_ptr = vcache.data_ptr(); + params.o_ptr = out.data_ptr(); + params.softmax_lse_ptr = softmax_lse.data_ptr(); + // All stride are in elements, not bytes. + params.q_batch_stride = q.stride(0); + params.k_batch_stride = kcache.stride(0); + params.v_batch_stride = vcache.stride(0); + params.o_batch_stride = out.stride(0); + params.q_row_stride = q.stride(-3); + params.k_row_stride = kcache.stride(-3); + params.v_row_stride = vcache.stride(-3); + params.o_row_stride = out.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = kcache.stride(-2); + params.v_head_stride = vcache.stride(-2); + params.o_head_stride = out.stride(-2); + + params.block_table = block_table.data_ptr(); + params.block_table_batch_stride = block_table.stride(0); + params.page_block_size = page_block_size; + + TORCH_CHECK(tile_scheduler_metadata.dtype() == torch::kInt32, "tile_scheduler_metadata must have dtype int32"); + TORCH_CHECK(tile_scheduler_metadata.size(1) == TileSchedulerMetaDataSize); + CHECK_DEVICE(tile_scheduler_metadata); + CHECK_CONTIGUOUS(tile_scheduler_metadata); + params.tile_scheduler_metadata_ptr = tile_scheduler_metadata.data_ptr(); + params.num_sm_parts = tile_scheduler_metadata.size(0); + TORCH_CHECK(num_splits.dtype() == torch::kInt32, "num_splits must have dtype int32"); + CHECK_DEVICE(num_splits); + CHECK_CONTIGUOUS(num_splits); + params.num_splits_ptr = num_splits.data_ptr(); + + at::Tensor softmax_lse_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor out_accum = torch::empty({batch_size + params.num_sm_parts, num_heads, seqlen_q, head_size_v}, opts.dtype(at::kFloat)); + params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr(); + params.oaccum_ptr = out_accum.data_ptr(); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK(head_size == 576); + + if (q_dtype == torch::kBFloat16) { + run_mha_fwd_splitkv_mla(params, stream); + } + #ifndef FLASH_MLA_DISABLE_FP16 + else if (q_dtype == torch::kHalf) { + run_mha_fwd_splitkv_mla(params, stream); + } + #endif + else { + TORCH_CHECK(false, "Unsupported tensor dtype for query"); + } + + out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3) + .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v}); + softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3) + .reshape({batch_size, num_heads_ori, seqlen_q_ori}); + + return {out, softmax_lse}; +} diff --git a/flash_mla/named_barrier.h b/flash_mla/named_barrier.h new file mode 100644 index 0000000000000000000000000000000000000000..cefa936ca769ad54d86fdfb585fa3bb70204fc06 --- /dev/null +++ b/flash_mla/named_barrier.h @@ -0,0 +1,15 @@ +#pragma once + +#include "cutlass/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts + +enum class NamedBarriers { + SReady = 1, + SoftmaxReady = 2, +}; + +} // flash diff --git a/flash_mla/softmax.h b/flash_mla/softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..4ab6ae9c6c0188c21914ae09bbd7ac36d4effe28 --- /dev/null +++ b/flash_mla/softmax.h @@ -0,0 +1,197 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h + +#pragma once + +#include + +#include +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ auto scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif + } + } + return tensor; +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + #pragma unroll + for (int mi = 0; mi < size(scale_o); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template + __forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scale_o; + clear(scale_o); + if (Is_first) { + flash::template reduce_max(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scale_o(mi) = scores_scale; + row_sum(mi) *= scores_scale; + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + return scale_o; + }; + + template + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +} // namespace flash diff --git a/flash_mla/static_switch.h b/flash_mla/static_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..f156adcca5857ca0d714cfe284d0550073b113ce --- /dev/null +++ b/flash_mla/static_switch.h @@ -0,0 +1,65 @@ +#pragma once + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while(0) + +#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError()) + + +#define FLASH_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + exit(1); \ + } \ + } while(0) + + +#define FLASH_DEVICE_ASSERT(cond) \ + do { \ + if (not (cond)) { \ + printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ + } while(0) + + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + + +#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \ + [&] { \ + if (NUM_SPLITS <= 32) { \ + constexpr static int NAME = 32; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 64) { \ + constexpr static int NAME = 64; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 96) { \ + constexpr static int NAME = 96; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 128) { \ + constexpr static int NAME = 128; \ + return __VA_ARGS__(); \ + } else if (NUM_SPLITS <= 160) { \ + constexpr static int NAME = 160; \ + return __VA_ARGS__(); \ + } else { \ + FLASH_ASSERT(false); \ + } \ + }() diff --git a/flash_mla/utils.h b/flash_mla/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3b8dd527597aed39536f6313cd4dc837933b3db5 --- /dev/null +++ b/flash_mla/utils.h @@ -0,0 +1,238 @@ +// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/utils.h + +#pragma once + +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + if constexpr (!Transposed) { + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } else { + return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) { + auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16)) + return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } else { + static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1); + static_assert(decltype(stride<0, 0>(acc_layout))::value == 1); + static_assert(decltype(stride<0, 1>(acc_layout))::value == 2); + auto l = logical_divide(get<0, 2>(acc_layout), Tile>>{}); // (((2, 2), N / 32)) + // This combines the first two modes (<0, 0> and <0, 1>) into one mode. + // Will require register shuffling later to be correct. + return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)), + get<1>(acc_layout), + coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N) + // This combination is right but doesn't work with register shuffling. + // return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)), + // get<1>(acc_layout), + // coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); + } + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_flash_mla.py b/tests/test_flash_mla.py new file mode 100644 index 0000000000000000000000000000000000000000..75665410f16f598538164849907c49573ad28bd3 --- /dev/null +++ b/tests/test_flash_mla.py @@ -0,0 +1,69 @@ +import torch +import random +import torch.nn.functional as F + +import flash_mla + +# TODO: revise to use the same test as the original code + + +def test_flash_mla(): + # b = 128 + # s_q = 4096 + # mean_sk = 8192 + # h_q = 16 + # h_kv = 1 + # d = 576 + # dv = 512 + + b = 16 + s_q = 16 + mean_sk = 16 + h_q = 16 + h_kv = 1 + d = 576 + dv = 512 + + + causal = True + varlen = False + + print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}") + + cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32) + if varlen: + for i in range(b): + cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q) + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + # TODO: avoid triton from original code + # max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + max_seqlen_pad = max_seqlen + 255 & ~255 # round up to multiple of 256 + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view( + b, max_seqlen_pad // block_size + ) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + print(blocked_k.shape) + for i in range(b): + blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item() :] = float( + "nan" + ) + blocked_v = blocked_k[..., :dv] + print(blocked_k.shape, blocked_v.shape) + + cache_seqlens = cache_seqlens.to("cuda") + + tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata( + seqlens_k=cache_seqlens, + # + s_q=s_q * h_q // h_kv, + h_kv=h_kv, + ) + print(tile_scheduler_metadata, num_splits) + + # TODO: update to expect the correct output + assert False diff --git a/torch-ext/flash_mla/__init__.py b/torch-ext/flash_mla/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2a129688bf920047cca1ca70a9c3521a9b4182 --- /dev/null +++ b/torch-ext/flash_mla/__init__.py @@ -0,0 +1,33 @@ +import torch + +from ._ops import ops + + +def get_mla_metadata(seqlens_k: torch.Tensor, s_q: int, h_kv: int): + return ops.get_mla_metadata(seqlens_k, s_q, h_kv) + + +def mha_fwd_kvcache_mla( + q: torch.Tensor, + kcache: torch.Tensor, + vcache_: torch.Tensor, + head_size_v: int, + seqlens_k: torch.Tensor, + block_table: torch.Tensor, + softmax_scale: float, + is_causal_: bool, + tile_scheduler_metadata: torch.Tensor, + num_splits: torch.Tensor, +) -> torch.Tensor: + return ops.mha_fwd_kvcache_mla( + q, + kcache, + vcache_, + head_size_v, + seqlens_k, + block_table, + softmax_scale, + is_causal_, + tile_scheduler_metadata, + num_splits + ) diff --git a/torch-ext/torch_binding.cpp b/torch-ext/torch_binding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..220c91b4e896292a8e6e054623f6b2a2510aead3 --- /dev/null +++ b/torch-ext/torch_binding.cpp @@ -0,0 +1,15 @@ +#include + +#include "registration.h" +#include "torch_binding.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { + ops.def("get_mla_metadata(Tensor! seqlens_k, int num_heads_per_head_k, int num_heads_k) -> Tensor[]"); + ops.impl("get_mla_metadata", torch::kCUDA, &get_mla_metadata); + + // TOOD: remove last unknown_param when resolved + ops.def("mha_fwd_kvcache_mla(Tensor! q, Tensor! kcache, Tensor? vcache_, int head_size_v, Tensor! seqlens_k, Tensor! block_table, float softmax_scale, bool is_causal_, Tensor! tile_scheduler_metadata, Tensor! num_splits) -> Tensor[]"); + ops.impl("mha_fwd_kvcache_mla", torch::kCUDA, &mha_fwd_kvcache_mla); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/torch-ext/torch_binding.h b/torch-ext/torch_binding.h new file mode 100644 index 0000000000000000000000000000000000000000..bb9ff36af1b0187e0a1f5778cfa0639b1bfbc5eb --- /dev/null +++ b/torch-ext/torch_binding.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +std::vector +get_mla_metadata( + torch::Tensor &seqlens_k, + const int64_t num_heads_per_head_k, + const int64_t num_heads_k +); + +std::vector +mha_fwd_kvcache_mla( + torch::Tensor &q, + const torch::Tensor &kcache, + const c10::optional &vcache_, + const int64_t head_size_v, + const torch::Tensor &seqlens_k, + const torch::Tensor &block_table, + const double softmax_scale, + bool is_causal, + const torch::Tensor &tile_scheduler_metadata, + const torch::Tensor &num_splits +); \ No newline at end of file