← all posts
AI 2026.05.03 · 11 min read Advanced

ZeRO는 왜 단계적으로 분산하는가

DDP의 16ψ 메모리 병목에서 출발해 ZeRO-1/2/3와 FSDP의 설계 결정까지, per-GPU 메모리를 1/N로 줄이는 원리를 추적한다.


70B 모델을 Data Parallelism으로만 학습하면 per-GPU 메모리는 1120 GB가 필요하다. A100 80GB가 14장 있어야 한다는 뜻이다. 왜 이렇게 많은가? 그리고 ZeRO는 이 숫자를 어떻게 70 GB까지 줄이는가?

16ψ — Adam FP16의 메모리 분해

모든 것은 하나의 등식에서 시작한다.

MDDP=2ψFP16 param+2ψFP16 grad+4ψFP32 master+4ψm+4ψv=16ψ\boxed{M_{\text{DDP}} = \underbrace{2\psi}_{\text{FP16 param}} + \underbrace{2\psi}_{\text{FP16 grad}} + \underbrace{4\psi}_{\text{FP32 master}} + \underbrace{4\psi}_{m} + \underbrace{4\psi}_{v} = 16\psi}

여기서 ψ\psi는 파라미터 수(바이트 단위). 각 항은 제거할 수 없다. FP16 파라미터는 forward의 계산 효율을 위해, FP32 master weight는 optimizer step의 수치 안정성을 위해, mmvv는 Adam의 exponential moving average를 위해 필수다.

DDP는 모든 GPU가 이 16ψ를 그대로 복제한다. 70B 모델에서 ψ = 70B이면 16×70=112016 \times 70 = 1120 GB. N=16으로 나눠도 rank당 70 GB — A100 80GB의 한계선이다. activation memory를 고려하면 이미 불가능에 가깝다.

ZeRO의 출발점은 명확하다: 이 16ψ를 단계적으로 N등분한다.

3단계 분산 — ZeRO-1, 2, 3의 수식

ZeRO는 세 component를 순서대로 sharding한다.

정리 1 · ZeRO Stage별 Per-rank 메모리

NN개의 rank, Adam optimizer (K=12K=12), ψ\psi 파라미터일 때:

MZeRO-1=4ψ+KψNM_{\text{ZeRO-1}} = 4\psi + \frac{K\psi}{N}MZeRO-2=2ψ+(2+K)ψNM_{\text{ZeRO-2}} = 2\psi + \frac{(2+K)\psi}{N}MZeRO-3=(4+K)ψN=16ψNM_{\text{ZeRO-3}} = \frac{(4+K)\psi}{N} = \frac{16\psi}{N}
▷ 증명

ZeRO-1: optimizer state (K=12K=12)만 N등분. FP16 params(2ψ)와 grads(2ψ)는 full replica. 각 rank는 mi,vi,θi(opt)m_i, v_i, \theta^{(opt)}_i (크기 Kψ/NK\psi/N)만 소유 → 4ψ+Kψ/N4\psi + K\psi/N.

ZeRO-2: gradient도 sharding. backward에서 AllReduce 대신 ReduceScatter를 사용하면 각 rank는 자신의 partition에 해당하는 summed gradient만 유지한다. FP16 grads: 2ψ2ψ/N2\psi \to 2\psi/N2ψ+(2+K)ψ/N2\psi + (2+K)\psi/N.

ZeRO-3: params마저 N등분. rank iiθiRψ/N\theta_i \in \mathbb{R}^{\psi/N}만 상시 보유. forward/backward에서 layer별로 AllGather → compute → free. 모든 component가 1/N1/N(4+K)ψ/N=16ψ/N(4+K)\psi/N = 16\psi/N. \square

70B, N=16에서의 구체적 수치:

StagePer-GPU (GB)감소율
DDP1120baseline
ZeRO-1332.570%
ZeRO-2201.382%
ZeRO-37093.8%

ZeRO-3이 되어서야 A100 80GB에 들어간다.

통신 패턴의 변화

메모리를 줄이는 대가로 통신 패턴이 바뀐다.

ZeRO-1: gradient AllReduce(2P\approx 2P) + parameter AllGather(2P\approx 2P) = 4P. DDP의 2배.

ZeRO-2: gradient AllReduce를 ReduceScatter로 교체. ReduceScatter는 AllReduce의 two-phase 중 all-gather phase를 생략한다 — 각 rank가 자신의 partition만 필요하기 때문이다. 결과: P+2P=3PP + 2P = 3P.

ReduceScatter:(N1)PNvsAllReduce:2(N1)PN\text{ReduceScatter}: \frac{(N-1)P}{N} \quad \text{vs} \quad \text{AllReduce}: \frac{2(N-1)P}{N}

ZeRO-3: layer별로 AllGather를 L번 수행하지만, 총량은 여전히 2P\approx 2P. 수량이 늘어날 뿐 총 바이트는 같다 — 대신 latency가 증가한다.

