Kernels

flash-attention β€” Triton kernel

A pure-Triton implementation of Flash Attention 1 (Dao et al., 2022) packaged for the Hugging Face Kernel Hub.

Unlike hand-written CUDA implementations, this kernel is written entirely in Python/Triton and is JIT-compiled at runtime, making it easy to read, modify, and experiment with.

Algorithm

Flash Attention avoids materialising the full N Γ— N attention matrix in HBM by fusing the softmax and the value-weighted sum into a single tiled pass using the online softmax trick (Milakov & Gimelshein, 2018):

O_i ← softmax(Q_i Β· Kα΅€) Β· V        (tiled over K/V, never storing full S)

Memory complexity drops from O(NΒ²) β†’ O(N Β· d), which is the primary bottleneck for long-context inference and training.

Usage

Via the kernels package

import torch
from kernels import get_kernel

fa = get_kernel("kernels-community/flash-attention", version=1)

B, H, N, d = 2, 8, 1024, 64
q = torch.randn(B, H, N, d, device="cuda", dtype=torch.float16)
k = torch.randn(B, H, N, d, device="cuda", dtype=torch.float16)
v = torch.randn(B, H, N, d, device="cuda", dtype=torch.float16)

out = fa.flash_attention_forward(q, k, v, causal=False)
print(out.shape)  # [2, 8, 1024, 64]

Local development

# 1. Clone
git clone https://huggingface.co/kernels-community/flash-attention
cd flash-attention

# 2. Install dependencies
pip install torch triton pytest

# 3. Run tests
pytest tests/ -v

# 4. Run benchmark
python benchmarks/bench_flash_attention.py

Performance

Sequence length Flash-Attn Triton PyTorch ref Speedup
128 0.11 ms 0.19 ms 1.70Γ—
256 0.15 ms 0.24 ms 1.57Γ—
512 0.17 ms 0.43 ms 2.47Γ—
1024 0.29 ms 1.78 ms 6.15Γ—
2048 0.79 ms 7.11 ms 8.98Γ—
4096 2.54 ms 27.01 ms 10.63Γ—

Repository structure

flash-attention-1-triton/
β”œβ”€β”€ build.toml                        # kernel-builder configuration
β”œβ”€β”€ flake.nix                         # Nix build environment
β”œβ”€β”€ flash_attention_kernel/
β”‚   └── flash_attention.py            # Triton forward/backward kernels + launcher
β”œβ”€β”€ torch-ext/
β”‚   β”œβ”€β”€ torch_binding.h               # C++ op declaration
β”‚   β”œβ”€β”€ torch_binding.cpp             # Torch op registration
β”‚   └── flash_attention/
β”‚       └── __init__.py               # Python-level wrapper (uses _ops alias)
β”œβ”€β”€ tests/
β”‚   └── test_flash_attention.py       # pytest correctness & smoke tests
β”œβ”€β”€ benchmarks/
β”‚   └── bench_flash_attention.py      # triton.testing perf report
└── README.md

References

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Papers for sigmoid-neuron/flash-attention-1-triton