Foundation Model Engineering

7.4 Flash Attention 1, 2, 3

In Chapter 3, we analyzed the complexity of self-attention and noted that its time and memory complexity scale quadratically with the sequence length (O(L2)O(L^2)). As we push foundation models to handle massive contexts (e.g., 128k to 1M+ tokens), this quadratic scaling becomes a fatal bottleneck.

However, the primary bottleneck in scaling attention is often not compute (FLOPs), but memory bandwidth (I/O). The time spent moving the large attention matrix between slow GPU High Bandwidth Memory (HBM) and fast on-chip SRAM dominates the execution time.

FlashAttention revolutionized transformer training by making attention IO-Aware. It reduces the number of memory reads/writes between HBM and SRAM, achieving massive speedups while yielding mathematically exact results.

In this section, we trace the evolution of FlashAttention from its original tiling concept to the hardware-specific optimizations of FlashAttention-3.


1. The Memory Wall in Attention

Standard attention computes the matrix S=QKT\mathbf{S} = \mathbf{Q}\mathbf{K}^T, writes it to HBM, reads it back to compute softmax, writes P=softmax(S)\mathbf{P} = \text{softmax}(\mathbf{S}) to HBM, and finally reads it back to compute O=PV\mathbf{O} = \mathbf{P}\mathbf{V}.

For a sequence length of L=4096L=4096, the attention matrix alone consumes 4096×4096×4 bytes64 MB4096 \times 4096 \times 4 \text{ bytes} \approx 64 \text{ MB} per head. Moving this matrix back and forth between HBM and SRAM wastes immense amounts of time.


2. FlashAttention-1: Tiling and Recomputation

Introduced by Dao et al. (2022) [1], FlashAttention-1 addresses the memory wall by avoiding the materialization of the full O(L2)O(L^2) attention matrix in slow HBM. It relies on two core ideas:

2.1 Tiling (Forward Pass)

FlashAttention loads blocks of Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V} from HBM to the fast, small on-chip SRAM. It computes attention for these blocks and writes the output back to HBM. To do this without seeing the full row of the attention matrix (which is needed for the Softmax denominator), it uses Online Softmax (based on work by Milakov and Gimelshein). This algorithm tracks running maximums and sums of exponentials to rescale the output incrementally, yielding mathematically exact results without materializing the full matrix.

2.2 Recomputation (Backward Pass)

To compute gradients during the backward pass, standard attention requires the stored L×LL \times L attention matrix. FlashAttention avoids storing this by recomputing it on the fly in SRAM during the backward pass using the stored blocks of Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V}. While this adds some FLOPs (recomputation), it drastically reduces HBM reads/writes, resulting in a net speedup of 2-4x.


3. FlashAttention-2: Better Parallelism and Work Partitioning

FlashAttention-2 (2023) [2] recognized that while FA-1 reduced I/O, it left some GPU compute resources underutilized. It introduced several algorithmic refinements:

  • Parallelism over Sequence Length: FA-1 parallelized over batch size and number of heads. For small batch sizes or long sequences, this left many GPU streaming multiprocessors (SMs) idle. FA-2 adds parallelism over the sequence length dimension (blocks of Q\mathbf{Q}), significantly improving utilization.
  • Refactored Online Softmax: FA-2 refactored the online softmax to reduce the number of non-matrix-multiplication operations (like exponentials), which are slow on GPUs compared to the heavily optimized Tensor Cores.
  • Support for Head Dimensions up to 256: Expanded support for larger head dimensions used in some modern architectures.

These changes resulted in a 2x speedup over FA-1, achieving up to 70% of the theoretical peak FLOPs on A100 GPUs.


4. FlashAttention-3: Asynchrony and Low-Precision for Hopper

FlashAttention-3 (2024) [3] targets the specific architectural features of NVIDIA’s Hopper architecture (H100) to push performance even further. The Hopper architecture introduced features that FlashAttention-3 exploits:

  • Asynchronous Execution (Overlapping Compute and I/O): Hopper introduced TMA (Tensor Memory Accelerator), which can move data between HBM and SRAM asynchronously, independent of the Tensor Cores. FA-3 overlaps the loading of the next block of K\mathbf{K} and V\mathbf{V} with the computation of the current block, hiding I/O latency completely.
  • WGMMA (Warpgroup Matrix Multiply-Accumulate): FA-3 utilizes H100’s new WGMMA instructions, which are designed for larger matrix operations and offer higher throughput than legacy instructions.
  • Natively Support FP8: FP8 reduces memory bandwidth requirements by half compared to FP16. FA-3 handles the scaling factors required for low-precision matrix multiplication accurately, maintaining model quality while doubling throughput.

4.5 Comparison Table: Evolution of FlashAttention

FeatureFlashAttention-1FlashAttention-2FlashAttention-3
Primary FocusIO-Awareness & TilingParallelism & Work PartitioningAsynchrony & Hopper Optimization
Hardware TargetAmpere (A100) and olderAmpere (A100) and newerHopper (H100) specific
ParallelismBatch, HeadsBatch, Heads, Seq LenBatch, Heads, Seq Len
I/O StrategySynchronous TilingSynchronous TilingAsynchronous (TMA)
PrecisionFP16/BF16FP16/BF16FP8 Supported
Peak FLOPs (A100)~30-40%~70%N/A (Optimized for H100)

5. PyTorch Implementation: Using Flash Attention

Modern PyTorch (2.0+) makes using Flash Attention incredibly easy through the scaled_dot_product_attention function. You don’t need to write CUDA code; PyTorch will automatically use Flash Attention if the hardware supports it.

import torch
import torch.nn.functional as F

# Check if CUDA is available and get device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define dimensions
batch_size = 4
num_heads = 8
seq_len = 2048
head_dim = 64

# Create random Query, Key, Value tensors
q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
k = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
v = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)

# PyTorch 2.0+ Scaled Dot Product Attention
# It automatically selects the best implementation (FlashAttention, Memory Efficient, or Math)
# based on the inputs and hardware.
with torch.inference_mode():
    # Standard usage
    output = F.scaled_dot_product_attention(q, k, v)
    
    print(f"Output shape: {output.shape}")

# To explicitly force or disable specific implementations, you can use a context manager:
# with torch.nn.attention.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
#     output = F.scaled_dot_product_attention(q, k, v)

Quizzes

Quiz 1: What is the primary bottleneck that FlashAttention attempts to solve? Is it compute (FLOPs) or memory bandwidth (I/O)? The primary bottleneck is memory bandwidth (I/O). The time spent reading and writing the massive L×LL \times L attention matrix between the slow High Bandwidth Memory (HBM) and the fast on-chip SRAM dominates the execution time, rather than the actual floating-point operations required for matrix multiplication.

Quiz 2: How does FlashAttention-1 compute the softmax function correctly without loading the entire row of the attention matrix into SRAM at once? FlashAttention uses a technique called online softmax. It processes the row in blocks (tiles) and maintains running statistics (specifically the running maximum and the sum of exponentials). When moving to a new block, it rescales the accumulated results using the new maximum, allowing it to compute the exact softmax result without ever materializing the full row in memory.

Quiz 3: Why does FlashAttention recompute the attention matrix during the backward pass instead of storing it during the forward pass? Storing the L×LL \times L attention matrix during the forward pass would require O(L2)O(L^2) memory, defeating the main purpose of FlashAttention (reducing memory footprint). By recomputing the attention matrix on the fly in fast SRAM using the stored blocks of Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V} during the backward pass, FlashAttention trades a small amount of extra compute for a massive reduction in memory bandwidth and storage costs.

Quiz 4: Calculate the memory footprint of the intermediate attention matrix (S=QKT\mathbf{S} = \mathbf{Q}\mathbf{K}^T) for a single head with sequence length L=128,000L = 128,000 using FP16. How does FlashAttention avoid an Out-Of-Memory (OOM) error in this scenario? For FP16, each element consumes 2 bytes. The intermediate matrix has dimensions L×L=128,000×128,000=1.6384×1010L \times L = 128,000 \times 128,000 = 1.6384 \times 10^{10} elements. The memory footprint is 1.6384×1010×2 bytes32.77 GB1.6384 \times 10^{10} \times 2 \text{ bytes} \approx 32.77 \text{ GB} per head. For a standard hidden size with 32 heads, this would require over 1 Terabyte of VRAM, causing an immediate OOM. FlashAttention avoids this by utilizing SRAM tiling, materializing only smaller blocks (e.g., 64×6464 \times 64), ensuring the intermediate matrix is never written to HBM.


References

  1. Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and memory-efficient exact attention with IO-awareness. arXiv:2205.14135.
  2. Dao, T. (2023). FlashAttention-2: Faster attention with better parallelism and work partitioning. arXiv:2307.08691.
  3. Dao, T., & Haziza, N. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08608.