Foundation Model Engineering

7.4 Flash Attention 1, 2, 3

3장에서 우리는 자가 어텐션(Self-attention)의 복잡도를 분석하면서 시간과 메모리 복잡도가 시퀀스 길이(LL)의 제곱에 비례하여 증가한다는 점(O(L2)O(L^2))을 확인했습니다. 파운데이션 모델이 128k에서 1M 이상의 거대한 컨텍스트를 처리하도록 확장됨에 따라, 이 이차식(quadratic) 스케일링은 치명적인 병목 현상이 됩니다.

그러나 어텐션을 확장하는 데 있어 주된 병목은 종종 연산량(FLOPs)이 아니라 메모리 대역폭(Memory Bandwidth, I/O) 입니다. 느린 GPU HBM(High Bandwidth Memory)과 빠른 칩 내부 SRAM 사이에서 거대한 어텐션 행렬을 읽고 쓰는 데 걸리는 시간이 실행 시간을 지배합니다.

FlashAttention 은 어텐션을 IO-Aware 하게 만듦으로써 트랜스포머 학습에 혁명을 일으켰습니다. HBM과 SRAM 간의 메모리 읽기/쓰기 횟수를 줄여, 수학적으로 완벽하게 동일한 결과를 내면서도 엄청난 속도 향상을 달성합니다.

이 섹션에서는 최초의 타일링(Tiling) 개념부터 FlashAttention-3의 하드웨어 특화 최적화에 이르기까지 FlashAttention의 진화를 추적합니다.


1. 어텐션의 메모리 벽 (The Memory Wall)

표준 어텐션은 행렬 S=QKT\mathbf{S} = \mathbf{Q}\mathbf{K}^T 를 계산하여 HBM에 쓰고, 이를 다시 읽어와 소프트맥스(softmax)를 계산한 뒤 P=softmax(S)\mathbf{P} = \text{softmax}(\mathbf{S}) 를 HBM에 쓰고, 마지막으로 이를 다시 읽어와 O=PV\mathbf{O} = \mathbf{P}\mathbf{V} 를 계산합니다.

시퀀스 길이가 L=4096L=4096 일 때, 어텐션 행렬 하나만 해도 헤드당 4096×4096×4 bytes64 MB4096 \times 4096 \times 4 \text{ bytes} \approx 64 \text{ MB} 의 메모리를 소비합니다. 이 행렬을 HBM과 SRAM 사이로 계속해서 이동시키는 것은 엄청난 시간 낭비를 초래합니다.


2. FlashAttention-1: 타일링과 재계산 (Tiling and Recomputation)

Dao 등(2022) [1] 에 의해 도입된 FlashAttention-1은 두 가지 주요 기술을 사용하여 메모리 병목 현상을 해결합니다.

2.1 타일링 (Tiling)

FlashAttention은 L×LL \times L 어텐션 행렬을 한 번에 모두 계산하는 대신, Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V} 를 GPU의 빠른 SRAM에 들어갈 수 있는 크기의 블록(타일)으로 나눕니다.

  1. Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V} 의 블록을 SRAM으로 로드합니다.
  2. 해당 블록에 대한 어텐션을 계산합니다.
  3. HBM에 있는 출력값을 업데이트합니다.

한 번에 전체 행을 보지 않고 블록 간에 소프트맥스를 올바르게 계산하기 위해, FlashAttention은 실행 중인 최댓값과 지수 합을 추적하는 online softmax 기법을 활용합니다.

2.2 재계산 (Recomputation, 역전파 시)

역전파(Backward pass)를 위해 L×LL \times L 어텐션 행렬을 저장하는 것은 메모리 절약 효과를 무력화하므로, FlashAttention은 이를 저장하지 않습니다. 대신 역전파 과정에서 저장된 Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V} 블록을 사용하여 SRAM에서 즉석으로 어텐션 행렬을 재계산 합니다. 이는 약간의 추가 연산을 대가로 메모리 사용량을 획기적으로 줄이는 트레이드오프입니다.


3. FlashAttention-2: 더 빠른 속도와 향상된 병렬화

FlashAttention-2 (2023) [2] 는 작업 분할(Work partitioning)과 병렬화를 최적화하여 오리지널 버전을 개선했습니다.

  • 더 나은 작업 분할: 배치(Batch)와 헤드(Head) 차원뿐만 아니라 시퀀스 길이 차원에 대해서도 연산을 병렬화하여 GPU 활용도를 높였습니다.
  • 비행렬 곱 연산(Non-MatMul FLOPs) 감소: GPU에서 행렬 곱에 비해 상대적으로 느린 지수(exponential) 연산 횟수를 줄이도록 online softmax 계산 방식을 리팩토링했습니다.

FlashAttention-2는 1버전에 비해 최대 2배의 속도 향상을 달성했으며, A100 GPU의 이론적 피크 FLOPs의 최대 70%에 도달했습니다.


4. FlashAttention-3: Hopper 및 FP8 최적화

FlashAttention-3 (2024) [3] 는 새로운 하드웨어 기능을 활용하기 위해 NVIDIA Hopper 아키텍처(예: H100)를 타겟으로 설계되었습니다.

  • WGMMA (Warpgroup Matrix Multiply-Accumulate): H100의 새로운 하드웨어 명령어를 활용하여 행렬 곱셈 속도를 대폭 향상시켰습니다.
  • 비동기 실행 (Asynchronous Execution): 행렬 곱셈 연산과 HBM-SRAM 간의 데이터 이동을 중첩(Overlap)시켜 데이터 전송 지연 시간을 숨겼습니다.
  • FP8 지원: 새로운 8비트 부동 소수점 포맷을 네이티브하게 지원하여 훨씬 더 빠른 연산과 더 낮은 메모리 대역폭 사용을 가능하게 했습니다.

