← all posts
AI 2026.05.03 · 13 min read Advanced

Flash Attention은 어떻게 T² 메모리 장벽을 넘었나

표준 Attention의 O(T²) HBM 병목의 수학적 근원부터 Flash Attention 1/2/3의 핵심 아이디어, 그리고 PagedAttention·Ring·Linear Attention까지, 효율적 Attention 설계의 전체 계보를 추적한다.


Transformer의 모든 설계 결정 중에서 Attention은 가장 표현력이 높은 동시에 가장 비싸다. 그런데 “비싸다”의 정확한 의미는 FLOPs가 아니라 HBM 메모리 대역폭이다. GPU 연산 코어는 데이터를 기다리며 유휴 상태에 빠지고, sequence length TT가 커질수록 이 낭비가 제곱으로 쌓인다. 어떻게 하면 수학적 정확도를 포기하지 않으면서 이 장벽을 넘을 수 있는가?

병목의 정체: O(T²) HBM 접근

표준 Attention은 다음 세 단계로 계산된다.

S=QKd,P=softmax(S),O=PVS = \frac{QK^\top}{\sqrt{d}}, \quad P = \mathrm{softmax}(S), \quad O = PV

FLOPs는 O(T2d)\mathcal{O}(T^2 d)이지만, 문제는 계산량이 아니다. S,PRT×TS, P \in \mathbb{R}^{T \times T}를 HBM에 기록하고 다시 읽어야 한다는 데 있다. T=4096T = 4096이면 S,PS, P 두 행렬만 128 MB — 전체 HBM 접근의 95% 이상을 차지한다.

정리 1 · Attention의 Arithmetic Intensity

표준 Attention forward pass의 arithmetic intensity는 O(d)\mathcal{O}(d) FLOP/byte — sequence length TT와 무관하다.

Iattention=O(T2d)O(T2)=O(d)FLOP/byteI_{\text{attention}} = \frac{\mathcal{O}(T^2 d)}{\mathcal{O}(T^2)} = \mathcal{O}(d) \quad \text{FLOP/byte}
▷ 증명

FLOPs: SS 계산 T2dT^2 d + softmax T2T^2 + OO 계산 T2dT^2 d → total O(T2d)\mathcal{O}(T^2 d).

HBM bytes: Q,K,VQ, K, V read (3Td3Td), S,PS, P write/read (4T24T^2), output write (TdTd) → dominant term O(T2)\mathcal{O}(T^2).

비율: T2d/T2=dT^2 d \,/\, T^2 = d. \square

A100의 break-even point는 약 156 FLOP/byte다. d=128d = 128이면 I=128<156I = 128 < 156 — Attention은 항상 memory-bound다. GPU 연산 코어의 30~50% 시간이 HBM 대기로 사라진다.

Flash Attention 1: S와 P를 HBM에 저장하지 않는다

Dao et al. 2022의 핵심 통찰은 단 한 줄이다: “S,PS, P를 HBM에 쓰지 않아도 정확한 결과를 얻을 수 있다.”

이를 가능하게 하는 도구가 online softmax다. Softmax의 normalizer jexp(Sj)\sum_j \exp(S_j)를 구하려면 원래 모든 SjS_j가 필요하다. 하지만 block 단위로 log-sum-exp를 병합하면 전체를 한 번에 보지 않아도 된다.

log(exp(1)+exp(2))=max(1,2)+log ⁣(1+exp(12))\log(\exp(\ell_1) + \exp(\ell_2)) = \max(\ell_1, \ell_2) + \log\!\left(1 + \exp(-|\ell_1 - \ell_2|)\right)

이 항등식을 재귀적으로 적용하면, Q,K,VQ, K, V를 SRAM에 fit되는 작은 block으로 쪼개어 계산하고 결과를 점진적으로 누적할 수 있다. S,PS, P는 각 block의 SRAM 안에서만 존재하고, HBM에는 최종 output OO와 LSE 통계만 기록된다.

정리 2 · Flash Attention 1의 HBM 메모리 감소

Flash Attention 1의 HBM 접근은 O(T2)O(Td)\mathcal{O}(T^2) \to \mathcal{O}(Td)로 감소한다. Arithmetic intensity는 O(d)O(T)\mathcal{O}(d) \to \mathcal{O}(T)로 증가 — compute-bound에 근접한다.

T=4096,d=128T = 4096, d = 128이면 이론적 32배 감소. 실제 가속은 kernel 오버헤드를 고려해 **2–4×**다. 그리고 이 계산은 근사가 아니다 — machine epsilon 범위에서 표준 Attention과 정확히 일치한다.

backward pass에서는 저장하지 않은 S,PS, P를 재계산(recompute)한다. 추가 FLOPs가 발생하지만, HBM 저장을 피하는 것이 대부분의 경우 더 이득이다.

Flash Attention 2: 병렬화 구조를 바꾼다

FA1의 남은 비효율은 알고리즘 구조에 있었다. FA1은 KK에 대해 outer loop를 돌았다 — 같은 OO row에 여러 thread block이 기여하므로 synchronization이 필요했다.

FA2는 outer loop를 QQ 기준으로 바꾼다. 각 thread block이 자신의 QQ block에 대한 OO를 독립적으로 관리하므로, reduction 동기화가 사라진다. GPU occupancy가 1.5–2배 향상된다.

