Transformer 학습에서 Activation Memory는 왜 폭발하는가
Forward pass 활성화 메모리의 수학적 분해부터 Gradient Checkpointing, Selective Recomputation, Sequence Parallelism까지 — 대규모 모델 학습의 메모리 병목을 추적한다.
- 01 분산 학습의 통신은 왜 전부 AllReduce로 귀결되는가
- 02 Data Parallelism의 수학 — AllReduce는 왜 정확한가
- 03 Tensor Parallelism은 왜 AllReduce가 정확히 2번인가
- 04 Pipeline Bubble은 어떻게 줄어드는가
- 05 ZeRO는 왜 단계적으로 분산하는가
- 06 Transformer 학습에서 Activation Memory는 왜 폭발하는가
- 07 분산 학습의 네 가지 축 — 3D Parallelism, MoE, Checkpoint, Elastic
GPT-3(175B)를 batch size 64, sequence length 2048로 학습하면 activation memory만 13TB에 달한다. 모델 파라미터(700GB)의 20배다. 왜 이렇게 되는가? 그리고 이 병목을 어떻게 5% compute overhead만으로 5배 이상 줄일 수 있는가?
Activation Memory의 정체
Backward pass에서 gradient를 계산하려면 forward pass의 모든 중간값이 필요하다. 이 중간값들이 activation이다. Transformer 한 layer의 activation memory는 다음과 같이 분해된다.
첫 번째 항은 Q, K, V projection, FFN intermediate, LayerNorm output 등 선형적으로 쌓이는 항이다. 두 번째 항이 문제다. Attention score matrix는 배치, 헤드 수, sequence length의 제곱에 비례한다. 이면 헤드 수는 64이므로:
이면 단일 layer에서 attention만 약 4GB다. L개 layer를 쌓으면 선형으로 증가한다.
이면 attention memory가 sequential ops memory를 초과한다.
. 이 비율이 1을 초과하는 조건은 이다.
GPT-3 스케일()은 이미 이 임계점 근방이다. 이면 attention score만 40GB(batch=1)로, 단일 GPU에서는 불가능하다.
Gradient Checkpointing — 을 로
Chen(2016)의 핵심 통찰은 단순하다. 모든 activation을 저장할 필요가 없다. Forward를 다시 계산하는 것이 저장보다 저렴하다면, 버리고 재계산하면 된다.
Standard:
Forward → [a₁, a₂, ..., a_L 전부 저장] — 메모리 O(L × M)
Backward → 저장된 activation 사용
Checkpointing:
Forward → [a_√L, a_2√L, ..., a_L 만 저장] — 메모리 O(√L × M)
Backward → stage별로 forward 재계산 후 사용
개 layer를 개 stage로 나누면, backward 시 각 stage마다 local forward를 한 번 더 실행한다. 재계산 비용은 원래 forward의 약 배다.
Forward가 전체 training cycle의 절반이므로, 실질 overhead는 약 **33%**다. 메모리는 최대 90% 감소한다.
PyTorch에서 이 전략은 한 줄로 구현된다.
from torch.utils.checkpoint import checkpoint
for layer in self.layers:
x = checkpoint(layer, x, use_reentrant=False)
Dropout, BatchNorm 같은 확률적 연산은 forward 재계산 시 random state가 달라질 수 있다. use_reentrant=False는 forward의 random state를 저장해 backward 재계산 시 복원하므로, 최신 PyTorch에서 항상 이 옵션을 사용해야 한다.
Selective Recomputation — 5% overhead로 5배 절감
Standard checkpointing의 문제는 모든 activation을 동등하게 취급한다는 점이다. Korthikanti et al.(2022, Megatron-LM)은 activation마다 저장 비용과 재계산 비용이 다르다는 사실을 활용한다.
| Activation | 저장 크기 | 재계산 비용 | 결정 |
|---|---|---|---|
| Attention scores | (大) | (大) | 저장 |
| FFN intermediate | (中) | (中) | 저장 |
| LayerNorm output | (小) | element-wise, (小) | 재계산 |
| Dropout mask | binary, 小 | random state 필요 | 저장 |
핵심 원칙: 저장 비용이 크고 재계산 비용도 큰 것(matmul류)은 저장하고, 저장 비용이 작고 재계산이 저렴한 것(norm류)은 재계산한다.
이 선택이 Pareto frontier에 위치하는 이유는 명확하다. GPU에서 메모리 대역폭(~4.8TB/s)과 연산 처리량(~1000 TFLOPS)의 비율을 고려하면, LayerNorm처럼 element-wise인 연산은 재계산 시간이 무시할 만하다. 반면 attention matmul의 재계산은 significant한 compute를 소비한다. 따라서 LayerNorm만 재계산으로 돌리면 compute overhead는 5% 수준에 머물면서 메모리는 80% 이상 줄어든다.
Sequence Parallelism — 통신 비용 없이 분산
Checkpointing과 Selective Recomputation이 단일 GPU 내 최적화라면, Sequence Parallelism(SP)은 activation을 여러 GPU에 분산한다.
아이디어는 sequence axis를 개 GPU로 shard하는 것이다.
LayerNorm, Dropout은 local sequence만으로 처리 가능하므로 통신이 불필요하다. Attention에서는 K와 V의 full sequence가 필요하므로 AllGather가 필요하고, FFN 이후 gradient 합산에는 ReduceScatter가 필요하다.
핵심 정리는 이 통신량이 Tensor Parallelism의 AllReduce와 동일하다는 것이다.
따라서 TP-P와 SP-P를 조합해도 통신량이 늘지 않는다. 대신 메모리는 sequential ops에서 , attention에서 로 줄어든다.
Seq=65536, P=4:
Without SP: ~1086 GB (단일 GPU, 불가능)
With SP: ~43 GB (각 GPU, A100에서 가능)
트레이드오프
세 기법은 각각 다른 비용을 치른다.
- Standard Checkpointing: 구현 1줄, 33% compute overhead, 90% memory 감소. Recurrent 모델(LSTM)엔 적용 불가.
- Selective Recomputation: 구현 복잡, 5% compute overhead, 80% memory 감소. Sparse attention, MoE에선 비용 추정이 달라진다.
- Sequence Parallelism: 통신 오버헤드 0%, 메모리 감소. Variable-length sequence에선 padding이 필요하고, seq/P가 너무 작으면 비효율적이다.
FlashAttention은 이 그림 밖에서 attention score matrix 자체를 타일 단위로 재계산해 메모리를 로 줄이며 세 기법 모두와 조합 가능하다.
정리
- Transformer activation memory의 병목은 attention score matrix다. 이면 attention이 전체 activation을 지배한다.
- Gradient Checkpointing은 stage 경계만 저장해 메모리를 로 줄이되, 33% compute overhead가 따른다.
- Selective Recomputation은 재계산 비용이 저렴한 activation(LayerNorm 등)만 골라 재계산하며, 5% overhead로 5배 절감의 Pareto frontier를 달성한다.
- Sequence Parallelism은 AllGather + ReduceScatter = AllReduce 동등성 덕분에 통신 비용 추가 없이 activation을 개 GPU에 분산한다.
세 기법의 조합이 100K+ token 학습을 현실로 만든다.