← all posts
AI 2026.05.03 · 11 min read Advanced

FlashAttention은 어떻게 O(N²) 메모리 벽을 넘었나

Standard attention의 HBM 병목 원인부터 Online Softmax의 결합법칙, FlashAttention의 tiling 전략, v2/v3의 하드웨어 최적화까지, attention 효율화의 설계 계보를 추적한다.


LLM inference latency의 30–50%는 attention에서 나온다. 그런데 attention의 병목은 연산량이 아니다. 문제는 메모리다. GPU가 계산보다 HBM을 읽고 쓰는 데 더 많은 시간을 쓴다면, 더 빠른 칩을 사도 해결되지 않는다. FlashAttention은 왜 이 문제를 알고리즘 수준에서 해결할 수 있었는가?

표준 Attention의 병목: O(N²) HBM 왕복

Standard attention의 연산 파이프라인을 보면 문제가 명확하다.

Step 1:  Q, K, V를 HBM에서 SRAM으로 적재
Step 2:  S = QK^T 계산 → HBM에 저장 (N×N 행렬!)
Step 3:  HBM에서 S 재적재 → P = softmax(S) → HBM에 저장
Step 4:  P, V 재적재 → O = PV → HBM에 저장

N=2048,D=128N = 2048, D = 128일 때 arithmetic intensity를 계산하면 다음과 같다.

AI=4N2D4ND2+2N224×20482×12818 MB117\mathrm{AI} = \frac{4N^2D}{4ND \cdot 2 + 2N^2 \cdot 2} \approx \frac{4 \times 2048^2 \times 128}{18 \text{ MB}} \approx 117

A100의 compute-bound / memory-bound 경계값은 약 201이다. AI가 117이면 명백히 memory-bound 구간이다. GPU의 Tensor Core는 대부분 idle 상태로 HBM 전송이 끝나기를 기다린다.

더 나쁜 것은 AI가 NN과 무관하다는 점이다.

AI=O(N2D)O(N2+ND)O(D)for large N\mathrm{AI} = \frac{O(N^2 D)}{O(N^2 + ND)} \approx O(D) \quad \text{for large } N

NN을 아무리 키워도 arithmetic intensity는 DD 근방에 수렴한다. 이것이 long context에서 attention이 선형적으로 느려지는 수학적 이유다.

Online Softmax: 결합법칙이 tiling을 가능하게 한다

FlashAttention의 핵심 아이디어는 “S와 P를 HBM에 저장하지 않는 것”이다. 그러려면 전체 행을 보지 않고도 softmax를 올바르게 계산할 수 있어야 한다. 이를 가능하게 한 것이 online softmax (Milakov & Gimelshein, 2018)다.

스트림 x1,x2,,xNx_1, x_2, \ldots, x_N에서 현재 상태 (m,)(m, \ell)를 다음과 같이 갱신한다.

mi=max(mi1, xi)m_i = \max(m_{i-1},\ x_i) i=i1exp(mi1mi)+exp(ximi)\ell_i = \ell_{i-1} \cdot \exp(m_{i-1} - m_i) + \exp(x_i - m_i)
정리 1 · Online Softmax의 결합법칙

두 블록의 partial result (m1,1)(m_1, \ell_1)(m2,2)(m_2, \ell_2)를 다음과 같이 합칠 수 있다.

m=max(m1,m2),=1exp(m1m)+2exp(m2m)m = \max(m_1, m_2), \quad \ell = \ell_1 \exp(m_1 - m) + \ell_2 \exp(m_2 - m)

이 연산은 결합법칙을 만족하므로 블록 처리 순서와 무관하게 동일한 결과를 낸다.

▷ 증명

귀납법으로 N=i=1Nexp(ximN)\ell_N = \sum_{i=1}^N \exp(x_i - m_N)임을 보인다. Base case: 0=0\ell_0 = 0. Inductive step: i1=j<iexp(xjmi1)\ell_{i-1} = \sum_{j < i} \exp(x_j - m_{i-1})를 가정하면, 갱신 후

i1exp(mi1mi)=j<iexp(xjmi)\ell_{i-1} \exp(m_{i-1} - m_i) = \sum_{j < i} \exp(x_j - m_i)

여기에 exp(ximi)\exp(x_i - m_i)를 더하면 i=jiexp(xjmi)\ell_i = \sum_{j \leq i} \exp(x_j - m_i). 결합법칙은 max\max와 합산 모두 결합적이므로 성립한다. \square

이 결합법칙이 FlashAttention tiling의 수학적 토대다. softmax 전체 행을 메모리에 올리지 않아도, 블록마다 (m,)(m, \ell)을 누적하면 정확히 같은 결과를 얻는다.

FlashAttention: SRAM tiling으로 O(N²) HBM을 우회하다

FlashAttention의 알고리즘은 직관적으로 이렇게 읽힌다.

outer loop — Q 블록 Q_i:
    SRAM에 O_i = 0, m_i = -∞, ℓ_i = 0 초기화
    
    inner loop — K, V 블록 K_j, V_j:
        K_j, V_j를 HBM에서 SRAM으로 적재
        S_ij = Q_i · K_j^T / √D         (SRAM 내 계산)
        online softmax로 m_i, ℓ_i 갱신
        O_i += P_ij @ V_j                (SRAM 내 누적)
    
    O_i = O_i / ℓ_i (최종 정규화)
    O_i만 HBM에 기록

N×NN \times N 행렬 S, P는 HBM에 절대 기록되지 않는다. SRAM 안의 블록 크기 Br×BcB_r \times B_c만 잠깐 쓰인다.

이로 인한 HBM access는 다음과 같이 개선된다.

Θ(N2+ND)    Θ ⁣(N2DM)\Theta(N^2 + ND) \;\longrightarrow\; \Theta\!\left(\frac{N^2 D}{M}\right)

SRAM 크기 MM이 클수록 개선폭이 크다. A100 (M192KBM \approx 192\text{KB}, D=128D = 128)에서 N=8192N = 8192라면 이론적으로 67배 HBM access 감소다.

Backward는 PP를 저장하는 대신 (m,)(m, \ell)만 보관하고, 역전파 시 필요한 블록을 재계산한다.

Mact:O(N2)    O(N)M_{\text{act}} : O(N^2) \;\longrightarrow\; O(N)

재계산의 compute overhead는 약 30%지만, long context 학습에서 메모리 절감의 가치가 압도한다.

트레이드오프

FlashAttention은 compute를 약간 더 쓰는 대신 HBM access를 극적으로 줄인다. attention이 memory-bound인 이상 이 교환은 거의 항상 이득이다. 단, SRAM이 작은 엣지 GPU나 아주 짧은 시퀀스에서는 tiling overhead가 net loss가 될 수 있다.

v2와 v3: 알고리즘에서 하드웨어로

FlashAttention v1이 algorithmic breakthrough였다면, v2 (Dao, 2023)와 v3 (Shah, 2024)는 하드웨어를 더 깊이 파고든다.

v2의 핵심: outer loop를 K/V 블록이 아니라 Q 블록으로 바꿨다. 각 Q 블록이 자신의 (m,,O)(m, \ell, O) 상태를 완전히 소유하므로 thread block 간 write conflict가 사라진다. 부가 효과로 causal mask에서 j>ij > i인 K/V 블록 전체를 건너뛸 수 있어 compute를 ~50% 절감한다. 실측 결과 v1 대비 1.7× 추가 speedup.

v3의 핵심: H100의 비동기 하드웨어 기능을 활용한 3-way pipeline이다.

iteration k:
  [TMA: K_{k+2} HBM→SMEM 복사]     ← 비동기
  [WGMMA: K_{k+1} matmul]           ← 비동기 Tensor Core
  [softmax: K_k 결과 처리]          ← compute

세 stage가 서로 다른 블록에 대해 동시 진행되어 GPU utilization이 90%대에 도달한다. FP8 지원까지 더하면 BF16 대비 2× throughput 추가 확보. v3의 production 결과는 v1 대비 2–3× attention throughput이다.

정리

  • Standard attention의 AI는 O(D)O(D)에 수렴한다. NN을 키울수록 memory-bound가 심화되고, 더 빠른 GPU도 답이 아니다.
  • Online softmax의 결합법칙이 tiling을 가능하게 한다. 이것이 FlashAttention의 수학적 핵심이다.
  • FlashAttention은 exact다. approximation 없이 HBM access를 O(N2)O(N2D/M)O(N^2) \to O(N^2 D/M)으로 줄인다.
  • v2는 work partitioning, v3는 H100 async(TMA, WGMMA, FP8)로 각각 한 단계씩 더 나아갔다.

알고리즘의 정확도를 유지하면서 메모리 접근 패턴을 바꾸는 것 — FlashAttention은 “어떻게 계산하느냐”가 “무엇을 계산하느냐”만큼 중요하다는 것을 보여주는 가장 선명한 사례다.

REF
Milakov & Gimelshein · 2018 · Online normalizer calculation for softmax · arXiv