FlashAttention은 어떻게 O(N²) 메모리 벽을 넘었나
Standard attention의 HBM 병목 원인부터 Online Softmax의 결합법칙, FlashAttention의 tiling 전략, v2/v3의 하드웨어 최적화까지, attention 효율화의 설계 계보를 추적한다.
- 01 모델 효율화의 4축 — Memory, Compute, Latency, Throughput
- 02 Pruning은 무엇을 제거하는가 — saliency에서 hardware까지
- 03 LLM Quantization은 왜 scale 결정의 문제인가
- 04 Knowledge Distillation은 왜 단순 압축이 아닌가
- 05 모델 압축의 4축은 어떻게 하나의 철학으로 수렴하는가
- 06 FlashAttention은 어떻게 O(N²) 메모리 벽을 넘었나
- 07 LLM을 어떻게 배포할 것인가 — serving 스택의 네 층
LLM inference latency의 30–50%는 attention에서 나온다. 그런데 attention의 병목은 연산량이 아니다. 문제는 메모리다. GPU가 계산보다 HBM을 읽고 쓰는 데 더 많은 시간을 쓴다면, 더 빠른 칩을 사도 해결되지 않는다. FlashAttention은 왜 이 문제를 알고리즘 수준에서 해결할 수 있었는가?
표준 Attention의 병목: O(N²) HBM 왕복
Standard attention의 연산 파이프라인을 보면 문제가 명확하다.
Step 1: Q, K, V를 HBM에서 SRAM으로 적재
Step 2: S = QK^T 계산 → HBM에 저장 (N×N 행렬!)
Step 3: HBM에서 S 재적재 → P = softmax(S) → HBM에 저장
Step 4: P, V 재적재 → O = PV → HBM에 저장
일 때 arithmetic intensity를 계산하면 다음과 같다.
A100의 compute-bound / memory-bound 경계값은 약 201이다. AI가 117이면 명백히 memory-bound 구간이다. GPU의 Tensor Core는 대부분 idle 상태로 HBM 전송이 끝나기를 기다린다.
더 나쁜 것은 AI가 과 무관하다는 점이다.
을 아무리 키워도 arithmetic intensity는 근방에 수렴한다. 이것이 long context에서 attention이 선형적으로 느려지는 수학적 이유다.
Online Softmax: 결합법칙이 tiling을 가능하게 한다
FlashAttention의 핵심 아이디어는 “S와 P를 HBM에 저장하지 않는 것”이다. 그러려면 전체 행을 보지 않고도 softmax를 올바르게 계산할 수 있어야 한다. 이를 가능하게 한 것이 online softmax (Milakov & Gimelshein, 2018)다.
스트림 에서 현재 상태 를 다음과 같이 갱신한다.
두 블록의 partial result 과 를 다음과 같이 합칠 수 있다.
이 연산은 결합법칙을 만족하므로 블록 처리 순서와 무관하게 동일한 결과를 낸다.
귀납법으로 임을 보인다. Base case: . Inductive step: 를 가정하면, 갱신 후
여기에 를 더하면 . 결합법칙은 와 합산 모두 결합적이므로 성립한다.
이 결합법칙이 FlashAttention tiling의 수학적 토대다. softmax 전체 행을 메모리에 올리지 않아도, 블록마다 을 누적하면 정확히 같은 결과를 얻는다.
FlashAttention: SRAM tiling으로 O(N²) HBM을 우회하다
FlashAttention의 알고리즘은 직관적으로 이렇게 읽힌다.
outer loop — Q 블록 Q_i:
SRAM에 O_i = 0, m_i = -∞, ℓ_i = 0 초기화
inner loop — K, V 블록 K_j, V_j:
K_j, V_j를 HBM에서 SRAM으로 적재
S_ij = Q_i · K_j^T / √D (SRAM 내 계산)
online softmax로 m_i, ℓ_i 갱신
O_i += P_ij @ V_j (SRAM 내 누적)
O_i = O_i / ℓ_i (최종 정규화)
O_i만 HBM에 기록
행렬 S, P는 HBM에 절대 기록되지 않는다. SRAM 안의 블록 크기 만 잠깐 쓰인다.
이로 인한 HBM access는 다음과 같이 개선된다.
SRAM 크기 이 클수록 개선폭이 크다. A100 (, )에서 라면 이론적으로 67배 HBM access 감소다.
Backward는 를 저장하는 대신 만 보관하고, 역전파 시 필요한 블록을 재계산한다.
재계산의 compute overhead는 약 30%지만, long context 학습에서 메모리 절감의 가치가 압도한다.
FlashAttention은 compute를 약간 더 쓰는 대신 HBM access를 극적으로 줄인다. attention이 memory-bound인 이상 이 교환은 거의 항상 이득이다. 단, SRAM이 작은 엣지 GPU나 아주 짧은 시퀀스에서는 tiling overhead가 net loss가 될 수 있다.
v2와 v3: 알고리즘에서 하드웨어로
FlashAttention v1이 algorithmic breakthrough였다면, v2 (Dao, 2023)와 v3 (Shah, 2024)는 하드웨어를 더 깊이 파고든다.
v2의 핵심: outer loop를 K/V 블록이 아니라 Q 블록으로 바꿨다. 각 Q 블록이 자신의 상태를 완전히 소유하므로 thread block 간 write conflict가 사라진다. 부가 효과로 causal mask에서 인 K/V 블록 전체를 건너뛸 수 있어 compute를 ~50% 절감한다. 실측 결과 v1 대비 1.7× 추가 speedup.
v3의 핵심: H100의 비동기 하드웨어 기능을 활용한 3-way pipeline이다.
iteration k:
[TMA: K_{k+2} HBM→SMEM 복사] ← 비동기
[WGMMA: K_{k+1} matmul] ← 비동기 Tensor Core
[softmax: K_k 결과 처리] ← compute
세 stage가 서로 다른 블록에 대해 동시 진행되어 GPU utilization이 90%대에 도달한다. FP8 지원까지 더하면 BF16 대비 2× throughput 추가 확보. v3의 production 결과는 v1 대비 2–3× attention throughput이다.
정리
- Standard attention의 AI는 에 수렴한다. 을 키울수록 memory-bound가 심화되고, 더 빠른 GPU도 답이 아니다.
- Online softmax의 결합법칙이 tiling을 가능하게 한다. 이것이 FlashAttention의 수학적 핵심이다.
- FlashAttention은 exact다. approximation 없이 HBM access를 으로 줄인다.
- v2는 work partitioning, v3는 H100 async(TMA, WGMMA, FP8)로 각각 한 단계씩 더 나아갔다.
알고리즘의 정확도를 유지하면서 메모리 접근 패턴을 바꾸는 것 — FlashAttention은 “어떻게 계산하느냐”가 “무엇을 계산하느냐”만큼 중요하다는 것을 보여주는 가장 선명한 사례다.