트레이드오프

ZeRO-1 → ZeRO-2는 통신량이 오히려 감소(4P → 3P)하면서 메모리도 줄어드는 “free lunch”다. ZeRO-2 → ZeRO-3은 통신 총량은 비슷하지만 round-trip 횟수가 L배 증가해 latency sensitive 환경에서 throughput이 하락한다. NVLink(~900 GB/s) 환경에서는 허용 범위이지만, 1 Gbps Ethernet에서는 치명적이다.

ZeRO-3의 구현 — Layer-by-layer AllGather

ZeRO-3의 핵심 패턴은 단순하다.

# Forward phase
for layer_idx in range(num_layers):
    params = allgather_dist(sharded_params[layer_idx])  # 1.75 GB (70B/16/80)
    activation = forward_layer(params, input)
    activation_cache.append(activation)
    free(params)  # 즉시 해제

# Backward phase (역순)
for layer_idx in range(num_layers - 1, -1, -1):
    params = allgather_dist(sharded_params[layer_idx])
    grad_input, grad_params = backward_layer(params, activation_cache[layer_idx], grad_output)
    grad_scatter = reducescatter_dist(grad_params)  # 자신의 partition만 보유
    free(params)

“AllGather → use → free” 패턴이다. 어느 시점에도 GPU에는 1개 layer의 parameter(70B/16/80 ≈ 1.75 GB)만 임시로 존재한다. 나머지 상시 메모리는 sharded grads(8.75 GB)와 optimizer state(52.5 GB)뿐 — 합계 약 63 GB.

FSDP — PyTorch가 ZeRO-3를 내재화한 방식

2023년 Meta의 FSDP(Zhao et al.)는 ZeRO-3와 수학적으로 동등하지만, 두 가지를 추가했다.

첫째, backward prefetch. backward phase에서 layer ii를 계산하는 동안 layer i1i-1의 parameter를 미리 AllGather한다. 통신 latency를 compute와 overlap시켜 throughput을 5~10% 개선한다.

둘째, PyTorch native 통합. 별도 엔진 없이 torch.distributed와 seamless하게 동작하며, IDE breakpoint가 그대로 작동한다.

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy

model_fsdp = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,   # ZeRO-3
    mixed_precision=MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.float32,   # gradient reduction은 FP32 필수
    ),
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
    auto_wrap_policy=transformer_auto_wrap_policy,    # layer별 sharding
)

reduce_dtype=torch.float32는 생략할 수 없다. BF16으로 ReduceScatter를 수행하면 small gradient가 underflow되어 수렴이 불안정해진다.

2024년 기준, LLaMA-2와 LLaMA-3는 FSDP로 학습됐다. ZeRO-3는 DeepSpeed 의존성이 필요한 specialized case에 사용된다.

ZeRO-Offload와 ZeRO-Infinity — 메모리 계층의 확장

GPU 메모리조차 모자랄 때 두 가지 확장이 있다.

ZeRO-Offload: optimizer state(Kψ/NK\psi/N)를 CPU RAM으로 이동. GPU 메모리가 (4+K)ψ/N4ψ/N(4+K)\psi/N \to 4\psi/N으로 감소한다. 70B, N=16에서 70 GB → 17.5 GB. 대가는 PCIe 전송 지연(~100 ms/step, PCIe 4.0 32 GB/s 기준)으로 throughput이 약 30% 하락한다.

ZeRO-Infinity: parameter까지 NVMe(SSD)로 offload. 400B+ 모델을 10장의 GPU로 구동할 수 있다. 하지만 NVMe 대역폭(PCIe 4.0 SSD ≈ 7 GB/s)이 GPU HBM(2000 GB/s)의 1/300이므로 I/O가 critical path를 완전히 지배한다. 70B 모델 1 epoch에 수십 시간이 소요된다. 연구/탐색용이며 production에는 사용하지 않는다.

정리

  • DDP의 16ψ는 FP16 param(2ψ) + FP16 grad(2ψ) + Adam FP32 state(12ψ)의 필연적 구성이다.
  • ZeRO는 이 세 component를 순서대로 N등분해 per-GPU를 16ψ16ψ/N16\psi \to 16\psi/N으로 줄인다.
  • ZeRO-2의 ReduceScatter는 AllReduce 대비 통신량이 절반이면서 메모리도 줄어드는 유일한 “공짜” 개선이다.
  • ZeRO-3/FSDP는 layer-by-layer AllGather로 구현되며, NVLink 환경에서만 실용적이다.
  • ZeRO-Offload는 메모리-throughput trade-off가 허용 가능하지만, ZeRO-Infinity는 extreme scale 탐색 전용이다.

ZeRO-3가 해결한 것은 메모리 병목이지, 계산 병목이 아니다. 다음 글에서는 ZeRO-3 이후에도 남아 있는 activation memory 병목과, gradient checkpointing이 이를 어떻게 O(L)O(\sqrt{L})로 줄이는지 추적한다.