Flash Attention v2 โ Ampere (SM80)
Flash Attention v2 forward pass implemented as a CUDA kernel targeting NVIDIA Ampere GPUs (SM80). Based on the CUTLASS CuTe DSL reference implementation.
Supported Hardware
| GPU | Compute Capability |
|---|---|
| A100 | sm_80 |
| A10 / A30 | sm_86 |
| Ada / L40 | sm_89 |
Installation
pip install kernels
Usage
import torch
from kernels import get_kernel
fa2 = get_kernel("pranay5255/flash-attn-v2-ampere")
# Tensor layout: [batch, seqlen, num_heads, head_dim]
q = torch.randn(2, 1024, 8, 64, dtype=torch.float16, device="cuda")
k = torch.randn(2, 1024, 8, 64, dtype=torch.float16, device="cuda")
v = torch.randn(2, 1024, 8, 64, dtype=torch.float16, device="cuda")
# Standard forward pass
out = fa2.forward(q, k, v)
# With causal mask
out = fa2.forward(q, k, v, is_causal=True)
# With custom softmax scale
out = fa2.forward(q, k, v, softmax_scale=0.125)
API
fa2.forward(q, k, v, ...)
| Parameter | Type | Default | Description |
|---|---|---|---|
q, k, v |
torch.Tensor |
required | Shape [batch, seqlen, num_heads, head_dim]. Must be contiguous, on CUDA. |
softmax_scale |
float |
1/sqrt(head_dim) |
Scale applied before softmax. |
is_causal |
bool |
False |
Apply causal (upper-triangular) mask. |
m_block_size |
int |
128 |
Tile size for the Q dimension. |
n_block_size |
int |
64 |
Tile size for the K/V dimension. |
num_threads |
int |
128 |
Threads per CTA. |
out |
torch.Tensor |
None |
Optional pre-allocated output. |
Returns: torch.Tensor of shape [batch, seqlen_q, num_heads, head_dim], same dtype as input.
Supported dtypes: torch.float16, torch.bfloat16
Constraint: head_dim must be a multiple of 8 (16-byte alignment).
Build Configuration
[general]
name = "flash_attn_v2_ampere"
backends = ["cuda"]
[kernel.flash_attn_v2]
backend = "cuda"
src = ["kernel_src/flash_attention_v2.cu"]
cuda-capabilities = ["8.0"]
License
BSD 3-Clause (NVIDIA CUTLASS reference).
- Downloads last month
- -
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐ Ask for provider support