← all posts
AI 2026.05.03 · 11 min read Advanced

Transformer 학습에서 Activation Memory는 왜 폭발하는가

Forward pass 활성화 메모리의 수학적 분해부터 Gradient Checkpointing, Selective Recomputation, Sequence Parallelism까지 — 대규모 모델 학습의 메모리 병목을 추적한다.


GPT-3(175B)를 batch size 64, sequence length 2048로 학습하면 activation memory만 13TB에 달한다. 모델 파라미터(700GB)의 20배다. 왜 이렇게 되는가? 그리고 이 병목을 어떻게 5% compute overhead만으로 5배 이상 줄일 수 있는가?

Activation Memory의 정체

Backward pass에서 gradient를 계산하려면 forward pass의 모든 중간값이 필요하다. 이 중간값들이 activation이다. Transformer 한 layer의 activation memory는 다음과 같이 분해된다.

M=34bshsequential ops+bhdhs2attention scoresM_\ell = \underbrace{34 \cdot b \cdot s \cdot h}_{\text{sequential ops}} + \underbrace{b \cdot \frac{h}{d_h} \cdot s^2}_{\text{attention scores}}

첫 번째 항은 Q, K, V projection, FFN intermediate, LayerNorm output 등 선형적으로 쌓이는 항이다. 두 번째 항이 문제다. Attention score matrix는 배치, 헤드 수, sequence length의 제곱에 비례한다. h=4096,dh=64h=4096, d_h=64이면 헤드 수는 64이므로:

Ascore=b64s2A^{\text{score}} = b \cdot 64 \cdot s^2

s=4096s=4096이면 단일 layer에서 attention만 약 4GB다. L개 layer를 쌓으면 선형으로 증가한다.

정리 1 · Attention이 Sequential을 지배하는 임계점

s>2176s > 2176이면 attention memory가 sequential ops memory를 초과한다.

▷ 증명

b(h/dh)s234bsh=s34dh=s34×64s2176\frac{b \cdot (h/d_h) \cdot s^2}{34 \cdot b \cdot s \cdot h} = \frac{s}{34 \cdot d_h} = \frac{s}{34 \times 64} \approx \frac{s}{2176}. 이 비율이 1을 초과하는 조건은 s>2176s > 2176이다. \square

GPT-3 스케일(s=2048s=2048)은 이미 이 임계점 근방이다. s=100Ks=100\text{K}이면 attention score만 40GB(batch=1)로, 단일 GPU에서는 불가능하다.

Gradient Checkpointing — O(L)O(L)O(L)O(\sqrt{L})

Chen(2016)의 핵심 통찰은 단순하다. 모든 activation을 저장할 필요가 없다. Forward를 다시 계산하는 것이 저장보다 저렴하다면, 버리고 재계산하면 된다.

Standard:
Forward  → [a₁, a₂, ..., a_L 전부 저장]  — 메모리 O(L × M)
Backward → 저장된 activation 사용

Checkpointing:
Forward  → [a_√L, a_2√L, ..., a_L 만 저장]  — 메모리 O(√L × M)
Backward → stage별로 forward 재계산 후 사용

LL개 layer를 L\sqrt{L}개 stage로 나누면, backward 시 각 stage마다 local forward를 한 번 더 실행한다. 재계산 비용은 원래 forward의 약 (11/L)(1 - 1/\sqrt{L})배다.

Total compute(21L)×Ftotal\text{Total compute} \approx \left(2 - \frac{1}{\sqrt{L}}\right) \times F_{\text{total}}

Forward가 전체 training cycle의 절반이므로, 실질 overhead는 약 **33%**다. 메모리는 최대 90% 감소한다.

PyTorch에서 이 전략은 한 줄로 구현된다.

from torch.utils.checkpoint import checkpoint

for layer in self.layers:
    x = checkpoint(layer, x, use_reentrant=False)
use_reentrant=False

Dropout, BatchNorm 같은 확률적 연산은 forward 재계산 시 random state가 달라질 수 있다. use_reentrant=False는 forward의 random state를 저장해 backward 재계산 시 복원하므로, 최신 PyTorch에서 항상 이 옵션을 사용해야 한다.

Selective Recomputation — 5% overhead로 5배 절감

Standard checkpointing의 문제는 모든 activation을 동등하게 취급한다는 점이다. Korthikanti et al.(2022, Megatron-LM)은 activation마다 저장 비용과 재계산 비용이 다르다는 사실을 활용한다.

Activation저장 크기재계산 비용결정
Attention scoresbnhs2b \cdot n_h \cdot s^2 (大)O(bs2h)O(b \cdot s^2 \cdot h) (大)저장
FFN intermediatebs4hb \cdot s \cdot 4h (中)O(bsh)O(b \cdot s \cdot h) (中)저장
LayerNorm outputbshb \cdot s \cdot h (小)element-wise, O(1)O(1) (小)재계산
Dropout maskbinary, 小random state 필요저장

핵심 원칙: 저장 비용이 크고 재계산 비용도 큰 것(matmul류)은 저장하고, 저장 비용이 작고 재계산이 저렴한 것(norm류)은 재계산한다.

이 선택이 Pareto frontier에 위치하는 이유는 명확하다. GPU에서 메모리 대역폭(~4.8TB/s)과 연산 처리량(~1000 TFLOPS)의 비율을 고려하면, LayerNorm처럼 element-wise인 연산은 재계산 시간이 무시할 만하다. 반면 attention matmul의 재계산은 significant한 compute를 소비한다. 따라서 LayerNorm만 재계산으로 돌리면 compute overhead는 5% 수준에 머물면서 메모리는 80% 이상 줄어든다.

Selective:5% compute overhead×5× memory reduction\boxed{\text{Selective:} \quad 5\% \text{ compute overhead} \times 5\text{×} \text{ memory reduction}}

Sequence Parallelism — 통신 비용 없이 분산

Checkpointing과 Selective Recomputation이 단일 GPU 내 최적화라면, Sequence Parallelism(SP)은 activation을 여러 GPU에 분산한다.

아이디어는 sequence axis를 PP개 GPU로 shard하는 것이다.

xRb×s×h    x(i)Rb×s/P×hon GPU i\mathbf{x} \in \mathbb{R}^{b \times s \times h} \;\to\; \mathbf{x}^{(i)} \in \mathbb{R}^{b \times s/P \times h} \quad \text{on GPU } i

LayerNorm, Dropout은 local sequence만으로 처리 가능하므로 통신이 불필요하다. Attention에서는 K와 V의 full sequence가 필요하므로 AllGather가 필요하고, FFN 이후 gradient 합산에는 ReduceScatter가 필요하다.

핵심 정리는 이 통신량이 Tensor Parallelism의 AllReduce와 동일하다는 것이다.

MAllGather+MReduceScatter=2×b×s×h×4=MAllReduceM_{\text{AllGather}} + M_{\text{ReduceScatter}} = 2 \times b \times s \times h \times 4 = M_{\text{AllReduce}}

따라서 TP-P와 SP-P를 조합해도 통신량이 늘지 않는다. 대신 메모리는 sequential ops에서 1/P1/P, attention에서 1/P21/P^2로 줄어든다.

Seq=65536, P=4:
  Without SP: ~1086 GB (단일 GPU, 불가능)
  With SP:    ~43 GB (각 GPU, A100에서 가능)

트레이드오프

트레이드오프

세 기법은 각각 다른 비용을 치른다.

  • Standard Checkpointing: 구현 1줄, 33% compute overhead, 90% memory 감소. Recurrent 모델(LSTM)엔 적용 불가.
  • Selective Recomputation: 구현 복잡, 5% compute overhead, 80% memory 감소. Sparse attention, MoE에선 비용 추정이 달라진다.
  • Sequence Parallelism: 통신 오버헤드 0%, 메모리 1/P1/P 감소. Variable-length sequence에선 padding이 필요하고, seq/P가 너무 작으면 비효율적이다.

FlashAttention은 이 그림 밖에서 attention score matrix 자체를 타일 단위로 재계산해 O(s2)O(s^2) 메모리를 O(s)O(s)로 줄이며 세 기법 모두와 조합 가능하다.

정리

  • Transformer activation memory의 병목은 O(s2)O(s^2) attention score matrix다. s>2176s > 2176이면 attention이 전체 activation을 지배한다.
  • Gradient Checkpointing은 L\sqrt{L} stage 경계만 저장해 메모리를 O(L)O(\sqrt{L})로 줄이되, 33% compute overhead가 따른다.
  • Selective Recomputation은 재계산 비용이 저렴한 activation(LayerNorm 등)만 골라 재계산하며, 5% overhead로 5배 절감의 Pareto frontier를 달성한다.
  • Sequence Parallelism은 AllGather + ReduceScatter = AllReduce 동등성 덕분에 통신 비용 추가 없이 activation을 PP개 GPU에 분산한다.

세 기법의 조합이 100K+ token 학습을 현실로 만든다.

REF
REF