4.5 비교 표: FlashAttention의 진화

특성FlashAttention-1FlashAttention-2FlashAttention-3
주요 초점IO-Awareness 및 타일링병렬화 및 작업 분할비동기성 및 Hopper 최적화
타겟 하드웨어Ampere (A100) 및 이전 세대Ampere (A100) 및 이후 세대Hopper (H100) 특화
병렬화배치, 헤드배치, 헤드, 시퀀스 길이배치, 헤드, 시퀀스 길이
I/O 전략동기식 타일링동기식 타일링비동기식 (TMA)
정밀도FP16/BF16FP16/BF16FP8 지원
피크 FLOPs (A100)~30-40%~70%해당 없음 (H100에 최적화됨)

5. PyTorch 구현: Flash Attention 사용하기

현대 PyTorch (2.0+) 에서는 scaled_dot_product_attention 함수를 통해 Flash Attention 을 매우 쉽게 사용할 수 있습니다. 사용자가 직접 CUDA 코드를 작성할 필요가 없으며, 하드웨어가 지원하는 경우 PyTorch가 자동으로 Flash Attention 을 사용합니다.

import torch
import torch.nn.functional as F

# CUDA 사용 가능 여부 확인 및 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 차원 정의
batch_size = 4
num_heads = 8
seq_len = 2048
head_dim = 64

# 무작위 Query, Key, Value 텐서 생성
q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
k = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
v = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)

# PyTorch 2.0+ Scaled Dot Product Attention
# 입력값과 하드웨어에 따라 최적의 구현체(FlashAttention, Memory Efficient, 또는 Math)를 자동으로 선택합니다.
with torch.inference_mode():
    # 표준적인 사용법
    output = F.scaled_dot_product_attention(q, k, v)
    
    print(f"Output shape: {output.shape}")

# 특정 구현체를 강제하거나 비활성화하려면 컨텍스트 매니저를 사용할 수 있습니다:
# with torch.nn.attention.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
#     output = F.scaled_dot_product_attention(q, k, v)

Quizzes

Quiz 1: FlashAttention 이 해결하고자 하는 주된 병목 현상은 무엇입니까? 연산량(FLOPs)입니까, 아니면 메모리 대역폭(I/O)입니까? 주된 병목 현상은 메모리 대역폭(I/O)입니다. 느린 HBM과 빠른 칩 내부 SRAM 사이에서 거대한 L×LL \times L 어텐션 행렬을 읽고 쓰는 데 걸리는 시간이, 행렬 곱셈에 필요한 실제 부동 소수점 연산 시간보다 훨씬 큽니다.

Quiz 2: FlashAttention-1은 어텐션 행렬의 전체 행을 한 번에 SRAM에 로드하지 않고 어떻게 소프트맥스 함수를 올바르게 계산합니까? FlashAttention은 online softmax 라는 기법을 사용합니다. 행을 블록(타일) 단위로 처리하면서 실행 중인 통계치(최댓값 및 지수 합)를 유지합니다. 새로운 블록으로 이동할 때 새로운 최댓값을 기반으로 기존에 누적된 결과를 재조정(rescale)함으로써, 메모리에 전체 행을 올리지 않고도 정확한 소프트맥스 결과를 계산할 수 있습니다.

Quiz 3: 왜 FlashAttention 은 순전파 시에 어텐션 행렬을 저장하지 않고 역전파 시에 이를 재계산합니까? 순전파 시에 L×LL \times L 어텐션 행렬을 저장하면 O(L2)O(L^2) 의 메모리가 필요하므로 FlashAttention 의 주된 목적인 메모리 절약 효과가 사라집니다. 따라서 저장된 Q,K,V\mathbf{Q}, \mathbf{K}, \mathbf{V} 블록을 사용하여 빠른 SRAM에서 즉석으로 어텐션 행렬을 재계산함으로써, 약간의 추가 연산을 수행하는 대신 메모리 대역폭과 저장 비용을 획기적으로 줄일 수 있습니다.

Quiz 4: FP16을 적용하고 시퀀스 길이가 L=128,000L = 128,000일 때, 단일 헤드에 대한 중간 어텐션 행렬 (S=QKT\mathbf{S} = \mathbf{Q}\mathbf{K}^T)의 메모리 점유율을 계산하고, FlashAttention이 이 시나리오에서 Out-Of-Memory(OOM) 에러를 방지하는 원리를 설명하시오. FP16을 사용할 경우, 각 원소는 2바이트를 소비합니다. 중간 행렬의 차원은 L×L=128,000×128,000=1.6384×1010L \times L = 128,000 \times 128,000 = 1.6384 \times 10^{10}개의 원소입니다. 따라서 헤드당 메모리 사용량은 1.6384×1010×2 바이트32.77 GB1.6384 \times 10^{10} \times 2 \text{ 바이트} \approx 32.77 \text{ GB} 입니다. 만약 32개의 헤드를 가진 표준 모델이라면 총 1테라바이트가 넘는 VRAM을 요구하여 즉각 OOM 에러가 발생합니다. FlashAttention은 SRAM 타이링(Tiling) 기법을 활용하여 매우 작은 블록(예: 64×6464 \times 64)만을 SRAM에 올려 계산하므로, 거대한 중간 행렬을 HBM에 절대 기록하지 않아 OOM 에러를 방지합니다.


References

  1. Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and memory-efficient exact attention with IO-awareness. arXiv:2205.14135.
  2. Dao, T. (2023). FlashAttention-2: Faster attention with better parallelism and work partitioning. arXiv:2307.08691.
  3. Dao, T., & Haziza, N. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08608.