추가로 causal mask의 block-level skip이 도입됐다. Decoder-only LLM에서 j>ij > i인 block 쌍은 기여값이 0이므로 계산 자체를 건너뛴다.

i=0T1(i+1)=T(T+1)2T22\sum_{i=0}^{T-1}(i+1) = \frac{T(T+1)}{2} \approx \frac{T^2}{2}

이론적으로 causal 경우 50% 계산이 절감되고, 실제로는 25–35% 가속이 된다. 전체적으로 FA2는 standard attention 대비 5–9배 빠르다.

Flash Attention 3: 하드웨어를 직접 활용한다

FA2는 generic GPU architecture 기준이다. H100 (Hopper)에는 FA2가 활용하지 못한 세 가지 하드웨어 primitive가 있다.

기능역할
TMA (Tensor Memory Accelerator)HBM↔SRAM 복사를 compute core 없이 비동기로 실행
WGMMA (Warp-Group Matmul Async)tensor core의 matmul을 비동기로 실행
FP88-bit 부동소수점 — 2× compute throughput, 메모리 절반

FA3는 이 세 기능을 결합해 3-stage pipelined execution을 구현한다: (1) TMA copy, (2) softmax, (3) WGMMA matmul이 서로 겹쳐 실행된다. critical path는 세 단계의 합이 아니라 max다.

Tpipelined=max(Tcopy,Tsoftmax,Tmatmul)+overheadT_{\text{pipelined}} = \max(T_{\text{copy}},\, T_{\text{softmax}},\, T_{\text{matmul}}) + \text{overhead}

FP8 중간 계산의 정확도 손실은 실용적으로 무시할 수 있다 — softmax는 절대값이 아니라 상대적 크기에만 의존하므로, 작은 정밀도 손실이 출력에 미치는 영향이 크지 않다. Shah et al. 2024에 따르면 LLaMA-7B pretrain에서 perplexity gap은 0.1% 미만이다. FA3는 H100에서 740 TFLOPS BF16을 달성하며, FA2 대비 1.5–2배 추가 가속이다.

트레이드오프

FA3는 H100/Hopper 전용이다. TMA·WGMMA는 A100·L40S에서 사용 불가능하므로, 구형 GPU에서는 FA2가 최선이다. FP8 정밀도가 민감한 태스크라면 BF16 모드로 전환할 수 있다.

스케일 확장: PagedAttention, Ring, Linear

Single device의 computation 효율화 이후에도 두 가지 벽이 남는다: inference 시 KV cache 단편화단일 GPU를 넘는 sequence length다.

PagedAttention (Kwon et al. 2023)은 OS의 가상 메모리 페이징 아이디어를 KV cache에 적용한다. KV cache를 16-token block 단위로 분할하고 논리 블록과 물리 블록을 별도로 관리하면, 마지막 partial block을 제외한 낭비가 사라진다. fragmentation이 ~50%에서 2–5%로 감소하고, 이를 기반으로 한 continuous batching이 추론 throughput을 2–4배 향상시킨다.

Ring Attention (Liu et al. 2023)은 1M token 이상의 극도로 긴 sequence를 multi-GPU로 처리한다. GPU들을 ring topology로 연결해 KV block을 순환시키면, 각 GPU는 T/GT/G 크기의 KV만 보관하면서 수학적으로 정확한 full attention 결과를 얻을 수 있다. 통신 오버헤드가 새로운 병목이 되지만, NVLink 환경에서는 compute와 overlap 가능하다.

Linear Attention (Mamba/SSM 계열)은 softmax를 feature map으로 대체해 attention을 recurrent state update로 변환한다.

si=si1+ϕ(Ki)Vi,Oi=ϕ(Qi)si/zis_i = s_{i-1} + \phi(K_i)V_i^\top, \quad O_i = \phi(Q_i)^\top s_i \,/\, z_i

FLOPs가 O(T2)O(T)\mathcal{O}(T^2) \to \mathcal{O}(T)로 줄어들지만, perplexity는 3–10% 손실이 발생한다. 정확도가 허용 범위 내인 long-context inference 전용 시나리오에서 적합하다.

정리

  • 표준 Attention의 병목은 FLOPs가 아니라 HBM bandwidth다. Arithmetic intensity =O(d)= \mathcal{O}(d)이므로 Attention은 항상 memory-bound다.
  • Flash Attention 1은 S,PS, P를 HBM에 저장하지 않는 online softmax tiling으로 메모리 접근을 O(T2)O(Td)\mathcal{O}(T^2) \to \mathcal{O}(Td)로 줄였다. 정확도 손실 없음.
  • Flash Attention 2는 outer loop를 QQ 기준으로 바꿔 동기화를 제거하고, causal block skip으로 계산량을 절반으로 줄였다.
  • Flash Attention 3는 H100의 TMA·WGMMA·FP8을 활용한 3-stage pipeline으로 FA2 대비 1.5–2배 추가 가속을 달성했다.
  • 단일 device를 넘어서면 추론에는 PagedAttention, 분산 학습에는 Ring Attention, 극단적 long-context에는 Linear Attention이 각자의 위치를 갖는다.

“더 좋은 알고리즘”이 아니라 “메모리 계층을 이해한 알고리즘”이 LLM 효율화의 핵심이다.