7.4 Flash Attention 1, 2, 3
3장에서 우리는 자가 어텐션(Self-attention)의 복잡도를 분석하면서 시간과 메모리 복잡도가 시퀀스 길이()의 제곱에 비례하여 증가한다는 점()을 확인했습니다. 파운데이션 모델이 128k에서 1M 이상의 거대한 컨텍스트를 처리하도록 확장됨에 따라, 이 이차식(quadratic) 스케일링은 치명적인 병목 현상이 됩니다.
그러나 어텐션을 확장하는 데 있어 주된 병목은 종종 연산량(FLOPs)이 아니라 메모리 대역폭(Memory Bandwidth, I/O) 입니다. 느린 GPU HBM(High Bandwidth Memory)과 빠른 칩 내부 SRAM 사이에서 거대한 어텐션 행렬을 읽고 쓰는 데 걸리는 시간이 실행 시간을 지배합니다.
FlashAttention 은 어텐션을 IO-Aware 하게 만듦으로써 트랜스포머 학습에 혁명을 일으켰습니다. HBM과 SRAM 간의 메모리 읽기/쓰기 횟수를 줄여, 수학적으로 완벽하게 동일한 결과를 내면서도 엄청난 속도 향상을 달성합니다.
이 섹션에서는 최초의 타일링(Tiling) 개념부터 FlashAttention-3의 하드웨어 특화 최적화에 이르기까지 FlashAttention의 진화를 추적합니다.
1. 어텐션의 메모리 벽 (The Memory Wall)
표준 어텐션은 행렬 를 계산하여 HBM에 쓰고, 이를 다시 읽어와 소프트맥스(softmax)를 계산한 뒤 를 HBM에 쓰고, 마지막으로 이를 다시 읽어와 를 계산합니다.
시퀀스 길이가 일 때, 어텐션 행렬 하나만 해도 헤드당 의 메모리를 소비합니다. 이 행렬을 HBM과 SRAM 사이로 계속해서 이동시키는 것은 엄청난 시간 낭비를 초래합니다.
2. FlashAttention-1: 타일링과 재계산 (Tiling and Recomputation)
Dao 등(2022) [1] 에 의해 도입된 FlashAttention-1은 두 가지 주요 기술을 사용하여 메모리 병목 현상을 해결합니다.
2.1 타일링 (Tiling)
FlashAttention은 어텐션 행렬을 한 번에 모두 계산하는 대신, 를 GPU의 빠른 SRAM에 들어갈 수 있는 크기의 블록(타일)으로 나눕니다.
- 의 블록을 SRAM으로 로드합니다.
- 해당 블록에 대한 어텐션을 계산합니다.
- HBM에 있는 출력값을 업데이트합니다.
한 번에 전체 행을 보지 않고 블록 간에 소프트맥스를 올바르게 계산하기 위해, FlashAttention은 실행 중인 최댓값과 지수 합을 추적하는 online softmax 기법을 활용합니다.
2.2 재계산 (Recomputation, 역전파 시)
역전파(Backward pass)를 위해 어텐션 행렬을 저장하는 것은 메모리 절약 효과를 무력화하므로, FlashAttention은 이를 저장하지 않습니다. 대신 역전파 과정에서 저장된 블록을 사용하여 SRAM에서 즉석으로 어텐션 행렬을 재계산 합니다. 이는 약간의 추가 연산을 대가로 메모리 사용량을 획기적으로 줄이는 트레이드오프입니다.
3. FlashAttention-2: 더 빠른 속도와 향상된 병렬화
FlashAttention-2 (2023) [2] 는 작업 분할(Work partitioning)과 병렬화를 최적화하여 오리지널 버전을 개선했습니다.
- 더 나은 작업 분할: 배치(Batch)와 헤드(Head) 차원뿐만 아니라 시퀀스 길이 차원에 대해서도 연산을 병렬화하여 GPU 활용도를 높였습니다.
- 비행렬 곱 연산(Non-MatMul FLOPs) 감소: GPU에서 행렬 곱에 비해 상대적으로 느린 지수(exponential) 연산 횟수를 줄이도록 online softmax 계산 방식을 리팩토링했습니다.
FlashAttention-2는 1버전에 비해 최대 2배의 속도 향상을 달성했으며, A100 GPU의 이론적 피크 FLOPs의 최대 70%에 도달했습니다.
4. FlashAttention-3: Hopper 및 FP8 최적화
FlashAttention-3 (2024) [3] 는 새로운 하드웨어 기능을 활용하기 위해 NVIDIA Hopper 아키텍처(예: H100)를 타겟으로 설계되었습니다.
- WGMMA (Warpgroup Matrix Multiply-Accumulate): H100의 새로운 하드웨어 명령어를 활용하여 행렬 곱셈 속도를 대폭 향상시켰습니다.
- 비동기 실행 (Asynchronous Execution): 행렬 곱셈 연산과 HBM-SRAM 간의 데이터 이동을 중첩(Overlap)시켜 데이터 전송 지연 시간을 숨겼습니다.
- FP8 지원: 새로운 8비트 부동 소수점 포맷을 네이티브하게 지원하여 훨씬 더 빠른 연산과 더 낮은 메모리 대역폭 사용을 가능하게 했습니다.
4.5 비교 표: FlashAttention의 진화
| 특성 | FlashAttention-1 | FlashAttention-2 | FlashAttention-3 |
|---|---|---|---|
| 주요 초점 | IO-Awareness 및 타일링 | 병렬화 및 작업 분할 | 비동기성 및 Hopper 최적화 |
| 타겟 하드웨어 | Ampere (A100) 및 이전 세대 | Ampere (A100) 및 이후 세대 | Hopper (H100) 특화 |
| 병렬화 | 배치, 헤드 | 배치, 헤드, 시퀀스 길이 | 배치, 헤드, 시퀀스 길이 |
| I/O 전략 | 동기식 타일링 | 동기식 타일링 | 비동기식 (TMA) |
| 정밀도 | FP16/BF16 | FP16/BF16 | FP8 지원 |
| 피크 FLOPs (A100) | ~30-40% | ~70% | 해당 없음 (H100에 최적화됨) |
5. PyTorch 구현: Flash Attention 사용하기
현대 PyTorch (2.0+) 에서는 scaled_dot_product_attention 함수를 통해 Flash Attention 을 매우 쉽게 사용할 수 있습니다. 사용자가 직접 CUDA 코드를 작성할 필요가 없으며, 하드웨어가 지원하는 경우 PyTorch가 자동으로 Flash Attention 을 사용합니다.
import torch
import torch.nn.functional as F
# CUDA 사용 가능 여부 확인 및 디바이스 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 차원 정의
batch_size = 4
num_heads = 8
seq_len = 2048
head_dim = 64
# 무작위 Query, Key, Value 텐서 생성
q = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
k = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
v = torch.randn(batch_size, num_heads, seq_len, head_dim, device=device)
# PyTorch 2.0+ Scaled Dot Product Attention
# 입력값과 하드웨어에 따라 최적의 구현체(FlashAttention, Memory Efficient, 또는 Math)를 자동으로 선택합니다.
with torch.inference_mode():
# 표준적인 사용법
output = F.scaled_dot_product_attention(q, k, v)
print(f"Output shape: {output.shape}")
# 특정 구현체를 강제하거나 비활성화하려면 컨텍스트 매니저를 사용할 수 있습니다:
# with torch.nn.attention.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
# output = F.scaled_dot_product_attention(q, k, v)
Quizzes
Quiz 1: FlashAttention 이 해결하고자 하는 주된 병목 현상은 무엇입니까? 연산량(FLOPs)입니까, 아니면 메모리 대역폭(I/O)입니까?
주된 병목 현상은 메모리 대역폭(I/O)입니다. 느린 HBM과 빠른 칩 내부 SRAM 사이에서 거대한 어텐션 행렬을 읽고 쓰는 데 걸리는 시간이, 행렬 곱셈에 필요한 실제 부동 소수점 연산 시간보다 훨씬 큽니다.
Quiz 2: FlashAttention-1은 어텐션 행렬의 전체 행을 한 번에 SRAM에 로드하지 않고 어떻게 소프트맥스 함수를 올바르게 계산합니까?
FlashAttention은 online softmax 라는 기법을 사용합니다. 행을 블록(타일) 단위로 처리하면서 실행 중인 통계치(최댓값 및 지수 합)를 유지합니다. 새로운 블록으로 이동할 때 새로운 최댓값을 기반으로 기존에 누적된 결과를 재조정(rescale)함으로써, 메모리에 전체 행을 올리지 않고도 정확한 소프트맥스 결과를 계산할 수 있습니다.
Quiz 3: 왜 FlashAttention 은 순전파 시에 어텐션 행렬을 저장하지 않고 역전파 시에 이를 재계산합니까?
순전파 시에 어텐션 행렬을 저장하면 의 메모리가 필요하므로 FlashAttention 의 주된 목적인 메모리 절약 효과가 사라집니다. 따라서 저장된 블록을 사용하여 빠른 SRAM에서 즉석으로 어텐션 행렬을 재계산함으로써, 약간의 추가 연산을 수행하는 대신 메모리 대역폭과 저장 비용을 획기적으로 줄일 수 있습니다.
Quiz 4: FP16을 적용하고 시퀀스 길이가 일 때, 단일 헤드에 대한 중간 어텐션 행렬 ()의 메모리 점유율을 계산하고, FlashAttention이 이 시나리오에서 Out-Of-Memory(OOM) 에러를 방지하는 원리를 설명하시오.
FP16을 사용할 경우, 각 원소는 2바이트를 소비합니다. 중간 행렬의 차원은 개의 원소입니다. 따라서 헤드당 메모리 사용량은 입니다. 만약 32개의 헤드를 가진 표준 모델이라면 총 1테라바이트가 넘는 VRAM을 요구하여 즉각 OOM 에러가 발생합니다. FlashAttention은 SRAM 타이링(Tiling) 기법을 활용하여 매우 작은 블록(예: )만을 SRAM에 올려 계산하므로, 거대한 중간 행렬을 HBM에 절대 기록하지 않아 OOM 에러를 방지합니다.
References
- Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and memory-efficient exact attention with IO-awareness. arXiv:2205.14135.
- Dao, T. (2023). FlashAttention-2: Faster attention with better parallelism and work partitioning. arXiv:2307.08691.
- Dao, T., & Haziza, N. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08608.