FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
Paper β’ 2205.14135 β’ Published β’ 15
A pure-Triton implementation of Flash Attention 1 (Dao et al., 2022) packaged for the Hugging Face Kernel Hub.
Unlike hand-written CUDA implementations, this kernel is written entirely in Python/Triton and is JIT-compiled at runtime, making it easy to read, modify, and experiment with.
Flash Attention avoids materialising the full N Γ N attention matrix in HBM by fusing the softmax and the value-weighted sum into a single tiled pass using the online softmax trick (Milakov & Gimelshein, 2018):
O_i β softmax(Q_i Β· Kα΅) Β· V (tiled over K/V, never storing full S)
Memory complexity drops from O(NΒ²) β O(N Β· d), which is the primary bottleneck for long-context inference and training.
kernels package
import torch
from kernels import get_kernel
fa = get_kernel("kernels-community/flash-attention", version=1)
B, H, N, d = 2, 8, 1024, 64
q = torch.randn(B, H, N, d, device="cuda", dtype=torch.float16)
k = torch.randn(B, H, N, d, device="cuda", dtype=torch.float16)
v = torch.randn(B, H, N, d, device="cuda", dtype=torch.float16)
out = fa.flash_attention_forward(q, k, v, causal=False)
print(out.shape) # [2, 8, 1024, 64]
# 1. Clone
git clone https://huggingface.co/kernels-community/flash-attention
cd flash-attention
# 2. Install dependencies
pip install torch triton pytest
# 3. Run tests
pytest tests/ -v
# 4. Run benchmark
python benchmarks/bench_flash_attention.py
| Sequence length | Flash-Attn Triton | PyTorch ref | Speedup |
|---|---|---|---|
| 128 | 0.11 ms | 0.19 ms | 1.70Γ |
| 256 | 0.15 ms | 0.24 ms | 1.57Γ |
| 512 | 0.17 ms | 0.43 ms | 2.47Γ |
| 1024 | 0.29 ms | 1.78 ms | 6.15Γ |
| 2048 | 0.79 ms | 7.11 ms | 8.98Γ |
| 4096 | 2.54 ms | 27.01 ms | 10.63Γ |
flash-attention-1-triton/
βββ build.toml # kernel-builder configuration
βββ flake.nix # Nix build environment
βββ flash_attention_kernel/
β βββ flash_attention.py # Triton forward/backward kernels + launcher
βββ torch-ext/
β βββ torch_binding.h # C++ op declaration
β βββ torch_binding.cpp # Torch op registration
β βββ flash_attention/
β βββ __init__.py # Python-level wrapper (uses _ops alias)
βββ tests/
β βββ test_flash_attention.py # pytest correctness & smoke tests
βββ benchmarks/
β βββ bench_flash_attention.py # triton.testing perf report
βββ README.md