ZeRO는 왜 단계적으로 분산하는가
DDP의 16ψ 메모리 병목에서 출발해 ZeRO-1/2/3와 FSDP의 설계 결정까지, per-GPU 메모리를 1/N로 줄이는 원리를 추적한다.
70B 모델을 Data Parallelism으로만 학습하면 per-GPU 메모리는 1120 GB가 필요하다. A100 80GB가 14장 있어야 한다는 뜻이다. 왜 이렇게 많은가? 그리고 ZeRO는 이 숫자를 어떻게 70 GB까지 줄이는가?
16ψ — Adam FP16의 메모리 분해
모든 것은 하나의 등식에서 시작한다.
여기서 는 파라미터 수(바이트 단위). 각 항은 제거할 수 없다. FP16 파라미터는 forward의 계산 효율을 위해, FP32 master weight는 optimizer step의 수치 안정성을 위해, 과 는 Adam의 exponential moving average를 위해 필수다.
DDP는 모든 GPU가 이 16ψ를 그대로 복제한다. 70B 모델에서 ψ = 70B이면 GB. N=16으로 나눠도 rank당 70 GB — A100 80GB의 한계선이다. activation memory를 고려하면 이미 불가능에 가깝다.
ZeRO의 출발점은 명확하다: 이 16ψ를 단계적으로 N등분한다.
3단계 분산 — ZeRO-1, 2, 3의 수식
ZeRO는 세 component를 순서대로 sharding한다.
개의 rank, Adam optimizer (), 파라미터일 때:
ZeRO-1: optimizer state ()만 N등분. FP16 params(2ψ)와 grads(2ψ)는 full replica. 각 rank는 (크기 )만 소유 → .
ZeRO-2: gradient도 sharding. backward에서 AllReduce 대신 ReduceScatter를 사용하면 각 rank는 자신의 partition에 해당하는 summed gradient만 유지한다. FP16 grads: → .
ZeRO-3: params마저 N등분. rank 는 만 상시 보유. forward/backward에서 layer별로 AllGather → compute → free. 모든 component가 → .
70B, N=16에서의 구체적 수치:
| Stage | Per-GPU (GB) | 감소율 |
|---|---|---|
| DDP | 1120 | baseline |
| ZeRO-1 | 332.5 | 70% |
| ZeRO-2 | 201.3 | 82% |
| ZeRO-3 | 70 | 93.8% |
ZeRO-3이 되어서야 A100 80GB에 들어간다.
통신 패턴의 변화
메모리를 줄이는 대가로 통신 패턴이 바뀐다.
ZeRO-1: gradient AllReduce() + parameter AllGather() = 4P. DDP의 2배.
ZeRO-2: gradient AllReduce를 ReduceScatter로 교체. ReduceScatter는 AllReduce의 two-phase 중 all-gather phase를 생략한다 — 각 rank가 자신의 partition만 필요하기 때문이다. 결과: .
ZeRO-3: layer별로 AllGather를 L번 수행하지만, 총량은 여전히 . 수량이 늘어날 뿐 총 바이트는 같다 — 대신 latency가 증가한다.
ZeRO-1 → ZeRO-2는 통신량이 오히려 감소(4P → 3P)하면서 메모리도 줄어드는 “free lunch”다. ZeRO-2 → ZeRO-3은 통신 총량은 비슷하지만 round-trip 횟수가 L배 증가해 latency sensitive 환경에서 throughput이 하락한다. NVLink(~900 GB/s) 환경에서는 허용 범위이지만, 1 Gbps Ethernet에서는 치명적이다.
ZeRO-3의 구현 — Layer-by-layer AllGather
ZeRO-3의 핵심 패턴은 단순하다.
# Forward phase
for layer_idx in range(num_layers):
params = allgather_dist(sharded_params[layer_idx]) # 1.75 GB (70B/16/80)
activation = forward_layer(params, input)
activation_cache.append(activation)
free(params) # 즉시 해제
# Backward phase (역순)
for layer_idx in range(num_layers - 1, -1, -1):
params = allgather_dist(sharded_params[layer_idx])
grad_input, grad_params = backward_layer(params, activation_cache[layer_idx], grad_output)
grad_scatter = reducescatter_dist(grad_params) # 자신의 partition만 보유
free(params)
“AllGather → use → free” 패턴이다. 어느 시점에도 GPU에는 1개 layer의 parameter(70B/16/80 ≈ 1.75 GB)만 임시로 존재한다. 나머지 상시 메모리는 sharded grads(8.75 GB)와 optimizer state(52.5 GB)뿐 — 합계 약 63 GB.
FSDP — PyTorch가 ZeRO-3를 내재화한 방식
2023년 Meta의 FSDP(Zhao et al.)는 ZeRO-3와 수학적으로 동등하지만, 두 가지를 추가했다.
첫째, backward prefetch. backward phase에서 layer 를 계산하는 동안 layer 의 parameter를 미리 AllGather한다. 통신 latency를 compute와 overlap시켜 throughput을 5~10% 개선한다.
둘째, PyTorch native 통합. 별도 엔진 없이 torch.distributed와 seamless하게 동작하며, IDE breakpoint가 그대로 작동한다.
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
model_fsdp = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32, # gradient reduction은 FP32 필수
),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
auto_wrap_policy=transformer_auto_wrap_policy, # layer별 sharding
)
reduce_dtype=torch.float32는 생략할 수 없다. BF16으로 ReduceScatter를 수행하면 small gradient가 underflow되어 수렴이 불안정해진다.
2024년 기준, LLaMA-2와 LLaMA-3는 FSDP로 학습됐다. ZeRO-3는 DeepSpeed 의존성이 필요한 specialized case에 사용된다.
ZeRO-Offload와 ZeRO-Infinity — 메모리 계층의 확장
GPU 메모리조차 모자랄 때 두 가지 확장이 있다.
ZeRO-Offload: optimizer state()를 CPU RAM으로 이동. GPU 메모리가 으로 감소한다. 70B, N=16에서 70 GB → 17.5 GB. 대가는 PCIe 전송 지연(~100 ms/step, PCIe 4.0 32 GB/s 기준)으로 throughput이 약 30% 하락한다.
ZeRO-Infinity: parameter까지 NVMe(SSD)로 offload. 400B+ 모델을 10장의 GPU로 구동할 수 있다. 하지만 NVMe 대역폭(PCIe 4.0 SSD ≈ 7 GB/s)이 GPU HBM(2000 GB/s)의 1/300이므로 I/O가 critical path를 완전히 지배한다. 70B 모델 1 epoch에 수십 시간이 소요된다. 연구/탐색용이며 production에는 사용하지 않는다.
정리
- DDP의 16ψ는 FP16 param(2ψ) + FP16 grad(2ψ) + Adam FP32 state(12ψ)의 필연적 구성이다.
- ZeRO는 이 세 component를 순서대로 N등분해 per-GPU를 으로 줄인다.
- ZeRO-2의 ReduceScatter는 AllReduce 대비 통신량이 절반이면서 메모리도 줄어드는 유일한 “공짜” 개선이다.
- ZeRO-3/FSDP는 layer-by-layer AllGather로 구현되며, NVLink 환경에서만 실용적이다.
- ZeRO-Offload는 메모리-throughput trade-off가 허용 가능하지만, ZeRO-Infinity는 extreme scale 탐색 전용이다.
ZeRO-3가 해결한 것은 메모리 병목이지, 계산 병목이 아니다. 다음 글에서는 ZeRO-3 이후에도 남아 있는 activation memory 병목과, gradient checkpointing이 이를 어떻게 로 줄이는지 추적한다.