FlashAttention: 2-4x Faster Exact Attention on GPUs
Replace PyTorch's scaled_dot_product_attention with FlashAttention kernels to cut transformer training memory by 3x+ and speed up by 2-4x via IO-aware tiling that fuses softmax and skips materializing N^2 attention matrix.
IO-Aware Kernel Design Cuts Memory and Boosts Speed
FlashAttention computes exact attention without storing the full N^2 attention matrix or gradients, using GPU tiling to maximize SRAM usage and minimize HBM reads/writes. This yields 2-4x end-to-end speedups in transformer training on A100 GPUs (e.g., 2.4x for GPT-2 style models) and 3-5x memory savings, enabling longer sequences like 64k tokens on single A100 vs. 16k baseline. Backward pass fuses dP computation with dV, avoiding extra softmax. FlashAttention-2 improves parallelism with better work partitioning (50-73% TFLOPS utilization on A100), supports bf16 on Ampere+, head dims to 256, causal masks aligned to bottom-right for decoder use, and sliding window attention (window_size=(left,right)).
Trade-offs: Requires Ampere+ GPUs (A100/RTX30/40/H100); head dim >192 backward needed A100/H100 originally but now works on consumer GPUs without dropout since v2.5.5. Deterministic backward option trades minor speed/memory for reproducibility.
Installation Matches Hardware for Peak Performance
Install via pip install flash-attn --no-build-isolation (3-5 min compile with ninja on 64-core, CUDA 12+). Needs PyTorch 2.2+, packaging/psutil/ninja. Limit jobs with MAX_JOBS=4 on low-RAM machines. ROCm 6.0+ supports MI200+/RDNA3/4 GPUs via composable_kernel (default, fp16/bf16 fwd/bwd) or Triton backend (fp16/bf16/fp32, causal/MQA/GQA/paged/FP8). Use Nvidia/ROCm PyTorch containers for deps.
Beta FlashAttention-3 (H100/H800, CUDA 12.3+, FP16/BF16 fwd/bwd, FP8 fwd) via separate install; FlashAttention-4 (CuTeDSL, H100/B200, pip install flash-attn-4[cu13]) for Hopper/Blackwell. Huggingface kernels offer drop-in via get_kernel('kernels-community/flash-attn2').
Usage Replaces Standard Attention with KV Cache Support
Core: out = flash_attn_func(q, k, v, softmax_scale=1/math.sqrt(d), causal=True, dropout_p=0.0) or flash_attn_qkvpacked_func(qkv) for packed inputs (faster bwd). Supports MQA/GQA (nheads_Q % nheads_KV == 0), ALiBi (alibi_slopes), softcapping (Gemma/Grok), paged KV cache (block_table), variable seq lens.
Inference: flash_attn_with_kvcache(q, k_cache, v_cache, k=new_k, v=new_v, rotary_cos/sin, cache_seqlens) updates cache inplace, applies RoPE, causal/local masks. Example causal mask for seqlen_q=2, seqlen_k=5: attends to last 2+3 positions bottom-right aligned. Integrate in MHA via flash_attn/modules/mha.py. Set dropout_p=0.0 eval; deterministic=True bwd for reproducibility.
Evolutions Unlock New Workloads
v2.0: 2x faster rewrite, flash_attn_varlen_* for ragged batches. v2.1+: Causal realignment, inference opts (split KV load for seqlen_q=1). v2.3+: Sliding window (Mistral 7B). v2.4+: ALiBi, deterministic bwd. v2.5+: PagedAttention. v2.6+: Softcap. v2.7+: torch.compile compat. Widely adopted (usage.md lists integrations).